Adds DnsWriter that implements DNS UPDATE protocol

* DnsUpdateWriter publishes changes to NS, DS, A, AAAA records
  for domains/hosts as appropriate using RFC 2136 DNS UPDATE protocol
* Static configuration separate from RegistryConfig
* Include dnsjava library as new third party dependency
  to generate DNS protocol messages
* Expose /_dr/task/writeDns in RegistryTestServer
* Currently not included in BackendComponent
This commit is contained in:
Hans Ridder 2016-04-06 08:56:54 -07:00
parent 954d7e1e8f
commit 20f214b9d0
14 changed files with 1006 additions and 2 deletions

View file

@ -0,0 +1,207 @@
// Copyright 2016 The Domain Registry Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.domain.registry.dns.writer.dnsupdate;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.base.VerifyException;
import org.joda.time.Duration;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.xbill.DNS.ARecord;
import org.xbill.DNS.DClass;
import org.xbill.DNS.Flags;
import org.xbill.DNS.Message;
import org.xbill.DNS.Name;
import org.xbill.DNS.Opcode;
import org.xbill.DNS.Rcode;
import org.xbill.DNS.Record;
import org.xbill.DNS.Type;
import org.xbill.DNS.Update;
import org.xbill.DNS.utils.base16;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetAddress;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import javax.net.SocketFactory;
/** Unit tests for {@link DnsMessageTransport}. */
@RunWith(MockitoJUnitRunner.class)
public class DnsMessageTransportTest {
private static final String UPDATE_HOST = "127.0.0.1";
@Mock private SocketFactory mockFactory;
@Mock private Socket mockSocket;
private Message simpleQuery;
private Message expectedResponse;
private DnsMessageTransport resolver;
@Rule public ExpectedException thrown = ExpectedException.none();
@Before
public void before() throws Exception {
simpleQuery =
Message.newQuery(Record.newRecord(Name.fromString("example.com."), Type.A, DClass.IN));
expectedResponse = responseMessageWithCode(simpleQuery, Rcode.NOERROR);
when(mockFactory.createSocket(InetAddress.getByName(UPDATE_HOST), DnsMessageTransport.DNS_PORT))
.thenReturn(mockSocket);
resolver = new DnsMessageTransport(mockFactory, UPDATE_HOST, Duration.ZERO);
}
@Test
public void sentMessageHasCorrectLengthAndContent() throws Exception {
ByteArrayInputStream inputStream =
new ByteArrayInputStream(messageToBytesWithLength(expectedResponse));
when(mockSocket.getInputStream()).thenReturn(inputStream);
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
when(mockSocket.getOutputStream()).thenReturn(outputStream);
resolver.send(simpleQuery);
ByteBuffer sentMessage = ByteBuffer.wrap(outputStream.toByteArray());
int messageLength = sentMessage.getShort();
byte[] messageData = new byte[messageLength];
sentMessage.get(messageData);
assertThat(messageLength).isEqualTo(simpleQuery.toWire().length);
assertThat(base16.toString(messageData)).isEqualTo(base16.toString(simpleQuery.toWire()));
}
@Test
public void receivedMessageWithLengthHasCorrectContent() throws Exception {
ByteArrayInputStream inputStream =
new ByteArrayInputStream(messageToBytesWithLength(expectedResponse));
when(mockSocket.getInputStream()).thenReturn(inputStream);
when(mockSocket.getOutputStream()).thenReturn(new ByteArrayOutputStream());
Message actualResponse = resolver.send(simpleQuery);
assertThat(base16.toString(actualResponse.toWire()))
.isEqualTo(base16.toString(expectedResponse.toWire()));
}
@Test
public void eofReceivingResponse() throws Exception {
byte[] messageBytes = messageToBytesWithLength(expectedResponse);
ByteArrayInputStream inputStream =
new ByteArrayInputStream(Arrays.copyOf(messageBytes, messageBytes.length - 1));
when(mockSocket.getInputStream()).thenReturn(inputStream);
when(mockSocket.getOutputStream()).thenReturn(new ByteArrayOutputStream());
thrown.expect(EOFException.class);
Message expectedQuery = new Message();
resolver.send(expectedQuery);
}
@Test
public void timeoutReceivingResponse() throws Exception {
InputStream mockInputStream = mock(InputStream.class);
when(mockInputStream.read()).thenThrow(new SocketTimeoutException("testing"));
when(mockSocket.getInputStream()).thenReturn(mockInputStream);
when(mockSocket.getOutputStream()).thenReturn(new ByteArrayOutputStream());
Duration testTimeout = Duration.standardSeconds(1);
DnsMessageTransport resolver = new DnsMessageTransport(mockFactory, UPDATE_HOST, testTimeout);
Message expectedQuery = new Message();
try {
resolver.send(expectedQuery);
fail("exception expected");
} catch (SocketTimeoutException e) {
verify(mockSocket).setSoTimeout((int) testTimeout.getMillis());
}
}
@Test
public void sentMessageTooLongThrowsException() throws Exception {
Update oversize = new Update(Name.fromString("tld", Name.root));
for (int i = 0; i < 2000; i++) {
oversize.add(
ARecord.newRecord(
Name.fromString("test-extremely-long-name-" + i + ".tld", Name.root),
Type.A,
DClass.IN));
}
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
when(mockSocket.getOutputStream()).thenReturn(outputStream);
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("message larger than maximum");
resolver.send(oversize);
}
@Test
public void responseIdMismatchThrowsExeption() throws Exception {
expectedResponse.getHeader().setID(1 + simpleQuery.getHeader().getID());
when(mockSocket.getInputStream())
.thenReturn(new ByteArrayInputStream(messageToBytesWithLength(expectedResponse)));
when(mockSocket.getOutputStream()).thenReturn(new ByteArrayOutputStream());
thrown.expect(VerifyException.class);
thrown.expectMessage(
"response ID "
+ expectedResponse.getHeader().getID()
+ " does not match query ID "
+ simpleQuery.getHeader().getID());
resolver.send(simpleQuery);
}
@Test
public void responseOpcodeMismatchThrowsException() throws Exception {
simpleQuery.getHeader().setOpcode(Opcode.QUERY);
expectedResponse.getHeader().setOpcode(Opcode.STATUS);
when(mockSocket.getInputStream())
.thenReturn(new ByteArrayInputStream(messageToBytesWithLength(expectedResponse)));
when(mockSocket.getOutputStream()).thenReturn(new ByteArrayOutputStream());
thrown.expect(VerifyException.class);
thrown.expectMessage("response opcode 'STATUS' does not match query opcode 'QUERY'");
resolver.send(simpleQuery);
}
private Message responseMessageWithCode(Message query, int responseCode) {
Message message = new Message(query.getHeader().getID());
message.getHeader().setOpcode(query.getHeader().getOpcode());
message.getHeader().setFlag(Flags.QR);
message.getHeader().setRcode(responseCode);
return message;
}
private byte[] messageToBytesWithLength(Message message) throws IOException {
byte[] bytes = message.toWire();
ByteBuffer buffer =
ByteBuffer.allocate(bytes.length + DnsMessageTransport.MESSAGE_LENGTH_FIELD_BYTES);
buffer.putShort((short) bytes.length);
buffer.put(bytes);
return buffer.array();
}
}

View file

@ -0,0 +1,302 @@
// Copyright 2016 The Domain Registry Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.domain.registry.dns.writer.dnsupdate;
import static com.google.common.io.BaseEncoding.base16;
import static com.google.common.truth.Truth.assertThat;
import static com.google.domain.registry.testing.DatastoreHelper.createTld;
import static com.google.domain.registry.testing.DatastoreHelper.persistActiveDomain;
import static com.google.domain.registry.testing.DatastoreHelper.persistActiveHost;
import static com.google.domain.registry.testing.DatastoreHelper.persistActiveSubordinateHost;
import static com.google.domain.registry.testing.DatastoreHelper.persistDeletedDomain;
import static com.google.domain.registry.testing.DatastoreHelper.persistDeletedHost;
import static com.google.domain.registry.testing.DatastoreHelper.persistResource;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.net.InetAddresses;
import com.google.domain.registry.model.domain.DomainResource;
import com.google.domain.registry.model.domain.ReferenceUnion;
import com.google.domain.registry.model.domain.secdns.DelegationSignerData;
import com.google.domain.registry.model.eppcommon.StatusValue;
import com.google.domain.registry.model.host.HostResource;
import com.google.domain.registry.model.ofy.Ofy;
import com.google.domain.registry.testing.AppEngineRule;
import com.google.domain.registry.testing.FakeClock;
import com.google.domain.registry.testing.InjectRule;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import org.xbill.DNS.Flags;
import org.xbill.DNS.Message;
import org.xbill.DNS.Opcode;
import org.xbill.DNS.RRset;
import org.xbill.DNS.Rcode;
import org.xbill.DNS.Record;
import org.xbill.DNS.Section;
import org.xbill.DNS.Type;
import org.xbill.DNS.Update;
import java.util.ArrayList;
import java.util.Iterator;
import junit.framework.AssertionFailedError;
/** Unit tests for {@link DnsUpdateWriter}. */
@RunWith(MockitoJUnitRunner.class)
public class DnsUpdateWriterTest {
@Rule
public final AppEngineRule appEngine =
AppEngineRule.builder().withDatastore().withTaskQueue().build();
@Rule public ExpectedException thrown = ExpectedException.none();
@Rule public final InjectRule inject = new InjectRule();
private final FakeClock clock = new FakeClock(DateTime.parse("1971-01-01TZ"));
@Mock private DnsMessageTransport mockResolver;
@Captor private ArgumentCaptor<Update> updateCaptor;
private DelegationSignerData testSignerData =
DelegationSignerData.create(1, 3, 1, base16().decode("0123456789ABCDEF"));
private DnsUpdateWriter writer;
@Before
public void setUp() throws Exception {
inject.setStaticField(Ofy.class, "clock", clock);
createTld("tld");
when(mockResolver.send(any(Update.class))).thenReturn(messageWithResponseCode(Rcode.NOERROR));
writer = new DnsUpdateWriter(Duration.ZERO, mockResolver, clock);
}
@Test
public void publishDomainCreatePublishesNameServers() throws Exception {
HostResource host1 = persistActiveHost("ns1.example.tld");
HostResource host2 = persistActiveHost("ns2.example.tld");
DomainResource domain =
persistActiveDomain("example.tld")
.asBuilder()
.setNameservers(
ImmutableSet.of(ReferenceUnion.create(host1), ReferenceUnion.create(host2)))
.build();
persistResource(domain);
writer.publishDomain("example.tld");
verify(mockResolver).send(updateCaptor.capture());
Update update = updateCaptor.getValue();
assertThatUpdatedZoneIs(update, "tld.");
assertThatUpdateDeletes(update, "example.tld.", Type.ANY);
assertThatUpdateAdds(update, "example.tld.", Type.NS, "ns1.example.tld.", "ns2.example.tld.");
assertThatTotalUpdateSetsIs(update, 2); // The delete and NS sets
}
@Test
public void publishDomainCreatePublishesDelegationSigner() throws Exception {
DomainResource domain =
persistActiveDomain("example.tld")
.asBuilder()
.setNameservers(
ImmutableSet.of(ReferenceUnion.create(persistActiveHost("ns1.example.tld"))))
.setDsData(ImmutableSet.of(testSignerData))
.build();
persistResource(domain);
writer.publishDomain("example.tld");
verify(mockResolver).send(updateCaptor.capture());
Update update = updateCaptor.getValue();
assertThatUpdatedZoneIs(update, "tld.");
assertThatUpdateDeletes(update, "example.tld.", Type.ANY);
assertThatUpdateAdds(update, "example.tld.", Type.NS, "ns1.example.tld.");
assertThatUpdateAdds(update, "example.tld.", Type.DS, "1 3 1 0123456789ABCDEF");
assertThatTotalUpdateSetsIs(update, 3); // The delete, the NS, and DS sets
}
@Test
public void publishDomainWhenNotActiveRemovesDnsRecords() throws Exception {
DomainResource domain =
persistActiveDomain("example.tld")
.asBuilder()
.addStatusValue(StatusValue.SERVER_HOLD)
.setNameservers(
ImmutableSet.of(ReferenceUnion.create(persistActiveHost("ns1.example.tld"))))
.build();
persistResource(domain);
writer.publishDomain("example.tld");
verify(mockResolver).send(updateCaptor.capture());
Update update = updateCaptor.getValue();
assertThatUpdatedZoneIs(update, "tld.");
assertThatUpdateDeletes(update, "example.tld.", Type.ANY);
assertThatTotalUpdateSetsIs(update, 1); // Just the delete set
}
@Test
public void publishDomainDeleteRemovesDnsRecords() throws Exception {
persistDeletedDomain("example.tld", clock.nowUtc());
writer.publishDomain("example.tld");
verify(mockResolver).send(updateCaptor.capture());
Update update = updateCaptor.getValue();
assertThatUpdatedZoneIs(update, "tld.");
assertThatUpdateDeletes(update, "example.tld.", Type.ANY);
assertThatTotalUpdateSetsIs(update, 1); // Just the delete set
}
@Test
public void publishHostCreatePublishesAddressRecords() throws Exception {
HostResource host =
persistActiveSubordinateHost("ns1.example.tld", persistActiveDomain("example.tld"))
.asBuilder()
.setInetAddresses(
ImmutableSet.of(
InetAddresses.forString("10.0.0.1"),
InetAddresses.forString("10.1.0.1"),
InetAddresses.forString("fd0e:a5c8:6dfb:6a5e:0:0:0:1")))
.build();
persistResource(host);
writer.publishHost("ns1.example.tld");
verify(mockResolver).send(updateCaptor.capture());
Update update = updateCaptor.getValue();
assertThatUpdatedZoneIs(update, "tld.");
assertThatUpdateDeletes(update, "ns1.example.tld.", Type.ANY);
assertThatUpdateAdds(update, "ns1.example.tld.", Type.A, "10.0.0.1", "10.1.0.1");
assertThatUpdateAdds(update, "ns1.example.tld.", Type.AAAA, "fd0e:a5c8:6dfb:6a5e:0:0:0:1");
assertThatTotalUpdateSetsIs(update, 3); // The delete, the A, and AAAA sets
}
@Test
public void publishHostDeleteRemovesDnsRecords() throws Exception {
persistDeletedHost("ns1.example.tld", clock.nowUtc());
writer.publishHost("ns1.example.tld");
verify(mockResolver).send(updateCaptor.capture());
Update update = updateCaptor.getValue();
assertThatUpdatedZoneIs(update, "tld.");
assertThatUpdateDeletes(update, "ns1.example.tld.", Type.ANY);
assertThatTotalUpdateSetsIs(update, 1); // Just the delete set
}
@Test
public void publishDomainFailsWhenDnsUpdateReturnsError() throws Exception {
DomainResource domain =
persistActiveDomain("example.tld")
.asBuilder()
.setNameservers(
ImmutableSet.of(ReferenceUnion.create(persistActiveHost("ns1.example.tld"))))
.build();
persistResource(domain);
when(mockResolver.send(any(Message.class))).thenReturn(messageWithResponseCode(Rcode.SERVFAIL));
thrown.expect(VerifyException.class);
thrown.expectMessage("SERVFAIL");
writer.publishDomain("example.tld");
}
@Test
public void publishHostFailsWhenDnsUpdateReturnsError() throws Exception {
HostResource host =
persistActiveSubordinateHost("ns1.example.tld", persistActiveDomain("example.tld"))
.asBuilder()
.setInetAddresses(ImmutableSet.of(InetAddresses.forString("10.0.0.1")))
.build();
persistResource(host);
when(mockResolver.send(any(Message.class))).thenReturn(messageWithResponseCode(Rcode.SERVFAIL));
thrown.expect(VerifyException.class);
thrown.expectMessage("SERVFAIL");
writer.publishHost("ns1.example.tld");
}
private void assertThatUpdatedZoneIs(Update update, String zoneName) {
Record[] zoneRecords = update.getSectionArray(Section.ZONE);
assertThat(zoneRecords[0].getName().toString()).isEqualTo(zoneName);
}
private void assertThatTotalUpdateSetsIs(Update update, int count) {
assertThat(update.getSectionRRsets(Section.UPDATE)).hasLength(count);
}
private void assertThatUpdateDeletes(Update update, String resourceName, int recordType) {
ImmutableList<Record> deleted = findUpdateRecords(update, resourceName, recordType);
// There's only an empty (i.e. "delete") record.
assertThat(deleted.get(0).rdataToString()).hasLength(0);
assertThat(deleted).hasSize(1);
}
private void assertThatUpdateAdds(
Update update, String resourceName, int recordType, String... resourceData) {
ArrayList<String> expectedData = new ArrayList<>();
for (String resourceDatum : resourceData) {
expectedData.add(resourceDatum.toLowerCase());
}
ArrayList<String> actualData = new ArrayList<>();
for (Record record : findUpdateRecords(update, resourceName, recordType)) {
actualData.add(record.rdataToString().toLowerCase());
}
assertThat(actualData).containsExactlyElementsIn(expectedData);
}
private ImmutableList<Record> findUpdateRecords(
Update update, String resourceName, int recordType) {
for (RRset set : update.getSectionRRsets(Section.UPDATE)) {
if (set.getName().toString().equals(resourceName) && set.getType() == recordType) {
return fixIterator(Record.class, set.rrs());
}
}
throw new AssertionFailedError(
"no record set found for resource '"
+ resourceName
+ "', type '"
+ Type.string(recordType)
+ "'");
}
@SuppressWarnings({"unchecked", "unused"})
private static <T> ImmutableList<T> fixIterator(Class<T> clazz, final Iterator<?> iterator) {
return ImmutableList.copyOf((Iterator<T>) iterator);
}
private Message messageWithResponseCode(int responseCode) {
Message message = new Message();
message.getHeader().setOpcode(Opcode.UPDATE);
message.getHeader().setFlag(Flags.QR);
message.getHeader().setRcode(responseCode);
return message;
}
}

View file

@ -68,6 +68,10 @@ public final class RegistryTestServer {
route("/_dr/task/nordnVerify",
com.google.domain.registry.module.backend.BackendServlet.class),
// Process DNS pull queue
route("/_dr/task/writeDns",
com.google.domain.registry.module.backend.BackendServlet.class),
// Registrar Console
route("/registrar", com.google.domain.registry.module.frontend.FrontendServlet.class),
route("/registrar-settings",