diff --git a/java/google/registry/dns/writer/clouddns/CloudDnsWriter.java b/java/google/registry/dns/writer/clouddns/CloudDnsWriter.java index 6ccb6a52f..b9979085f 100644 --- a/java/google/registry/dns/writer/clouddns/CloudDnsWriter.java +++ b/java/google/registry/dns/writer/clouddns/CloudDnsWriter.java @@ -15,6 +15,7 @@ package google.registry.dns.writer.clouddns; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static google.registry.model.EppResourceUtils.loadByForeignKey; import com.google.api.client.googleapis.json.GoogleJsonError.ErrorInfo; @@ -27,7 +28,6 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.ImmutableSet.Builder; import com.google.common.net.InternetDomainName; import com.google.common.util.concurrent.RateLimiter; import google.registry.config.RegistryConfig.Config; @@ -39,17 +39,21 @@ import google.registry.model.domain.secdns.DelegationSignerData; import google.registry.model.host.HostResource; import google.registry.model.registry.Registries; import google.registry.util.Clock; +import google.registry.util.Concurrent; import google.registry.util.FormattingLogger; import google.registry.util.Retrier; import java.io.IOException; import java.net.Inet4Address; import java.net.Inet6Address; import java.net.InetAddress; +import java.util.AbstractMap.SimpleImmutableEntry; import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Stream; import javax.inject.Inject; import javax.inject.Named; import org.joda.time.Duration; @@ -73,6 +77,7 @@ public class CloudDnsWriter extends BaseDnsWriter { private final Clock clock; private final RateLimiter rateLimiter; + private final int numThreads; // TODO(shikhman): This uses @Named("transientFailureRetries") which may not be tuned for this // application. private final Retrier retrier; @@ -93,6 +98,7 @@ public class CloudDnsWriter extends BaseDnsWriter { @Config("dnsDefaultNsTtl") Duration defaultNsTtl, @Config("dnsDefaultDsTtl") Duration defaultDsTtl, @Named("cloudDns") RateLimiter rateLimiter, + @Named("cloudDnsNumThreads") int numThreads, Clock clock, Retrier retrier) { this.dnsConnection = dnsConnection; @@ -104,6 +110,7 @@ public class CloudDnsWriter extends BaseDnsWriter { this.rateLimiter = rateLimiter; this.clock = clock; this.retrier = retrier; + this.numThreads = numThreads; } /** Publish the domain and all subordinate hosts. */ @@ -277,76 +284,101 @@ public class CloudDnsWriter extends BaseDnsWriter { logger.info("Wrote to Cloud DNS"); } + /** + * Returns the glue records for in-bailiwick nameservers for the given domain+records. + */ + private Stream filterGlueRecords(String domainName, Stream records) { + return records + .filter(record -> record.getType().equals("NS")) + .flatMap(record -> record.getRrdatas().stream()) + .filter(hostName -> hostName.endsWith(domainName) && !hostName.equals(domainName)); + } + /** * Mutate the zone with the provided {@code desiredRecords}. */ @VisibleForTesting - void mutateZone(ImmutableMap> desiredRecords) - throws IOException { + void mutateZone(ImmutableMap> desiredRecords) { // Fetch all existing records for names that this writer is trying to modify - Builder existingRecords = new Builder<>(); - for (String domainName : desiredRecords.keySet()) { - List existingRecordsForDomain = getResourceRecordsForDomain(domainName); - existingRecords.addAll(existingRecordsForDomain); + ImmutableSet.Builder flattenedExistingRecords = new ImmutableSet.Builder<>(); - // Fetch glue records for in-bailiwick nameservers - for (ResourceRecordSet record : existingRecordsForDomain) { - if (!record.getType().equals("NS")) { - continue; - } - for (String hostName : record.getRrdatas()) { - if (hostName.endsWith(domainName) && !hostName.equals(domainName)) { - existingRecords.addAll(getResourceRecordsForDomain(hostName)); - } - } - } - } + // First, fetch the records for the given domains + Map> domainRecords = + getResourceRecordsForDomains(desiredRecords.keySet()); + + // add the records to the list of exiting records + domainRecords.values().forEach(flattenedExistingRecords::addAll); + + // Get the glue record host names from the given records + ImmutableSet hostsToRead = + domainRecords + .entrySet() + .stream() + .flatMap(entry -> filterGlueRecords(entry.getKey(), entry.getValue().stream())) + .collect(toImmutableSet()); + + // Then fetch and add the records for these hosts + getResourceRecordsForDomains(hostsToRead).values().forEach(flattenedExistingRecords::addAll); // Flatten the desired records into one set. - Builder flattenedDesiredRecords = new Builder<>(); - for (ImmutableSet records : desiredRecords.values()) { - flattenedDesiredRecords.addAll(records); - } + ImmutableSet.Builder flattenedDesiredRecords = new ImmutableSet.Builder<>(); + desiredRecords.values().forEach(flattenedDesiredRecords::addAll); // Delete all existing records and add back the desired records - updateResourceRecords(flattenedDesiredRecords.build(), existingRecords.build()); + updateResourceRecords(flattenedDesiredRecords.build(), flattenedExistingRecords.build()); + } + + /** + * Fetch the {@link ResourceRecordSet}s for the given domain names under this zone. + * + *

The provided domain should be in absolute form. + */ + private Map> getResourceRecordsForDomains( + Set domainNames) { + logger.finefmt("Fetching records for %s", domainNames); + // As per Concurrent.transform() - if numThreads or domainNames.size() < 2, it will not use + // threading. + return ImmutableMap.copyOf( + Concurrent.transform( + domainNames, + numThreads, + domainName -> + new SimpleImmutableEntry<>(domainName, getResourceRecordsForDomain(domainName)))); } /** * Fetch the {@link ResourceRecordSet}s for the given domain name under this zone. * *

The provided domain should be in absolute form. - * - * @throws IOException if the operation could not be completed successfully */ - private List getResourceRecordsForDomain(String domainName) - throws IOException { - logger.finefmt("Fetching records for %s", domainName); - Dns.ResourceRecordSets.List listRecordsRequest = - dnsConnection.resourceRecordSets().list(projectId, zoneName).setName(domainName); + private List getResourceRecordsForDomain(String domainName) { + // TODO(b/70217860): do we want to use a retrier here? + try { + Dns.ResourceRecordSets.List listRecordsRequest = + dnsConnection.resourceRecordSets().list(projectId, zoneName).setName(domainName); - rateLimiter.acquire(); - return listRecordsRequest.execute().getRrsets(); + rateLimiter.acquire(); + return listRecordsRequest.execute().getRrsets(); + } catch (IOException e) { + throw new RuntimeException(e); + } } /** * Update {@link ResourceRecordSet}s under this zone. * - *

This call should be used in conjunction with getResourceRecordsForDomain in a get-and-set - * retry loop. + *

This call should be used in conjunction with {@link #getResourceRecordsForDomains} in a + * get-and-set retry loop. * *

See {@link "https://cloud.google.com/dns/troubleshooting"} for a list of errors produced by * the Google Cloud DNS API. * - * @throws IOException if the operation could not be completed successfully due to an - * uncorrectable error. * @throws ZoneStateException if the operation could not be completely successfully because the * records to delete do not exist, already exist or have been modified with different * attributes since being queried. */ private void updateResourceRecords( - ImmutableSet additions, ImmutableSet deletions) - throws IOException, ZoneStateException { + ImmutableSet additions, ImmutableSet deletions) { Change change = new Change().setAdditions(additions.asList()).setDeletions(deletions.asList()); rateLimiter.acquire(); @@ -356,15 +388,17 @@ public class CloudDnsWriter extends BaseDnsWriter { List errors = e.getDetails().getErrors(); // We did something really wrong here, just give up and re-throw if (errors.size() > 1) { - throw e; + throw new RuntimeException(e); } String errorReason = errors.get(0).getReason(); if (RETRYABLE_EXCEPTION_REASONS.contains(errorReason)) { throw new ZoneStateException(errorReason); } else { - throw e; + throw new RuntimeException(e); } + } catch (IOException e) { + throw new RuntimeException(e); } } diff --git a/java/google/registry/dns/writer/clouddns/CloudDnsWriterModule.java b/java/google/registry/dns/writer/clouddns/CloudDnsWriterModule.java index a71be2d23..2e8084f6b 100644 --- a/java/google/registry/dns/writer/clouddns/CloudDnsWriterModule.java +++ b/java/google/registry/dns/writer/clouddns/CloudDnsWriterModule.java @@ -76,4 +76,14 @@ public final class CloudDnsWriterModule { int cloudDnsMaxQps = 20; return RateLimiter.create(cloudDnsMaxQps); } + + @Provides + @Named("cloudDnsNumThreads") + static int provideNumThreads() { + // TODO(b/70217860): find the "best" number of threads, taking into account running time, App + // Engine constraints, and any Cloud DNS comsiderations etc. + // + // NOTE: any number below 2 will not use threading at all. + return 10; + } } diff --git a/java/google/registry/util/Concurrent.java b/java/google/registry/util/Concurrent.java index 256881bf0..0d7aa862d 100644 --- a/java/google/registry/util/Concurrent.java +++ b/java/google/registry/util/Concurrent.java @@ -45,12 +45,15 @@ public final class Concurrent { * @see #transform(Collection, int, Function) */ public static ImmutableList transform(Collection items, final Function funk) { - return transform(items, max(1, min(items.size(), MAX_THREADS)), funk); + return transform(items, MAX_THREADS, funk); } /** * Processes {@code items} in parallel using {@code funk}, with the specified number of threads. * + *

If the maxThreadCount or the number of items is less than 2, will use a non-concurrent + * transform. + * *

Note: Spawned threads will inherit the same namespace. * * @throws UncheckedExecutionException to wrap the exception thrown by {@code funk}. This will @@ -59,17 +62,18 @@ public final class Concurrent { */ public static ImmutableList transform( Collection items, - int threadCount, + int maxThreadCount, final Function funk) { checkNotNull(funk); checkNotNull(items); - ThreadFactory threadFactory = currentRequestThreadFactory(); + int threadCount = max(1, min(items.size(), maxThreadCount)); + ThreadFactory threadFactory = threadCount > 1 ? currentRequestThreadFactory() : null; if (threadFactory == null) { - // Fall back to non-concurrent transform if we can't get an App Engine thread factory (most - // likely caused by hitting this code from a command-line tool). Default Java system threads - // are not compatible with code that needs to interact with App Engine (such as Objectify), - // which we often have in funk when calling Concurrent.transform(). - // For more info see: http://stackoverflow.com/questions/15976406 + // Fall back to non-concurrent transform if we only want 1 thread, or if we can't get an App + // Engine thread factory (most likely caused by hitting this code from a command-line tool). + // Default Java system threads are not compatible with code that needs to interact with App + // Engine (such as Objectify), which we often have in funk when calling + // Concurrent.transform(). For more info see: http://stackoverflow.com/questions/15976406 return items.stream().map(funk).collect(toImmutableList()); } ExecutorService executor = newFixedThreadPool(threadCount, threadFactory); diff --git a/javatests/google/registry/dns/writer/clouddns/CloudDnsWriterTest.java b/javatests/google/registry/dns/writer/clouddns/CloudDnsWriterTest.java index eaaa499ef..9475be462 100644 --- a/javatests/google/registry/dns/writer/clouddns/CloudDnsWriterTest.java +++ b/javatests/google/registry/dns/writer/clouddns/CloudDnsWriterTest.java @@ -23,6 +23,7 @@ import static google.registry.testing.DatastoreHelper.newHostResource; import static google.registry.testing.DatastoreHelper.persistResource; import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -77,10 +78,8 @@ public class CloudDnsWriterTest { @Mock private Dns dnsConnection; @Mock private Dns.ResourceRecordSets resourceRecordSets; - @Mock private Dns.ResourceRecordSets.List listResourceRecordSetsRequest; @Mock private Dns.Changes changes; @Mock private Dns.Changes.Create createChangeRequest; - @Captor ArgumentCaptor recordNameCaptor; @Captor ArgumentCaptor zoneNameCaptor; @Captor ArgumentCaptor changeCaptor; private CloudDnsWriter writer; @@ -90,28 +89,15 @@ public class CloudDnsWriterTest { @Rule public final AppEngineRule appEngine = AppEngineRule.builder().withDatastore().build(); - @Before - public void setUp() throws Exception { - createTld("tld"); - writer = - new CloudDnsWriter( - dnsConnection, - "projectId", - "triple.secret.tld", // used by testInvalidZoneNames() - DEFAULT_A_TTL, - DEFAULT_NS_TTL, - DEFAULT_DS_TTL, - RateLimiter.create(20), - new SystemClock(), - new Retrier(new SystemSleeper(), 5)); - - // Create an empty zone. - stubZone = ImmutableSet.of(); - - when(dnsConnection.changes()).thenReturn(changes); - when(dnsConnection.resourceRecordSets()).thenReturn(resourceRecordSets); - when(resourceRecordSets.list(anyString(), anyString())) - .thenReturn(listResourceRecordSetsRequest); + /* + * Because of multi-threading in the CloudDnsWriter, we need to return a different instance of + * List for every request, with its own ArgumentCaptor. Otherwise, we can't separate the arguments + * of the various Lists + */ + private Dns.ResourceRecordSets.List newListResourceRecordSetsRequestMock() throws Exception { + Dns.ResourceRecordSets.List listResourceRecordSetsRequest = + mock(Dns.ResourceRecordSets.List.class); + ArgumentCaptor recordNameCaptor = ArgumentCaptor.forClass(String.class); when(listResourceRecordSetsRequest.setName(recordNameCaptor.capture())) .thenReturn(listResourceRecordSetsRequest); // Return records from our stub zone when a request to list the records is executed @@ -126,7 +112,34 @@ public class CloudDnsWriterTest { rs -> rs != null && rs.getName().equals(recordNameCaptor.getValue())) .collect(toImmutableList()))); + return listResourceRecordSetsRequest; + } + + @Before + public void setUp() throws Exception { + createTld("tld"); + writer = + new CloudDnsWriter( + dnsConnection, + "projectId", + "triple.secret.tld", // used by testInvalidZoneNames() + DEFAULT_A_TTL, + DEFAULT_NS_TTL, + DEFAULT_DS_TTL, + RateLimiter.create(20), + 10, // max num threads + new SystemClock(), + new Retrier(new SystemSleeper(), 5)); + + // Create an empty zone. + stubZone = ImmutableSet.of(); + + when(dnsConnection.changes()).thenReturn(changes); + when(dnsConnection.resourceRecordSets()).thenReturn(resourceRecordSets); + when(resourceRecordSets.list(anyString(), anyString())) + .thenAnswer( + invocationOnMock -> newListResourceRecordSetsRequestMock()); when(changes.create(anyString(), zoneNameCaptor.capture(), changeCaptor.capture())) .thenReturn(createChangeRequest); // Change our stub zone when a request to change the records is executed