diff --git a/java/google/registry/util/UrlFetchUtils.java b/java/google/registry/util/UrlFetchUtils.java index c7064c3fb..e40637d72 100644 --- a/java/google/registry/util/UrlFetchUtils.java +++ b/java/google/registry/util/UrlFetchUtils.java @@ -14,7 +14,7 @@ package google.registry.util; -import static com.google.common.io.BaseEncoding.base32; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.io.BaseEncoding.base64; import static com.google.common.net.HttpHeaders.AUTHORIZATION; import static com.google.common.net.HttpHeaders.CONTENT_DISPOSITION; @@ -28,16 +28,15 @@ import com.google.appengine.api.urlfetch.HTTPRequest; import com.google.appengine.api.urlfetch.HTTPResponse; import com.google.common.base.Ascii; import com.google.common.base.Optional; +import com.google.common.base.Strings; import com.google.common.net.MediaType; -import java.security.NoSuchAlgorithmException; -import java.security.ProviderException; -import java.security.SecureRandom; +import java.util.Random; /** Helper methods for the App Engine URL fetch service. */ public final class UrlFetchUtils { @NonFinalForTesting - private static SecureRandom secureRandom = initSecureRandom(); + private static Random random = new Random(); /** Returns value of first header matching {@code name}. */ public static Optional getHeaderFirst(HTTPResponse rsp, String name) { @@ -66,9 +65,12 @@ public final class UrlFetchUtils { * * @see RFC2388 - Returning Values from Forms */ - public static void setPayloadMultipart( - HTTPRequest request, String name, String filename, MediaType contentType, T data) { + public static void setPayloadMultipart( + HTTPRequest request, String name, String filename, MediaType contentType, String data) { String boundary = createMultipartBoundary(); + checkState( + !data.contains(boundary), + "Multipart data contains autogenerated boundary: %s", boundary); StringBuilder multipart = new StringBuilder(); multipart.append(format("--%s\r\n", boundary)); multipart.append(format("%s: form-data; name=\"%s\"; filename=\"%s\"\r\n", @@ -79,23 +81,19 @@ public final class UrlFetchUtils { multipart.append("\r\n"); multipart.append(format("--%s--", boundary)); byte[] payload = multipart.toString().getBytes(UTF_8); - request.addHeader(new HTTPHeader(CONTENT_TYPE, "multipart/form-data; boundary=" + boundary)); + request.addHeader( + new HTTPHeader(CONTENT_TYPE, format("multipart/form-data; boundary=\"%s\"", boundary))); request.addHeader(new HTTPHeader(CONTENT_LENGTH, Integer.toString(payload.length))); request.setPayload(payload); } private static String createMultipartBoundary() { - byte[] rand = new byte[5]; // Avoid base32 padding since `5 * 8 % log2(32) == 0` - secureRandom.nextBytes(rand); - return "------------------------------" + base32().encode(rand); - } - - private static SecureRandom initSecureRandom() { - try { - return SecureRandom.getInstance("NativePRNG"); - } catch (NoSuchAlgorithmException e) { - throw new ProviderException(e); - } + // Generate 192 random bits (24 bytes) to produce 192/log_2(64) = 192/6 = 32 base64 digits. + byte[] rand = new byte[24]; + random.nextBytes(rand); + // Boundary strings can be up to 70 characters long, so use 30 hyphens plus 32 random digits. + // See https://tools.ietf.org/html/rfc2046#section-5.1.1 + return Strings.repeat("-", 30) + base64().encode(rand); } /** Sets the HTTP Basic Authentication header on an {@link HTTPRequest}. */ diff --git a/javatests/google/registry/util/UrlFetchUtilsTest.java b/javatests/google/registry/util/UrlFetchUtilsTest.java index cb4b337e0..21af3d01a 100644 --- a/javatests/google/registry/util/UrlFetchUtilsTest.java +++ b/javatests/google/registry/util/UrlFetchUtilsTest.java @@ -17,27 +17,30 @@ package google.registry.util; import static com.google.common.net.HttpHeaders.CONTENT_LENGTH; import static com.google.common.net.HttpHeaders.CONTENT_TYPE; import static com.google.common.net.MediaType.CSV_UTF_8; +import static com.google.common.truth.Truth.assertThat; import static google.registry.util.UrlFetchUtils.setPayloadMultipart; import static java.nio.charset.StandardCharsets.UTF_8; import static org.mockito.Matchers.any; -import static org.mockito.Matchers.argThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import com.google.appengine.api.urlfetch.HTTPHeader; import com.google.appengine.api.urlfetch.HTTPRequest; import google.registry.testing.AppEngineRule; +import google.registry.testing.ExceptionRule; import google.registry.testing.InjectRule; -import java.security.SecureRandom; import java.util.Arrays; +import java.util.List; +import java.util.Random; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.ArgumentMatcher; +import org.mockito.ArgumentCaptor; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -49,55 +52,57 @@ public class UrlFetchUtilsTest { public final AppEngineRule appEngine = AppEngineRule.builder() .build(); + @Rule + public final ExceptionRule thrown = new ExceptionRule(); + @Rule public final InjectRule inject = new InjectRule(); @Before public void setupRandomZeroes() throws Exception { - SecureRandom secureRandom = mock(SecureRandom.class); - inject.setStaticField(UrlFetchUtils.class, "secureRandom", secureRandom); + Random random = mock(Random.class); + inject.setStaticField(UrlFetchUtils.class, "random", random); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock info) throws Throwable { - byte[] bytes = (byte[]) info.getArguments()[0]; - Arrays.fill(bytes, (byte) 0); + Arrays.fill((byte[]) info.getArguments()[0], (byte) 0); return null; - }}).when(secureRandom).nextBytes(any(byte[].class)); + }}).when(random).nextBytes(any(byte[].class)); } @Test public void testSetPayloadMultipart() throws Exception { - String payload = "--------------------------------AAAAAAAA\r\n" + HTTPRequest request = mock(HTTPRequest.class); + setPayloadMultipart( + request, "lol", "cat", CSV_UTF_8, "The nice people at the store say hello. ヘ(◕。◕ヘ)"); + ArgumentCaptor headerCaptor = ArgumentCaptor.forClass(HTTPHeader.class); + verify(request, times(2)).addHeader(headerCaptor.capture()); + List addedHeaders = headerCaptor.getAllValues(); + assertThat(addedHeaders.get(0).getName()).isEqualTo(CONTENT_TYPE); + assertThat(addedHeaders.get(0).getValue()) + .isEqualTo( + "multipart/form-data; " + + "boundary=\"------------------------------AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\""); + assertThat(addedHeaders.get(1).getName()).isEqualTo(CONTENT_LENGTH); + assertThat(addedHeaders.get(1).getValue()).isEqualTo("292"); + String payload = "--------------------------------AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\r\n" + "Content-Disposition: form-data; name=\"lol\"; filename=\"cat\"\r\n" + "Content-Type: text/csv; charset=utf-8\r\n" + "\r\n" + "The nice people at the store say hello. ヘ(◕。◕ヘ)\r\n" - + "--------------------------------AAAAAAAA--"; - HTTPRequest request = mock(HTTPRequest.class); - setPayloadMultipart( - request, "lol", "cat", CSV_UTF_8, "The nice people at the store say hello. ヘ(◕。◕ヘ)"); - verify(request).addHeader(argThat(new HTTPHeaderMatcher( - CONTENT_TYPE, "multipart/form-data; boundary=------------------------------AAAAAAAA"))); - verify(request).addHeader(argThat(new HTTPHeaderMatcher(CONTENT_LENGTH, "244"))); + + "--------------------------------AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA--"; verify(request).setPayload(payload.getBytes(UTF_8)); verifyNoMoreInteractions(request); } - /** Mockito matcher for {@link HTTPHeader}. */ - public static class HTTPHeaderMatcher extends ArgumentMatcher { - private final String name; - private final String value; - - public HTTPHeaderMatcher(String name, String value) { - this.name = name; - this.value = value; - } - - @Override - public boolean matches(Object arg) { - HTTPHeader header = (HTTPHeader) arg; - return name.equals(header.getName()) - && value.equals(header.getValue()); - } + @Test + public void testSetPayloadMultipart_boundaryInPayload() throws Exception { + HTTPRequest request = mock(HTTPRequest.class); + String payload = "I screamed------------------------------AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHHH"; + thrown.expect( + IllegalStateException.class, + "Multipart data contains autogenerated boundary: " + + "------------------------------AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"); + setPayloadMultipart(request, "lol", "cat", CSV_UTF_8, payload); } }