diff --git a/core/src/main/java/google/registry/model/registry/Registries.java b/core/src/main/java/google/registry/model/registry/Registries.java index 448069803..e6524f83b 100644 --- a/core/src/main/java/google/registry/model/registry/Registries.java +++ b/core/src/main/java/google/registry/model/registry/Registries.java @@ -19,6 +19,7 @@ import static com.google.common.base.Predicates.equalTo; import static com.google.common.base.Predicates.in; import static com.google.common.base.Predicates.not; import static com.google.common.base.Strings.emptyToNull; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Maps.filterValues; import static google.registry.model.CacheUtils.memoizeWithShortExpiration; @@ -31,10 +32,12 @@ import com.google.common.base.Joiner; import com.google.common.base.Supplier; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; import com.google.common.collect.Streams; import com.google.common.net.InternetDomainName; import com.googlecode.objectify.Key; import google.registry.model.registry.Registry.TldType; +import java.util.Map; import java.util.Optional; /** Utilities for finding and listing {@link Registry} entities. */ @@ -54,16 +57,21 @@ public final class Registries { private static Supplier> createFreshCache() { return memoizeWithShortExpiration( () -> - tm() - .doTransactionless( + tm().doTransactionless( () -> { - ImmutableMap.Builder builder = - new ImmutableMap.Builder<>(); - for (Registry registry : - ofy().load().type(Registry.class).ancestor(getCrossTldKey())) { - builder.put(registry.getTldStr(), registry.getTldType()); - } - return builder.build(); + ImmutableSet tlds = + ofy() + .load() + .type(Registry.class) + .ancestor(getCrossTldKey()) + .keys() + .list() + .stream() + .map(Key::getName) + .collect(toImmutableSet()); + return Registry.getAll(tlds).stream() + .map(e -> Maps.immutableEntry(e.getTldStr(), e.getTldType())) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); })); } @@ -83,11 +91,7 @@ public final class Registries { /** Returns the Registry entities themselves of the given type loaded fresh from Datastore. */ public static ImmutableSet getTldEntitiesOfType(TldType type) { - ImmutableSet> keys = - filterValues(cache.get(), equalTo(type)).keySet().stream() - .map(tld -> Key.create(getCrossTldKey(), Registry.class, tld)) - .collect(toImmutableSet()); - return ImmutableSet.copyOf(tm().doTransactionless(() -> ofy().load().keys(keys).values())); + return Registry.getAll(filterValues(cache.get(), equalTo(type)).keySet()); } /** Pass-through check that the specified TLD exists, otherwise throw an IAE. */ diff --git a/core/src/main/java/google/registry/model/registry/Registry.java b/core/src/main/java/google/registry/model/registry/Registry.java index e4501ac2f..3f818ad79 100644 --- a/core/src/main/java/google/registry/model/registry/Registry.java +++ b/core/src/main/java/google/registry/model/registry/Registry.java @@ -18,6 +18,9 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Predicates.equalTo; import static com.google.common.base.Predicates.not; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Maps.toMap; import static google.registry.config.RegistryConfig.getSingletonCacheRefreshDuration; import static google.registry.model.common.EntityGroupRoot.getCrossTldKey; import static google.registry.model.ofy.ObjectifyService.ofy; @@ -30,9 +33,11 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.joda.money.CurrencyUnit.USD; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Joiner; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSortedMap; import com.google.common.collect.Iterables; @@ -58,8 +63,10 @@ import google.registry.model.domain.fee.Fee; import google.registry.model.registry.label.PremiumList; import google.registry.model.registry.label.ReservedList; import google.registry.util.Idn; +import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.ExecutionException; import java.util.function.Predicate; import java.util.regex.Pattern; import javax.annotation.Nullable; @@ -201,6 +208,25 @@ public class Registry extends ImmutableObject implements Buildable { return registry; } + /** Returns the registry entities for the given TLD strings, throwing if any don't exist. */ + static ImmutableSet getAll(Set tlds) { + try { + ImmutableMap> registries = CACHE.getAll(tlds); + ImmutableSet missingRegistries = + registries.entrySet().stream() + .filter(e -> !e.getValue().isPresent()) + .map(Map.Entry::getKey) + .collect(toImmutableSet()); + if (missingRegistries.isEmpty()) { + return registries.values().stream().map(Optional::get).collect(toImmutableSet()); + } else { + throw new RegistryNotFoundException(missingRegistries); + } + } catch (ExecutionException e) { + throw new RuntimeException("Unexpected error retrieving TLDs " + tlds, e); + } + } + /** * Invalidates the cache entry. * @@ -220,15 +246,30 @@ public class Registry extends ImmutableObject implements Buildable { new CacheLoader>() { @Override public Optional load(final String tld) { - // Enter a transactionless context briefly; we don't want to enroll every TLD in a - // transaction that might be wrapping this call. + // Enter a transaction-less context briefly; we don't want to enroll every TLD in + // a transaction that might be wrapping this call. return Optional.ofNullable( - tm() - .doTransactionless( - () -> ofy() - .load() - .key(Key.create(getCrossTldKey(), Registry.class, tld)) - .now())); + tm().doTransactionless( + () -> + ofy() + .load() + .key(Key.create(getCrossTldKey(), Registry.class, tld)) + .now())); + } + + @Override + public Map> loadAll(Iterable tlds) { + ImmutableMap> keysMap = + toMap( + ImmutableSet.copyOf(tlds), + tld -> Key.create(getCrossTldKey(), Registry.class, tld)); + Map, Registry> entities = + tm().doTransactionless(() -> ofy().load().keys(keysMap.values())); + return keysMap.entrySet().stream() + .collect( + toImmutableMap( + Map.Entry::getKey, + e -> Optional.ofNullable(entities.getOrDefault(e.getValue(), null)))); } }); @@ -883,10 +924,14 @@ public class Registry extends ImmutableObject implements Buildable { } } - /** Exception to throw when no Registry is found for a given tld. */ + /** Exception to throw when no Registry entity is found for given TLD string(s). */ public static class RegistryNotFoundException extends RuntimeException { + RegistryNotFoundException(ImmutableSet tlds) { + super("No registry object(s) found for " + Joiner.on(", ").join(tlds)); + } + RegistryNotFoundException(String tld) { - super("No registry object found for " + tld); + this(ImmutableSet.of(tld)); } } } diff --git a/core/src/test/java/google/registry/export/ExportReservedTermsActionTest.java b/core/src/test/java/google/registry/export/ExportReservedTermsActionTest.java index a4f93da24..d1644f1d8 100644 --- a/core/src/test/java/google/registry/export/ExportReservedTermsActionTest.java +++ b/core/src/test/java/google/registry/export/ExportReservedTermsActionTest.java @@ -48,10 +48,7 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class ExportReservedTermsActionTest { - @Rule - public final AppEngineRule appEngine = AppEngineRule.builder() - .withDatastore() - .build(); + @Rule public final AppEngineRule appEngine = AppEngineRule.builder().withDatastore().build(); private final DriveConnection driveConnection = mock(DriveConnection.class); private final Response response = mock(Response.class); @@ -133,6 +130,6 @@ public class ExportReservedTermsActionTest { assertThat(thrown) .hasCauseThat() .hasMessageThat() - .isEqualTo("No registry object found for fakeTld"); + .isEqualTo("No registry object(s) found for fakeTld"); } } diff --git a/core/src/test/java/google/registry/model/registry/RegistryTest.java b/core/src/test/java/google/registry/model/registry/RegistryTest.java index a7712f86d..be2c2bf09 100644 --- a/core/src/test/java/google/registry/model/registry/RegistryTest.java +++ b/core/src/test/java/google/registry/model/registry/RegistryTest.java @@ -51,6 +51,7 @@ import org.junit.Test; /** Unit tests for {@link Registry}. */ public class RegistryTest extends EntityTestCase { + Registry registry; @Before @@ -146,6 +147,19 @@ public class RegistryTest extends EntityTestCase { assertThat(registry.getReservedLists()).isEmpty(); } + @Test + public void testGetAll() { + createTld("foo"); + assertThat(Registry.getAll(ImmutableSet.of("foo", "tld"))) + .containsExactlyElementsIn( + ofy() + .load() + .keys( + Key.create(getCrossTldKey(), Registry.class, "foo"), + Key.create(getCrossTldKey(), Registry.class, "tld")) + .values()); + } + @Test public void testSetReservedLists() { ReservedList rl5 = persistReservedList(