diff --git a/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java b/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java index 685c64878..5f6208d6b 100644 --- a/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java +++ b/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java @@ -19,9 +19,15 @@ import static google.registry.beam.BeamUtils.getQueryFromFile; import com.google.auth.oauth2.GoogleCredentials; import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableSet; +import google.registry.backup.AppEngineEnvironment; +import google.registry.beam.initsql.Transforms.SerializableSupplier; import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn; import google.registry.config.CredentialModule.LocalCredential; import google.registry.config.RegistryConfig.Config; +import google.registry.model.reporting.Spec11ThreatMatch; +import google.registry.model.reporting.Spec11ThreatMatch.ThreatType; +import google.registry.persistence.transaction.JpaTransactionManager; import google.registry.util.GoogleCredentialsBundle; import google.registry.util.Retrier; import google.registry.util.SqlTemplate; @@ -37,6 +43,7 @@ import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider; +import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; @@ -46,6 +53,7 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptors; import org.joda.time.LocalDate; import org.joda.time.YearMonth; +import org.joda.time.format.ISODateTimeFormat; import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; @@ -86,6 +94,7 @@ public class Spec11Pipeline implements Serializable { private final String reportingBucketUrl; private final GoogleCredentials googleCredentials; private final Retrier retrier; + private final SerializableSupplier jpaSupplierFactory; @Inject public Spec11Pipeline( @@ -93,12 +102,14 @@ public class Spec11Pipeline implements Serializable { @Config("beamStagingUrl") String beamStagingUrl, @Config("spec11TemplateUrl") String spec11TemplateUrl, @Config("reportingBucketUrl") String reportingBucketUrl, + SerializableSupplier jpaSupplierFactory, @LocalCredential GoogleCredentialsBundle googleCredentialsBundle, Retrier retrier) { this.projectId = projectId; this.beamStagingUrl = beamStagingUrl; this.spec11TemplateUrl = spec11TemplateUrl; this.reportingBucketUrl = reportingBucketUrl; + this.jpaSupplierFactory = jpaSupplierFactory; this.googleCredentials = googleCredentialsBundle.getGoogleCredentials(); this.retrier = retrier; } @@ -177,12 +188,40 @@ public class Spec11Pipeline implements Serializable { EvaluateSafeBrowsingFn evaluateSafeBrowsingFn, ValueProvider dateProvider) { + PCollection> subdomainsSql = + domains.apply("Run through SafeBrowsing API", ParDo.of(evaluateSafeBrowsingFn)); + /* Store ThreatMatch objects in SQL. */ + subdomainsSql.apply( + ParDo.of( + new DoFn, Void>() { + @ProcessElement + public void processElement(ProcessContext context) { + // create the Spec11ThreatMatch from Subdomain and ThreatMatch + try (AppEngineEnvironment env = new AppEngineEnvironment()) { + Subdomain subdomain = context.element().getKey(); + Spec11ThreatMatch threatMatch = + new Spec11ThreatMatch.Builder() + .setThreatTypes( + ImmutableSet.of( + ThreatType.valueOf(context.element().getValue().threatType()))) + .setCheckDate( + LocalDate.parse(dateProvider.get(), ISODateTimeFormat.date())) + .setDomainName(subdomain.domainName()) + .setDomainRepoId(subdomain.domainRepoId()) + .setRegistrarId(subdomain.registrarId()) + .build(); + JpaTransactionManager jpaTransactionManager = jpaSupplierFactory.get(); + jpaTransactionManager.transact(() -> jpaTransactionManager.saveNew(threatMatch)); + } + } + })); + /* Store ThreatMatch objects in JSON. */ PCollection> subdomainsJson = domains.apply("Run through SafeBrowsingAPI", ParDo.of(evaluateSafeBrowsingFn)); subdomainsJson .apply( - "Map registrar client ID to email/ThreatMatch pair", + "Map registrar ID to email/ThreatMatch pair", MapElements.into( TypeDescriptors.kvs( TypeDescriptors.strings(), TypeDescriptor.of(EmailAndThreatMatch.class))) diff --git a/core/src/main/java/google/registry/model/reporting/Spec11ThreatMatch.java b/core/src/main/java/google/registry/model/reporting/Spec11ThreatMatch.java index fa2e20902..592823325 100644 --- a/core/src/main/java/google/registry/model/reporting/Spec11ThreatMatch.java +++ b/core/src/main/java/google/registry/model/reporting/Spec11ThreatMatch.java @@ -140,6 +140,16 @@ public class Spec11ThreatMatch extends ImmutableObject implements Buildable, Sql return super.build(); } + /** + * Manually set the ID for testing or other special circumstances. + * + *

In general the ID is generated by SQL and there should be no need to set it manually. + */ + public Builder setId(Long id) { + getInstance().id = id; + return this; + } + public Builder setDomainName(String domainName) { getInstance().domainName = domainName; getInstance().tld = DomainNameUtils.getTldFromDomainName(domainName); diff --git a/core/src/main/java/google/registry/tools/DeploySpec11PipelineCommand.java b/core/src/main/java/google/registry/tools/DeploySpec11PipelineCommand.java index 62f5ff3e8..146660f63 100644 --- a/core/src/main/java/google/registry/tools/DeploySpec11PipelineCommand.java +++ b/core/src/main/java/google/registry/tools/DeploySpec11PipelineCommand.java @@ -14,7 +14,10 @@ package google.registry.tools; +import com.beust.jcommander.Parameter; import com.beust.jcommander.Parameters; +import google.registry.beam.initsql.BeamJpaModule.JpaTransactionManagerComponent; +import google.registry.beam.initsql.JpaSupplierFactory; import google.registry.beam.spec11.Spec11Pipeline; import google.registry.config.CredentialModule.LocalCredential; import google.registry.config.RegistryConfig.Config; @@ -31,6 +34,12 @@ public class DeploySpec11PipelineCommand implements Command { @Config("projectId") String projectId; + @Parameter( + names = {"-p", "--project"}, + description = "Cloud KMS project ID", + required = true) + String cloudKmsProjectId; + @Inject @Config("beamStagingUrl") String beamStagingUrl; @@ -53,12 +62,19 @@ public class DeploySpec11PipelineCommand implements Command { @Override public void run() { + JpaSupplierFactory jpaSupplierFactory = + new JpaSupplierFactory( + sqlAccessInfoFile, + cloudKmsProjectId, + JpaTransactionManagerComponent::cloudSqlJpaTransactionManager); + Spec11Pipeline pipeline = new Spec11Pipeline( projectId, beamStagingUrl, spec11TemplateUrl, reportingBucketUrl, + jpaSupplierFactory, googleCredentialsBundle, retrier); pipeline.deploy(); diff --git a/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java b/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java index 04ba955a4..17b8544c9 100644 --- a/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java +++ b/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java @@ -17,15 +17,22 @@ package google.registry.beam.spec11; import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import static org.mockito.Mockito.withSettings; import com.google.auth.oauth2.GoogleCredentials; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.io.CharStreams; import google.registry.beam.TestPipelineExtension; import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn; +import google.registry.model.reporting.Spec11ThreatMatch; +import google.registry.model.reporting.Spec11ThreatMatch.ThreatType; +import google.registry.persistence.transaction.JpaTransactionManager; import google.registry.testing.FakeClock; import google.registry.testing.FakeSleeper; import google.registry.util.GoogleCredentialsBundle; @@ -55,22 +62,42 @@ import org.apache.http.entity.BasicHttpEntity; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.message.BasicStatusLine; import org.joda.time.DateTime; +import org.joda.time.LocalDate; +import org.joda.time.format.ISODateTimeFormat; import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.io.TempDir; +import org.mockito.Mock; import org.mockito.invocation.InvocationOnMock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; import org.mockito.stubbing.Answer; /** Unit tests for {@link Spec11Pipeline}. */ +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) class Spec11PipelineTest { + private static class SaveNewThreatMatchAnswer implements Answer, Serializable { + @Override + public Void answer(InvocationOnMock invocation) { + Runnable runnable = invocation.getArgument(0, Runnable.class); + runnable.run(); + return null; + } + } private static PipelineOptions pipelineOptions; + @Mock(serializable = true) + private static JpaTransactionManager mockJpaTm; + @BeforeAll static void beforeAll() { pipelineOptions = PipelineOptionsFactory.create(); @@ -93,20 +120,23 @@ class Spec11PipelineTest { void beforeEach() throws IOException { String beamTempFolder = Files.createDirectory(tmpDir.resolve("beam_temp")).toAbsolutePath().toString(); + spec11Pipeline = new Spec11Pipeline( "test-project", beamTempFolder + "/staging", beamTempFolder + "/templates/invoicing", tmpDir.toAbsolutePath().toString(), + () -> mockJpaTm, GoogleCredentialsBundle.create(GoogleCredentials.create(null)), retrier); } private static final ImmutableList BAD_DOMAINS = - ImmutableList.of("111.com", "222.com", "444.com", "no-email.com"); + ImmutableList.of( + "111.com", "222.com", "444.com", "no-email.com", "testThreatMatchToSqlBad.com"); - private ImmutableList getInputDomains() { + private ImmutableList getInputDomainsJson() { ImmutableList.Builder subdomainsBuilder = new ImmutableList.Builder<>(); // Put in at least 2 batches worth (x > 490) to guarantee multiple executions. // Put in half for theRegistrar and half for someRegistrar @@ -134,17 +164,18 @@ class Spec11PipelineTest { @SuppressWarnings("unchecked") void testEndToEndPipeline_generatesExpectedFiles() throws Exception { // Establish mocks for testing - ImmutableList inputRows = getInputDomains(); - CloseableHttpClient httpClient = mock(CloseableHttpClient.class, withSettings().serializable()); + ImmutableList inputRows = getInputDomainsJson(); + CloseableHttpClient mockHttpClient = + mock(CloseableHttpClient.class, withSettings().serializable()); // Return a mock HttpResponse that returns a JSON response based on the request. - when(httpClient.execute(any(HttpPost.class))).thenAnswer(new HttpResponder()); + when(mockHttpClient.execute(any(HttpPost.class))).thenAnswer(new HttpResponder()); EvaluateSafeBrowsingFn evalFn = new EvaluateSafeBrowsingFn( StaticValueProvider.of("apikey"), new Retrier(new FakeSleeper(new FakeClock()), 3), - (Serializable & Supplier) () -> httpClient); + (Serializable & Supplier) () -> mockHttpClient); // Apply input and evaluation transforms PCollection input = testPipeline.apply(Create.of(inputRows)); @@ -207,6 +238,56 @@ class Spec11PipelineTest { .toString()); } + @Test + @SuppressWarnings("unchecked") + public void testSpec11ThreatMatchToSql() throws Exception { + doAnswer(new SaveNewThreatMatchAnswer()).when(mockJpaTm).transact(any(Runnable.class)); + + // Create one bad and one good Subdomain to test with evaluateUrlHealth. Only the bad one should + // be detected and persisted. + Subdomain badDomain = + Subdomain.create( + "testThreatMatchToSqlBad.com", "theDomain", "theRegistrar", "fake@theRegistrar.com"); + Subdomain goodDomain = + Subdomain.create( + "testThreatMatchToSqlGood.com", + "someDomain", + "someRegistrar", + "fake@someRegistrar.com"); + + // Establish a mock HttpResponse that returns a JSON response based on the request. + CloseableHttpClient mockHttpClient = + mock(CloseableHttpClient.class, withSettings().serializable()); + when(mockHttpClient.execute(any(HttpPost.class))).thenAnswer(new HttpResponder()); + + EvaluateSafeBrowsingFn evalFn = + new EvaluateSafeBrowsingFn( + StaticValueProvider.of("apikey"), + new Retrier(new FakeSleeper(new FakeClock()), 3), + (Serializable & Supplier) () -> mockHttpClient); + + // Apply input and evaluation transforms + PCollection input = testPipeline.apply(Create.of(badDomain, goodDomain)); + spec11Pipeline.evaluateUrlHealth(input, evalFn, StaticValueProvider.of("2020-06-10")); + testPipeline.run(); + + // Verify that the expected threat created from the bad Subdomain and the persisted + // Spec11TThreatMatch are equal. + Spec11ThreatMatch expected = + new Spec11ThreatMatch() + .asBuilder() + .setThreatTypes(ImmutableSet.of(ThreatType.MALWARE)) + .setCheckDate(LocalDate.parse("2020-06-10", ISODateTimeFormat.date())) + .setDomainName(badDomain.domainName()) + .setDomainRepoId(badDomain.domainRepoId()) + .setRegistrarId(badDomain.registrarId()) + .build(); + + verify(mockJpaTm).transact(any(Runnable.class)); + verify(mockJpaTm).saveNew(expected); + verifyNoMoreInteractions(mockJpaTm); + } + /** * A serializable {@link Answer} that returns a mock HTTP response based on the HTTP request's * content. diff --git a/core/src/test/java/google/registry/model/reporting/Spec11ThreatMatchTest.java b/core/src/test/java/google/registry/model/reporting/Spec11ThreatMatchTest.java index 62031e256..c14205489 100644 --- a/core/src/test/java/google/registry/model/reporting/Spec11ThreatMatchTest.java +++ b/core/src/test/java/google/registry/model/reporting/Spec11ThreatMatchTest.java @@ -114,6 +114,8 @@ public class Spec11ThreatMatchTest extends EntityTestCase { VKey threatVKey = VKey.createSql(Spec11ThreatMatch.class, threat.getId()); Spec11ThreatMatch persistedThreat = jpaTm().transact(() -> jpaTm().load(threatVKey)); + + // Threat object saved for the first time doesn't have an ID; it is generated by SQL threat.id = persistedThreat.id; assertThat(threat).isEqualTo(persistedThreat); } diff --git a/release/cloudbuild-deploy.yaml b/release/cloudbuild-deploy.yaml index 26cd5eca7..45207490f 100644 --- a/release/cloudbuild-deploy.yaml +++ b/release/cloudbuild-deploy.yaml @@ -37,6 +37,18 @@ steps: cat tool-credential.json.enc | base64 -d | gcloud kms decrypt \ --ciphertext-file=- --plaintext-file=tool-credential.json \ --location=global --keyring=nomulus-tool-keyring --key=nomulus-tool-key +# Set the path to the file for sql_access_info to deploy the Spec 11 pipeline +- name: 'gcr.io/$PROJECT_ID/builder:latest' + entrypoint: /bin/bash + args: + - -c + - | + set -e + if [ ${_ENV} == production ]; then + echo "gs://domain-registry-beam/cloudsql/admin_credential.enc" > sql_access_path.txt + else + echo "gs://domain-registry-${_ENV}-beam/cloudsql/admin_credential.enc" > sql_access_path.txt + fi # Deploy the Spec11 pipeline to GCS. - name: 'gcr.io/$PROJECT_ID/nomulus-tool:latest' args: @@ -44,7 +56,11 @@ steps: - ${_ENV} - --credential - tool-credential.json + - --sql_access_info + - `cat sql_access_path.txt` - deploy_spec11_pipeline + - --project + - $PROJECT_ID # Deploy the invoicing pipeline to GCS. - name: 'gcr.io/$PROJECT_ID/nomulus-tool:latest' args: