diff --git a/core/src/main/java/google/registry/beam/rde/RdePipeline.java b/core/src/main/java/google/registry/beam/rde/RdePipeline.java index 549e30c13..c30088ea4 100644 --- a/core/src/main/java/google/registry/beam/rde/RdePipeline.java +++ b/core/src/main/java/google/registry/beam/rde/RdePipeline.java @@ -27,6 +27,10 @@ import static google.registry.beam.rde.RdePipeline.TupleTags.REVISION_ID; import static google.registry.beam.rde.RdePipeline.TupleTags.SUPERORDINATE_DOMAINS; import static google.registry.model.reporting.HistoryEntryDao.RESOURCE_TYPES_TO_HISTORY_TYPES; import static google.registry.persistence.transaction.TransactionManagerFactory.tm; +import static google.registry.util.SafeSerializationUtils.safeDeserializeCollection; +import static google.registry.util.SafeSerializationUtils.serializeCollection; +import static google.registry.util.SerializeUtils.decodeBase64; +import static google.registry.util.SerializeUtils.encodeBase64; import static org.apache.beam.sdk.values.TypeDescriptors.kvs; import com.google.common.collect.ImmutableList; @@ -65,11 +69,7 @@ import google.registry.rde.PendingDeposit.PendingDepositCoder; import google.registry.rde.RdeMarshaller; import google.registry.util.UtilsModule; import google.registry.xml.ValidationMode; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.HashSet; import javax.inject.Inject; @@ -658,14 +658,8 @@ public class RdePipeline implements Serializable { */ @SuppressWarnings("unchecked") static ImmutableSet decodePendingDeposits(String encodedPendingDeposits) { - try (ObjectInputStream ois = - new ObjectInputStream( - new ByteArrayInputStream( - BaseEncoding.base64Url().omitPadding().decode(encodedPendingDeposits)))) { - return (ImmutableSet) ois.readObject(); - } catch (IOException | ClassNotFoundException e) { - throw new IllegalArgumentException("Unable to parse encoded pending deposit map.", e); - } + return ImmutableSet.copyOf( + safeDeserializeCollection(PendingDeposit.class, decodeBase64(encodedPendingDeposits))); } /** @@ -674,12 +668,7 @@ public class RdePipeline implements Serializable { */ public static String encodePendingDeposits(ImmutableSet pendingDeposits) throws IOException { - try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { - ObjectOutputStream oos = new ObjectOutputStream(baos); - oos.writeObject(pendingDeposits); - oos.flush(); - return BaseEncoding.base64Url().omitPadding().encode(baos.toByteArray()); - } + return encodeBase64(serializeCollection(pendingDeposits)); } public static void main(String[] args) throws IOException, ClassNotFoundException { diff --git a/core/src/main/java/google/registry/persistence/VKey.java b/core/src/main/java/google/registry/persistence/VKey.java index c2222e4a3..b2c9fb83b 100644 --- a/core/src/main/java/google/registry/persistence/VKey.java +++ b/core/src/main/java/google/registry/persistence/VKey.java @@ -16,6 +16,8 @@ package google.registry.persistence; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static google.registry.util.PreconditionsUtils.checkArgumentNotNull; +import static google.registry.util.SafeSerializationUtils.safeDeserialize; +import static google.registry.util.SerializeUtils.decodeBase64; import static java.util.function.Function.identity; import com.google.common.base.Joiner; @@ -97,7 +99,7 @@ public class VKey extends ImmutableObject implements Serializable { throw new IllegalArgumentException( String.format("\"%s\" missing from the string: %s", LOOKUP_KEY, keyString)); } - return VKey.create(classType, SerializeUtils.parse(Serializable.class, kvs.get(LOOKUP_KEY))); + return VKey.create(classType, safeDeserialize(decodeBase64(kvs.get(LOOKUP_KEY)))); } /** Returns the type of the entity. */ diff --git a/core/src/main/java/google/registry/rde/PendingDeposit.java b/core/src/main/java/google/registry/rde/PendingDeposit.java index 1ba7e2a61..dbd97962d 100644 --- a/core/src/main/java/google/registry/rde/PendingDeposit.java +++ b/core/src/main/java/google/registry/rde/PendingDeposit.java @@ -14,18 +14,23 @@ package google.registry.rde; +import static com.google.common.base.Preconditions.checkState; + import com.google.auto.value.AutoValue; import google.registry.model.common.Cursor.CursorType; import google.registry.model.rde.RdeMode; import java.io.IOException; import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.ObjectStreamException; import java.io.OutputStream; import java.io.Serializable; +import java.util.Optional; import javax.annotation.Nullable; import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.BooleanCoder; import org.apache.beam.sdk.coders.NullableCoder; -import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.joda.time.DateTime; @@ -35,6 +40,12 @@ import org.joda.time.Duration; * Container representing a single RDE or BRDA XML escrow deposit that needs to be created. * *

There are some {@code @Nullable} fields here because Optionals aren't Serializable. + * + *

Note that this class is serialized in two ways: by Beam pipelines using custom serialization + * mechanism and the {@code Coder} API, and by Java serialization when passed as command-line + * arguments (see {@code RdePipeline#decodePendingDeposits}). The latter requires safe + * deserialization because the data crosses credential boundaries (See {@code + * SafeObjectInputStream}). */ @AutoValue public abstract class PendingDeposit implements Serializable { @@ -95,11 +106,61 @@ public abstract class PendingDeposit implements Serializable { PendingDeposit() {} + /** + * Specifies that {@link SerializedForm} be used for {@code SafeObjectInputStream}-compatible + * custom-serialization of {@link AutoValue_PendingDeposit the AutoValue implementation class}. + * + *

This method is package-protected so that the AutoValue implementation class inherits this + * behavior. + * + *

This method leverages {@link PendingDepositCoder} to serializes an instance. However, it is + * not invoked in Beam pipelines. + */ + Object writeReplace() throws ObjectStreamException { + return new SerializedForm(this); + } + + /** + * Proxy for custom-serialization of {@link PendingDeposit}. This is necessary because the actual + * class to be (de)serialized is the generated AutoValue implementation. See also {@link + * #writeReplace}. + * + *

This class leverages {@link PendingDepositCoder} to safely deserializes an instance. + * However, it is not used in Beam pipelines. + */ + private static class SerializedForm implements Serializable { + + private static final long serialVersionUID = 3141095605225904433L; + + private PendingDeposit value; + + private SerializedForm(PendingDeposit value) { + this.value = value; + } + + private void writeObject(ObjectOutputStream os) throws IOException { + checkState(value != null, "Non-null value expected for serialization."); + PendingDepositCoder.INSTANCE.encode(value, os); + } + + private void readObject(ObjectInputStream is) throws IOException, ClassNotFoundException { + checkState(value == null, "Non-null value unexpected for deserialization."); + this.value = PendingDepositCoder.INSTANCE.decode(is); + } + + @SuppressWarnings("unused") + private Object readResolve() throws ObjectStreamException { + return this.value; + } + } + /** * A deterministic coder for {@link PendingDeposit} used during a GroupBy transform. * - *

We cannot use a {@link SerializableCoder} directly because it does not guarantee - * determinism, which is required by GroupBy. + *

We cannot use a {@code SerializableCoder} directly for two reasons: the default + * serialization does not guarantee determinism, which is required by GroupBy in Beam; and the + * default deserialization is not robust against deserialization-based attacks (See {@code + * SafeObjectInputStream} for more information). */ public static class PendingDepositCoder extends AtomicCoder { @@ -117,10 +178,15 @@ public abstract class PendingDeposit implements Serializable { public void encode(PendingDeposit value, OutputStream outStream) throws IOException { BooleanCoder.of().encode(value.manual(), outStream); StringUtf8Coder.of().encode(value.tld(), outStream); - SerializableCoder.of(DateTime.class).encode(value.watermark(), outStream); - SerializableCoder.of(RdeMode.class).encode(value.mode(), outStream); - NullableCoder.of(SerializableCoder.of(CursorType.class)).encode(value.cursor(), outStream); - NullableCoder.of(SerializableCoder.of(Duration.class)).encode(value.interval(), outStream); + StringUtf8Coder.of().encode(value.watermark().toString(), outStream); + StringUtf8Coder.of().encode(value.mode().name(), outStream); + NullableCoder.of(StringUtf8Coder.of()) + .encode( + Optional.ofNullable(value.cursor()).map(CursorType::name).orElse(null), outStream); + NullableCoder.of(StringUtf8Coder.of()) + .encode( + Optional.ofNullable(value.interval()).map(Duration::toString).orElse(null), + outStream); NullableCoder.of(StringUtf8Coder.of()).encode(value.directoryWithTrailingSlash(), outStream); NullableCoder.of(VarIntCoder.of()).encode(value.revision(), outStream); } @@ -130,10 +196,14 @@ public abstract class PendingDeposit implements Serializable { return new AutoValue_PendingDeposit( BooleanCoder.of().decode(inStream), StringUtf8Coder.of().decode(inStream), - SerializableCoder.of(DateTime.class).decode(inStream), - SerializableCoder.of(RdeMode.class).decode(inStream), - NullableCoder.of(SerializableCoder.of(CursorType.class)).decode(inStream), - NullableCoder.of(SerializableCoder.of(Duration.class)).decode(inStream), + DateTime.parse(StringUtf8Coder.of().decode(inStream)), + RdeMode.valueOf(StringUtf8Coder.of().decode(inStream)), + Optional.ofNullable(NullableCoder.of(StringUtf8Coder.of()).decode(inStream)) + .map(CursorType::valueOf) + .orElse(null), + Optional.ofNullable(NullableCoder.of(StringUtf8Coder.of()).decode(inStream)) + .map(Duration::parse) + .orElse(null), NullableCoder.of(StringUtf8Coder.of()).decode(inStream), NullableCoder.of(VarIntCoder.of()).decode(inStream)); } diff --git a/core/src/main/java/google/registry/xjc/JaxbFragment.java b/core/src/main/java/google/registry/xjc/JaxbFragment.java deleted file mode 100644 index 44fef7966..000000000 --- a/core/src/main/java/google/registry/xjc/JaxbFragment.java +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2017 The Nomulus Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package google.registry.xjc; - -import static java.nio.charset.StandardCharsets.UTF_8; - -import google.registry.xml.XmlException; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.io.Serializable; - -/** - * JAXB element wrapper for java object serialization. - * - * Instances of {@link JaxbFragment} wrap a non-serializable JAXB element instance, and provide - * hooks into the java object serialization process that allow the elements to be safely - * marshalled and unmarshalled using {@link ObjectOutputStream} and {@link ObjectInputStream}, - * respectively. - */ -public class JaxbFragment implements Serializable { - - private static final long serialVersionUID = 5651243983008818813L; - - private T instance; - - /** Stores a JAXB element in a {@link JaxbFragment} */ - public static JaxbFragment create(T object) { - JaxbFragment fragment = new JaxbFragment<>(); - fragment.instance = object; - return fragment; - } - - /** Serializes a JAXB element into xml bytes. */ - private static byte[] freezeInstance(T instance) throws IOException { - try { - ByteArrayOutputStream bout = new ByteArrayOutputStream(); - XjcXmlTransformer.marshalLenient(instance, bout, UTF_8); - return bout.toByteArray(); - } catch (XmlException e) { - throw new IOException(e); - } - } - - /** Deserializes a JAXB element from xml bytes. */ - private static T unfreezeInstance(byte[] instanceData, Class instanceType) - throws IOException { - try { - ByteArrayInputStream bin = new ByteArrayInputStream(instanceData); - return XjcXmlTransformer.unmarshal(instanceType, bin); - } catch (XmlException e) { - throw new IOException(e); - } - } - - /** - * Retrieves the JAXB element that is wrapped by this fragment. - */ - public T getInstance() { - return instance; - } - - @Override - public String toString() { - try { - return new String(freezeInstance(instance), UTF_8); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private void writeObject(ObjectOutputStream out) throws IOException { - // write instanceType, then instanceData - out.writeObject(instance.getClass()); - out.writeObject(freezeInstance(instance)); - } - - @SuppressWarnings("unchecked") - private void readObject(ObjectInputStream in) throws IOException { - // read instanceType, then instanceData - Class instanceType; - byte[] instanceData; - try { - instanceType = (Class) in.readObject(); - instanceData = (byte[]) in.readObject(); - } catch (ClassNotFoundException e) { - throw new RuntimeException(e); - } - instance = unfreezeInstance(instanceData, instanceType); - } -} diff --git a/core/src/test/java/google/registry/rde/PendingDepositTest.java b/core/src/test/java/google/registry/rde/PendingDepositTest.java new file mode 100644 index 000000000..2f31d942b --- /dev/null +++ b/core/src/test/java/google/registry/rde/PendingDepositTest.java @@ -0,0 +1,59 @@ +// Copyright 2023 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.rde; + +import static com.google.common.truth.Truth.assertThat; +import static google.registry.model.common.Cursor.CursorType.RDE_STAGING; +import static google.registry.model.rde.RdeMode.FULL; +import static google.registry.util.SafeSerializationUtils.safeDeserialize; +import static google.registry.util.SerializeUtils.deserialize; +import static google.registry.util.SerializeUtils.serialize; + +import org.joda.time.DateTime; +import org.joda.time.Duration; +import org.junit.jupiter.api.Test; + +/** Unit tests for {@link PendingDeposit}. */ +public class PendingDepositTest { + private final DateTime now = DateTime.parse("2000-01-01TZ"); + + PendingDeposit pendingDeposit = + PendingDeposit.create("soy", now, FULL, RDE_STAGING, Duration.standardDays(1)); + + PendingDeposit manualPendingDeposit = + PendingDeposit.createInManualOperation("soy", now, FULL, "/", null); + + @Test + void deserialize_normalDeposit_success() { + assertThat(deserialize(PendingDeposit.class, serialize(pendingDeposit))) + .isEqualTo(pendingDeposit); + } + + @Test + void deserialize_manualDeposit_success() { + assertThat(deserialize(PendingDeposit.class, serialize(manualPendingDeposit))) + .isEqualTo(manualPendingDeposit); + } + + @Test + void safeDeserialize_normalDeposit_success() { + assertThat(safeDeserialize(serialize(pendingDeposit))).isEqualTo(pendingDeposit); + } + + @Test + void safeDeserialize_manualDeposit_success() { + assertThat(safeDeserialize(serialize(manualPendingDeposit))).isEqualTo(manualPendingDeposit); + } +} diff --git a/core/src/test/java/google/registry/xjc/JaxbFragmentTest.java b/core/src/test/java/google/registry/xjc/JaxbFragmentTest.java deleted file mode 100644 index bc1059fd0..000000000 --- a/core/src/test/java/google/registry/xjc/JaxbFragmentTest.java +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2017 The Nomulus Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package google.registry.xjc; - -import static com.google.common.truth.Truth.assertThat; -import static google.registry.testing.TestDataHelper.loadFile; -import static java.nio.charset.StandardCharsets.UTF_8; - -import google.registry.xjc.rdehost.XjcRdeHostElement; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.InputStream; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import org.junit.jupiter.api.Test; - -/** Unit tests for {@link JaxbFragment}. */ -class JaxbFragmentTest { - - private static final String HOST_FRAGMENT = loadFile(XjcObjectTest.class, "host_fragment.xml"); - - /** Verifies that a {@link JaxbFragment} can be serialized and deserialized successfully. */ - @SuppressWarnings("unchecked") - @Test - void testJavaSerialization() throws Exception { - // Load rdeHost xml fragment into a jaxb object, wrap it, marshal, unmarshal, verify host. - // The resulting host name should be "ns1.example1.test", from the original xml fragment. - try (InputStream source = new ByteArrayInputStream(HOST_FRAGMENT.getBytes(UTF_8))) { - // Load xml - JaxbFragment hostFragment = - JaxbFragment.create(XjcXmlTransformer.unmarshal(XjcRdeHostElement.class, source)); - // Marshal - ByteArrayOutputStream bout = new ByteArrayOutputStream(); - new ObjectOutputStream(bout).writeObject(hostFragment); - // Unmarshal - ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(bout.toByteArray())); - JaxbFragment restoredHostFragment = - (JaxbFragment) in.readObject(); - // Verify host name - assertThat(restoredHostFragment.getInstance().getValue().getName()) - .isEqualTo("ns1.example1.test"); - } - } -} diff --git a/util/src/main/java/google/registry/util/SafeObjectInputStream.java b/util/src/main/java/google/registry/util/SafeObjectInputStream.java new file mode 100644 index 000000000..d07a68803 --- /dev/null +++ b/util/src/main/java/google/registry/util/SafeObjectInputStream.java @@ -0,0 +1,108 @@ +// Copyright 2023 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.util; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; + +import com.google.common.collect.ImmutableSet; +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectStreamClass; +import java.util.Collection; +import java.util.Map; + +/** + * Safely deserializes Nomulus http request parameters. + * + *

Serialized Java objects may be passed between Nomulus components that hold different + * credentials. Deserialization of such objects should be protected against attacks through + * compromised accounts. + * + *

This class protects against three types of attacks by restricting the classes used for + * serialization: + * + *

    + *
  • Remote code execution by referencing bad classes in compromised jars. When a class with + * malicious code in the static initialization block or the deserialization code path (e.g., + * the {@code readObject} method) is deserialized, such code will be executed. For Nomulus, + * this risk comes from third-party dependencies. To counter this risk, this class only allows + * Nomulus (google.registry.**) classes and specific core Java classes, and forbid others + * including third-party dependencies. (As a side note, this class does not use allow lists + * for Nomulus or third-party classes because it is infeasible in practice. Super classes of + * the instance being deserialized must be resolved, and therefore must be on the allow list; + * same for the field types of the instance. The allow list for the Joda {@code DateTime} + * class alone would have more than 10 classes. Generated classes, e.g., by AutoValue, present + * another problem: their real names are not meant to be a concern to the user). + *
  • CPU-targeting denial-of-service attacks. Containers and arrays may be used to construct + * object graphs that require enormous amount of computation during deserialization and/or + * during invocations of methods such as {@code hashCode} or {@code equals}, taking minutes or + * even hours to complete. See + * here for an example of such object graphs. To counter this risk, this class forbids + * lists, maps, and arrays for deserialization. + *
  • Memory-targeting denial-of-service attacks. By forbidding container and arrays, this class + * also prevents some memory-targeting attacks, e.g., using wire format that claims to be an + * array of a huge size, causing the JVM to preallocate excessive amount of memory and + * triggering the {@code OutOfMemoryError}. This is actually a small risk for Nomulus, since + * the impact of each error is limited to a single (spurious) request. + *
+ * + *

Nomulus classes with fields of array, container, or third-party Java types must implement + * their own serialization/deserialization methods to be safely deserialized. For the common use + * case of passing a collection of `safe` objects, {@link + * SafeSerializationUtils#serializeCollection} and {@link + * SafeSerializationUtils#safeDeserializeCollection} may be used. + */ +public final class SafeObjectInputStream extends ObjectInputStream { + + /** + * Core Java classes allowed in deserialization. Add new classes as needed but do not add + * third-party classes. + */ + private static final ImmutableSet ALLOWED_CORE_JAVA_CLASSES = + ImmutableSet.of(String.class, Byte.class, Short.class, Integer.class, Long.class).stream() + .map(Class::getName) + .collect(toImmutableSet()); + + public SafeObjectInputStream(InputStream in) throws IOException { + super(in); + } + + @Override + protected Class resolveClass(ObjectStreamClass desc) + throws ClassNotFoundException, IOException { + String clazz = desc.getName(); + if (isNomulusClass(clazz) || ALLOWED_CORE_JAVA_CLASSES.contains(clazz)) { + return checkNotArrayOrContainer(super.resolveClass(desc)); + } + throw new ClassNotFoundException(clazz + " not found or not allowed in deserialization."); + } + + private Class checkNotArrayOrContainer(Class clazz) throws ClassNotFoundException { + if (isContainer(clazz) || clazz.isArray()) { + throw new ClassNotFoundException(clazz.getName() + " not allowed as non-root object."); + } + return clazz; + } + + private boolean isNomulusClass(String clazz) { + return clazz.startsWith("google.registry."); + } + + private boolean isContainer(Class clazz) { + return Collection.class.isAssignableFrom(clazz) || Map.class.isAssignableFrom(clazz); + } +} diff --git a/util/src/main/java/google/registry/util/SafeSerializationUtils.java b/util/src/main/java/google/registry/util/SafeSerializationUtils.java new file mode 100644 index 000000000..9c980558b --- /dev/null +++ b/util/src/main/java/google/registry/util/SafeSerializationUtils.java @@ -0,0 +1,104 @@ +// Copyright 2023 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.util; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableList; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collection; +import javax.annotation.Nullable; + +/** + * Helpers for using {@link SafeObjectInputStream}. + * + *

Please refer to {@code SafeObjectInputStream} for more information. + */ +public final class SafeSerializationUtils { + + private SafeSerializationUtils() {} + + /** + * Maximum number of elements allowed in a serialized collection. + * + *

This value is sufficient for parameters embedded in a {@code URL} to typical cloud services. + * E.g., as of Fall 2023, AWS limits request line size to 16KB and GCP limits total header size to + * 64KB. + */ + public static final int MAX_COLLECTION_SIZE = 32768; + + /** + * Serializes a collection of objects that can be safely deserialized using {@link + * #safeDeserializeCollection}. + * + *

If any element of the collection cannot be safely-deserialized, deserialization will fail. + */ + public static byte[] serializeCollection(Collection collection) { + checkNotNull(collection, "collection"); + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try (ObjectOutputStream os = new ObjectOutputStream(bos)) { + os.writeInt(collection.size()); + for (Object obj : collection) { + os.writeObject(obj); + } + } catch (IOException e) { + throw new RuntimeException("Failed to serialize: " + collection, e); + } + return bos.toByteArray(); + } + + /** Safely deserializes an object using {@link SafeObjectInputStream}. */ + @Nullable + public static Serializable safeDeserialize(@Nullable byte[] bytes) { + if (bytes == null) { + return null; + } + try (ObjectInputStream is = new SafeObjectInputStream(new ByteArrayInputStream(bytes))) { + Serializable ret = (Serializable) is.readObject(); + return ret; + } catch (IOException | ClassNotFoundException e) { + throw new IllegalArgumentException("Failed to deserialize: " + Arrays.toString(bytes), e); + } + } + + /** + * Safely deserializes a collection of objects previously serialized with {@link + * #serializeCollection}. + */ + public static ImmutableList safeDeserializeCollection(Class elementType, byte[] bytes) { + checkNotNull(bytes, "Serialized list must not be null."); + try (ObjectInputStream is = new SafeObjectInputStream(new ByteArrayInputStream(bytes))) { + int size = is.readInt(); + checkArgument(size >= 0, "Malformed data: negative collection size."); + if (size > MAX_COLLECTION_SIZE) { + throw new IllegalArgumentException("Too many elements in collection: " + size); + } + ImmutableList.Builder builder = new ImmutableList.Builder<>(); + for (int i = 0; i < size; i++) { + builder.add(elementType.cast(is.readObject())); + } + return builder.build(); + } catch (IOException | ClassNotFoundException | ClassCastException e) { + throw new IllegalArgumentException("Failed to deserialize: " + Arrays.toString(bytes), e); + } + } +} diff --git a/util/src/main/java/google/registry/util/SerializeUtils.java b/util/src/main/java/google/registry/util/SerializeUtils.java index 6ed3cf1a8..130ad3212 100644 --- a/util/src/main/java/google/registry/util/SerializeUtils.java +++ b/util/src/main/java/google/registry/util/SerializeUtils.java @@ -74,10 +74,20 @@ public final class SerializeUtils { private SerializeUtils() {} + /** Encodes a byte array as a URL-safe string. */ + public static String encodeBase64(byte[] bytes) { + return Base64.encodeBase64URLSafeString(bytes); + } + + /** Turns a string encoded by {@link #encodeBase64} back into a byte array. */ + public static byte[] decodeBase64(String objectString) { + return Base64.decodeBase64(objectString); + } + /** Turns an object into an encoded string that can be used safely as a URI query parameter. */ public static String stringify(Serializable object) { checkNotNull(object, "Object cannot be null"); - return Base64.encodeBase64URLSafeString(SerializeUtils.serialize(object)); + return encodeBase64(SerializeUtils.serialize(object)); } /** Turns a string encoded by stringify() into an object. */ @@ -86,6 +96,6 @@ public final class SerializeUtils { checkNotNull(type, "Class type is not specified"); checkNotNull(objectString, "Object string cannot be null"); - return SerializeUtils.deserialize(type, Base64.decodeBase64(objectString)); + return SerializeUtils.deserialize(type, decodeBase64(objectString)); } } diff --git a/util/src/test/java/google/registry/util/SafeObjectInputStreamTest.java b/util/src/test/java/google/registry/util/SafeObjectInputStreamTest.java new file mode 100644 index 000000000..66e606438 --- /dev/null +++ b/util/src/test/java/google/registry/util/SafeObjectInputStreamTest.java @@ -0,0 +1,126 @@ +// Copyright 2023 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.util; + +import static com.google.common.collect.Lists.newArrayList; +import static com.google.common.truth.Truth.assertThat; +import static google.registry.util.SerializeUtils.serialize; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.common.base.Objects; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; +import java.io.ByteArrayInputStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import org.joda.time.Duration; +import org.junit.jupiter.api.Test; + +/** Unit tests for {@link SafeObjectInputStream}. */ +public class SafeObjectInputStreamTest { + + @Test + void javaUnitarySuccess() throws Exception { + String orig = "some string"; + try (SafeObjectInputStream sois = + new SafeObjectInputStream(new ByteArrayInputStream(serialize(orig)))) { + assertThat(sois.readObject()).isEqualTo(orig); + } + } + + @Test + void javaCollectionFailure() throws Exception { + ArrayList orig = newArrayList("a"); + try (SafeObjectInputStream sois = + new SafeObjectInputStream(new ByteArrayInputStream(serialize(orig)))) { + assertThrows(ClassNotFoundException.class, () -> sois.readObject()); + } + } + + @Test + void javaMapFailure() throws Exception { + HashMap orig = Maps.newHashMap(); + try (SafeObjectInputStream sois = + new SafeObjectInputStream(new ByteArrayInputStream(serialize(orig)))) { + assertThrows(ClassNotFoundException.class, () -> sois.readObject()); + } + } + + @Test + void javaArrayFailure() throws Exception { + int[] orig = new int[] {1}; + try (SafeObjectInputStream sois = + new SafeObjectInputStream(new ByteArrayInputStream(serialize(orig)))) { + // For array, the parent class converts ClassNotFoundException in an undocumented way. Safer + // to catch Exception than the one thrown by the current JVM. + assertThrows(Exception.class, () -> sois.readObject()); + } + } + + @Test + void nonJavaNonNomulusUnitaryFailure() throws Exception { + Serializable orig = Duration.millis(1); + try (SafeObjectInputStream sois = + new SafeObjectInputStream(new ByteArrayInputStream(serialize(orig)))) { + assertThrows(ClassNotFoundException.class, () -> sois.readObject()); + } + } + + @Test + void nonJavaCollectionFailure() throws Exception { + ImmutableList orig = ImmutableList.of("a"); + try (SafeObjectInputStream sois = + new SafeObjectInputStream(new ByteArrayInputStream(serialize(orig)))) { + assertThrows(ClassNotFoundException.class, () -> sois.readObject()); + } + } + + @Test + void nomulusEntitySuccess() throws Exception { + NomulusEntity orig = new NomulusEntity(1); + byte[] serialized = serialize(orig); + try (SafeObjectInputStream sois = + new SafeObjectInputStream(new ByteArrayInputStream(serialized))) { + Object deserialized = sois.readObject(); + assertThat(deserialized).isEqualTo(orig); + } + } + + static class NomulusEntity implements Serializable { + Integer value; + + NomulusEntity(int value) { + this.value = value; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof NomulusEntity)) { + return false; + } + NomulusEntity that = (NomulusEntity) o; + return Objects.equal(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hashCode(value); + } + } +} diff --git a/util/src/test/java/google/registry/util/SafeSerializationUtilsTest.java b/util/src/test/java/google/registry/util/SafeSerializationUtilsTest.java new file mode 100644 index 000000000..47fc93f89 --- /dev/null +++ b/util/src/test/java/google/registry/util/SafeSerializationUtilsTest.java @@ -0,0 +1,110 @@ +// Copyright 2023 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.util; + +import static com.google.common.collect.Lists.newArrayList; +import static com.google.common.truth.Truth.assertThat; +import static google.registry.util.SafeSerializationUtils.safeDeserialize; +import static google.registry.util.SafeSerializationUtils.safeDeserializeCollection; +import static google.registry.util.SafeSerializationUtils.serializeCollection; +import static google.registry.util.SerializeUtils.serialize; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.ArrayList; +import java.util.Arrays; +import org.junit.jupiter.api.Test; + +/** Unit tests for {@link SafeSerializationUtils}. */ +public class SafeSerializationUtilsTest { + + @Test + void deserialize_array_failure() { + assertThat( + assertThrows( + IllegalArgumentException.class, () -> safeDeserialize(serialize(new byte[0])))) + .hasMessageThat() + .contains("Failed to deserialize:"); + } + + @Test + void deserialize_null_success() { + assertThat(safeDeserialize(serialize(null))).isNull(); + } + + @Test + void deserialize_map_failure() { + assertThat( + assertThrows( + IllegalArgumentException.class, + () -> safeDeserialize(serialize(ImmutableMap.of())))) + .hasMessageThat() + .contains("Failed to deserialize:"); + } + + @Test + void serializeDeserialize_null_success() { + assertThat(safeDeserialize(null)).isNull(); + } + + @Test + void serializeDeserialize_notCollection_success() { + Integer orig = 1; + assertThat(safeDeserialize(serialize(orig))).isEqualTo(orig); + } + + @Test + void serializeDeserializeCollection_success() { + ArrayList orig = newArrayList(1, 2, 3); + ImmutableList deserialized = + safeDeserializeCollection(Integer.class, serializeCollection(orig)); + assertThat(deserialized).isEqualTo(orig); + } + + @Test + void serializeDeserializeCollection_withMaxSize_success() { + Integer[] array = new Integer[SafeSerializationUtils.MAX_COLLECTION_SIZE]; + Arrays.fill(array, 1); + ArrayList orig = newArrayList(array); + assertThat(safeDeserializeCollection(Integer.class, serializeCollection(orig))).isEqualTo(orig); + } + + @Test + void serializeDeserializeCollection_tooLarge_Failure() { + Integer[] array = new Integer[SafeSerializationUtils.MAX_COLLECTION_SIZE + 1]; + Arrays.fill(array, 1); + ArrayList orig = newArrayList(array); + assertThat( + assertThrows( + IllegalArgumentException.class, + () -> safeDeserializeCollection(Integer.class, serializeCollection(orig)))) + .hasMessageThat() + .contains("Too many elements"); + } + + @Test + void serializeDeserializeCollection_wrong_elementType_success() { + ArrayList orig = newArrayList(1, 2, 3); + assertThrows( + IllegalArgumentException.class, + () -> safeDeserializeCollection(Long.class, serializeCollection(orig))); + } + + @Test + void deserializeCollection_null_failure() { + assertThrows(NullPointerException.class, () -> safeDeserializeCollection(Integer.class, null)); + } +}