Remove nested transaction from requestDnsRefresh (#2044)

* Remove nested transaction from requestDnsRefresh

* Add a bulk version

* Remove transaction time as field

* Only add delay once

* have PublishDnsUpdatesAction use bulk refresh
This commit is contained in:
sarahcaseybot 2023-06-07 16:00:50 -04:00 committed by GitHub
parent fe86ef0a7d
commit 798a6ffc74
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 87 additions and 41 deletions

View file

@ -17,6 +17,7 @@ package google.registry.dns;
import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableList.toImmutableList;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm; import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import com.google.common.collect.ImmutableCollection;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.net.InternetDomainName; import com.google.common.net.InternetDomainName;
import google.registry.model.common.DnsRefreshRequest; import google.registry.model.common.DnsRefreshRequest;
@ -36,28 +37,53 @@ public final class DnsUtils {
private DnsUtils() {} private DnsUtils() {}
private static void requestDnsRefresh(String name, TargetType type, Duration delay) { private static void requestDnsRefresh(String name, TargetType type, Duration delay) {
tm().assertInTransaction();
// Throws an IllegalArgumentException if the name is not under a managed TLD -- we only update // Throws an IllegalArgumentException if the name is not under a managed TLD -- we only update
// DNS for names that are under our management. // DNS for names that are under our management.
String tld = Tlds.findTldForNameOrThrow(InternetDomainName.from(name)).toString(); String tld = Tlds.findTldForNameOrThrow(InternetDomainName.from(name)).toString();
tm().transact( tm().insert(new DnsRefreshRequest(type, name, tld, tm().getTransactionTime().plus(delay)));
() -> }
tm().insert(
private static void requestDnsRefresh(
ImmutableCollection<String> names, TargetType type, Duration delay) {
tm().assertInTransaction();
DateTime requestTime = tm().getTransactionTime().plus(delay);
tm().insertAll(
names.stream()
.map(
name ->
new DnsRefreshRequest( new DnsRefreshRequest(
type, name, tld, tm().getTransactionTime().plus(delay)))); type,
name,
Tlds.findTldForNameOrThrow(InternetDomainName.from(name)).toString(),
requestTime))
.collect(toImmutableList()));
} }
public static void requestDomainDnsRefresh(String domainName, Duration delay) { public static void requestDomainDnsRefresh(String domainName, Duration delay) {
requestDnsRefresh(domainName, TargetType.DOMAIN, delay); requestDnsRefresh(domainName, TargetType.DOMAIN, delay);
} }
public static void requestDomainDnsRefresh(ImmutableCollection<String> names, Duration delay) {
requestDnsRefresh(names, TargetType.DOMAIN, delay);
}
public static void requestDomainDnsRefresh(String domainName) { public static void requestDomainDnsRefresh(String domainName) {
requestDomainDnsRefresh(domainName, Duration.ZERO); requestDomainDnsRefresh(domainName, Duration.ZERO);
} }
public static void requestDomainDnsRefresh(ImmutableCollection<String> names) {
requestDomainDnsRefresh(names, Duration.ZERO);
}
public static void requestHostDnsRefresh(String hostName) { public static void requestHostDnsRefresh(String hostName) {
requestDnsRefresh(hostName, TargetType.HOST, Duration.ZERO); requestDnsRefresh(hostName, TargetType.HOST, Duration.ZERO);
} }
public static void requestHostDnsRefresh(ImmutableCollection<String> hostNames) {
requestDnsRefresh(hostNames, TargetType.HOST, Duration.ZERO);
}
/** /**
* Returns pending DNS update requests that need further processing up to batch size, in ascending * Returns pending DNS update requests that need further processing up to batch size, in ascending
* order of their request time, and updates their processing time to now. * order of their request time, and updates their processing time to now.

View file

@ -26,6 +26,7 @@ import static google.registry.dns.DnsUtils.DNS_PUBLISH_PUSH_QUEUE_NAME;
import static google.registry.dns.DnsUtils.requestDomainDnsRefresh; import static google.registry.dns.DnsUtils.requestDomainDnsRefresh;
import static google.registry.dns.DnsUtils.requestHostDnsRefresh; import static google.registry.dns.DnsUtils.requestHostDnsRefresh;
import static google.registry.model.EppResourceUtils.loadByForeignKey; import static google.registry.model.EppResourceUtils.loadByForeignKey;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.request.Action.Method.POST; import static google.registry.request.Action.Method.POST;
import static google.registry.request.RequestParameters.PARAM_TLD; import static google.registry.request.RequestParameters.PARAM_TLD;
import static google.registry.util.CollectionUtils.nullToEmpty; import static google.registry.util.CollectionUtils.nullToEmpty;
@ -34,6 +35,7 @@ import static javax.servlet.http.HttpServletResponse.SC_ACCEPTED;
import com.google.common.base.Joiner; import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.flogger.FluentLogger; import com.google.common.flogger.FluentLogger;
import com.google.common.net.InternetDomainName; import com.google.common.net.InternetDomainName;
import dagger.Lazy; import dagger.Lazy;
@ -354,12 +356,11 @@ public final class PublishDnsUpdatesAction implements Runnable, Callable<Void> {
/** Adds all the domains and hosts in the batch back to the queue to be processed later. */ /** Adds all the domains and hosts in the batch back to the queue to be processed later. */
private void requeueBatch() { private void requeueBatch() {
logger.atInfo().log("Requeueing batch for retry."); logger.atInfo().log("Requeueing batch for retry.");
for (String domain : nullToEmpty(domains)) { tm().transact(
requestDomainDnsRefresh(domain); () -> {
} requestDomainDnsRefresh(ImmutableSet.copyOf(nullToEmpty(domains)));
for (String host : nullToEmpty(hosts)) { requestHostDnsRefresh(ImmutableSet.copyOf(nullToEmpty(hosts)));
requestHostDnsRefresh(host); });
}
} }
/** Returns if the lock parameters are valid for this action. */ /** Returns if the lock parameters are valid for this action. */

View file

@ -17,6 +17,7 @@ package google.registry.dns;
import static google.registry.dns.DnsUtils.requestDomainDnsRefresh; import static google.registry.dns.DnsUtils.requestDomainDnsRefresh;
import static google.registry.dns.DnsUtils.requestHostDnsRefresh; import static google.registry.dns.DnsUtils.requestHostDnsRefresh;
import static google.registry.model.EppResourceUtils.loadByForeignKey; import static google.registry.model.EppResourceUtils.loadByForeignKey;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import google.registry.dns.DnsUtils.TargetType; import google.registry.dns.DnsUtils.TargetType;
import google.registry.model.EppResource; import google.registry.model.EppResource;
@ -59,18 +60,21 @@ public final class RefreshDnsAction implements Runnable {
if (!domainOrHostName.contains(".")) { if (!domainOrHostName.contains(".")) {
throw new BadRequestException("URL parameter 'name' must be fully qualified"); throw new BadRequestException("URL parameter 'name' must be fully qualified");
} }
switch (type) { tm().transact(
case DOMAIN: () -> {
loadAndVerifyExistence(Domain.class, domainOrHostName); switch (type) {
requestDomainDnsRefresh(domainOrHostName); case DOMAIN:
break; loadAndVerifyExistence(Domain.class, domainOrHostName);
case HOST: requestDomainDnsRefresh(domainOrHostName);
verifyHostIsSubordinate(loadAndVerifyExistence(Host.class, domainOrHostName)); break;
requestHostDnsRefresh(domainOrHostName); case HOST:
break; verifyHostIsSubordinate(loadAndVerifyExistence(Host.class, domainOrHostName));
default: requestHostDnsRefresh(domainOrHostName);
throw new BadRequestException("Unsupported type: " + type); break;
} default:
throw new BadRequestException("Unsupported type: " + type);
}
});
} }
private <T extends EppResource & ForeignKeyedEppResource> private <T extends EppResource & ForeignKeyedEppResource>

View file

@ -15,6 +15,7 @@
package google.registry.dns; package google.registry.dns;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.testing.DatabaseHelper.assertHostDnsRequests; import static google.registry.testing.DatabaseHelper.assertHostDnsRequests;
import static google.registry.testing.DatabaseHelper.assertNoDnsRequests; import static google.registry.testing.DatabaseHelper.assertNoDnsRequests;
import static google.registry.testing.DatabaseHelper.assertNoDnsRequestsExcept; import static google.registry.testing.DatabaseHelper.assertNoDnsRequestsExcept;
@ -68,7 +69,7 @@ public final class DnsInjectionTest {
void testReadDnsRefreshRequestsAction_injectsAndWorks() { void testReadDnsRefreshRequestsAction_injectsAndWorks() {
persistActiveSubordinateHost("ns1.example.lol", persistActiveDomain("example.lol")); persistActiveSubordinateHost("ns1.example.lol", persistActiveDomain("example.lol"));
clock.advanceOneMilli(); clock.advanceOneMilli();
DnsUtils.requestDomainDnsRefresh("example.lol"); tm().transact(() -> DnsUtils.requestDomainDnsRefresh("example.lol"));
when(req.getParameter("tld")).thenReturn("lol"); when(req.getParameter("tld")).thenReturn("lol");
clock.advanceOneMilli(); clock.advanceOneMilli();
component.readDnsRefreshRequestsAction().run(); component.readDnsRefreshRequestsAction().run();

View file

@ -16,6 +16,10 @@ package google.registry.dns;
import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static google.registry.dns.DnsUtils.deleteRequests;
import static google.registry.dns.DnsUtils.readAndUpdateRequestsWithLatestProcessTime;
import static google.registry.dns.DnsUtils.requestDomainDnsRefresh;
import static google.registry.dns.DnsUtils.requestHostDnsRefresh;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm; import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.testing.DatabaseHelper.createTld; import static google.registry.testing.DatabaseHelper.createTld;
import static google.registry.testing.DatabaseHelper.loadAllOf; import static google.registry.testing.DatabaseHelper.loadAllOf;
@ -58,7 +62,8 @@ public class DnsUtilsTest {
void testFailure_hostRefresh_unmanagedHost() { void testFailure_hostRefresh_unmanagedHost() {
String unmanagedHostName = "ns1.another.example"; String unmanagedHostName = "ns1.another.example";
Assertions.assertThrows( Assertions.assertThrows(
IllegalArgumentException.class, () -> DnsUtils.requestHostDnsRefresh(unmanagedHostName)); IllegalArgumentException.class,
() -> tm().transact(() -> requestHostDnsRefresh(unmanagedHostName)));
assertThat(loadAllOf(DnsRefreshRequest.class)).isEmpty(); assertThat(loadAllOf(DnsRefreshRequest.class)).isEmpty();
} }
@ -67,27 +72,38 @@ public class DnsUtilsTest {
String unmanagedDomainName = "another.example"; String unmanagedDomainName = "another.example";
Assertions.assertThrows( Assertions.assertThrows(
IllegalArgumentException.class, IllegalArgumentException.class,
() -> DnsUtils.requestDomainDnsRefresh(unmanagedDomainName)); () -> tm().transact(() -> requestDomainDnsRefresh(unmanagedDomainName)));
assertThat(loadAllOf(DnsRefreshRequest.class)).isEmpty(); assertThat(loadAllOf(DnsRefreshRequest.class)).isEmpty();
} }
@Test @Test
void testSuccess_hostRefresh() { void testSuccess_hostRefresh() {
DnsUtils.requestHostDnsRefresh(hostName); tm().transact(() -> requestHostDnsRefresh(hostName));
DnsRefreshRequest request = Iterables.getOnlyElement(loadAllOf(DnsRefreshRequest.class)); DnsRefreshRequest request = Iterables.getOnlyElement(loadAllOf(DnsRefreshRequest.class));
assertRequest(request, TargetType.HOST, hostName, tld, clock.nowUtc()); assertRequest(request, TargetType.HOST, hostName, tld, clock.nowUtc());
} }
@Test @Test
void testSuccess_domainRefresh() { void testSuccess_domainRefresh() {
DnsUtils.requestDomainDnsRefresh(domainName); tm().transact(
() -> requestDomainDnsRefresh(ImmutableList.of(domainName, "test2.tld", "test3.tld")));
ImmutableList<DnsRefreshRequest> requests = loadAllOf(DnsRefreshRequest.class);
assertThat(requests.size()).isEqualTo(3);
assertRequest(requests.get(0), TargetType.DOMAIN, domainName, tld, clock.nowUtc());
assertRequest(requests.get(1), TargetType.DOMAIN, "test2.tld", tld, clock.nowUtc());
assertRequest(requests.get(2), TargetType.DOMAIN, "test3.tld", tld, clock.nowUtc());
}
@Test
void testSuccess_domainRefreshMultipleDomains() {
tm().transact(() -> requestDomainDnsRefresh(domainName));
DnsRefreshRequest request = Iterables.getOnlyElement(loadAllOf(DnsRefreshRequest.class)); DnsRefreshRequest request = Iterables.getOnlyElement(loadAllOf(DnsRefreshRequest.class));
assertRequest(request, TargetType.DOMAIN, domainName, tld, clock.nowUtc()); assertRequest(request, TargetType.DOMAIN, domainName, tld, clock.nowUtc());
} }
@Test @Test
void testSuccess_domainRefreshWithDelay() { void testSuccess_domainRefreshWithDelay() {
DnsUtils.requestDomainDnsRefresh(domainName, Duration.standardMinutes(3)); tm().transact(() -> requestDomainDnsRefresh(domainName, Duration.standardMinutes(3)));
DnsRefreshRequest request = Iterables.getOnlyElement(loadAllOf(DnsRefreshRequest.class)); DnsRefreshRequest request = Iterables.getOnlyElement(loadAllOf(DnsRefreshRequest.class));
assertRequest(request, TargetType.DOMAIN, domainName, tld, clock.nowUtc().plusMinutes(3)); assertRequest(request, TargetType.DOMAIN, domainName, tld, clock.nowUtc().plusMinutes(3));
} }
@ -133,8 +149,7 @@ public class DnsUtilsTest {
clock.advanceOneMilli(); clock.advanceOneMilli();
// Requests within cooldown period not included. // Requests within cooldown period not included.
requests = requests = readAndUpdateRequestsWithLatestProcessTime("tld", Duration.standardMinutes(1), 4);
DnsUtils.readAndUpdateRequestsWithLatestProcessTime("tld", Duration.standardMinutes(1), 4);
assertThat(requests.size()).isEqualTo(1); assertThat(requests.size()).isEqualTo(1);
assertRequest( assertRequest(
requests.get(0), requests.get(0),
@ -147,7 +162,7 @@ public class DnsUtilsTest {
@Test @Test
void testSuccess_deleteRequests() { void testSuccess_deleteRequests() {
DnsUtils.deleteRequests(processRequests()); deleteRequests(processRequests());
ImmutableList<DnsRefreshRequest> remainingRequests = ImmutableList<DnsRefreshRequest> remainingRequests =
loadAllOf(DnsRefreshRequest.class).stream() loadAllOf(DnsRefreshRequest.class).stream()
.sorted(Comparator.comparing(DnsRefreshRequest::getRequestTime)) .sorted(Comparator.comparing(DnsRefreshRequest::getRequestTime))
@ -174,31 +189,30 @@ public class DnsUtilsTest {
tm().transact(() -> tm().delete(remainingRequests.get(2))); tm().transact(() -> tm().delete(remainingRequests.get(2)));
assertThat(loadAllOf(DnsRefreshRequest.class).size()).isEqualTo(2); assertThat(loadAllOf(DnsRefreshRequest.class).size()).isEqualTo(2);
// Should not throw even though one of the request is already deleted. // Should not throw even though one of the request is already deleted.
DnsUtils.deleteRequests(remainingRequests); deleteRequests(remainingRequests);
assertThat(loadAllOf(DnsRefreshRequest.class).size()).isEqualTo(0); assertThat(loadAllOf(DnsRefreshRequest.class).size()).isEqualTo(0);
} }
private ImmutableList<DnsRefreshRequest> processRequests() { private ImmutableList<DnsRefreshRequest> processRequests() {
createTld("example"); createTld("example");
// Domain Included. // Domain Included.
DnsUtils.requestDomainDnsRefresh("test1.tld", Duration.standardMinutes(1)); tm().transact(() -> requestDomainDnsRefresh("test1.tld", Duration.standardMinutes(1)));
// This one should be returned before test1.tld, even though it's added later, because of // This one should be returned before test1.tld, even though it's added later, because of
// the delay specified in test1.tld. // the delay specified in test1.tld.
DnsUtils.requestDomainDnsRefresh("test2.tld"); tm().transact(() -> requestDomainDnsRefresh("test2.tld"));
// Not included because the TLD is not under management. // Not included because the TLD is not under management.
DnsUtils.requestDomainDnsRefresh("something.example", Duration.standardMinutes(2)); tm().transact(() -> requestDomainDnsRefresh("something.example", Duration.standardMinutes(2)));
clock.advanceBy(Duration.standardMinutes(3)); clock.advanceBy(Duration.standardMinutes(3));
// Host included. // Host included.
DnsUtils.requestHostDnsRefresh("ns1.test2.tld"); tm().transact(() -> requestHostDnsRefresh("ns1.test2.tld"));
// Not included because the request time is in the future // Not included because the request time is in the future
DnsUtils.requestDomainDnsRefresh("test4.tld", Duration.standardMinutes(2)); tm().transact(() -> requestDomainDnsRefresh("test4.tld", Duration.standardMinutes(2)));
// Included after the previous one. Same request time, order by insertion order (i.e. ID); // Included after the previous one. Same request time, order by insertion order (i.e. ID);
DnsUtils.requestDomainDnsRefresh("test5.tld"); tm().transact(() -> requestDomainDnsRefresh("test5.tld"));
// Not included because batch size is exceeded; // Not included because batch size is exceeded;
DnsUtils.requestDomainDnsRefresh("test6.tld"); tm().transact(() -> requestDomainDnsRefresh("test6.tld"));
clock.advanceBy(Duration.standardMinutes(1)); clock.advanceBy(Duration.standardMinutes(1));
return DnsUtils.readAndUpdateRequestsWithLatestProcessTime( return readAndUpdateRequestsWithLatestProcessTime("tld", Duration.standardMinutes(1), 4);
"tld", Duration.standardMinutes(1), 4);
} }
private static void assertRequest( private static void assertRequest(