diff --git a/core/src/main/java/google/registry/batch/WipeOutCloudSqlAction.java b/core/src/main/java/google/registry/batch/WipeOutCloudSqlAction.java index dc12eadc0..70ad115ff 100644 --- a/core/src/main/java/google/registry/batch/WipeOutCloudSqlAction.java +++ b/core/src/main/java/google/registry/batch/WipeOutCloudSqlAction.java @@ -19,6 +19,7 @@ import static javax.servlet.http.HttpServletResponse.SC_FORBIDDEN; import static javax.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR; import static javax.servlet.http.HttpServletResponse.SC_OK; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.flogger.FluentLogger; import google.registry.config.RegistryConfig.Config; @@ -28,10 +29,11 @@ import google.registry.request.Response; import google.registry.request.auth.Auth; import google.registry.util.Retrier; import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; import java.sql.Statement; import java.util.function.Supplier; import javax.inject.Inject; -import org.flywaydb.core.api.FlywayException; /** * Wipes out all Cloud SQL data in a Nomulus GCP environment. @@ -80,13 +82,13 @@ public class WipeOutCloudSqlAction implements Runnable { try { retrier.callWithRetry( () -> { - try (Connection conn = connectionSupplier.get(); - Statement statement = conn.createStatement()) { - statement.execute("drop owned by schema_deployer;"); + try (Connection conn = connectionSupplier.get()) { + dropAllTables(conn, listTables(conn)); + dropAllSequences(conn, listSequences(conn)); } return null; }, - e -> !(e instanceof FlywayException)); + e -> !(e instanceof SQLException)); response.setStatus(SC_OK); response.setPayload("Wiped out Cloud SQL in " + projectId); } catch (RuntimeException e) { @@ -95,4 +97,69 @@ public class WipeOutCloudSqlAction implements Runnable { response.setPayload("Failed to wipe out Cloud SQL in " + projectId); } } + + /** Returns a list of all tables in the public schema of a Postgresql database. */ + static ImmutableList listTables(Connection connection) throws SQLException { + try (ResultSet resultSet = + connection.getMetaData().getTables(null, null, null, new String[] {"TABLE"})) { + ImmutableList.Builder tables = new ImmutableList.Builder<>(); + while (resultSet.next()) { + String schema = resultSet.getString("TABLE_SCHEM"); + if (schema == null || !schema.equalsIgnoreCase("public")) { + continue; + } + String tableName = resultSet.getString("TABLE_NAME"); + tables.add("public.\"" + tableName + "\""); + } + return tables.build(); + } + } + + static void dropAllTables(Connection conn, ImmutableList tables) throws SQLException { + if (tables.isEmpty()) { + return; + } + + try (Statement statement = conn.createStatement()) { + for (String table : tables) { + statement.addBatch(String.format("DROP TABLE IF EXISTS %s CASCADE;", table)); + } + for (int code : statement.executeBatch()) { + if (code == Statement.EXECUTE_FAILED) { + throw new RuntimeException("Failed to drop some tables. Please check."); + } + } + } + } + + /** Returns a list of all sequences in a Postgresql database. */ + static ImmutableList listSequences(Connection conn) throws SQLException { + try (Statement statement = conn.createStatement(); + ResultSet resultSet = + statement.executeQuery("SELECT c.relname FROM pg_class c WHERE c.relkind = 'S';")) { + ImmutableList.Builder sequences = new ImmutableList.Builder<>(); + while (resultSet.next()) { + sequences.add('\"' + resultSet.getString(1) + '\"'); + } + return sequences.build(); + } + } + + static void dropAllSequences(Connection conn, ImmutableList sequences) + throws SQLException { + if (sequences.isEmpty()) { + return; + } + + try (Statement statement = conn.createStatement()) { + for (String sequence : sequences) { + statement.addBatch(String.format("DROP SEQUENCE IF EXISTS %s CASCADE;", sequence)); + } + for (int code : statement.executeBatch()) { + if (code == Statement.EXECUTE_FAILED) { + throw new RuntimeException("Failed to drop some sequences. Please check."); + } + } + } + } } diff --git a/core/src/test/java/google/registry/batch/WipeOutCloudSqlActionTest.java b/core/src/test/java/google/registry/batch/WipeOutCloudSqlActionTest.java index 31f4fe940..484bbff09 100644 --- a/core/src/test/java/google/registry/batch/WipeOutCloudSqlActionTest.java +++ b/core/src/test/java/google/registry/batch/WipeOutCloudSqlActionTest.java @@ -19,6 +19,7 @@ import static javax.servlet.http.HttpServletResponse.SC_FORBIDDEN; import static javax.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR; import static javax.servlet.http.HttpServletResponse.SC_OK; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.times; @@ -32,8 +33,10 @@ import google.registry.testing.FakeResponse; import google.registry.testing.FakeSleeper; import google.registry.util.Retrier; import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.SQLException; import java.sql.Statement; -import org.flywaydb.core.api.FlywayException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -46,6 +49,8 @@ public class WipeOutCloudSqlActionTest { @Mock private Statement stmt; @Mock private Connection conn; + @Mock private DatabaseMetaData metaData; + @Mock private ResultSet resultSet; private FakeResponse response = new FakeResponse(); private Retrier retrier = new Retrier(new FakeSleeper(new FakeClock()), 2); @@ -53,6 +58,17 @@ public class WipeOutCloudSqlActionTest { @BeforeEach void beforeEach() throws Exception { lenient().when(conn.createStatement()).thenReturn(stmt); + lenient().when(conn.getMetaData()).thenReturn(metaData); + lenient() + .when( + metaData.getTables( + nullable(String.class), + nullable(String.class), + nullable(String.class), + nullable(String[].class))) + .thenReturn(resultSet); + lenient().when(stmt.executeQuery(anyString())).thenReturn(resultSet); + lenient().when(resultSet.next()).thenReturn(false); } @Test @@ -61,7 +77,7 @@ public class WipeOutCloudSqlActionTest { new WipeOutCloudSqlAction("domain-registry-qa", () -> conn, response, retrier); action.run(); assertThat(response.getStatus()).isEqualTo(SC_OK); - verify(stmt, times(1)).execute(anyString()); + verify(stmt, times(1)).executeQuery(anyString()); verify(stmt, times(1)).close(); verifyNoMoreInteractions(stmt); } @@ -77,25 +93,23 @@ public class WipeOutCloudSqlActionTest { @Test void run_nonRetrieableFailure() throws Exception { - doThrow(new FlywayException()).when(stmt).execute(anyString()); + doThrow(new SQLException()).when(conn).getMetaData(); WipeOutCloudSqlAction action = new WipeOutCloudSqlAction("domain-registry-qa", () -> conn, response, retrier); action.run(); assertThat(response.getStatus()).isEqualTo(SC_INTERNAL_SERVER_ERROR); - verify(stmt, times(1)).execute(anyString()); - verify(stmt, times(1)).close(); - verifyNoMoreInteractions(stmt); + verifyNoInteractions(stmt); } @Test void run_retrieableFailure() throws Exception { - when(stmt.execute(anyString())).thenThrow(new RuntimeException()).thenReturn(true); + when(conn.getMetaData()).thenThrow(new RuntimeException()).thenReturn(metaData); WipeOutCloudSqlAction action = new WipeOutCloudSqlAction("domain-registry-qa", () -> conn, response, retrier); action.run(); assertThat(response.getStatus()).isEqualTo(SC_OK); - verify(stmt, times(2)).execute(anyString()); - verify(stmt, times(2)).close(); + verify(stmt, times(1)).executeQuery(anyString()); + verify(stmt, times(1)).close(); verifyNoMoreInteractions(stmt); } } diff --git a/core/src/test/java/google/registry/batch/WipeOutCloudSqlIntegrationTest.java b/core/src/test/java/google/registry/batch/WipeOutCloudSqlIntegrationTest.java new file mode 100644 index 000000000..a3d785db5 --- /dev/null +++ b/core/src/test/java/google/registry/batch/WipeOutCloudSqlIntegrationTest.java @@ -0,0 +1,89 @@ +// Copyright 2021 The Nomulus Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package google.registry.batch; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import google.registry.persistence.NomulusPostgreSql; +import java.sql.Connection; +import java.sql.Statement; +import java.util.Properties; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +/** Tests the database wipeout mechanism used by {@link WipeOutCloudSqlAction}. */ +@Testcontainers +public class WipeOutCloudSqlIntegrationTest { + + @Container + PostgreSQLContainer container = new PostgreSQLContainer(NomulusPostgreSql.getDockerTag()); + + private Connection getJdbcConnection() throws Exception { + Properties properties = new Properties(); + properties.setProperty("user", container.getUsername()); + properties.setProperty("password", container.getPassword()); + return container.getJdbcDriverInstance().connect(container.getJdbcUrl(), properties); + } + + @BeforeEach + void beforeEach() throws Exception { + try (Connection conn = getJdbcConnection(); + Statement statement = conn.createStatement()) { + statement.addBatch("CREATE TABLE public.\"Domain\" (value int);"); + statement.addBatch("CREATE SEQUENCE public.\"Domain_seq\""); + statement.executeBatch(); + } + } + + @Test + void listTables() throws Exception { + try (Connection conn = getJdbcConnection()) { + ImmutableList tables = WipeOutCloudSqlAction.listTables(conn); + assertThat(tables).containsExactly("public.\"Domain\""); + } + } + + @Test + void dropAllTables() throws Exception { + try (Connection conn = getJdbcConnection()) { + ImmutableList tables = WipeOutCloudSqlAction.listTables(conn); + assertThat(tables).isNotEmpty(); + WipeOutCloudSqlAction.dropAllTables(conn, tables); + assertThat(WipeOutCloudSqlAction.listTables(conn)).isEmpty(); + } + } + + @Test + void listAllSequences() throws Exception { + try (Connection conn = getJdbcConnection()) { + ImmutableList sequences = WipeOutCloudSqlAction.listSequences(conn); + assertThat(sequences).containsExactly("\"Domain_seq\""); + } + } + + @Test + void dropAllSequences() throws Exception { + try (Connection conn = getJdbcConnection()) { + ImmutableList sequences = WipeOutCloudSqlAction.listSequences(conn); + assertThat(sequences).isNotEmpty(); + WipeOutCloudSqlAction.dropAllSequences(conn, sequences); + assertThat(WipeOutCloudSqlAction.listSequences(conn)).isEmpty(); + } + } +}