Prevent saving duplicate rows in spec11 pipeline (#1810)

* Prevent saving duplicate rows in spec11 pipeline

* Chain applies together
This commit is contained in:
sarahcaseybot 2022-12-15 15:51:28 -05:00 committed by GitHub
parent b45a59c892
commit f9d1945787
3 changed files with 35 additions and 21 deletions

View file

@ -26,6 +26,7 @@ import google.registry.beam.common.RegistryJpaIO;
import google.registry.beam.common.RegistryJpaIO.Read; import google.registry.beam.common.RegistryJpaIO.Read;
import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn; import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn;
import google.registry.config.RegistryConfig.ConfigModule; import google.registry.config.RegistryConfig.ConfigModule;
import google.registry.model.IdService;
import google.registry.model.domain.Domain; import google.registry.model.domain.Domain;
import google.registry.model.reporting.Spec11ThreatMatch; import google.registry.model.reporting.Spec11ThreatMatch;
import google.registry.model.reporting.Spec11ThreatMatch.ThreatType; import google.registry.model.reporting.Spec11ThreatMatch.ThreatType;
@ -45,6 +46,7 @@ import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptor;
@ -154,25 +156,36 @@ public class Spec11Pipeline implements Serializable {
static void saveToSql( static void saveToSql(
PCollection<KV<DomainNameInfo, ThreatMatch>> threatMatches, Spec11PipelineOptions options) { PCollection<KV<DomainNameInfo, ThreatMatch>> threatMatches, Spec11PipelineOptions options) {
String transformId = "Spec11 Threat Matches";
LocalDate date = LocalDate.parse(options.getDate(), ISODateTimeFormat.date()); LocalDate date = LocalDate.parse(options.getDate(), ISODateTimeFormat.date());
threatMatches.apply( String transformId = "Spec11 Threat Matches";
"Write to Sql: " + transformId, threatMatches
RegistryJpaIO.<KV<DomainNameInfo, ThreatMatch>>write() .apply(
.withName(transformId) "Construct objects",
.withBatchSize(options.getSqlWriteBatchSize()) ParDo.of(
.withJpaConverter( new DoFn<KV<DomainNameInfo, ThreatMatch>, Spec11ThreatMatch>() {
(kv) -> { @ProcessElement
DomainNameInfo domainNameInfo = kv.getKey(); public void processElement(
return new Spec11ThreatMatch.Builder() @Element KV<DomainNameInfo, ThreatMatch> input,
OutputReceiver<Spec11ThreatMatch> output) {
Spec11ThreatMatch spec11ThreatMatch =
new Spec11ThreatMatch.Builder()
.setThreatTypes( .setThreatTypes(
ImmutableSet.of(ThreatType.valueOf(kv.getValue().threatType()))) ImmutableSet.of(ThreatType.valueOf(input.getValue().threatType())))
.setCheckDate(date) .setCheckDate(date)
.setDomainName(domainNameInfo.domainName()) .setDomainName(input.getKey().domainName())
.setDomainRepoId(domainNameInfo.domainRepoId()) .setDomainRepoId(input.getKey().domainRepoId())
.setRegistrarId(domainNameInfo.registrarId()) .setRegistrarId(input.getKey().registrarId())
.setId(IdService.allocateId())
.build(); .build();
})); output.output(spec11ThreatMatch);
}
}))
.apply("Prevent Fusing", Reshuffle.viaRandomKey())
.apply(
"Write to Sql: " + transformId,
RegistryJpaIO.<Spec11ThreatMatch>write()
.withName(transformId)
.withBatchSize(options.getSqlWriteBatchSize()));
} }
static void saveToGcs( static void saveToGcs(

View file

@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableSet;
import google.registry.model.Buildable; import google.registry.model.Buildable;
import google.registry.model.ImmutableObject; import google.registry.model.ImmutableObject;
import google.registry.util.DomainNameUtils; import google.registry.util.DomainNameUtils;
import java.io.Serializable;
import java.util.Set; import java.util.Set;
import javax.persistence.Column; import javax.persistence.Column;
import javax.persistence.Entity; import javax.persistence.Entity;
@ -39,7 +40,7 @@ import org.joda.time.LocalDate;
@Index(name = "spec11threatmatch_tld_idx", columnList = "tld"), @Index(name = "spec11threatmatch_tld_idx", columnList = "tld"),
@Index(name = "spec11threatmatch_check_date_idx", columnList = "checkDate") @Index(name = "spec11threatmatch_check_date_idx", columnList = "checkDate")
}) })
public class Spec11ThreatMatch extends ImmutableObject implements Buildable { public class Spec11ThreatMatch extends ImmutableObject implements Buildable, Serializable {
/** The type of threat detected. */ /** The type of threat detected. */
public enum ThreatType { public enum ThreatType {

View file

@ -280,9 +280,9 @@ class Spec11PipelineTest {
private void verifySaveToCloudSql() { private void verifySaveToCloudSql() {
tm().transact( tm().transact(
() -> { () -> {
ImmutableList<Spec11ThreatMatch> sqlThreatMatches = ImmutableList<Spec11ThreatMatch> spec11ThreatMatches =
Spec11ThreatMatchDao.loadEntriesByDate(tm(), new LocalDate(2020, 1, 27)); Spec11ThreatMatchDao.loadEntriesByDate(tm(), new LocalDate(2020, 1, 27));
assertThat(sqlThreatMatches) assertThat(spec11ThreatMatches)
.comparingElementsUsing(immutableObjectCorrespondence("id")) .comparingElementsUsing(immutableObjectCorrespondence("id"))
.containsExactlyElementsIn(sqlThreatMatches); .containsExactlyElementsIn(sqlThreatMatches);
}); });