Persist ThreatMatches into Spec11ThreatMatch (#723)

* Replace jpaTm with a JpaSupplierFactory

* Style

* Style

* Pipeline takes in a SerializableSupplier instead

* Change the ordering of imports

* Test a good domain in addition to a bad one

* Rename and check good domain for Transact Answer

* Use standard Mockito verify

* Verify transact call and no more interactions

* Remove Answer comment

* Naming chsnges

* Deploy Spec 11 pipeline correctly

* Fix formatting of deploy file

* Use a file to persist state across Cloud Build steps

Co-authored-by: Gus Brodman <gbrodman@google.com>
This commit is contained in:
Legina Chen 2020-08-03 14:40:00 -07:00 committed by GitHub
parent 917a72e2cb
commit d3098b35a4
6 changed files with 171 additions and 7 deletions

View file

@ -19,9 +19,15 @@ import static google.registry.beam.BeamUtils.getQueryFromFile;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.auto.value.AutoValue;
import com.google.common.collect.ImmutableSet;
import google.registry.backup.AppEngineEnvironment;
import google.registry.beam.initsql.Transforms.SerializableSupplier;
import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn;
import google.registry.config.CredentialModule.LocalCredential;
import google.registry.config.RegistryConfig.Config;
import google.registry.model.reporting.Spec11ThreatMatch;
import google.registry.model.reporting.Spec11ThreatMatch.ThreatType;
import google.registry.persistence.transaction.JpaTransactionManager;
import google.registry.util.GoogleCredentialsBundle;
import google.registry.util.Retrier;
import google.registry.util.SqlTemplate;
@ -37,6 +43,7 @@ import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.ParDo;
@ -46,6 +53,7 @@ import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.joda.time.LocalDate;
import org.joda.time.YearMonth;
import org.joda.time.format.ISODateTimeFormat;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
@ -86,6 +94,7 @@ public class Spec11Pipeline implements Serializable {
private final String reportingBucketUrl;
private final GoogleCredentials googleCredentials;
private final Retrier retrier;
private final SerializableSupplier<JpaTransactionManager> jpaSupplierFactory;
@Inject
public Spec11Pipeline(
@ -93,12 +102,14 @@ public class Spec11Pipeline implements Serializable {
@Config("beamStagingUrl") String beamStagingUrl,
@Config("spec11TemplateUrl") String spec11TemplateUrl,
@Config("reportingBucketUrl") String reportingBucketUrl,
SerializableSupplier<JpaTransactionManager> jpaSupplierFactory,
@LocalCredential GoogleCredentialsBundle googleCredentialsBundle,
Retrier retrier) {
this.projectId = projectId;
this.beamStagingUrl = beamStagingUrl;
this.spec11TemplateUrl = spec11TemplateUrl;
this.reportingBucketUrl = reportingBucketUrl;
this.jpaSupplierFactory = jpaSupplierFactory;
this.googleCredentials = googleCredentialsBundle.getGoogleCredentials();
this.retrier = retrier;
}
@ -177,12 +188,40 @@ public class Spec11Pipeline implements Serializable {
EvaluateSafeBrowsingFn evaluateSafeBrowsingFn,
ValueProvider<String> dateProvider) {
PCollection<KV<Subdomain, ThreatMatch>> subdomainsSql =
domains.apply("Run through SafeBrowsing API", ParDo.of(evaluateSafeBrowsingFn));
/* Store ThreatMatch objects in SQL. */
subdomainsSql.apply(
ParDo.of(
new DoFn<KV<Subdomain, ThreatMatch>, Void>() {
@ProcessElement
public void processElement(ProcessContext context) {
// create the Spec11ThreatMatch from Subdomain and ThreatMatch
try (AppEngineEnvironment env = new AppEngineEnvironment()) {
Subdomain subdomain = context.element().getKey();
Spec11ThreatMatch threatMatch =
new Spec11ThreatMatch.Builder()
.setThreatTypes(
ImmutableSet.of(
ThreatType.valueOf(context.element().getValue().threatType())))
.setCheckDate(
LocalDate.parse(dateProvider.get(), ISODateTimeFormat.date()))
.setDomainName(subdomain.domainName())
.setDomainRepoId(subdomain.domainRepoId())
.setRegistrarId(subdomain.registrarId())
.build();
JpaTransactionManager jpaTransactionManager = jpaSupplierFactory.get();
jpaTransactionManager.transact(() -> jpaTransactionManager.saveNew(threatMatch));
}
}
}));
/* Store ThreatMatch objects in JSON. */
PCollection<KV<Subdomain, ThreatMatch>> subdomainsJson =
domains.apply("Run through SafeBrowsingAPI", ParDo.of(evaluateSafeBrowsingFn));
subdomainsJson
.apply(
"Map registrar client ID to email/ThreatMatch pair",
"Map registrar ID to email/ThreatMatch pair",
MapElements.into(
TypeDescriptors.kvs(
TypeDescriptors.strings(), TypeDescriptor.of(EmailAndThreatMatch.class)))

View file

@ -140,6 +140,16 @@ public class Spec11ThreatMatch extends ImmutableObject implements Buildable, Sql
return super.build();
}
/**
* Manually set the ID for testing or other special circumstances.
*
* <p>In general the ID is generated by SQL and there should be no need to set it manually.
*/
public Builder setId(Long id) {
getInstance().id = id;
return this;
}
public Builder setDomainName(String domainName) {
getInstance().domainName = domainName;
getInstance().tld = DomainNameUtils.getTldFromDomainName(domainName);

View file

@ -14,7 +14,10 @@
package google.registry.tools;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.Parameters;
import google.registry.beam.initsql.BeamJpaModule.JpaTransactionManagerComponent;
import google.registry.beam.initsql.JpaSupplierFactory;
import google.registry.beam.spec11.Spec11Pipeline;
import google.registry.config.CredentialModule.LocalCredential;
import google.registry.config.RegistryConfig.Config;
@ -31,6 +34,12 @@ public class DeploySpec11PipelineCommand implements Command {
@Config("projectId")
String projectId;
@Parameter(
names = {"-p", "--project"},
description = "Cloud KMS project ID",
required = true)
String cloudKmsProjectId;
@Inject
@Config("beamStagingUrl")
String beamStagingUrl;
@ -53,12 +62,19 @@ public class DeploySpec11PipelineCommand implements Command {
@Override
public void run() {
JpaSupplierFactory jpaSupplierFactory =
new JpaSupplierFactory(
sqlAccessInfoFile,
cloudKmsProjectId,
JpaTransactionManagerComponent::cloudSqlJpaTransactionManager);
Spec11Pipeline pipeline =
new Spec11Pipeline(
projectId,
beamStagingUrl,
spec11TemplateUrl,
reportingBucketUrl,
jpaSupplierFactory,
googleCredentialsBundle,
retrier);
pipeline.deploy();

View file

@ -17,15 +17,22 @@ package google.registry.beam.spec11;
import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.withSettings;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.CharStreams;
import google.registry.beam.TestPipelineExtension;
import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn;
import google.registry.model.reporting.Spec11ThreatMatch;
import google.registry.model.reporting.Spec11ThreatMatch.ThreatType;
import google.registry.persistence.transaction.JpaTransactionManager;
import google.registry.testing.FakeClock;
import google.registry.testing.FakeSleeper;
import google.registry.util.GoogleCredentialsBundle;
@ -55,22 +62,42 @@ import org.apache.http.entity.BasicHttpEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.message.BasicStatusLine;
import org.joda.time.DateTime;
import org.joda.time.LocalDate;
import org.joda.time.format.ISODateTimeFormat;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.api.io.TempDir;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import org.mockito.stubbing.Answer;
/** Unit tests for {@link Spec11Pipeline}. */
@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
class Spec11PipelineTest {
private static class SaveNewThreatMatchAnswer implements Answer<Void>, Serializable {
@Override
public Void answer(InvocationOnMock invocation) {
Runnable runnable = invocation.getArgument(0, Runnable.class);
runnable.run();
return null;
}
}
private static PipelineOptions pipelineOptions;
@Mock(serializable = true)
private static JpaTransactionManager mockJpaTm;
@BeforeAll
static void beforeAll() {
pipelineOptions = PipelineOptionsFactory.create();
@ -93,20 +120,23 @@ class Spec11PipelineTest {
void beforeEach() throws IOException {
String beamTempFolder =
Files.createDirectory(tmpDir.resolve("beam_temp")).toAbsolutePath().toString();
spec11Pipeline =
new Spec11Pipeline(
"test-project",
beamTempFolder + "/staging",
beamTempFolder + "/templates/invoicing",
tmpDir.toAbsolutePath().toString(),
() -> mockJpaTm,
GoogleCredentialsBundle.create(GoogleCredentials.create(null)),
retrier);
}
private static final ImmutableList<String> BAD_DOMAINS =
ImmutableList.of("111.com", "222.com", "444.com", "no-email.com");
ImmutableList.of(
"111.com", "222.com", "444.com", "no-email.com", "testThreatMatchToSqlBad.com");
private ImmutableList<Subdomain> getInputDomains() {
private ImmutableList<Subdomain> getInputDomainsJson() {
ImmutableList.Builder<Subdomain> subdomainsBuilder = new ImmutableList.Builder<>();
// Put in at least 2 batches worth (x > 490) to guarantee multiple executions.
// Put in half for theRegistrar and half for someRegistrar
@ -134,17 +164,18 @@ class Spec11PipelineTest {
@SuppressWarnings("unchecked")
void testEndToEndPipeline_generatesExpectedFiles() throws Exception {
// Establish mocks for testing
ImmutableList<Subdomain> inputRows = getInputDomains();
CloseableHttpClient httpClient = mock(CloseableHttpClient.class, withSettings().serializable());
ImmutableList<Subdomain> inputRows = getInputDomainsJson();
CloseableHttpClient mockHttpClient =
mock(CloseableHttpClient.class, withSettings().serializable());
// Return a mock HttpResponse that returns a JSON response based on the request.
when(httpClient.execute(any(HttpPost.class))).thenAnswer(new HttpResponder());
when(mockHttpClient.execute(any(HttpPost.class))).thenAnswer(new HttpResponder());
EvaluateSafeBrowsingFn evalFn =
new EvaluateSafeBrowsingFn(
StaticValueProvider.of("apikey"),
new Retrier(new FakeSleeper(new FakeClock()), 3),
(Serializable & Supplier) () -> httpClient);
(Serializable & Supplier) () -> mockHttpClient);
// Apply input and evaluation transforms
PCollection<Subdomain> input = testPipeline.apply(Create.of(inputRows));
@ -207,6 +238,56 @@ class Spec11PipelineTest {
.toString());
}
@Test
@SuppressWarnings("unchecked")
public void testSpec11ThreatMatchToSql() throws Exception {
doAnswer(new SaveNewThreatMatchAnswer()).when(mockJpaTm).transact(any(Runnable.class));
// Create one bad and one good Subdomain to test with evaluateUrlHealth. Only the bad one should
// be detected and persisted.
Subdomain badDomain =
Subdomain.create(
"testThreatMatchToSqlBad.com", "theDomain", "theRegistrar", "fake@theRegistrar.com");
Subdomain goodDomain =
Subdomain.create(
"testThreatMatchToSqlGood.com",
"someDomain",
"someRegistrar",
"fake@someRegistrar.com");
// Establish a mock HttpResponse that returns a JSON response based on the request.
CloseableHttpClient mockHttpClient =
mock(CloseableHttpClient.class, withSettings().serializable());
when(mockHttpClient.execute(any(HttpPost.class))).thenAnswer(new HttpResponder());
EvaluateSafeBrowsingFn evalFn =
new EvaluateSafeBrowsingFn(
StaticValueProvider.of("apikey"),
new Retrier(new FakeSleeper(new FakeClock()), 3),
(Serializable & Supplier) () -> mockHttpClient);
// Apply input and evaluation transforms
PCollection<Subdomain> input = testPipeline.apply(Create.of(badDomain, goodDomain));
spec11Pipeline.evaluateUrlHealth(input, evalFn, StaticValueProvider.of("2020-06-10"));
testPipeline.run();
// Verify that the expected threat created from the bad Subdomain and the persisted
// Spec11TThreatMatch are equal.
Spec11ThreatMatch expected =
new Spec11ThreatMatch()
.asBuilder()
.setThreatTypes(ImmutableSet.of(ThreatType.MALWARE))
.setCheckDate(LocalDate.parse("2020-06-10", ISODateTimeFormat.date()))
.setDomainName(badDomain.domainName())
.setDomainRepoId(badDomain.domainRepoId())
.setRegistrarId(badDomain.registrarId())
.build();
verify(mockJpaTm).transact(any(Runnable.class));
verify(mockJpaTm).saveNew(expected);
verifyNoMoreInteractions(mockJpaTm);
}
/**
* A serializable {@link Answer} that returns a mock HTTP response based on the HTTP request's
* content.

View file

@ -114,6 +114,8 @@ public class Spec11ThreatMatchTest extends EntityTestCase {
VKey<Spec11ThreatMatch> threatVKey = VKey.createSql(Spec11ThreatMatch.class, threat.getId());
Spec11ThreatMatch persistedThreat = jpaTm().transact(() -> jpaTm().load(threatVKey));
// Threat object saved for the first time doesn't have an ID; it is generated by SQL
threat.id = persistedThreat.id;
assertThat(threat).isEqualTo(persistedThreat);
}

View file

@ -37,6 +37,18 @@ steps:
cat tool-credential.json.enc | base64 -d | gcloud kms decrypt \
--ciphertext-file=- --plaintext-file=tool-credential.json \
--location=global --keyring=nomulus-tool-keyring --key=nomulus-tool-key
# Set the path to the file for sql_access_info to deploy the Spec 11 pipeline
- name: 'gcr.io/$PROJECT_ID/builder:latest'
entrypoint: /bin/bash
args:
- -c
- |
set -e
if [ ${_ENV} == production ]; then
echo "gs://domain-registry-beam/cloudsql/admin_credential.enc" > sql_access_path.txt
else
echo "gs://domain-registry-${_ENV}-beam/cloudsql/admin_credential.enc" > sql_access_path.txt
fi
# Deploy the Spec11 pipeline to GCS.
- name: 'gcr.io/$PROJECT_ID/nomulus-tool:latest'
args:
@ -44,7 +56,11 @@ steps:
- ${_ENV}
- --credential
- tool-credential.json
- --sql_access_info
- `cat sql_access_path.txt`
- deploy_spec11_pipeline
- --project
- $PROJECT_ID
# Deploy the invoicing pipeline to GCS.
- name: 'gcr.io/$PROJECT_ID/nomulus-tool:latest'
args: