diff --git a/src/registrar/tests/test_views_domains_json.py b/src/registrar/tests/test_views_domains_json.py index 70ae23b43..07799104b 100644 --- a/src/registrar/tests/test_views_domains_json.py +++ b/src/registrar/tests/test_views_domains_json.py @@ -1,9 +1,13 @@ from registrar.models import UserDomainRole, Domain, DomainInformation, Portfolio from django.urls import reverse + +from registrar.models.user_portfolio_permission import UserPortfolioPermission +from registrar.models.utility.portfolio_helper import UserPortfolioPermissionChoices, UserPortfolioRoleChoices from .test_views import TestWithUser from django_webtest import WebTest # type: ignore from django.utils.dateparse import parse_date from api.tests.common import less_console_noise_decorator +from waffle.testutils import override_flag class GetDomainsJsonTest(TestWithUser, WebTest): @@ -31,6 +35,7 @@ class GetDomainsJsonTest(TestWithUser, WebTest): def tearDown(self): UserDomainRole.objects.all().delete() + UserPortfolioPermission.objects.all().delete() DomainInformation.objects.all().delete() Portfolio.objects.all().delete() super().tearDown() @@ -115,8 +120,104 @@ class GetDomainsJsonTest(TestWithUser, WebTest): self.assertEqual(svg_icon_expected, svg_icons[i]) @less_console_noise_decorator - def test_get_domains_json_with_portfolio(self): - """Test that an authenticated user gets the list of 2 domains for portfolio.""" + @override_flag("organization_feature", active=True) + def test_get_domains_json_with_portfolio_view_managed_domains(self): + """Test that an authenticated user gets the list of 1 domain for portfolio. The 1 domain + is the domain that they manage within the portfolio.""" + + UserPortfolioPermission.objects.get_or_create( + user=self.user, + portfolio=self.portfolio, + roles=[UserPortfolioRoleChoices.ORGANIZATION_MEMBER], + additional_permissions=[UserPortfolioPermissionChoices.VIEW_MANAGED_DOMAINS], + ) + + response = self.app.get(reverse("get_domains_json"), {"portfolio": self.portfolio.id}) + self.assertEqual(response.status_code, 200) + data = response.json + + # Check pagination info + self.assertEqual(data["page"], 1) + self.assertFalse(data["has_next"]) + self.assertFalse(data["has_previous"]) + self.assertEqual(data["num_pages"], 1) + + # Check the number of domains + self.assertEqual(len(data["domains"]), 1) + + # Expected domains + expected_domains = [self.domain3] + + # Extract fields from response + domain_ids = [domain["id"] for domain in data["domains"]] + names = [domain["name"] for domain in data["domains"]] + expiration_dates = [domain["expiration_date"] for domain in data["domains"]] + states = [domain["state"] for domain in data["domains"]] + state_displays = [domain["state_display"] for domain in data["domains"]] + get_state_help_texts = [domain["get_state_help_text"] for domain in data["domains"]] + action_urls = [domain["action_url"] for domain in data["domains"]] + action_labels = [domain["action_label"] for domain in data["domains"]] + svg_icons = [domain["svg_icon"] for domain in data["domains"]] + + # Check fields for each domain + for i, expected_domain in enumerate(expected_domains): + self.assertEqual(expected_domain.id, domain_ids[i]) + self.assertEqual(expected_domain.name, names[i]) + self.assertEqual(expected_domain.expiration_date, expiration_dates[i]) + self.assertEqual(expected_domain.state, states[i]) + + # Parsing the expiration date from string to date + parsed_expiration_date = parse_date(expiration_dates[i]) + expected_domain.expiration_date = parsed_expiration_date + + # Check state_display and get_state_help_text + self.assertEqual(expected_domain.state_display(), state_displays[i]) + self.assertEqual(expected_domain.get_state_help_text(), get_state_help_texts[i]) + + self.assertEqual(reverse("domain", kwargs={"pk": expected_domain.id}), action_urls[i]) + + # Check action_label + user_domain_role_exists = UserDomainRole.objects.filter( + domain_id=expected_domains[i].id, user=self.user + ).exists() + action_label_expected = ( + "View" + if not user_domain_role_exists + or expected_domains[i].state + in [ + Domain.State.DELETED, + Domain.State.ON_HOLD, + ] + else "Manage" + ) + self.assertEqual(action_label_expected, action_labels[i]) + + # Check svg_icon + svg_icon_expected = ( + "visibility" + if not user_domain_role_exists + or expected_domains[i].state + in [ + Domain.State.DELETED, + Domain.State.ON_HOLD, + ] + else "settings" + ) + self.assertEqual(svg_icon_expected, svg_icons[i]) + + @less_console_noise_decorator + @override_flag("organization_feature", active=True) + def test_get_domains_json_with_portfolio_view_all_domains(self): + """Test that an authenticated user gets the list of 2 domains for portfolio. One is a domain which + they manage within the portfolio. The other is a domain which they don't manage within the + portfolio.""" + + UserPortfolioPermission.objects.get_or_create( + user=self.user, + portfolio=self.portfolio, + roles=[UserPortfolioRoleChoices.ORGANIZATION_MEMBER], + additional_permissions=[UserPortfolioPermissionChoices.VIEW_ALL_DOMAINS], + ) response = self.app.get(reverse("get_domains_json"), {"portfolio": self.portfolio.id}) self.assertEqual(response.status_code, 200) diff --git a/src/registrar/views/domains_json.py b/src/registrar/views/domains_json.py index e72aac5fb..b26e92802 100644 --- a/src/registrar/views/domains_json.py +++ b/src/registrar/views/domains_json.py @@ -50,11 +50,15 @@ def get_domain_ids_from_request(request): """ portfolio = request.GET.get("portfolio") if portfolio: - domain_infos = DomainInformation.objects.filter(portfolio=portfolio) - return domain_infos.values_list("domain_id", flat=True) - else: - user_domain_roles = UserDomainRole.objects.filter(user=request.user) - return user_domain_roles.values_list("domain_id", flat=True) + if request.user.is_org_user(request) and request.user.has_view_all_domains_permission(portfolio): + domain_infos = DomainInformation.objects.filter(portfolio=portfolio) + return domain_infos.values_list("domain_id", flat=True) + else: + domain_info_ids = DomainInformation.objects.filter(portfolio=portfolio).values_list("domain_id", flat=True) + user_domain_roles = UserDomainRole.objects.filter(user=request.user).values_list("domain_id", flat=True) + return domain_info_ids.intersection(user_domain_roles) + user_domain_roles = UserDomainRole.objects.filter(user=request.user) + return user_domain_roles.values_list("domain_id", flat=True) def apply_search(queryset, request):