Add tm().reTransact() methods and refactor away some inner transactions (#2125)

In the future, reTransact() will be the only way to initiate a transaction that
doesn't fail when called inside an outer wrapping transaction (when wrapped,
it's a no-op). It should be used sparingly, with a preference towards
refactoring the code to move transactions outwards (which this PR also
contains).

Note that this PR includes some potential efficiency gains caused by existing
poor use of transactions. E.g. in the file RefreshDnsAction, the existing code
was using two separate transactions to refresh the DNS for domains and hosts
(one is hidden in loadAndVerifyExistence(), whereas now as of this PR it has a
single wrapping transaction to do so.
This commit is contained in:
Ben McIlwain 2023-08-25 14:03:25 -04:00 committed by GitHub
parent 739a15851d
commit f01adfb060
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 125 additions and 59 deletions

View file

@ -60,18 +60,21 @@ public final class RefreshDnsAction implements Runnable {
if (!domainOrHostName.contains(".")) { if (!domainOrHostName.contains(".")) {
throw new BadRequestException("URL parameter 'name' must be fully qualified"); throw new BadRequestException("URL parameter 'name' must be fully qualified");
} }
switch (type) { tm().transact(
case DOMAIN: () -> {
loadAndVerifyExistence(Domain.class, domainOrHostName); switch (type) {
tm().transact(() -> requestDomainDnsRefresh(domainOrHostName)); case DOMAIN:
break; loadAndVerifyExistence(Domain.class, domainOrHostName);
case HOST: requestDomainDnsRefresh(domainOrHostName);
verifyHostIsSubordinate(loadAndVerifyExistence(Host.class, domainOrHostName)); break;
tm().transact(() -> requestHostDnsRefresh(domainOrHostName)); case HOST:
break; verifyHostIsSubordinate(loadAndVerifyExistence(Host.class, domainOrHostName));
default: requestHostDnsRefresh(domainOrHostName);
throw new BadRequestException("Unsupported type: " + type); break;
} default:
throw new BadRequestException("Unsupported type: " + type);
}
});
} }
private <T extends EppResource & ForeignKeyedEppResource> private <T extends EppResource & ForeignKeyedEppResource>

View file

@ -23,6 +23,7 @@ import static google.registry.flows.domain.DomainFlowUtils.validateDomainNameWit
import static google.registry.flows.domain.DomainFlowUtils.verifyClaimsPeriodNotEnded; import static google.registry.flows.domain.DomainFlowUtils.verifyClaimsPeriodNotEnded;
import static google.registry.flows.domain.DomainFlowUtils.verifyNotInPredelegation; import static google.registry.flows.domain.DomainFlowUtils.verifyNotInPredelegation;
import static google.registry.model.domain.launch.LaunchPhase.CLAIMS; import static google.registry.model.domain.launch.LaunchPhase.CLAIMS;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
@ -108,7 +109,8 @@ public final class DomainClaimsCheckFlow implements Flow {
verifyClaimsPeriodNotEnded(tld, now); verifyClaimsPeriodNotEnded(tld, now);
} }
} }
Optional<String> claimKey = ClaimsListDao.get().getClaimKey(parsedDomain.parts().get(0)); Optional<String> claimKey =
tm().transact(() -> ClaimsListDao.get().getClaimKey(parsedDomain.parts().get(0)));
launchChecksBuilder.add( launchChecksBuilder.add(
LaunchCheck.create( LaunchCheck.create(
LaunchCheckName.create(claimKey.isPresent(), domainName), claimKey.orElse(null))); LaunchCheckName.create(claimKey.isPresent(), domainName), claimKey.orElse(null)));

View file

@ -108,7 +108,7 @@ public final class PollAckFlow implements TransactionalFlow {
// acked, then we return a special status code indicating that. Note that the query will // acked, then we return a special status code indicating that. Note that the query will
// include the message being acked. // include the message being acked.
int messageCount = tm().transact(() -> getPollMessageCount(registrarId, now)); int messageCount = getPollMessageCount(registrarId, now);
if (messageCount <= 0) { if (messageCount <= 0) {
return responseBuilder.setResultFromCode(SUCCESS_WITH_NO_MESSAGES).build(); return responseBuilder.setResultFromCode(SUCCESS_WITH_NO_MESSAGES).build();
} }

View file

@ -30,13 +30,12 @@ public final class PollFlowUtils {
/** Returns the number of poll messages for the given registrar that are not in the future. */ /** Returns the number of poll messages for the given registrar that are not in the future. */
public static int getPollMessageCount(String registrarId, DateTime now) { public static int getPollMessageCount(String registrarId, DateTime now) {
return tm().transact(() -> createPollMessageQuery(registrarId, now).count()).intValue(); return (int) createPollMessageQuery(registrarId, now).count();
} }
/** Returns the first (by event time) poll message not in the future for this registrar. */ /** Returns the first (by event time) poll message not in the future for this registrar. */
public static Optional<PollMessage> getFirstPollMessage(String registrarId, DateTime now) { public static Optional<PollMessage> getFirstPollMessage(String registrarId, DateTime now) {
return tm().transact( return createPollMessageQuery(registrarId, now).orderBy("eventTime").first();
() -> createPollMessageQuery(registrarId, now).orderBy("eventTime").first());
} }
/** /**

View file

@ -20,6 +20,7 @@ import static google.registry.flows.poll.PollFlowUtils.getPollMessageCount;
import static google.registry.model.eppoutput.Result.Code.SUCCESS_WITH_ACK_MESSAGE; import static google.registry.model.eppoutput.Result.Code.SUCCESS_WITH_ACK_MESSAGE;
import static google.registry.model.eppoutput.Result.Code.SUCCESS_WITH_NO_MESSAGES; import static google.registry.model.eppoutput.Result.Code.SUCCESS_WITH_NO_MESSAGES;
import static google.registry.model.poll.PollMessageExternalKeyConverter.makePollMessageExternalId; import static google.registry.model.poll.PollMessageExternalKeyConverter.makePollMessageExternalId;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import google.registry.flows.EppException; import google.registry.flows.EppException;
import google.registry.flows.EppException.ParameterValueSyntaxErrorException; import google.registry.flows.EppException.ParameterValueSyntaxErrorException;
@ -31,7 +32,6 @@ import google.registry.model.eppoutput.EppResponse;
import google.registry.model.poll.MessageQueueInfo; import google.registry.model.poll.MessageQueueInfo;
import google.registry.model.poll.PollMessage; import google.registry.model.poll.PollMessage;
import google.registry.model.poll.PollMessageExternalKeyConverter; import google.registry.model.poll.PollMessageExternalKeyConverter;
import google.registry.util.Clock;
import java.util.Optional; import java.util.Optional;
import javax.inject.Inject; import javax.inject.Inject;
import org.joda.time.DateTime; import org.joda.time.DateTime;
@ -52,7 +52,6 @@ public final class PollRequestFlow implements Flow {
@Inject ExtensionManager extensionManager; @Inject ExtensionManager extensionManager;
@Inject @RegistrarId String registrarId; @Inject @RegistrarId String registrarId;
@Inject @PollMessageId String messageId; @Inject @PollMessageId String messageId;
@Inject Clock clock;
@Inject EppResponse.Builder responseBuilder; @Inject EppResponse.Builder responseBuilder;
@Inject PollRequestFlow() {} @Inject PollRequestFlow() {}
@ -63,24 +62,28 @@ public final class PollRequestFlow implements Flow {
if (!messageId.isEmpty()) { if (!messageId.isEmpty()) {
throw new UnexpectedMessageIdException(); throw new UnexpectedMessageIdException();
} }
// Return the oldest message from the queue. // Return the oldest message from the queue.
DateTime now = clock.nowUtc(); return tm().transact(
Optional<PollMessage> maybePollMessage = getFirstPollMessage(registrarId, now); () -> {
if (!maybePollMessage.isPresent()) { DateTime now = tm().getTransactionTime();
return responseBuilder.setResultFromCode(SUCCESS_WITH_NO_MESSAGES).build(); Optional<PollMessage> maybePollMessage = getFirstPollMessage(registrarId, now);
} if (!maybePollMessage.isPresent()) {
PollMessage pollMessage = maybePollMessage.get(); return responseBuilder.setResultFromCode(SUCCESS_WITH_NO_MESSAGES).build();
return responseBuilder }
.setResultFromCode(SUCCESS_WITH_ACK_MESSAGE) PollMessage pollMessage = maybePollMessage.get();
.setMessageQueueInfo( return responseBuilder
new MessageQueueInfo.Builder() .setResultFromCode(SUCCESS_WITH_ACK_MESSAGE)
.setQueueDate(pollMessage.getEventTime()) .setMessageQueueInfo(
.setMsg(pollMessage.getMsg()) new MessageQueueInfo.Builder()
.setQueueLength(getPollMessageCount(registrarId, now)) .setQueueDate(pollMessage.getEventTime())
.setMessageId(makePollMessageExternalId(pollMessage)) .setMsg(pollMessage.getMsg())
.build()) .setQueueLength(getPollMessageCount(registrarId, now))
.setMultipleResData(pollMessage.getResponseData()) .setMessageId(makePollMessageExternalId(pollMessage))
.build(); .build())
.setMultipleResData(pollMessage.getResponseData())
.build();
});
} }
/** Unexpected message id. */ /** Unexpected message id. */

View file

@ -156,7 +156,9 @@ public final class EppResourceUtils {
T resource = T resource =
useCache useCache
? EppResource.loadCached(key) ? EppResource.loadCached(key)
: tm().transact(() -> tm().loadByKeyIfPresent(key).orElse(null)); // This transaction is buried very deeply inside many outer nested calls, hence merits
// the use of reTransact() for now pending a substantial refactoring.
: tm().reTransact(() -> tm().loadByKeyIfPresent(key).orElse(null));
if (resource == null || isAtOrAfter(now, resource.getDeletionTime())) { if (resource == null || isAtOrAfter(now, resource.getDeletionTime())) {
return Optional.empty(); return Optional.empty();
} }

View file

@ -206,13 +206,11 @@ public class ClaimsList extends ImmutableObject {
if (labelsToKeys != null) { if (labelsToKeys != null) {
return Optional.ofNullable(labelsToKeys.get(label)); return Optional.ofNullable(labelsToKeys.get(label));
} }
return tm().transact( return tm().createQueryComposer(ClaimsEntry.class)
() -> .where("revisionId", EQ, revisionId)
tm().createQueryComposer(ClaimsEntry.class) .where("domainLabel", EQ, label)
.where("revisionId", EQ, revisionId) .first()
.where("domainLabel", EQ, label) .map(ClaimsEntry::getClaimKey);
.first()
.map(ClaimsEntry::getClaimKey));
} }
public static ClaimsList create( public static ClaimsList create(

View file

@ -158,6 +158,11 @@ public class JpaTransactionManagerImpl implements JpaTransactionManager {
() -> transactNoRetry(work, isolationLevel), JpaRetries::isFailedTxnRetriable); () -> transactNoRetry(work, isolationLevel), JpaRetries::isFailedTxnRetriable);
} }
@Override
public <T> T reTransact(Supplier<T> work) {
return transact(work);
}
@Override @Override
public <T> T transact(Supplier<T> work) { public <T> T transact(Supplier<T> work) {
return transact(work, null); return transact(work, null);
@ -229,6 +234,11 @@ public class JpaTransactionManagerImpl implements JpaTransactionManager {
isolationLevel); isolationLevel);
} }
@Override
public void reTransact(Runnable work) {
transact(work);
}
@Override @Override
public void transact(Runnable work) { public void transact(Runnable work) {
transact(work, null); transact(work, null);

View file

@ -56,12 +56,44 @@ public interface TransactionManager {
*/ */
<T> T transact(Supplier<T> work, TransactionIsolationLevel isolationLevel); <T> T transact(Supplier<T> work, TransactionIsolationLevel isolationLevel);
/**
* Executes the work in a (potentially wrapped) transaction and returns the result.
*
* <p>Calls to this method are typically going to be in inner functions, that are called either as
* top-level transactions themselves or are nested inside of larger transactions (e.g. a
* transactional flow). Invocations of reTransact must be vetted to occur in both situations and
* with such complexity that it is not trivial to refactor out the nested transaction calls. New
* code should be written in such a way as to avoid requiring reTransact in the first place.
*
* <p>In the future we will be enforcing that {@link #transact(Supplier)} calls be top-level only,
* with reTransact calls being the only ones that can potentially be an inner nested transaction
* (which is a noop). Note that, as this can be a nested inner exception, there is no overload
* provided to specify a (potentially conflicting) transaction isolation level.
*/
<T> T reTransact(Supplier<T> work);
/** Executes the work in a transaction. */ /** Executes the work in a transaction. */
void transact(Runnable work); void transact(Runnable work);
/** Executes the work in a transaction at the given {@link TransactionIsolationLevel}. */ /** Executes the work in a transaction at the given {@link TransactionIsolationLevel}. */
void transact(Runnable work, TransactionIsolationLevel isolationLevel); void transact(Runnable work, TransactionIsolationLevel isolationLevel);
/**
* Executes the work in a (potentially wrapped) transaction and returns the result.
*
* <p>Calls to this method are typically going to be in inner functions, that are called either as
* top-level transactions themselves or are nested inside of larger transactions (e.g. a
* transactional flow). Invocations of reTransact must be vetted to occur in both situations and
* with such complexity that it is not trivial to refactor out the nested transaction calls. New
* code should be written in such a way as to avoid requiring reTransact in the first place.
*
* <p>In the future we will be enforcing that {@link #transact(Runnable)} calls be top-level only,
* with reTransact calls being the only ones that can potentially be an inner nested transaction
* (which is a noop). Note that, as this can be a nested inner exception, there is no overload *
* provided to specify a (potentially conflicting) transaction isolation level.
*/
void reTransact(Runnable work);
/** Returns the time associated with the start of this particular transaction attempt. */ /** Returns the time associated with the start of this particular transaction attempt. */
DateTime getTransactionTime(); DateTime getTransactionTime();

View file

@ -15,11 +15,11 @@
package google.registry.model.tmch; package google.registry.model.tmch;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth8.assertThat;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm; import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.truth.Truth8;
import google.registry.persistence.transaction.JpaTestExtensions; import google.registry.persistence.transaction.JpaTestExtensions;
import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationWithCoverageExtension; import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationWithCoverageExtension;
import google.registry.testing.FakeClock; import google.registry.testing.FakeClock;
@ -113,17 +113,21 @@ public class ClaimsListDaoTest {
tm().transact(() -> tm().insert(claimsList)); tm().transact(() -> tm().insert(claimsList));
ClaimsList fromDatabase = ClaimsListDao.get(); ClaimsList fromDatabase = ClaimsListDao.get();
// At first, we haven't loaded any entries // At first, we haven't loaded any entries
assertThat(fromDatabase.claimKeyCache.getIfPresent("label1")).isNull(); assertThat(tm().transact(() -> fromDatabase.claimKeyCache.getIfPresent("label1"))).isNull();
Truth8.assertThat(fromDatabase.getClaimKey("label1")).hasValue("key1"); assertThat(tm().transact(() -> fromDatabase.getClaimKey("label1"))).hasValue("key1");
// After retrieval, the key exists // After retrieval, the key exists
Truth8.assertThat(fromDatabase.claimKeyCache.getIfPresent("label1")).hasValue("key1"); assertThat(tm().transact(() -> fromDatabase.claimKeyCache.getIfPresent("label1")))
assertThat(fromDatabase.claimKeyCache.getIfPresent("label2")).isNull(); .hasValue("key1");
assertThat(tm().transact(() -> fromDatabase.claimKeyCache.getIfPresent("label2"))).isNull();
// Loading labels-to-keys should still work // Loading labels-to-keys should still work
assertThat(fromDatabase.getLabelsToKeys()).containsExactly("label1", "key1", "label2", "key2"); assertThat(tm().transact(() -> fromDatabase.getLabelsToKeys()))
.containsExactly("label1", "key1", "label2", "key2");
// We should also cache nonexistent values // We should also cache nonexistent values
assertThat(fromDatabase.claimKeyCache.getIfPresent("nonexistent")).isNull(); assertThat(tm().transact(() -> fromDatabase.claimKeyCache.getIfPresent("nonexistent")))
Truth8.assertThat(fromDatabase.getClaimKey("nonexistent")).isEmpty(); .isNull();
Truth8.assertThat(fromDatabase.claimKeyCache.getIfPresent("nonexistent")).isEmpty(); assertThat(tm().transact(() -> fromDatabase.getClaimKey("nonexistent"))).isEmpty();
assertThat(tm().transact(() -> fromDatabase.claimKeyCache.getIfPresent("nonexistent")))
.isEmpty();
} }
private void assertClaimsListEquals(ClaimsList left, ClaimsList right) { private void assertClaimsListEquals(ClaimsList left, ClaimsList right) {

View file

@ -116,6 +116,11 @@ public class ReplicaSimulatingJpaTransactionManager implements JpaTransactionMan
isolationLevel); isolationLevel);
} }
@Override
public <T> T reTransact(Supplier<T> work) {
return transact(work);
}
@Override @Override
public <T> T transact(Supplier<T> work) { public <T> T transact(Supplier<T> work) {
return transact(work, null); return transact(work, null);
@ -141,6 +146,11 @@ public class ReplicaSimulatingJpaTransactionManager implements JpaTransactionMan
isolationLevel); isolationLevel);
} }
@Override
public void reTransact(Runnable work) {
transact(work);
}
@Override @Override
public void transact(Runnable work) { public void transact(Runnable work) {
transact(work, null); transact(work, null);

View file

@ -17,6 +17,7 @@ package google.registry.tmch;
import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth8.assertThat; import static com.google.common.truth.Truth8.assertThat;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -41,7 +42,8 @@ class TmchDnlActionTest extends TmchActionTestCase {
@Test @Test
void testDnl() throws Exception { void testDnl() throws Exception {
assertThat(ClaimsListDao.get().getClaimKey("xn----7sbejwbn3axu3d")).isEmpty(); assertThat(tm().transact(() -> ClaimsListDao.get().getClaimKey("xn----7sbejwbn3axu3d")))
.isEmpty();
when(httpUrlConnection.getInputStream()) when(httpUrlConnection.getInputStream())
.thenReturn(new ByteArrayInputStream(TmchTestData.loadBytes("dnl/dnl-latest.csv").read())) .thenReturn(new ByteArrayInputStream(TmchTestData.loadBytes("dnl/dnl-latest.csv").read()))
.thenReturn(new ByteArrayInputStream(TmchTestData.loadBytes("dnl/dnl-latest.sig").read())); .thenReturn(new ByteArrayInputStream(TmchTestData.loadBytes("dnl/dnl-latest.sig").read()));
@ -54,8 +56,8 @@ class TmchDnlActionTest extends TmchActionTestCase {
ClaimsList claimsList = ClaimsListDao.get(); ClaimsList claimsList = ClaimsListDao.get();
assertThat(claimsList.getTmdbGenerationTime()) assertThat(claimsList.getTmdbGenerationTime())
.isEqualTo(DateTime.parse("2013-11-24T23:15:37.4Z")); .isEqualTo(DateTime.parse("2013-11-24T23:15:37.4Z"));
assertThat(claimsList.getClaimKey("xn----7sbejwbn3axu3d")) assertThat(tm().transact(() -> claimsList.getClaimKey("xn----7sbejwbn3axu3d")))
.hasValue("2013112500/7/4/8/dIHW0DiuybvhdP8kIz"); .hasValue("2013112500/7/4/8/dIHW0DiuybvhdP8kIz");
assertThat(claimsList.getClaimKey("lolcat")).isEmpty(); assertThat(tm().transact(() -> claimsList.getClaimKey("lolcat"))).isEmpty();
} }
} }

View file

@ -16,6 +16,7 @@ package google.registry.tools;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth8.assertThat; import static com.google.common.truth.Truth8.assertThat;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import google.registry.model.tmch.ClaimsList; import google.registry.model.tmch.ClaimsList;
@ -46,11 +47,11 @@ class UploadClaimsListCommandTest extends CommandTestCase<UploadClaimsListComman
ClaimsList claimsList = ClaimsListDao.get(); ClaimsList claimsList = ClaimsListDao.get();
assertThat(claimsList.getTmdbGenerationTime()) assertThat(claimsList.getTmdbGenerationTime())
.isEqualTo(DateTime.parse("2012-08-16T00:00:00.0Z")); .isEqualTo(DateTime.parse("2012-08-16T00:00:00.0Z"));
assertThat(claimsList.getClaimKey("example")) assertThat(tm().transact(() -> claimsList.getClaimKey("example")))
.hasValue("2013041500/2/6/9/rJ1NrDO92vDsAzf7EQzgjX4R0000000001"); .hasValue("2013041500/2/6/9/rJ1NrDO92vDsAzf7EQzgjX4R0000000001");
assertThat(claimsList.getClaimKey("another-example")) assertThat(tm().transact(() -> claimsList.getClaimKey("another-example")))
.hasValue("2013041500/6/A/5/alJAqG2vI2BmCv5PfUvuDkf40000000002"); .hasValue("2013041500/6/A/5/alJAqG2vI2BmCv5PfUvuDkf40000000002");
assertThat(claimsList.getClaimKey("anotherexample")) assertThat(tm().transact(() -> claimsList.getClaimKey("anotherexample")))
.hasValue("2013041500/A/C/7/rHdC4wnrWRvPY6nneCVtQhFj0000000003"); .hasValue("2013041500/A/C/7/rHdC4wnrWRvPY6nneCVtQhFj0000000003");
} }