diff --git a/java/google/registry/beam/BUILD b/java/google/registry/beam/BUILD index 842dcf407..10890598a 100644 --- a/java/google/registry/beam/BUILD +++ b/java/google/registry/beam/BUILD @@ -8,6 +8,7 @@ java_library( name = "beam", srcs = glob(["*.java"]), deps = [ + "//java/google/registry/util", "@com_google_flogger", "@com_google_flogger_system_backend", "@com_google_guava", diff --git a/java/google/registry/beam/BeamUtils.java b/java/google/registry/beam/BeamUtils.java index faa603a81..583503075 100644 --- a/java/google/registry/beam/BeamUtils.java +++ b/java/google/registry/beam/BeamUtils.java @@ -17,6 +17,8 @@ package google.registry.beam; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.flogger.FluentLogger; +import com.google.common.io.Resources; +import google.registry.util.ResourceUtils; import org.apache.avro.generic.GenericRecord; import org.apache.beam.sdk.io.gcp.bigquery.SchemaAndRecord; @@ -54,4 +56,12 @@ public class BeamUtils { missingFieldList, record)); } } + + /** + * Returns the {@link String} contents for a file in the {@code sql/} directory relative to a + * class. + */ + public static String getQueryFromFile(Class clazz, String filename) { + return ResourceUtils.readResourceUtf8(Resources.getResource(clazz, "sql/" + filename)); + } } diff --git a/java/google/registry/beam/invoicing/InvoicingUtils.java b/java/google/registry/beam/invoicing/InvoicingUtils.java index 77c2a3c9c..07cabeb2d 100644 --- a/java/google/registry/beam/invoicing/InvoicingUtils.java +++ b/java/google/registry/beam/invoicing/InvoicingUtils.java @@ -14,8 +14,8 @@ package google.registry.beam.invoicing; -import com.google.common.io.Resources; -import google.registry.util.ResourceUtils; +import static google.registry.beam.BeamUtils.getQueryFromFile; + import google.registry.util.SqlTemplate; import java.time.LocalDateTime; import java.time.LocalTime; @@ -91,7 +91,7 @@ public class InvoicingUtils { LocalDateTime firstMoment = reportingMonth.atDay(1).atTime(LocalTime.MIDNIGHT); LocalDateTime lastMoment = reportingMonth.atEndOfMonth().atTime(LocalTime.MAX); // Construct the month's query by filling in the billing_events.sql template - return SqlTemplate.create(getQueryFromFile("billing_events.sql")) + return SqlTemplate.create(getQueryFromFile(InvoicingPipeline.class, "billing_events.sql")) .put("FIRST_TIMESTAMP_OF_MONTH", firstMoment.format(TIMESTAMP_FORMATTER)) .put("LAST_TIMESTAMP_OF_MONTH", lastMoment.format(TIMESTAMP_FORMATTER)) .put("PROJECT_ID", projectId) @@ -103,10 +103,4 @@ public class InvoicingUtils { .build(); }); } - - /** Returns the {@link String} contents for a file in the {@code beam/sql/} directory. */ - private static String getQueryFromFile(String filename) { - return ResourceUtils.readResourceUtf8( - Resources.getResource(InvoicingUtils.class, "sql/" + filename)); - } } diff --git a/java/google/registry/beam/spec11/BUILD b/java/google/registry/beam/spec11/BUILD index 99c55d0fd..eefeb4c59 100644 --- a/java/google/registry/beam/spec11/BUILD +++ b/java/google/registry/beam/spec11/BUILD @@ -11,6 +11,7 @@ java_library( deps = [ "//java/google/registry/beam", "//java/google/registry/config", + "//java/google/registry/util", "@com_google_auto_value", "@com_google_dagger", "@com_google_flogger", diff --git a/java/google/registry/beam/spec11/SafeBrowsingTransforms.java b/java/google/registry/beam/spec11/SafeBrowsingTransforms.java index 932fc0aca..e5b253d81 100644 --- a/java/google/registry/beam/spec11/SafeBrowsingTransforms.java +++ b/java/google/registry/beam/spec11/SafeBrowsingTransforms.java @@ -19,9 +19,10 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static org.apache.http.HttpStatus.SC_OK; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.flogger.FluentLogger; import com.google.common.io.CharStreams; +import google.registry.util.Retrier; import java.io.IOException; import java.io.InputStreamReader; import java.io.Serializable; @@ -62,7 +63,7 @@ public class SafeBrowsingTransforms { * * @see Lookup API */ - static class EvaluateSafeBrowsingFn extends DoFn> { + static class EvaluateSafeBrowsingFn extends DoFn> { /** * Max number of urls we can check in a single query. @@ -75,8 +76,8 @@ public class SafeBrowsingTransforms { private final ValueProvider apiKeyProvider; /** - * Maps a subdomain's HTTP URL to its corresponding {@link Subdomain} to facilitate batching - * SafeBrowsing API requests. + * Maps a subdomain's {@code fullyQualifiedDomainName} to its corresponding {@link Subdomain} to + * facilitate batching SafeBrowsing API requests. */ private final Map subdomainBuffer = new LinkedHashMap<>(BATCH_SIZE); @@ -88,6 +89,9 @@ public class SafeBrowsingTransforms { */ private final Supplier closeableHttpClientSupplier; + /** Retries on receiving transient failures such as {@link IOException}. */ + private final Retrier retrier; + /** * Constructs a {@link EvaluateSafeBrowsingFn} that gets its API key from the given provider. * @@ -99,8 +103,9 @@ public class SafeBrowsingTransforms { * @param apiKeyProvider provides the SafeBrowsing API key from {@code KMS} at runtime */ @SuppressWarnings("unchecked") - EvaluateSafeBrowsingFn(ValueProvider apiKeyProvider) { + EvaluateSafeBrowsingFn(ValueProvider apiKeyProvider, Retrier retrier) { this.apiKeyProvider = apiKeyProvider; + this.retrier = retrier; this.closeableHttpClientSupplier = (Supplier & Serializable) HttpClients::createDefault; } @@ -113,8 +118,11 @@ public class SafeBrowsingTransforms { @VisibleForTesting @SuppressWarnings("unchecked") EvaluateSafeBrowsingFn( - ValueProvider apiKeyProvider, Supplier clientSupplier) { + ValueProvider apiKeyProvider, + Retrier retrier, + Supplier clientSupplier) { this.apiKeyProvider = apiKeyProvider; + this.retrier = retrier; this.closeableHttpClientSupplier = clientSupplier; } @@ -122,7 +130,7 @@ public class SafeBrowsingTransforms { @FinishBundle public void finishBundle(FinishBundleContext context) { if (!subdomainBuffer.isEmpty()) { - ImmutableList> results = evaluateAndFlush(); + ImmutableSet> results = evaluateAndFlush(); results.forEach((kv) -> context.output(kv, Instant.now(), GlobalWindow.INSTANCE)); } } @@ -134,11 +142,9 @@ public class SafeBrowsingTransforms { @ProcessElement public void processElement(ProcessContext context) { Subdomain subdomain = context.element(); - // We put HTTP URLs into the buffer because the API requires specifying the protocol. - subdomainBuffer.put( - String.format("http://%s", subdomain.fullyQualifiedDomainName()), subdomain); + subdomainBuffer.put(subdomain.fullyQualifiedDomainName(), subdomain); if (subdomainBuffer.size() >= BATCH_SIZE) { - ImmutableList> results = evaluateAndFlush(); + ImmutableSet> results = evaluateAndFlush(); results.forEach(context::output); } } @@ -149,8 +155,8 @@ public class SafeBrowsingTransforms { * *

If a {@link Subdomain} is safe according to the API, it will not emit a report. */ - private ImmutableList> evaluateAndFlush() { - ImmutableList.Builder> resultBuilder = new ImmutableList.Builder<>(); + private ImmutableSet> evaluateAndFlush() { + ImmutableSet.Builder> resultBuilder = new ImmutableSet.Builder<>(); try { URIBuilder uriBuilder = new URIBuilder(SAFE_BROWSING_URL); // Add the API key param @@ -161,17 +167,18 @@ public class SafeBrowsingTransforms { JSONObject requestBody = createRequestBody(); httpPost.setEntity(new ByteArrayEntity(requestBody.toString().getBytes(UTF_8))); - - try (CloseableHttpClient client = closeableHttpClientSupplier.get(); - CloseableHttpResponse response = client.execute(httpPost)) { - processResponse(response, resultBuilder); - } - } catch (URISyntaxException | JSONException e) { - // TODO(b/112354588): also send an alert e-mail to indicate the pipeline failed - logger.atSevere().withCause(e).log( - "Caught parsing error during execution, skipping batch."); - } catch (IOException e) { - logger.atSevere().withCause(e).log("Caught IOException during processing, skipping batch."); + // Retry transient exceptions such as IOException + retrier.callWithRetry( + () -> { + try (CloseableHttpClient client = closeableHttpClientSupplier.get(); + CloseableHttpResponse response = client.execute(httpPost)) { + processResponse(response, resultBuilder); + } + }, + IOException.class); + } catch (URISyntaxException | JSONException e) { + // Fail the pipeline on a parsing exception- this indicates the API likely changed. + throw new RuntimeException("Caught parsing exception, failing pipeline.", e); } finally { // Flush the buffer subdomainBuffer.clear(); @@ -206,12 +213,13 @@ public class SafeBrowsingTransforms { } /** - * Iterates through all threat matches in the API response and adds them to the resultBuilder. + * Iterates through all threat matches in the API response and adds them to the {@code + * resultBuilder}. */ private void processResponse( - CloseableHttpResponse response, ImmutableList.Builder> resultBuilder) + CloseableHttpResponse response, + ImmutableSet.Builder> resultBuilder) throws JSONException, IOException { - int statusCode = response.getStatusLine().getStatusCode(); if (statusCode != SC_OK) { logger.atWarning().log("Got unexpected status code %s from response", statusCode); @@ -230,7 +238,9 @@ public class SafeBrowsingTransforms { for (int i = 0; i < threatMatches.length(); i++) { JSONObject match = threatMatches.getJSONObject(i); String url = match.getJSONObject("threat").getString("url"); - resultBuilder.add(KV.of(subdomainBuffer.get(url), match.toString())); + Subdomain subdomain = subdomainBuffer.get(url); + resultBuilder.add( + KV.of(subdomain, ThreatMatch.create(match, subdomain.fullyQualifiedDomainName()))); } } } diff --git a/java/google/registry/beam/spec11/Spec11Pipeline.java b/java/google/registry/beam/spec11/Spec11Pipeline.java index 80f3a9366..60ee35bb4 100644 --- a/java/google/registry/beam/spec11/Spec11Pipeline.java +++ b/java/google/registry/beam/spec11/Spec11Pipeline.java @@ -14,8 +14,12 @@ package google.registry.beam.spec11; +import static google.registry.beam.BeamUtils.getQueryFromFile; + import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn; import google.registry.config.RegistryConfig.Config; +import google.registry.util.Retrier; +import google.registry.util.SqlTemplate; import java.io.Serializable; import javax.inject.Inject; import org.apache.beam.runners.dataflow.DataflowRunner; @@ -27,10 +31,17 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.Sample; -import org.apache.beam.sdk.transforms.ToString; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; /** * Definition of a Dataflow pipeline template, which generates a given month's spec11 report. @@ -60,6 +71,9 @@ public class Spec11Pipeline implements Serializable { @Config("spec11BucketUrl") String spec11BucketUrl; + @Inject + Retrier retrier; + @Inject Spec11Pipeline() {} @@ -106,13 +120,21 @@ public class Spec11Pipeline implements Serializable { "Read active domains from BigQuery", BigQueryIO.read(Subdomain::parseFromRecord) .fromQuery( - // This query must be customized for your own use. - "SELECT * FROM YOUR_TABLE_HERE") + SqlTemplate.create(getQueryFromFile(Spec11Pipeline.class, "subdomains.sql")) + .put("PROJECT_ID", projectId) + .put("DATASTORE_EXPORT_DATASET", "latest_datastore_export") + .put("REGISTRAR_TABLE", "Registrar") + .put("DOMAIN_BASE_TABLE", "DomainBase") + .build()) .withCoder(SerializableCoder.of(Subdomain.class)) .usingStandardSql() .withoutValidation() .withTemplateCompatibility()); - evaluateUrlHealth(domains, new EvaluateSafeBrowsingFn(options.getSafeBrowsingApiKey())); + + evaluateUrlHealth( + domains, + new EvaluateSafeBrowsingFn(options.getSafeBrowsingApiKey(), retrier), + options.getYearMonth()); p.run(); } @@ -122,21 +144,51 @@ public class Spec11Pipeline implements Serializable { *

This is factored out to facilitate testing. */ void evaluateUrlHealth( - PCollection domains, EvaluateSafeBrowsingFn evaluateSafeBrowsingFn) { + PCollection domains, + EvaluateSafeBrowsingFn evaluateSafeBrowsingFn, + ValueProvider yearMonthProvider) { domains - // TODO(b/111545355): Remove this limiter once we're confident we won't go over quota. - .apply( - "Get just a few representative samples for now, don't want to overwhelm our quota", - Sample.any(1000)) .apply("Run through SafeBrowsingAPI", ParDo.of(evaluateSafeBrowsingFn)) - .apply("Convert results to string", ToString.elements()) + .apply( + "Map registrar e-mail to ThreatMatch", + MapElements.into( + TypeDescriptors.kvs( + TypeDescriptors.strings(), TypeDescriptor.of(ThreatMatch.class))) + .via( + (KV kv) -> + KV.of(kv.getKey().registrarEmailAddress(), kv.getValue()))) + .apply("Group by registrar email address", GroupByKey.create()) + .apply( + "Convert results to JSON format", + MapElements.into(TypeDescriptors.strings()) + .via( + (KV> kv) -> { + JSONObject output = new JSONObject(); + try { + output.put("registrarEmailAddress", kv.getKey()); + JSONArray threatMatches = new JSONArray(); + for (ThreatMatch match : kv.getValue()) { + threatMatches.put(match.toJSON()); + } + output.put("threatMatches", threatMatches); + return output.toString(); + } catch (JSONException e) { + throw new RuntimeException( + String.format( + "Encountered an error constructing the JSON for %s", kv.toString()), + e); + } + })) .apply( "Output to text file", TextIO.write() - // TODO(b/111545355): Replace this with a templated directory based on yearMonth - .to(spec11BucketUrl) + .to( + NestedValueProvider.of( + yearMonthProvider, + yearMonth -> + String.format( + "%s/%s/%s-monthly-report", spec11BucketUrl, yearMonth, yearMonth))) .withoutSharding() - .withHeader("HELLO WORLD")); + .withHeader("Map from registrar email to detected subdomain threats:")); } - } diff --git a/java/google/registry/beam/spec11/Subdomain.java b/java/google/registry/beam/spec11/Subdomain.java index 367401c0b..4a436592f 100644 --- a/java/google/registry/beam/spec11/Subdomain.java +++ b/java/google/registry/beam/spec11/Subdomain.java @@ -22,9 +22,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.flogger.FluentLogger; import java.io.Serializable; -import java.time.Instant; -import java.time.ZoneId; -import java.time.ZonedDateTime; import org.apache.avro.generic.GenericRecord; import org.apache.beam.sdk.io.gcp.bigquery.SchemaAndRecord; @@ -42,14 +39,14 @@ public abstract class Subdomain implements Serializable { private static final FluentLogger logger = FluentLogger.forEnclosingClass(); private static final ImmutableList FIELD_NAMES = - ImmutableList.of("fullyQualifiedDomainName", "statuses", "creationTime"); + ImmutableList.of("fullyQualifiedDomainName", "registrarName", "registrarEmailAddress"); /** Returns the fully qualified domain name. */ abstract String fullyQualifiedDomainName(); - /** Returns the UTC DateTime this domain was created. */ - abstract ZonedDateTime creationTime(); - /** Returns the space-delimited list of statuses on this domain. */ - abstract String statuses(); + /** Returns the name of the associated registrar for this domain. */ + abstract String registrarName(); + /** Returns the email address of the registrar associated with this domain. */ + abstract String registrarEmailAddress(); /** * Constructs a {@link Subdomain} from an Apache Avro {@code SchemaAndRecord}. @@ -63,10 +60,8 @@ public abstract class Subdomain implements Serializable { GenericRecord record = schemaAndRecord.getRecord(); return create( extractField(record, "fullyQualifiedDomainName"), - // Bigquery provides UNIX timestamps with microsecond precision. - Instant.ofEpochMilli(Long.parseLong(extractField(record, "creationTime")) / 1000) - .atZone(ZoneId.of("UTC")), - extractField(record, "statuses")); + extractField(record, "registrarName"), + extractField(record, "registrarEmailAddress")); } /** @@ -77,8 +72,8 @@ public abstract class Subdomain implements Serializable { */ @VisibleForTesting static Subdomain create( - String fullyQualifiedDomainName, ZonedDateTime creationTime, String statuses) { - return new AutoValue_Subdomain(fullyQualifiedDomainName, creationTime, statuses); + String fullyQualifiedDomainName, String registrarName, String registrarEmailAddress) { + return new AutoValue_Subdomain(fullyQualifiedDomainName, registrarName, registrarEmailAddress); } } diff --git a/java/google/registry/beam/spec11/ThreatMatch.java b/java/google/registry/beam/spec11/ThreatMatch.java new file mode 100644 index 000000000..49c420089 --- /dev/null +++ b/java/google/registry/beam/spec11/ThreatMatch.java @@ -0,0 +1,72 @@ +// Copyright 2018 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.beam.spec11; + +import com.google.auto.value.AutoValue; +import java.io.Serializable; +import org.json.JSONException; +import org.json.JSONObject; + +/** A POJO representing a threat match response from the {@code SafeBrowsing API}. */ +@AutoValue +public abstract class ThreatMatch implements Serializable { + + private static final String THREAT_TYPE_FIELD = "threatType"; + private static final String PLATFORM_TYPE_FIELD = "platformType"; + private static final String METADATA_FIELD = "threatEntryMetadata"; + + /** Returns what kind of threat it is (malware, phishing etc.) */ + abstract String threatType(); + /** Returns what platforms it affects (Windows, Linux etc.) */ + abstract String platformType(); + /** + * Returns a String representing a JSON Object containing arbitrary metadata associated with this + * threat, or "NONE" if there is no metadata to retrieve. + * + *

This ideally would be a {@link JSONObject} type, but can't be due to serialization + * requirements. + */ + abstract String metadata(); + /** Returns the fully qualified domain name [SLD].[TLD] of the matched threat. */ + abstract String fullyQualifiedDomainName(); + + /** + * Constructs a {@link ThreatMatch} by parsing a {@code SafeBrowsing API} response {@link + * JSONObject}. + * + * @throws JSONException when encountering parse errors in the response format + */ + static ThreatMatch create(JSONObject threatMatchJSON, String fullyQualifiedDomainName) + throws JSONException { + return new AutoValue_ThreatMatch( + threatMatchJSON.getString(THREAT_TYPE_FIELD), + threatMatchJSON.getString(PLATFORM_TYPE_FIELD), + threatMatchJSON.has(METADATA_FIELD) + ? threatMatchJSON.getJSONObject(METADATA_FIELD).toString() + : "NONE", + fullyQualifiedDomainName); + } + + /** Returns a {@link String} containing the simplest details about this threat. */ + String getSimpleDetails() { + return String.format("%s;%s", this.fullyQualifiedDomainName(), this.threatType()); + } + /** Returns a {@link JSONObject} representing a subset of this object's data. */ + JSONObject toJSON() throws JSONException { + return new JSONObject() + .put("fullyQualifiedDomainName", fullyQualifiedDomainName()) + .put("threatType", threatType()); + } +} diff --git a/java/google/registry/beam/spec11/sql/subdomains.sql b/java/google/registry/beam/spec11/sql/subdomains.sql new file mode 100644 index 000000000..46b51e30a --- /dev/null +++ b/java/google/registry/beam/spec11/sql/subdomains.sql @@ -0,0 +1,49 @@ +#standardSQL + -- Copyright 2018 The Nomulus Authors. All Rights Reserved. + -- + -- Licensed under the Apache License, Version 2.0 (the "License"); + -- you may not use this file except in compliance with the License. + -- You may obtain a copy of the License at + -- + -- http://www.apache.org/licenses/LICENSE-2.0 + -- + -- Unless required by applicable law or agreed to in writing, software + -- distributed under the License is distributed on an "AS IS" BASIS, + -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + -- See the License for the specific language governing permissions and + -- limitations under the License. + + -- This query gathers all Subdomains active within a given yearMonth + -- and emits a row containing its fully qualified domain name + -- [SLD].[TLD], the current registrar's name, and the current registrar's + -- email address. + +SELECT + domain.fullyQualifiedDomainName AS fullyQualifiedDomainName, + registrar.name AS registrarName, + registrar.emailAddress AS registrarEmailAddress +FROM ( ( + SELECT + fullyQualifiedDomainName, + currentSponsorClientId, + creationTime + FROM + `%PROJECT_ID%.%DATASTORE_EXPORT_DATASET%.%DOMAIN_BASE_TABLE%` + WHERE + -- Only include active registrations + -- Registrations that are active (not deleted) will have null deletionTime + -- because END_OF_TIME is an invalid timestamp in standardSQL + (SAFE_CAST(deletionTime AS STRING) IS NULL + OR deletionTime > CURRENT_TIMESTAMP)) AS domain + JOIN ( + SELECT + __key__.name AS name, + emailAddress + FROM + `%PROJECT_ID%.%DATASTORE_EXPORT_DATASET%.%REGISTRAR_TABLE%` + WHERE + type = 'REAL') AS registrar + ON + domain.currentSponsorClientId = registrar.name) +ORDER BY + creationTime DESC diff --git a/javatests/google/registry/beam/spec11/BUILD b/javatests/google/registry/beam/spec11/BUILD index 2708100f0..539a9fff5 100644 --- a/javatests/google/registry/beam/spec11/BUILD +++ b/javatests/google/registry/beam/spec11/BUILD @@ -26,6 +26,7 @@ java_library( "@org_apache_beam_sdks_java_io_google_cloud_platform", "@org_apache_httpcomponents_httpclient", "@org_apache_httpcomponents_httpcore", + "@org_json", "@org_mockito_all", ], ) diff --git a/javatests/google/registry/beam/spec11/Spec11PipelineTest.java b/javatests/google/registry/beam/spec11/Spec11PipelineTest.java index 01aa048f3..639f09b7b 100644 --- a/javatests/google/registry/beam/spec11/Spec11PipelineTest.java +++ b/javatests/google/registry/beam/spec11/Spec11PipelineTest.java @@ -24,7 +24,10 @@ import static org.mockito.Mockito.withSettings; import com.google.common.collect.ImmutableList; import com.google.common.io.CharStreams; import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn; +import google.registry.testing.FakeClock; +import google.registry.testing.FakeSleeper; import google.registry.util.ResourceUtils; +import google.registry.util.Retrier; import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; @@ -32,8 +35,7 @@ import java.io.InputStreamReader; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; -import java.time.ZoneId; -import java.time.ZonedDateTime; +import java.util.Comparator; import java.util.function.Supplier; import org.apache.beam.runners.direct.DirectRunner; import org.apache.beam.sdk.options.PipelineOptions; @@ -48,6 +50,9 @@ import org.apache.http.client.methods.HttpPost; import org.apache.http.entity.BasicHttpEntity; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.message.BasicStatusLine; +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; @@ -55,6 +60,7 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; /** Unit tests for {@link Spec11Pipeline}. */ @@ -78,21 +84,26 @@ public class Spec11PipelineTest { public void initializePipeline() throws IOException { spec11Pipeline = new Spec11Pipeline(); spec11Pipeline.projectId = "test-project"; - spec11Pipeline.spec11BucketUrl = tempFolder.getRoot().getAbsolutePath() + "/results"; + spec11Pipeline.spec11BucketUrl = tempFolder.getRoot().getAbsolutePath(); File beamTempFolder = tempFolder.newFolder(); spec11Pipeline.beamStagingUrl = beamTempFolder.getAbsolutePath() + "/staging"; spec11Pipeline.spec11TemplateUrl = beamTempFolder.getAbsolutePath() + "/templates/invoicing"; } + private static final ImmutableList BAD_DOMAINS = + ImmutableList.of("111.com", "222.com", "444.com"); + private ImmutableList getInputDomains() { ImmutableList.Builder subdomainsBuilder = new ImmutableList.Builder<>(); - // Put in 2 batches worth (490 < max < 490*2) to get one positive and one negative example. - for (int i = 0; i < 510; i++) { + // Put in at least 2 batches worth (x > 490) to guarantee multiple executions. + // Put in half for theRegistrar and half for someRegistrar + for (int i = 0; i < 255; i++) { subdomainsBuilder.add( - Subdomain.create( - String.format("%s.com", i), - ZonedDateTime.of(2017, 9, 29, 0, 0, 0, 0, ZoneId.of("UTC")), - "OK")); + Subdomain.create(String.format("%s.com", i), "theRegistrar", "fake@theRegistrar.com")); + } + for (int i = 255; i < 510; i++) { + subdomainsBuilder.add( + Subdomain.create(String.format("%s.com", i), "someRegistrar", "fake@someRegistrar.com")); } return subdomainsBuilder.build(); } @@ -109,75 +120,124 @@ public class Spec11PipelineTest { // Establish mocks for testing ImmutableList inputRows = getInputDomains(); CloseableHttpClient httpClient = mock(CloseableHttpClient.class, withSettings().serializable()); - CloseableHttpResponse negativeResponse = - mock(CloseableHttpResponse.class, withSettings().serializable()); - CloseableHttpResponse positiveResponse = - mock(CloseableHttpResponse.class, withSettings().serializable()); - // Tailor the fake API's response based on whether or not it contains the "bad url" 111.com - when(httpClient.execute(any(HttpPost.class))) - .thenAnswer( - (Answer & Serializable) - (i) -> { - String request = - CharStreams.toString( - new InputStreamReader( - ((HttpPost) i.getArguments()[0]).getEntity().getContent(), UTF_8)); - if (request.contains("http://111.com")) { - return positiveResponse; - } else { - return negativeResponse; - } - }); - when(negativeResponse.getStatusLine()) - .thenReturn(new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "Done")); - when(negativeResponse.getEntity()).thenReturn(new FakeHttpEntity("{}")); - when(positiveResponse.getStatusLine()) - .thenReturn(new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "Done")); - when(positiveResponse.getEntity()) - .thenReturn(new FakeHttpEntity(getBadUrlMatch("http://111.com"))); + // Return a mock HttpResponse that returns a JSON response based on the request. + when(httpClient.execute(any(HttpPost.class))).thenAnswer(new HttpResponder()); + EvaluateSafeBrowsingFn evalFn = new EvaluateSafeBrowsingFn( - StaticValueProvider.of("apikey"), (Serializable & Supplier) () -> httpClient); + StaticValueProvider.of("apikey"), + new Retrier(new FakeSleeper(new FakeClock()), 3), + (Serializable & Supplier) () -> httpClient); // Apply input and evaluation transforms PCollection input = p.apply(Create.of(inputRows)); - spec11Pipeline.evaluateUrlHealth(input, evalFn); + spec11Pipeline.evaluateUrlHealth(input, evalFn, StaticValueProvider.of("2018-06")); p.run(); - // Verify output of text file + // Verify header and 3 threat matches for 2 registrars are found ImmutableList generatedReport = resultFileContents(); - // TODO(b/80524726): Rigorously test this output once the pipeline output is finalized. - assertThat(generatedReport).hasSize(2); - assertThat(generatedReport.get(1)).contains("http://111.com"); + assertThat(generatedReport).hasSize(3); + assertThat(generatedReport.get(0)) + .isEqualTo("Map from registrar email to detected subdomain threats:"); + // The output file can put the registrar emails and bad URLs in any order. + // So we sort by length (sorry) to put the shorter JSON first. + ImmutableList sortedLines = + generatedReport + .subList(1, 3) + .stream() + .sorted(Comparator.comparingInt(String::length)) + .collect(ImmutableList.toImmutableList()); + + JSONObject someRegistrarJSON = new JSONObject(sortedLines.get(0)); + assertThat(someRegistrarJSON.get("registrarEmailAddress")).isEqualTo("fake@someRegistrar.com"); + assertThat(someRegistrarJSON.has("threatMatches")).isTrue(); + JSONArray someThreatMatch = someRegistrarJSON.getJSONArray("threatMatches"); + assertThat(someThreatMatch.length()).isEqualTo(1); + assertThat(someThreatMatch.getJSONObject(0).get("fullyQualifiedDomainName")) + .isEqualTo("444.com"); + assertThat(someThreatMatch.getJSONObject(0).get("threatType")) + .isEqualTo("MALWARE"); + + // theRegistrar has two ThreatMatches, we have to parse it explicitly + JSONObject theRegistrarJSON = new JSONObject(sortedLines.get(1)); + assertThat(theRegistrarJSON.get("registrarEmailAddress")).isEqualTo("fake@theRegistrar.com"); + assertThat(theRegistrarJSON.has("threatMatches")).isTrue(); + JSONArray theThreatMatches = theRegistrarJSON.getJSONArray("threatMatches"); + assertThat(theThreatMatches.length()).isEqualTo(2); + ImmutableList threatMatchStrings = + ImmutableList.of( + theThreatMatches.getJSONObject(0).toString(), + theThreatMatches.getJSONObject(1).toString()); + assertThat(threatMatchStrings) + .containsExactly( + new JSONObject() + .put("fullyQualifiedDomainName", "111.com") + .put("threatType", "MALWARE") + .toString(), + new JSONObject() + .put("fullyQualifiedDomainName", "222.com") + .put("threatType", "MALWARE") + .toString()); } - /** Returns the text contents of a file under the beamBucket/results directory. */ - private ImmutableList resultFileContents() throws Exception { - File resultFile = new File(String.format("%s/results", tempFolder.getRoot().getAbsolutePath())); - return ImmutableList.copyOf( - ResourceUtils.readResourceUtf8(resultFile.toURI().toURL()).split("\n")); + /** + * A serializable {@link Answer} that returns a mock HTTP response based on the HTTP request's + * content. + */ + private static class HttpResponder implements Answer, Serializable { + @Override + public CloseableHttpResponse answer(InvocationOnMock invocation) throws Throwable { + return getMockResponse( + CharStreams.toString( + new InputStreamReader( + ((HttpPost) invocation.getArguments()[0]).getEntity().getContent(), UTF_8))); + } } - /** Returns a filled-in template for threat detected at a given url. */ - private static String getBadUrlMatch(String url) { - return "{\n" - + " \"matches\": [{\n" - + " \"threatType\": \"MALWARE\",\n" - + " \"platformType\": \"WINDOWS\",\n" - + " \"threatEntryType\": \"URL\",\n" - + String.format(" \"threat\": {\"url\": \"%s\"},\n", url) - + " \"threatEntryMetadata\": {\n" - + " \"entries\": [{\n" - + " \"key\": \"malware_threat_type\",\n" - + " \"value\": \"landing\"\n" - + " }]\n" - + " },\n" - + " \"cacheDuration\": \"300.000s\"\n" - + " }," - + "]\n" - + "}"; + /** + * Returns a {@link CloseableHttpResponse} containing either positive (threat found) or negative + * (no threat) API examples based on the request data. + */ + private static CloseableHttpResponse getMockResponse(String request) throws JSONException { + // Determine which bad URLs are in the request (if any) + ImmutableList badUrls = + BAD_DOMAINS.stream().filter(request::contains).collect(ImmutableList.toImmutableList()); + + CloseableHttpResponse httpResponse = + mock(CloseableHttpResponse.class, withSettings().serializable()); + when(httpResponse.getStatusLine()) + .thenReturn( + new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "Done")); + when(httpResponse.getEntity()) + .thenReturn(new FakeHttpEntity(getAPIResponse(badUrls))); + return httpResponse; + } + + /** + * Returns the expected API response for a list of bad URLs. + * + *

If there are no badUrls in the list, this returns the empty JSON string "{}". + */ + private static String getAPIResponse(ImmutableList badUrls) throws JSONException { + JSONObject response = new JSONObject(); + if (badUrls.isEmpty()) { + return response.toString(); + } + // Create a threatMatch for each badUrl + JSONArray matches = new JSONArray(); + for (String badUrl : badUrls) { + matches.put( + new JSONObject() + .put("threatType", "MALWARE") + .put("platformType", "WINDOWS") + .put("threatEntryType", "URL") + .put("threat", new JSONObject().put("url", badUrl)) + .put("cacheDuration", "300.000s")); + } + response.put("matches", matches); + return response.toString(); } /** A serializable HttpEntity fake that returns {@link String} content. */ @@ -191,6 +251,12 @@ public class Spec11PipelineTest { oos.defaultWriteObject(); } + /** + * Sets the {@link FakeHttpEntity} content upon deserialization. + * + *

This allows us to use {@link #getContent()} as-is, fully emulating the behavior of {@link + * BasicHttpEntity} regardless of serialization. + */ private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException { ois.defaultReadObject(); super.setContent(new ByteArrayInputStream(this.content.getBytes(UTF_8))); @@ -198,6 +264,18 @@ public class Spec11PipelineTest { FakeHttpEntity(String content) { this.content = content; + super.setContent(new ByteArrayInputStream(this.content.getBytes(UTF_8))); } } + + /** Returns the text contents of a file under the beamBucket/results directory. */ + private ImmutableList resultFileContents() throws Exception { + File resultFile = + new File( + String.format( + "%s/2018-06/2018-06-monthly-report", tempFolder.getRoot().getAbsolutePath())); + return ImmutableList.copyOf( + ResourceUtils.readResourceUtf8(resultFile.toURI().toURL()).split("\n")); + } + }