diff --git a/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java b/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java index 83204ee72..d03548a73 100644 --- a/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java +++ b/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java @@ -14,12 +14,17 @@ package google.registry.persistence.transaction; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + import com.google.appengine.api.utils.SystemProperty; import com.google.appengine.api.utils.SystemProperty.Environment.Value; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Suppliers; +import google.registry.config.RegistryEnvironment; import google.registry.model.ofy.DatastoreTransactionManager; import google.registry.persistence.DaggerPersistenceComponent; +import google.registry.tools.RegistryToolEnvironment; import google.registry.util.NonFinalForTesting; import java.util.function.Supplier; @@ -82,8 +87,13 @@ public class TransactionManagerFactory { } /** Sets the return of {@link #jpaTm()} to the given instance of {@link JpaTransactionManager}. */ - public static void setJpaTm(JpaTransactionManager newJpaTm) { - jpaTm = Suppliers.ofInstance(newJpaTm); + public static void setJpaTm(Supplier jpaTmSupplier) { + checkNotNull(jpaTmSupplier, "jpaTmSupplier"); + checkState( + RegistryEnvironment.get().equals(RegistryEnvironment.UNITTEST) + || RegistryToolEnvironment.get() != null, + "setJpamTm() should only be called by tools and tests."); + jpaTm = jpaTmSupplier; } /** Sets the return of {@link #tm()} to the given instance of {@link TransactionManager}. */ diff --git a/core/src/main/java/google/registry/tools/RegistryCli.java b/core/src/main/java/google/registry/tools/RegistryCli.java index 2f8ee9bce..48341bed8 100644 --- a/core/src/main/java/google/registry/tools/RegistryCli.java +++ b/core/src/main/java/google/registry/tools/RegistryCli.java @@ -240,7 +240,7 @@ final class RegistryCli implements AutoCloseable, CommandRunner { // Enable Cloud SQL for command that needs remote API as they will very likely use // Cloud SQL after the database migration. Note that the DB password is stored in Datastore // and it is already initialized above. - TransactionManagerFactory.setJpaTm(component.nomulusToolJpaTransactionManager()); + TransactionManagerFactory.setJpaTm(() -> component.nomulusToolJpaTransactionManager().get()); } command.run(); diff --git a/core/src/main/java/google/registry/tools/RegistryToolComponent.java b/core/src/main/java/google/registry/tools/RegistryToolComponent.java index e2b18cfb9..120edcbfc 100644 --- a/core/src/main/java/google/registry/tools/RegistryToolComponent.java +++ b/core/src/main/java/google/registry/tools/RegistryToolComponent.java @@ -16,6 +16,7 @@ package google.registry.tools; import dagger.BindsInstance; import dagger.Component; +import dagger.Lazy; import google.registry.batch.BatchModule; import google.registry.bigquery.BigqueryModule; import google.registry.config.CredentialModule.LocalCredentialJson; @@ -124,7 +125,7 @@ interface RegistryToolComponent { String googleCredentialJson(); @NomulusToolJpaTm - JpaTransactionManager nomulusToolJpaTransactionManager(); + Lazy nomulusToolJpaTransactionManager(); @Component.Builder interface Builder { diff --git a/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerRule.java b/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerRule.java index fc06b53ee..c9972a5e7 100644 --- a/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerRule.java +++ b/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerRule.java @@ -21,6 +21,7 @@ import static org.testcontainers.containers.PostgreSQLContainer.POSTGRESQL_PORT; import com.google.common.base.Charsets; import com.google.common.base.Joiner; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; @@ -199,12 +200,12 @@ abstract class JpaTransactionManagerRule extends ExternalResource { } JpaTransactionManagerImpl txnManager = new JpaTransactionManagerImpl(emf, clock); cachedTm = TransactionManagerFactory.jpaTm(); - TransactionManagerFactory.setJpaTm(txnManager); + TransactionManagerFactory.setJpaTm(Suppliers.ofInstance(txnManager)); } @Override public void after() { - TransactionManagerFactory.setJpaTm(cachedTm); + TransactionManagerFactory.setJpaTm(Suppliers.ofInstance(cachedTm)); cachedTm = null; }