Fetch data from Cloud DNS in parallel

Before pushing an update to Cloud DNS, the CloudDnsWriter needs to read all the domain RRSs from Cloud DNS one by one to know what to delete.

Doing so sequentially results in update times that are too long (approx 200ms per domain, which is 20 seconds per batch of 100) severely limiting our QPS.

This CL uses Concurrent threading to do the Cloud DNS queries in parallel. Unfortunately, my preferred method (Set.parallelStream) doesn't work on App Engine :(

This reduces the per-item time from 200ms to 80ms, which can be further reduced to 50ms if we remove the rate limiter (currently set to 20 per second).

-------------
Created by MOE: https://github.com/google/moe
MOE_MIGRATED_REVID=178126877
This commit is contained in:
guyben 2017-12-06 11:30:22 -08:00 committed by jianglai
parent 735112def6
commit d87f01e7bf
4 changed files with 134 additions and 73 deletions

View file

@ -15,6 +15,7 @@
package google.registry.dns.writer.clouddns; package google.registry.dns.writer.clouddns;
import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static google.registry.model.EppResourceUtils.loadByForeignKey; import static google.registry.model.EppResourceUtils.loadByForeignKey;
import com.google.api.client.googleapis.json.GoogleJsonError.ErrorInfo; import com.google.api.client.googleapis.json.GoogleJsonError.ErrorInfo;
@ -27,7 +28,6 @@ import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSet.Builder;
import com.google.common.net.InternetDomainName; import com.google.common.net.InternetDomainName;
import com.google.common.util.concurrent.RateLimiter; import com.google.common.util.concurrent.RateLimiter;
import google.registry.config.RegistryConfig.Config; import google.registry.config.RegistryConfig.Config;
@ -39,17 +39,21 @@ import google.registry.model.domain.secdns.DelegationSignerData;
import google.registry.model.host.HostResource; import google.registry.model.host.HostResource;
import google.registry.model.registry.Registries; import google.registry.model.registry.Registries;
import google.registry.util.Clock; import google.registry.util.Clock;
import google.registry.util.Concurrent;
import google.registry.util.FormattingLogger; import google.registry.util.FormattingLogger;
import google.registry.util.Retrier; import google.registry.util.Retrier;
import java.io.IOException; import java.io.IOException;
import java.net.Inet4Address; import java.net.Inet4Address;
import java.net.Inet6Address; import java.net.Inet6Address;
import java.net.InetAddress; import java.net.InetAddress;
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.stream.Stream;
import javax.inject.Inject; import javax.inject.Inject;
import javax.inject.Named; import javax.inject.Named;
import org.joda.time.Duration; import org.joda.time.Duration;
@ -73,6 +77,7 @@ public class CloudDnsWriter extends BaseDnsWriter {
private final Clock clock; private final Clock clock;
private final RateLimiter rateLimiter; private final RateLimiter rateLimiter;
private final int numThreads;
// TODO(shikhman): This uses @Named("transientFailureRetries") which may not be tuned for this // TODO(shikhman): This uses @Named("transientFailureRetries") which may not be tuned for this
// application. // application.
private final Retrier retrier; private final Retrier retrier;
@ -93,6 +98,7 @@ public class CloudDnsWriter extends BaseDnsWriter {
@Config("dnsDefaultNsTtl") Duration defaultNsTtl, @Config("dnsDefaultNsTtl") Duration defaultNsTtl,
@Config("dnsDefaultDsTtl") Duration defaultDsTtl, @Config("dnsDefaultDsTtl") Duration defaultDsTtl,
@Named("cloudDns") RateLimiter rateLimiter, @Named("cloudDns") RateLimiter rateLimiter,
@Named("cloudDnsNumThreads") int numThreads,
Clock clock, Clock clock,
Retrier retrier) { Retrier retrier) {
this.dnsConnection = dnsConnection; this.dnsConnection = dnsConnection;
@ -104,6 +110,7 @@ public class CloudDnsWriter extends BaseDnsWriter {
this.rateLimiter = rateLimiter; this.rateLimiter = rateLimiter;
this.clock = clock; this.clock = clock;
this.retrier = retrier; this.retrier = retrier;
this.numThreads = numThreads;
} }
/** Publish the domain and all subordinate hosts. */ /** Publish the domain and all subordinate hosts. */
@ -277,76 +284,101 @@ public class CloudDnsWriter extends BaseDnsWriter {
logger.info("Wrote to Cloud DNS"); logger.info("Wrote to Cloud DNS");
} }
/**
* Returns the glue records for in-bailiwick nameservers for the given domain+records.
*/
private Stream<String> filterGlueRecords(String domainName, Stream<ResourceRecordSet> records) {
return records
.filter(record -> record.getType().equals("NS"))
.flatMap(record -> record.getRrdatas().stream())
.filter(hostName -> hostName.endsWith(domainName) && !hostName.equals(domainName));
}
/** /**
* Mutate the zone with the provided {@code desiredRecords}. * Mutate the zone with the provided {@code desiredRecords}.
*/ */
@VisibleForTesting @VisibleForTesting
void mutateZone(ImmutableMap<String, ImmutableSet<ResourceRecordSet>> desiredRecords) void mutateZone(ImmutableMap<String, ImmutableSet<ResourceRecordSet>> desiredRecords) {
throws IOException {
// Fetch all existing records for names that this writer is trying to modify // Fetch all existing records for names that this writer is trying to modify
Builder<ResourceRecordSet> existingRecords = new Builder<>(); ImmutableSet.Builder<ResourceRecordSet> flattenedExistingRecords = new ImmutableSet.Builder<>();
for (String domainName : desiredRecords.keySet()) {
List<ResourceRecordSet> existingRecordsForDomain = getResourceRecordsForDomain(domainName);
existingRecords.addAll(existingRecordsForDomain);
// Fetch glue records for in-bailiwick nameservers // First, fetch the records for the given domains
for (ResourceRecordSet record : existingRecordsForDomain) { Map<String, List<ResourceRecordSet>> domainRecords =
if (!record.getType().equals("NS")) { getResourceRecordsForDomains(desiredRecords.keySet());
continue;
} // add the records to the list of exiting records
for (String hostName : record.getRrdatas()) { domainRecords.values().forEach(flattenedExistingRecords::addAll);
if (hostName.endsWith(domainName) && !hostName.equals(domainName)) {
existingRecords.addAll(getResourceRecordsForDomain(hostName)); // Get the glue record host names from the given records
} ImmutableSet<String> hostsToRead =
} domainRecords
} .entrySet()
} .stream()
.flatMap(entry -> filterGlueRecords(entry.getKey(), entry.getValue().stream()))
.collect(toImmutableSet());
// Then fetch and add the records for these hosts
getResourceRecordsForDomains(hostsToRead).values().forEach(flattenedExistingRecords::addAll);
// Flatten the desired records into one set. // Flatten the desired records into one set.
Builder<ResourceRecordSet> flattenedDesiredRecords = new Builder<>(); ImmutableSet.Builder<ResourceRecordSet> flattenedDesiredRecords = new ImmutableSet.Builder<>();
for (ImmutableSet<ResourceRecordSet> records : desiredRecords.values()) { desiredRecords.values().forEach(flattenedDesiredRecords::addAll);
flattenedDesiredRecords.addAll(records);
}
// Delete all existing records and add back the desired records // Delete all existing records and add back the desired records
updateResourceRecords(flattenedDesiredRecords.build(), existingRecords.build()); updateResourceRecords(flattenedDesiredRecords.build(), flattenedExistingRecords.build());
}
/**
* Fetch the {@link ResourceRecordSet}s for the given domain names under this zone.
*
* <p>The provided domain should be in absolute form.
*/
private Map<String, List<ResourceRecordSet>> getResourceRecordsForDomains(
Set<String> domainNames) {
logger.finefmt("Fetching records for %s", domainNames);
// As per Concurrent.transform() - if numThreads or domainNames.size() < 2, it will not use
// threading.
return ImmutableMap.copyOf(
Concurrent.transform(
domainNames,
numThreads,
domainName ->
new SimpleImmutableEntry<>(domainName, getResourceRecordsForDomain(domainName))));
} }
/** /**
* Fetch the {@link ResourceRecordSet}s for the given domain name under this zone. * Fetch the {@link ResourceRecordSet}s for the given domain name under this zone.
* *
* <p>The provided domain should be in absolute form. * <p>The provided domain should be in absolute form.
*
* @throws IOException if the operation could not be completed successfully
*/ */
private List<ResourceRecordSet> getResourceRecordsForDomain(String domainName) private List<ResourceRecordSet> getResourceRecordsForDomain(String domainName) {
throws IOException { // TODO(b/70217860): do we want to use a retrier here?
logger.finefmt("Fetching records for %s", domainName); try {
Dns.ResourceRecordSets.List listRecordsRequest = Dns.ResourceRecordSets.List listRecordsRequest =
dnsConnection.resourceRecordSets().list(projectId, zoneName).setName(domainName); dnsConnection.resourceRecordSets().list(projectId, zoneName).setName(domainName);
rateLimiter.acquire(); rateLimiter.acquire();
return listRecordsRequest.execute().getRrsets(); return listRecordsRequest.execute().getRrsets();
} catch (IOException e) {
throw new RuntimeException(e);
}
} }
/** /**
* Update {@link ResourceRecordSet}s under this zone. * Update {@link ResourceRecordSet}s under this zone.
* *
* <p>This call should be used in conjunction with getResourceRecordsForDomain in a get-and-set * <p>This call should be used in conjunction with {@link #getResourceRecordsForDomains} in a
* retry loop. * get-and-set retry loop.
* *
* <p>See {@link "https://cloud.google.com/dns/troubleshooting"} for a list of errors produced by * <p>See {@link "https://cloud.google.com/dns/troubleshooting"} for a list of errors produced by
* the Google Cloud DNS API. * the Google Cloud DNS API.
* *
* @throws IOException if the operation could not be completed successfully due to an
* uncorrectable error.
* @throws ZoneStateException if the operation could not be completely successfully because the * @throws ZoneStateException if the operation could not be completely successfully because the
* records to delete do not exist, already exist or have been modified with different * records to delete do not exist, already exist or have been modified with different
* attributes since being queried. * attributes since being queried.
*/ */
private void updateResourceRecords( private void updateResourceRecords(
ImmutableSet<ResourceRecordSet> additions, ImmutableSet<ResourceRecordSet> deletions) ImmutableSet<ResourceRecordSet> additions, ImmutableSet<ResourceRecordSet> deletions) {
throws IOException, ZoneStateException {
Change change = new Change().setAdditions(additions.asList()).setDeletions(deletions.asList()); Change change = new Change().setAdditions(additions.asList()).setDeletions(deletions.asList());
rateLimiter.acquire(); rateLimiter.acquire();
@ -356,15 +388,17 @@ public class CloudDnsWriter extends BaseDnsWriter {
List<ErrorInfo> errors = e.getDetails().getErrors(); List<ErrorInfo> errors = e.getDetails().getErrors();
// We did something really wrong here, just give up and re-throw // We did something really wrong here, just give up and re-throw
if (errors.size() > 1) { if (errors.size() > 1) {
throw e; throw new RuntimeException(e);
} }
String errorReason = errors.get(0).getReason(); String errorReason = errors.get(0).getReason();
if (RETRYABLE_EXCEPTION_REASONS.contains(errorReason)) { if (RETRYABLE_EXCEPTION_REASONS.contains(errorReason)) {
throw new ZoneStateException(errorReason); throw new ZoneStateException(errorReason);
} else { } else {
throw e; throw new RuntimeException(e);
} }
} catch (IOException e) {
throw new RuntimeException(e);
} }
} }

View file

@ -76,4 +76,14 @@ public final class CloudDnsWriterModule {
int cloudDnsMaxQps = 20; int cloudDnsMaxQps = 20;
return RateLimiter.create(cloudDnsMaxQps); return RateLimiter.create(cloudDnsMaxQps);
} }
@Provides
@Named("cloudDnsNumThreads")
static int provideNumThreads() {
// TODO(b/70217860): find the "best" number of threads, taking into account running time, App
// Engine constraints, and any Cloud DNS comsiderations etc.
//
// NOTE: any number below 2 will not use threading at all.
return 10;
}
} }

View file

@ -45,12 +45,15 @@ public final class Concurrent {
* @see #transform(Collection, int, Function) * @see #transform(Collection, int, Function)
*/ */
public static <A, B> ImmutableList<B> transform(Collection<A> items, final Function<A, B> funk) { public static <A, B> ImmutableList<B> transform(Collection<A> items, final Function<A, B> funk) {
return transform(items, max(1, min(items.size(), MAX_THREADS)), funk); return transform(items, MAX_THREADS, funk);
} }
/** /**
* Processes {@code items} in parallel using {@code funk}, with the specified number of threads. * Processes {@code items} in parallel using {@code funk}, with the specified number of threads.
* *
* <p>If the maxThreadCount or the number of items is less than 2, will use a non-concurrent
* transform.
*
* <p><b>Note:</b> Spawned threads will inherit the same namespace. * <p><b>Note:</b> Spawned threads will inherit the same namespace.
* *
* @throws UncheckedExecutionException to wrap the exception thrown by {@code funk}. This will * @throws UncheckedExecutionException to wrap the exception thrown by {@code funk}. This will
@ -59,17 +62,18 @@ public final class Concurrent {
*/ */
public static <A, B> ImmutableList<B> transform( public static <A, B> ImmutableList<B> transform(
Collection<A> items, Collection<A> items,
int threadCount, int maxThreadCount,
final Function<A, B> funk) { final Function<A, B> funk) {
checkNotNull(funk); checkNotNull(funk);
checkNotNull(items); checkNotNull(items);
ThreadFactory threadFactory = currentRequestThreadFactory(); int threadCount = max(1, min(items.size(), maxThreadCount));
ThreadFactory threadFactory = threadCount > 1 ? currentRequestThreadFactory() : null;
if (threadFactory == null) { if (threadFactory == null) {
// Fall back to non-concurrent transform if we can't get an App Engine thread factory (most // Fall back to non-concurrent transform if we only want 1 thread, or if we can't get an App
// likely caused by hitting this code from a command-line tool). Default Java system threads // Engine thread factory (most likely caused by hitting this code from a command-line tool).
// are not compatible with code that needs to interact with App Engine (such as Objectify), // Default Java system threads are not compatible with code that needs to interact with App
// which we often have in funk when calling Concurrent.transform(). // Engine (such as Objectify), which we often have in funk when calling
// For more info see: http://stackoverflow.com/questions/15976406 // Concurrent.transform(). For more info see: http://stackoverflow.com/questions/15976406
return items.stream().map(funk).collect(toImmutableList()); return items.stream().map(funk).collect(toImmutableList());
} }
ExecutorService executor = newFixedThreadPool(threadCount, threadFactory); ExecutorService executor = newFixedThreadPool(threadCount, threadFactory);

View file

@ -23,6 +23,7 @@ import static google.registry.testing.DatastoreHelper.newHostResource;
import static google.registry.testing.DatastoreHelper.persistResource; import static google.registry.testing.DatastoreHelper.persistResource;
import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -77,10 +78,8 @@ public class CloudDnsWriterTest {
@Mock private Dns dnsConnection; @Mock private Dns dnsConnection;
@Mock private Dns.ResourceRecordSets resourceRecordSets; @Mock private Dns.ResourceRecordSets resourceRecordSets;
@Mock private Dns.ResourceRecordSets.List listResourceRecordSetsRequest;
@Mock private Dns.Changes changes; @Mock private Dns.Changes changes;
@Mock private Dns.Changes.Create createChangeRequest; @Mock private Dns.Changes.Create createChangeRequest;
@Captor ArgumentCaptor<String> recordNameCaptor;
@Captor ArgumentCaptor<String> zoneNameCaptor; @Captor ArgumentCaptor<String> zoneNameCaptor;
@Captor ArgumentCaptor<Change> changeCaptor; @Captor ArgumentCaptor<Change> changeCaptor;
private CloudDnsWriter writer; private CloudDnsWriter writer;
@ -90,28 +89,15 @@ public class CloudDnsWriterTest {
@Rule public final AppEngineRule appEngine = AppEngineRule.builder().withDatastore().build(); @Rule public final AppEngineRule appEngine = AppEngineRule.builder().withDatastore().build();
@Before /*
public void setUp() throws Exception { * Because of multi-threading in the CloudDnsWriter, we need to return a different instance of
createTld("tld"); * List for every request, with its own ArgumentCaptor. Otherwise, we can't separate the arguments
writer = * of the various Lists
new CloudDnsWriter( */
dnsConnection, private Dns.ResourceRecordSets.List newListResourceRecordSetsRequestMock() throws Exception {
"projectId", Dns.ResourceRecordSets.List listResourceRecordSetsRequest =
"triple.secret.tld", // used by testInvalidZoneNames() mock(Dns.ResourceRecordSets.List.class);
DEFAULT_A_TTL, ArgumentCaptor<String> recordNameCaptor = ArgumentCaptor.forClass(String.class);
DEFAULT_NS_TTL,
DEFAULT_DS_TTL,
RateLimiter.create(20),
new SystemClock(),
new Retrier(new SystemSleeper(), 5));
// Create an empty zone.
stubZone = ImmutableSet.of();
when(dnsConnection.changes()).thenReturn(changes);
when(dnsConnection.resourceRecordSets()).thenReturn(resourceRecordSets);
when(resourceRecordSets.list(anyString(), anyString()))
.thenReturn(listResourceRecordSetsRequest);
when(listResourceRecordSetsRequest.setName(recordNameCaptor.capture())) when(listResourceRecordSetsRequest.setName(recordNameCaptor.capture()))
.thenReturn(listResourceRecordSetsRequest); .thenReturn(listResourceRecordSetsRequest);
// Return records from our stub zone when a request to list the records is executed // Return records from our stub zone when a request to list the records is executed
@ -126,7 +112,34 @@ public class CloudDnsWriterTest {
rs -> rs ->
rs != null && rs.getName().equals(recordNameCaptor.getValue())) rs != null && rs.getName().equals(recordNameCaptor.getValue()))
.collect(toImmutableList()))); .collect(toImmutableList())));
return listResourceRecordSetsRequest;
}
@Before
public void setUp() throws Exception {
createTld("tld");
writer =
new CloudDnsWriter(
dnsConnection,
"projectId",
"triple.secret.tld", // used by testInvalidZoneNames()
DEFAULT_A_TTL,
DEFAULT_NS_TTL,
DEFAULT_DS_TTL,
RateLimiter.create(20),
10, // max num threads
new SystemClock(),
new Retrier(new SystemSleeper(), 5));
// Create an empty zone.
stubZone = ImmutableSet.of();
when(dnsConnection.changes()).thenReturn(changes);
when(dnsConnection.resourceRecordSets()).thenReturn(resourceRecordSets);
when(resourceRecordSets.list(anyString(), anyString()))
.thenAnswer(
invocationOnMock -> newListResourceRecordSetsRequestMock());
when(changes.create(anyString(), zoneNameCaptor.capture(), changeCaptor.capture())) when(changes.create(anyString(), zoneNameCaptor.capture(), changeCaptor.capture()))
.thenReturn(createChangeRequest); .thenReturn(createChangeRequest);
// Change our stub zone when a request to change the records is executed // Change our stub zone when a request to change the records is executed