diff --git a/java/google/registry/tools/CurlCommand.java b/java/google/registry/tools/CurlCommand.java index 3804b8d48..9e0e744e5 100644 --- a/java/google/registry/tools/CurlCommand.java +++ b/java/google/registry/tools/CurlCommand.java @@ -14,11 +14,15 @@ package google.registry.tools; +import static com.google.common.base.Preconditions.checkArgument; import static java.nio.charset.StandardCharsets.UTF_8; +import com.beust.jcommander.IStringConverter; import com.beust.jcommander.Parameter; import com.beust.jcommander.Parameters; +import com.beust.jcommander.converters.IParameterSplitter; import com.google.common.base.Joiner; +import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.net.MediaType; @@ -50,6 +54,7 @@ class CurlCommand implements CommandWithConnection { @Parameter( names = {"-t", "--content-type"}, + converter = MediaTypeConverter.class, description = "Media type of the request body (for a POST request. Must be combined with --body)") private MediaType mimeType = MediaType.PLAIN_TEXT_UTF_8; @@ -58,6 +63,7 @@ class CurlCommand implements CommandWithConnection { // GET...) @Parameter( names = {"-d", "--data"}, + splitter = NoSplittingSplitter.class, description = "Body for a post request. If specified, a POST request is sent. If " + "absent, a GET request is sent.") @@ -95,4 +101,20 @@ class CurlCommand implements CommandWithConnection { Joiner.on("&").join(data).getBytes(UTF_8)); System.out.println(response); } + + public static class MediaTypeConverter implements IStringConverter { + @Override + public MediaType convert(String mediaType) { + List parts = Splitter.on('/').splitToList(mediaType); + checkArgument(parts.size() == 2, "invalid MediaType '%s'", mediaType); + return MediaType.create(parts.get(0), parts.get(1)).withCharset(UTF_8); + } + } + + public static class NoSplittingSplitter implements IParameterSplitter { + @Override + public List split(String value) { + return ImmutableList.of(value); + } + } } diff --git a/javatests/google/registry/tools/CurlCommandTest.java b/javatests/google/registry/tools/CurlCommandTest.java index 84144524d..4275620c1 100644 --- a/javatests/google/registry/tools/CurlCommandTest.java +++ b/javatests/google/registry/tools/CurlCommandTest.java @@ -79,6 +79,37 @@ public class CurlCommandTest extends CommandTestCase { eq("some data".getBytes(UTF_8))); } + @Test + public void testPostInvocation_withContentType() throws Exception { + runCommand( + "--path=/foo/bar?a=1&b=2", + "--data=some data", + "--service=DEFAULT", + "--content-type=application/json"); + verify(connection).withService(DEFAULT); + verifyNoMoreInteractions(connection); + verify(connectionForService) + .sendPostRequest( + eq("/foo/bar?a=1&b=2"), + eq(ImmutableMap.of()), + eq(MediaType.JSON_UTF_8), + eq("some data".getBytes(UTF_8))); + } + + @Test + public void testPostInvocation_badContentType() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + runCommand( + "--path=/foo/bar?a=1&b=2", + "--data=some data", + "--service=DEFAULT", + "--content-type=bad")); + verifyNoMoreInteractions(connection); + verifyNoMoreInteractions(connectionForService); + } + @Test public void testMultiDataPost() throws Exception { runCommand( @@ -93,6 +124,20 @@ public class CurlCommandTest extends CommandTestCase { eq("first=100&second=200".getBytes(UTF_8))); } + @Test + public void testDataDoesntSplit() throws Exception { + runCommand( + "--path=/foo/bar?a=1&b=2", "--data=one,two", "--service=PUBAPI"); + verify(connection).withService(PUBAPI); + verifyNoMoreInteractions(connection); + verify(connectionForService) + .sendPostRequest( + eq("/foo/bar?a=1&b=2"), + eq(ImmutableMap.of()), + eq(MediaType.PLAIN_TEXT_UTF_8), + eq("one,two".getBytes(UTF_8))); + } + @Test public void testExplicitPostInvocation() throws Exception { runCommand("--path=/foo/bar?a=1&b=2", "--request=POST", "--service=TOOLS");