review feedback: error checking, timeouts, add .gov if missing

This commit is contained in:
Neil Martinsen-Burrell 2022-11-01 10:34:48 -05:00
parent 5d9a469ebd
commit 61b4cbf10b
No known key found for this signature in database
GPG key ID: 6A3C818CC10D0184
3 changed files with 77 additions and 14 deletions

View file

@ -2,11 +2,14 @@
import json import json
from django.core.exceptions import BadRequest
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.test import TestCase, RequestFactory from django.test import TestCase, RequestFactory
from ..views import available, _domains, in_domains from ..views import available, _domains, in_domains
API_BASE_PATH = "/api/v1/available/"
class AvailableViewTest(TestCase): class AvailableViewTest(TestCase):
@ -17,7 +20,7 @@ class AvailableViewTest(TestCase):
self.factory = RequestFactory() self.factory = RequestFactory()
def test_view_function(self): def test_view_function(self):
request = self.factory.get("/available/test.gov") request = self.factory.get(API_BASE_PATH + "test.gov")
request.user = self.user request.user = self.user
response = available(request, domain="test.gov") response = available(request, domain="test.gov")
# has the right text in it # has the right text in it
@ -27,7 +30,11 @@ class AvailableViewTest(TestCase):
self.assertIn("available", response_object) self.assertIn("available", response_object)
def test_domain_list(self): def test_domain_list(self):
"""Test the domain list that is returned.""" """Test the domain list that is returned from Github.
This does not mock out the external file, it is actually fetched from
the internet.
"""
domains = _domains() domains = _domains()
self.assertIn("gsa.gov", domains) self.assertIn("gsa.gov", domains)
# entries are all lowercase so GSA.GOV is not in the set # entries are all lowercase so GSA.GOV is not in the set
@ -42,23 +49,44 @@ class AvailableViewTest(TestCase):
self.assertTrue(in_domains("GSA.GOV")) self.assertTrue(in_domains("GSA.GOV"))
# This domain should not have been registered # This domain should not have been registered
self.assertFalse(in_domains("igorville.gov")) self.assertFalse(in_domains("igorville.gov"))
# all the entries have dots
self.assertFalse(in_domains("gsa")) def test_in_domains_dotgov(self):
"""Domain searches work without trailing .gov"""
self.assertTrue(in_domains("gsa"))
# input is lowercased so GSA.GOV should be found
self.assertTrue(in_domains("GSA"))
# This domain should not have been registered
self.assertFalse(in_domains("igorville"))
def test_not_available_domain(self): def test_not_available_domain(self):
"""gsa.gov is not available""" """gsa.gov is not available"""
request = self.factory.get("/available/gsa.gov") request = self.factory.get(API_BASE_PATH + "gsa.gov")
request.user = self.user request.user = self.user
response = available(request, domain="gsa.gov") response = available(request, domain="gsa.gov")
self.assertFalse(json.loads(response.content)["available"]) self.assertFalse(json.loads(response.content)["available"])
def test_available_domain(self): def test_available_domain(self):
"""igorville.gov is still available""" """igorville.gov is still available"""
request = self.factory.get("/available/igorville.gov") request = self.factory.get(API_BASE_PATH + "igorville.gov")
request.user = self.user request.user = self.user
response = available(request, domain="igorville.gov") response = available(request, domain="igorville.gov")
self.assertTrue(json.loads(response.content)["available"]) self.assertTrue(json.loads(response.content)["available"])
def test_available_domain_dotgov(self):
"""igorville.gov is still available even without the .gov suffix"""
request = self.factory.get(API_BASE_PATH + "igorville")
request.user = self.user
response = available(request, domain="igorville")
self.assertTrue(json.loads(response.content)["available"])
def test_error_handling(self):
"""Calling with bad strings raises an error."""
bad_string = "blah!;"
request = self.factory.get(API_BASE_PATH + bad_string)
request.user = self.user
with self.assertRaisesMessage(BadRequest, "Invalid"):
available(request, domain=bad_string)
class AvailableAPITest(TestCase): class AvailableAPITest(TestCase):
@ -69,12 +97,17 @@ class AvailableAPITest(TestCase):
def test_available_get(self): def test_available_get(self):
self.client.force_login(self.user) self.client.force_login(self.user)
response = self.client.get("/available/nonsense") response = self.client.get(API_BASE_PATH + "nonsense")
self.assertContains(response, "available") self.assertContains(response, "available")
response_object = json.loads(response.content) response_object = json.loads(response.content)
self.assertIn("available", response_object) self.assertIn("available", response_object)
def test_available_post(self): def test_available_post(self):
"""Cannot post to the /available/ API endpoint.""" """Cannot post to the /available/ API endpoint."""
response = self.client.post("/available/nonsense") response = self.client.post(API_BASE_PATH + "nonsense")
self.assertEqual(response.status_code, 405) self.assertEqual(response.status_code, 405)
def test_available_bad_input(self):
self.client.force_login(self.user)
response = self.client.get(API_BASE_PATH + "blah!;")
self.assertEqual(response.status_code, 400)

View file

@ -1,6 +1,9 @@
"""Internal API views""" """Internal API views"""
import re
from django.core.exceptions import BadRequest
from django.views.decorators.http import require_http_methods from django.views.decorators.http import require_http_methods
from django.http import JsonResponse from django.http import JsonResponse
@ -13,6 +16,19 @@ from cachetools.func import ttl_cache
DOMAIN_FILE_URL = ( DOMAIN_FILE_URL = (
"https://raw.githubusercontent.com/cisagov/dotgov-data/main/current-full.csv" "https://raw.githubusercontent.com/cisagov/dotgov-data/main/current-full.csv"
) )
# a domain name is alphanumeric or hyphen, up to 63 characters, doesn't
# begin or end with a hyphen, followed by a TLD of 2-6 alphabetic characters
DOMAIN_REGEX = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(?<!-)\.[A-Za-z]{2,6}")
def string_could_be_domain(domain):
"""Return True if the string could be a domain name, otherwise False.
TODO: when we have a Domain class, this could be a classmethod there.
"""
if DOMAIN_REGEX.match(domain):
return True
return False
# this file doesn't change that often, nor is it that big, so cache the result # this file doesn't change that often, nor is it that big, so cache the result
@ -24,23 +40,33 @@ def _domains():
Fetch a file from DOMAIN_FILE_URL, parse the CSV for the domain, Fetch a file from DOMAIN_FILE_URL, parse the CSV for the domain,
lowercase everything and return the list. lowercase everything and return the list.
""" """
file_contents = requests.get(DOMAIN_FILE_URL).text # 5 second timeout
file_contents = requests.get(DOMAIN_FILE_URL, timeout=5).text
domains = set() domains = set()
# skip the first line # skip the first line
for line in file_contents.splitlines()[1:]: for line in file_contents.splitlines()[1:]:
# get the domain before the first comma # get the domain before the first comma
domain = line.split(",", 1)[0] domain = line.split(",", 1)[0]
# lowercase everything # sanity-check the string we got from the file here
domains.add(domain.lower()) if string_could_be_domain(domain):
# lowercase everything when we put it in domains
domains.add(domain.lower())
return domains return domains
def in_domains(domain): def in_domains(domain):
"""Return true if the given domain is in the domains list. """Return true if the given domain is in the domains list.
The given domain is lowercased to match against the domains list. The given domain is lowercased to match against the domains list. If the
given domain doesn't end with .gov, ".gov" is added when looking for
a match.
""" """
return domain.lower() in _domains() domain = domain.lower()
if domain.endswith(".gov"):
return domain.lower() in _domains()
else:
# domain search string doesn't end with .gov, add it on here
return (domain + ".gov") in _domains()
@require_http_methods(["GET"]) @require_http_methods(["GET"])
@ -52,5 +78,9 @@ def available(request, domain=""):
Response is a JSON dictionary with the key "available" and value true or Response is a JSON dictionary with the key "available" and value true or
false. false.
""" """
# validate that the given domain could be a domain name and fail early if
# not.
if not (string_could_be_domain(domain) or string_could_be_domain(domain + ".gov")):
raise BadRequest("Invalid request.")
# a domain is available if it is NOT in the list of current domains # a domain is available if it is NOT in the list of current domains
return JsonResponse({"available": not in_domains(domain)}) return JsonResponse({"available": not in_domains(domain)})

View file

@ -27,7 +27,7 @@ urlpatterns = [
path("openid/", include("djangooidc.urls")), path("openid/", include("djangooidc.urls")),
path("register/", application_wizard, name="application"), path("register/", application_wizard, name="application"),
path("register/<step>/", application_wizard, name=APPLICATION_URL_NAME), path("register/<step>/", application_wizard, name=APPLICATION_URL_NAME),
path("available/<domain>", available, name="available"), path("api/v1/available/<domain>", available, name="available"),
] ]
if not settings.DEBUG: if not settings.DEBUG: