mirror of
https://github.com/google/nomulus.git
synced 2025-05-13 07:57:13 +02:00
More information: https://github.com/junit-team/junit5/issues/531 ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=187034408
190 lines
7.7 KiB
Java
190 lines
7.7 KiB
Java
// Copyright 2017 The Nomulus Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package google.registry.dns.writer.dnsupdate;
|
|
|
|
import static com.google.common.io.BaseEncoding.base16;
|
|
import static com.google.common.truth.Truth.assertThat;
|
|
import static google.registry.testing.JUnitBackports.assertThrows;
|
|
import static org.mockito.Mockito.mock;
|
|
import static org.mockito.Mockito.verify;
|
|
import static org.mockito.Mockito.when;
|
|
|
|
import com.google.common.base.VerifyException;
|
|
import java.io.ByteArrayInputStream;
|
|
import java.io.ByteArrayOutputStream;
|
|
import java.io.EOFException;
|
|
import java.io.InputStream;
|
|
import java.net.InetAddress;
|
|
import java.net.Socket;
|
|
import java.net.SocketTimeoutException;
|
|
import java.nio.ByteBuffer;
|
|
import java.util.Arrays;
|
|
import javax.net.SocketFactory;
|
|
import org.joda.time.Duration;
|
|
import org.junit.Before;
|
|
import org.junit.Test;
|
|
import org.junit.runner.RunWith;
|
|
import org.junit.runners.JUnit4;
|
|
import org.xbill.DNS.ARecord;
|
|
import org.xbill.DNS.DClass;
|
|
import org.xbill.DNS.Flags;
|
|
import org.xbill.DNS.Message;
|
|
import org.xbill.DNS.Name;
|
|
import org.xbill.DNS.Opcode;
|
|
import org.xbill.DNS.Rcode;
|
|
import org.xbill.DNS.Record;
|
|
import org.xbill.DNS.Type;
|
|
import org.xbill.DNS.Update;
|
|
|
|
/** Unit tests for {@link DnsMessageTransport}. */
|
|
@RunWith(JUnit4.class)
|
|
public class DnsMessageTransportTest {
|
|
|
|
private static final String UPDATE_HOST = "127.0.0.1";
|
|
|
|
private final SocketFactory mockFactory = mock(SocketFactory.class);
|
|
private final Socket mockSocket = mock(Socket.class);
|
|
|
|
private Message simpleQuery;
|
|
private Message expectedResponse;
|
|
private DnsMessageTransport resolver;
|
|
@Before
|
|
public void before() throws Exception {
|
|
simpleQuery =
|
|
Message.newQuery(Record.newRecord(Name.fromString("example.com."), Type.A, DClass.IN));
|
|
expectedResponse = responseMessageWithCode(simpleQuery, Rcode.NOERROR);
|
|
when(mockFactory.createSocket(InetAddress.getByName(UPDATE_HOST), DnsMessageTransport.DNS_PORT))
|
|
.thenReturn(mockSocket);
|
|
resolver = new DnsMessageTransport(mockFactory, UPDATE_HOST, Duration.ZERO);
|
|
}
|
|
|
|
@Test
|
|
public void testSentMessageHasCorrectLengthAndContent() throws Exception {
|
|
ByteArrayInputStream inputStream =
|
|
new ByteArrayInputStream(messageToBytesWithLength(expectedResponse));
|
|
when(mockSocket.getInputStream()).thenReturn(inputStream);
|
|
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
|
|
when(mockSocket.getOutputStream()).thenReturn(outputStream);
|
|
|
|
resolver.send(simpleQuery);
|
|
|
|
ByteBuffer sentMessage = ByteBuffer.wrap(outputStream.toByteArray());
|
|
int messageLength = sentMessage.getShort();
|
|
byte[] messageData = new byte[messageLength];
|
|
sentMessage.get(messageData);
|
|
assertThat(messageLength).isEqualTo(simpleQuery.toWire().length);
|
|
assertThat(base16().encode(messageData)).isEqualTo(base16().encode(simpleQuery.toWire()));
|
|
}
|
|
|
|
@Test
|
|
public void testReceivedMessageWithLengthHasCorrectContent() throws Exception {
|
|
ByteArrayInputStream inputStream =
|
|
new ByteArrayInputStream(messageToBytesWithLength(expectedResponse));
|
|
when(mockSocket.getInputStream()).thenReturn(inputStream);
|
|
when(mockSocket.getOutputStream()).thenReturn(new ByteArrayOutputStream());
|
|
|
|
Message actualResponse = resolver.send(simpleQuery);
|
|
|
|
assertThat(base16().encode(actualResponse.toWire()))
|
|
.isEqualTo(base16().encode(expectedResponse.toWire()));
|
|
}
|
|
|
|
@Test
|
|
public void testEofReceivingResponse() throws Exception {
|
|
byte[] messageBytes = messageToBytesWithLength(expectedResponse);
|
|
ByteArrayInputStream inputStream =
|
|
new ByteArrayInputStream(Arrays.copyOf(messageBytes, messageBytes.length - 1));
|
|
when(mockSocket.getInputStream()).thenReturn(inputStream);
|
|
when(mockSocket.getOutputStream()).thenReturn(new ByteArrayOutputStream());
|
|
assertThrows(EOFException.class, () -> resolver.send(new Message()));
|
|
}
|
|
|
|
@Test
|
|
public void testTimeoutReceivingResponse() throws Exception {
|
|
InputStream mockInputStream = mock(InputStream.class);
|
|
when(mockInputStream.read()).thenThrow(new SocketTimeoutException("testing"));
|
|
when(mockSocket.getInputStream()).thenReturn(mockInputStream);
|
|
when(mockSocket.getOutputStream()).thenReturn(new ByteArrayOutputStream());
|
|
|
|
Duration testTimeout = Duration.standardSeconds(1);
|
|
DnsMessageTransport resolver = new DnsMessageTransport(mockFactory, UPDATE_HOST, testTimeout);
|
|
Message expectedQuery = new Message();
|
|
assertThrows(SocketTimeoutException.class, () -> resolver.send(expectedQuery));
|
|
verify(mockSocket).setSoTimeout((int) testTimeout.getMillis());
|
|
}
|
|
|
|
@Test
|
|
public void testSentMessageTooLongThrowsException() throws Exception {
|
|
Update oversize = new Update(Name.fromString("tld", Name.root));
|
|
for (int i = 0; i < 2000; i++) {
|
|
oversize.add(
|
|
ARecord.newRecord(
|
|
Name.fromString("test-extremely-long-name-" + i + ".tld", Name.root),
|
|
Type.A,
|
|
DClass.IN));
|
|
}
|
|
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
|
|
when(mockSocket.getOutputStream()).thenReturn(outputStream);
|
|
IllegalArgumentException thrown =
|
|
assertThrows(IllegalArgumentException.class, () -> resolver.send(oversize));
|
|
assertThat(thrown).hasMessageThat().contains("message larger than maximum");
|
|
}
|
|
|
|
@Test
|
|
public void testResponseIdMismatchThrowsExeption() throws Exception {
|
|
expectedResponse.getHeader().setID(1 + simpleQuery.getHeader().getID());
|
|
when(mockSocket.getInputStream())
|
|
.thenReturn(new ByteArrayInputStream(messageToBytesWithLength(expectedResponse)));
|
|
when(mockSocket.getOutputStream()).thenReturn(new ByteArrayOutputStream());
|
|
VerifyException thrown = assertThrows(VerifyException.class, () -> resolver.send(simpleQuery));
|
|
assertThat(thrown)
|
|
.hasMessageThat()
|
|
.contains(
|
|
"response ID "
|
|
+ expectedResponse.getHeader().getID()
|
|
+ " does not match query ID "
|
|
+ simpleQuery.getHeader().getID());
|
|
}
|
|
|
|
@Test
|
|
public void testResponseOpcodeMismatchThrowsException() throws Exception {
|
|
simpleQuery.getHeader().setOpcode(Opcode.QUERY);
|
|
expectedResponse.getHeader().setOpcode(Opcode.STATUS);
|
|
when(mockSocket.getInputStream())
|
|
.thenReturn(new ByteArrayInputStream(messageToBytesWithLength(expectedResponse)));
|
|
when(mockSocket.getOutputStream()).thenReturn(new ByteArrayOutputStream());
|
|
VerifyException thrown = assertThrows(VerifyException.class, () -> resolver.send(simpleQuery));
|
|
assertThat(thrown)
|
|
.hasMessageThat()
|
|
.contains("response opcode 'STATUS' does not match query opcode 'QUERY'");
|
|
}
|
|
|
|
private Message responseMessageWithCode(Message query, int responseCode) {
|
|
Message message = new Message(query.getHeader().getID());
|
|
message.getHeader().setOpcode(query.getHeader().getOpcode());
|
|
message.getHeader().setFlag(Flags.QR);
|
|
message.getHeader().setRcode(responseCode);
|
|
return message;
|
|
}
|
|
|
|
private byte[] messageToBytesWithLength(Message message) {
|
|
byte[] bytes = message.toWire();
|
|
ByteBuffer buffer =
|
|
ByteBuffer.allocate(bytes.length + DnsMessageTransport.MESSAGE_LENGTH_FIELD_BYTES);
|
|
buffer.putShort((short) bytes.length);
|
|
buffer.put(bytes);
|
|
return buffer.array();
|
|
}
|
|
}
|