diff --git a/core/src/main/java/google/registry/persistence/EntityCallbacksListener.java b/core/src/main/java/google/registry/persistence/EntityCallbacksListener.java new file mode 100644 index 000000000..931b24976 --- /dev/null +++ b/core/src/main/java/google/registry/persistence/EntityCallbacksListener.java @@ -0,0 +1,196 @@ +// Copyright 2020 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.persistence; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.Objects; +import java.util.stream.Stream; +import javax.persistence.Embeddable; +import javax.persistence.Embedded; +import javax.persistence.MappedSuperclass; +import javax.persistence.PostLoad; +import javax.persistence.PostPersist; +import javax.persistence.PostRemove; +import javax.persistence.PostUpdate; +import javax.persistence.PrePersist; +import javax.persistence.PreRemove; +import javax.persistence.PreUpdate; + +/** + * A listener class to invoke entity callbacks in cases where Hibernate doesn't invoke the callback + * as expected. + * + *

JPA defines a few annotations, e.g. {@link PostLoad}, that we can use for the application to + * react to certain events that occur inside the persistence mechanism. However, Hibernate only + * supports a few basic use cases, e.g. defining a {@link PostLoad} method directly in an {@link + * javax.persistence.Entity} class or in an {@link Embeddable} class. If the annotated method is + * defined in an {@link Embeddable} class that is a property of another {@link Embeddable} class, or + * it is defined in a parent class of the {@link Embeddable} class, Hibernate doesn't invoke it. + * + *

This listener is added in core/src/main/resources/META-INF/orm.xml as a default entity + * listener whose annotated methods will be invoked by Hibernate when corresponding events happen. + * For example, {@link EntityCallbacksListener#prePersist} will be invoked before the entity is + * persisted to the database, then it will recursively invoke any other {@link PrePersist} method + * that should be invoked but not handled by Hibernate due to the bug. + * + * @see JPA + * Callbacks + * @see HHH-13316 + */ +public class EntityCallbacksListener { + + @PrePersist + void prePersist(Object entity) { + EntityCallbackExecutor.create(PrePersist.class).execute(entity, entity.getClass()); + } + + @PreRemove + void preRemove(Object entity) { + EntityCallbackExecutor.create(PreRemove.class).execute(entity, entity.getClass()); + } + + @PostPersist + void postPersist(Object entity) { + EntityCallbackExecutor.create(PostPersist.class).execute(entity, entity.getClass()); + } + + @PostRemove + void postRemove(Object entity) { + EntityCallbackExecutor.create(PostRemove.class).execute(entity, entity.getClass()); + } + + @PreUpdate + void preUpdate(Object entity) { + EntityCallbackExecutor.create(PreUpdate.class).execute(entity, entity.getClass()); + } + + @PostUpdate + void postUpdate(Object entity) { + EntityCallbackExecutor.create(PostUpdate.class).execute(entity, entity.getClass()); + } + + @PostLoad + void postLoad(Object entity) { + EntityCallbackExecutor.create(PostLoad.class).execute(entity, entity.getClass()); + } + + private static class EntityCallbackExecutor { + Class callbackType; + + private EntityCallbackExecutor(Class callbackType) { + this.callbackType = callbackType; + } + + private static EntityCallbackExecutor create(Class callbackType) { + return new EntityCallbackExecutor(callbackType); + } + + /** + * Executes eligible callbacks in {@link Embedded} properties recursively. + * + * @param entity the Java object of the entity class + * @param entityType either the type of the entity or an ancestor type + */ + private void execute(Object entity, Class entityType) { + Class parentType = entityType.getSuperclass(); + if (parentType != null && parentType.isAnnotationPresent(MappedSuperclass.class)) { + execute(entity, parentType); + } + + findEmbeddedProperties(entity, entityType) + .forEach( + normalEmbedded -> { + // For each normal embedded property, we don't execute its callback method because + // it is handled by Hibernate. However, for the embedded property defined in the + // entity's parent class, we need to treat it as a nested embedded property and + // invoke its callback function. + if (entity.getClass().equals(entityType)) { + executeCallbackForNormalEmbeddedProperty( + normalEmbedded, normalEmbedded.getClass()); + } else { + executeCallbackForNestedEmbeddedProperty( + normalEmbedded, normalEmbedded.getClass()); + } + }); + } + + private void executeCallbackForNestedEmbeddedProperty( + Object nestedEmbeddedObject, Class nestedEmbeddedType) { + Class parentType = nestedEmbeddedType.getSuperclass(); + if (parentType != null && parentType.isAnnotationPresent(MappedSuperclass.class)) { + executeCallbackForNestedEmbeddedProperty(nestedEmbeddedObject, parentType); + } + + findEmbeddedProperties(nestedEmbeddedObject, nestedEmbeddedType) + .forEach( + embeddedProperty -> + executeCallbackForNestedEmbeddedProperty( + embeddedProperty, embeddedProperty.getClass())); + + for (Method method : nestedEmbeddedType.getDeclaredMethods()) { + if (method.isAnnotationPresent(callbackType)) { + invokeMethod(method, nestedEmbeddedObject); + } + } + } + + private void executeCallbackForNormalEmbeddedProperty( + Object normalEmbeddedObject, Class normalEmbeddedType) { + Class parentType = normalEmbeddedType.getSuperclass(); + if (parentType != null && parentType.isAnnotationPresent(MappedSuperclass.class)) { + executeCallbackForNormalEmbeddedProperty(normalEmbeddedObject, parentType); + } + + findEmbeddedProperties(normalEmbeddedObject, normalEmbeddedType) + .forEach( + embeddedProperty -> + executeCallbackForNestedEmbeddedProperty( + embeddedProperty, embeddedProperty.getClass())); + } + + private Stream findEmbeddedProperties(Object object, Class clazz) { + return Arrays.stream(clazz.getDeclaredFields()) + .filter( + field -> + field.isAnnotationPresent(Embedded.class) + || field.getType().isAnnotationPresent(Embeddable.class)) + .map(field -> getFieldObject(field, object)) + .filter(Objects::nonNull); + } + + private static Object getFieldObject(Field field, Object object) { + field.setAccessible(true); + try { + return field.get(object); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + private static void invokeMethod(Method method, Object object) { + method.setAccessible(true); + try { + method.invoke(object); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/core/src/main/resources/META-INF/orm.xml b/core/src/main/resources/META-INF/orm.xml index 3bd380f39..9eb506ffa 100644 --- a/core/src/main/resources/META-INF/orm.xml +++ b/core/src/main/resources/META-INF/orm.xml @@ -10,4 +10,11 @@ + + + + + + + diff --git a/core/src/test/java/google/registry/persistence/EntityCallbacksListenerTest.java b/core/src/test/java/google/registry/persistence/EntityCallbacksListenerTest.java new file mode 100644 index 000000000..6f6989dea --- /dev/null +++ b/core/src/test/java/google/registry/persistence/EntityCallbacksListenerTest.java @@ -0,0 +1,302 @@ +// Copyright 2020 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.persistence; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static google.registry.persistence.transaction.TransactionManagerFactory.jpaTm; + +import com.google.common.collect.ImmutableSet; +import google.registry.persistence.transaction.JpaTestRules; +import google.registry.persistence.transaction.JpaTestRules.JpaUnitTestRule; +import java.lang.reflect.Method; +import javax.persistence.Embeddable; +import javax.persistence.Embedded; +import javax.persistence.Entity; +import javax.persistence.Id; +import javax.persistence.MappedSuperclass; +import javax.persistence.PostLoad; +import javax.persistence.PostPersist; +import javax.persistence.PostRemove; +import javax.persistence.PostUpdate; +import javax.persistence.PrePersist; +import javax.persistence.PreRemove; +import javax.persistence.PreUpdate; +import javax.persistence.Transient; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link EntityCallbacksListener}. */ +@RunWith(JUnit4.class) +public class EntityCallbacksListenerTest { + + @Rule + public final JpaUnitTestRule jpaRule = + new JpaTestRules.Builder().withEntityClass(TestEntity.class).buildUnitTestRule(); + + @Test + public void verifyAllCallbacks_executedExpectedTimes() { + TestEntity testPersist = new TestEntity(); + jpaTm().transact(() -> jpaTm().saveNew(testPersist)); + checkAll(testPersist, 1, 0, 0, 0); + + TestEntity testUpdate = new TestEntity(); + TestEntity updated = + jpaTm() + .transact( + () -> { + TestEntity merged = jpaTm().getEntityManager().merge(testUpdate); + merged.foo++; + jpaTm().getEntityManager().flush(); + return merged; + }); + // Note that when we get the merged entity, its @PostLoad callbacks are also invoked + checkAll(updated, 0, 1, 0, 1); + + TestEntity testLoad = + jpaTm().transact(() -> jpaTm().load(VKey.createSql(TestEntity.class, "id"))).get(); + checkAll(testLoad, 0, 0, 0, 1); + + TestEntity testRemove = + jpaTm() + .transact( + () -> { + TestEntity removed = jpaTm().load(VKey.createSql(TestEntity.class, "id")).get(); + jpaTm().getEntityManager().remove(removed); + return removed; + }); + checkAll(testRemove, 0, 0, 1, 1); + } + + @Test + public void verifyAllManagedEntities_haveNoMethodWithEmbedded() { + ImmutableSet violations = + PersistenceXmlUtility.getManagedClasses().stream() + .filter(clazz -> clazz.isAnnotationPresent(Entity.class)) + .filter(EntityCallbacksListenerTest::hasMethodAnnotatedWithEmbedded) + .collect(toImmutableSet()); + assertWithMessage( + "Found entity classes having methods annotated with @Embedded. EntityCallbacksListener" + + " only supports annotating fields with @Embedded.") + .that(violations) + .isEmpty(); + } + + @Test + public void verifyHasMethodAnnotatedWithEmbedded_work() { + assertThat(hasMethodAnnotatedWithEmbedded(ViolationEntity.class)).isTrue(); + } + + private static boolean hasMethodAnnotatedWithEmbedded(Class entityType) { + boolean result = false; + Class parentType = entityType.getSuperclass(); + if (parentType != null && parentType.isAnnotationPresent(MappedSuperclass.class)) { + result = hasMethodAnnotatedWithEmbedded(parentType); + } + for (Method method : entityType.getDeclaredMethods()) { + if (method.isAnnotationPresent(Embedded.class)) { + result = true; + break; + } + } + return result; + } + + private static void checkAll( + TestEntity testEntity, + int expectedPersist, + int expectedUpdate, + int expectedRemove, + int expectedLoad) { + assertThat(testEntity.entityEmbedded.entityEmbeddedNested.entityEmbeddedNestedPostPersist) + .isEqualTo(expectedPersist); + assertThat(testEntity.entityEmbedded.entityEmbeddedNested.entityEmbeddedNestedPrePersist) + .isEqualTo(expectedPersist); + + assertThat(testEntity.entityEmbedded.entityEmbeddedNested.entityEmbeddedNestedPreUpdate) + .isEqualTo(expectedUpdate); + assertThat(testEntity.entityEmbedded.entityEmbeddedNested.entityEmbeddedNestedPostUpdate) + .isEqualTo(expectedUpdate); + + assertThat(testEntity.entityEmbedded.entityEmbeddedNested.entityEmbeddedNestedPreRemove) + .isEqualTo(expectedRemove); + assertThat(testEntity.entityEmbedded.entityEmbeddedNested.entityEmbeddedNestedPostRemove) + .isEqualTo(expectedRemove); + + assertThat(testEntity.entityPostLoad).isEqualTo(expectedLoad); + assertThat(testEntity.entityEmbedded.entityEmbeddedPostLoad).isEqualTo(expectedLoad); + assertThat(testEntity.entityEmbedded.entityEmbeddedNested.entityEmbeddedNestedPostLoad) + .isEqualTo(expectedLoad); + assertThat(testEntity.entityEmbedded.entityEmbeddedParentPostLoad).isEqualTo(expectedLoad); + + assertThat(testEntity.parentPostLoad).isEqualTo(expectedLoad); + assertThat(testEntity.parentEmbedded.parentEmbeddedPostLoad).isEqualTo(expectedLoad); + assertThat(testEntity.parentEmbedded.parentEmbeddedNested.parentEmbeddedNestedPostLoad) + .isEqualTo(expectedLoad); + assertThat(testEntity.parentEmbedded.parentEmbeddedParentPostLoad).isEqualTo(expectedLoad); + } + + @Entity(name = "TestEntity") + private static class TestEntity extends ParentEntity { + @Id String name = "id"; + int foo = 0; + + @Transient int entityPostLoad = 0; + + @Embedded EntityEmbedded entityEmbedded = new EntityEmbedded(); + + @PostLoad + void entityPostLoad() { + entityPostLoad++; + } + } + + @Embeddable + private static class EntityEmbedded extends EntityEmbeddedParent { + @Embedded EntityEmbeddedNested entityEmbeddedNested = new EntityEmbeddedNested(); + + @Transient int entityEmbeddedPostLoad = 0; + + String entityEmbedded = "placeholder"; + + @PostLoad + void entityEmbeddedPrePersist() { + entityEmbeddedPostLoad++; + } + } + + @MappedSuperclass + private static class EntityEmbeddedParent { + @Transient int entityEmbeddedParentPostLoad = 0; + + String entityEmbeddedParent = "placeholder"; + + @PostLoad + void entityEmbeddedParentPostLoad() { + entityEmbeddedParentPostLoad++; + } + } + + @Embeddable + private static class EntityEmbeddedNested { + @Transient int entityEmbeddedNestedPrePersist = 0; + @Transient int entityEmbeddedNestedPreRemove = 0; + @Transient int entityEmbeddedNestedPostPersist = 0; + @Transient int entityEmbeddedNestedPostRemove = 0; + @Transient int entityEmbeddedNestedPreUpdate = 0; + @Transient int entityEmbeddedNestedPostUpdate = 0; + @Transient int entityEmbeddedNestedPostLoad = 0; + + String entityEmbeddedNested = "placeholder"; + + @PrePersist + void entityEmbeddedNestedPrePersist() { + entityEmbeddedNestedPrePersist++; + } + + @PreRemove + void entityEmbeddedNestedPreRemove() { + entityEmbeddedNestedPreRemove++; + } + + @PostPersist + void entityEmbeddedNestedPostPersist() { + entityEmbeddedNestedPostPersist++; + } + + @PostRemove + void entityEmbeddedNestedPostRemove() { + entityEmbeddedNestedPostRemove++; + } + + @PreUpdate + void entityEmbeddedNestedPreUpdate() { + entityEmbeddedNestedPreUpdate++; + } + + @PostUpdate + void entityEmbeddedNestedPostUpdate() { + entityEmbeddedNestedPostUpdate++; + } + + @PostLoad + void entityEmbeddedNestedPostLoad() { + entityEmbeddedNestedPostLoad++; + } + } + + @MappedSuperclass + private static class ParentEntity { + @Embedded ParentEmbedded parentEmbedded = new ParentEmbedded(); + @Transient int parentPostLoad = 0; + + String parentEntity = "placeholder"; + + @PostLoad + void parentPostLoad() { + parentPostLoad++; + } + } + + @Embeddable + private static class ParentEmbedded extends ParentEmbeddedParent { + @Transient int parentEmbeddedPostLoad = 0; + + String parentEmbedded = "placeholder"; + + @Embedded ParentEmbeddedNested parentEmbeddedNested = new ParentEmbeddedNested(); + + @PostLoad + void parentEmbeddedPostLoad() { + parentEmbeddedPostLoad++; + } + } + + @Embeddable + private static class ParentEmbeddedNested { + @Transient int parentEmbeddedNestedPostLoad = 0; + + String parentEmbeddedNested = "placeholder"; + + @PostLoad + void parentEmbeddedNestedPostLoad() { + parentEmbeddedNestedPostLoad++; + } + } + + @MappedSuperclass + private static class ParentEmbeddedParent { + @Transient int parentEmbeddedParentPostLoad = 0; + + String parentEmbeddedParent = "placeholder"; + + @PostLoad + void parentEmbeddedParentPostLoad() { + parentEmbeddedParentPostLoad++; + } + } + + @Entity + private static class ViolationEntity { + + @Embedded + EntityEmbedded getEntityEmbedded() { + return new EntityEmbedded(); + } + } +}