diff --git a/java/google/registry/keyring/kms/KmsConnectionImpl.java b/java/google/registry/keyring/kms/KmsConnectionImpl.java index 7db18aeaf..14a23625a 100644 --- a/java/google/registry/keyring/kms/KmsConnectionImpl.java +++ b/java/google/registry/keyring/kms/KmsConnectionImpl.java @@ -27,7 +27,9 @@ import com.google.api.services.cloudkms.v1beta1.model.KeyRing; import com.google.api.services.cloudkms.v1beta1.model.UpdateCryptoKeyPrimaryVersionRequest; import google.registry.config.RegistryConfig.Config; import google.registry.keyring.api.KeyringException; +import google.registry.util.Retrier; import java.io.IOException; +import java.util.concurrent.Callable; import javax.inject.Inject; /** The {@link KmsConnection} which talks to Cloud KMS. */ @@ -41,14 +43,17 @@ class KmsConnectionImpl implements KmsConnection { private final CloudKMS kms; private final String kmsKeyRingName; private final String projectId; + private final Retrier retrier; @Inject KmsConnectionImpl( @Config("cloudKmsProjectId") String projectId, @Config("cloudKmsKeyRing") String kmsKeyringName, + Retrier retrier, CloudKMS kms) { this.projectId = projectId; this.kmsKeyRingName = kmsKeyringName; + this.retrier = retrier; this.kms = kms; } @@ -129,23 +134,34 @@ class KmsConnectionImpl implements KmsConnection { } @Override - public byte[] decrypt(String cryptoKeyName, String encodedCiphertext) { + public byte[] decrypt(final String cryptoKeyName, final String encodedCiphertext) { try { - return kms.projects() - .locations() - .keyRings() - .cryptoKeys() - .decrypt( - getCryptoKeyName(projectId, kmsKeyRingName, cryptoKeyName), - new DecryptRequest().setCiphertext(encodedCiphertext)) - .execute() - .decodePlaintext(); - } catch (IOException e) { + return retrier.callWithRetry( + new Callable() { + @Override + public byte[] call() throws IOException { + return attemptDecrypt(cryptoKeyName, encodedCiphertext); + } + }, + IOException.class); + } catch (RuntimeException e) { throw new KeyringException( String.format("CloudKMS decrypt operation failed for secret %s", cryptoKeyName), e); } } + private byte[] attemptDecrypt(String cryptoKeyName, String encodedCiphertext) throws IOException{ + return kms.projects() + .locations() + .keyRings() + .cryptoKeys() + .decrypt( + getCryptoKeyName(projectId, kmsKeyRingName, cryptoKeyName), + new DecryptRequest().setCiphertext(encodedCiphertext)) + .execute() + .decodePlaintext(); + } + private static String getLocationName(String projectId) { return String.format(KMS_LOCATION_FORMAT, projectId); } diff --git a/javatests/google/registry/keyring/kms/BUILD b/javatests/google/registry/keyring/kms/BUILD index 01d0baef2..5df9acd5d 100644 --- a/javatests/google/registry/keyring/kms/BUILD +++ b/javatests/google/registry/keyring/kms/BUILD @@ -18,6 +18,7 @@ java_library( "//java/google/registry/keyring/api", "//java/google/registry/keyring/kms", "//java/google/registry/model", + "//java/google/registry/util", "//javatests/google/registry/testing", "//third_party/java/objectify:objectify-v4_1", "@com_google_api_client", diff --git a/javatests/google/registry/keyring/kms/KmsConnectionImplTest.java b/javatests/google/registry/keyring/kms/KmsConnectionImplTest.java index 652ba281d..6f034e71a 100644 --- a/javatests/google/registry/keyring/kms/KmsConnectionImplTest.java +++ b/javatests/google/registry/keyring/kms/KmsConnectionImplTest.java @@ -34,6 +34,9 @@ import com.google.api.services.cloudkms.v1beta1.model.EncryptRequest; import com.google.api.services.cloudkms.v1beta1.model.EncryptResponse; import com.google.api.services.cloudkms.v1beta1.model.KeyRing; import com.google.api.services.cloudkms.v1beta1.model.UpdateCryptoKeyPrimaryVersionRequest; +import google.registry.testing.FakeClock; +import google.registry.testing.FakeSleeper; +import google.registry.util.Retrier; import java.io.ByteArrayInputStream; import org.junit.Before; import org.junit.Test; @@ -69,6 +72,8 @@ public class KmsConnectionImplTest { @Mock private CloudKMS.Projects.Locations.KeyRings.CryptoKeys.Encrypt kmsCryptoKeysEncrypt; @Mock private CloudKMS.Projects.Locations.KeyRings.CryptoKeys.Decrypt kmsCryptoKeysDecrypt; + private final Retrier retrier = new Retrier(new FakeSleeper(new FakeClock()), 3); + @Captor private ArgumentCaptor keyRing; @Captor private ArgumentCaptor cryptoKey; @Captor private ArgumentCaptor cryptoKeyVersion; @@ -116,7 +121,7 @@ public class KmsConnectionImplTest { public void test_encrypt_createsKeyRingIfNotFound() throws Exception { when(kmsKeyRingsGet.execute()).thenThrow(createNotFoundException()); - new KmsConnectionImpl("foo", "bar", kms).encrypt("key", "moo".getBytes(UTF_8)); + new KmsConnectionImpl("foo", "bar", retrier, kms).encrypt("key", "moo".getBytes(UTF_8)); verify(kmsKeyRings).create(locationName.capture(), keyRing.capture()); assertThat(locationName.getValue()).isEqualTo("projects/foo/locations/global"); @@ -135,7 +140,7 @@ public class KmsConnectionImplTest { public void test_encrypt_newCryptoKey() throws Exception { when(kmsCryptoKeysGet.execute()).thenThrow(createNotFoundException()); - new KmsConnectionImpl("foo", "bar", kms).encrypt("key", "moo".getBytes(UTF_8)); + new KmsConnectionImpl("foo", "bar", retrier, kms).encrypt("key", "moo".getBytes(UTF_8)); verify(kmsCryptoKeys).create(keyRingName.capture(), cryptoKey.capture()); assertThat(keyRingName.getValue()).isEqualTo("projects/foo/locations/global/keyRings/bar"); @@ -154,7 +159,7 @@ public class KmsConnectionImplTest { @Test public void test_encrypt() throws Exception { - new KmsConnectionImpl("foo", "bar", kms).encrypt("key", "moo".getBytes(UTF_8)); + new KmsConnectionImpl("foo", "bar", retrier, kms).encrypt("key", "moo".getBytes(UTF_8)); verify(kmsCryptoKeyVersions).create(cryptoKeyName.capture(), cryptoKeyVersion.capture()); @@ -182,7 +187,7 @@ public class KmsConnectionImplTest { when(kmsCryptoKeysDecrypt.execute()) .thenReturn(new DecryptResponse().encodePlaintext("moo".getBytes(UTF_8))); - byte[] plaintext = new KmsConnectionImpl("foo", "bar", kms).decrypt("key", "blah"); + byte[] plaintext = new KmsConnectionImpl("foo", "bar", retrier, kms).decrypt("key", "blah"); verify(kmsCryptoKeys).decrypt(cryptoKeyName.capture(), decryptRequest.capture()); assertThat(cryptoKeyName.getValue())