From b90865b40441657048cfe585cb9f14fe7245e2ac Mon Sep 17 00:00:00 2001 From: sarahcaseybot Date: Fri, 11 Jun 2021 14:25:20 -0400 Subject: [PATCH] Add Cloud SQL read to Spec11Pipeline (#1173) * Add Cloud SQL read to Spec11Pipeline * Add database option * Add database parameter * Add a test of the full pipeline * Use DatabaseHelper in tests * restore the original tm * More test fixes --- .../registry/beam/spec11/Spec11Pipeline.java | 60 +++++-- .../beam/spec11/Spec11PipelineOptions.java | 5 + .../registry/reporting/ReportingModule.java | 12 ++ .../spec11/GenerateSpec11ReportAction.java | 10 ++ .../beam/spec11_pipeline_metadata.json | 8 + .../spec11/SafeBrowsingTransformsTest.java | 2 +- .../beam/spec11/Spec11PipelineTest.java | 164 ++++++++++++++++-- .../GenerateSpec11ReportActionTest.java | 2 + 8 files changed, 234 insertions(+), 29 deletions(-) diff --git a/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java b/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java index c9cef2d5a..97a114b7d 100644 --- a/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java +++ b/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java @@ -23,8 +23,10 @@ import dagger.Component; import dagger.Module; import dagger.Provides; import google.registry.beam.common.RegistryJpaIO; +import google.registry.beam.common.RegistryJpaIO.Read; import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn; import google.registry.config.RegistryConfig.ConfigModule; +import google.registry.model.domain.DomainBase; import google.registry.model.reporting.Spec11ThreatMatch; import google.registry.model.reporting.Spec11ThreatMatch.ThreatType; import google.registry.util.Retrier; @@ -97,20 +99,9 @@ public class Spec11Pipeline implements Serializable { void setupPipeline(Pipeline pipeline) { PCollection domains = - pipeline.apply( - "Read active domains from BigQuery", - BigQueryIO.read(Subdomain::parseFromRecord) - .fromQuery( - SqlTemplate.create(getQueryFromFile(Spec11Pipeline.class, "subdomains.sql")) - .put("PROJECT_ID", options.getProject()) - .put("DATASTORE_EXPORT_DATASET", "latest_datastore_export") - .put("REGISTRAR_TABLE", "Registrar") - .put("DOMAIN_BASE_TABLE", "DomainBase") - .build()) - .withCoder(SerializableCoder.of(Subdomain.class)) - .usingStandardSql() - .withoutValidation() - .withTemplateCompatibility()); + options.getDatabase().equals("DATASTORE") + ? readFromBigQuery(options, pipeline) + : readFromCloudSql(pipeline); PCollection> threatMatches = domains.apply("Run through SafeBrowsing API", ParDo.of(safeBrowsingFn)); @@ -119,6 +110,47 @@ public class Spec11Pipeline implements Serializable { saveToGcs(threatMatches, options); } + static PCollection readFromCloudSql(Pipeline pipeline) { + Read read = + RegistryJpaIO.read( + "select d, r.emailAddress from Domain d join Registrar r on" + + " d.currentSponsorClientId = r.clientIdentifier where r.type = 'REAL'" + + " and d.deletionTime > now()", + Spec11Pipeline::parseRow); + + return pipeline.apply("Read active domains from Cloud SQL", read); + } + + static PCollection readFromBigQuery(Spec11PipelineOptions options, Pipeline pipeline) { + return pipeline.apply( + "Read active domains from BigQuery", + BigQueryIO.read(Subdomain::parseFromRecord) + .fromQuery( + SqlTemplate.create(getQueryFromFile(Spec11Pipeline.class, "subdomains.sql")) + .put("PROJECT_ID", options.getProject()) + .put("DATASTORE_EXPORT_DATASET", "latest_datastore_export") + .put("REGISTRAR_TABLE", "Registrar") + .put("DOMAIN_BASE_TABLE", "DomainBase") + .build()) + .withCoder(SerializableCoder.of(Subdomain.class)) + .usingStandardSql() + .withoutValidation() + .withTemplateCompatibility()); + } + + private static Subdomain parseRow(Object[] row) { + DomainBase domainBase = (DomainBase) row[0]; + String emailAddress = (String) row[1]; + if (emailAddress == null) { + emailAddress = ""; + } + return Subdomain.create( + domainBase.getDomainName(), + domainBase.getRepoId(), + domainBase.getCurrentSponsorClientId(), + emailAddress); + } + static void saveToSql( PCollection> threatMatches, Spec11PipelineOptions options) { String transformId = "Spec11 Threat Matches"; diff --git a/core/src/main/java/google/registry/beam/spec11/Spec11PipelineOptions.java b/core/src/main/java/google/registry/beam/spec11/Spec11PipelineOptions.java index 7e3ab546b..a04730b7c 100644 --- a/core/src/main/java/google/registry/beam/spec11/Spec11PipelineOptions.java +++ b/core/src/main/java/google/registry/beam/spec11/Spec11PipelineOptions.java @@ -34,4 +34,9 @@ public interface Spec11PipelineOptions extends RegistryPipelineOptions { String getReportingBucketUrl(); void setReportingBucketUrl(String value); + + @Description("The database to read data from.") + String getDatabase(); + + void setDatabase(String value); } diff --git a/core/src/main/java/google/registry/reporting/ReportingModule.java b/core/src/main/java/google/registry/reporting/ReportingModule.java index 8e9112252..6bece86d4 100644 --- a/core/src/main/java/google/registry/reporting/ReportingModule.java +++ b/core/src/main/java/google/registry/reporting/ReportingModule.java @@ -14,6 +14,7 @@ package google.registry.reporting; +import static google.registry.persistence.transaction.TransactionManagerFactory.tm; import static google.registry.request.RequestParameters.extractOptionalParameter; import static google.registry.request.RequestParameters.extractRequiredParameter; @@ -55,6 +56,9 @@ public class ReportingModule { /** The request parameter specifying the jobId for a running Dataflow pipeline. */ public static final String PARAM_JOB_ID = "jobId"; + /** The request parameter for specifying which database reporting actions should read from. */ + public static final String DATABASE = "database"; + /** Provides the Cloud Dataflow jobId for a pipeline. */ @Provides @Parameter(PARAM_JOB_ID) @@ -62,6 +66,14 @@ public class ReportingModule { return extractRequiredParameter(req, PARAM_JOB_ID); } + /** Provides the database for the pipeline to read from. */ + @Provides + @Parameter(DATABASE) + static String provideDatabase(HttpServletRequest req) { + Optional optionalDatabase = extractOptionalParameter(req, DATABASE); + return optionalDatabase.orElse(tm().isOfy() ? "DATASTORE" : "CLOUD_SQL"); + } + /** Extracts an optional YearMonth in yyyy-MM format from the request. */ @Provides @Parameter(PARAM_YEAR_MONTH) diff --git a/core/src/main/java/google/registry/reporting/spec11/GenerateSpec11ReportAction.java b/core/src/main/java/google/registry/reporting/spec11/GenerateSpec11ReportAction.java index a6c601d25..dcdfb51b1 100644 --- a/core/src/main/java/google/registry/reporting/spec11/GenerateSpec11ReportAction.java +++ b/core/src/main/java/google/registry/reporting/spec11/GenerateSpec11ReportAction.java @@ -15,6 +15,8 @@ package google.registry.reporting.spec11; import static google.registry.beam.BeamUtils.createJobName; +import static google.registry.persistence.transaction.TransactionManagerFactory.tm; +import static google.registry.reporting.ReportingModule.DATABASE; import static google.registry.reporting.ReportingModule.PARAM_DATE; import static google.registry.reporting.ReportingUtils.enqueueBeamReportingTask; import static google.registry.request.Action.Method.POST; @@ -69,6 +71,7 @@ public class GenerateSpec11ReportAction implements Runnable { private final Clock clock; private final Response response; private final Dataflow dataflow; + private final String database; @Inject GenerateSpec11ReportAction( @@ -78,15 +81,20 @@ public class GenerateSpec11ReportAction implements Runnable { @Config("reportingBucketUrl") String reportingBucketUrl, @Key("safeBrowsingAPIKey") String apiKey, @Parameter(PARAM_DATE) LocalDate date, + @Parameter(DATABASE) String database, Clock clock, Response response, Dataflow dataflow) { this.projectId = projectId; this.jobRegion = jobRegion; this.stagingBucketUrl = stagingBucketUrl; + if (tm().isOfy() && database.equals("CLOUD_SQL")) { + reportingBucketUrl = reportingBucketUrl.concat("-sql"); + } this.reportingBucketUrl = reportingBucketUrl; this.apiKey = apiKey; this.date = date; + this.database = database; this.clock = clock; this.response = response; this.dataflow = dataflow; @@ -105,6 +113,8 @@ public class GenerateSpec11ReportAction implements Runnable { ImmutableMap.of( "safeBrowsingApiKey", apiKey, + "database", + database, ReportingModule.PARAM_DATE, date.toString(), "reportingBucketUrl", diff --git a/core/src/main/resources/google/registry/beam/spec11_pipeline_metadata.json b/core/src/main/resources/google/registry/beam/spec11_pipeline_metadata.json index eaff2dfe8..869b0268f 100644 --- a/core/src/main/resources/google/registry/beam/spec11_pipeline_metadata.json +++ b/core/src/main/resources/google/registry/beam/spec11_pipeline_metadata.json @@ -61,6 +61,14 @@ "regexes": [ "^gs:\\/\\/[^\\n\\r]+$" ] + }, + { + "name": "database", + "label": "Database to read from.", + "helpText": "DATASTORE or CLOUD_SQL.", + "regexes": [ + "^DATASTORE|CLOUD_SQL$" + ] } ] } diff --git a/core/src/test/java/google/registry/beam/spec11/SafeBrowsingTransformsTest.java b/core/src/test/java/google/registry/beam/spec11/SafeBrowsingTransformsTest.java index 04a5dd81e..224bc86c9 100644 --- a/core/src/test/java/google/registry/beam/spec11/SafeBrowsingTransformsTest.java +++ b/core/src/test/java/google/registry/beam/spec11/SafeBrowsingTransformsTest.java @@ -164,7 +164,7 @@ class SafeBrowsingTransformsTest { * A serializable {@link Answer} that returns a mock HTTP response based on the HTTP request's * content. */ - private static class HttpResponder implements Answer, Serializable { + static class HttpResponder implements Answer, Serializable { @Override public CloseableHttpResponse answer(InvocationOnMock invocation) throws Throwable { return getMockResponse( diff --git a/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java b/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java index 8661a5855..9d2f765c2 100644 --- a/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java +++ b/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java @@ -18,33 +18,60 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.truth.Truth.assertThat; import static google.registry.model.ImmutableObjectSubject.immutableObjectCorrespondence; import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm; +import static google.registry.persistence.transaction.TransactionManagerFactory.tm; +import static google.registry.testing.AppEngineExtension.makeRegistrar1; +import static google.registry.testing.DatabaseHelper.createTld; +import static google.registry.testing.DatabaseHelper.persistActiveContact; +import static google.registry.testing.DatabaseHelper.persistNewRegistrar; +import static google.registry.testing.DatabaseHelper.persistResource; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; import com.google.common.truth.Correspondence; import com.google.common.truth.Correspondence.BinaryPredicate; import google.registry.beam.TestPipelineExtension; +import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn; +import google.registry.beam.spec11.SafeBrowsingTransformsTest.HttpResponder; +import google.registry.model.contact.ContactResource; +import google.registry.model.domain.DomainAuthInfo; +import google.registry.model.domain.DomainBase; +import google.registry.model.eppcommon.AuthInfo.PasswordAuth; +import google.registry.model.registrar.Registrar; import google.registry.model.reporting.Spec11ThreatMatch; import google.registry.model.reporting.Spec11ThreatMatch.ThreatType; import google.registry.model.reporting.Spec11ThreatMatchDao; import google.registry.persistence.transaction.JpaTestRules; import google.registry.persistence.transaction.JpaTestRules.JpaIntegrationTestExtension; +import google.registry.persistence.transaction.TransactionManager; +import google.registry.persistence.transaction.TransactionManagerFactory; import google.registry.testing.DatastoreEntityExtension; import google.registry.testing.FakeClock; +import google.registry.testing.FakeSleeper; import google.registry.util.ResourceUtils; +import google.registry.util.Retrier; import java.io.File; import java.nio.file.Files; import java.nio.file.Path; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.CloseableHttpClient; import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.DateTime; import org.joda.time.LocalDate; import org.json.JSONObject; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Test; @@ -60,9 +87,14 @@ import org.junit.jupiter.api.io.TempDir; */ class Spec11PipelineTest { + private static final DateTime START_TIME = DateTime.parse("2020-01-27T00:00:00.0Z"); + private final FakeClock fakeClock = new FakeClock(START_TIME); + private static final String DATE = "2020-01-27"; private static final String SAFE_BROWSING_API_KEY = "api-key"; private static final String REPORTING_BUCKET_URL = "reporting_bucket"; + private final CloseableHttpClient mockHttpClient = + mock(CloseableHttpClient.class, withSettings().serializable()); private static final ImmutableList SUBDOMAINS = ImmutableList.of( @@ -103,12 +135,18 @@ class Spec11PipelineTest { private File reportingBucketUrl; private PCollection> threatMatches; + ImmutableSet sqlThreatMatches; + TransactionManager tm; + @BeforeEach void beforeEach() throws Exception { + tm = tm(); + TransactionManagerFactory.setTm(jpaTm()); reportingBucketUrl = Files.createDirectory(tmpDir.resolve(REPORTING_BUCKET_URL)).toFile(); options.setDate(DATE); options.setSafeBrowsingApiKey(SAFE_BROWSING_API_KEY); options.setReportingBucketUrl(reportingBucketUrl.getAbsolutePath()); + options.setDatabase("DATASTORE"); threatMatches = pipeline.apply( Create.of( @@ -118,11 +156,8 @@ class Spec11PipelineTest { KvCoder.of( SerializableCoder.of(Subdomain.class), SerializableCoder.of(ThreatMatch.class)))); - } - @Test - void testSuccess_saveToSql() { - ImmutableSet sqlThreatMatches = + sqlThreatMatches = ImmutableSet.of( new Spec11ThreatMatch.Builder() .setDomainName("111.com") @@ -159,25 +194,98 @@ class Spec11PipelineTest { .setCheckDate(new LocalDate(2020, 1, 27)) .setThreatTypes(ImmutableSet.of(ThreatType.UNWANTED_SOFTWARE)) .build()); + } + + @AfterEach + void afterEach() { + TransactionManagerFactory.setTm(tm); + } + + @Test + void testSuccess_fullSqlPipeline() throws Exception { + setupCloudSql(); + options.setDatabase("CLOUD_SQL"); + EvaluateSafeBrowsingFn safeBrowsingFn = + new EvaluateSafeBrowsingFn( + SAFE_BROWSING_API_KEY, + new Retrier(new FakeSleeper(new FakeClock()), 1), + Suppliers.ofInstance(mockHttpClient)); + when(mockHttpClient.execute(any(HttpPost.class))).thenAnswer(new HttpResponder()); + Spec11Pipeline spec11Pipeline = new Spec11Pipeline(options, safeBrowsingFn); + spec11Pipeline.setupPipeline(pipeline); + pipeline.run(options).waitUntilFinish(); + verifySaveToGcs(); + verifySaveToCloudSql(); + } + + @Test + void testSuccess_saveToSql() { Spec11Pipeline.saveToSql(threatMatches, options); pipeline.run().waitUntilFinish(); - assertThat( - jpaTm() - .transact( - () -> - Spec11ThreatMatchDao.loadEntriesByDate( - jpaTm(), new LocalDate(2020, 1, 27)))) - .comparingElementsUsing(immutableObjectCorrespondence("id")) - .containsExactlyElementsIn(sqlThreatMatches); + verifySaveToCloudSql(); } @Test void testSuccess_saveToGcs() throws Exception { + Spec11Pipeline.saveToGcs(threatMatches, options); + pipeline.run().waitUntilFinish(); + verifySaveToGcs(); + } + + @Test + void testSuccess_readFromCloudSql() throws Exception { + setupCloudSql(); + PCollection subdomains = Spec11Pipeline.readFromCloudSql(pipeline); + PAssert.that(subdomains).containsInAnyOrder(SUBDOMAINS); + pipeline.run().waitUntilFinish(); + } + + private void setupCloudSql() { + persistNewRegistrar("TheRegistrar"); + persistNewRegistrar("NewRegistrar"); + Registrar registrar1 = + persistResource( + makeRegistrar1() + .asBuilder() + .setClientId("hello-registrar") + .setEmailAddress("email@hello.net") + .build()); + Registrar registrar2 = + persistResource( + makeRegistrar1() + .asBuilder() + .setClientId("kitty-registrar") + .setEmailAddress("contact@kit.ty") + .build()); + Registrar registrar3 = + persistResource( + makeRegistrar1() + .asBuilder() + .setClientId("cool-registrar") + .setEmailAddress("cool@aid.net") + .build()); + + createTld("com"); + createTld("net"); + createTld("bank"); + createTld("dev"); + + ContactResource contact1 = persistActiveContact(registrar1.getClientId()); + ContactResource contact2 = persistActiveContact(registrar2.getClientId()); + ContactResource contact3 = persistActiveContact(registrar3.getClientId()); + + persistResource(createDomain("111.com", "123456789-COM", registrar1, contact1)); + persistResource(createDomain("party-night.net", "2244AABBC-NET", registrar2, contact2)); + persistResource(createDomain("bitcoin.bank", "1C3D5E7F9-BANK", registrar1, contact1)); + persistResource(createDomain("no-email.com", "2A4BA9BBC-COM", registrar2, contact2)); + persistResource( + createDomain("anti-anti-anti-virus.dev", "555666888-DEV", registrar3, contact3)); + } + + private void verifySaveToGcs() throws Exception { ImmutableList expectedFileContents = ImmutableList.copyOf( ResourceUtils.readResourceUtf8(this.getClass(), "test_output.txt").split("\n")); - Spec11Pipeline.saveToGcs(threatMatches, options); - pipeline.run().waitUntilFinish(); ImmutableList resultFileContents = resultFileContents(); assertThat(resultFileContents.size()).isEqualTo(expectedFileContents.size()); assertThat(resultFileContents.get(0)).isEqualTo(expectedFileContents.get(0)); @@ -188,6 +296,34 @@ class Spec11PipelineTest { .containsExactlyElementsIn(expectedFileContents.subList(1, expectedFileContents.size())); } + private void verifySaveToCloudSql() { + jpaTm() + .transact( + () -> { + ImmutableList sqlThreatMatches = + Spec11ThreatMatchDao.loadEntriesByDate(jpaTm(), new LocalDate(2020, 1, 27)); + assertThat(sqlThreatMatches) + .comparingElementsUsing(immutableObjectCorrespondence("id")) + .containsExactlyElementsIn(sqlThreatMatches); + }); + } + + private DomainBase createDomain( + String domainName, String repoId, Registrar registrar, ContactResource contact) { + return new DomainBase.Builder() + .setDomainName(domainName) + .setRepoId(repoId) + .setCreationClientId(registrar.getClientId()) + .setLastEppUpdateTime(fakeClock.nowUtc()) + .setLastEppUpdateClientId(registrar.getClientId()) + .setLastTransferTime(fakeClock.nowUtc()) + .setRegistrant(contact.createVKey()) + .setPersistedCurrentSponsorClientId(registrar.getClientId()) + .setRegistrationExpirationTime(fakeClock.nowUtc().plusYears(1)) + .setAuthInfo(DomainAuthInfo.create(PasswordAuth.create("password"))) + .build(); + } + /** Returns the text contents of a file under the beamBucket/results directory. */ private ImmutableList resultFileContents() throws Exception { File resultFile = diff --git a/core/src/test/java/google/registry/reporting/spec11/GenerateSpec11ReportActionTest.java b/core/src/test/java/google/registry/reporting/spec11/GenerateSpec11ReportActionTest.java index c4b6ec6d0..76d7fcff3 100644 --- a/core/src/test/java/google/registry/reporting/spec11/GenerateSpec11ReportActionTest.java +++ b/core/src/test/java/google/registry/reporting/spec11/GenerateSpec11ReportActionTest.java @@ -50,6 +50,7 @@ class GenerateSpec11ReportActionTest extends BeamActionTestBase { "gs://reporting-project/reporting-bucket/", "api_key/a", clock.nowUtc().toLocalDate(), + "DATASTORE", clock, response, dataflow); @@ -71,6 +72,7 @@ class GenerateSpec11ReportActionTest extends BeamActionTestBase { "gs://reporting-project/reporting-bucket/", "api_key/a", clock.nowUtc().toLocalDate(), + "DATASTORE", clock, response, dataflow);