Basic setup stuff

This commit is contained in:
zandercymatics 2024-05-09 11:42:18 -06:00
parent c9a735bf6a
commit d268ef54b1
No known key found for this signature in database
GPG key ID: FF4636ABEC9682B7
8 changed files with 110 additions and 13 deletions

View file

@ -21,10 +21,13 @@ class OpenIdConnectBackend(ModelBackend):
""" """
def authenticate(self, request, **kwargs): def authenticate(self, request, **kwargs):
"""Returns a tuple of (User, is_new_user)"""
logger.debug("kwargs %s" % kwargs) logger.debug("kwargs %s" % kwargs)
user = None user = None
is_new_user = True
if not kwargs or "sub" not in kwargs.keys(): if not kwargs or "sub" not in kwargs.keys():
return user return user, is_new_user
UserModel = get_user_model() UserModel = get_user_model()
username = self.clean_username(kwargs["sub"]) username = self.clean_username(kwargs["sub"])
@ -48,6 +51,7 @@ class OpenIdConnectBackend(ModelBackend):
} }
user, created = UserModel.objects.get_or_create(**args) user, created = UserModel.objects.get_or_create(**args)
is_new_user = created
if not created: if not created:
# If user exists, update existing user # If user exists, update existing user
@ -59,10 +63,10 @@ class OpenIdConnectBackend(ModelBackend):
try: try:
user = UserModel.objects.get_by_natural_key(username) user = UserModel.objects.get_by_natural_key(username)
except UserModel.DoesNotExist: except UserModel.DoesNotExist:
return None return None, is_new_user
# run this callback for a each login # run this callback for a each login
user.on_each_login() user.on_each_login()
return user return user, is_new_user
def update_existing_user(self, user, kwargs): def update_existing_user(self, user, kwargs):
""" """

View file

@ -21,7 +21,7 @@ class OpenIdConnectBackendTestCase(TestCase):
"""Test that authenticate creates a new user if it does not find """Test that authenticate creates a new user if it does not find
existing user""" existing user"""
# Ensure that the authenticate method creates a new user # Ensure that the authenticate method creates a new user
user = self.backend.authenticate(request=None, **self.kwargs) user, _ = self.backend.authenticate(request=None, **self.kwargs)
self.assertIsNotNone(user) self.assertIsNotNone(user)
self.assertIsInstance(user, User) self.assertIsInstance(user, User)
self.assertEqual(user.username, "test_user") self.assertEqual(user.username, "test_user")
@ -39,7 +39,7 @@ class OpenIdConnectBackendTestCase(TestCase):
existing_user = User.objects.create_user(username="test_user") existing_user = User.objects.create_user(username="test_user")
# Ensure that the authenticate method updates the existing user # Ensure that the authenticate method updates the existing user
user = self.backend.authenticate(request=None, **self.kwargs) user, _ = self.backend.authenticate(request=None, **self.kwargs)
self.assertIsNotNone(user) self.assertIsNotNone(user)
self.assertIsInstance(user, User) self.assertIsInstance(user, User)
self.assertEqual(user, existing_user) # The same user instance should be returned self.assertEqual(user, existing_user) # The same user instance should be returned
@ -68,7 +68,7 @@ class OpenIdConnectBackendTestCase(TestCase):
# Ensure that the authenticate method updates the existing user # Ensure that the authenticate method updates the existing user
# and preserves existing first and last names # and preserves existing first and last names
user = self.backend.authenticate(request=None, **self.kwargs) user, _ = self.backend.authenticate(request=None, **self.kwargs)
self.assertIsNotNone(user) self.assertIsNotNone(user)
self.assertIsInstance(user, User) self.assertIsInstance(user, User)
self.assertEqual(user, existing_user) # The same user instance should be returned self.assertEqual(user, existing_user) # The same user instance should be returned
@ -89,7 +89,7 @@ class OpenIdConnectBackendTestCase(TestCase):
# Ensure that the authenticate method updates the existing user # Ensure that the authenticate method updates the existing user
# and preserves existing first and last names # and preserves existing first and last names
user = self.backend.authenticate(request=None, **self.kwargs) user, _ = self.backend.authenticate(request=None, **self.kwargs)
self.assertIsNotNone(user) self.assertIsNotNone(user)
self.assertIsInstance(user, User) self.assertIsInstance(user, User)
self.assertEqual(user, existing_user) # The same user instance should be returned self.assertEqual(user, existing_user) # The same user instance should be returned
@ -103,5 +103,5 @@ class OpenIdConnectBackendTestCase(TestCase):
def test_authenticate_with_unknown_user(self): def test_authenticate_with_unknown_user(self):
"""Test that authenticate returns None when no kwargs are supplied""" """Test that authenticate returns None when no kwargs are supplied"""
# Ensure that the authenticate method handles the case when the user is not found # Ensure that the authenticate method handles the case when the user is not found
user = self.backend.authenticate(request=None, **{}) user, _ = self.backend.authenticate(request=None, **{})
self.assertIsNone(user) self.assertIsNone(user)

View file

@ -85,6 +85,7 @@ def login_callback(request):
"""Analyze the token returned by the authentication provider (OP).""" """Analyze the token returned by the authentication provider (OP)."""
global CLIENT global CLIENT
try: try:
request.session["is_new_user"] = False
# If the CLIENT is none, attempt to reinitialize before handling the request # If the CLIENT is none, attempt to reinitialize before handling the request
if _client_is_none(): if _client_is_none():
logger.debug("OIDC client is None, attempting to initialize") logger.debug("OIDC client is None, attempting to initialize")
@ -97,9 +98,9 @@ def login_callback(request):
# add acr_value to request.session # add acr_value to request.session
request.session["acr_value"] = CLIENT.get_step_up_acr_value() request.session["acr_value"] = CLIENT.get_step_up_acr_value()
return CLIENT.create_authn_request(request.session) return CLIENT.create_authn_request(request.session)
user = authenticate(request=request, **userinfo) user, is_new_user = authenticate(request=request, **userinfo)
if user: if user:
should_update_user = False
# Fixture users kind of exist in a superposition of verification types, # Fixture users kind of exist in a superposition of verification types,
# because while the system "verified" them, if they login, # because while the system "verified" them, if they login,
# we don't know how the user themselves was verified through login.gov until # we don't know how the user themselves was verified through login.gov until
@ -110,9 +111,17 @@ def login_callback(request):
# Set the verification type if it doesn't already exist or if its a fixture user # Set the verification type if it doesn't already exist or if its a fixture user
if not user.verification_type or is_fixture_user: if not user.verification_type or is_fixture_user:
user.set_user_verification_type() user.set_user_verification_type()
should_update_user = True
if is_new_user:
user.finished_setup = False
should_update_user = True
if should_update_user:
user.save() user.save()
login(request, user) login(request, user)
logger.info("Successfully logged in user %s" % user) logger.info("Successfully logged in user %s" % user)
# Clear the flag if the exception is not caught # Clear the flag if the exception is not caught

View file

@ -0,0 +1,18 @@
# Generated by Django 4.2.10 on 2024-05-09 17:42
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("registrar", "0093_alter_publiccontact_unique_together"),
]
operations = [
migrations.AddField(
model_name="user",
name="finished_setup",
field=models.BooleanField(default=True),
),
]

View file

@ -80,6 +80,13 @@ class User(AbstractUser):
help_text="The means through which this user was verified", help_text="The means through which this user was verified",
) )
# Tracks if the user finished their profile setup or not. This is so
# we can globally enforce that new users provide additional context before proceeding.
finished_setup = models.BooleanField(
# Default to true so we don't impact existing users. We set this to false downstream.
default=True
)
def __str__(self): def __str__(self):
# this info is pulled from Login.gov # this info is pulled from Login.gov
if self.first_name or self.last_name: if self.first_name or self.last_name:

View file

@ -14,7 +14,7 @@ from registrar.models.contact import Contact
from registrar.models.user import User from registrar.models.user import User
from registrar.utility import StrEnum from registrar.utility import StrEnum
from registrar.views.utility import StepsHelper from registrar.views.utility import StepsHelper
from registrar.views.utility.permission_views import DomainRequestPermissionDeleteView from registrar.views.utility.permission_views import DomainRequestPermissionDeleteView, ContactPermissionView
from .utility import ( from .utility import (
DomainRequestPermissionView, DomainRequestPermissionView,
@ -819,3 +819,10 @@ class DomainRequestDeleteView(DomainRequestPermissionDeleteView):
duplicates = [item for item, count in object_dict.items() if count > 1] duplicates = [item for item, count in object_dict.items() if count > 1]
return duplicates return duplicates
class FinishContactProfileSetupView(ContactPermissionView):
"""This view forces the user into providing additional details that
we may have missed from Login.gov"""
template_name = "domain_request_your_contact.html"
forms = [forms.YourContactForm]

View file

@ -8,6 +8,7 @@ from registrar.models import (
DomainInvitation, DomainInvitation,
DomainInformation, DomainInformation,
UserDomainRole, UserDomainRole,
Contact,
) )
import logging import logging
@ -324,6 +325,38 @@ class UserDeleteDomainRolePermission(PermissionsLoginMixin):
return True return True
class ContactPermission(PermissionsLoginMixin):
"""Permission mixin for UserDomainRole if user
has access, otherwise 403"""
def has_permission(self):
"""Check if this user has access to this domain request.
The user is in self.request.user and the domain needs to be looked
up from the domain's primary key in self.kwargs["pk"]
"""
# Check if the user is authenticated
if not self.request.user.is_authenticated:
return False
user_pk = self.kwargs["pk"]
# Check if the user has an associated contact
associated_contacts = Contact.objects.filter(user=user_pk)
associated_contacts_length = len(associated_contacts)
if associated_contacts_length == 0:
# This means that the user trying to access this page
# is a different user than the contact holder.
return False
elif associated_contacts_length > 1:
# TODO - change this
raise ValueError("User has multiple connected contacts")
else:
return True
class DomainRequestPermissionWithdraw(PermissionsLoginMixin): class DomainRequestPermissionWithdraw(PermissionsLoginMixin):
"""Permission mixin that redirects to withdraw action on domain request """Permission mixin that redirects to withdraw action on domain request
if user has access, otherwise 403""" if user has access, otherwise 403"""

View file

@ -3,8 +3,7 @@
import abc # abstract base class import abc # abstract base class
from django.views.generic import DetailView, DeleteView, TemplateView from django.views.generic import DetailView, DeleteView, TemplateView
from registrar.models import Domain, DomainRequest, DomainInvitation from registrar.models import Domain, DomainRequest, DomainInvitation, UserDomainRole, Contact
from registrar.models.user_domain_role import UserDomainRole
from .mixins import ( from .mixins import (
DomainPermission, DomainPermission,
@ -13,6 +12,7 @@ from .mixins import (
DomainInvitationPermission, DomainInvitationPermission,
DomainRequestWizardPermission, DomainRequestWizardPermission,
UserDeleteDomainRolePermission, UserDeleteDomainRolePermission,
ContactPermission,
) )
import logging import logging
@ -142,3 +142,22 @@ class UserDomainRolePermissionDeleteView(UserDeleteDomainRolePermission, DeleteV
# variable name in template context for the model object # variable name in template context for the model object
context_object_name = "userdomainrole" context_object_name = "userdomainrole"
class ContactPermissionView(ContactPermission, DetailView, abc.ABC):
"""Abstract base view for domain requests that enforces permissions
This abstract view cannot be instantiated. Actual views must specify
`template_name`.
"""
# DetailView property for what model this is viewing
model = Contact
# variable name in template context for the model object
context_object_name = "Contact"
# Abstract property enforces NotImplementedError on an attribute.
@property
@abc.abstractmethod
def template_name(self):
raise NotImplementedError