Add Cloud SQL read to Spec11Pipeline (#1173)

* Add Cloud SQL read to Spec11Pipeline

* Add database option

* Add database parameter

* Add a test of the full pipeline

* Use DatabaseHelper in tests

* restore the original tm

* More test fixes
This commit is contained in:
sarahcaseybot 2021-06-11 14:25:20 -04:00 committed by GitHub
parent bcc26c486e
commit b90865b404
8 changed files with 234 additions and 29 deletions

View file

@ -23,8 +23,10 @@ import dagger.Component;
import dagger.Module;
import dagger.Provides;
import google.registry.beam.common.RegistryJpaIO;
import google.registry.beam.common.RegistryJpaIO.Read;
import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn;
import google.registry.config.RegistryConfig.ConfigModule;
import google.registry.model.domain.DomainBase;
import google.registry.model.reporting.Spec11ThreatMatch;
import google.registry.model.reporting.Spec11ThreatMatch.ThreatType;
import google.registry.util.Retrier;
@ -97,20 +99,9 @@ public class Spec11Pipeline implements Serializable {
void setupPipeline(Pipeline pipeline) {
PCollection<Subdomain> domains =
pipeline.apply(
"Read active domains from BigQuery",
BigQueryIO.read(Subdomain::parseFromRecord)
.fromQuery(
SqlTemplate.create(getQueryFromFile(Spec11Pipeline.class, "subdomains.sql"))
.put("PROJECT_ID", options.getProject())
.put("DATASTORE_EXPORT_DATASET", "latest_datastore_export")
.put("REGISTRAR_TABLE", "Registrar")
.put("DOMAIN_BASE_TABLE", "DomainBase")
.build())
.withCoder(SerializableCoder.of(Subdomain.class))
.usingStandardSql()
.withoutValidation()
.withTemplateCompatibility());
options.getDatabase().equals("DATASTORE")
? readFromBigQuery(options, pipeline)
: readFromCloudSql(pipeline);
PCollection<KV<Subdomain, ThreatMatch>> threatMatches =
domains.apply("Run through SafeBrowsing API", ParDo.of(safeBrowsingFn));
@ -119,6 +110,47 @@ public class Spec11Pipeline implements Serializable {
saveToGcs(threatMatches, options);
}
static PCollection<Subdomain> readFromCloudSql(Pipeline pipeline) {
Read<Object[], Subdomain> read =
RegistryJpaIO.read(
"select d, r.emailAddress from Domain d join Registrar r on"
+ " d.currentSponsorClientId = r.clientIdentifier where r.type = 'REAL'"
+ " and d.deletionTime > now()",
Spec11Pipeline::parseRow);
return pipeline.apply("Read active domains from Cloud SQL", read);
}
static PCollection<Subdomain> readFromBigQuery(Spec11PipelineOptions options, Pipeline pipeline) {
return pipeline.apply(
"Read active domains from BigQuery",
BigQueryIO.read(Subdomain::parseFromRecord)
.fromQuery(
SqlTemplate.create(getQueryFromFile(Spec11Pipeline.class, "subdomains.sql"))
.put("PROJECT_ID", options.getProject())
.put("DATASTORE_EXPORT_DATASET", "latest_datastore_export")
.put("REGISTRAR_TABLE", "Registrar")
.put("DOMAIN_BASE_TABLE", "DomainBase")
.build())
.withCoder(SerializableCoder.of(Subdomain.class))
.usingStandardSql()
.withoutValidation()
.withTemplateCompatibility());
}
private static Subdomain parseRow(Object[] row) {
DomainBase domainBase = (DomainBase) row[0];
String emailAddress = (String) row[1];
if (emailAddress == null) {
emailAddress = "";
}
return Subdomain.create(
domainBase.getDomainName(),
domainBase.getRepoId(),
domainBase.getCurrentSponsorClientId(),
emailAddress);
}
static void saveToSql(
PCollection<KV<Subdomain, ThreatMatch>> threatMatches, Spec11PipelineOptions options) {
String transformId = "Spec11 Threat Matches";

View file

@ -34,4 +34,9 @@ public interface Spec11PipelineOptions extends RegistryPipelineOptions {
String getReportingBucketUrl();
void setReportingBucketUrl(String value);
@Description("The database to read data from.")
String getDatabase();
void setDatabase(String value);
}

View file

@ -14,6 +14,7 @@
package google.registry.reporting;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.request.RequestParameters.extractOptionalParameter;
import static google.registry.request.RequestParameters.extractRequiredParameter;
@ -55,6 +56,9 @@ public class ReportingModule {
/** The request parameter specifying the jobId for a running Dataflow pipeline. */
public static final String PARAM_JOB_ID = "jobId";
/** The request parameter for specifying which database reporting actions should read from. */
public static final String DATABASE = "database";
/** Provides the Cloud Dataflow jobId for a pipeline. */
@Provides
@Parameter(PARAM_JOB_ID)
@ -62,6 +66,14 @@ public class ReportingModule {
return extractRequiredParameter(req, PARAM_JOB_ID);
}
/** Provides the database for the pipeline to read from. */
@Provides
@Parameter(DATABASE)
static String provideDatabase(HttpServletRequest req) {
Optional<String> optionalDatabase = extractOptionalParameter(req, DATABASE);
return optionalDatabase.orElse(tm().isOfy() ? "DATASTORE" : "CLOUD_SQL");
}
/** Extracts an optional YearMonth in yyyy-MM format from the request. */
@Provides
@Parameter(PARAM_YEAR_MONTH)

View file

@ -15,6 +15,8 @@
package google.registry.reporting.spec11;
import static google.registry.beam.BeamUtils.createJobName;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.reporting.ReportingModule.DATABASE;
import static google.registry.reporting.ReportingModule.PARAM_DATE;
import static google.registry.reporting.ReportingUtils.enqueueBeamReportingTask;
import static google.registry.request.Action.Method.POST;
@ -69,6 +71,7 @@ public class GenerateSpec11ReportAction implements Runnable {
private final Clock clock;
private final Response response;
private final Dataflow dataflow;
private final String database;
@Inject
GenerateSpec11ReportAction(
@ -78,15 +81,20 @@ public class GenerateSpec11ReportAction implements Runnable {
@Config("reportingBucketUrl") String reportingBucketUrl,
@Key("safeBrowsingAPIKey") String apiKey,
@Parameter(PARAM_DATE) LocalDate date,
@Parameter(DATABASE) String database,
Clock clock,
Response response,
Dataflow dataflow) {
this.projectId = projectId;
this.jobRegion = jobRegion;
this.stagingBucketUrl = stagingBucketUrl;
if (tm().isOfy() && database.equals("CLOUD_SQL")) {
reportingBucketUrl = reportingBucketUrl.concat("-sql");
}
this.reportingBucketUrl = reportingBucketUrl;
this.apiKey = apiKey;
this.date = date;
this.database = database;
this.clock = clock;
this.response = response;
this.dataflow = dataflow;
@ -105,6 +113,8 @@ public class GenerateSpec11ReportAction implements Runnable {
ImmutableMap.of(
"safeBrowsingApiKey",
apiKey,
"database",
database,
ReportingModule.PARAM_DATE,
date.toString(),
"reportingBucketUrl",

View file

@ -61,6 +61,14 @@
"regexes": [
"^gs:\\/\\/[^\\n\\r]+$"
]
},
{
"name": "database",
"label": "Database to read from.",
"helpText": "DATASTORE or CLOUD_SQL.",
"regexes": [
"^DATASTORE|CLOUD_SQL$"
]
}
]
}

View file

@ -164,7 +164,7 @@ class SafeBrowsingTransformsTest {
* A serializable {@link Answer} that returns a mock HTTP response based on the HTTP request's
* content.
*/
private static class HttpResponder implements Answer<CloseableHttpResponse>, Serializable {
static class HttpResponder implements Answer<CloseableHttpResponse>, Serializable {
@Override
public CloseableHttpResponse answer(InvocationOnMock invocation) throws Throwable {
return getMockResponse(

View file

@ -18,33 +18,60 @@ import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.model.ImmutableObjectSubject.immutableObjectCorrespondence;
import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.testing.AppEngineExtension.makeRegistrar1;
import static google.registry.testing.DatabaseHelper.createTld;
import static google.registry.testing.DatabaseHelper.persistActiveContact;
import static google.registry.testing.DatabaseHelper.persistNewRegistrar;
import static google.registry.testing.DatabaseHelper.persistResource;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.withSettings;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import com.google.common.truth.Correspondence;
import com.google.common.truth.Correspondence.BinaryPredicate;
import google.registry.beam.TestPipelineExtension;
import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn;
import google.registry.beam.spec11.SafeBrowsingTransformsTest.HttpResponder;
import google.registry.model.contact.ContactResource;
import google.registry.model.domain.DomainAuthInfo;
import google.registry.model.domain.DomainBase;
import google.registry.model.eppcommon.AuthInfo.PasswordAuth;
import google.registry.model.registrar.Registrar;
import google.registry.model.reporting.Spec11ThreatMatch;
import google.registry.model.reporting.Spec11ThreatMatch.ThreatType;
import google.registry.model.reporting.Spec11ThreatMatchDao;
import google.registry.persistence.transaction.JpaTestRules;
import google.registry.persistence.transaction.JpaTestRules.JpaIntegrationTestExtension;
import google.registry.persistence.transaction.TransactionManager;
import google.registry.persistence.transaction.TransactionManagerFactory;
import google.registry.testing.DatastoreEntityExtension;
import google.registry.testing.FakeClock;
import google.registry.testing.FakeSleeper;
import google.registry.util.ResourceUtils;
import google.registry.util.Retrier;
import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.impl.client.CloseableHttpClient;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.DateTime;
import org.joda.time.LocalDate;
import org.json.JSONObject;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
@ -60,9 +87,14 @@ import org.junit.jupiter.api.io.TempDir;
*/
class Spec11PipelineTest {
private static final DateTime START_TIME = DateTime.parse("2020-01-27T00:00:00.0Z");
private final FakeClock fakeClock = new FakeClock(START_TIME);
private static final String DATE = "2020-01-27";
private static final String SAFE_BROWSING_API_KEY = "api-key";
private static final String REPORTING_BUCKET_URL = "reporting_bucket";
private final CloseableHttpClient mockHttpClient =
mock(CloseableHttpClient.class, withSettings().serializable());
private static final ImmutableList<Subdomain> SUBDOMAINS =
ImmutableList.of(
@ -103,12 +135,18 @@ class Spec11PipelineTest {
private File reportingBucketUrl;
private PCollection<KV<Subdomain, ThreatMatch>> threatMatches;
ImmutableSet<Spec11ThreatMatch> sqlThreatMatches;
TransactionManager tm;
@BeforeEach
void beforeEach() throws Exception {
tm = tm();
TransactionManagerFactory.setTm(jpaTm());
reportingBucketUrl = Files.createDirectory(tmpDir.resolve(REPORTING_BUCKET_URL)).toFile();
options.setDate(DATE);
options.setSafeBrowsingApiKey(SAFE_BROWSING_API_KEY);
options.setReportingBucketUrl(reportingBucketUrl.getAbsolutePath());
options.setDatabase("DATASTORE");
threatMatches =
pipeline.apply(
Create.of(
@ -118,11 +156,8 @@ class Spec11PipelineTest {
KvCoder.of(
SerializableCoder.of(Subdomain.class),
SerializableCoder.of(ThreatMatch.class))));
}
@Test
void testSuccess_saveToSql() {
ImmutableSet<Spec11ThreatMatch> sqlThreatMatches =
sqlThreatMatches =
ImmutableSet.of(
new Spec11ThreatMatch.Builder()
.setDomainName("111.com")
@ -159,25 +194,98 @@ class Spec11PipelineTest {
.setCheckDate(new LocalDate(2020, 1, 27))
.setThreatTypes(ImmutableSet.of(ThreatType.UNWANTED_SOFTWARE))
.build());
}
@AfterEach
void afterEach() {
TransactionManagerFactory.setTm(tm);
}
@Test
void testSuccess_fullSqlPipeline() throws Exception {
setupCloudSql();
options.setDatabase("CLOUD_SQL");
EvaluateSafeBrowsingFn safeBrowsingFn =
new EvaluateSafeBrowsingFn(
SAFE_BROWSING_API_KEY,
new Retrier(new FakeSleeper(new FakeClock()), 1),
Suppliers.ofInstance(mockHttpClient));
when(mockHttpClient.execute(any(HttpPost.class))).thenAnswer(new HttpResponder());
Spec11Pipeline spec11Pipeline = new Spec11Pipeline(options, safeBrowsingFn);
spec11Pipeline.setupPipeline(pipeline);
pipeline.run(options).waitUntilFinish();
verifySaveToGcs();
verifySaveToCloudSql();
}
@Test
void testSuccess_saveToSql() {
Spec11Pipeline.saveToSql(threatMatches, options);
pipeline.run().waitUntilFinish();
assertThat(
jpaTm()
.transact(
() ->
Spec11ThreatMatchDao.loadEntriesByDate(
jpaTm(), new LocalDate(2020, 1, 27))))
.comparingElementsUsing(immutableObjectCorrespondence("id"))
.containsExactlyElementsIn(sqlThreatMatches);
verifySaveToCloudSql();
}
@Test
void testSuccess_saveToGcs() throws Exception {
Spec11Pipeline.saveToGcs(threatMatches, options);
pipeline.run().waitUntilFinish();
verifySaveToGcs();
}
@Test
void testSuccess_readFromCloudSql() throws Exception {
setupCloudSql();
PCollection<Subdomain> subdomains = Spec11Pipeline.readFromCloudSql(pipeline);
PAssert.that(subdomains).containsInAnyOrder(SUBDOMAINS);
pipeline.run().waitUntilFinish();
}
private void setupCloudSql() {
persistNewRegistrar("TheRegistrar");
persistNewRegistrar("NewRegistrar");
Registrar registrar1 =
persistResource(
makeRegistrar1()
.asBuilder()
.setClientId("hello-registrar")
.setEmailAddress("email@hello.net")
.build());
Registrar registrar2 =
persistResource(
makeRegistrar1()
.asBuilder()
.setClientId("kitty-registrar")
.setEmailAddress("contact@kit.ty")
.build());
Registrar registrar3 =
persistResource(
makeRegistrar1()
.asBuilder()
.setClientId("cool-registrar")
.setEmailAddress("cool@aid.net")
.build());
createTld("com");
createTld("net");
createTld("bank");
createTld("dev");
ContactResource contact1 = persistActiveContact(registrar1.getClientId());
ContactResource contact2 = persistActiveContact(registrar2.getClientId());
ContactResource contact3 = persistActiveContact(registrar3.getClientId());
persistResource(createDomain("111.com", "123456789-COM", registrar1, contact1));
persistResource(createDomain("party-night.net", "2244AABBC-NET", registrar2, contact2));
persistResource(createDomain("bitcoin.bank", "1C3D5E7F9-BANK", registrar1, contact1));
persistResource(createDomain("no-email.com", "2A4BA9BBC-COM", registrar2, contact2));
persistResource(
createDomain("anti-anti-anti-virus.dev", "555666888-DEV", registrar3, contact3));
}
private void verifySaveToGcs() throws Exception {
ImmutableList<String> expectedFileContents =
ImmutableList.copyOf(
ResourceUtils.readResourceUtf8(this.getClass(), "test_output.txt").split("\n"));
Spec11Pipeline.saveToGcs(threatMatches, options);
pipeline.run().waitUntilFinish();
ImmutableList<String> resultFileContents = resultFileContents();
assertThat(resultFileContents.size()).isEqualTo(expectedFileContents.size());
assertThat(resultFileContents.get(0)).isEqualTo(expectedFileContents.get(0));
@ -188,6 +296,34 @@ class Spec11PipelineTest {
.containsExactlyElementsIn(expectedFileContents.subList(1, expectedFileContents.size()));
}
private void verifySaveToCloudSql() {
jpaTm()
.transact(
() -> {
ImmutableList<Spec11ThreatMatch> sqlThreatMatches =
Spec11ThreatMatchDao.loadEntriesByDate(jpaTm(), new LocalDate(2020, 1, 27));
assertThat(sqlThreatMatches)
.comparingElementsUsing(immutableObjectCorrespondence("id"))
.containsExactlyElementsIn(sqlThreatMatches);
});
}
private DomainBase createDomain(
String domainName, String repoId, Registrar registrar, ContactResource contact) {
return new DomainBase.Builder()
.setDomainName(domainName)
.setRepoId(repoId)
.setCreationClientId(registrar.getClientId())
.setLastEppUpdateTime(fakeClock.nowUtc())
.setLastEppUpdateClientId(registrar.getClientId())
.setLastTransferTime(fakeClock.nowUtc())
.setRegistrant(contact.createVKey())
.setPersistedCurrentSponsorClientId(registrar.getClientId())
.setRegistrationExpirationTime(fakeClock.nowUtc().plusYears(1))
.setAuthInfo(DomainAuthInfo.create(PasswordAuth.create("password")))
.build();
}
/** Returns the text contents of a file under the beamBucket/results directory. */
private ImmutableList<String> resultFileContents() throws Exception {
File resultFile =

View file

@ -50,6 +50,7 @@ class GenerateSpec11ReportActionTest extends BeamActionTestBase {
"gs://reporting-project/reporting-bucket/",
"api_key/a",
clock.nowUtc().toLocalDate(),
"DATASTORE",
clock,
response,
dataflow);
@ -71,6 +72,7 @@ class GenerateSpec11ReportActionTest extends BeamActionTestBase {
"gs://reporting-project/reporting-bucket/",
"api_key/a",
clock.nowUtc().toLocalDate(),
"DATASTORE",
clock,
response,
dataflow);