Remove deleted BSA labels from database (#2286)

Fixed the bug that retains deleted BSA labels in the database.

Added a few simple end-to-end tests for BSA download.
This commit is contained in:
Weimin Yu 2024-01-12 14:20:56 -05:00 committed by GitHub
parent 036d35c11a
commit 9273d2bf15
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 241 additions and 23 deletions

View file

@ -206,9 +206,7 @@ class BsaDiffCreator {
BlockLabel.of(
entry.getKey(),
LabelType.NEW_ORDER_ASSOCIATION,
idnChecker.getAllValidIdns(entry.getKey()).stream()
.map(IdnTableEnum::name)
.collect(toImmutableSet()))),
getAllValidIdnNames(entry.getKey()))),
newAndRemaining.asMap().entrySet().stream()
.filter(e -> e.getValue().size() > 1 || !e.getValue().contains(ORDER_ID_SENTINEL))
.filter(entry -> !entry.getValue().contains(ORDER_ID_SENTINEL))
@ -217,13 +215,17 @@ class BsaDiffCreator {
BlockLabel.of(
entry.getKey(),
LabelType.CREATE,
idnChecker.getAllValidIdns(entry.getKey()).stream()
.map(IdnTableEnum::name)
.collect(toImmutableSet()))),
getAllValidIdnNames(entry.getKey()))),
Sets.difference(deleted.keySet(), newAndRemaining.keySet()).stream()
.map(label -> BlockLabel.of(label, LabelType.DELETE, ImmutableSet.of())))
.map(label -> BlockLabel.of(label, LabelType.DELETE, getAllValidIdnNames(label))))
.flatMap(x -> x);
}
ImmutableSet<String> getAllValidIdnNames(String label) {
return idnChecker.getAllValidIdns(label).stream()
.map(IdnTableEnum::name)
.collect(toImmutableSet());
}
}
static class Canonicals<T> {

View file

@ -162,6 +162,7 @@ public class BsaDownloadAction implements Runnable {
// Fall through
case MAKE_ORDER_AND_LABEL_DIFF:
diff = diffCreator.createDiff(schedule, lazyIdnChecker.get());
// TODO(weiminyu): log the diff stats
gcsClient.writeOrderDiffs(schedule.jobName(), diff.getOrders());
gcsClient.writeLabelDiffs(schedule.jobName(), diff.getLabels());
schedule.updateJobStage(DownloadStage.APPLY_ORDER_AND_LABEL_DIFF);

View file

@ -128,6 +128,7 @@ class BsaDiffCreatorTest {
@Test
void allRemoved() {
when(idnChecker.getAllValidIdns(anyString())).thenReturn(ImmutableSet.of(IdnTableEnum.JA));
when(gcsClient.readBlockList("first", BlockListType.BLOCK))
.thenReturn(Stream.of("domainLabel,orderIDs", "test1,1;2", "test2,3", "test3,1;4"));
when(gcsClient.readBlockList("second", BlockListType.BLOCK)).thenReturn(Stream.of());
@ -140,9 +141,9 @@ class BsaDiffCreatorTest {
BsaDiff diff = diffCreator.createDiff(schedule, idnChecker);
assertThat(diff.getLabels())
.containsExactly(
BlockLabel.of("test1", LabelType.DELETE, ImmutableSet.of()),
BlockLabel.of("test2", LabelType.DELETE, ImmutableSet.of()),
BlockLabel.of("test3", LabelType.DELETE, ImmutableSet.of()));
BlockLabel.of("test1", LabelType.DELETE, ImmutableSet.of("JA")),
BlockLabel.of("test2", LabelType.DELETE, ImmutableSet.of("JA")),
BlockLabel.of("test3", LabelType.DELETE, ImmutableSet.of("JA")));
assertThat(diff.getOrders())
.containsExactly(
BlockOrder.of(1, OrderType.DELETE),
@ -227,6 +228,7 @@ class BsaDiffCreatorTest {
@Test
void removeLabelAndOrder() {
when(idnChecker.getAllValidIdns(anyString())).thenReturn(ImmutableSet.of(IdnTableEnum.JA));
when(gcsClient.readBlockList("first", BlockListType.BLOCK))
.thenReturn(Stream.of("domainLabel,orderIDs", "test1,1;2", "test2,3", "test3,1;4"));
when(gcsClient.readBlockList("second", BlockListType.BLOCK))
@ -239,12 +241,13 @@ class BsaDiffCreatorTest {
when(schedule.latestCompleted()).thenReturn(Optional.of(completedJob));
BsaDiff diff = diffCreator.createDiff(schedule, idnChecker);
assertThat(diff.getLabels())
.containsExactly(BlockLabel.of("test2", LabelType.DELETE, ImmutableSet.of()));
.containsExactly(BlockLabel.of("test2", LabelType.DELETE, ImmutableSet.of("JA")));
assertThat(diff.getOrders()).containsExactly(BlockOrder.of(3, OrderType.DELETE));
}
@Test
void removeLabelAndOrder_multi() {
when(idnChecker.getAllValidIdns(anyString())).thenReturn(ImmutableSet.of(IdnTableEnum.JA));
when(gcsClient.readBlockList("first", BlockListType.BLOCK))
.thenReturn(Stream.of("domainLabel,orderIDs", "test1,1;2", "test2,3", "test3,1;4"));
when(gcsClient.readBlockList("second", BlockListType.BLOCK))
@ -258,8 +261,8 @@ class BsaDiffCreatorTest {
BsaDiff diff = diffCreator.createDiff(schedule, idnChecker);
assertThat(diff.getLabels())
.containsExactly(
BlockLabel.of("test1", LabelType.DELETE, ImmutableSet.of()),
BlockLabel.of("test3", LabelType.DELETE, ImmutableSet.of()));
BlockLabel.of("test1", LabelType.DELETE, ImmutableSet.of("JA")),
BlockLabel.of("test3", LabelType.DELETE, ImmutableSet.of("JA")));
assertThat(diff.getOrders())
.containsExactly(
BlockOrder.of(1, OrderType.DELETE),

View file

@ -0,0 +1,203 @@
// Copyright 2023 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.bsa;
import static com.google.common.io.BaseEncoding.base16;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth8.assertThat;
import static google.registry.bsa.persistence.BsaTestingUtils.createDownloadScheduler;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.testing.DatabaseHelper.createTlds;
import static google.registry.testing.DatabaseHelper.persistResource;
import static google.registry.util.DateTimeUtils.START_OF_TIME;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.joda.time.Duration.standardDays;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.cloud.storage.contrib.nio.testing.LocalStorageHelper;
import com.google.common.base.Joiner;
import google.registry.bsa.BlockListFetcher.LazyBlockList;
import google.registry.bsa.api.BsaReportSender;
import google.registry.gcs.GcsUtils;
import google.registry.model.tld.Tld.TldType;
import google.registry.model.tld.Tlds;
import google.registry.persistence.transaction.JpaTestExtensions;
import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationWithCoverageExtension;
import google.registry.request.Response;
import google.registry.testing.FakeClock;
import google.registry.testing.FakeLockHandler;
import google.registry.testing.FakeResponse;
import java.security.MessageDigest;
import java.util.Optional;
import java.util.function.BiConsumer;
import java.util.stream.Stream;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.stubbing.Answer;
import org.testcontainers.shaded.com.google.common.collect.ImmutableList;
/** Functional tests of BSA block list download and processing. */
@ExtendWith(MockitoExtension.class)
class BsaDownloadFunctionalTest {
static final DateTime TEST_START_TIME = DateTime.parse("2024-01-01T00:00:00Z");
static final String BSA_CSV_HEADER = "domainLabel,orderIDs";
@Mock BlockListFetcher blockListFetcher;
@Mock BsaReportSender bsaReportSender;
private final FakeClock fakeClock = new FakeClock(TEST_START_TIME);
@RegisterExtension
JpaIntegrationWithCoverageExtension jpa =
new JpaTestExtensions.Builder().withClock(fakeClock).buildIntegrationWithCoverageExtension();
private GcsClient gcsClient;
private BsaDownloadAction action;
private Response response;
@BeforeEach
void setup() throws Exception {
createTlds("app", "dev");
Tlds.getTldEntitiesOfType(TldType.REAL)
.forEach(
tld ->
persistResource(
tld.asBuilder().setBsaEnrollStartTime(Optional.of(START_OF_TIME)).build()));
gcsClient =
new GcsClient(new GcsUtils(LocalStorageHelper.getOptions()), "my-bucket", "SHA-256");
response = new FakeResponse();
action =
new BsaDownloadAction(
createDownloadScheduler(fakeClock),
blockListFetcher,
new BsaDiffCreator(gcsClient),
bsaReportSender,
gcsClient,
() -> new IdnChecker(fakeClock),
new BsaLock(
new FakeLockHandler(/* lockSucceeds= */ true), Duration.standardSeconds(30)),
fakeClock,
/* transactionBatchSize= */ 5,
response);
}
@Test
void initialDownload_noUnblockables() throws Exception {
LazyBlockList blockList = mockBlockList(BlockListType.BLOCK, ImmutableList.of("abc,1"));
LazyBlockList blockPlusList =
mockBlockList(BlockListType.BLOCK_PLUS, ImmutableList.of("abc,2", "def,3"));
mockBlockListFetcher(blockList, blockPlusList);
action.run();
String downloadJob = "2024-01-01t000000.000z";
try (Stream<String> blockListFile = gcsClient.readBlockList(downloadJob, BlockListType.BLOCK)) {
assertThat(blockListFile).containsExactly(BSA_CSV_HEADER, "abc,1").inOrder();
}
try (Stream<String> blockListFile =
gcsClient.readBlockList(downloadJob, BlockListType.BLOCK_PLUS)) {
assertThat(blockListFile).containsExactly(BSA_CSV_HEADER, "abc,2", "def,3");
}
ImmutableList<String> persistedLabels =
ImmutableList.copyOf(
tm().transact(
() ->
tm().getEntityManager()
.createNativeQuery("SELECT label from \"BsaLabel\"")
.getResultList()));
// TODO(weiminyu): check intermediate files
assertThat(persistedLabels).containsExactly("abc", "def");
}
@Test
void initialDownload_thenDeleteLabel_noUnblockables() throws Exception {
LazyBlockList blockList = mockBlockList(BlockListType.BLOCK, ImmutableList.of("abc,1"));
LazyBlockList blockPlusList =
mockBlockList(BlockListType.BLOCK_PLUS, ImmutableList.of("abc,2", "def,3"));
LazyBlockList blockList2 = mockBlockList(BlockListType.BLOCK, ImmutableList.of("abc,1"));
LazyBlockList blockPlusList2 =
mockBlockList(BlockListType.BLOCK_PLUS, ImmutableList.of("abc,2"));
mockBlockListFetcher(blockList, blockPlusList, blockList2, blockPlusList2);
action.run();
assertThat(getPersistedLabels()).containsExactly("abc", "def");
fakeClock.advanceBy(standardDays(1));
action.run();
assertThat(getPersistedLabels()).containsExactly("abc");
}
private ImmutableList<String> getPersistedLabels() {
return ImmutableList.copyOf(
tm().transact(
() ->
tm().getEntityManager()
.createNativeQuery("SELECT label from \"BsaLabel\"")
.getResultList()));
}
private void mockBlockListFetcher(LazyBlockList blockList, LazyBlockList blockPlusList)
throws Exception {
when(blockListFetcher.fetch(BlockListType.BLOCK)).thenReturn(blockList);
when(blockListFetcher.fetch(BlockListType.BLOCK_PLUS)).thenReturn(blockPlusList);
}
private void mockBlockListFetcher(
LazyBlockList blockList1,
LazyBlockList blockPlusList1,
LazyBlockList blockList2,
LazyBlockList blockPlusList2)
throws Exception {
when(blockListFetcher.fetch(BlockListType.BLOCK)).thenReturn(blockList1, blockList2);
when(blockListFetcher.fetch(BlockListType.BLOCK_PLUS))
.thenReturn(blockPlusList1, blockPlusList2);
}
static LazyBlockList mockBlockList(BlockListType blockListType, ImmutableList<String> dataLines)
throws Exception {
byte[] bytes =
Joiner.on('\n')
.join(new ImmutableList.Builder().add(BSA_CSV_HEADER).addAll(dataLines).build())
.getBytes(UTF_8);
String checksum = generateChecksum(bytes);
LazyBlockList blockList = mock(LazyBlockList.class);
when(blockList.checksum()).thenReturn(checksum);
when(blockList.getName()).thenReturn(blockListType);
doAnswer(
new Answer() {
@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
BiConsumer<byte[], Integer> consumer = invocation.getArgument(0);
consumer.accept(bytes, bytes.length);
return null;
}
})
.when(blockList)
.consumeAll(any(BiConsumer.class));
return blockList;
}
private static String generateChecksum(byte[] bytes) throws Exception {
MessageDigest messageDigest = MessageDigest.getInstance("SHA-256");
messageDigest.update(bytes, 0, bytes.length);
return base16().lowerCase().encode(messageDigest.digest());
}
}

View file

@ -15,8 +15,8 @@
package google.registry.bsa.persistence;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.bsa.persistence.BsaLabelTestingUtils.persistBsaLabel;
import static google.registry.bsa.persistence.BsaLabelUtils.isLabelBlocked;
import static google.registry.bsa.persistence.BsaTestingUtils.persistBsaLabel;
import static google.registry.persistence.transaction.TransactionManagerFactory.replicaTm;
import static google.registry.persistence.transaction.TransactionManagerFactory.setJpaTm;
import static google.registry.persistence.transaction.TransactionManagerFactory.setReplicaJpaTm;

View file

@ -16,14 +16,23 @@ package google.registry.bsa.persistence;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import google.registry.util.Clock;
import org.joda.time.DateTime;
import org.joda.time.Duration;
/** Testing utils for users of {@link BsaLabel}. */
public final class BsaLabelTestingUtils {
/** Exposes BSA persistence entities and tools to test classes. */
public final class BsaTestingUtils {
private BsaLabelTestingUtils() {}
public static final Duration DEFAULT_DOWNLOAD_INTERVAL = Duration.standardHours(1);
public static final Duration DEFAULT_NOP_INTERVAL = Duration.standardDays(1);
private BsaTestingUtils() {}
public static void persistBsaLabel(String domainLabel, DateTime creationTime) {
tm().transact(() -> tm().put(new BsaLabel(domainLabel, creationTime)));
}
public static DownloadScheduler createDownloadScheduler(Clock clock) {
return new DownloadScheduler(DEFAULT_DOWNLOAD_INTERVAL, DEFAULT_NOP_INTERVAL, clock);
}
}

View file

@ -16,7 +16,7 @@ package google.registry.bsa.persistence;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.bsa.BsaTransactions.bsaTransact;
import static google.registry.bsa.persistence.BsaLabelTestingUtils.persistBsaLabel;
import static google.registry.bsa.persistence.BsaTestingUtils.persistBsaLabel;
import static google.registry.model.tld.label.ReservationType.RESERVED_FOR_SPECIFIC_USE;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.testing.DatabaseHelper.createTld;

View file

@ -30,7 +30,7 @@ import static google.registry.testing.DatabaseHelper.persistResource;
import static google.registry.util.DateTimeUtils.START_OF_TIME;
import static org.mockito.Mockito.verify;
import google.registry.bsa.persistence.BsaLabelTestingUtils;
import google.registry.bsa.persistence.BsaTestingUtils;
import google.registry.model.tld.Tld;
import google.registry.monitoring.whitebox.CheckApiMetric;
import google.registry.monitoring.whitebox.CheckApiMetric.Availability;
@ -288,7 +288,7 @@ class CheckApiActionTest {
@Test
void testSuccess_blockedByBsa() {
BsaLabelTestingUtils.persistBsaLabel("rich", START_OF_TIME);
BsaTestingUtils.persistBsaLabel("rich", START_OF_TIME);
persistResource(
Tld.get("example").asBuilder().setBsaEnrollStartTime(Optional.of(START_OF_TIME)).build());
assertThat(getCheckResponse("rich.example"))

View file

@ -18,7 +18,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.io.BaseEncoding.base16;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth8.assertThat;
import static google.registry.bsa.persistence.BsaLabelTestingUtils.persistBsaLabel;
import static google.registry.bsa.persistence.BsaTestingUtils.persistBsaLabel;
import static google.registry.flows.FlowTestCase.UserPrivileges.SUPERUSER;
import static google.registry.model.billing.BillingBase.Flag.ANCHOR_TENANT;
import static google.registry.model.billing.BillingBase.Flag.RESERVED;

View file

@ -16,7 +16,7 @@ package google.registry.whois;
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth8.assertThat;
import static google.registry.bsa.persistence.BsaLabelTestingUtils.persistBsaLabel;
import static google.registry.bsa.persistence.BsaTestingUtils.persistBsaLabel;
import static google.registry.model.EppResourceUtils.loadByForeignKeyCached;
import static google.registry.model.registrar.Registrar.State.ACTIVE;
import static google.registry.model.registrar.Registrar.Type.PDT;

View file

@ -16,7 +16,7 @@ package google.registry.whois;
import static com.google.common.net.MediaType.PLAIN_TEXT_UTF_8;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.bsa.persistence.BsaLabelTestingUtils.persistBsaLabel;
import static google.registry.bsa.persistence.BsaTestingUtils.persistBsaLabel;
import static google.registry.model.registrar.Registrar.State.ACTIVE;
import static google.registry.testing.DatabaseHelper.createTlds;
import static google.registry.testing.DatabaseHelper.loadRegistrar;