From 798a6ffc74e050a5d239e658ee1bbebfe165741f Mon Sep 17 00:00:00 2001 From: sarahcaseybot Date: Wed, 7 Jun 2023 16:00:50 -0400 Subject: [PATCH] Remove nested transaction from requestDnsRefresh (#2044) * Remove nested transaction from requestDnsRefresh * Add a bulk version * Remove transaction time as field * Only add delay once * have PublishDnsUpdatesAction use bulk refresh --- .../java/google/registry/dns/DnsUtils.java | 34 +++++++++++-- .../registry/dns/PublishDnsUpdatesAction.java | 13 ++--- .../google/registry/dns/RefreshDnsAction.java | 28 ++++++----- .../google/registry/dns/DnsInjectionTest.java | 3 +- .../google/registry/dns/DnsUtilsTest.java | 50 ++++++++++++------- 5 files changed, 87 insertions(+), 41 deletions(-) diff --git a/core/src/main/java/google/registry/dns/DnsUtils.java b/core/src/main/java/google/registry/dns/DnsUtils.java index eb6793856..82c39689b 100644 --- a/core/src/main/java/google/registry/dns/DnsUtils.java +++ b/core/src/main/java/google/registry/dns/DnsUtils.java @@ -17,6 +17,7 @@ package google.registry.dns; import static com.google.common.collect.ImmutableList.toImmutableList; import static google.registry.persistence.transaction.TransactionManagerFactory.tm; +import com.google.common.collect.ImmutableCollection; import com.google.common.collect.ImmutableList; import com.google.common.net.InternetDomainName; import google.registry.model.common.DnsRefreshRequest; @@ -36,28 +37,53 @@ public final class DnsUtils { private DnsUtils() {} private static void requestDnsRefresh(String name, TargetType type, Duration delay) { + tm().assertInTransaction(); // Throws an IllegalArgumentException if the name is not under a managed TLD -- we only update // DNS for names that are under our management. String tld = Tlds.findTldForNameOrThrow(InternetDomainName.from(name)).toString(); - tm().transact( - () -> - tm().insert( + tm().insert(new DnsRefreshRequest(type, name, tld, tm().getTransactionTime().plus(delay))); + } + + private static void requestDnsRefresh( + ImmutableCollection names, TargetType type, Duration delay) { + tm().assertInTransaction(); + DateTime requestTime = tm().getTransactionTime().plus(delay); + tm().insertAll( + names.stream() + .map( + name -> new DnsRefreshRequest( - type, name, tld, tm().getTransactionTime().plus(delay)))); + type, + name, + Tlds.findTldForNameOrThrow(InternetDomainName.from(name)).toString(), + requestTime)) + .collect(toImmutableList())); } public static void requestDomainDnsRefresh(String domainName, Duration delay) { requestDnsRefresh(domainName, TargetType.DOMAIN, delay); } + public static void requestDomainDnsRefresh(ImmutableCollection names, Duration delay) { + requestDnsRefresh(names, TargetType.DOMAIN, delay); + } + public static void requestDomainDnsRefresh(String domainName) { requestDomainDnsRefresh(domainName, Duration.ZERO); } + public static void requestDomainDnsRefresh(ImmutableCollection names) { + requestDomainDnsRefresh(names, Duration.ZERO); + } + public static void requestHostDnsRefresh(String hostName) { requestDnsRefresh(hostName, TargetType.HOST, Duration.ZERO); } + public static void requestHostDnsRefresh(ImmutableCollection hostNames) { + requestDnsRefresh(hostNames, TargetType.HOST, Duration.ZERO); + } + /** * Returns pending DNS update requests that need further processing up to batch size, in ascending * order of their request time, and updates their processing time to now. diff --git a/core/src/main/java/google/registry/dns/PublishDnsUpdatesAction.java b/core/src/main/java/google/registry/dns/PublishDnsUpdatesAction.java index 90d0675ca..ff8199073 100644 --- a/core/src/main/java/google/registry/dns/PublishDnsUpdatesAction.java +++ b/core/src/main/java/google/registry/dns/PublishDnsUpdatesAction.java @@ -26,6 +26,7 @@ import static google.registry.dns.DnsUtils.DNS_PUBLISH_PUSH_QUEUE_NAME; import static google.registry.dns.DnsUtils.requestDomainDnsRefresh; import static google.registry.dns.DnsUtils.requestHostDnsRefresh; import static google.registry.model.EppResourceUtils.loadByForeignKey; +import static google.registry.persistence.transaction.TransactionManagerFactory.tm; import static google.registry.request.Action.Method.POST; import static google.registry.request.RequestParameters.PARAM_TLD; import static google.registry.util.CollectionUtils.nullToEmpty; @@ -34,6 +35,7 @@ import static javax.servlet.http.HttpServletResponse.SC_ACCEPTED; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableSet; import com.google.common.flogger.FluentLogger; import com.google.common.net.InternetDomainName; import dagger.Lazy; @@ -354,12 +356,11 @@ public final class PublishDnsUpdatesAction implements Runnable, Callable { /** Adds all the domains and hosts in the batch back to the queue to be processed later. */ private void requeueBatch() { logger.atInfo().log("Requeueing batch for retry."); - for (String domain : nullToEmpty(domains)) { - requestDomainDnsRefresh(domain); - } - for (String host : nullToEmpty(hosts)) { - requestHostDnsRefresh(host); - } + tm().transact( + () -> { + requestDomainDnsRefresh(ImmutableSet.copyOf(nullToEmpty(domains))); + requestHostDnsRefresh(ImmutableSet.copyOf(nullToEmpty(hosts))); + }); } /** Returns if the lock parameters are valid for this action. */ diff --git a/core/src/main/java/google/registry/dns/RefreshDnsAction.java b/core/src/main/java/google/registry/dns/RefreshDnsAction.java index 4418b8491..af4233a14 100644 --- a/core/src/main/java/google/registry/dns/RefreshDnsAction.java +++ b/core/src/main/java/google/registry/dns/RefreshDnsAction.java @@ -17,6 +17,7 @@ package google.registry.dns; import static google.registry.dns.DnsUtils.requestDomainDnsRefresh; import static google.registry.dns.DnsUtils.requestHostDnsRefresh; import static google.registry.model.EppResourceUtils.loadByForeignKey; +import static google.registry.persistence.transaction.TransactionManagerFactory.tm; import google.registry.dns.DnsUtils.TargetType; import google.registry.model.EppResource; @@ -59,18 +60,21 @@ public final class RefreshDnsAction implements Runnable { if (!domainOrHostName.contains(".")) { throw new BadRequestException("URL parameter 'name' must be fully qualified"); } - switch (type) { - case DOMAIN: - loadAndVerifyExistence(Domain.class, domainOrHostName); - requestDomainDnsRefresh(domainOrHostName); - break; - case HOST: - verifyHostIsSubordinate(loadAndVerifyExistence(Host.class, domainOrHostName)); - requestHostDnsRefresh(domainOrHostName); - break; - default: - throw new BadRequestException("Unsupported type: " + type); - } + tm().transact( + () -> { + switch (type) { + case DOMAIN: + loadAndVerifyExistence(Domain.class, domainOrHostName); + requestDomainDnsRefresh(domainOrHostName); + break; + case HOST: + verifyHostIsSubordinate(loadAndVerifyExistence(Host.class, domainOrHostName)); + requestHostDnsRefresh(domainOrHostName); + break; + default: + throw new BadRequestException("Unsupported type: " + type); + } + }); } private diff --git a/core/src/test/java/google/registry/dns/DnsInjectionTest.java b/core/src/test/java/google/registry/dns/DnsInjectionTest.java index 37dd60245..285dd62f3 100644 --- a/core/src/test/java/google/registry/dns/DnsInjectionTest.java +++ b/core/src/test/java/google/registry/dns/DnsInjectionTest.java @@ -15,6 +15,7 @@ package google.registry.dns; import static com.google.common.truth.Truth.assertThat; +import static google.registry.persistence.transaction.TransactionManagerFactory.tm; import static google.registry.testing.DatabaseHelper.assertHostDnsRequests; import static google.registry.testing.DatabaseHelper.assertNoDnsRequests; import static google.registry.testing.DatabaseHelper.assertNoDnsRequestsExcept; @@ -68,7 +69,7 @@ public final class DnsInjectionTest { void testReadDnsRefreshRequestsAction_injectsAndWorks() { persistActiveSubordinateHost("ns1.example.lol", persistActiveDomain("example.lol")); clock.advanceOneMilli(); - DnsUtils.requestDomainDnsRefresh("example.lol"); + tm().transact(() -> DnsUtils.requestDomainDnsRefresh("example.lol")); when(req.getParameter("tld")).thenReturn("lol"); clock.advanceOneMilli(); component.readDnsRefreshRequestsAction().run(); diff --git a/core/src/test/java/google/registry/dns/DnsUtilsTest.java b/core/src/test/java/google/registry/dns/DnsUtilsTest.java index e1a8c16a3..69e54d368 100644 --- a/core/src/test/java/google/registry/dns/DnsUtilsTest.java +++ b/core/src/test/java/google/registry/dns/DnsUtilsTest.java @@ -16,6 +16,10 @@ package google.registry.dns; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.truth.Truth.assertThat; +import static google.registry.dns.DnsUtils.deleteRequests; +import static google.registry.dns.DnsUtils.readAndUpdateRequestsWithLatestProcessTime; +import static google.registry.dns.DnsUtils.requestDomainDnsRefresh; +import static google.registry.dns.DnsUtils.requestHostDnsRefresh; import static google.registry.persistence.transaction.TransactionManagerFactory.tm; import static google.registry.testing.DatabaseHelper.createTld; import static google.registry.testing.DatabaseHelper.loadAllOf; @@ -58,7 +62,8 @@ public class DnsUtilsTest { void testFailure_hostRefresh_unmanagedHost() { String unmanagedHostName = "ns1.another.example"; Assertions.assertThrows( - IllegalArgumentException.class, () -> DnsUtils.requestHostDnsRefresh(unmanagedHostName)); + IllegalArgumentException.class, + () -> tm().transact(() -> requestHostDnsRefresh(unmanagedHostName))); assertThat(loadAllOf(DnsRefreshRequest.class)).isEmpty(); } @@ -67,27 +72,38 @@ public class DnsUtilsTest { String unmanagedDomainName = "another.example"; Assertions.assertThrows( IllegalArgumentException.class, - () -> DnsUtils.requestDomainDnsRefresh(unmanagedDomainName)); + () -> tm().transact(() -> requestDomainDnsRefresh(unmanagedDomainName))); assertThat(loadAllOf(DnsRefreshRequest.class)).isEmpty(); } @Test void testSuccess_hostRefresh() { - DnsUtils.requestHostDnsRefresh(hostName); + tm().transact(() -> requestHostDnsRefresh(hostName)); DnsRefreshRequest request = Iterables.getOnlyElement(loadAllOf(DnsRefreshRequest.class)); assertRequest(request, TargetType.HOST, hostName, tld, clock.nowUtc()); } @Test void testSuccess_domainRefresh() { - DnsUtils.requestDomainDnsRefresh(domainName); + tm().transact( + () -> requestDomainDnsRefresh(ImmutableList.of(domainName, "test2.tld", "test3.tld"))); + ImmutableList requests = loadAllOf(DnsRefreshRequest.class); + assertThat(requests.size()).isEqualTo(3); + assertRequest(requests.get(0), TargetType.DOMAIN, domainName, tld, clock.nowUtc()); + assertRequest(requests.get(1), TargetType.DOMAIN, "test2.tld", tld, clock.nowUtc()); + assertRequest(requests.get(2), TargetType.DOMAIN, "test3.tld", tld, clock.nowUtc()); + } + + @Test + void testSuccess_domainRefreshMultipleDomains() { + tm().transact(() -> requestDomainDnsRefresh(domainName)); DnsRefreshRequest request = Iterables.getOnlyElement(loadAllOf(DnsRefreshRequest.class)); assertRequest(request, TargetType.DOMAIN, domainName, tld, clock.nowUtc()); } @Test void testSuccess_domainRefreshWithDelay() { - DnsUtils.requestDomainDnsRefresh(domainName, Duration.standardMinutes(3)); + tm().transact(() -> requestDomainDnsRefresh(domainName, Duration.standardMinutes(3))); DnsRefreshRequest request = Iterables.getOnlyElement(loadAllOf(DnsRefreshRequest.class)); assertRequest(request, TargetType.DOMAIN, domainName, tld, clock.nowUtc().plusMinutes(3)); } @@ -133,8 +149,7 @@ public class DnsUtilsTest { clock.advanceOneMilli(); // Requests within cooldown period not included. - requests = - DnsUtils.readAndUpdateRequestsWithLatestProcessTime("tld", Duration.standardMinutes(1), 4); + requests = readAndUpdateRequestsWithLatestProcessTime("tld", Duration.standardMinutes(1), 4); assertThat(requests.size()).isEqualTo(1); assertRequest( requests.get(0), @@ -147,7 +162,7 @@ public class DnsUtilsTest { @Test void testSuccess_deleteRequests() { - DnsUtils.deleteRequests(processRequests()); + deleteRequests(processRequests()); ImmutableList remainingRequests = loadAllOf(DnsRefreshRequest.class).stream() .sorted(Comparator.comparing(DnsRefreshRequest::getRequestTime)) @@ -174,31 +189,30 @@ public class DnsUtilsTest { tm().transact(() -> tm().delete(remainingRequests.get(2))); assertThat(loadAllOf(DnsRefreshRequest.class).size()).isEqualTo(2); // Should not throw even though one of the request is already deleted. - DnsUtils.deleteRequests(remainingRequests); + deleteRequests(remainingRequests); assertThat(loadAllOf(DnsRefreshRequest.class).size()).isEqualTo(0); } private ImmutableList processRequests() { createTld("example"); // Domain Included. - DnsUtils.requestDomainDnsRefresh("test1.tld", Duration.standardMinutes(1)); + tm().transact(() -> requestDomainDnsRefresh("test1.tld", Duration.standardMinutes(1))); // This one should be returned before test1.tld, even though it's added later, because of // the delay specified in test1.tld. - DnsUtils.requestDomainDnsRefresh("test2.tld"); + tm().transact(() -> requestDomainDnsRefresh("test2.tld")); // Not included because the TLD is not under management. - DnsUtils.requestDomainDnsRefresh("something.example", Duration.standardMinutes(2)); + tm().transact(() -> requestDomainDnsRefresh("something.example", Duration.standardMinutes(2))); clock.advanceBy(Duration.standardMinutes(3)); // Host included. - DnsUtils.requestHostDnsRefresh("ns1.test2.tld"); + tm().transact(() -> requestHostDnsRefresh("ns1.test2.tld")); // Not included because the request time is in the future - DnsUtils.requestDomainDnsRefresh("test4.tld", Duration.standardMinutes(2)); + tm().transact(() -> requestDomainDnsRefresh("test4.tld", Duration.standardMinutes(2))); // Included after the previous one. Same request time, order by insertion order (i.e. ID); - DnsUtils.requestDomainDnsRefresh("test5.tld"); + tm().transact(() -> requestDomainDnsRefresh("test5.tld")); // Not included because batch size is exceeded; - DnsUtils.requestDomainDnsRefresh("test6.tld"); + tm().transact(() -> requestDomainDnsRefresh("test6.tld")); clock.advanceBy(Duration.standardMinutes(1)); - return DnsUtils.readAndUpdateRequestsWithLatestProcessTime( - "tld", Duration.standardMinutes(1), 4); + return readAndUpdateRequestsWithLatestProcessTime("tld", Duration.standardMinutes(1), 4); } private static void assertRequest(