Close input channel in LevelDbLogReader (#594)

* Close input channel in LevelDbLogReader

Input channel should be closed when all data has been read.
This commit is contained in:
Weimin Yu 2020-05-20 12:54:13 -04:00 committed by GitHub
parent 3947ac6ef7
commit ca2edb6a17
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 6 deletions

View file

@ -34,6 +34,8 @@ import java.util.Optional;
/** /**
* Iterator that incrementally parses binary data in LevelDb format into records. * Iterator that incrementally parses binary data in LevelDb format into records.
* *
* <p>The input source is automatically closed when all data have been read.
*
* <p>See <a * <p>See <a
* href="https://github.com/google/leveldb/blob/master/doc/log_format.md">log_format.md</a> for the * href="https://github.com/google/leveldb/blob/master/doc/log_format.md">log_format.md</a> for the
* leveldb log format specification.</a> * leveldb log format specification.</a>
@ -92,11 +94,12 @@ public final class LevelDbLogReader implements Iterator<byte[]> {
*/ */
// TODO(weiminyu): use ByteBuffer directly. // TODO(weiminyu): use ByteBuffer directly.
private Optional<byte[]> readFromChannel() throws IOException { private Optional<byte[]> readFromChannel() throws IOException {
while (true) { while (channel.isOpen()) {
int bytesRead = channel.read(byteBuffer); int bytesRead = channel.read(byteBuffer);
if (!byteBuffer.hasRemaining() || bytesRead < 0) { if (!byteBuffer.hasRemaining() || bytesRead < 0) {
byteBuffer.flip(); byteBuffer.flip();
if (!byteBuffer.hasRemaining()) { if (!byteBuffer.hasRemaining()) {
channel.close();
return Optional.empty(); return Optional.empty();
} }
byte[] result = new byte[byteBuffer.remaining()]; byte[] result = new byte[byteBuffer.remaining()];
@ -105,6 +108,7 @@ public final class LevelDbLogReader implements Iterator<byte[]> {
return Optional.of(result); return Optional.of(result);
} }
} }
return Optional.empty();
} }
/** Read a complete block, which must be exactly 32 KB. */ /** Read a complete block, which must be exactly 32 KB. */

View file

@ -18,6 +18,9 @@ import static com.google.common.truth.Truth.assertThat;
import static google.registry.tools.LevelDbUtil.MAX_RECORD; import static google.registry.tools.LevelDbUtil.MAX_RECORD;
import static google.registry.tools.LevelDbUtil.addRecord; import static google.registry.tools.LevelDbUtil.addRecord;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Bytes; import com.google.common.primitives.Bytes;
@ -93,7 +96,7 @@ public final class LevelDbLogReaderTest {
} }
@Test @Test
void read_noData() { void read_noData() throws IOException {
assertThat(readIncrementally(new byte[0])).isEmpty(); assertThat(readIncrementally(new byte[0])).isEmpty();
} }
@ -119,10 +122,11 @@ public final class LevelDbLogReaderTest {
} }
@SafeVarargs @SafeVarargs
private static ImmutableList<byte[]> readIncrementally(byte[]... blocks) { private static ImmutableList<byte[]> readIncrementally(byte[]... blocks) throws IOException {
LevelDbLogReader recordReader = ByteArrayInputStream source = spy(new ByteArrayInputStream(Bytes.concat(blocks)));
LevelDbLogReader.from(new ByteArrayInputStream(Bytes.concat(blocks))); ImmutableList<byte[]> records = ImmutableList.copyOf(LevelDbLogReader.from(source));
return ImmutableList.copyOf(recordReader); verify(source, times(1)).close();
return records;
} }
/** Aggregates the bytes of a test block with the record count. */ /** Aggregates the bytes of a test block with the record count. */