diff --git a/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java b/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java
index c0762f6de..84f71be0c 100644
--- a/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java
+++ b/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java
@@ -132,13 +132,16 @@ public class TransactionManagerFactory {
/**
* Sets the return of {@link #tm()} to the given instance of {@link TransactionManager}.
*
+ *
DO NOT CALL THIS DIRECTLY IF POSSIBLE. Strongly prefer the use of TmOverrideExtension
+ *
in test code instead.
+ *
*
Used when overriding the per-test transaction manager for dual-database tests. Should be
* matched with a corresponding invocation of {@link #removeTmOverrideForTest()} either at the end
* of the test or in an @AfterEach
handler.
*/
@VisibleForTesting
- public static void setTmForTest(TransactionManager newTm) {
- tmForTest = Optional.of(newTm);
+ public static void setTmOverrideForTest(TransactionManager newTmOverride) {
+ tmForTest = Optional.of(newTmOverride);
}
/** Resets the overridden transaction manager post-test. */
diff --git a/core/src/test/java/google/registry/beam/invoicing/InvoicingPipelineTest.java b/core/src/test/java/google/registry/beam/invoicing/InvoicingPipelineTest.java
index 6113655c3..62fafe6df 100644
--- a/core/src/test/java/google/registry/beam/invoicing/InvoicingPipelineTest.java
+++ b/core/src/test/java/google/registry/beam/invoicing/InvoicingPipelineTest.java
@@ -17,9 +17,6 @@ package google.registry.beam.invoicing;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.model.tld.Registry.TldState.GENERAL_AVAILABILITY;
-import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
-import static google.registry.persistence.transaction.TransactionManagerFactory.removeTmOverrideForTest;
-import static google.registry.persistence.transaction.TransactionManagerFactory.setTmForTest;
import static google.registry.testing.DatabaseHelper.createTld;
import static google.registry.testing.DatabaseHelper.newRegistry;
import static google.registry.testing.DatabaseHelper.persistActiveDomain;
@@ -48,6 +45,7 @@ import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationT
import google.registry.testing.DatastoreEntityExtension;
import google.registry.testing.FakeClock;
import google.registry.testing.TestDataHelper;
+import google.registry.testing.TmOverrideExtension;
import google.registry.util.ResourceUtils;
import java.io.File;
import java.nio.file.Files;
@@ -77,6 +75,25 @@ import org.junit.jupiter.api.io.TempDir;
/** Unit tests for {@link InvoicingPipeline}. */
class InvoicingPipelineTest {
+ @RegisterExtension
+ @Order(Order.DEFAULT - 1)
+ final transient DatastoreEntityExtension datastore =
+ new DatastoreEntityExtension().allThreads(true);
+
+ @RegisterExtension
+ final TestPipelineExtension pipeline =
+ TestPipelineExtension.create().enableAbandonedNodeEnforcement(true);
+
+ @RegisterExtension
+ final JpaIntegrationTestExtension database =
+ new JpaTestExtensions.Builder().withClock(new FakeClock()).buildIntegrationTestExtension();
+
+ @RegisterExtension
+ @Order(Order.DEFAULT + 1)
+ TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa();
+
+ @TempDir Path tmpDir;
+
private static final String BILLING_BUCKET_URL = "billing_bucket";
private static final String YEAR_MONTH = "2017-10";
private static final String INVOICE_FILE_PREFIX = "REG-INV";
@@ -225,21 +242,6 @@ class InvoicingPipelineTest {
"2017-10-01,2018-09-30,456,20.50,USD,10125,1,PURCHASE,bestdomains - test,1,"
+ "RENEW | TLD: test | TERM: 1-year,20.50,USD,116688");
- @RegisterExtension
- @Order(Order.DEFAULT - 1)
- final transient DatastoreEntityExtension datastore =
- new DatastoreEntityExtension().allThreads(true);
-
- @RegisterExtension
- final TestPipelineExtension pipeline =
- TestPipelineExtension.create().enableAbandonedNodeEnforcement(true);
-
- @RegisterExtension
- final JpaIntegrationTestExtension database =
- new JpaTestExtensions.Builder().withClock(new FakeClock()).buildIntegrationTestExtension();
-
- @TempDir Path tmpDir;
-
private final InvoicingPipelineOptions options =
PipelineOptionsFactory.create().as(InvoicingPipelineOptions.class);
@@ -261,13 +263,12 @@ class InvoicingPipelineTest {
String query = InvoicingPipeline.makeQuery("2017-10", "my-project-id");
assertThat(query)
.isEqualTo(TestDataHelper.loadFile(this.getClass(), "billing_events_test.sql"));
- // This is necessary because the TestPipelineExtension verifies that the pipelien is run.
+ // This is necessary because the TestPipelineExtension verifies that the pipeline is run.
pipeline.run();
}
@Test
void testSuccess_fullSqlPipeline() throws Exception {
- setTmForTest(jpaTm());
setupCloudSql();
options.setDatabase("CLOUD_SQL");
InvoicingPipeline invoicingPipeline = new InvoicingPipeline(options);
@@ -282,18 +283,15 @@ class InvoicingPipelineTest {
+ "UnitPriceCurrency,PONumber");
assertThat(overallInvoice.subList(1, overallInvoice.size()))
.containsExactlyElementsIn(EXPECTED_INVOICE_OUTPUT);
- removeTmOverrideForTest();
}
@Test
void testSuccess_readFromCloudSql() throws Exception {
- setTmForTest(jpaTm());
setupCloudSql();
PCollection billingEvents = InvoicingPipeline.readFromCloudSql(options, pipeline);
billingEvents = billingEvents.apply(new ChangeDomainRepo());
PAssert.that(billingEvents).containsInAnyOrder(INPUT_EVENTS);
pipeline.run().waitUntilFinish();
- removeTmOverrideForTest();
}
@Test
diff --git a/core/src/test/java/google/registry/beam/rde/RdePipelineTest.java b/core/src/test/java/google/registry/beam/rde/RdePipelineTest.java
index d9abb8925..9b40de040 100644
--- a/core/src/test/java/google/registry/beam/rde/RdePipelineTest.java
+++ b/core/src/test/java/google/registry/beam/rde/RdePipelineTest.java
@@ -22,7 +22,6 @@ import static google.registry.beam.rde.RdePipeline.encodePendings;
import static google.registry.model.common.Cursor.CursorType.RDE_STAGING;
import static google.registry.model.rde.RdeMode.FULL;
import static google.registry.model.rde.RdeMode.THIN;
-import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.rde.RdeResourceType.CONTACT;
import static google.registry.rde.RdeResourceType.DOMAIN;
@@ -64,7 +63,6 @@ import google.registry.model.tld.Registry;
import google.registry.persistence.VKey;
import google.registry.persistence.transaction.JpaTestExtensions;
import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationTestExtension;
-import google.registry.persistence.transaction.TransactionManagerFactory;
import google.registry.rde.DepositFragment;
import google.registry.rde.Ghostryde;
import google.registry.rde.PendingDeposit;
@@ -74,6 +72,7 @@ import google.registry.testing.CloudTasksHelper.TaskMatcher;
import google.registry.testing.DatastoreEntityExtension;
import google.registry.testing.FakeClock;
import google.registry.testing.FakeKeyringModule;
+import google.registry.testing.TmOverrideExtension;
import java.io.IOException;
import java.util.function.Function;
import java.util.regex.Matcher;
@@ -88,7 +87,6 @@ import org.bouncycastle.openpgp.PGPPrivateKey;
import org.bouncycastle.openpgp.PGPPublicKey;
import org.joda.time.DateTime;
import org.joda.time.Duration;
-import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
@@ -156,6 +154,10 @@ public class RdePipelineTest {
final JpaIntegrationTestExtension database =
new JpaTestExtensions.Builder().withClock(clock).buildIntegrationTestExtension();
+ @RegisterExtension
+ @Order(Order.DEFAULT + 1)
+ TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa();
+
@RegisterExtension
final TestPipelineExtension pipeline =
TestPipelineExtension.fromOptions(options).enableAbandonedNodeEnforcement(true);
@@ -164,7 +166,6 @@ public class RdePipelineTest {
@BeforeEach
void beforeEach() throws Exception {
- TransactionManagerFactory.setTmForTest(jpaTm());
loadInitialData();
// Two real registrars have been created by loadInitialData(), named "New Registrar" and "The
@@ -221,11 +222,6 @@ public class RdePipelineTest {
rdePipeline = new RdePipeline(options, gcsUtils, cloudTasksHelper.getTestCloudTasksUtils());
}
- @AfterEach
- void afterEach() {
- TransactionManagerFactory.removeTmOverrideForTest();
- }
-
@Test
void testSuccess_encodeAndDecodePendingsMap() throws Exception {
String encodedString = encodePendings(pendings);
diff --git a/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java b/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java
index e0277fdb6..b73778dd5 100644
--- a/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java
+++ b/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java
@@ -18,8 +18,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.model.ImmutableObjectSubject.immutableObjectCorrespondence;
import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
-import static google.registry.persistence.transaction.TransactionManagerFactory.removeTmOverrideForTest;
-import static google.registry.persistence.transaction.TransactionManagerFactory.setTmForTest;
import static google.registry.testing.AppEngineExtension.makeRegistrar1;
import static google.registry.testing.DatabaseHelper.createTld;
import static google.registry.testing.DatabaseHelper.persistActiveContact;
@@ -52,6 +50,7 @@ import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationT
import google.registry.testing.DatastoreEntityExtension;
import google.registry.testing.FakeClock;
import google.registry.testing.FakeSleeper;
+import google.registry.testing.TmOverrideExtension;
import google.registry.util.ResourceUtils;
import google.registry.util.Retrier;
import java.io.File;
@@ -129,6 +128,10 @@ class Spec11PipelineTest {
final JpaIntegrationTestExtension database =
new JpaTestExtensions.Builder().withClock(new FakeClock()).buildIntegrationTestExtension();
+ @RegisterExtension
+ @Order(Order.DEFAULT + 1)
+ TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa();
+
private final Spec11PipelineOptions options =
PipelineOptionsFactory.create().as(Spec11PipelineOptions.class);
@@ -233,7 +236,6 @@ class Spec11PipelineTest {
}
private void setupCloudSql() {
- setTmForTest(jpaTm());
persistNewRegistrar("TheRegistrar");
persistNewRegistrar("NewRegistrar");
Registrar registrar1 =
@@ -273,7 +275,6 @@ class Spec11PipelineTest {
persistResource(createDomain("no-email.com", "2A4BA9BBC-COM", registrar2, contact2));
persistResource(
createDomain("anti-anti-anti-virus.dev", "555666888-DEV", registrar3, contact3));
- removeTmOverrideForTest();
}
private void verifySaveToGcs() throws Exception {
diff --git a/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerImplTest.java b/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerImplTest.java
index 1ceb3c729..dd04389a1 100644
--- a/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerImplTest.java
+++ b/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerImplTest.java
@@ -36,6 +36,7 @@ import google.registry.persistence.VKey;
import google.registry.persistence.transaction.JpaTestExtensions.JpaUnitTestExtension;
import google.registry.testing.DatabaseHelper;
import google.registry.testing.FakeClock;
+import google.registry.testing.TmOverrideExtension;
import java.io.Serializable;
import java.math.BigInteger;
import java.sql.SQLException;
@@ -48,8 +49,7 @@ import javax.persistence.IdClass;
import javax.persistence.OptimisticLockException;
import javax.persistence.RollbackException;
import org.hibernate.exception.JDBCConnectionException;
-import org.junit.jupiter.api.AfterEach;
-import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
@@ -84,15 +84,9 @@ class JpaTransactionManagerImplTest {
TestEntity.class, TestCompoundIdEntity.class, TestNamedCompoundIdEntity.class)
.buildUnitTestExtension();
- @BeforeEach
- void beforeEach() {
- TransactionManagerFactory.setTmForTest(jpaTm());
- }
-
- @AfterEach
- void afterEach() {
- TransactionManagerFactory.removeTmOverrideForTest();
- }
+ @RegisterExtension
+ @Order(Order.DEFAULT + 1)
+ TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa();
@Test
void transact_succeeds() {
diff --git a/core/src/test/java/google/registry/schema/registrar/RegistrarContactTest.java b/core/src/test/java/google/registry/schema/registrar/RegistrarContactTest.java
index d2489a92d..6bbe91eb2 100644
--- a/core/src/test/java/google/registry/schema/registrar/RegistrarContactTest.java
+++ b/core/src/test/java/google/registry/schema/registrar/RegistrarContactTest.java
@@ -16,8 +16,6 @@ package google.registry.schema.registrar;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.model.registrar.RegistrarContact.Type.WHOIS;
-import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
-import static google.registry.persistence.transaction.TransactionManagerFactory.setTmForTest;
import static google.registry.testing.DatabaseHelper.insertInDb;
import static google.registry.testing.DatabaseHelper.loadByEntity;
import static google.registry.testing.SqlHelper.saveRegistrar;
@@ -28,6 +26,7 @@ import google.registry.model.registrar.RegistrarContact;
import google.registry.persistence.transaction.JpaTestExtensions;
import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationWithCoverageExtension;
import google.registry.testing.DatastoreEntityExtension;
+import google.registry.testing.TmOverrideExtension;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
@@ -44,13 +43,16 @@ class RegistrarContactTest {
JpaIntegrationWithCoverageExtension jpa =
new JpaTestExtensions.Builder().buildIntegrationWithCoverageExtension();
+ @RegisterExtension
+ @Order(Order.DEFAULT + 1)
+ TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa();
+
private Registrar testRegistrar;
private RegistrarContact testRegistrarPoc;
@BeforeEach
public void beforeEach() {
- setTmForTest(jpaTm());
testRegistrar = saveRegistrar("registrarId");
testRegistrarPoc =
new RegistrarContact.Builder()
diff --git a/core/src/test/java/google/registry/schema/registrar/RegistrarDaoTest.java b/core/src/test/java/google/registry/schema/registrar/RegistrarDaoTest.java
index edf52a19f..8a559359d 100644
--- a/core/src/test/java/google/registry/schema/registrar/RegistrarDaoTest.java
+++ b/core/src/test/java/google/registry/schema/registrar/RegistrarDaoTest.java
@@ -15,7 +15,6 @@
package google.registry.schema.registrar;
import static com.google.common.truth.Truth.assertThat;
-import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
import static google.registry.testing.DatabaseHelper.existsInDb;
import static google.registry.testing.DatabaseHelper.insertInDb;
import static google.registry.testing.DatabaseHelper.loadByKey;
@@ -29,11 +28,10 @@ import google.registry.model.registrar.RegistrarAddress;
import google.registry.persistence.VKey;
import google.registry.persistence.transaction.JpaTestExtensions;
import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationWithCoverageExtension;
-import google.registry.persistence.transaction.TransactionManagerFactory;
import google.registry.testing.DatastoreEntityExtension;
import google.registry.testing.FakeClock;
+import google.registry.testing.TmOverrideExtension;
import org.joda.time.DateTime;
-import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
@@ -52,13 +50,16 @@ public class RegistrarDaoTest {
JpaIntegrationWithCoverageExtension jpa =
new JpaTestExtensions.Builder().withClock(fakeClock).buildIntegrationWithCoverageExtension();
+ @RegisterExtension
+ @Order(Order.DEFAULT + 1)
+ TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa();
+
private final VKey registrarKey = VKey.createSql(Registrar.class, "registrarId");
private Registrar testRegistrar;
@BeforeEach
void beforeEach() {
- TransactionManagerFactory.setTmForTest(jpaTm());
testRegistrar =
new Registrar.Builder()
.setType(Registrar.Type.TEST)
@@ -75,11 +76,6 @@ public class RegistrarDaoTest {
.build();
}
- @AfterEach
- void afterEach() {
- TransactionManagerFactory.removeTmOverrideForTest();
- }
-
@Test
void saveNew_worksSuccessfully() {
assertThat(existsInDb(testRegistrar)).isFalse();
diff --git a/core/src/test/java/google/registry/schema/replay/SqlEntityTest.java b/core/src/test/java/google/registry/schema/replay/SqlEntityTest.java
index 0bc61bb05..09502c332 100644
--- a/core/src/test/java/google/registry/schema/replay/SqlEntityTest.java
+++ b/core/src/test/java/google/registry/schema/replay/SqlEntityTest.java
@@ -21,10 +21,9 @@ import google.registry.model.registrar.Registrar;
import google.registry.model.registrar.RegistrarContact;
import google.registry.model.registrar.RegistrarContact.RegistrarPocId;
import google.registry.persistence.VKey;
-import google.registry.persistence.transaction.TransactionManagerFactory;
import google.registry.testing.AppEngineExtension;
import google.registry.testing.DatastoreEntityExtension;
-import org.junit.jupiter.api.AfterEach;
+import google.registry.testing.TmOverrideExtension;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
@@ -41,17 +40,15 @@ public class SqlEntityTest {
final AppEngineExtension database =
new AppEngineExtension.Builder().withCloudSql().withoutCannedData().build();
+ @RegisterExtension
+ @Order(Order.DEFAULT + 1)
+ TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa();
+
@BeforeEach
void setup() throws Exception {
- TransactionManagerFactory.setTmForTest(TransactionManagerFactory.jpaTm());
AppEngineExtension.loadInitialData();
}
- @AfterEach
- void teardown() {
- TransactionManagerFactory.removeTmOverrideForTest();
- }
-
@Test
void getPrimaryKeyString_oneIdColumn() {
// AppEngineExtension canned data: Registrar1
diff --git a/core/src/test/java/google/registry/testing/DualDatabaseTestInvocationContextProvider.java b/core/src/test/java/google/registry/testing/DualDatabaseTestInvocationContextProvider.java
index 1b48047e4..3f31d4498 100644
--- a/core/src/test/java/google/registry/testing/DualDatabaseTestInvocationContextProvider.java
+++ b/core/src/test/java/google/registry/testing/DualDatabaseTestInvocationContextProvider.java
@@ -154,7 +154,7 @@ class DualDatabaseTestInvocationContextProvider implements TestTemplateInvocatio
context.getStore(NAMESPACE).put(ORIGINAL_TM_KEY, tm());
DatabaseType databaseType =
(DatabaseType) context.getStore(NAMESPACE).get(INJECTED_TM_SUPPLIER_KEY);
- TransactionManagerFactory.setTmForTest(databaseType.getTm());
+ TransactionManagerFactory.setTmOverrideForTest(databaseType.getTm());
}
}
diff --git a/core/src/test/java/google/registry/testing/TmOverrideExtension.java b/core/src/test/java/google/registry/testing/TmOverrideExtension.java
new file mode 100644
index 000000000..247285423
--- /dev/null
+++ b/core/src/test/java/google/registry/testing/TmOverrideExtension.java
@@ -0,0 +1,73 @@
+// Copyright 2021 The Nomulus Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package google.registry.testing;
+
+import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
+import static google.registry.persistence.transaction.TransactionManagerFactory.ofyTm;
+
+import google.registry.persistence.transaction.TransactionManager;
+import google.registry.persistence.transaction.TransactionManagerFactory;
+import org.junit.jupiter.api.extension.AfterEachCallback;
+import org.junit.jupiter.api.extension.BeforeEachCallback;
+import org.junit.jupiter.api.extension.ExtensionContext;
+
+/**
+ * JUnit extension for overriding the {@link TransactionManager} in tests.
+ *
+ * You will typically want to run this at @Order(Order.DEFAULT + 1)
alongside a
+ * {@link google.registry.persistence.transaction.JpaTransactionManagerExtension} or {@link
+ * DatastoreEntityExtension} with default {@link org.junit.jupiter.api.Order}. The transaction
+ * manager extension needs to run first so that when this override is called it's not trying to use
+ * the default dummy one.
+ *
+ *
This extension is incompatible with {@link DualDatabaseTest}. Use either that or this, but not
+ * both.
+ */
+public final class TmOverrideExtension implements BeforeEachCallback, AfterEachCallback {
+
+ private static enum TmOverride {
+ OFY,
+ JPA;
+ }
+
+ private final TmOverride tmOverride;
+
+ private TmOverrideExtension(TmOverride tmOverride) {
+ this.tmOverride = tmOverride;
+ }
+
+ /** Use the {@link google.registry.model.ofy.DatastoreTransactionManager} for all tests. */
+ public static TmOverrideExtension withOfy() {
+ return new TmOverrideExtension(TmOverride.OFY);
+ }
+
+ /**
+ * Use the {@link google.registry.persistence.transaction.JpaTransactionManager} for all tests.
+ */
+ public static TmOverrideExtension withJpa() {
+ return new TmOverrideExtension(TmOverride.JPA);
+ }
+
+ @Override
+ public void beforeEach(ExtensionContext context) {
+ TransactionManagerFactory.setTmOverrideForTest(
+ tmOverride == TmOverride.OFY ? ofyTm() : jpaTm());
+ }
+
+ @Override
+ public void afterEach(ExtensionContext context) {
+ TransactionManagerFactory.removeTmOverrideForTest();
+ }
+}