diff --git a/core/src/main/java/google/registry/model/ofy/DatastoreTransactionManager.java b/core/src/main/java/google/registry/model/ofy/DatastoreTransactionManager.java index a06ce5561..dbea67f74 100644 --- a/core/src/main/java/google/registry/model/ofy/DatastoreTransactionManager.java +++ b/core/src/main/java/google/registry/model/ofy/DatastoreTransactionManager.java @@ -413,5 +413,10 @@ public class DatastoreTransactionManager implements TransactionManager { public Stream stream() { return Streams.stream(buildQuery()); } + + @Override + public long count() { + return buildQuery().count(); + } } } diff --git a/core/src/main/java/google/registry/persistence/transaction/CriteriaQueryBuilder.java b/core/src/main/java/google/registry/persistence/transaction/CriteriaQueryBuilder.java index c4a206103..9befbfab5 100644 --- a/core/src/main/java/google/registry/persistence/transaction/CriteriaQueryBuilder.java +++ b/core/src/main/java/google/registry/persistence/transaction/CriteriaQueryBuilder.java @@ -41,11 +41,11 @@ public class CriteriaQueryBuilder { } private final CriteriaQuery query; - private final Root root; + private final Root root; private final ImmutableList.Builder predicates = new ImmutableList.Builder<>(); private final ImmutableList.Builder orders = new ImmutableList.Builder<>(); - private CriteriaQueryBuilder(CriteriaQuery query, Root root) { + private CriteriaQueryBuilder(CriteriaQuery query, Root root) { this.query = query; this.root = root; } @@ -106,4 +106,13 @@ public class CriteriaQueryBuilder { query = query.select(root); return new CriteriaQueryBuilder<>(query, root); } + + /** Creates a "count" query for the table for the class. */ + public static CriteriaQueryBuilder createCount(EntityManager em, Class clazz) { + CriteriaBuilder builder = em.getCriteriaBuilder(); + CriteriaQuery query = builder.createQuery(Long.class); + Root root = query.from(clazz); + query = query.select(builder.count(root)); + return new CriteriaQueryBuilder<>(query, root); + } } diff --git a/core/src/main/java/google/registry/persistence/transaction/JpaTransactionManagerImpl.java b/core/src/main/java/google/registry/persistence/transaction/JpaTransactionManagerImpl.java index 7c14557fb..b2e492da5 100644 --- a/core/src/main/java/google/registry/persistence/transaction/JpaTransactionManagerImpl.java +++ b/core/src/main/java/google/registry/persistence/transaction/JpaTransactionManagerImpl.java @@ -700,7 +700,10 @@ public class JpaTransactionManagerImpl implements JpaTransactionManager { private TypedQuery buildQuery() { CriteriaQueryBuilder queryBuilder = CriteriaQueryBuilder.create(em, entityClass); + return addCriteria(queryBuilder); + } + private TypedQuery addCriteria(CriteriaQueryBuilder queryBuilder) { for (WhereClause pred : predicates) { pred.addToCriteriaQueryBuilder(queryBuilder); } @@ -727,5 +730,11 @@ public class JpaTransactionManagerImpl implements JpaTransactionManager { public Stream stream() { return buildQuery().getResultStream(); } + + @Override + public long count() { + CriteriaQueryBuilder queryBuilder = CriteriaQueryBuilder.createCount(em, entityClass); + return addCriteria(queryBuilder).getSingleResult(); + } } } diff --git a/core/src/main/java/google/registry/persistence/transaction/QueryComposer.java b/core/src/main/java/google/registry/persistence/transaction/QueryComposer.java index 2785bb2f1..f53385f01 100644 --- a/core/src/main/java/google/registry/persistence/transaction/QueryComposer.java +++ b/core/src/main/java/google/registry/persistence/transaction/QueryComposer.java @@ -88,6 +88,9 @@ public abstract class QueryComposer { /** Returns the results of the query as a stream. */ public abstract Stream stream(); + /** Returns the number of results of the query. */ + public abstract long count(); + // We have to wrap the CriteriaQueryBuilder predicate factories in our own functions because at // the point where we pass them to the Comparator constructor, the compiler can't determine which // of the overloads to use since there is no "value" object for context. diff --git a/core/src/main/java/google/registry/tools/CountDomainsCommand.java b/core/src/main/java/google/registry/tools/CountDomainsCommand.java index 0a9c46b01..d7274a534 100644 --- a/core/src/main/java/google/registry/tools/CountDomainsCommand.java +++ b/core/src/main/java/google/registry/tools/CountDomainsCommand.java @@ -14,12 +14,13 @@ package google.registry.tools; -import static google.registry.model.ofy.ObjectifyService.ofy; import static google.registry.model.registry.Registries.assertTldsExist; +import static google.registry.persistence.transaction.QueryComposer.Comparator; +import static google.registry.persistence.transaction.TransactionManagerFactory.tm; +import static google.registry.persistence.transaction.TransactionManagerUtil.transactIfJpaTm; import com.beust.jcommander.Parameter; import com.beust.jcommander.Parameters; -import com.google.common.collect.Iterables; import google.registry.model.domain.DomainBase; import google.registry.util.Clock; import java.util.List; @@ -45,14 +46,12 @@ final class CountDomainsCommand implements CommandWithRemoteApi { .forEach(tld -> System.out.printf("%s,%d\n", tld, getCountForTld(tld, now))); } - private int getCountForTld(String tld, DateTime now) { - return Iterables.size( - ofy() - .load() - .type(DomainBase.class) - .filter("tld", tld) - .filter("deletionTime >", now) - .chunkAll() - .keys()); + private long getCountForTld(String tld, DateTime now) { + return transactIfJpaTm( + () -> + tm().createQueryComposer(DomainBase.class) + .where("tld", Comparator.EQ, tld) + .where("deletionTime", Comparator.GT, now) + .count()); } } diff --git a/core/src/test/java/google/registry/persistence/transaction/QueryComposerTest.java b/core/src/test/java/google/registry/persistence/transaction/QueryComposerTest.java index 6c9ff31da..70935e465 100644 --- a/core/src/test/java/google/registry/persistence/transaction/QueryComposerTest.java +++ b/core/src/test/java/google/registry/persistence/transaction/QueryComposerTest.java @@ -111,6 +111,17 @@ public class QueryComposerTest { .isEqualTo(alpha); } + @TestOfyAndSql + public void testCount() { + assertThat( + transactIfJpaTm( + () -> + tm().createQueryComposer(TestEntity.class) + .where("name", Comparator.GTE, "bravo") + .count())) + .isEqualTo(2L); + } + @TestOfyAndSql public void testGetSingleResult() { assertThat( diff --git a/core/src/test/java/google/registry/tools/CountDomainsCommandTest.java b/core/src/test/java/google/registry/tools/CountDomainsCommandTest.java index b1fc93be1..4020ba679 100644 --- a/core/src/test/java/google/registry/tools/CountDomainsCommandTest.java +++ b/core/src/test/java/google/registry/tools/CountDomainsCommandTest.java @@ -19,12 +19,14 @@ import static google.registry.testing.DatabaseHelper.persistActiveDomain; import static google.registry.testing.DatabaseHelper.persistDeletedDomain; import google.registry.model.ofy.Ofy; +import google.registry.testing.DualDatabaseTest; import google.registry.testing.InjectExtension; +import google.registry.testing.TestOfyAndSql; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; /** Unit tests for {@link CountDomainsCommand}. */ +@DualDatabaseTest public class CountDomainsCommandTest extends CommandTestCase { @RegisterExtension public final InjectExtension inject = new InjectExtension(); @@ -36,7 +38,7 @@ public class CountDomainsCommandTest extends CommandTestCase