diff --git a/core/src/test/java/google/registry/beam/TestPipelineExtension.java b/core/src/test/java/google/registry/beam/TestPipelineExtension.java
new file mode 100644
index 000000000..bef8158e6
--- /dev/null
+++ b/core/src/test/java/google/registry/beam/TestPipelineExtension.java
@@ -0,0 +1,569 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package google.registry.beam;
+
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.is;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import java.io.IOException;
+import java.lang.annotation.Annotation;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.PipelineResult;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.io.FileSystems;
+import org.apache.beam.sdk.metrics.MetricNameFilter;
+import org.apache.beam.sdk.metrics.MetricResult;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
+import org.apache.beam.sdk.metrics.MetricsFilter;
+import org.apache.beam.sdk.options.ApplicationNameOptions;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptions.CheckEnabled;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
+import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.testing.CrashingRunner;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipelineOptions;
+import org.apache.beam.sdk.testing.ValidatesRunner;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.util.common.ReflectHelpers;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicate;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.FluentIterable;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
+import org.junit.experimental.categories.Category;
+import org.junit.rules.TestRule;
+import org.junit.runner.Description;
+import org.junit.runners.model.Statement;
+
+// NOTE: This file is copied from the Apache Beam distribution so that it can be locally modified to
+// support JUnit 5.
+
+/**
+ * A creator of test pipelines that can be used inside of tests that can be configured to run
+ * locally or against a remote pipeline runner.
+ *
+ *
It is recommended to tag hand-selected tests for this purpose using the {@link
+ * ValidatesRunner} {@link Category} annotation, as each test run against a pipeline runner will
+ * utilize resources of that pipeline runner.
+ *
+ *
In order to run tests on a pipeline runner, the following conditions must be met:
+ *
+ *
+ *
+ * Use {@link PAssert} for tests, as it integrates with this test harness in both direct and
+ * remote execution modes. For example:
+ *
+ *
+ * {@literal @Rule}
+ * public final transient TestPipeline p = TestPipeline.create();
+ *
+ * {@literal @Test}
+ * {@literal @Category}(NeedsRunner.class)
+ * public void myPipelineTest() throws Exception {
+ * final PCollection<String> pCollection = pipeline.apply(...)
+ * PAssert.that(pCollection).containsInAnyOrder(...);
+ * pipeline.run();
+ * }
+ *
+ *
+ * For pipeline runners, it is required that they must throw an {@link AssertionError} containing
+ * the message from the {@link PAssert} that failed.
+ *
+ *
See also the Testing documentation
+ * section.
+ */
+public class TestPipelineExtension extends Pipeline implements TestRule {
+
+ private final PipelineOptions options;
+
+ private static class PipelineRunEnforcement {
+
+ @SuppressWarnings("WeakerAccess")
+ protected boolean enableAutoRunIfMissing;
+
+ protected final Pipeline pipeline;
+
+ protected boolean runAttempted;
+
+ private PipelineRunEnforcement(final Pipeline pipeline) {
+ this.pipeline = pipeline;
+ }
+
+ protected void enableAutoRunIfMissing(final boolean enable) {
+ enableAutoRunIfMissing = enable;
+ }
+
+ protected void beforePipelineExecution() {
+ runAttempted = true;
+ }
+
+ protected void afterPipelineExecution() {}
+
+ protected void afterUserCodeFinished() {
+ if (!runAttempted && enableAutoRunIfMissing) {
+ pipeline.run().waitUntilFinish();
+ }
+ }
+ }
+
+ private static class PipelineAbandonedNodeEnforcement extends PipelineRunEnforcement {
+
+ // Null until the pipeline has been run
+ @Nullable private List runVisitedNodes;
+
+ private final Predicate isPAssertNode =
+ node ->
+ node.getTransform() instanceof PAssert.GroupThenAssert
+ || node.getTransform() instanceof PAssert.GroupThenAssertForSingleton
+ || node.getTransform() instanceof PAssert.OneSideInputAssert;
+
+ private static class NodeRecorder extends PipelineVisitor.Defaults {
+
+ private final List visited = new ArrayList<>();
+
+ @Override
+ public void leaveCompositeTransform(final TransformHierarchy.Node node) {
+ visited.add(node);
+ }
+
+ @Override
+ public void visitPrimitiveTransform(final TransformHierarchy.Node node) {
+ visited.add(node);
+ }
+ }
+
+ private PipelineAbandonedNodeEnforcement(final TestPipelineExtension pipeline) {
+ super(pipeline);
+ runVisitedNodes = null;
+ }
+
+ private List recordPipelineNodes(final Pipeline pipeline) {
+ final NodeRecorder nodeRecorder = new NodeRecorder();
+ pipeline.traverseTopologically(nodeRecorder);
+ return nodeRecorder.visited;
+ }
+
+ private boolean isEmptyPipeline(final Pipeline pipeline) {
+ final IsEmptyVisitor isEmptyVisitor = new IsEmptyVisitor();
+ pipeline.traverseTopologically(isEmptyVisitor);
+ return isEmptyVisitor.isEmpty();
+ }
+
+ private void verifyPipelineExecution() {
+ if (!isEmptyPipeline(pipeline)) {
+ if (!runAttempted && !enableAutoRunIfMissing) {
+ throw new PipelineRunMissingException("The pipeline has not been run.");
+
+ } else {
+ final List pipelineNodes = recordPipelineNodes(pipeline);
+ if (pipelineRunSucceeded() && !visitedAll(pipelineNodes)) {
+ final boolean hasDanglingPAssert =
+ FluentIterable.from(pipelineNodes)
+ .filter(Predicates.not(Predicates.in(runVisitedNodes)))
+ .anyMatch(isPAssertNode);
+ if (hasDanglingPAssert) {
+ throw new AbandonedNodeException("The pipeline contains abandoned PAssert(s).");
+ } else {
+ throw new AbandonedNodeException("The pipeline contains abandoned PTransform(s).");
+ }
+ }
+ }
+ }
+ }
+
+ private boolean visitedAll(final List pipelineNodes) {
+ return runVisitedNodes.equals(pipelineNodes);
+ }
+
+ private boolean pipelineRunSucceeded() {
+ return runVisitedNodes != null;
+ }
+
+ @Override
+ protected void afterPipelineExecution() {
+ runVisitedNodes = recordPipelineNodes(pipeline);
+ super.afterPipelineExecution();
+ }
+
+ @Override
+ protected void afterUserCodeFinished() {
+ super.afterUserCodeFinished();
+ verifyPipelineExecution();
+ }
+ }
+
+ /**
+ * An exception thrown in case an abandoned {@link org.apache.beam.sdk.transforms.PTransform} is
+ * detected, that is, a {@link org.apache.beam.sdk.transforms.PTransform} that has not been run.
+ */
+ public static class AbandonedNodeException extends RuntimeException {
+
+ AbandonedNodeException(final String msg) {
+ super(msg);
+ }
+ }
+
+ /** An exception thrown in case a test finishes without invoking {@link Pipeline#run()}. */
+ public static class PipelineRunMissingException extends RuntimeException {
+
+ PipelineRunMissingException(final String msg) {
+ super(msg);
+ }
+ }
+
+ /** System property used to set {@link TestPipelineOptions}. */
+ public static final String PROPERTY_BEAM_TEST_PIPELINE_OPTIONS = "beamTestPipelineOptions";
+
+ static final String PROPERTY_USE_DEFAULT_DUMMY_RUNNER = "beamUseDummyRunner";
+
+ private static final ObjectMapper MAPPER =
+ new ObjectMapper()
+ .registerModules(ObjectMapper.findModules(ReflectHelpers.findClassLoader()));
+
+ @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
+ private Optional extends PipelineRunEnforcement> enforcement = Optional.absent();
+
+ /**
+ * Creates and returns a new test pipeline.
+ *
+ * Use {@link PAssert} to add tests, then call {@link Pipeline#run} to execute the pipeline and
+ * check the tests.
+ */
+ public static TestPipelineExtension create() {
+ return fromOptions(testingPipelineOptions());
+ }
+
+ public static TestPipelineExtension fromOptions(PipelineOptions options) {
+ return new TestPipelineExtension(options);
+ }
+
+ private TestPipelineExtension(final PipelineOptions options) {
+ super(options);
+ this.options = options;
+ }
+
+ @Override
+ public PipelineOptions getOptions() {
+ return this.options;
+ }
+
+ @Override
+ public Statement apply(final Statement statement, final Description description) {
+ return new Statement() {
+
+ private void setDeducedEnforcementLevel() {
+ // if the enforcement level has not been set by the user do auto-inference
+ if (!enforcement.isPresent()) {
+
+ final boolean annotatedWithNeedsRunner =
+ FluentIterable.from(description.getAnnotations())
+ .filter(Annotations.Predicates.isAnnotationOfType(Category.class))
+ .anyMatch(Annotations.Predicates.isCategoryOf(NeedsRunner.class, true));
+
+ final boolean crashingRunner = CrashingRunner.class.isAssignableFrom(options.getRunner());
+
+ checkState(
+ !(annotatedWithNeedsRunner && crashingRunner),
+ "The test was annotated with a [@%s] / [@%s] while the runner "
+ + "was set to [%s]. Please re-check your configuration.",
+ NeedsRunner.class.getSimpleName(),
+ ValidatesRunner.class.getSimpleName(),
+ CrashingRunner.class.getSimpleName());
+
+ enableAbandonedNodeEnforcement(annotatedWithNeedsRunner || !crashingRunner);
+ }
+ }
+
+ @Override
+ public void evaluate() throws Throwable {
+ options.as(ApplicationNameOptions.class).setAppName(getAppName(description));
+
+ setDeducedEnforcementLevel();
+
+ // statement.evaluate() essentially runs the user code contained in the unit test at hand.
+ // Exceptions thrown during the execution of the user's test code will propagate here,
+ // unless the user explicitly handles them with a "catch" clause in his code. If the
+ // exception is handled by a user's "catch" clause, is does not interrupt the flow and
+ // we move on to invoking the configured enforcements.
+ // If the user does not handle a thrown exception, it will propagate here and interrupt
+ // the flow, preventing the enforcement(s) from being activated.
+ // The motivation for this is avoiding enforcements over faulty pipelines.
+ statement.evaluate();
+ enforcement.get().afterUserCodeFinished();
+ }
+ };
+ }
+
+ /**
+ * Runs this {@link TestPipelineExtension}, unwrapping any {@code AssertionError} that is raised during
+ * testing.
+ */
+ @Override
+ public PipelineResult run() {
+ return run(getOptions());
+ }
+
+ /** Like {@link #run} but with the given potentially modified options. */
+ @Override
+ public PipelineResult run(PipelineOptions options) {
+ checkState(
+ enforcement.isPresent(),
+ "Is your TestPipeline declaration missing a @Rule annotation? Usage: "
+ + "@Rule public final transient TestPipeline pipeline = TestPipeline.create();");
+
+ final PipelineResult pipelineResult;
+ try {
+ enforcement.get().beforePipelineExecution();
+ PipelineOptions updatedOptions =
+ MAPPER.convertValue(MAPPER.valueToTree(options), PipelineOptions.class);
+ updatedOptions
+ .as(TestValueProviderOptions.class)
+ .setProviderRuntimeValues(StaticValueProvider.of(providerRuntimeValues));
+ pipelineResult = super.run(updatedOptions);
+ verifyPAssertsSucceeded(this, pipelineResult);
+ } catch (RuntimeException exc) {
+ Throwable cause = exc.getCause();
+ if (cause instanceof AssertionError) {
+ throw (AssertionError) cause;
+ } else {
+ throw exc;
+ }
+ }
+
+ // If we reach this point, the pipeline has been run and no exceptions have been thrown during
+ // its execution.
+ enforcement.get().afterPipelineExecution();
+ return pipelineResult;
+ }
+
+ /** Implementation detail of {@link #newProvider}, do not use. */
+ @Internal
+ public interface TestValueProviderOptions extends PipelineOptions {
+ ValueProvider