Persist ThreatMatches into Spec11ThreatMatch (#723)

* Replace jpaTm with a JpaSupplierFactory

* Style

* Style

* Pipeline takes in a SerializableSupplier instead

* Change the ordering of imports

* Test a good domain in addition to a bad one

* Rename and check good domain for Transact Answer

* Use standard Mockito verify

* Verify transact call and no more interactions

* Remove Answer comment

* Naming chsnges

* Deploy Spec 11 pipeline correctly

* Fix formatting of deploy file

* Use a file to persist state across Cloud Build steps

Co-authored-by: Gus Brodman <gbrodman@google.com>
This commit is contained in:
Legina Chen 2020-08-03 14:40:00 -07:00 committed by GitHub
parent 917a72e2cb
commit d3098b35a4
6 changed files with 171 additions and 7 deletions

View file

@ -17,15 +17,22 @@ package google.registry.beam.spec11;
import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.withSettings;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.CharStreams;
import google.registry.beam.TestPipelineExtension;
import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn;
import google.registry.model.reporting.Spec11ThreatMatch;
import google.registry.model.reporting.Spec11ThreatMatch.ThreatType;
import google.registry.persistence.transaction.JpaTransactionManager;
import google.registry.testing.FakeClock;
import google.registry.testing.FakeSleeper;
import google.registry.util.GoogleCredentialsBundle;
@ -55,22 +62,42 @@ import org.apache.http.entity.BasicHttpEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.message.BasicStatusLine;
import org.joda.time.DateTime;
import org.joda.time.LocalDate;
import org.joda.time.format.ISODateTimeFormat;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.api.io.TempDir;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import org.mockito.stubbing.Answer;
/** Unit tests for {@link Spec11Pipeline}. */
@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
class Spec11PipelineTest {
private static class SaveNewThreatMatchAnswer implements Answer<Void>, Serializable {
@Override
public Void answer(InvocationOnMock invocation) {
Runnable runnable = invocation.getArgument(0, Runnable.class);
runnable.run();
return null;
}
}
private static PipelineOptions pipelineOptions;
@Mock(serializable = true)
private static JpaTransactionManager mockJpaTm;
@BeforeAll
static void beforeAll() {
pipelineOptions = PipelineOptionsFactory.create();
@ -93,20 +120,23 @@ class Spec11PipelineTest {
void beforeEach() throws IOException {
String beamTempFolder =
Files.createDirectory(tmpDir.resolve("beam_temp")).toAbsolutePath().toString();
spec11Pipeline =
new Spec11Pipeline(
"test-project",
beamTempFolder + "/staging",
beamTempFolder + "/templates/invoicing",
tmpDir.toAbsolutePath().toString(),
() -> mockJpaTm,
GoogleCredentialsBundle.create(GoogleCredentials.create(null)),
retrier);
}
private static final ImmutableList<String> BAD_DOMAINS =
ImmutableList.of("111.com", "222.com", "444.com", "no-email.com");
ImmutableList.of(
"111.com", "222.com", "444.com", "no-email.com", "testThreatMatchToSqlBad.com");
private ImmutableList<Subdomain> getInputDomains() {
private ImmutableList<Subdomain> getInputDomainsJson() {
ImmutableList.Builder<Subdomain> subdomainsBuilder = new ImmutableList.Builder<>();
// Put in at least 2 batches worth (x > 490) to guarantee multiple executions.
// Put in half for theRegistrar and half for someRegistrar
@ -134,17 +164,18 @@ class Spec11PipelineTest {
@SuppressWarnings("unchecked")
void testEndToEndPipeline_generatesExpectedFiles() throws Exception {
// Establish mocks for testing
ImmutableList<Subdomain> inputRows = getInputDomains();
CloseableHttpClient httpClient = mock(CloseableHttpClient.class, withSettings().serializable());
ImmutableList<Subdomain> inputRows = getInputDomainsJson();
CloseableHttpClient mockHttpClient =
mock(CloseableHttpClient.class, withSettings().serializable());
// Return a mock HttpResponse that returns a JSON response based on the request.
when(httpClient.execute(any(HttpPost.class))).thenAnswer(new HttpResponder());
when(mockHttpClient.execute(any(HttpPost.class))).thenAnswer(new HttpResponder());
EvaluateSafeBrowsingFn evalFn =
new EvaluateSafeBrowsingFn(
StaticValueProvider.of("apikey"),
new Retrier(new FakeSleeper(new FakeClock()), 3),
(Serializable & Supplier) () -> httpClient);
(Serializable & Supplier) () -> mockHttpClient);
// Apply input and evaluation transforms
PCollection<Subdomain> input = testPipeline.apply(Create.of(inputRows));
@ -207,6 +238,56 @@ class Spec11PipelineTest {
.toString());
}
@Test
@SuppressWarnings("unchecked")
public void testSpec11ThreatMatchToSql() throws Exception {
doAnswer(new SaveNewThreatMatchAnswer()).when(mockJpaTm).transact(any(Runnable.class));
// Create one bad and one good Subdomain to test with evaluateUrlHealth. Only the bad one should
// be detected and persisted.
Subdomain badDomain =
Subdomain.create(
"testThreatMatchToSqlBad.com", "theDomain", "theRegistrar", "fake@theRegistrar.com");
Subdomain goodDomain =
Subdomain.create(
"testThreatMatchToSqlGood.com",
"someDomain",
"someRegistrar",
"fake@someRegistrar.com");
// Establish a mock HttpResponse that returns a JSON response based on the request.
CloseableHttpClient mockHttpClient =
mock(CloseableHttpClient.class, withSettings().serializable());
when(mockHttpClient.execute(any(HttpPost.class))).thenAnswer(new HttpResponder());
EvaluateSafeBrowsingFn evalFn =
new EvaluateSafeBrowsingFn(
StaticValueProvider.of("apikey"),
new Retrier(new FakeSleeper(new FakeClock()), 3),
(Serializable & Supplier) () -> mockHttpClient);
// Apply input and evaluation transforms
PCollection<Subdomain> input = testPipeline.apply(Create.of(badDomain, goodDomain));
spec11Pipeline.evaluateUrlHealth(input, evalFn, StaticValueProvider.of("2020-06-10"));
testPipeline.run();
// Verify that the expected threat created from the bad Subdomain and the persisted
// Spec11TThreatMatch are equal.
Spec11ThreatMatch expected =
new Spec11ThreatMatch()
.asBuilder()
.setThreatTypes(ImmutableSet.of(ThreatType.MALWARE))
.setCheckDate(LocalDate.parse("2020-06-10", ISODateTimeFormat.date()))
.setDomainName(badDomain.domainName())
.setDomainRepoId(badDomain.domainRepoId())
.setRegistrarId(badDomain.registrarId())
.build();
verify(mockJpaTm).transact(any(Runnable.class));
verify(mockJpaTm).saveNew(expected);
verifyNoMoreInteractions(mockJpaTm);
}
/**
* A serializable {@link Answer} that returns a mock HTTP response based on the HTTP request's
* content.

View file

@ -114,6 +114,8 @@ public class Spec11ThreatMatchTest extends EntityTestCase {
VKey<Spec11ThreatMatch> threatVKey = VKey.createSql(Spec11ThreatMatch.class, threat.getId());
Spec11ThreatMatch persistedThreat = jpaTm().transact(() -> jpaTm().load(threatVKey));
// Threat object saved for the first time doesn't have an ID; it is generated by SQL
threat.id = persistedThreat.id;
assertThat(threat).isEqualTo(persistedThreat);
}