Add some load convenience methods to DatabaseHelper (#1038)

* Add some load convenience methods to DatabaseHelper

These can only be called by test code, and they automatically wrap the load
in a transaction if one isn't already specified (for convenience).

In production code we don't want to be able to use these, as we have to be
more thoughtful about transactions in production code (e.g. make sure that
we aren't loading and then saving a resource in separate transactions in a
way that makes it prone to contention errors).
This commit is contained in:
Ben McIlwain 2021-03-25 16:14:46 -04:00 committed by GitHub
parent 14f08e9070
commit cd6cb10b37
8 changed files with 78 additions and 74 deletions

View file

@ -20,13 +20,12 @@ import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.MoreCollectors.onlyElement;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.config.RegistryConfig.getContactAutomaticTransferLength;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.persistence.transaction.TransactionManagerUtil.transactIfJpaTm;
import static google.registry.testing.ContactResourceSubject.assertAboutContacts;
import static google.registry.testing.DatabaseHelper.assertNoBillingEvents;
import static google.registry.testing.DatabaseHelper.assertPollMessagesEqual;
import static google.registry.testing.DatabaseHelper.deleteResource;
import static google.registry.testing.DatabaseHelper.getPollMessages;
import static google.registry.testing.DatabaseHelper.loadByKeys;
import static google.registry.testing.DatabaseHelper.persistActiveContact;
import static google.registry.testing.DatabaseHelper.persistResource;
import static google.registry.testing.EppExceptionSubject.assertAboutEppExceptions;
@ -136,10 +135,7 @@ class ContactTransferRequestFlowTest
// poll messages, the approval notice ones for gaining and losing registrars.
assertPollMessagesEqual(
Iterables.filter(
transactIfJpaTm(
() -> tm().loadByKeys(contact.getTransferData().getServerApproveEntities()))
.values(),
PollMessage.class),
loadByKeys(contact.getTransferData().getServerApproveEntities()), PollMessage.class),
ImmutableList.of(gainingApproveMessage, losingApproveMessage));
}

View file

@ -35,6 +35,8 @@ import static google.registry.testing.DatabaseHelper.createTld;
import static google.registry.testing.DatabaseHelper.getOnlyHistoryEntryOfType;
import static google.registry.testing.DatabaseHelper.getOnlyPollMessage;
import static google.registry.testing.DatabaseHelper.getPollMessages;
import static google.registry.testing.DatabaseHelper.loadByKey;
import static google.registry.testing.DatabaseHelper.loadByKeys;
import static google.registry.testing.DatabaseHelper.loadRegistrar;
import static google.registry.testing.DatabaseHelper.persistActiveContact;
import static google.registry.testing.DatabaseHelper.persistResource;
@ -300,15 +302,13 @@ class DomainTransferRequestFlowTest
// Assert that the domain's TransferData server-approve billing events match the above.
if (expectTransferBillingEvent) {
assertBillingEventsEqual(
transactIfJpaTm(
() -> tm().loadByKey(domain.getTransferData().getServerApproveBillingEvent())),
loadByKey(domain.getTransferData().getServerApproveBillingEvent()),
optionalTransferBillingEvent.get());
} else {
assertThat(domain.getTransferData().getServerApproveBillingEvent()).isNull();
}
assertBillingEventsEqual(
transactIfJpaTm(
() -> tm().loadByKey(domain.getTransferData().getServerApproveAutorenewEvent())),
loadByKey(domain.getTransferData().getServerApproveAutorenewEvent()),
gainingClientAutorenew);
// Assert that the full set of server-approve billing events is exactly the extra ones plus
// the transfer billing event (if present) and the gaining client autorenew.
@ -318,14 +318,10 @@ class DomainTransferRequestFlowTest
.collect(toImmutableSet());
assertBillingEventsEqual(
Iterables.filter(
transactIfJpaTm(
() ->
tm().loadByKeys(domain.getTransferData().getServerApproveEntities()).values()),
BillingEvent.class),
loadByKeys(domain.getTransferData().getServerApproveEntities()), BillingEvent.class),
Sets.union(expectedServeApproveBillingEvents, extraBillingEvents));
// The domain's autorenew billing event should still point to the losing client's event.
BillingEvent.Recurring domainAutorenewEvent =
transactIfJpaTm(() -> tm().loadByKey(domain.getAutorenewBillingEvent()));
BillingEvent.Recurring domainAutorenewEvent = loadByKey(domain.getAutorenewBillingEvent());
assertThat(domainAutorenewEvent.getClientId()).isEqualTo("TheRegistrar");
assertThat(domainAutorenewEvent.getRecurrenceEndTime()).isEqualTo(implicitTransferTime);
// The original grace periods should remain untouched.
@ -421,17 +417,13 @@ class DomainTransferRequestFlowTest
// Assert that the poll messages show up in the TransferData server approve entities.
assertPollMessagesEqual(
transactIfJpaTm(
() -> tm().loadByKey(domain.getTransferData().getServerApproveAutorenewPollMessage())),
loadByKey(domain.getTransferData().getServerApproveAutorenewPollMessage()),
autorenewPollMessage);
// Assert that the full set of server-approve poll messages is exactly the server approve
// OneTime messages to gaining and losing registrars plus the gaining client autorenew.
assertPollMessagesEqual(
Iterables.filter(
transactIfJpaTm(
() ->
tm().loadByKeys(domain.getTransferData().getServerApproveEntities()).values()),
PollMessage.class),
loadByKeys(domain.getTransferData().getServerApproveEntities()), PollMessage.class),
ImmutableList.of(
transferApprovedPollMessage, losingTransferApprovedPollMessage, autorenewPollMessage));
}
@ -449,11 +441,7 @@ class DomainTransferRequestFlowTest
.hasLastEppUpdateTime(implicitTransferTime)
.and()
.hasLastEppUpdateClientId("NewRegistrar");
assertThat(
transactIfJpaTm(
() ->
tm().loadByKey(domainAfterAutomaticTransfer.getAutorenewBillingEvent())
.getEventTime()))
assertThat(loadByKey(domainAfterAutomaticTransfer.getAutorenewBillingEvent()).getEventTime())
.isEqualTo(expectedExpirationTime);
// And after the expected grace time, the grace period should be gone.
DomainBase afterGracePeriod =

View file

@ -19,10 +19,10 @@ import static com.google.common.truth.Truth.assertThat;
import static google.registry.batch.AsyncTaskEnqueuer.QUEUE_ASYNC_HOST_RENAME;
import static google.registry.model.EppResourceUtils.loadByForeignKey;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.persistence.transaction.TransactionManagerUtil.transactIfJpaTm;
import static google.registry.testing.DatabaseHelper.assertNoBillingEvents;
import static google.registry.testing.DatabaseHelper.createTld;
import static google.registry.testing.DatabaseHelper.getOnlyHistoryEntryOfType;
import static google.registry.testing.DatabaseHelper.loadByEntity;
import static google.registry.testing.DatabaseHelper.newDomainBase;
import static google.registry.testing.DatabaseHelper.newHostResource;
import static google.registry.testing.DatabaseHelper.persistActiveDomain;
@ -302,8 +302,7 @@ class HostUpdateFlowTest extends ResourceFlowTestCase<HostUpdateFlow, HostResour
.hasPersistedCurrentSponsorClientId("TheRegistrar")
.and()
.hasLastTransferTime(oneDayAgo);
DomainBase reloadedDomain =
transactIfJpaTm(() -> tm().loadByEntity(domain)).cloneProjectedAtTime(now);
DomainBase reloadedDomain = loadByEntity(domain).cloneProjectedAtTime(now);
assertThat(reloadedDomain.getSubordinateHosts()).containsExactly("ns2.example.tld");
assertDnsTasksEnqueued("ns1.example.tld", "ns2.example.tld");
}
@ -337,15 +336,8 @@ class HostUpdateFlowTest extends ResourceFlowTestCase<HostUpdateFlow, HostResour
.hasPersistedCurrentSponsorClientId("TheRegistrar")
.and()
.hasLastTransferTime(null);
assertThat(
transactIfJpaTm(() -> tm().loadByEntity(foo))
.cloneProjectedAtTime(now)
.getSubordinateHosts())
.isEmpty();
assertThat(
transactIfJpaTm(() -> tm().loadByEntity(example))
.cloneProjectedAtTime(now)
.getSubordinateHosts())
assertThat(loadByEntity(foo).cloneProjectedAtTime(now).getSubordinateHosts()).isEmpty();
assertThat(loadByEntity(example).cloneProjectedAtTime(now).getSubordinateHosts())
.containsExactly("ns2.example.tld");
assertDnsTasksEnqueued("ns2.foo.tld", "ns2.example.tld");
}
@ -380,11 +372,9 @@ class HostUpdateFlowTest extends ResourceFlowTestCase<HostUpdateFlow, HostResour
.hasPersistedCurrentSponsorClientId("TheRegistrar")
.and()
.hasLastTransferTime(null);
DomainBase reloadedFooDomain =
transactIfJpaTm(() -> tm().loadByEntity(fooDomain)).cloneProjectedAtTime(now);
DomainBase reloadedFooDomain = loadByEntity(fooDomain).cloneProjectedAtTime(now);
assertThat(reloadedFooDomain.getSubordinateHosts()).isEmpty();
DomainBase reloadedTldDomain =
transactIfJpaTm(() -> tm().loadByEntity(tldDomain)).cloneProjectedAtTime(now);
DomainBase reloadedTldDomain = loadByEntity(tldDomain).cloneProjectedAtTime(now);
assertThat(reloadedTldDomain.getSubordinateHosts()).containsExactly("ns2.example.tld");
assertDnsTasksEnqueued("ns1.example.foo", "ns2.example.tld");
}
@ -427,8 +417,7 @@ class HostUpdateFlowTest extends ResourceFlowTestCase<HostUpdateFlow, HostResour
.and()
.hasLastSuperordinateChange(clock.nowUtc());
assertThat(renamedHost.getLastTransferTime()).isEqualTo(oneDayAgo);
DomainBase reloadedDomain =
transactIfJpaTm(() -> tm().loadByEntity(domain)).cloneProjectedAtTime(clock.nowUtc());
DomainBase reloadedDomain = loadByEntity(domain).cloneProjectedAtTime(clock.nowUtc());
assertThat(reloadedDomain.getSubordinateHosts()).isEmpty();
assertDnsTasksEnqueued("ns1.example.foo");
}
@ -464,10 +453,7 @@ class HostUpdateFlowTest extends ResourceFlowTestCase<HostUpdateFlow, HostResour
.hasPersistedCurrentSponsorClientId("TheRegistrar")
.and()
.hasLastTransferTime(null);
assertThat(
transactIfJpaTm(() -> tm().loadByEntity(domain))
.cloneProjectedAtTime(now)
.getSubordinateHosts())
assertThat(loadByEntity(domain).cloneProjectedAtTime(now).getSubordinateHosts())
.containsExactly("ns2.example.tld");
assertDnsTasksEnqueued("ns2.example.tld");
}

View file

@ -19,8 +19,9 @@ import static google.registry.model.domain.token.AllocationToken.TokenType.UNLIM
import static google.registry.model.ofy.ObjectifyService.ofy;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.persistence.transaction.TransactionManagerUtil.ofyTmOrDoNothing;
import static google.registry.persistence.transaction.TransactionManagerUtil.transactIfJpaTm;
import static google.registry.testing.DatabaseHelper.createTld;
import static google.registry.testing.DatabaseHelper.loadByEntity;
import static google.registry.testing.DatabaseHelper.loadByKey;
import static google.registry.testing.DatabaseHelper.persistActiveDomain;
import static google.registry.testing.DatabaseHelper.persistResource;
import static google.registry.util.DateTimeUtils.END_OF_TIME;
@ -184,14 +185,11 @@ public class BillingEventTest extends EntityTestCase {
@TestOfyAndSql
void testPersistence() {
assertThat(transactIfJpaTm(() -> tm().loadByEntity(oneTime))).isEqualTo(oneTime);
assertThat(transactIfJpaTm(() -> tm().loadByEntity(oneTimeSynthetic)))
.isEqualTo(oneTimeSynthetic);
assertThat(transactIfJpaTm(() -> tm().loadByEntity(recurring))).isEqualTo(recurring);
assertThat(transactIfJpaTm(() -> tm().loadByEntity(cancellationOneTime)))
.isEqualTo(cancellationOneTime);
assertThat(transactIfJpaTm(() -> tm().loadByEntity(cancellationRecurring)))
.isEqualTo(cancellationRecurring);
assertThat(loadByEntity(oneTime)).isEqualTo(oneTime);
assertThat(loadByEntity(oneTimeSynthetic)).isEqualTo(oneTimeSynthetic);
assertThat(loadByEntity(recurring)).isEqualTo(recurring);
assertThat(loadByEntity(cancellationOneTime)).isEqualTo(cancellationOneTime);
assertThat(loadByEntity(cancellationRecurring)).isEqualTo(cancellationRecurring);
ofyTmOrDoNothing(() -> assertThat(tm().loadByEntity(modification)).isEqualTo(modification));
}
@ -220,10 +218,8 @@ public class BillingEventTest extends EntityTestCase {
@TestOfyAndSql
void testCancellationMatching() {
VKey<?> recurringKey =
transactIfJpaTm(
() -> tm().loadByEntity(oneTimeSynthetic).getCancellationMatchingBillingEvent());
assertThat(transactIfJpaTm(() -> tm().loadByKey(recurringKey))).isEqualTo(recurring);
VKey<?> recurringKey = loadByEntity(oneTimeSynthetic).getCancellationMatchingBillingEvent();
assertThat(loadByKey(recurringKey)).isEqualTo(recurring);
}
@TestOfyOnly

View file

@ -22,9 +22,8 @@ import static google.registry.model.domain.token.AllocationToken.TokenStatus.NOT
import static google.registry.model.domain.token.AllocationToken.TokenStatus.VALID;
import static google.registry.model.domain.token.AllocationToken.TokenType.SINGLE_USE;
import static google.registry.model.domain.token.AllocationToken.TokenType.UNLIMITED_USE;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.persistence.transaction.TransactionManagerUtil.transactIfJpaTm;
import static google.registry.testing.DatabaseHelper.createTld;
import static google.registry.testing.DatabaseHelper.loadByEntity;
import static google.registry.testing.DatabaseHelper.persistActiveDomain;
import static google.registry.testing.DatabaseHelper.persistResource;
import static google.registry.util.DateTimeUtils.START_OF_TIME;
@ -78,8 +77,7 @@ public class AllocationTokenTest extends EntityTestCase {
.put(DateTime.now(UTC).plusWeeks(8), TokenStatus.ENDED)
.build())
.build());
assertThat(transactIfJpaTm(() -> tm().loadByEntity(unlimitedUseToken)))
.isEqualTo(unlimitedUseToken);
assertThat(loadByEntity(unlimitedUseToken)).isEqualTo(unlimitedUseToken);
DomainBase domain = persistActiveDomain("example.foo");
Key<HistoryEntry> historyEntryKey = Key.create(Key.create(domain), HistoryEntry.class, 1);
@ -92,7 +90,7 @@ public class AllocationTokenTest extends EntityTestCase {
.setCreationTimeForTest(DateTime.parse("2010-11-12T05:00:00Z"))
.setTokenType(SINGLE_USE)
.build());
assertThat(transactIfJpaTm(() -> tm().loadByEntity(singleUseToken))).isEqualTo(singleUseToken);
assertThat(loadByEntity(singleUseToken)).isEqualTo(singleUseToken);
}
@TestOfyOnly

View file

@ -54,6 +54,7 @@ import static org.junit.jupiter.api.Assertions.fail;
import com.google.common.base.Ascii;
import com.google.common.base.Splitter;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableCollection;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
@ -1244,5 +1245,45 @@ public class DatabaseHelper {
clientId);
}
/**
* Loads (i.e. reloads) the specified entity from the DB.
*
* <p>If the transaction manager is Cloud SQL, then this creates an inner wrapping transaction for
* convenience, so you don't need to wrap it in a transaction at the callsite.
*/
public static <T> T loadByEntity(T entity) {
return transactIfJpaTm(() -> tm().loadByEntity(entity));
}
/**
* Loads the specified entity by its key from the DB.
*
* <p>If the transaction manager is Cloud SQL, then this creates an inner wrapping transaction for
* convenience, so you don't need to wrap it in a transaction at the callsite.
*/
public static <T> T loadByKey(VKey<T> key) {
return transactIfJpaTm(() -> tm().loadByKey(key));
}
/**
* Loads the specified entities by their keys from the DB.
*
* <p>If the transaction manager is Cloud SQL, then this creates an inner wrapping transaction for
* convenience, so you don't need to wrap it in a transaction at the callsite.
*/
public static <T> ImmutableCollection<T> loadByKeys(Iterable<? extends VKey<? extends T>> keys) {
return transactIfJpaTm(() -> tm().loadByKeys(keys).values());
}
/**
* Loads all of the entities of the specified type from the DB.
*
* <p>If the transaction manager is Cloud SQL, then this creates an inner wrapping transaction for
* convenience, so you don't need to wrap it in a transaction at the callsite.
*/
public static <T> ImmutableList<T> loadAllOf(Class<T> clazz) {
return transactIfJpaTm(() -> tm().loadAllOf(clazz));
}
private DatabaseHelper() {}
}

View file

@ -19,6 +19,7 @@ import static google.registry.model.domain.token.AllocationToken.TokenType.SINGL
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.persistence.transaction.TransactionManagerUtil.transactIfJpaTm;
import static google.registry.testing.DatabaseHelper.createTlds;
import static google.registry.testing.DatabaseHelper.loadAllOf;
import static google.registry.testing.DatabaseHelper.persistActiveDomain;
import static google.registry.testing.DatabaseHelper.persistResource;
import static org.junit.jupiter.api.Assertions.assertThrows;
@ -134,10 +135,9 @@ class DeleteAllocationTokensCommandTest extends CommandTestCase<DeleteAllocation
for (int i = 0; i < 50; i++) {
persistToken(String.format("batch%2d", i), null, i % 2 == 0);
}
assertThat(transactIfJpaTm(() -> tm().loadAllOf(AllocationToken.class).size())).isEqualTo(56);
assertThat(loadAllOf(AllocationToken.class).size()).isEqualTo(56);
runCommandForced("--prefix", "batch");
assertThat(transactIfJpaTm(() -> tm().loadAllOf(AllocationToken.class).size()))
.isEqualTo(56 - 25);
assertThat(loadAllOf(AllocationToken.class).size()).isEqualTo(56 - 25);
}
@TestOfyAndSql

View file

@ -17,10 +17,9 @@ package google.registry.tools;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.model.domain.token.AllocationToken.TokenType.SINGLE_USE;
import static google.registry.model.domain.token.AllocationToken.TokenType.UNLIMITED_USE;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.persistence.transaction.TransactionManagerUtil.transactIfJpaTm;
import static google.registry.testing.DatabaseHelper.assertAllocationTokens;
import static google.registry.testing.DatabaseHelper.createTld;
import static google.registry.testing.DatabaseHelper.loadAllOf;
import static google.registry.testing.DatabaseHelper.persistResource;
import static google.registry.util.DateTimeUtils.START_OF_TIME;
import static java.nio.charset.StandardCharsets.UTF_8;
@ -134,7 +133,7 @@ class GenerateAllocationTokensCommandTest extends CommandTestCase<GenerateAlloca
runCommand("--prefix", "ooo", "--number", "100", "--length", "16");
// The deterministic string generator makes it too much hassle to assert about each token, so
// just assert total number.
assertThat(transactIfJpaTm(() -> tm().loadAllOf(AllocationToken.class).size())).isEqualTo(100);
assertThat(loadAllOf(AllocationToken.class).size()).isEqualTo(100);
}
@TestOfyAndSql
@ -199,7 +198,7 @@ class GenerateAllocationTokensCommandTest extends CommandTestCase<GenerateAlloca
Collection<String> sampleTokens = command.stringGenerator.createStrings(13, 100);
runCommand("--tokens", Joiner.on(",").join(sampleTokens));
assertInStdout(Iterables.toArray(sampleTokens, String.class));
assertThat(transactIfJpaTm(() -> tm().loadAllOf(AllocationToken.class).size())).isEqualTo(100);
assertThat(loadAllOf(AllocationToken.class).size()).isEqualTo(100);
}
@TestOfyAndSql