From 46fdf2c996c9eb311dc9b3240b1fcce25c7ac609 Mon Sep 17 00:00:00 2001 From: Weimin Yu Date: Wed, 20 Sep 2023 16:56:56 -0400 Subject: [PATCH] Defend against deserialization-based attacks (#2150) * Defend against deserialization-based attacks Added the `SafeObjectInputStream` class that defends attacks using malformed serialized data, including remote code execution and denial-of-service attacks. Started using the new class to handle EPP resource VKeys and PendingDeposits, which are passed across credential-boundaries: between TaskQueue and AppEngine server, and between AppEngine server and the RDE pipeline on GCE. Note that the wireformat of VKeys do not change, therefore existing tasks sitting in the TaskQueue are not affected. Also removed an unused class: JaxbFragment. --- .../google/registry/beam/rde/RdePipeline.java | 25 +--- .../google/registry/persistence/VKey.java | 4 +- .../google/registry/rde/PendingDeposit.java | 92 +++++++++++-- .../google/registry/xjc/JaxbFragment.java | 105 --------------- .../registry/rde/PendingDepositTest.java | 59 ++++++++ .../google/registry/xjc/JaxbFragmentTest.java | 56 -------- .../registry/util/SafeObjectInputStream.java | 108 +++++++++++++++ .../registry/util/SafeSerializationUtils.java | 104 +++++++++++++++ .../google/registry/util/SerializeUtils.java | 14 +- .../util/SafeObjectInputStreamTest.java | 126 ++++++++++++++++++ .../util/SafeSerializationUtilsTest.java | 110 +++++++++++++++ 11 files changed, 610 insertions(+), 193 deletions(-) delete mode 100644 core/src/main/java/google/registry/xjc/JaxbFragment.java create mode 100644 core/src/test/java/google/registry/rde/PendingDepositTest.java delete mode 100644 core/src/test/java/google/registry/xjc/JaxbFragmentTest.java create mode 100644 util/src/main/java/google/registry/util/SafeObjectInputStream.java create mode 100644 util/src/main/java/google/registry/util/SafeSerializationUtils.java create mode 100644 util/src/test/java/google/registry/util/SafeObjectInputStreamTest.java create mode 100644 util/src/test/java/google/registry/util/SafeSerializationUtilsTest.java diff --git a/core/src/main/java/google/registry/beam/rde/RdePipeline.java b/core/src/main/java/google/registry/beam/rde/RdePipeline.java index 549e30c13..c30088ea4 100644 --- a/core/src/main/java/google/registry/beam/rde/RdePipeline.java +++ b/core/src/main/java/google/registry/beam/rde/RdePipeline.java @@ -27,6 +27,10 @@ import static google.registry.beam.rde.RdePipeline.TupleTags.REVISION_ID; import static google.registry.beam.rde.RdePipeline.TupleTags.SUPERORDINATE_DOMAINS; import static google.registry.model.reporting.HistoryEntryDao.RESOURCE_TYPES_TO_HISTORY_TYPES; import static google.registry.persistence.transaction.TransactionManagerFactory.tm; +import static google.registry.util.SafeSerializationUtils.safeDeserializeCollection; +import static google.registry.util.SafeSerializationUtils.serializeCollection; +import static google.registry.util.SerializeUtils.decodeBase64; +import static google.registry.util.SerializeUtils.encodeBase64; import static org.apache.beam.sdk.values.TypeDescriptors.kvs; import com.google.common.collect.ImmutableList; @@ -65,11 +69,7 @@ import google.registry.rde.PendingDeposit.PendingDepositCoder; import google.registry.rde.RdeMarshaller; import google.registry.util.UtilsModule; import google.registry.xml.ValidationMode; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.HashSet; import javax.inject.Inject; @@ -658,14 +658,8 @@ public class RdePipeline implements Serializable { */ @SuppressWarnings("unchecked") static ImmutableSet decodePendingDeposits(String encodedPendingDeposits) { - try (ObjectInputStream ois = - new ObjectInputStream( - new ByteArrayInputStream( - BaseEncoding.base64Url().omitPadding().decode(encodedPendingDeposits)))) { - return (ImmutableSet) ois.readObject(); - } catch (IOException | ClassNotFoundException e) { - throw new IllegalArgumentException("Unable to parse encoded pending deposit map.", e); - } + return ImmutableSet.copyOf( + safeDeserializeCollection(PendingDeposit.class, decodeBase64(encodedPendingDeposits))); } /** @@ -674,12 +668,7 @@ public class RdePipeline implements Serializable { */ public static String encodePendingDeposits(ImmutableSet pendingDeposits) throws IOException { - try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { - ObjectOutputStream oos = new ObjectOutputStream(baos); - oos.writeObject(pendingDeposits); - oos.flush(); - return BaseEncoding.base64Url().omitPadding().encode(baos.toByteArray()); - } + return encodeBase64(serializeCollection(pendingDeposits)); } public static void main(String[] args) throws IOException, ClassNotFoundException { diff --git a/core/src/main/java/google/registry/persistence/VKey.java b/core/src/main/java/google/registry/persistence/VKey.java index c2222e4a3..b2c9fb83b 100644 --- a/core/src/main/java/google/registry/persistence/VKey.java +++ b/core/src/main/java/google/registry/persistence/VKey.java @@ -16,6 +16,8 @@ package google.registry.persistence; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static google.registry.util.PreconditionsUtils.checkArgumentNotNull; +import static google.registry.util.SafeSerializationUtils.safeDeserialize; +import static google.registry.util.SerializeUtils.decodeBase64; import static java.util.function.Function.identity; import com.google.common.base.Joiner; @@ -97,7 +99,7 @@ public class VKey extends ImmutableObject implements Serializable { throw new IllegalArgumentException( String.format("\"%s\" missing from the string: %s", LOOKUP_KEY, keyString)); } - return VKey.create(classType, SerializeUtils.parse(Serializable.class, kvs.get(LOOKUP_KEY))); + return VKey.create(classType, safeDeserialize(decodeBase64(kvs.get(LOOKUP_KEY)))); } /** Returns the type of the entity. */ diff --git a/core/src/main/java/google/registry/rde/PendingDeposit.java b/core/src/main/java/google/registry/rde/PendingDeposit.java index 1ba7e2a61..dbd97962d 100644 --- a/core/src/main/java/google/registry/rde/PendingDeposit.java +++ b/core/src/main/java/google/registry/rde/PendingDeposit.java @@ -14,18 +14,23 @@ package google.registry.rde; +import static com.google.common.base.Preconditions.checkState; + import com.google.auto.value.AutoValue; import google.registry.model.common.Cursor.CursorType; import google.registry.model.rde.RdeMode; import java.io.IOException; import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.ObjectStreamException; import java.io.OutputStream; import java.io.Serializable; +import java.util.Optional; import javax.annotation.Nullable; import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.BooleanCoder; import org.apache.beam.sdk.coders.NullableCoder; -import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.joda.time.DateTime; @@ -35,6 +40,12 @@ import org.joda.time.Duration; * Container representing a single RDE or BRDA XML escrow deposit that needs to be created. * *

There are some {@code @Nullable} fields here because Optionals aren't Serializable. + * + *

Note that this class is serialized in two ways: by Beam pipelines using custom serialization + * mechanism and the {@code Coder} API, and by Java serialization when passed as command-line + * arguments (see {@code RdePipeline#decodePendingDeposits}). The latter requires safe + * deserialization because the data crosses credential boundaries (See {@code + * SafeObjectInputStream}). */ @AutoValue public abstract class PendingDeposit implements Serializable { @@ -95,11 +106,61 @@ public abstract class PendingDeposit implements Serializable { PendingDeposit() {} + /** + * Specifies that {@link SerializedForm} be used for {@code SafeObjectInputStream}-compatible + * custom-serialization of {@link AutoValue_PendingDeposit the AutoValue implementation class}. + * + *

This method is package-protected so that the AutoValue implementation class inherits this + * behavior. + * + *

This method leverages {@link PendingDepositCoder} to serializes an instance. However, it is + * not invoked in Beam pipelines. + */ + Object writeReplace() throws ObjectStreamException { + return new SerializedForm(this); + } + + /** + * Proxy for custom-serialization of {@link PendingDeposit}. This is necessary because the actual + * class to be (de)serialized is the generated AutoValue implementation. See also {@link + * #writeReplace}. + * + *

This class leverages {@link PendingDepositCoder} to safely deserializes an instance. + * However, it is not used in Beam pipelines. + */ + private static class SerializedForm implements Serializable { + + private static final long serialVersionUID = 3141095605225904433L; + + private PendingDeposit value; + + private SerializedForm(PendingDeposit value) { + this.value = value; + } + + private void writeObject(ObjectOutputStream os) throws IOException { + checkState(value != null, "Non-null value expected for serialization."); + PendingDepositCoder.INSTANCE.encode(value, os); + } + + private void readObject(ObjectInputStream is) throws IOException, ClassNotFoundException { + checkState(value == null, "Non-null value unexpected for deserialization."); + this.value = PendingDepositCoder.INSTANCE.decode(is); + } + + @SuppressWarnings("unused") + private Object readResolve() throws ObjectStreamException { + return this.value; + } + } + /** * A deterministic coder for {@link PendingDeposit} used during a GroupBy transform. * - *

We cannot use a {@link SerializableCoder} directly because it does not guarantee - * determinism, which is required by GroupBy. + *

We cannot use a {@code SerializableCoder} directly for two reasons: the default + * serialization does not guarantee determinism, which is required by GroupBy in Beam; and the + * default deserialization is not robust against deserialization-based attacks (See {@code + * SafeObjectInputStream} for more information). */ public static class PendingDepositCoder extends AtomicCoder { @@ -117,10 +178,15 @@ public abstract class PendingDeposit implements Serializable { public void encode(PendingDeposit value, OutputStream outStream) throws IOException { BooleanCoder.of().encode(value.manual(), outStream); StringUtf8Coder.of().encode(value.tld(), outStream); - SerializableCoder.of(DateTime.class).encode(value.watermark(), outStream); - SerializableCoder.of(RdeMode.class).encode(value.mode(), outStream); - NullableCoder.of(SerializableCoder.of(CursorType.class)).encode(value.cursor(), outStream); - NullableCoder.of(SerializableCoder.of(Duration.class)).encode(value.interval(), outStream); + StringUtf8Coder.of().encode(value.watermark().toString(), outStream); + StringUtf8Coder.of().encode(value.mode().name(), outStream); + NullableCoder.of(StringUtf8Coder.of()) + .encode( + Optional.ofNullable(value.cursor()).map(CursorType::name).orElse(null), outStream); + NullableCoder.of(StringUtf8Coder.of()) + .encode( + Optional.ofNullable(value.interval()).map(Duration::toString).orElse(null), + outStream); NullableCoder.of(StringUtf8Coder.of()).encode(value.directoryWithTrailingSlash(), outStream); NullableCoder.of(VarIntCoder.of()).encode(value.revision(), outStream); } @@ -130,10 +196,14 @@ public abstract class PendingDeposit implements Serializable { return new AutoValue_PendingDeposit( BooleanCoder.of().decode(inStream), StringUtf8Coder.of().decode(inStream), - SerializableCoder.of(DateTime.class).decode(inStream), - SerializableCoder.of(RdeMode.class).decode(inStream), - NullableCoder.of(SerializableCoder.of(CursorType.class)).decode(inStream), - NullableCoder.of(SerializableCoder.of(Duration.class)).decode(inStream), + DateTime.parse(StringUtf8Coder.of().decode(inStream)), + RdeMode.valueOf(StringUtf8Coder.of().decode(inStream)), + Optional.ofNullable(NullableCoder.of(StringUtf8Coder.of()).decode(inStream)) + .map(CursorType::valueOf) + .orElse(null), + Optional.ofNullable(NullableCoder.of(StringUtf8Coder.of()).decode(inStream)) + .map(Duration::parse) + .orElse(null), NullableCoder.of(StringUtf8Coder.of()).decode(inStream), NullableCoder.of(VarIntCoder.of()).decode(inStream)); } diff --git a/core/src/main/java/google/registry/xjc/JaxbFragment.java b/core/src/main/java/google/registry/xjc/JaxbFragment.java deleted file mode 100644 index 44fef7966..000000000 --- a/core/src/main/java/google/registry/xjc/JaxbFragment.java +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2017 The Nomulus Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package google.registry.xjc; - -import static java.nio.charset.StandardCharsets.UTF_8; - -import google.registry.xml.XmlException; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.io.Serializable; - -/** - * JAXB element wrapper for java object serialization. - * - * Instances of {@link JaxbFragment} wrap a non-serializable JAXB element instance, and provide - * hooks into the java object serialization process that allow the elements to be safely - * marshalled and unmarshalled using {@link ObjectOutputStream} and {@link ObjectInputStream}, - * respectively. - */ -public class JaxbFragment implements Serializable { - - private static final long serialVersionUID = 5651243983008818813L; - - private T instance; - - /** Stores a JAXB element in a {@link JaxbFragment} */ - public static JaxbFragment create(T object) { - JaxbFragment fragment = new JaxbFragment<>(); - fragment.instance = object; - return fragment; - } - - /** Serializes a JAXB element into xml bytes. */ - private static byte[] freezeInstance(T instance) throws IOException { - try { - ByteArrayOutputStream bout = new ByteArrayOutputStream(); - XjcXmlTransformer.marshalLenient(instance, bout, UTF_8); - return bout.toByteArray(); - } catch (XmlException e) { - throw new IOException(e); - } - } - - /** Deserializes a JAXB element from xml bytes. */ - private static T unfreezeInstance(byte[] instanceData, Class instanceType) - throws IOException { - try { - ByteArrayInputStream bin = new ByteArrayInputStream(instanceData); - return XjcXmlTransformer.unmarshal(instanceType, bin); - } catch (XmlException e) { - throw new IOException(e); - } - } - - /** - * Retrieves the JAXB element that is wrapped by this fragment. - */ - public T getInstance() { - return instance; - } - - @Override - public String toString() { - try { - return new String(freezeInstance(instance), UTF_8); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private void writeObject(ObjectOutputStream out) throws IOException { - // write instanceType, then instanceData - out.writeObject(instance.getClass()); - out.writeObject(freezeInstance(instance)); - } - - @SuppressWarnings("unchecked") - private void readObject(ObjectInputStream in) throws IOException { - // read instanceType, then instanceData - Class instanceType; - byte[] instanceData; - try { - instanceType = (Class) in.readObject(); - instanceData = (byte[]) in.readObject(); - } catch (ClassNotFoundException e) { - throw new RuntimeException(e); - } - instance = unfreezeInstance(instanceData, instanceType); - } -} diff --git a/core/src/test/java/google/registry/rde/PendingDepositTest.java b/core/src/test/java/google/registry/rde/PendingDepositTest.java new file mode 100644 index 000000000..2f31d942b --- /dev/null +++ b/core/src/test/java/google/registry/rde/PendingDepositTest.java @@ -0,0 +1,59 @@ +// Copyright 2023 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.rde; + +import static com.google.common.truth.Truth.assertThat; +import static google.registry.model.common.Cursor.CursorType.RDE_STAGING; +import static google.registry.model.rde.RdeMode.FULL; +import static google.registry.util.SafeSerializationUtils.safeDeserialize; +import static google.registry.util.SerializeUtils.deserialize; +import static google.registry.util.SerializeUtils.serialize; + +import org.joda.time.DateTime; +import org.joda.time.Duration; +import org.junit.jupiter.api.Test; + +/** Unit tests for {@link PendingDeposit}. */ +public class PendingDepositTest { + private final DateTime now = DateTime.parse("2000-01-01TZ"); + + PendingDeposit pendingDeposit = + PendingDeposit.create("soy", now, FULL, RDE_STAGING, Duration.standardDays(1)); + + PendingDeposit manualPendingDeposit = + PendingDeposit.createInManualOperation("soy", now, FULL, "/", null); + + @Test + void deserialize_normalDeposit_success() { + assertThat(deserialize(PendingDeposit.class, serialize(pendingDeposit))) + .isEqualTo(pendingDeposit); + } + + @Test + void deserialize_manualDeposit_success() { + assertThat(deserialize(PendingDeposit.class, serialize(manualPendingDeposit))) + .isEqualTo(manualPendingDeposit); + } + + @Test + void safeDeserialize_normalDeposit_success() { + assertThat(safeDeserialize(serialize(pendingDeposit))).isEqualTo(pendingDeposit); + } + + @Test + void safeDeserialize_manualDeposit_success() { + assertThat(safeDeserialize(serialize(manualPendingDeposit))).isEqualTo(manualPendingDeposit); + } +} diff --git a/core/src/test/java/google/registry/xjc/JaxbFragmentTest.java b/core/src/test/java/google/registry/xjc/JaxbFragmentTest.java deleted file mode 100644 index bc1059fd0..000000000 --- a/core/src/test/java/google/registry/xjc/JaxbFragmentTest.java +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2017 The Nomulus Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package google.registry.xjc; - -import static com.google.common.truth.Truth.assertThat; -import static google.registry.testing.TestDataHelper.loadFile; -import static java.nio.charset.StandardCharsets.UTF_8; - -import google.registry.xjc.rdehost.XjcRdeHostElement; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.InputStream; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import org.junit.jupiter.api.Test; - -/** Unit tests for {@link JaxbFragment}. */ -class JaxbFragmentTest { - - private static final String HOST_FRAGMENT = loadFile(XjcObjectTest.class, "host_fragment.xml"); - - /** Verifies that a {@link JaxbFragment} can be serialized and deserialized successfully. */ - @SuppressWarnings("unchecked") - @Test - void testJavaSerialization() throws Exception { - // Load rdeHost xml fragment into a jaxb object, wrap it, marshal, unmarshal, verify host. - // The resulting host name should be "ns1.example1.test", from the original xml fragment. - try (InputStream source = new ByteArrayInputStream(HOST_FRAGMENT.getBytes(UTF_8))) { - // Load xml - JaxbFragment hostFragment = - JaxbFragment.create(XjcXmlTransformer.unmarshal(XjcRdeHostElement.class, source)); - // Marshal - ByteArrayOutputStream bout = new ByteArrayOutputStream(); - new ObjectOutputStream(bout).writeObject(hostFragment); - // Unmarshal - ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(bout.toByteArray())); - JaxbFragment restoredHostFragment = - (JaxbFragment) in.readObject(); - // Verify host name - assertThat(restoredHostFragment.getInstance().getValue().getName()) - .isEqualTo("ns1.example1.test"); - } - } -} diff --git a/util/src/main/java/google/registry/util/SafeObjectInputStream.java b/util/src/main/java/google/registry/util/SafeObjectInputStream.java new file mode 100644 index 000000000..d07a68803 --- /dev/null +++ b/util/src/main/java/google/registry/util/SafeObjectInputStream.java @@ -0,0 +1,108 @@ +// Copyright 2023 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.util; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; + +import com.google.common.collect.ImmutableSet; +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectStreamClass; +import java.util.Collection; +import java.util.Map; + +/** + * Safely deserializes Nomulus http request parameters. + * + *

Serialized Java objects may be passed between Nomulus components that hold different + * credentials. Deserialization of such objects should be protected against attacks through + * compromised accounts. + * + *

This class protects against three types of attacks by restricting the classes used for + * serialization: + * + *

    + *
  • Remote code execution by referencing bad classes in compromised jars. When a class with + * malicious code in the static initialization block or the deserialization code path (e.g., + * the {@code readObject} method) is deserialized, such code will be executed. For Nomulus, + * this risk comes from third-party dependencies. To counter this risk, this class only allows + * Nomulus (google.registry.**) classes and specific core Java classes, and forbid others + * including third-party dependencies. (As a side note, this class does not use allow lists + * for Nomulus or third-party classes because it is infeasible in practice. Super classes of + * the instance being deserialized must be resolved, and therefore must be on the allow list; + * same for the field types of the instance. The allow list for the Joda {@code DateTime} + * class alone would have more than 10 classes. Generated classes, e.g., by AutoValue, present + * another problem: their real names are not meant to be a concern to the user). + *
  • CPU-targeting denial-of-service attacks. Containers and arrays may be used to construct + * object graphs that require enormous amount of computation during deserialization and/or + * during invocations of methods such as {@code hashCode} or {@code equals}, taking minutes or + * even hours to complete. See + * here for an example of such object graphs. To counter this risk, this class forbids + * lists, maps, and arrays for deserialization. + *
  • Memory-targeting denial-of-service attacks. By forbidding container and arrays, this class + * also prevents some memory-targeting attacks, e.g., using wire format that claims to be an + * array of a huge size, causing the JVM to preallocate excessive amount of memory and + * triggering the {@code OutOfMemoryError}. This is actually a small risk for Nomulus, since + * the impact of each error is limited to a single (spurious) request. + *
+ * + *

Nomulus classes with fields of array, container, or third-party Java types must implement + * their own serialization/deserialization methods to be safely deserialized. For the common use + * case of passing a collection of `safe` objects, {@link + * SafeSerializationUtils#serializeCollection} and {@link + * SafeSerializationUtils#safeDeserializeCollection} may be used. + */ +public final class SafeObjectInputStream extends ObjectInputStream { + + /** + * Core Java classes allowed in deserialization. Add new classes as needed but do not add + * third-party classes. + */ + private static final ImmutableSet ALLOWED_CORE_JAVA_CLASSES = + ImmutableSet.of(String.class, Byte.class, Short.class, Integer.class, Long.class).stream() + .map(Class::getName) + .collect(toImmutableSet()); + + public SafeObjectInputStream(InputStream in) throws IOException { + super(in); + } + + @Override + protected Class resolveClass(ObjectStreamClass desc) + throws ClassNotFoundException, IOException { + String clazz = desc.getName(); + if (isNomulusClass(clazz) || ALLOWED_CORE_JAVA_CLASSES.contains(clazz)) { + return checkNotArrayOrContainer(super.resolveClass(desc)); + } + throw new ClassNotFoundException(clazz + " not found or not allowed in deserialization."); + } + + private Class checkNotArrayOrContainer(Class clazz) throws ClassNotFoundException { + if (isContainer(clazz) || clazz.isArray()) { + throw new ClassNotFoundException(clazz.getName() + " not allowed as non-root object."); + } + return clazz; + } + + private boolean isNomulusClass(String clazz) { + return clazz.startsWith("google.registry."); + } + + private boolean isContainer(Class clazz) { + return Collection.class.isAssignableFrom(clazz) || Map.class.isAssignableFrom(clazz); + } +} diff --git a/util/src/main/java/google/registry/util/SafeSerializationUtils.java b/util/src/main/java/google/registry/util/SafeSerializationUtils.java new file mode 100644 index 000000000..9c980558b --- /dev/null +++ b/util/src/main/java/google/registry/util/SafeSerializationUtils.java @@ -0,0 +1,104 @@ +// Copyright 2023 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.util; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableList; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collection; +import javax.annotation.Nullable; + +/** + * Helpers for using {@link SafeObjectInputStream}. + * + *

Please refer to {@code SafeObjectInputStream} for more information. + */ +public final class SafeSerializationUtils { + + private SafeSerializationUtils() {} + + /** + * Maximum number of elements allowed in a serialized collection. + * + *

This value is sufficient for parameters embedded in a {@code URL} to typical cloud services. + * E.g., as of Fall 2023, AWS limits request line size to 16KB and GCP limits total header size to + * 64KB. + */ + public static final int MAX_COLLECTION_SIZE = 32768; + + /** + * Serializes a collection of objects that can be safely deserialized using {@link + * #safeDeserializeCollection}. + * + *

If any element of the collection cannot be safely-deserialized, deserialization will fail. + */ + public static byte[] serializeCollection(Collection collection) { + checkNotNull(collection, "collection"); + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try (ObjectOutputStream os = new ObjectOutputStream(bos)) { + os.writeInt(collection.size()); + for (Object obj : collection) { + os.writeObject(obj); + } + } catch (IOException e) { + throw new RuntimeException("Failed to serialize: " + collection, e); + } + return bos.toByteArray(); + } + + /** Safely deserializes an object using {@link SafeObjectInputStream}. */ + @Nullable + public static Serializable safeDeserialize(@Nullable byte[] bytes) { + if (bytes == null) { + return null; + } + try (ObjectInputStream is = new SafeObjectInputStream(new ByteArrayInputStream(bytes))) { + Serializable ret = (Serializable) is.readObject(); + return ret; + } catch (IOException | ClassNotFoundException e) { + throw new IllegalArgumentException("Failed to deserialize: " + Arrays.toString(bytes), e); + } + } + + /** + * Safely deserializes a collection of objects previously serialized with {@link + * #serializeCollection}. + */ + public static ImmutableList safeDeserializeCollection(Class elementType, byte[] bytes) { + checkNotNull(bytes, "Serialized list must not be null."); + try (ObjectInputStream is = new SafeObjectInputStream(new ByteArrayInputStream(bytes))) { + int size = is.readInt(); + checkArgument(size >= 0, "Malformed data: negative collection size."); + if (size > MAX_COLLECTION_SIZE) { + throw new IllegalArgumentException("Too many elements in collection: " + size); + } + ImmutableList.Builder builder = new ImmutableList.Builder<>(); + for (int i = 0; i < size; i++) { + builder.add(elementType.cast(is.readObject())); + } + return builder.build(); + } catch (IOException | ClassNotFoundException | ClassCastException e) { + throw new IllegalArgumentException("Failed to deserialize: " + Arrays.toString(bytes), e); + } + } +} diff --git a/util/src/main/java/google/registry/util/SerializeUtils.java b/util/src/main/java/google/registry/util/SerializeUtils.java index 6ed3cf1a8..130ad3212 100644 --- a/util/src/main/java/google/registry/util/SerializeUtils.java +++ b/util/src/main/java/google/registry/util/SerializeUtils.java @@ -74,10 +74,20 @@ public final class SerializeUtils { private SerializeUtils() {} + /** Encodes a byte array as a URL-safe string. */ + public static String encodeBase64(byte[] bytes) { + return Base64.encodeBase64URLSafeString(bytes); + } + + /** Turns a string encoded by {@link #encodeBase64} back into a byte array. */ + public static byte[] decodeBase64(String objectString) { + return Base64.decodeBase64(objectString); + } + /** Turns an object into an encoded string that can be used safely as a URI query parameter. */ public static String stringify(Serializable object) { checkNotNull(object, "Object cannot be null"); - return Base64.encodeBase64URLSafeString(SerializeUtils.serialize(object)); + return encodeBase64(SerializeUtils.serialize(object)); } /** Turns a string encoded by stringify() into an object. */ @@ -86,6 +96,6 @@ public final class SerializeUtils { checkNotNull(type, "Class type is not specified"); checkNotNull(objectString, "Object string cannot be null"); - return SerializeUtils.deserialize(type, Base64.decodeBase64(objectString)); + return SerializeUtils.deserialize(type, decodeBase64(objectString)); } } diff --git a/util/src/test/java/google/registry/util/SafeObjectInputStreamTest.java b/util/src/test/java/google/registry/util/SafeObjectInputStreamTest.java new file mode 100644 index 000000000..66e606438 --- /dev/null +++ b/util/src/test/java/google/registry/util/SafeObjectInputStreamTest.java @@ -0,0 +1,126 @@ +// Copyright 2023 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.util; + +import static com.google.common.collect.Lists.newArrayList; +import static com.google.common.truth.Truth.assertThat; +import static google.registry.util.SerializeUtils.serialize; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.common.base.Objects; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; +import java.io.ByteArrayInputStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import org.joda.time.Duration; +import org.junit.jupiter.api.Test; + +/** Unit tests for {@link SafeObjectInputStream}. */ +public class SafeObjectInputStreamTest { + + @Test + void javaUnitarySuccess() throws Exception { + String orig = "some string"; + try (SafeObjectInputStream sois = + new SafeObjectInputStream(new ByteArrayInputStream(serialize(orig)))) { + assertThat(sois.readObject()).isEqualTo(orig); + } + } + + @Test + void javaCollectionFailure() throws Exception { + ArrayList orig = newArrayList("a"); + try (SafeObjectInputStream sois = + new SafeObjectInputStream(new ByteArrayInputStream(serialize(orig)))) { + assertThrows(ClassNotFoundException.class, () -> sois.readObject()); + } + } + + @Test + void javaMapFailure() throws Exception { + HashMap orig = Maps.newHashMap(); + try (SafeObjectInputStream sois = + new SafeObjectInputStream(new ByteArrayInputStream(serialize(orig)))) { + assertThrows(ClassNotFoundException.class, () -> sois.readObject()); + } + } + + @Test + void javaArrayFailure() throws Exception { + int[] orig = new int[] {1}; + try (SafeObjectInputStream sois = + new SafeObjectInputStream(new ByteArrayInputStream(serialize(orig)))) { + // For array, the parent class converts ClassNotFoundException in an undocumented way. Safer + // to catch Exception than the one thrown by the current JVM. + assertThrows(Exception.class, () -> sois.readObject()); + } + } + + @Test + void nonJavaNonNomulusUnitaryFailure() throws Exception { + Serializable orig = Duration.millis(1); + try (SafeObjectInputStream sois = + new SafeObjectInputStream(new ByteArrayInputStream(serialize(orig)))) { + assertThrows(ClassNotFoundException.class, () -> sois.readObject()); + } + } + + @Test + void nonJavaCollectionFailure() throws Exception { + ImmutableList orig = ImmutableList.of("a"); + try (SafeObjectInputStream sois = + new SafeObjectInputStream(new ByteArrayInputStream(serialize(orig)))) { + assertThrows(ClassNotFoundException.class, () -> sois.readObject()); + } + } + + @Test + void nomulusEntitySuccess() throws Exception { + NomulusEntity orig = new NomulusEntity(1); + byte[] serialized = serialize(orig); + try (SafeObjectInputStream sois = + new SafeObjectInputStream(new ByteArrayInputStream(serialized))) { + Object deserialized = sois.readObject(); + assertThat(deserialized).isEqualTo(orig); + } + } + + static class NomulusEntity implements Serializable { + Integer value; + + NomulusEntity(int value) { + this.value = value; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof NomulusEntity)) { + return false; + } + NomulusEntity that = (NomulusEntity) o; + return Objects.equal(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hashCode(value); + } + } +} diff --git a/util/src/test/java/google/registry/util/SafeSerializationUtilsTest.java b/util/src/test/java/google/registry/util/SafeSerializationUtilsTest.java new file mode 100644 index 000000000..47fc93f89 --- /dev/null +++ b/util/src/test/java/google/registry/util/SafeSerializationUtilsTest.java @@ -0,0 +1,110 @@ +// Copyright 2023 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.util; + +import static com.google.common.collect.Lists.newArrayList; +import static com.google.common.truth.Truth.assertThat; +import static google.registry.util.SafeSerializationUtils.safeDeserialize; +import static google.registry.util.SafeSerializationUtils.safeDeserializeCollection; +import static google.registry.util.SafeSerializationUtils.serializeCollection; +import static google.registry.util.SerializeUtils.serialize; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.ArrayList; +import java.util.Arrays; +import org.junit.jupiter.api.Test; + +/** Unit tests for {@link SafeSerializationUtils}. */ +public class SafeSerializationUtilsTest { + + @Test + void deserialize_array_failure() { + assertThat( + assertThrows( + IllegalArgumentException.class, () -> safeDeserialize(serialize(new byte[0])))) + .hasMessageThat() + .contains("Failed to deserialize:"); + } + + @Test + void deserialize_null_success() { + assertThat(safeDeserialize(serialize(null))).isNull(); + } + + @Test + void deserialize_map_failure() { + assertThat( + assertThrows( + IllegalArgumentException.class, + () -> safeDeserialize(serialize(ImmutableMap.of())))) + .hasMessageThat() + .contains("Failed to deserialize:"); + } + + @Test + void serializeDeserialize_null_success() { + assertThat(safeDeserialize(null)).isNull(); + } + + @Test + void serializeDeserialize_notCollection_success() { + Integer orig = 1; + assertThat(safeDeserialize(serialize(orig))).isEqualTo(orig); + } + + @Test + void serializeDeserializeCollection_success() { + ArrayList orig = newArrayList(1, 2, 3); + ImmutableList deserialized = + safeDeserializeCollection(Integer.class, serializeCollection(orig)); + assertThat(deserialized).isEqualTo(orig); + } + + @Test + void serializeDeserializeCollection_withMaxSize_success() { + Integer[] array = new Integer[SafeSerializationUtils.MAX_COLLECTION_SIZE]; + Arrays.fill(array, 1); + ArrayList orig = newArrayList(array); + assertThat(safeDeserializeCollection(Integer.class, serializeCollection(orig))).isEqualTo(orig); + } + + @Test + void serializeDeserializeCollection_tooLarge_Failure() { + Integer[] array = new Integer[SafeSerializationUtils.MAX_COLLECTION_SIZE + 1]; + Arrays.fill(array, 1); + ArrayList orig = newArrayList(array); + assertThat( + assertThrows( + IllegalArgumentException.class, + () -> safeDeserializeCollection(Integer.class, serializeCollection(orig)))) + .hasMessageThat() + .contains("Too many elements"); + } + + @Test + void serializeDeserializeCollection_wrong_elementType_success() { + ArrayList orig = newArrayList(1, 2, 3); + assertThrows( + IllegalArgumentException.class, + () -> safeDeserializeCollection(Long.class, serializeCollection(orig))); + } + + @Test + void deserializeCollection_null_failure() { + assertThrows(NullPointerException.class, () -> safeDeserializeCollection(Integer.class, null)); + } +}