Add TmOverrideExtension for more safe TM overrides in tests (#1382)

* Add TmOverrideExtension for more safe TM overrides in tests

This is safer to use than calling setTmForTest() directly because this extension
also handles the corresponding call to removeTmOverrideForTest() automatically,
the forgetting of which has been a source of test flakiness/instability in the
past.

There are now broadly two ways to get tests to run in JPA: either use
DualDatabaseTest, an AppEngineExtension, and the corresponding JPA-specific
@Test annotations, OR use this override alongside a
JpaTransactionManagerExtension.
This commit is contained in:
Ben McIlwain 2021-10-07 19:26:25 -04:00 committed by GitHub
parent 3c3140dd9a
commit d1a972d1e4
10 changed files with 130 additions and 70 deletions

View file

@ -132,13 +132,16 @@ public class TransactionManagerFactory {
/**
* Sets the return of {@link #tm()} to the given instance of {@link TransactionManager}.
*
* <p>DO NOT CALL THIS DIRECTLY IF POSSIBLE. Strongly prefer the use of <code>TmOverrideExtension
* </code> in test code instead.
*
* <p>Used when overriding the per-test transaction manager for dual-database tests. Should be
* matched with a corresponding invocation of {@link #removeTmOverrideForTest()} either at the end
* of the test or in an <code>@AfterEach</code> handler.
*/
@VisibleForTesting
public static void setTmForTest(TransactionManager newTm) {
tmForTest = Optional.of(newTm);
public static void setTmOverrideForTest(TransactionManager newTmOverride) {
tmForTest = Optional.of(newTmOverride);
}
/** Resets the overridden transaction manager post-test. */

View file

@ -17,9 +17,6 @@ package google.registry.beam.invoicing;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.model.tld.Registry.TldState.GENERAL_AVAILABILITY;
import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
import static google.registry.persistence.transaction.TransactionManagerFactory.removeTmOverrideForTest;
import static google.registry.persistence.transaction.TransactionManagerFactory.setTmForTest;
import static google.registry.testing.DatabaseHelper.createTld;
import static google.registry.testing.DatabaseHelper.newRegistry;
import static google.registry.testing.DatabaseHelper.persistActiveDomain;
@ -48,6 +45,7 @@ import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationT
import google.registry.testing.DatastoreEntityExtension;
import google.registry.testing.FakeClock;
import google.registry.testing.TestDataHelper;
import google.registry.testing.TmOverrideExtension;
import google.registry.util.ResourceUtils;
import java.io.File;
import java.nio.file.Files;
@ -77,6 +75,25 @@ import org.junit.jupiter.api.io.TempDir;
/** Unit tests for {@link InvoicingPipeline}. */
class InvoicingPipelineTest {
@RegisterExtension
@Order(Order.DEFAULT - 1)
final transient DatastoreEntityExtension datastore =
new DatastoreEntityExtension().allThreads(true);
@RegisterExtension
final TestPipelineExtension pipeline =
TestPipelineExtension.create().enableAbandonedNodeEnforcement(true);
@RegisterExtension
final JpaIntegrationTestExtension database =
new JpaTestExtensions.Builder().withClock(new FakeClock()).buildIntegrationTestExtension();
@RegisterExtension
@Order(Order.DEFAULT + 1)
TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa();
@TempDir Path tmpDir;
private static final String BILLING_BUCKET_URL = "billing_bucket";
private static final String YEAR_MONTH = "2017-10";
private static final String INVOICE_FILE_PREFIX = "REG-INV";
@ -225,21 +242,6 @@ class InvoicingPipelineTest {
"2017-10-01,2018-09-30,456,20.50,USD,10125,1,PURCHASE,bestdomains - test,1,"
+ "RENEW | TLD: test | TERM: 1-year,20.50,USD,116688");
@RegisterExtension
@Order(Order.DEFAULT - 1)
final transient DatastoreEntityExtension datastore =
new DatastoreEntityExtension().allThreads(true);
@RegisterExtension
final TestPipelineExtension pipeline =
TestPipelineExtension.create().enableAbandonedNodeEnforcement(true);
@RegisterExtension
final JpaIntegrationTestExtension database =
new JpaTestExtensions.Builder().withClock(new FakeClock()).buildIntegrationTestExtension();
@TempDir Path tmpDir;
private final InvoicingPipelineOptions options =
PipelineOptionsFactory.create().as(InvoicingPipelineOptions.class);
@ -261,13 +263,12 @@ class InvoicingPipelineTest {
String query = InvoicingPipeline.makeQuery("2017-10", "my-project-id");
assertThat(query)
.isEqualTo(TestDataHelper.loadFile(this.getClass(), "billing_events_test.sql"));
// This is necessary because the TestPipelineExtension verifies that the pipelien is run.
// This is necessary because the TestPipelineExtension verifies that the pipeline is run.
pipeline.run();
}
@Test
void testSuccess_fullSqlPipeline() throws Exception {
setTmForTest(jpaTm());
setupCloudSql();
options.setDatabase("CLOUD_SQL");
InvoicingPipeline invoicingPipeline = new InvoicingPipeline(options);
@ -282,18 +283,15 @@ class InvoicingPipelineTest {
+ "UnitPriceCurrency,PONumber");
assertThat(overallInvoice.subList(1, overallInvoice.size()))
.containsExactlyElementsIn(EXPECTED_INVOICE_OUTPUT);
removeTmOverrideForTest();
}
@Test
void testSuccess_readFromCloudSql() throws Exception {
setTmForTest(jpaTm());
setupCloudSql();
PCollection<BillingEvent> billingEvents = InvoicingPipeline.readFromCloudSql(options, pipeline);
billingEvents = billingEvents.apply(new ChangeDomainRepo());
PAssert.that(billingEvents).containsInAnyOrder(INPUT_EVENTS);
pipeline.run().waitUntilFinish();
removeTmOverrideForTest();
}
@Test

View file

@ -22,7 +22,6 @@ import static google.registry.beam.rde.RdePipeline.encodePendings;
import static google.registry.model.common.Cursor.CursorType.RDE_STAGING;
import static google.registry.model.rde.RdeMode.FULL;
import static google.registry.model.rde.RdeMode.THIN;
import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.rde.RdeResourceType.CONTACT;
import static google.registry.rde.RdeResourceType.DOMAIN;
@ -64,7 +63,6 @@ import google.registry.model.tld.Registry;
import google.registry.persistence.VKey;
import google.registry.persistence.transaction.JpaTestExtensions;
import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationTestExtension;
import google.registry.persistence.transaction.TransactionManagerFactory;
import google.registry.rde.DepositFragment;
import google.registry.rde.Ghostryde;
import google.registry.rde.PendingDeposit;
@ -74,6 +72,7 @@ import google.registry.testing.CloudTasksHelper.TaskMatcher;
import google.registry.testing.DatastoreEntityExtension;
import google.registry.testing.FakeClock;
import google.registry.testing.FakeKeyringModule;
import google.registry.testing.TmOverrideExtension;
import java.io.IOException;
import java.util.function.Function;
import java.util.regex.Matcher;
@ -88,7 +87,6 @@ import org.bouncycastle.openpgp.PGPPrivateKey;
import org.bouncycastle.openpgp.PGPPublicKey;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
@ -156,6 +154,10 @@ public class RdePipelineTest {
final JpaIntegrationTestExtension database =
new JpaTestExtensions.Builder().withClock(clock).buildIntegrationTestExtension();
@RegisterExtension
@Order(Order.DEFAULT + 1)
TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa();
@RegisterExtension
final TestPipelineExtension pipeline =
TestPipelineExtension.fromOptions(options).enableAbandonedNodeEnforcement(true);
@ -164,7 +166,6 @@ public class RdePipelineTest {
@BeforeEach
void beforeEach() throws Exception {
TransactionManagerFactory.setTmForTest(jpaTm());
loadInitialData();
// Two real registrars have been created by loadInitialData(), named "New Registrar" and "The
@ -221,11 +222,6 @@ public class RdePipelineTest {
rdePipeline = new RdePipeline(options, gcsUtils, cloudTasksHelper.getTestCloudTasksUtils());
}
@AfterEach
void afterEach() {
TransactionManagerFactory.removeTmOverrideForTest();
}
@Test
void testSuccess_encodeAndDecodePendingsMap() throws Exception {
String encodedString = encodePendings(pendings);

View file

@ -18,8 +18,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.model.ImmutableObjectSubject.immutableObjectCorrespondence;
import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
import static google.registry.persistence.transaction.TransactionManagerFactory.removeTmOverrideForTest;
import static google.registry.persistence.transaction.TransactionManagerFactory.setTmForTest;
import static google.registry.testing.AppEngineExtension.makeRegistrar1;
import static google.registry.testing.DatabaseHelper.createTld;
import static google.registry.testing.DatabaseHelper.persistActiveContact;
@ -52,6 +50,7 @@ import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationT
import google.registry.testing.DatastoreEntityExtension;
import google.registry.testing.FakeClock;
import google.registry.testing.FakeSleeper;
import google.registry.testing.TmOverrideExtension;
import google.registry.util.ResourceUtils;
import google.registry.util.Retrier;
import java.io.File;
@ -129,6 +128,10 @@ class Spec11PipelineTest {
final JpaIntegrationTestExtension database =
new JpaTestExtensions.Builder().withClock(new FakeClock()).buildIntegrationTestExtension();
@RegisterExtension
@Order(Order.DEFAULT + 1)
TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa();
private final Spec11PipelineOptions options =
PipelineOptionsFactory.create().as(Spec11PipelineOptions.class);
@ -233,7 +236,6 @@ class Spec11PipelineTest {
}
private void setupCloudSql() {
setTmForTest(jpaTm());
persistNewRegistrar("TheRegistrar");
persistNewRegistrar("NewRegistrar");
Registrar registrar1 =
@ -273,7 +275,6 @@ class Spec11PipelineTest {
persistResource(createDomain("no-email.com", "2A4BA9BBC-COM", registrar2, contact2));
persistResource(
createDomain("anti-anti-anti-virus.dev", "555666888-DEV", registrar3, contact3));
removeTmOverrideForTest();
}
private void verifySaveToGcs() throws Exception {

View file

@ -36,6 +36,7 @@ import google.registry.persistence.VKey;
import google.registry.persistence.transaction.JpaTestExtensions.JpaUnitTestExtension;
import google.registry.testing.DatabaseHelper;
import google.registry.testing.FakeClock;
import google.registry.testing.TmOverrideExtension;
import java.io.Serializable;
import java.math.BigInteger;
import java.sql.SQLException;
@ -48,8 +49,7 @@ import javax.persistence.IdClass;
import javax.persistence.OptimisticLockException;
import javax.persistence.RollbackException;
import org.hibernate.exception.JDBCConnectionException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
@ -84,15 +84,9 @@ class JpaTransactionManagerImplTest {
TestEntity.class, TestCompoundIdEntity.class, TestNamedCompoundIdEntity.class)
.buildUnitTestExtension();
@BeforeEach
void beforeEach() {
TransactionManagerFactory.setTmForTest(jpaTm());
}
@AfterEach
void afterEach() {
TransactionManagerFactory.removeTmOverrideForTest();
}
@RegisterExtension
@Order(Order.DEFAULT + 1)
TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa();
@Test
void transact_succeeds() {

View file

@ -16,8 +16,6 @@ package google.registry.schema.registrar;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.model.registrar.RegistrarContact.Type.WHOIS;
import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
import static google.registry.persistence.transaction.TransactionManagerFactory.setTmForTest;
import static google.registry.testing.DatabaseHelper.insertInDb;
import static google.registry.testing.DatabaseHelper.loadByEntity;
import static google.registry.testing.SqlHelper.saveRegistrar;
@ -28,6 +26,7 @@ import google.registry.model.registrar.RegistrarContact;
import google.registry.persistence.transaction.JpaTestExtensions;
import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationWithCoverageExtension;
import google.registry.testing.DatastoreEntityExtension;
import google.registry.testing.TmOverrideExtension;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
@ -44,13 +43,16 @@ class RegistrarContactTest {
JpaIntegrationWithCoverageExtension jpa =
new JpaTestExtensions.Builder().buildIntegrationWithCoverageExtension();
@RegisterExtension
@Order(Order.DEFAULT + 1)
TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa();
private Registrar testRegistrar;
private RegistrarContact testRegistrarPoc;
@BeforeEach
public void beforeEach() {
setTmForTest(jpaTm());
testRegistrar = saveRegistrar("registrarId");
testRegistrarPoc =
new RegistrarContact.Builder()

View file

@ -15,7 +15,6 @@
package google.registry.schema.registrar;
import static com.google.common.truth.Truth.assertThat;
import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
import static google.registry.testing.DatabaseHelper.existsInDb;
import static google.registry.testing.DatabaseHelper.insertInDb;
import static google.registry.testing.DatabaseHelper.loadByKey;
@ -29,11 +28,10 @@ import google.registry.model.registrar.RegistrarAddress;
import google.registry.persistence.VKey;
import google.registry.persistence.transaction.JpaTestExtensions;
import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationWithCoverageExtension;
import google.registry.persistence.transaction.TransactionManagerFactory;
import google.registry.testing.DatastoreEntityExtension;
import google.registry.testing.FakeClock;
import google.registry.testing.TmOverrideExtension;
import org.joda.time.DateTime;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
@ -52,13 +50,16 @@ public class RegistrarDaoTest {
JpaIntegrationWithCoverageExtension jpa =
new JpaTestExtensions.Builder().withClock(fakeClock).buildIntegrationWithCoverageExtension();
@RegisterExtension
@Order(Order.DEFAULT + 1)
TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa();
private final VKey<Registrar> registrarKey = VKey.createSql(Registrar.class, "registrarId");
private Registrar testRegistrar;
@BeforeEach
void beforeEach() {
TransactionManagerFactory.setTmForTest(jpaTm());
testRegistrar =
new Registrar.Builder()
.setType(Registrar.Type.TEST)
@ -75,11 +76,6 @@ public class RegistrarDaoTest {
.build();
}
@AfterEach
void afterEach() {
TransactionManagerFactory.removeTmOverrideForTest();
}
@Test
void saveNew_worksSuccessfully() {
assertThat(existsInDb(testRegistrar)).isFalse();

View file

@ -21,10 +21,9 @@ import google.registry.model.registrar.Registrar;
import google.registry.model.registrar.RegistrarContact;
import google.registry.model.registrar.RegistrarContact.RegistrarPocId;
import google.registry.persistence.VKey;
import google.registry.persistence.transaction.TransactionManagerFactory;
import google.registry.testing.AppEngineExtension;
import google.registry.testing.DatastoreEntityExtension;
import org.junit.jupiter.api.AfterEach;
import google.registry.testing.TmOverrideExtension;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
@ -41,17 +40,15 @@ public class SqlEntityTest {
final AppEngineExtension database =
new AppEngineExtension.Builder().withCloudSql().withoutCannedData().build();
@RegisterExtension
@Order(Order.DEFAULT + 1)
TmOverrideExtension tmOverrideExtension = TmOverrideExtension.withJpa();
@BeforeEach
void setup() throws Exception {
TransactionManagerFactory.setTmForTest(TransactionManagerFactory.jpaTm());
AppEngineExtension.loadInitialData();
}
@AfterEach
void teardown() {
TransactionManagerFactory.removeTmOverrideForTest();
}
@Test
void getPrimaryKeyString_oneIdColumn() {
// AppEngineExtension canned data: Registrar1

View file

@ -154,7 +154,7 @@ class DualDatabaseTestInvocationContextProvider implements TestTemplateInvocatio
context.getStore(NAMESPACE).put(ORIGINAL_TM_KEY, tm());
DatabaseType databaseType =
(DatabaseType) context.getStore(NAMESPACE).get(INJECTED_TM_SUPPLIER_KEY);
TransactionManagerFactory.setTmForTest(databaseType.getTm());
TransactionManagerFactory.setTmOverrideForTest(databaseType.getTm());
}
}

View file

@ -0,0 +1,73 @@
// Copyright 2021 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package google.registry.testing;
import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm;
import static google.registry.persistence.transaction.TransactionManagerFactory.ofyTm;
import google.registry.persistence.transaction.TransactionManager;
import google.registry.persistence.transaction.TransactionManagerFactory;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
/**
* JUnit extension for overriding the {@link TransactionManager} in tests.
*
* <p>You will typically want to run this at <code>@Order(Order.DEFAULT + 1)</code> alongside a
* {@link google.registry.persistence.transaction.JpaTransactionManagerExtension} or {@link
* DatastoreEntityExtension} with default {@link org.junit.jupiter.api.Order}. The transaction
* manager extension needs to run first so that when this override is called it's not trying to use
* the default dummy one.
*
* <p>This extension is incompatible with {@link DualDatabaseTest}. Use either that or this, but not
* both.
*/
public final class TmOverrideExtension implements BeforeEachCallback, AfterEachCallback {
private static enum TmOverride {
OFY,
JPA;
}
private final TmOverride tmOverride;
private TmOverrideExtension(TmOverride tmOverride) {
this.tmOverride = tmOverride;
}
/** Use the {@link google.registry.model.ofy.DatastoreTransactionManager} for all tests. */
public static TmOverrideExtension withOfy() {
return new TmOverrideExtension(TmOverride.OFY);
}
/**
* Use the {@link google.registry.persistence.transaction.JpaTransactionManager} for all tests.
*/
public static TmOverrideExtension withJpa() {
return new TmOverrideExtension(TmOverride.JPA);
}
@Override
public void beforeEach(ExtensionContext context) {
TransactionManagerFactory.setTmOverrideForTest(
tmOverride == TmOverride.OFY ? ofyTm() : jpaTm());
}
@Override
public void afterEach(ExtensionContext context) {
TransactionManagerFactory.removeTmOverrideForTest();
}
}