diff --git a/core/build.gradle b/core/build.gradle index cc2f7bc29..dabd48186 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -880,6 +880,9 @@ task standardTest(type: FilteringTest) { // forkEvery 1 // Sets the maximum number of test executors that may exist at the same time. + // Also, Gradle executes tests in 1 thread and some of our test infrastructures + // depend on that, e.g. DualDatabaseTestInvocationContextProvider injects + // different implementation of TransactionManager into TransactionManagerFactory. maxParallelForks 5 systemProperty 'test.projectRoot', rootProject.projectRootDir diff --git a/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java b/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java index 37ce21e0c..83204ee72 100644 --- a/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java +++ b/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java @@ -16,6 +16,7 @@ package google.registry.persistence.transaction; import com.google.appengine.api.utils.SystemProperty; import com.google.appengine.api.utils.SystemProperty.Environment.Value; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Suppliers; import google.registry.model.ofy.DatastoreTransactionManager; import google.registry.persistence.DaggerPersistenceComponent; @@ -26,7 +27,9 @@ import java.util.function.Supplier; // TODO: Rename this to PersistenceFactory and move to persistence package. public class TransactionManagerFactory { - private static final TransactionManager TM = createTransactionManager(); + private static final DatastoreTransactionManager ofyTm = createTransactionManager(); + + @NonFinalForTesting private static TransactionManager tm = ofyTm; /** Supplier for jpaTm so that it is initialized only once, upon first usage. */ @NonFinalForTesting @@ -45,10 +48,7 @@ public class TransactionManagerFactory { } } - private static TransactionManager createTransactionManager() { - // TODO: Determine how to provision TransactionManager after the dual-write. During the - // dual-write transitional phase, we need the TransactionManager for both Datastore and Cloud - // SQL, and this method returns the one for Datastore. + private static DatastoreTransactionManager createTransactionManager() { return new DatastoreTransactionManager(null); } @@ -67,7 +67,7 @@ public class TransactionManagerFactory { /** Returns {@link TransactionManager} instance. */ public static TransactionManager tm() { - return TM; + return tm; } /** Returns {@link JpaTransactionManager} instance. */ @@ -75,8 +75,20 @@ public class TransactionManagerFactory { return jpaTm.get(); } + /** Returns {@link DatastoreTransactionManager} instance. */ + @VisibleForTesting + public static DatastoreTransactionManager ofyTm() { + return ofyTm; + } + /** Sets the return of {@link #jpaTm()} to the given instance of {@link JpaTransactionManager}. */ public static void setJpaTm(JpaTransactionManager newJpaTm) { jpaTm = Suppliers.ofInstance(newJpaTm); } + + /** Sets the return of {@link #tm()} to the given instance of {@link TransactionManager}. */ + @VisibleForTesting + public static void setTm(TransactionManager newTm) { + tm = newTm; + } } diff --git a/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerImplTest.java b/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerImplTest.java index fef501ed1..a1f03dd30 100644 --- a/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerImplTest.java +++ b/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerImplTest.java @@ -37,7 +37,13 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Unit tests for {@link JpaTransactionManagerImpl}. */ +/** + * Unit tests for SQL only APIs defined in {@link JpaTransactionManagerImpl}. Note that the tests + * for common APIs in {@link TransactionManager} are added in {@link TransactionManagerTest}. + * + *

TODO(shicong): Remove duplicate tests that covered by TransactionManagerTest by refactoring + * the test schema. + */ @RunWith(JUnit4.class) public class JpaTransactionManagerImplTest { @@ -62,29 +68,6 @@ public class JpaTransactionManagerImplTest { .withEntityClass(TestEntity.class, TestCompoundIdEntity.class) .buildUnitTestRule(); - @Test - public void inTransaction_returnsCorrespondingResult() { - assertThat(jpaTm().inTransaction()).isFalse(); - jpaTm().transact(() -> assertThat(jpaTm().inTransaction()).isTrue()); - assertThat(jpaTm().inTransaction()).isFalse(); - } - - @Test - public void assertInTransaction_throwsExceptionWhenNotInTransaction() { - assertThrows(IllegalStateException.class, () -> jpaTm().assertInTransaction()); - jpaTm().transact(() -> jpaTm().assertInTransaction()); - assertThrows(IllegalStateException.class, () -> jpaTm().assertInTransaction()); - } - - @Test - public void getTransactionTime_throwsExceptionWhenNotInTransaction() { - FakeClock txnClock = fakeClock; - txnClock.advanceOneMilli(); - assertThrows(IllegalStateException.class, () -> jpaTm().getTransactionTime()); - jpaTm().transact(() -> assertThat(jpaTm().getTransactionTime()).isEqualTo(txnClock.nowUtc())); - assertThrows(IllegalStateException.class, () -> jpaTm().getTransactionTime()); - } - @Test public void transact_succeeds() { assertPersonEmpty(); diff --git a/core/src/test/java/google/registry/model/ofy/DatastoreTransactionManagerTest.java b/core/src/test/java/google/registry/persistence/transaction/TransactionManagerTest.java similarity index 88% rename from core/src/test/java/google/registry/model/ofy/DatastoreTransactionManagerTest.java rename to core/src/test/java/google/registry/persistence/transaction/TransactionManagerTest.java index 62755b35e..e9d83c9e4 100644 --- a/core/src/test/java/google/registry/model/ofy/DatastoreTransactionManagerTest.java +++ b/core/src/test/java/google/registry/persistence/transaction/TransactionManagerTest.java @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package google.registry.model.ofy; +package google.registry.persistence.transaction; import static com.google.common.truth.Truth.assertThat; import static google.registry.persistence.transaction.TransactionManagerFactory.tm; @@ -23,16 +23,24 @@ import com.googlecode.objectify.Key; import com.googlecode.objectify.annotation.Entity; import com.googlecode.objectify.annotation.Id; import google.registry.model.ImmutableObject; +import google.registry.model.ofy.DatastoreTransactionManager; +import google.registry.model.ofy.Ofy; import google.registry.persistence.VKey; import google.registry.testing.AppEngineRule; +import google.registry.testing.DualDatabaseTest; import google.registry.testing.FakeClock; import google.registry.testing.InjectRule; import java.util.NoSuchElementException; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; import org.junit.jupiter.api.extension.RegisterExtension; -public class DatastoreTransactionManagerTest { +/** + * Unit tests for common APIs in {@link DatastoreTransactionManager} and {@link + * JpaTransactionManagerImpl}. + */ +@DualDatabaseTest +public class TransactionManagerTest { private final FakeClock fakeClock = new FakeClock(); @@ -51,38 +59,31 @@ public class DatastoreTransactionManagerTest { .withClock(fakeClock) .withDatastoreAndCloudSql() .withOfyTestEntities(TestEntity.class) + .withJpaUnitTestEntities(TestEntity.class) .build(); - public DatastoreTransactionManagerTest() {} + public TransactionManagerTest() {} @BeforeEach public void setUp() { inject.setStaticField(Ofy.class, "clock", fakeClock); } - // TODO(mmuller): The tests below are just copy-pasted from JpaTransactionManagerImplTest - // (excluding the CompoundId tests and native query tests, which are not relevant to datastore, - // and the test methods using "count" which doesn't work for datastore, as well as tests for - // functionality that doesn't exist in datastore, like failures based on whether a newly saved or - // updated object exists or not). We need to merge these into a single test suite, but first we - // should move the JpaUnitTestRule functionality into AppEngineTest and migrate the whole thing - // to junit5. - - @Test + @TestTemplate public void inTransaction_returnsCorrespondingResult() { assertThat(tm().inTransaction()).isFalse(); tm().transact(() -> assertThat(tm().inTransaction()).isTrue()); assertThat(tm().inTransaction()).isFalse(); } - @Test + @TestTemplate public void assertInTransaction_throwsExceptionWhenNotInTransaction() { assertThrows(IllegalStateException.class, () -> tm().assertInTransaction()); tm().transact(() -> tm().assertInTransaction()); assertThrows(IllegalStateException.class, () -> tm().assertInTransaction()); } - @Test + @TestTemplate public void getTransactionTime_throwsExceptionWhenNotInTransaction() { FakeClock txnClock = fakeClock; txnClock.advanceOneMilli(); @@ -91,7 +92,7 @@ public class DatastoreTransactionManagerTest { assertThrows(IllegalStateException.class, () -> tm().getTransactionTime()); } - @Test + @TestTemplate public void transact_hasNoEffectWithPartialSuccess() { assertThat(tm().transact(() -> tm().checkExists(theEntity))).isFalse(); assertThrows( @@ -106,7 +107,7 @@ public class DatastoreTransactionManagerTest { assertThat(tm().transact(() -> tm().checkExists(theEntity))).isFalse(); } - @Test + @TestTemplate public void transact_reusesExistingTransaction() { assertThat(tm().transact(() -> tm().checkExists(theEntity))).isFalse(); fakeClock.advanceOneMilli(); @@ -115,7 +116,7 @@ public class DatastoreTransactionManagerTest { assertThat(tm().transact(() -> tm().checkExists(theEntity))).isTrue(); } - @Test + @TestTemplate public void saveNew_succeeds() { assertThat(tm().transact(() -> tm().checkExists(theEntity))).isFalse(); fakeClock.advanceOneMilli(); @@ -126,7 +127,7 @@ public class DatastoreTransactionManagerTest { assertThat(tm().transact(() -> tm().load(theEntity.key()))).isEqualTo(theEntity); } - @Test + @TestTemplate public void saveAllNew_succeeds() { moreEntities.forEach( entity -> assertThat(tm().transact(() -> tm().checkExists(entity))).isFalse()); @@ -137,7 +138,7 @@ public class DatastoreTransactionManagerTest { entity -> assertThat(tm().transact(() -> tm().checkExists(entity))).isTrue()); } - @Test + @TestTemplate public void saveNewOrUpdate_persistsNewEntity() { assertThat(tm().transact(() -> tm().checkExists(theEntity))).isFalse(); fakeClock.advanceOneMilli(); @@ -148,7 +149,7 @@ public class DatastoreTransactionManagerTest { assertThat(tm().transact(() -> tm().load(theEntity.key()))).isEqualTo(theEntity); } - @Test + @TestTemplate public void saveNewOrUpdate_updatesExistingEntity() { fakeClock.advanceOneMilli(); tm().transact(() -> tm().saveNew(theEntity)); @@ -163,7 +164,7 @@ public class DatastoreTransactionManagerTest { assertThat(persisted.data).isEqualTo("bar"); } - @Test + @TestTemplate public void saveNewOrUpdateAll_succeeds() { moreEntities.forEach( entity -> assertThat(tm().transact(() -> tm().checkExists(entity))).isFalse()); @@ -174,13 +175,16 @@ public class DatastoreTransactionManagerTest { entity -> assertThat(tm().transact(() -> tm().checkExists(entity))).isTrue()); } - @Test + @TestTemplate public void update_succeeds() { fakeClock.advanceOneMilli(); tm().transact(() -> tm().saveNew(theEntity)); fakeClock.advanceOneMilli(); TestEntity persisted = - tm().transact(() -> tm().load(VKey.createOfy(TestEntity.class, Key.create(theEntity)))); + tm().transact( + () -> + tm().load( + VKey.create(TestEntity.class, theEntity.name, Key.create(theEntity)))); fakeClock.advanceOneMilli(); assertThat(persisted.data).isEqualTo("foo"); theEntity.data = "bar"; @@ -190,7 +194,7 @@ public class DatastoreTransactionManagerTest { assertThat(persisted.data).isEqualTo("bar"); } - @Test + @TestTemplate public void load_succeeds() { assertThat(tm().transact(() -> tm().checkExists(theEntity))).isFalse(); fakeClock.advanceOneMilli(); @@ -201,7 +205,7 @@ public class DatastoreTransactionManagerTest { assertThat(persisted.data).isEqualTo("foo"); } - @Test + @TestTemplate public void load_throwsOnMissingElement() { assertThat(tm().transact(() -> tm().checkExists(theEntity))).isFalse(); fakeClock.advanceOneMilli(); @@ -209,7 +213,7 @@ public class DatastoreTransactionManagerTest { NoSuchElementException.class, () -> tm().transact(() -> tm().load(theEntity.key()))); } - @Test + @TestTemplate public void maybeLoad_succeeds() { assertThat(tm().transact(() -> tm().checkExists(theEntity))).isFalse(); fakeClock.advanceOneMilli(); @@ -220,14 +224,14 @@ public class DatastoreTransactionManagerTest { assertThat(persisted.data).isEqualTo("foo"); } - @Test + @TestTemplate public void maybeLoad_nonExistentObject() { assertThat(tm().transact(() -> tm().checkExists(theEntity))).isFalse(); fakeClock.advanceOneMilli(); assertThat(tm().transact(() -> tm().maybeLoad(theEntity.key())).isPresent()).isFalse(); } - @Test + @TestTemplate public void delete_succeeds() { fakeClock.advanceOneMilli(); tm().transact(() -> tm().saveNew(theEntity)); @@ -239,7 +243,7 @@ public class DatastoreTransactionManagerTest { assertThat(tm().transact(() -> tm().checkExists(theEntity))).isFalse(); } - @Test + @TestTemplate public void delete_returnsZeroWhenNoEntity() { assertThat(tm().transact(() -> tm().checkExists(theEntity))).isFalse(); fakeClock.advanceOneMilli(); @@ -249,8 +253,9 @@ public class DatastoreTransactionManagerTest { } @Entity(name = "TestEntity") + @javax.persistence.Entity(name = "TestEntity") private static class TestEntity extends ImmutableObject { - @Id private String name; + @Id @javax.persistence.Id private String name; private String data; @@ -262,7 +267,7 @@ public class DatastoreTransactionManagerTest { } public VKey key() { - return VKey.createOfy(TestEntity.class, Key.create(this)); + return VKey.create(TestEntity.class, name, Key.create(this)); } } } diff --git a/core/src/test/java/google/registry/testing/AppEngineRule.java b/core/src/test/java/google/registry/testing/AppEngineRule.java index 96dfe7b35..203c62e92 100644 --- a/core/src/test/java/google/registry/testing/AppEngineRule.java +++ b/core/src/test/java/google/registry/testing/AppEngineRule.java @@ -17,6 +17,8 @@ package google.registry.testing; import static com.google.common.base.Preconditions.checkState; import static com.google.common.truth.Truth.assertWithMessage; import static google.registry.testing.DatastoreHelper.persistSimpleResources; +import static google.registry.testing.DualDatabaseTestInvocationContextProvider.injectTmForDualDatabaseTest; +import static google.registry.testing.DualDatabaseTestInvocationContextProvider.restoreTmAfterDualDatabaseTest; import static google.registry.util.ResourceUtils.readResourceUtf8; import static java.nio.charset.StandardCharsets.UTF_8; import static org.json.XML.toJSONObject; @@ -45,6 +47,7 @@ import google.registry.persistence.transaction.JpaTestRules; import google.registry.persistence.transaction.JpaTestRules.JpaIntegrationTestRule; import google.registry.persistence.transaction.JpaTestRules.JpaIntegrationWithCoverageExtension; import google.registry.persistence.transaction.JpaTestRules.JpaIntegrationWithCoverageRule; +import google.registry.persistence.transaction.JpaTestRules.JpaUnitTestRule; import google.registry.util.Clock; import java.io.ByteArrayInputStream; import java.io.File; @@ -118,8 +121,11 @@ public final class AppEngineRule extends ExternalResource */ JpaIntegrationWithCoverageExtension jpaIntegrationWithCoverageExtension = null; + JpaUnitTestRule jpaUnitTestRule; + private boolean withDatastoreAndCloudSql; private boolean enableJpaEntityCoverageCheck; + private boolean withJpaUnitTest; private boolean withLocalModules; private boolean withTaskQueue; private boolean withUserService; @@ -131,12 +137,14 @@ public final class AppEngineRule extends ExternalResource // Test Objectify entity classes to be used with this AppEngineRule instance. private ImmutableList> ofyTestEntities; + private ImmutableList> jpaTestEntities; /** Builder for {@link AppEngineRule}. */ public static class Builder { private AppEngineRule rule = new AppEngineRule(); - private ImmutableList.Builder> ofyTestEntities = new ImmutableList.Builder(); + private ImmutableList.Builder> ofyTestEntities = new ImmutableList.Builder<>(); + private ImmutableList.Builder> jpaTestEntities = new ImmutableList.Builder<>(); /** Turn on the Datastore service and the Cloud SQL service. */ public Builder withDatastoreAndCloudSql() { @@ -205,11 +213,24 @@ public final class AppEngineRule extends ExternalResource return this; } + public Builder withJpaUnitTestEntities(Class... entities) { + jpaTestEntities.add(entities); + rule.withJpaUnitTest = true; + return this; + } + public AppEngineRule build() { checkState( !rule.enableJpaEntityCoverageCheck || rule.withDatastoreAndCloudSql, "withJpaEntityCoverageCheck enabled without Cloud SQL"); + checkState( + !rule.withJpaUnitTest || rule.withDatastoreAndCloudSql, + "withJpaUnitTestEntities enabled without Cloud SQL"); + checkState( + !rule.withJpaUnitTest || !rule.enableJpaEntityCoverageCheck, + "withJpaUnitTestEntities cannot be set when enableJpaEntityCoverageCheck"); rule.ofyTestEntities = this.ofyTestEntities.build(); + rule.jpaTestEntities = this.jpaTestEntities.build(); return rule; } } @@ -328,11 +349,18 @@ public final class AppEngineRule extends ExternalResource if (enableJpaEntityCoverageCheck) { jpaIntegrationWithCoverageExtension = builder.buildIntegrationWithCoverageExtension(); jpaIntegrationWithCoverageExtension.beforeEach(context); + } else if (withJpaUnitTest) { + jpaUnitTestRule = + builder + .withEntityClass(jpaTestEntities.toArray(new Class[jpaTestEntities.size()])) + .buildUnitTestRule(); + jpaUnitTestRule.before(); } else { jpaIntegrationTestRule = builder.buildIntegrationTestRule(); jpaIntegrationTestRule.before(); } } + injectTmForDualDatabaseTest(context); } /** Called after each test method. JUnit 5 only. */ @@ -341,11 +369,14 @@ public final class AppEngineRule extends ExternalResource if (withDatastoreAndCloudSql) { if (enableJpaEntityCoverageCheck) { jpaIntegrationWithCoverageExtension.afterEach(context); + } else if (withJpaUnitTest) { + jpaUnitTestRule.after(); } else { jpaIntegrationTestRule.after(); } } after(); + restoreTmAfterDualDatabaseTest(context); } /** @@ -560,4 +591,8 @@ public final class AppEngineRule extends ExternalResource makeRegistrarContact2(), makeRegistrarContact3())); } + + boolean isWithDatastoreAndCloudSql() { + return withDatastoreAndCloudSql; + } } diff --git a/core/src/test/java/google/registry/testing/DualDatabaseTest.java b/core/src/test/java/google/registry/testing/DualDatabaseTest.java new file mode 100644 index 000000000..75badacca --- /dev/null +++ b/core/src/test/java/google/registry/testing/DualDatabaseTest.java @@ -0,0 +1,28 @@ +// Copyright 2020 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.testing; + +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import org.junit.jupiter.api.extension.ExtendWith; + +/** Annotation to add {@link DualDatabaseTestInvocationContextProvider} for the annotated test. */ +@Target({TYPE}) +@Retention(RUNTIME) +@ExtendWith(DualDatabaseTestInvocationContextProvider.class) +public @interface DualDatabaseTest {} diff --git a/core/src/test/java/google/registry/testing/DualDatabaseTestInvocationContextProvider.java b/core/src/test/java/google/registry/testing/DualDatabaseTestInvocationContextProvider.java new file mode 100644 index 000000000..7a6b268d8 --- /dev/null +++ b/core/src/test/java/google/registry/testing/DualDatabaseTestInvocationContextProvider.java @@ -0,0 +1,125 @@ +// Copyright 2020 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.testing; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static google.registry.persistence.transaction.TransactionManagerFactory.tm; + +import com.google.common.collect.ImmutableList; +import google.registry.persistence.transaction.TransactionManager; +import google.registry.persistence.transaction.TransactionManagerFactory; +import java.lang.reflect.Field; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Stream; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.Extension; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.ExtensionContext.Namespace; +import org.junit.jupiter.api.extension.TestInstancePostProcessor; +import org.junit.jupiter.api.extension.TestTemplateInvocationContext; +import org.junit.jupiter.api.extension.TestTemplateInvocationContextProvider; + +/** + * Implementation of {@link TestTemplateInvocationContextProvider} to execute tests against + * different database. The test annotated with {@link TestTemplate} will be executed twice against + * Datastore and PostgresQL respectively. + */ +class DualDatabaseTestInvocationContextProvider implements TestTemplateInvocationContextProvider { + private static final Namespace NAMESPACE = + Namespace.create(DualDatabaseTestInvocationContextProvider.class); + private static final String INJECTED_TM_SUPPLIER_KEY = "injected_tm_supplier_key"; + private static final String ORIGINAL_TM_KEY = "original_tm_key"; + + @Override + public boolean supportsTestTemplate(ExtensionContext context) { + return true; + } + + @Override + public Stream provideTestTemplateInvocationContexts( + ExtensionContext context) { + return Stream.of( + createInvocationContext("Test Datastore", TransactionManagerFactory::ofyTm), + createInvocationContext("Test PostgreSQL", TransactionManagerFactory::jpaTm)); + } + + private TestTemplateInvocationContext createInvocationContext( + String name, Supplier tmSupplier) { + return new TestTemplateInvocationContext() { + @Override + public String getDisplayName(int invocationIndex) { + return name; + } + + @Override + public List getAdditionalExtensions() { + return ImmutableList.of(new DatabaseSwitchInvocationContext(tmSupplier)); + } + }; + } + + private static class DatabaseSwitchInvocationContext implements TestInstancePostProcessor { + + private Supplier tmSupplier; + + private DatabaseSwitchInvocationContext(Supplier tmSupplier) { + this.tmSupplier = tmSupplier; + } + + @Override + public void postProcessTestInstance(Object testInstance, ExtensionContext context) + throws Exception { + List appEngineRuleFields = + Stream.of(testInstance.getClass().getFields()) + .filter(field -> field.getType().isAssignableFrom(AppEngineRule.class)) + .collect(toImmutableList()); + if (appEngineRuleFields.size() != 1) { + throw new IllegalStateException( + "@DualDatabaseTest test must have only 1 AppEngineRule field"); + } + appEngineRuleFields.get(0).setAccessible(true); + AppEngineRule appEngineRule = (AppEngineRule) appEngineRuleFields.get(0).get(testInstance); + if (!appEngineRule.isWithDatastoreAndCloudSql()) { + throw new IllegalStateException( + "AppEngineRule in @DualDatabaseTest test must set withDatastoreAndCloudSql()"); + } + context.getStore(NAMESPACE).put(INJECTED_TM_SUPPLIER_KEY, tmSupplier); + } + } + + static void injectTmForDualDatabaseTest(ExtensionContext context) { + if (isDualDatabaseTest(context)) { + context.getStore(NAMESPACE).put(ORIGINAL_TM_KEY, tm()); + Supplier tmSupplier = + (Supplier) + context.getStore(NAMESPACE).get(INJECTED_TM_SUPPLIER_KEY); + TransactionManagerFactory.setTm(tmSupplier.get()); + } + } + + static void restoreTmAfterDualDatabaseTest(ExtensionContext context) { + if (isDualDatabaseTest(context)) { + TransactionManager original = + (TransactionManager) context.getStore(NAMESPACE).get(ORIGINAL_TM_KEY); + TransactionManagerFactory.setTm(original); + } + } + + private static boolean isDualDatabaseTest(ExtensionContext context) { + Object testInstance = context.getTestInstance().orElseThrow(RuntimeException::new); + return testInstance.getClass().isAnnotationPresent(DualDatabaseTest.class); + } +}