diff --git a/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java b/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java index c0762f6de..84f71be0c 100644 --- a/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java +++ b/core/src/main/java/google/registry/persistence/transaction/TransactionManagerFactory.java @@ -132,13 +132,16 @@ public class TransactionManagerFactory { /** * Sets the return of {@link #tm()} to the given instance of {@link TransactionManager}. * + *

DO NOT CALL THIS DIRECTLY IF POSSIBLE. Strongly prefer the use of TmOverrideExtension + * in test code instead. + * *

Used when overriding the per-test transaction manager for dual-database tests. Should be * matched with a corresponding invocation of {@link #removeTmOverrideForTest()} either at the end * of the test or in an @AfterEach handler. */ @VisibleForTesting - public static void setTmForTest(TransactionManager newTm) { - tmForTest = Optional.of(newTm); + public static void setTmOverrideForTest(TransactionManager newTmOverride) { + tmForTest = Optional.of(newTmOverride); } /** Resets the overridden transaction manager post-test. */ diff --git a/core/src/test/java/google/registry/beam/invoicing/InvoicingPipelineTest.java b/core/src/test/java/google/registry/beam/invoicing/InvoicingPipelineTest.java index 6113655c3..62fafe6df 100644 --- a/core/src/test/java/google/registry/beam/invoicing/InvoicingPipelineTest.java +++ b/core/src/test/java/google/registry/beam/invoicing/InvoicingPipelineTest.java @@ -17,9 +17,6 @@ package google.registry.beam.invoicing; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.truth.Truth.assertThat; import static google.registry.model.tld.Registry.TldState.GENERAL_AVAILABILITY; -import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm; -import static google.registry.persistence.transaction.TransactionManagerFactory.removeTmOverrideForTest; -import static google.registry.persistence.transaction.TransactionManagerFactory.setTmForTest; import static google.registry.testing.DatabaseHelper.createTld; import static google.registry.testing.DatabaseHelper.newRegistry; import static google.registry.testing.DatabaseHelper.persistActiveDomain; @@ -48,6 +45,7 @@ import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationT import google.registry.testing.DatastoreEntityExtension; import google.registry.testing.FakeClock; import google.registry.testing.TestDataHelper; +import google.registry.testing.TmOverrideExtension; import google.registry.util.ResourceUtils; import java.io.File; import java.nio.file.Files; @@ -77,6 +75,25 @@ import org.junit.jupiter.api.io.TempDir; /** Unit tests for {@link InvoicingPipeline}. */ class InvoicingPipelineTest { + @RegisterExtension + @Order(Order.DEFAULT - 1) + final transient DatastoreEntityExtension datastore = + new DatastoreEntityExtension().allThreads(true); + + @RegisterExtension + final TestPipelineExtension pipeline = + TestPipelineExtension.create().enableAbandonedNodeEnforcement(true); + + @RegisterExtension + final JpaIntegrationTestExtension database = + new JpaTestExtensions.Builder().withClock(new FakeClock()).buildIntegrationTestExtension(); + + @RegisterExtension + @Order(Order.DEFAULT + 1) + TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa(); + + @TempDir Path tmpDir; + private static final String BILLING_BUCKET_URL = "billing_bucket"; private static final String YEAR_MONTH = "2017-10"; private static final String INVOICE_FILE_PREFIX = "REG-INV"; @@ -225,21 +242,6 @@ class InvoicingPipelineTest { "2017-10-01,2018-09-30,456,20.50,USD,10125,1,PURCHASE,bestdomains - test,1," + "RENEW | TLD: test | TERM: 1-year,20.50,USD,116688"); - @RegisterExtension - @Order(Order.DEFAULT - 1) - final transient DatastoreEntityExtension datastore = - new DatastoreEntityExtension().allThreads(true); - - @RegisterExtension - final TestPipelineExtension pipeline = - TestPipelineExtension.create().enableAbandonedNodeEnforcement(true); - - @RegisterExtension - final JpaIntegrationTestExtension database = - new JpaTestExtensions.Builder().withClock(new FakeClock()).buildIntegrationTestExtension(); - - @TempDir Path tmpDir; - private final InvoicingPipelineOptions options = PipelineOptionsFactory.create().as(InvoicingPipelineOptions.class); @@ -261,13 +263,12 @@ class InvoicingPipelineTest { String query = InvoicingPipeline.makeQuery("2017-10", "my-project-id"); assertThat(query) .isEqualTo(TestDataHelper.loadFile(this.getClass(), "billing_events_test.sql")); - // This is necessary because the TestPipelineExtension verifies that the pipelien is run. + // This is necessary because the TestPipelineExtension verifies that the pipeline is run. pipeline.run(); } @Test void testSuccess_fullSqlPipeline() throws Exception { - setTmForTest(jpaTm()); setupCloudSql(); options.setDatabase("CLOUD_SQL"); InvoicingPipeline invoicingPipeline = new InvoicingPipeline(options); @@ -282,18 +283,15 @@ class InvoicingPipelineTest { + "UnitPriceCurrency,PONumber"); assertThat(overallInvoice.subList(1, overallInvoice.size())) .containsExactlyElementsIn(EXPECTED_INVOICE_OUTPUT); - removeTmOverrideForTest(); } @Test void testSuccess_readFromCloudSql() throws Exception { - setTmForTest(jpaTm()); setupCloudSql(); PCollection billingEvents = InvoicingPipeline.readFromCloudSql(options, pipeline); billingEvents = billingEvents.apply(new ChangeDomainRepo()); PAssert.that(billingEvents).containsInAnyOrder(INPUT_EVENTS); pipeline.run().waitUntilFinish(); - removeTmOverrideForTest(); } @Test diff --git a/core/src/test/java/google/registry/beam/rde/RdePipelineTest.java b/core/src/test/java/google/registry/beam/rde/RdePipelineTest.java index d9abb8925..9b40de040 100644 --- a/core/src/test/java/google/registry/beam/rde/RdePipelineTest.java +++ b/core/src/test/java/google/registry/beam/rde/RdePipelineTest.java @@ -22,7 +22,6 @@ import static google.registry.beam.rde.RdePipeline.encodePendings; import static google.registry.model.common.Cursor.CursorType.RDE_STAGING; import static google.registry.model.rde.RdeMode.FULL; import static google.registry.model.rde.RdeMode.THIN; -import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm; import static google.registry.persistence.transaction.TransactionManagerFactory.tm; import static google.registry.rde.RdeResourceType.CONTACT; import static google.registry.rde.RdeResourceType.DOMAIN; @@ -64,7 +63,6 @@ import google.registry.model.tld.Registry; import google.registry.persistence.VKey; import google.registry.persistence.transaction.JpaTestExtensions; import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationTestExtension; -import google.registry.persistence.transaction.TransactionManagerFactory; import google.registry.rde.DepositFragment; import google.registry.rde.Ghostryde; import google.registry.rde.PendingDeposit; @@ -74,6 +72,7 @@ import google.registry.testing.CloudTasksHelper.TaskMatcher; import google.registry.testing.DatastoreEntityExtension; import google.registry.testing.FakeClock; import google.registry.testing.FakeKeyringModule; +import google.registry.testing.TmOverrideExtension; import java.io.IOException; import java.util.function.Function; import java.util.regex.Matcher; @@ -88,7 +87,6 @@ import org.bouncycastle.openpgp.PGPPrivateKey; import org.bouncycastle.openpgp.PGPPublicKey; import org.joda.time.DateTime; import org.joda.time.Duration; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Test; @@ -156,6 +154,10 @@ public class RdePipelineTest { final JpaIntegrationTestExtension database = new JpaTestExtensions.Builder().withClock(clock).buildIntegrationTestExtension(); + @RegisterExtension + @Order(Order.DEFAULT + 1) + TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa(); + @RegisterExtension final TestPipelineExtension pipeline = TestPipelineExtension.fromOptions(options).enableAbandonedNodeEnforcement(true); @@ -164,7 +166,6 @@ public class RdePipelineTest { @BeforeEach void beforeEach() throws Exception { - TransactionManagerFactory.setTmForTest(jpaTm()); loadInitialData(); // Two real registrars have been created by loadInitialData(), named "New Registrar" and "The @@ -221,11 +222,6 @@ public class RdePipelineTest { rdePipeline = new RdePipeline(options, gcsUtils, cloudTasksHelper.getTestCloudTasksUtils()); } - @AfterEach - void afterEach() { - TransactionManagerFactory.removeTmOverrideForTest(); - } - @Test void testSuccess_encodeAndDecodePendingsMap() throws Exception { String encodedString = encodePendings(pendings); diff --git a/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java b/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java index e0277fdb6..b73778dd5 100644 --- a/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java +++ b/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java @@ -18,8 +18,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.truth.Truth.assertThat; import static google.registry.model.ImmutableObjectSubject.immutableObjectCorrespondence; import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm; -import static google.registry.persistence.transaction.TransactionManagerFactory.removeTmOverrideForTest; -import static google.registry.persistence.transaction.TransactionManagerFactory.setTmForTest; import static google.registry.testing.AppEngineExtension.makeRegistrar1; import static google.registry.testing.DatabaseHelper.createTld; import static google.registry.testing.DatabaseHelper.persistActiveContact; @@ -52,6 +50,7 @@ import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationT import google.registry.testing.DatastoreEntityExtension; import google.registry.testing.FakeClock; import google.registry.testing.FakeSleeper; +import google.registry.testing.TmOverrideExtension; import google.registry.util.ResourceUtils; import google.registry.util.Retrier; import java.io.File; @@ -129,6 +128,10 @@ class Spec11PipelineTest { final JpaIntegrationTestExtension database = new JpaTestExtensions.Builder().withClock(new FakeClock()).buildIntegrationTestExtension(); + @RegisterExtension + @Order(Order.DEFAULT + 1) + TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa(); + private final Spec11PipelineOptions options = PipelineOptionsFactory.create().as(Spec11PipelineOptions.class); @@ -233,7 +236,6 @@ class Spec11PipelineTest { } private void setupCloudSql() { - setTmForTest(jpaTm()); persistNewRegistrar("TheRegistrar"); persistNewRegistrar("NewRegistrar"); Registrar registrar1 = @@ -273,7 +275,6 @@ class Spec11PipelineTest { persistResource(createDomain("no-email.com", "2A4BA9BBC-COM", registrar2, contact2)); persistResource( createDomain("anti-anti-anti-virus.dev", "555666888-DEV", registrar3, contact3)); - removeTmOverrideForTest(); } private void verifySaveToGcs() throws Exception { diff --git a/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerImplTest.java b/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerImplTest.java index 1ceb3c729..dd04389a1 100644 --- a/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerImplTest.java +++ b/core/src/test/java/google/registry/persistence/transaction/JpaTransactionManagerImplTest.java @@ -36,6 +36,7 @@ import google.registry.persistence.VKey; import google.registry.persistence.transaction.JpaTestExtensions.JpaUnitTestExtension; import google.registry.testing.DatabaseHelper; import google.registry.testing.FakeClock; +import google.registry.testing.TmOverrideExtension; import java.io.Serializable; import java.math.BigInteger; import java.sql.SQLException; @@ -48,8 +49,7 @@ import javax.persistence.IdClass; import javax.persistence.OptimisticLockException; import javax.persistence.RollbackException; import org.hibernate.exception.JDBCConnectionException; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; @@ -84,15 +84,9 @@ class JpaTransactionManagerImplTest { TestEntity.class, TestCompoundIdEntity.class, TestNamedCompoundIdEntity.class) .buildUnitTestExtension(); - @BeforeEach - void beforeEach() { - TransactionManagerFactory.setTmForTest(jpaTm()); - } - - @AfterEach - void afterEach() { - TransactionManagerFactory.removeTmOverrideForTest(); - } + @RegisterExtension + @Order(Order.DEFAULT + 1) + TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa(); @Test void transact_succeeds() { diff --git a/core/src/test/java/google/registry/schema/registrar/RegistrarContactTest.java b/core/src/test/java/google/registry/schema/registrar/RegistrarContactTest.java index d2489a92d..6bbe91eb2 100644 --- a/core/src/test/java/google/registry/schema/registrar/RegistrarContactTest.java +++ b/core/src/test/java/google/registry/schema/registrar/RegistrarContactTest.java @@ -16,8 +16,6 @@ package google.registry.schema.registrar; import static com.google.common.truth.Truth.assertThat; import static google.registry.model.registrar.RegistrarContact.Type.WHOIS; -import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm; -import static google.registry.persistence.transaction.TransactionManagerFactory.setTmForTest; import static google.registry.testing.DatabaseHelper.insertInDb; import static google.registry.testing.DatabaseHelper.loadByEntity; import static google.registry.testing.SqlHelper.saveRegistrar; @@ -28,6 +26,7 @@ import google.registry.model.registrar.RegistrarContact; import google.registry.persistence.transaction.JpaTestExtensions; import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationWithCoverageExtension; import google.registry.testing.DatastoreEntityExtension; +import google.registry.testing.TmOverrideExtension; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Test; @@ -44,13 +43,16 @@ class RegistrarContactTest { JpaIntegrationWithCoverageExtension jpa = new JpaTestExtensions.Builder().buildIntegrationWithCoverageExtension(); + @RegisterExtension + @Order(Order.DEFAULT + 1) + TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa(); + private Registrar testRegistrar; private RegistrarContact testRegistrarPoc; @BeforeEach public void beforeEach() { - setTmForTest(jpaTm()); testRegistrar = saveRegistrar("registrarId"); testRegistrarPoc = new RegistrarContact.Builder() diff --git a/core/src/test/java/google/registry/schema/registrar/RegistrarDaoTest.java b/core/src/test/java/google/registry/schema/registrar/RegistrarDaoTest.java index edf52a19f..8a559359d 100644 --- a/core/src/test/java/google/registry/schema/registrar/RegistrarDaoTest.java +++ b/core/src/test/java/google/registry/schema/registrar/RegistrarDaoTest.java @@ -15,7 +15,6 @@ package google.registry.schema.registrar; import static com.google.common.truth.Truth.assertThat; -import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm; import static google.registry.testing.DatabaseHelper.existsInDb; import static google.registry.testing.DatabaseHelper.insertInDb; import static google.registry.testing.DatabaseHelper.loadByKey; @@ -29,11 +28,10 @@ import google.registry.model.registrar.RegistrarAddress; import google.registry.persistence.VKey; import google.registry.persistence.transaction.JpaTestExtensions; import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationWithCoverageExtension; -import google.registry.persistence.transaction.TransactionManagerFactory; import google.registry.testing.DatastoreEntityExtension; import google.registry.testing.FakeClock; +import google.registry.testing.TmOverrideExtension; import org.joda.time.DateTime; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Test; @@ -52,13 +50,16 @@ public class RegistrarDaoTest { JpaIntegrationWithCoverageExtension jpa = new JpaTestExtensions.Builder().withClock(fakeClock).buildIntegrationWithCoverageExtension(); + @RegisterExtension + @Order(Order.DEFAULT + 1) + TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa(); + private final VKey registrarKey = VKey.createSql(Registrar.class, "registrarId"); private Registrar testRegistrar; @BeforeEach void beforeEach() { - TransactionManagerFactory.setTmForTest(jpaTm()); testRegistrar = new Registrar.Builder() .setType(Registrar.Type.TEST) @@ -75,11 +76,6 @@ public class RegistrarDaoTest { .build(); } - @AfterEach - void afterEach() { - TransactionManagerFactory.removeTmOverrideForTest(); - } - @Test void saveNew_worksSuccessfully() { assertThat(existsInDb(testRegistrar)).isFalse(); diff --git a/core/src/test/java/google/registry/schema/replay/SqlEntityTest.java b/core/src/test/java/google/registry/schema/replay/SqlEntityTest.java index 0bc61bb05..09502c332 100644 --- a/core/src/test/java/google/registry/schema/replay/SqlEntityTest.java +++ b/core/src/test/java/google/registry/schema/replay/SqlEntityTest.java @@ -21,10 +21,9 @@ import google.registry.model.registrar.Registrar; import google.registry.model.registrar.RegistrarContact; import google.registry.model.registrar.RegistrarContact.RegistrarPocId; import google.registry.persistence.VKey; -import google.registry.persistence.transaction.TransactionManagerFactory; import google.registry.testing.AppEngineExtension; import google.registry.testing.DatastoreEntityExtension; -import org.junit.jupiter.api.AfterEach; +import google.registry.testing.TmOverrideExtension; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Order; import org.junit.jupiter.api.Test; @@ -41,17 +40,15 @@ public class SqlEntityTest { final AppEngineExtension database = new AppEngineExtension.Builder().withCloudSql().withoutCannedData().build(); + @RegisterExtension + @Order(Order.DEFAULT + 1) + TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa(); + @BeforeEach void setup() throws Exception { - TransactionManagerFactory.setTmForTest(TransactionManagerFactory.jpaTm()); AppEngineExtension.loadInitialData(); } - @AfterEach - void teardown() { - TransactionManagerFactory.removeTmOverrideForTest(); - } - @Test void getPrimaryKeyString_oneIdColumn() { // AppEngineExtension canned data: Registrar1 diff --git a/core/src/test/java/google/registry/testing/DualDatabaseTestInvocationContextProvider.java b/core/src/test/java/google/registry/testing/DualDatabaseTestInvocationContextProvider.java index 1b48047e4..3f31d4498 100644 --- a/core/src/test/java/google/registry/testing/DualDatabaseTestInvocationContextProvider.java +++ b/core/src/test/java/google/registry/testing/DualDatabaseTestInvocationContextProvider.java @@ -154,7 +154,7 @@ class DualDatabaseTestInvocationContextProvider implements TestTemplateInvocatio context.getStore(NAMESPACE).put(ORIGINAL_TM_KEY, tm()); DatabaseType databaseType = (DatabaseType) context.getStore(NAMESPACE).get(INJECTED_TM_SUPPLIER_KEY); - TransactionManagerFactory.setTmForTest(databaseType.getTm()); + TransactionManagerFactory.setTmOverrideForTest(databaseType.getTm()); } } diff --git a/core/src/test/java/google/registry/testing/TmOverrideExtension.java b/core/src/test/java/google/registry/testing/TmOverrideExtension.java new file mode 100644 index 000000000..247285423 --- /dev/null +++ b/core/src/test/java/google/registry/testing/TmOverrideExtension.java @@ -0,0 +1,73 @@ +// Copyright 2021 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.testing; + +import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm; +import static google.registry.persistence.transaction.TransactionManagerFactory.ofyTm; + +import google.registry.persistence.transaction.TransactionManager; +import google.registry.persistence.transaction.TransactionManagerFactory; +import org.junit.jupiter.api.extension.AfterEachCallback; +import org.junit.jupiter.api.extension.BeforeEachCallback; +import org.junit.jupiter.api.extension.ExtensionContext; + +/** + * JUnit extension for overriding the {@link TransactionManager} in tests. + * + *

You will typically want to run this at @Order(Order.DEFAULT + 1) alongside a + * {@link google.registry.persistence.transaction.JpaTransactionManagerExtension} or {@link + * DatastoreEntityExtension} with default {@link org.junit.jupiter.api.Order}. The transaction + * manager extension needs to run first so that when this override is called it's not trying to use + * the default dummy one. + * + *

This extension is incompatible with {@link DualDatabaseTest}. Use either that or this, but not + * both. + */ +public final class TmOverrideExtension implements BeforeEachCallback, AfterEachCallback { + + private static enum TmOverride { + OFY, + JPA; + } + + private final TmOverride tmOverride; + + private TmOverrideExtension(TmOverride tmOverride) { + this.tmOverride = tmOverride; + } + + /** Use the {@link google.registry.model.ofy.DatastoreTransactionManager} for all tests. */ + public static TmOverrideExtension withOfy() { + return new TmOverrideExtension(TmOverride.OFY); + } + + /** + * Use the {@link google.registry.persistence.transaction.JpaTransactionManager} for all tests. + */ + public static TmOverrideExtension withJpa() { + return new TmOverrideExtension(TmOverride.JPA); + } + + @Override + public void beforeEach(ExtensionContext context) { + TransactionManagerFactory.setTmOverrideForTest( + tmOverride == TmOverride.OFY ? ofyTm() : jpaTm()); + } + + @Override + public void afterEach(ExtensionContext context) { + TransactionManagerFactory.removeTmOverrideForTest(); + } +}