diff --git a/core/src/test/java/google/registry/testing/InjectExtension.java b/core/src/test/java/google/registry/testing/InjectExtension.java index d4693f6dd..9da767fef 100644 --- a/core/src/test/java/google/registry/testing/InjectExtension.java +++ b/core/src/test/java/google/registry/testing/InjectExtension.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Set; import javax.annotation.Nullable; import org.junit.jupiter.api.extension.AfterEachCallback; +import org.junit.jupiter.api.extension.BeforeEachCallback; import org.junit.jupiter.api.extension.ExtensionContext; /** @@ -85,26 +86,38 @@ import org.junit.jupiter.api.extension.ExtensionContext; * * @see google.registry.util.NonFinalForTesting */ -public class InjectExtension implements AfterEachCallback { +public class InjectExtension implements AfterEachCallback, BeforeEachCallback { private static class Change { private final Field field; - @Nullable private final Object oldValue; + @Nullable private Object oldValue; @Nullable private final Object newValue; + private boolean active; - Change(Field field, @Nullable Object oldValue, @Nullable Object newValue) { + Change(Field field, @Nullable Object oldValue, @Nullable Object newValue, boolean active) { this.field = field; this.oldValue = oldValue; this.newValue = newValue; + this.active = active; } } private final List changes = new ArrayList<>(); private final Set injected = new HashSet<>(); + /** Adds the specified field override to those set by the extension. */ + public InjectExtension withStaticFieldOverride( + Class clazz, String fieldName, @Nullable Object newValue) { + changes.add(new Change(getField(clazz, fieldName), null, newValue, false)); + return this; + } + /** * Sets a static field and be restores its current value after the test completes. * + *

Prefer to use withStaticFieldOverride(), which is more consistent with the extension + * pattern. + * *

The field is allowed to be {@code private}, but it most not be {@code final}. * *

This method may be called either from either your {@link @@ -116,50 +129,40 @@ public class InjectExtension implements AfterEachCallback { * @throws IllegalStateException if the field has already been injected during this test. */ public void setStaticField(Class clazz, String fieldName, @Nullable Object newValue) { - Field field; - Object oldValue; - try { - field = clazz.getDeclaredField(fieldName); - field.setAccessible(true); - oldValue = field.get(null); - } catch (NoSuchFieldException - | SecurityException - | IllegalArgumentException - | IllegalAccessException e) { - throw new IllegalArgumentException( - String.format("Static field not found: %s.%s", clazz.getSimpleName(), fieldName), e); - } - checkState( - !injected.contains(field), - "Static field already injected: %s.%s", - clazz.getSimpleName(), - fieldName); - try { - field.set(null, newValue); - } catch (IllegalArgumentException | IllegalAccessException e) { - throw new IllegalArgumentException( - String.format("Static field not settable: %s.%s", clazz.getSimpleName(), fieldName), e); - } - changes.add(new Change(field, oldValue, newValue)); + Field field = getField(clazz, fieldName); + Change change = new Change(field, null, newValue, true); + activateChange(change); + changes.add(change); injected.add(field); } + @Override + public void beforeEach(ExtensionContext context) { + for (Change change : changes) { + if (!change.active) { + activateChange(change); + } + } + } + @Override public void afterEach(ExtensionContext context) { RuntimeException thrown = null; for (Change change : changes) { - try { - checkState( - change.field.get(null).equals(change.newValue), - "Static field value was changed post-injection: %s.%s", - change.field.getDeclaringClass().getSimpleName(), - change.field.getName()); - change.field.set(null, change.oldValue); - } catch (IllegalArgumentException | IllegalStateException | IllegalAccessException e) { - if (thrown == null) { - thrown = new RuntimeException(e); - } else { - thrown.addSuppressed(e); + if (change.active) { + try { + checkState( + change.field.get(null).equals(change.newValue), + "Static field value was changed post-injection: %s.%s", + change.field.getDeclaringClass().getSimpleName(), + change.field.getName()); + change.field.set(null, change.oldValue); + } catch (IllegalArgumentException | IllegalStateException | IllegalAccessException e) { + if (thrown == null) { + thrown = new RuntimeException(e); + } else { + thrown.addSuppressed(e); + } } } } @@ -169,4 +172,40 @@ public class InjectExtension implements AfterEachCallback { throw thrown; } } + + private Field getField(Class clazz, String fieldName) { + try { + return clazz.getDeclaredField(fieldName); + } catch (SecurityException | NoSuchFieldException e) { + throw new IllegalArgumentException( + String.format("Static field not found: %s.%s", clazz.getSimpleName(), fieldName), e); + } + } + + private void activateChange(Change change) { + Class clazz = change.field.getDeclaringClass(); + try { + change.field.setAccessible(true); + change.oldValue = change.field.get(null); + } catch (IllegalArgumentException | IllegalAccessException e) { + throw new IllegalArgumentException( + String.format( + "Static field not gettable: %s.%s", clazz.getSimpleName(), change.field.getName()), + e); + } + checkState( + !injected.contains(change.field), + "Static field already injected: %s.%s", + clazz.getSimpleName(), + change.field.getName()); + try { + change.field.set(null, change.newValue); + } catch (IllegalArgumentException | IllegalAccessException e) { + throw new IllegalArgumentException( + String.format( + "Static field not settable: %s.%s", clazz.getSimpleName(), change.field.getName()), + e); + } + change.active = true; + } }