diff --git a/core/src/main/java/google/registry/tools/LevelDbLogReader.java b/core/src/main/java/google/registry/tools/LevelDbLogReader.java index 8f6b26a7b..c88e55c82 100644 --- a/core/src/main/java/google/registry/tools/LevelDbLogReader.java +++ b/core/src/main/java/google/registry/tools/LevelDbLogReader.java @@ -34,6 +34,8 @@ import java.util.Optional; /** * Iterator that incrementally parses binary data in LevelDb format into records. * + *

The input source is automatically closed when all data have been read. + * *

See log_format.md for the * leveldb log format specification. @@ -92,11 +94,12 @@ public final class LevelDbLogReader implements Iterator { */ // TODO(weiminyu): use ByteBuffer directly. private Optional readFromChannel() throws IOException { - while (true) { + while (channel.isOpen()) { int bytesRead = channel.read(byteBuffer); if (!byteBuffer.hasRemaining() || bytesRead < 0) { byteBuffer.flip(); if (!byteBuffer.hasRemaining()) { + channel.close(); return Optional.empty(); } byte[] result = new byte[byteBuffer.remaining()]; @@ -105,6 +108,7 @@ public final class LevelDbLogReader implements Iterator { return Optional.of(result); } } + return Optional.empty(); } /** Read a complete block, which must be exactly 32 KB. */ diff --git a/core/src/test/java/google/registry/tools/LevelDbLogReaderTest.java b/core/src/test/java/google/registry/tools/LevelDbLogReaderTest.java index 1be5cb562..e2b6dc84d 100644 --- a/core/src/test/java/google/registry/tools/LevelDbLogReaderTest.java +++ b/core/src/test/java/google/registry/tools/LevelDbLogReaderTest.java @@ -18,6 +18,9 @@ import static com.google.common.truth.Truth.assertThat; import static google.registry.tools.LevelDbUtil.MAX_RECORD; import static google.registry.tools.LevelDbUtil.addRecord; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Bytes; @@ -93,7 +96,7 @@ public final class LevelDbLogReaderTest { } @Test - void read_noData() { + void read_noData() throws IOException { assertThat(readIncrementally(new byte[0])).isEmpty(); } @@ -119,10 +122,11 @@ public final class LevelDbLogReaderTest { } @SafeVarargs - private static ImmutableList readIncrementally(byte[]... blocks) { - LevelDbLogReader recordReader = - LevelDbLogReader.from(new ByteArrayInputStream(Bytes.concat(blocks))); - return ImmutableList.copyOf(recordReader); + private static ImmutableList readIncrementally(byte[]... blocks) throws IOException { + ByteArrayInputStream source = spy(new ByteArrayInputStream(Bytes.concat(blocks))); + ImmutableList records = ImmutableList.copyOf(LevelDbLogReader.from(source)); + verify(source, times(1)).close(); + return records; } /** Aggregates the bytes of a test block with the record count. */