From ccf963f9d957c481fcc0738b421692485c5550c3 Mon Sep 17 00:00:00 2001 From: dogboat Date: Mon, 13 Apr 2026 11:12:04 -0400 Subject: [PATCH 01/47] wip --- ...ntity_hash_alter_urlevent_identity_hash.py | 24 +++ dojo/importers/location_manager.py | 156 +++++++++++++++--- dojo/location/models.py | 93 ++++++++++- dojo/url/models.py | 16 +- 4 files changed, 248 insertions(+), 41 deletions(-) create mode 100644 dojo/db_migrations/0264_alter_url_identity_hash_alter_urlevent_identity_hash.py diff --git a/dojo/db_migrations/0264_alter_url_identity_hash_alter_urlevent_identity_hash.py b/dojo/db_migrations/0264_alter_url_identity_hash_alter_urlevent_identity_hash.py new file mode 100644 index 00000000000..725612dbcc7 --- /dev/null +++ b/dojo/db_migrations/0264_alter_url_identity_hash_alter_urlevent_identity_hash.py @@ -0,0 +1,24 @@ +# Generated by Django 5.2.12 on 2026-04-10 14:24 + +import django.core.validators +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('dojo', '0263_language_type_unique_language'), + ] + + operations = [ + migrations.AlterField( + model_name='url', + name='identity_hash', + field=models.CharField(db_index=True, editable=False, help_text='The hash of the location for uniqueness', max_length=64, unique=True, validators=[django.core.validators.MinLengthValidator(64)]), + ), + migrations.AlterField( + model_name='urlevent', + name='identity_hash', + field=models.CharField(editable=False, help_text='The hash of the location for uniqueness', max_length=64, validators=[django.core.validators.MinLengthValidator(64)]), + ), + ] diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 125ff922dc0..a59c7ec9508 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import logging -from typing import TypeVar +from itertools import groupby +from typing import TYPE_CHECKING, TypeVar from django.core.exceptions import ValidationError from django.db.models import QuerySet @@ -7,8 +10,8 @@ from dojo.celery import app from dojo.celery_dispatch import dojo_dispatch_task -from dojo.location.models import AbstractLocation, LocationFindingReference -from dojo.location.status import FindingLocationStatus +from dojo.location.models import AbstractLocation, LocationFindingReference, LocationProductReference +from dojo.location.status import FindingLocationStatus, ProductLocationStatus from dojo.models import ( Dojo_User, Finding, @@ -16,6 +19,9 @@ from dojo.tools.locations import LocationData from dojo.url.models import URL +if TYPE_CHECKING: + from dojo.models import Product + logger = logging.getLogger(__name__) @@ -34,17 +40,23 @@ def get_or_create_location(cls, unsaved_location: AbstractLocation) -> AbstractL logger.debug(f"IMPORT_SCAN: Unsupported location type: {type(unsaved_location)}") return None + @classmethod + def get_supported_location_types(cls) -> dict[str, type[AbstractLocation]]: + """Return a mapping of location type string to AbstractLocation subclass.""" + return {URL.get_location_type(): URL} + @classmethod def make_abstract_locations(cls, locations: list[UnsavedLocation]) -> list[AbstractLocation]: """Converts the list of unsaved locations (AbstractLocation/LocationData objects) to a list of AbstractLocations.""" + supported_types = cls.get_supported_location_types() abstract_locations = [] for location in locations: if isinstance(location, AbstractLocation): abstract_locations.append(location) - elif isinstance(location, LocationData) and location.type == URL.get_location_type(): + elif isinstance(location, LocationData) and (loc_cls := supported_types.get(location.type)): try: - abstract_locations.append(URL.from_location_data(location)) + abstract_locations.append(loc_cls.from_location_data(location)) except (ValidationError, ValueError): logger.debug("Skipping invalid location data: %s", location) else: @@ -52,6 +64,116 @@ def make_abstract_locations(cls, locations: list[UnsavedLocation]) -> list[Abstr return abstract_locations + @classmethod + def bulk_get_or_create_locations(cls, locations: list[UnsavedLocation]) -> list[AbstractLocation]: + """Bulk get-or-create a (possibly heterogeneous) list of AbstractLocations.""" + locations = cls.clean_unsaved_locations(locations) + if not locations: + return [] + + # Util method for sorting/keying; returns the (Python) identity of the location entry's Type + def type_id(x: tuple[int, AbstractLocation]) -> int: + return id(type(x[1])) + + saved = [] + # Group by actual AbstractLocation subtype, tracking the original ordering (hence the `enumerate`) + locations_with_idx = sorted(enumerate(locations), key=type_id) + locations_by_type = groupby(locations_with_idx, key=type_id) + for _, grouped_locations_with_idx in locations_by_type: + # Split into two lists: indices and homogenous location types + indices, grouped_locations = zip(*grouped_locations_with_idx) + # Determine the correct AbstractLocation class to use for bulk get/create + loc_cls = type(grouped_locations[0]) + # `.bulk_get_or_create` is expected to return the saved items in the order they were submitted + saved_locations = loc_cls.bulk_get_or_create(grouped_locations) + # Zip 'em back together: associate the saved instance with its original index in the `locations` list + saved.extend((idx, saved_loc) for idx, saved_loc in zip(indices, saved_locations)) + + # Sort by index to return in original order + saved.sort(key=lambda x: x[0]) + return [loc for _, loc in saved] + + @classmethod + def bulk_create_refs( + cls, + locations: list[AbstractLocation], + *, + finding: Finding | None = None, + product: Product | None = None, + ) -> None: + """Bulk create LocationFindingReference and/or LocationProductReference rows. + + Iterates the unsaved/saved pairs once, building both finding and product + refs in a single pass. Skips refs that already exist in the DB. + """ + if not locations: + return + + if not finding and not product: + error_message = "One of 'finding' or 'product' must be provided." + raise ValueError(error_message) + + if finding: + # If associating with a finding, use its product regardless of whatever's set. Keeps in line with the + # original intended purpose: this is a bulk version of Location.(associate_with_finding|associate_with_product) + product = finding.test.engagement.product + + location_ids = [loc.location_id for loc in locations] + + # Pre-fetch existing refs to avoid duplicates + existing_finding_refs = set() + existing_product_refs = set() + if finding is not None: + existing_finding_refs = set( + LocationFindingReference.objects.filter( + location_id__in=location_ids, + finding=finding, + ).values_list("location_id", flat=True) + ) + if product is not None: + existing_product_refs = set( + LocationProductReference.objects.filter( + location_id__in=location_ids, + product=product, + ).values_list("location_id", flat=True) + ) + + new_finding_refs = [] + new_product_refs = [] + # Process locations (unsaved, with possible association data) alongside their corresponding saved versions, + # which do not contain that information. We can do this because the bulk get/create operations are stable. + for location in locations: + assoc = location.get_association_data() + + if finding is not None and location.location_id not in existing_finding_refs: + new_finding_refs.append(LocationFindingReference( + location_id=location.location_id, + finding=finding, + status=FindingLocationStatus.Active, + relationship=assoc.relationship_type, + relationship_data=assoc.relationship_data, + )) + existing_finding_refs.add(location.location_id) + + if product is not None and location.location_id not in existing_product_refs: + new_product_refs.append(LocationProductReference( + location_id=location.location_id, + product=product, + status=ProductLocationStatus.Active, + relationship=assoc.relationship_type, + relationship_data=assoc.relationship_data, + )) + existing_product_refs.add(location.location_id) + + if new_finding_refs: + LocationFindingReference.objects.bulk_create( + new_finding_refs, batch_size=1000, ignore_conflicts=True, + ) + if new_product_refs: + LocationProductReference.objects.bulk_create( + new_product_refs, batch_size=1000, ignore_conflicts=True, + ) + @classmethod def _add_locations_to_unsaved_finding( cls, @@ -59,26 +181,10 @@ def _add_locations_to_unsaved_finding( locations: list[UnsavedLocation], **kwargs: dict, # noqa: ARG003 ) -> None: - """Creates AbstractLocation objects from the given list and links them to the given finding.""" - locations = cls.clean_unsaved_locations(locations) - - logger.debug(f"IMPORT_SCAN: Adding {len(locations)} locations to finding: {finding}") - - # LOCATION LOCATION LOCATION - # TODO: bulk create the finding/product refs... - locations_saved = 0 - for unsaved_location in locations: - if saved_location := cls.get_or_create_location(unsaved_location): - locations_saved += 1 - association_data = unsaved_location.get_association_data() - saved_location.location.associate_with_finding( - finding, - status=FindingLocationStatus.Active, - relationship=association_data.relationship_type, - relationship_data=association_data.relationship_data, - ) - - logger.debug(f"IMPORT_SCAN: {locations_saved} locations imported") + """Creates AbstractLocation objects from the given list and links them to the given Finding and its Product.""" + locations = cls.bulk_get_or_create_locations(locations) + cls.bulk_create_refs(locations, finding=finding, product=finding.test.engagement.product) + logger.debug(f"LocationManager: {len(locations)} locations associated with {finding}") @app.task def add_locations_to_unsaved_finding( diff --git a/dojo/location/models.py b/dojo/location/models.py index 3ab313ace87..7a4721b52a0 100644 --- a/dojo/location/models.py +++ b/dojo/location/models.py @@ -1,7 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Self, TypeVar +import hashlib +from typing import TYPE_CHECKING, Self, TypeVar, Iterable +from django.core.validators import MinLengthValidator from django.db import transaction from django.db.models import ( CASCADE, @@ -263,10 +265,24 @@ class AbstractLocation(BaseModelWithoutTimeMeta): null=False, related_name="%(class)s", ) + identity_hash = CharField( + null=False, + blank=False, + max_length=64, + editable=False, + unique=True, + db_index=True, + validators=[MinLengthValidator(64)], + help_text=_("The hash of the location for uniqueness"), + ) class Meta: abstract = True + def clean(self, *args: list, **kwargs: dict) -> None: + self.set_identity_hash() + super().clean(*args, **kwargs) + @classmethod def get_location_type(cls) -> str: """Return the type of location (e.g., 'url').""" @@ -287,6 +303,9 @@ def create_location_from_value(value: str) -> T: msg = "Subclasses must implement create_location_from_value" raise NotImplementedError(msg) + def set_identity_hash(self): + self.identity_hash = hashlib.blake2b(str(self).encode(), digest_size=32).hexdigest() + def pre_save_logic(self): """Automatically create or update the associated Location.""" location_value = self.get_location_value() @@ -333,6 +352,78 @@ def get_or_create_from_object(cls: T, location: T) -> T: msg = "Subclasses must implement get_or_create_from_object" raise NotImplementedError(msg) + @classmethod + def bulk_get_or_create(cls: type[T], locations: Iterable[T]) -> list[T]: + """Get or create multiple locations in bulk. + + For each location, looks up by identity_hash. Creates missing ones using + bulk_create for both the parent Location rows and the subtype rows. + Returns the full list of saved instances (existing + newly created), + in the same order as the input. Duplicate inputs map to the same saved instance. + """ + if not locations: + return [] + + # Create the list of hashes of the supplied locations; we will also use this to reconstruct the initial ordering + # of locations we return (which would otherwise be lost if duplicates are represented in `locations`). + hashes = [] + for loc in locations: + # Sanity check the given locations list is homogenous + if not isinstance(loc, cls): + error_message = f"Invalid location type; expected {cls} but got {type(loc)}" + raise ValueError(error_message) + # Set .identity_hash if not present + if not loc.identity_hash: + loc.clean() + hashes.append(loc.identity_hash) + + # Look up existing objects, grouping by hash + existing_by_hash = { + obj.identity_hash: obj + for obj in cls.objects.filter(identity_hash__in=hashes).select_related("location") + } + + # Create the list of new locations to create + new_locations = [] + for loc in locations: + if loc.identity_hash not in existing_by_hash: + new_locations.append(loc) + # Mark it so we don't try to create duplicates within the same batch + existing_by_hash[loc.identity_hash] = loc + else: + # Preserve association data from the input onto the existing saved object + saved = existing_by_hash[loc.identity_hash] + if hasattr(loc, "_association_data") and not hasattr(saved, "_association_data"): + saved._association_data = loc._association_data + + # Create 'em + if new_locations: + location_type = cls.get_location_type() + with transaction.atomic(): + # Bulk create parent Locations + parents = [ + Location( + location_type=location_type, + location_value=loc.get_location_value(), + ) + for loc in new_locations + ] + Location.objects.bulk_create(parents, batch_size=1000) + + # Assign Location FKs to the subtypes, then bulk create them. + for loc, parent in zip(new_locations, parents): + loc.location_id = parent.id + loc.location = parent + # Note there is a subtle race condition here, if somehow one of our newly-created locations conflicts + # with an existing one (e.g. from a separate thread that commits while this is running). Setting + # `ignore_conflicts=True` here would prevent this step from raising an IntegrityError, but would leave + # dangling parent Location objects that were created above. Rather than performing a cleanup in that + # (unlikely?) case, just allow the transaction to rollback. + cls.objects.bulk_create(new_locations, batch_size=1000) + + # Return in input order (minus dupes) + return [existing_by_hash[h] for h in hashes] + class ReferenceDataMixin(Model): diff --git a/dojo/url/models.py b/dojo/url/models.py index 06f7e2cd008..cc5ae338ddf 100644 --- a/dojo/url/models.py +++ b/dojo/url/models.py @@ -1,6 +1,5 @@ from __future__ import annotations -import hashlib import ipaddress from contextlib import suppress from dataclasses import dataclass @@ -9,7 +8,7 @@ import idna from django.core.exceptions import ValidationError -from django.core.validators import MaxValueValidator, MinLengthValidator, MinValueValidator +from django.core.validators import MaxValueValidator, MinValueValidator from django.db import IntegrityError, transaction from django.db.models import ( BooleanField, @@ -185,15 +184,6 @@ class URL(AbstractLocation): blank=False, help_text="Dictates whether the endpoint was found to have host validation issues during creation", ) - identity_hash = CharField( - null=False, - blank=False, - max_length=64, - editable=False, - unique=True, - validators=[MinLengthValidator(64)], - help_text="The hash of the URL for uniqueness", - ) objects = URLManager().from_queryset(URLQueryset)() @@ -262,7 +252,6 @@ def clean(self, *args: list, **kwargs: dict) -> None: self.clean_path() self.clean_query() self.clean_fragment() - self.set_identity_hash() super().clean(*args, **kwargs) def clean_protocol(self) -> None: @@ -323,9 +312,6 @@ def clean_query(self) -> None: else: self.query = self.replace_null_bytes(self.query.strip().removeprefix("?")) - def set_identity_hash(self): - self.identity_hash = hashlib.blake2b(str(self).encode(), digest_size=32).hexdigest() - def replace_null_bytes(self, value: str) -> str: return value.replace("\x00", "%00") From 695877a7501c0a1b067f5b6ad4df3e843faa5b6c Mon Sep 17 00:00:00 2001 From: dogboat Date: Tue, 14 Apr 2026 12:54:48 -0400 Subject: [PATCH 02/47] comment --- dojo/location/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dojo/location/models.py b/dojo/location/models.py index 7a4721b52a0..4c88ce7deef 100644 --- a/dojo/location/models.py +++ b/dojo/location/models.py @@ -391,7 +391,8 @@ def bulk_get_or_create(cls: type[T], locations: Iterable[T]) -> list[T]: # Mark it so we don't try to create duplicates within the same batch existing_by_hash[loc.identity_hash] = loc else: - # Preserve association data from the input onto the existing saved object + # Preserve association data from the input onto the existing saved object, in case we're associating + # existing locations with findings/products saved = existing_by_hash[loc.identity_hash] if hasattr(loc, "_association_data") and not hasattr(saved, "_association_data"): saved._association_data = loc._association_data From 1b8b02646ff2f86e62e72376bbf7713da83a771c Mon Sep 17 00:00:00 2001 From: dogboat Date: Tue, 14 Apr 2026 12:55:44 -0400 Subject: [PATCH 03/47] simplify --- dojo/importers/location_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index a59c7ec9508..cafeca5a345 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -183,7 +183,7 @@ def _add_locations_to_unsaved_finding( ) -> None: """Creates AbstractLocation objects from the given list and links them to the given Finding and its Product.""" locations = cls.bulk_get_or_create_locations(locations) - cls.bulk_create_refs(locations, finding=finding, product=finding.test.engagement.product) + cls.bulk_create_refs(locations, finding=finding) logger.debug(f"LocationManager: {len(locations)} locations associated with {finding}") @app.task From 23c18397ad7cab69c0d5a5dd6c988958676bf2ff Mon Sep 17 00:00:00 2001 From: dogboat Date: Tue, 14 Apr 2026 13:32:40 -0400 Subject: [PATCH 04/47] tests --- unittests/test_bulk_locations.py | 290 +++++++++++++++++++++++++++++++ 1 file changed, 290 insertions(+) create mode 100644 unittests/test_bulk_locations.py diff --git a/unittests/test_bulk_locations.py b/unittests/test_bulk_locations.py new file mode 100644 index 00000000000..0a57661a1b1 --- /dev/null +++ b/unittests/test_bulk_locations.py @@ -0,0 +1,290 @@ +"""Tests for bulk location creation and association (open-source, URL-only). + +Covers: +- AbstractLocation.bulk_get_or_create (on URL) +- LocationManager.bulk_get_or_create_locations (URL-only) +- LocationManager.bulk_create_refs (finding + product refs) +- LocationManager._add_locations_to_unsaved_finding (end-to-end) +- Query efficiency +""" + +from unittest.mock import patch + +from django.db import connection +from django.test.utils import CaptureQueriesContext + +from dojo.importers.location_manager import LocationManager +from dojo.location.models import Location, LocationFindingReference, LocationProductReference +from dojo.models import Product, Product_Type, Test_Type +from dojo.tools.locations import LocationAssociationData, LocationData +from dojo.url.models import URL +from unittests.dojo_test_case import DojoTestCase, skip_unless_v3 + + +def _make_url(host, path=""): + url = URL(protocol="https", host=host, path=path) + url.clean() + return url + + +def _make_finding(): + from django.contrib.auth import get_user_model + from dojo.models import Engagement, Finding, Test + + User = get_user_model() + user, _ = User.objects.get_or_create(username="bulk_test_user", defaults={"is_active": True}) + pt, _ = Product_Type.objects.get_or_create(name="Bulk Test Type") + product = Product.objects.create(name="Bulk Test Product", description="test", prod_type=pt) + eng = Engagement.objects.create(product=product, target_start="2026-01-01", target_end="2026-12-31") + tt, _ = Test_Type.objects.get_or_create(name="Bulk Test") + test = Test.objects.create(engagement=eng, test_type=tt, target_start="2026-01-01", target_end="2026-12-31") + return Finding.objects.create(test=test, title="Bulk Test Finding", severity="Medium", reporter=user) + + +# --------------------------------------------------------------------------- +# AbstractLocation.bulk_get_or_create (URL) +# --------------------------------------------------------------------------- +@skip_unless_v3 +class TestBulkGetOrCreateURL(DojoTestCase): + + def test_all_new(self): + urls = [_make_url(f"oss-new-{i}.example.com") for i in range(5)] + saved = URL.bulk_get_or_create(urls) + + self.assertEqual(len(saved), 5) + self.assertTrue(all(s.pk is not None for s in saved)) + self.assertTrue(all(s.location_id is not None for s in saved)) + self.assertEqual(URL.objects.filter(pk__in=[s.pk for s in saved]).count(), 5) + + def test_all_existing(self): + originals = [URL.get_or_create_from_object(_make_url(f"oss-existing-{i}.example.com")) for i in range(3)] + + urls = [_make_url(f"oss-existing-{i}.example.com") for i in range(3)] + saved = URL.bulk_get_or_create(urls) + + self.assertEqual(len(saved), 3) + self.assertEqual({s.pk for s in saved}, {o.pk for o in originals}) + + def test_mixed_new_and_existing(self): + existing = URL.get_or_create_from_object(_make_url("oss-mixed-existing.example.com")) + + urls = [ + _make_url("oss-mixed-existing.example.com"), + _make_url("oss-mixed-new.example.com"), + ] + saved = URL.bulk_get_or_create(urls) + + self.assertEqual(len(saved), 2) + self.assertEqual(saved[0].pk, existing.pk) + self.assertIsNotNone(saved[1].pk) + + def test_duplicates_in_input(self): + urls = [ + _make_url("oss-dupe.example.com"), + _make_url("oss-dupe.example.com"), + _make_url("oss-unique.example.com"), + ] + saved = URL.bulk_get_or_create(urls) + + self.assertEqual(len(saved), 3) + self.assertEqual(saved[0].pk, saved[1].pk) + self.assertNotEqual(saved[2].pk, saved[0].pk) + self.assertEqual(URL.objects.filter(host__in=["oss-dupe.example.com", "oss-unique.example.com"]).count(), 2) + + def test_preserves_association_data_on_new(self): + url = _make_url("oss-assoc-new.example.com") + url._association_data = LocationAssociationData( + relationship_type="owned_by", + relationship_data={"file_path": "/src/main.py"}, + ) + + saved = URL.bulk_get_or_create([url]) + + self.assertEqual(saved[0].get_association_data().relationship_type, "owned_by") + + def test_copies_association_data_to_existing(self): + URL.get_or_create_from_object(_make_url("oss-assoc-existing.example.com")) + + url = _make_url("oss-assoc-existing.example.com") + url._association_data = LocationAssociationData(relationship_type="used_by") + + saved = URL.bulk_get_or_create([url]) + + self.assertEqual(saved[0].get_association_data().relationship_type, "used_by") + + def test_empty_input(self): + self.assertEqual(URL.bulk_get_or_create([]), []) + + def test_parent_location_created(self): + saved = URL.bulk_get_or_create([_make_url("oss-parent.example.com")]) + + loc = Location.objects.get(pk=saved[0].location_id) + self.assertEqual(loc.location_type, "url") + self.assertIn("oss-parent.example.com", loc.location_value) + + def test_transaction_atomicity(self): + initial_count = Location.objects.count() + urls = [_make_url("oss-atomic.example.com")] + + with patch.object(URL.objects, "bulk_create", side_effect=Exception("boom")): + with self.assertRaisesMessage(Exception, "boom"): + URL.bulk_get_or_create(urls) + + self.assertEqual(Location.objects.count(), initial_count) + + +# --------------------------------------------------------------------------- +# LocationManager.bulk_get_or_create_locations (URL-only) +# --------------------------------------------------------------------------- +@skip_unless_v3 +class TestBulkGetOrCreateLocations(DojoTestCase): + + def test_url_only(self): + urls = [_make_url("oss-loc-mgr.example.com")] + saved = LocationManager.bulk_get_or_create_locations(urls) + + self.assertEqual(len(saved), 1) + self.assertIsInstance(saved[0], URL) + + def test_cleans_location_data(self): + loc_data = LocationData(type="url", data={"url": "https://oss-from-data.example.com/api"}) + saved = LocationManager.bulk_get_or_create_locations([loc_data]) + + self.assertEqual(len(saved), 1) + self.assertIsInstance(saved[0], URL) + self.assertEqual(saved[0].host, "oss-from-data.example.com") + + def test_empty_input(self): + self.assertEqual(LocationManager.bulk_get_or_create_locations([]), []) + + +# --------------------------------------------------------------------------- +# LocationManager.bulk_create_refs +# --------------------------------------------------------------------------- +@skip_unless_v3 +class TestBulkCreateRefs(DojoTestCase): + + def test_creates_finding_and_product_refs(self): + finding = _make_finding() + product = finding.test.engagement.product + + saved = URL.bulk_get_or_create([_make_url("oss-refs-both.example.com")]) + LocationManager.bulk_create_refs(saved, finding=finding) + + self.assertTrue(LocationFindingReference.objects.filter( + location_id=saved[0].location_id, finding=finding, + ).exists()) + self.assertTrue(LocationProductReference.objects.filter( + location_id=saved[0].location_id, product=product, + ).exists()) + + def test_creates_product_refs_only(self): + pt, _ = Product_Type.objects.get_or_create(name="Refs Test Type") + product = Product.objects.create(name="Refs Test Product", description="test", prod_type=pt) + + saved = URL.bulk_get_or_create([_make_url("oss-refs-product.example.com")]) + LocationManager.bulk_create_refs(saved, product=product) + + self.assertTrue(LocationProductReference.objects.filter( + location_id=saved[0].location_id, product=product, + ).exists()) + self.assertFalse(LocationFindingReference.objects.filter( + location_id=saved[0].location_id, + ).exists()) + + def test_skips_existing_refs(self): + finding = _make_finding() + saved = URL.bulk_get_or_create([_make_url("oss-refs-existing.example.com")]) + + LocationManager.bulk_create_refs(saved, finding=finding) + LocationManager.bulk_create_refs(saved, finding=finding) + + self.assertEqual(LocationFindingReference.objects.filter( + location_id=saved[0].location_id, finding=finding, + ).count(), 1) + + def test_uses_association_data(self): + finding = _make_finding() + url = _make_url("oss-refs-assoc.example.com") + url._association_data = LocationAssociationData( + relationship_type="owned_by", + relationship_data={"file_path": "/app/main.py"}, + ) + saved = URL.bulk_get_or_create([url]) + LocationManager.bulk_create_refs(saved, finding=finding) + + ref = LocationFindingReference.objects.get( + location_id=saved[0].location_id, finding=finding, + ) + self.assertEqual(ref.relationship, "owned_by") + self.assertEqual(ref.relationship_data, {"file_path": "/app/main.py"}) + + def test_raises_without_finding_or_product(self): + saved = URL.bulk_get_or_create([_make_url("oss-refs-error.example.com")]) + with self.assertRaises(ValueError): + LocationManager.bulk_create_refs(saved) + + def test_empty_locations(self): + LocationManager.bulk_create_refs([], finding=_make_finding()) + + def test_finding_implies_product(self): + finding = _make_finding() + product = finding.test.engagement.product + saved = URL.bulk_get_or_create([_make_url("oss-refs-implied.example.com")]) + + LocationManager.bulk_create_refs(saved, finding=finding) + + self.assertTrue(LocationProductReference.objects.filter( + location_id=saved[0].location_id, product=product, + ).exists()) + + +# --------------------------------------------------------------------------- +# End-to-end: _add_locations_to_unsaved_finding +# --------------------------------------------------------------------------- +@skip_unless_v3 +class TestAddLocationsToUnsavedFinding(DojoTestCase): + + def test_full_pipeline(self): + finding = _make_finding() + product = finding.test.engagement.product + + loc_data = [ + LocationData(type="url", data={"url": "https://oss-e2e-1.example.com/api"}), + LocationData(type="url", data={"url": "https://oss-e2e-2.example.com/api"}), + ] + + LocationManager._add_locations_to_unsaved_finding(finding, loc_data) + + self.assertEqual(LocationFindingReference.objects.filter(finding=finding).count(), 2) + self.assertEqual(LocationProductReference.objects.filter(product=product).count(), 2) + + def test_empty_locations(self): + finding = _make_finding() + LocationManager._add_locations_to_unsaved_finding(finding, []) + self.assertEqual(LocationFindingReference.objects.filter(finding=finding).count(), 0) + + def test_idempotent(self): + finding = _make_finding() + loc_data = [LocationData(type="url", data={"url": "https://oss-idempotent.example.com"})] + + LocationManager._add_locations_to_unsaved_finding(finding, loc_data) + LocationManager._add_locations_to_unsaved_finding(finding, loc_data) + + self.assertEqual(LocationFindingReference.objects.filter(finding=finding).count(), 1) + + +# --------------------------------------------------------------------------- +# Query efficiency +# --------------------------------------------------------------------------- +@skip_unless_v3 +class TestBulkQueryEfficiency(DojoTestCase): + + def test_bulk_fewer_queries_than_locations(self): + urls = [_make_url(f"oss-perf-{i}.example.com") for i in range(50)] + + with CaptureQueriesContext(connection) as ctx: + URL.bulk_get_or_create(urls) + + # Expected: ~3 queries (SELECT existing, INSERT parents, INSERT subtypes) + self.assertLess(len(ctx.captured_queries), 10) From 479d0de7df42e81e2181c174e77d75a81d5c0694 Mon Sep 17 00:00:00 2001 From: dogboat Date: Tue, 14 Apr 2026 13:47:35 -0400 Subject: [PATCH 05/47] linter fixes --- dojo/importers/location_manager.py | 23 +++++++++++------------ dojo/location/models.py | 24 +++++++++++------------- unittests/test_bulk_locations.py | 18 +++++++++--------- 3 files changed, 31 insertions(+), 34 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index cafeca5a345..a48e39b728f 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -2,25 +2,23 @@ import logging from itertools import groupby +from operator import itemgetter from typing import TYPE_CHECKING, TypeVar from django.core.exceptions import ValidationError -from django.db.models import QuerySet from django.utils import timezone from dojo.celery import app from dojo.celery_dispatch import dojo_dispatch_task from dojo.location.models import AbstractLocation, LocationFindingReference, LocationProductReference from dojo.location.status import FindingLocationStatus, ProductLocationStatus -from dojo.models import ( - Dojo_User, - Finding, -) from dojo.tools.locations import LocationData from dojo.url.models import URL if TYPE_CHECKING: - from dojo.models import Product + from django.db.models import QuerySet + + from dojo.models import Dojo_User, Finding, Product logger = logging.getLogger(__name__) @@ -81,16 +79,16 @@ def type_id(x: tuple[int, AbstractLocation]) -> int: locations_by_type = groupby(locations_with_idx, key=type_id) for _, grouped_locations_with_idx in locations_by_type: # Split into two lists: indices and homogenous location types - indices, grouped_locations = zip(*grouped_locations_with_idx) + indices, grouped_locations = zip(*grouped_locations_with_idx, strict=True) # Determine the correct AbstractLocation class to use for bulk get/create loc_cls = type(grouped_locations[0]) # `.bulk_get_or_create` is expected to return the saved items in the order they were submitted saved_locations = loc_cls.bulk_get_or_create(grouped_locations) # Zip 'em back together: associate the saved instance with its original index in the `locations` list - saved.extend((idx, saved_loc) for idx, saved_loc in zip(indices, saved_locations)) + saved.extend((idx, saved_loc) for idx, saved_loc in zip(indices, saved_locations, strict=True)) # Sort by index to return in original order - saved.sort(key=lambda x: x[0]) + saved.sort(key=itemgetter(0)) return [loc for _, loc in saved] @classmethod @@ -101,7 +99,8 @@ def bulk_create_refs( finding: Finding | None = None, product: Product | None = None, ) -> None: - """Bulk create LocationFindingReference and/or LocationProductReference rows. + """ + Bulk create LocationFindingReference and/or LocationProductReference rows. Iterates the unsaved/saved pairs once, building both finding and product refs in a single pass. Skips refs that already exist in the DB. @@ -128,14 +127,14 @@ def bulk_create_refs( LocationFindingReference.objects.filter( location_id__in=location_ids, finding=finding, - ).values_list("location_id", flat=True) + ).values_list("location_id", flat=True), ) if product is not None: existing_product_refs = set( LocationProductReference.objects.filter( location_id__in=location_ids, product=product, - ).values_list("location_id", flat=True) + ).values_list("location_id", flat=True), ) new_finding_refs = [] diff --git a/dojo/location/models.py b/dojo/location/models.py index 4c88ce7deef..112d2c8b644 100644 --- a/dojo/location/models.py +++ b/dojo/location/models.py @@ -1,7 +1,7 @@ from __future__ import annotations import hashlib -from typing import TYPE_CHECKING, Self, TypeVar, Iterable +from typing import TYPE_CHECKING, Self from django.core.validators import MinLengthValidator from django.db import transaction @@ -38,6 +38,7 @@ from dojo.tools.locations import LocationAssociationData if TYPE_CHECKING: + from collections.abc import Iterable from datetime import datetime from dojo.tools.locations import LocationData @@ -253,10 +254,6 @@ class Meta: ] -# TypeVar to help linting in AbstractLocation child classes -T = TypeVar("T", bound="AbstractLocation") - - class AbstractLocation(BaseModelWithoutTimeMeta): location = OneToOneField( Location, @@ -295,7 +292,7 @@ def get_location_value(self) -> str: raise NotImplementedError(msg) @staticmethod - def create_location_from_value(value: str) -> T: + def create_location_from_value(value: str) -> Self: """ Dynamically create a Location and subclass instance based on location_type and location_value. Uses parse_string_value from the correct subclass. @@ -322,7 +319,7 @@ def pre_save_logic(self): self.location.save(update_fields=["location_type", "location_value"]) @classmethod - def from_location_data(cls: T, location_data: LocationData) -> T: + def from_location_data(cls, location_data: LocationData) -> Self: """ Checks that the given LocationData object represents this type, then calls #_from_location_data_impl() to build one based on its contents. Saving boilerplate checking is all. @@ -333,7 +330,7 @@ def from_location_data(cls: T, location_data: LocationData) -> T: return cls._from_location_data_impl(location_data) @classmethod - def _from_location_data_impl(cls: T, location_data: LocationData) -> T: + def _from_location_data_impl(cls, location_data: LocationData) -> Self: """Given a LocationData object trusted to represent this type, build a Location object from its contents.""" msg = "Subclasses must implement _from_location_data_impl" raise NotImplementedError(msg) @@ -347,14 +344,15 @@ def get_association_data(self) -> LocationAssociationData: return getattr(self, "_association_data", LocationAssociationData()) @classmethod - def get_or_create_from_object(cls: T, location: T) -> T: + def get_or_create_from_object(cls, location: Self) -> Self: """Given an object of this type, this method should get/create the object and return it.""" msg = "Subclasses must implement get_or_create_from_object" raise NotImplementedError(msg) @classmethod - def bulk_get_or_create(cls: type[T], locations: Iterable[T]) -> list[T]: - """Get or create multiple locations in bulk. + def bulk_get_or_create(cls, locations: Iterable[Self]) -> list[Self]: + """ + Get or create multiple locations in bulk. For each location, looks up by identity_hash. Creates missing ones using bulk_create for both the parent Location rows and the subtype rows. @@ -371,7 +369,7 @@ def bulk_get_or_create(cls: type[T], locations: Iterable[T]) -> list[T]: # Sanity check the given locations list is homogenous if not isinstance(loc, cls): error_message = f"Invalid location type; expected {cls} but got {type(loc)}" - raise ValueError(error_message) + raise TypeError(error_message) # Set .identity_hash if not present if not loc.identity_hash: loc.clean() @@ -412,7 +410,7 @@ def bulk_get_or_create(cls: type[T], locations: Iterable[T]) -> list[T]: Location.objects.bulk_create(parents, batch_size=1000) # Assign Location FKs to the subtypes, then bulk create them. - for loc, parent in zip(new_locations, parents): + for loc, parent in zip(new_locations, parents, strict=True): loc.location_id = parent.id loc.location = parent # Note there is a subtle race condition here, if somehow one of our newly-created locations conflicts diff --git a/unittests/test_bulk_locations.py b/unittests/test_bulk_locations.py index 0a57661a1b1..20a6ff90d41 100644 --- a/unittests/test_bulk_locations.py +++ b/unittests/test_bulk_locations.py @@ -1,4 +1,5 @@ -"""Tests for bulk location creation and association (open-source, URL-only). +""" +Tests for bulk location creation and association (open-source, URL-only). Covers: - AbstractLocation.bulk_get_or_create (on URL) @@ -10,16 +11,19 @@ from unittest.mock import patch +from django.contrib.auth import get_user_model from django.db import connection from django.test.utils import CaptureQueriesContext from dojo.importers.location_manager import LocationManager from dojo.location.models import Location, LocationFindingReference, LocationProductReference -from dojo.models import Product, Product_Type, Test_Type +from dojo.models import Engagement, Finding, Product, Product_Type, Test, Test_Type from dojo.tools.locations import LocationAssociationData, LocationData from dojo.url.models import URL from unittests.dojo_test_case import DojoTestCase, skip_unless_v3 +User = get_user_model() + def _make_url(host, path=""): url = URL(protocol="https", host=host, path=path) @@ -28,10 +32,6 @@ def _make_url(host, path=""): def _make_finding(): - from django.contrib.auth import get_user_model - from dojo.models import Engagement, Finding, Test - - User = get_user_model() user, _ = User.objects.get_or_create(username="bulk_test_user", defaults={"is_active": True}) pt, _ = Product_Type.objects.get_or_create(name="Bulk Test Type") product = Product.objects.create(name="Bulk Test Product", description="test", prod_type=pt) @@ -126,9 +126,9 @@ def test_transaction_atomicity(self): initial_count = Location.objects.count() urls = [_make_url("oss-atomic.example.com")] - with patch.object(URL.objects, "bulk_create", side_effect=Exception("boom")): - with self.assertRaisesMessage(Exception, "boom"): - URL.bulk_get_or_create(urls) + with patch.object(URL.objects, "bulk_create", side_effect=Exception("boom")), \ + self.assertRaisesMessage(Exception, "boom"): + URL.bulk_get_or_create(urls) self.assertEqual(Location.objects.count(), initial_count) From 70c8449213ce46bb77a29d951fa346b2bcd498fa Mon Sep 17 00:00:00 2001 From: dogboat Date: Tue, 14 Apr 2026 13:49:48 -0400 Subject: [PATCH 06/47] comments --- dojo/importers/location_manager.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index a48e39b728f..82d210a7364 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -82,7 +82,6 @@ def type_id(x: tuple[int, AbstractLocation]) -> int: indices, grouped_locations = zip(*grouped_locations_with_idx, strict=True) # Determine the correct AbstractLocation class to use for bulk get/create loc_cls = type(grouped_locations[0]) - # `.bulk_get_or_create` is expected to return the saved items in the order they were submitted saved_locations = loc_cls.bulk_get_or_create(grouped_locations) # Zip 'em back together: associate the saved instance with its original index in the `locations` list saved.extend((idx, saved_loc) for idx, saved_loc in zip(indices, saved_locations, strict=True)) @@ -139,8 +138,6 @@ def bulk_create_refs( new_finding_refs = [] new_product_refs = [] - # Process locations (unsaved, with possible association data) alongside their corresponding saved versions, - # which do not contain that information. We can do this because the bulk get/create operations are stable. for location in locations: assoc = location.get_association_data() From b193222cccae65a308799f8122324317094f72e2 Mon Sep 17 00:00:00 2001 From: dogboat Date: Tue, 14 Apr 2026 14:50:27 -0400 Subject: [PATCH 07/47] updates --- dojo/importers/location_manager.py | 9 --------- dojo/location/models.py | 2 +- unittests/test_bulk_locations.py | 11 +++++++++-- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 82d210a7364..37801d8bb2a 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -28,16 +28,7 @@ UnsavedLocation = TypeVar("UnsavedLocation", LocationData, AbstractLocation) -# test_notifications.py: Implement Locations class LocationManager: - @classmethod - def get_or_create_location(cls, unsaved_location: AbstractLocation) -> AbstractLocation | None: - """Gets/creates the given AbstractLocation.""" - if isinstance(unsaved_location, URL): - return URL.get_or_create_from_object(unsaved_location) - logger.debug(f"IMPORT_SCAN: Unsupported location type: {type(unsaved_location)}") - return None - @classmethod def get_supported_location_types(cls) -> dict[str, type[AbstractLocation]]: """Return a mapping of location type string to AbstractLocation subclass.""" diff --git a/dojo/location/models.py b/dojo/location/models.py index 112d2c8b644..fa840c23a01 100644 --- a/dojo/location/models.py +++ b/dojo/location/models.py @@ -395,7 +395,7 @@ def bulk_get_or_create(cls, locations: Iterable[Self]) -> list[Self]: if hasattr(loc, "_association_data") and not hasattr(saved, "_association_data"): saved._association_data = loc._association_data - # Create 'em + # Create 'em if new_locations: location_type = cls.get_location_type() with transaction.atomic(): diff --git a/unittests/test_bulk_locations.py b/unittests/test_bulk_locations.py index 20a6ff90d41..fc8eb651778 100644 --- a/unittests/test_bulk_locations.py +++ b/unittests/test_bulk_locations.py @@ -14,6 +14,7 @@ from django.contrib.auth import get_user_model from django.db import connection from django.test.utils import CaptureQueriesContext +from django.utils import timezone from dojo.importers.location_manager import LocationManager from dojo.location.models import Location, LocationFindingReference, LocationProductReference @@ -32,12 +33,13 @@ def _make_url(host, path=""): def _make_finding(): + now = timezone.now() user, _ = User.objects.get_or_create(username="bulk_test_user", defaults={"is_active": True}) pt, _ = Product_Type.objects.get_or_create(name="Bulk Test Type") product = Product.objects.create(name="Bulk Test Product", description="test", prod_type=pt) - eng = Engagement.objects.create(product=product, target_start="2026-01-01", target_end="2026-12-31") + eng = Engagement.objects.create(product=product, target_start=now, target_end=now) tt, _ = Test_Type.objects.get_or_create(name="Bulk Test") - test = Test.objects.create(engagement=eng, test_type=tt, target_start="2026-01-01", target_end="2026-12-31") + test = Test.objects.create(engagement=eng, test_type=tt, target_start=now, target_end=now) return Finding.objects.create(test=test, title="Bulk Test Finding", severity="Medium", reporter=user) @@ -139,6 +141,11 @@ def test_transaction_atomicity(self): @skip_unless_v3 class TestBulkGetOrCreateLocations(DojoTestCase): + def test_supported_location_types_includes_url(self): + supported = LocationManager.get_supported_location_types() + self.assertIn("url", supported) + self.assertIs(supported["url"], URL) + def test_url_only(self): urls = [_make_url("oss-loc-mgr.example.com")] saved = LocationManager.bulk_get_or_create_locations(urls) From cc5cf134d1b4ee85224f5e9e86aefb27826e78b3 Mon Sep 17 00:00:00 2001 From: dogboat Date: Tue, 14 Apr 2026 15:21:30 -0400 Subject: [PATCH 08/47] remove celery stuff --- dojo/importers/base_importer.py | 4 +- dojo/importers/default_reimporter.py | 6 +-- dojo/importers/location_manager.py | 66 +++++----------------------- unittests/test_bulk_locations.py | 12 ++--- 4 files changed, 22 insertions(+), 66 deletions(-) diff --git a/dojo/importers/base_importer.py b/dojo/importers/base_importer.py index c149f4e169d..7d1af7482e4 100644 --- a/dojo/importers/base_importer.py +++ b/dojo/importers/base_importer.py @@ -806,11 +806,11 @@ def process_locations( finding and product """ # Save the unsaved locations - self.location_manager.chunk_locations_and_disperse(finding, finding.unsaved_locations) + self.location_manager.add_locations_to_finding(finding, finding.unsaved_locations) # Check for any that were added in the form if len(locations_to_add) > 0: logger.debug("locations_to_add: %s", locations_to_add) - self.location_manager.chunk_locations_and_disperse(finding, locations_to_add) + self.location_manager.add_locations_to_finding(finding, locations_to_add) # TODO: Delete this after the move to Locations def process_endpoints( diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index 2a22da10a35..da4c48450d4 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -821,7 +821,7 @@ def process_matched_mitigated_finding( if settings.V3_FEATURE_LOCATIONS: # Reactivate mitigated locations mitigated_locations = existing_finding.locations.filter(status=FindingLocationStatus.Mitigated) - self.location_manager.chunk_locations_and_reactivate(mitigated_locations) + self.location_manager.reactivate_location_status(mitigated_locations) else: # TODO: Delete this after the move to Locations # Accumulate endpoint statuses for bulk reactivation in persist() @@ -992,9 +992,9 @@ def finding_post_processing( for the purpose of foreign key restrictions """ if settings.V3_FEATURE_LOCATIONS: - self.location_manager.chunk_locations_and_disperse(finding, finding_from_report.unsaved_locations) + self.location_manager.add_locations_to_finding(finding, finding_from_report.unsaved_locations) if len(self.endpoints_to_add) > 0: - self.location_manager.chunk_locations_and_disperse(finding, self.endpoints_to_add) + self.location_manager.add_locations_to_finding(finding, self.endpoints_to_add) else: # TODO: Delete this after the move to Locations for endpoint in finding_from_report.unsaved_endpoints: diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 37801d8bb2a..d7af2e58605 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -8,8 +8,6 @@ from django.core.exceptions import ValidationError from django.utils import timezone -from dojo.celery import app -from dojo.celery_dispatch import dojo_dispatch_task from dojo.location.models import AbstractLocation, LocationFindingReference, LocationProductReference from dojo.location.status import FindingLocationStatus, ProductLocationStatus from dojo.tools.locations import LocationData @@ -162,49 +160,33 @@ def bulk_create_refs( ) @classmethod - def _add_locations_to_unsaved_finding( + def add_locations_to_finding( cls, finding: Finding, locations: list[UnsavedLocation], - **kwargs: dict, # noqa: ARG003 ) -> None: """Creates AbstractLocation objects from the given list and links them to the given Finding and its Product.""" locations = cls.bulk_get_or_create_locations(locations) cls.bulk_create_refs(locations, finding=finding) logger.debug(f"LocationManager: {len(locations)} locations associated with {finding}") - @app.task - def add_locations_to_unsaved_finding( - manager_cls_path: str, # noqa: N805 - finding: Finding, - locations: list[UnsavedLocation], - **kwargs: dict, - ) -> None: - """Celery task that resolves the LocationManager class and delegates to _add_locations_to_unsaved_finding.""" - from django.utils.module_loading import import_string # noqa: PLC0415 - - manager_cls = import_string(manager_cls_path) - manager_cls._add_locations_to_unsaved_finding(finding, locations, **kwargs) - - @app.task + @staticmethod def mitigate_location_status( - location_refs: QuerySet[LocationFindingReference], # noqa: N805 + location_refs: QuerySet[LocationFindingReference], user: Dojo_User, - **kwargs: dict, ) -> None: - """Mitigate all given (non-mitigated) location refs""" + """Mitigate all given (non-mitigated) location refs.""" location_refs.exclude(status=FindingLocationStatus.Mitigated).update( auditor=user, audit_time=timezone.now(), status=FindingLocationStatus.Mitigated, ) - @app.task + @staticmethod def reactivate_location_status( - location_refs: QuerySet[LocationFindingReference], # noqa: N805 - **kwargs: dict, + location_refs: QuerySet[LocationFindingReference], ) -> None: - """Reactivate all given (mitigated) locations refs""" + """Reactivate all given (mitigated) location refs.""" location_refs.filter(status=FindingLocationStatus.Mitigated).update( auditor=None, audit_time=timezone.now(), @@ -235,7 +217,7 @@ def update_location_status( user: Dojo_User, **kwargs: dict, ) -> None: - """Update the list of locations from the new finding with the list that is in the old finding""" + """Update the list of locations from the new finding with the list that is in the old finding.""" # New endpoints are already added in serializers.py / views.py (see comment "# for existing findings: make sure endpoints are present or created") # So we only need to mitigate endpoints that are no longer present existing_location_refs: QuerySet[LocationFindingReference] = existing_finding.locations.exclude( @@ -247,7 +229,7 @@ def update_location_status( ) if new_finding.is_mitigated: # New finding is mitigated, so mitigate all existing location refs - self.chunk_locations_and_mitigate(existing_location_refs, user) + self.mitigate_location_status(existing_location_refs, user) else: new_locations_values = [str(location) for location in type(self).clean_unsaved_locations(new_finding.unsaved_locations)] # Reactivate endpoints in the old finding that are in the new finding @@ -255,31 +237,5 @@ def update_location_status( # Mitigate endpoints in the existing finding not in the new finding location_refs_to_mitigate = existing_location_refs.exclude(location__location_value__in=new_locations_values) - self.chunk_locations_and_reactivate(location_refs_to_reactivate) - self.chunk_locations_and_mitigate(location_refs_to_mitigate, user) - - def chunk_locations_and_disperse( - self, - finding: Finding, - locations: list[UnsavedLocation], - **kwargs: dict, - ) -> None: - if not locations: - return - cls_path = f"{type(self).__module__}.{type(self).__qualname__}" - dojo_dispatch_task(self.add_locations_to_unsaved_finding, cls_path, finding, locations, sync=True) - - def chunk_locations_and_reactivate( - self, - location_refs: QuerySet[LocationFindingReference], - **kwargs: dict, - ) -> None: - dojo_dispatch_task(self.reactivate_location_status, location_refs, sync=True) - - def chunk_locations_and_mitigate( - self, - location_refs: QuerySet[LocationFindingReference], - user: Dojo_User, - **kwargs: dict, - ) -> None: - dojo_dispatch_task(self.mitigate_location_status, location_refs, user, sync=True) + self.reactivate_location_status(location_refs_to_reactivate) + self.mitigate_location_status(location_refs_to_mitigate, user) diff --git a/unittests/test_bulk_locations.py b/unittests/test_bulk_locations.py index fc8eb651778..5a4a21e4ad8 100644 --- a/unittests/test_bulk_locations.py +++ b/unittests/test_bulk_locations.py @@ -5,7 +5,7 @@ - AbstractLocation.bulk_get_or_create (on URL) - LocationManager.bulk_get_or_create_locations (URL-only) - LocationManager.bulk_create_refs (finding + product refs) -- LocationManager._add_locations_to_unsaved_finding (end-to-end) +- LocationManager.add_locations_to_finding (end-to-end) - Query efficiency """ @@ -247,7 +247,7 @@ def test_finding_implies_product(self): # --------------------------------------------------------------------------- -# End-to-end: _add_locations_to_unsaved_finding +# End-to-end: add_locations_to_finding # --------------------------------------------------------------------------- @skip_unless_v3 class TestAddLocationsToUnsavedFinding(DojoTestCase): @@ -261,22 +261,22 @@ def test_full_pipeline(self): LocationData(type="url", data={"url": "https://oss-e2e-2.example.com/api"}), ] - LocationManager._add_locations_to_unsaved_finding(finding, loc_data) + LocationManager.add_locations_to_finding(finding, loc_data) self.assertEqual(LocationFindingReference.objects.filter(finding=finding).count(), 2) self.assertEqual(LocationProductReference.objects.filter(product=product).count(), 2) def test_empty_locations(self): finding = _make_finding() - LocationManager._add_locations_to_unsaved_finding(finding, []) + LocationManager.add_locations_to_finding(finding, []) self.assertEqual(LocationFindingReference.objects.filter(finding=finding).count(), 0) def test_idempotent(self): finding = _make_finding() loc_data = [LocationData(type="url", data={"url": "https://oss-idempotent.example.com"})] - LocationManager._add_locations_to_unsaved_finding(finding, loc_data) - LocationManager._add_locations_to_unsaved_finding(finding, loc_data) + LocationManager.add_locations_to_finding(finding, loc_data) + LocationManager.add_locations_to_finding(finding, loc_data) self.assertEqual(LocationFindingReference.objects.filter(finding=finding).count(), 1) From 10748a96403b5cf92acd1994c8f55b583321d30b Mon Sep 17 00:00:00 2001 From: dogboat Date: Tue, 14 Apr 2026 16:52:48 -0400 Subject: [PATCH 09/47] refactor --- dojo/importers/base_importer.py | 103 +------ dojo/importers/default_importer.py | 27 +- dojo/importers/default_reimporter.py | 83 ++---- dojo/importers/endpoint_manager.py | 42 +++ dojo/importers/location_manager.py | 420 ++++++++++++++++----------- unittests/test_bulk_locations.py | 145 +++++---- 6 files changed, 408 insertions(+), 412 deletions(-) diff --git a/dojo/importers/base_importer.py b/dojo/importers/base_importer.py index 7d1af7482e4..8f504d862df 100644 --- a/dojo/importers/base_importer.py +++ b/dojo/importers/base_importer.py @@ -13,8 +13,6 @@ import dojo.finding.helper as finding_helper import dojo.risk_acceptance.helper as ra_helper -from dojo.celery_dispatch import dojo_dispatch_task -from dojo.importers.location_manager import LocationManager, UnsavedLocation from dojo.importers.options import ImporterOptions from dojo.jira_link.helper import is_keep_in_sync_with_jira from dojo.location.models import Location @@ -80,8 +78,6 @@ def __init__( and will raise a `NotImplemented` exception """ ImporterOptions.__init__(self, *args, **kwargs) - if settings.V3_FEATURE_LOCATIONS: - self.location_manager = LocationManager() def check_child_implementation_exception(self): """ @@ -391,36 +387,20 @@ def apply_import_tags( for tag in self.tags: self.add_tags_safe(finding, tag) - if settings.V3_FEATURE_LOCATIONS: - # Add any tags to any locations of the findings imported if necessary - if self.apply_tags_to_endpoints and self.tags: - # Collect all endpoints linked to the affected findings - locations_qs = Location.objects.filter(findings__finding__in=findings_to_tag).distinct() - try: - bulk_add_tags_to_instances( - tag_or_tags=self.tags, - instances=locations_qs, - tag_field_name="tags", - ) - except IntegrityError: - for finding in findings_to_tag: - for location in finding.locations.all(): - for tag in self.tags: - self.add_tags_safe(location.location, tag) - # Add any tags to any endpoints of the findings imported if necessary - elif self.apply_tags_to_endpoints and self.tags: - endpoints_qs = Endpoint.objects.filter(finding__in=findings_to_tag).distinct() + # Add any tags to any locations/endpoints of the findings imported if necessary + if self.apply_tags_to_endpoints and self.tags: + items_qs = self.item_manager.get_items_for_tagging(findings_to_tag) try: bulk_add_tags_to_instances( tag_or_tags=self.tags, - instances=endpoints_qs, + instances=items_qs, tag_field_name="tags", ) except IntegrityError: for finding in findings_to_tag: - for endpoint in finding.endpoints.all(): + for item in self.item_manager.get_item_tag_fallback(finding): for tag in self.tags: - self.add_tags_safe(endpoint, tag) + self.add_tags_safe(item, tag) def update_import_history( self, @@ -467,14 +447,8 @@ def update_import_history( import_settings["apply_tags_to_endpoints"] = self.apply_tags_to_endpoints import_settings["group_by"] = self.group_by import_settings["create_finding_groups_for_all_findings"] = self.create_finding_groups_for_all_findings - if settings.V3_FEATURE_LOCATIONS: - # Add the list of locations that were added exclusively at import time - if len(self.endpoints_to_add) > 0: - import_settings["locations"] = [str(location) for location in self.endpoints_to_add] - # TODO: Delete this after the move to Locations - # Add the list of endpoints that were added exclusively at import time - elif len(self.endpoints_to_add) > 0: - import_settings["endpoints"] = [str(endpoint) for endpoint in self.endpoints_to_add] + if len(self.endpoints_to_add) > 0: + import_settings.update(self.item_manager.serialize_extra_items(self.endpoints_to_add)) # Create the test import object test_import = Test_Import.objects.create( test=self.test, @@ -793,53 +767,16 @@ def process_request_response_pairs( burp_rr.clean() burp_rr.save() - def process_locations( + def process_items( self, finding: Finding, - locations_to_add: list[UnsavedLocation], + extra_items_to_add: list | None = None, ) -> None: """ - Process any locations to add to the finding. Locations could come from two places - - Directly from the report - - Supplied by the user from the import form - These locations will be processed in to Location objects and associated with the - finding and product - """ - # Save the unsaved locations - self.location_manager.add_locations_to_finding(finding, finding.unsaved_locations) - # Check for any that were added in the form - if len(locations_to_add) > 0: - logger.debug("locations_to_add: %s", locations_to_add) - self.location_manager.add_locations_to_finding(finding, locations_to_add) - - # TODO: Delete this after the move to Locations - def process_endpoints( - self, - finding: Finding, - endpoints_to_add: list[Endpoint], - ) -> None: + Record locations/endpoints from the finding + any form-added extras. + Flushed to DB by item_manager.persist(). """ - Process any endpoints to add to the finding. Endpoints could come from two places - - Directly from the report - - Supplied by the user from the import form - These endpoints will be processed in to endpoints objects and associated with the - finding and and product - """ - if settings.V3_FEATURE_LOCATIONS: - msg = "BaseImporter#process_endpoints() method is deprecated when V3_FEATURE_LOCATIONS is enabled" - raise NotImplementedError(msg) - - # Clean and record unsaved endpoints from the report - self.endpoint_manager.clean_unsaved_endpoints(finding.unsaved_endpoints) - for endpoint in finding.unsaved_endpoints: - key = self.endpoint_manager.record_endpoint(endpoint) - self.endpoint_manager.record_status_for_create(finding, key) - # Record any endpoints added from the form - if len(endpoints_to_add) > 0: - logger.debug("endpoints_to_add: %s", endpoints_to_add) - for endpoint in endpoints_to_add: - key = self.endpoint_manager.record_endpoint(endpoint) - self.endpoint_manager.record_status_for_create(finding, key) + self.item_manager.record_for_finding(finding, extra_items_to_add) def sanitize_vulnerability_ids(self, finding) -> None: """Remove undisired vulnerability id values""" @@ -932,19 +869,7 @@ def mitigate_finding( # Remove risk acceptance if present (vulnerability is now fixed) # risk_unaccept will check if finding.risk_accepted is True before proceeding ra_helper.risk_unaccept(self.user, finding, perform_save=False, post_comments=False) - if settings.V3_FEATURE_LOCATIONS: - # Mitigate the location statuses - dojo_dispatch_task( - LocationManager.mitigate_location_status, - finding.locations.all(), - self.user, - kwuser=self.user, - sync=True, - ) - else: - # TODO: Delete this after the move to Locations - # Accumulate endpoint statuses for bulk mitigate in persist() - self.endpoint_manager.record_statuses_to_mitigate(finding.status_finding.all()) + self.item_manager.record_mitigations_for_finding(finding, self.user) # to avoid pushing a finding group multiple times, we push those outside of the loop if finding_groups_enabled and finding.finding_group: # don't try to dedupe findings that we are closing diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index e8bd56baf55..8a7e3e0344a 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -10,6 +10,7 @@ from dojo.finding import helper as finding_helper from dojo.importers.base_importer import BaseImporter, Parser from dojo.importers.endpoint_manager import EndpointManager +from dojo.importers.location_manager import LocationManager from dojo.importers.options import ImporterOptions from dojo.jira_link.helper import is_keep_in_sync_with_jira from dojo.models import ( @@ -58,8 +59,10 @@ def __init__(self, *args, **kwargs): import_type=Test_Import.IMPORT_TYPE, **kwargs, ) - if not settings.V3_FEATURE_LOCATIONS: - self.endpoint_manager = EndpointManager(self.engagement.product) + if settings.V3_FEATURE_LOCATIONS: + self.item_manager = LocationManager(self.engagement.product) + else: + self.item_manager = EndpointManager(self.engagement.product) def create_test( self, @@ -240,13 +243,7 @@ def process_findings( ) # Process any request/response pairs self.process_request_response_pairs(finding) - if settings.V3_FEATURE_LOCATIONS: - # Process any locations on the finding, or added on the form - self.process_locations(finding, self.endpoints_to_add) - else: - # TODO: Delete this after the move to Locations - # Process any endpoints on the finding, or added on the form - self.process_endpoints(finding, self.endpoints_to_add) + self.process_items(finding, self.endpoints_to_add) # Parsers must use unsaved_tags to store tags, so we can clean them. # Accumulate for bulk application after the loop (O(unique_tags) instead of O(N·T)). cleaned_tags = clean_tags(finding.unsaved_tags) @@ -267,16 +264,13 @@ def process_findings( logger.debug("process_findings: computed push_to_jira=%s", push_to_jira) batch_finding_ids.append(finding.id) - # If batch is full or we're at the end, persist endpoints and dispatch + # If batch is full or we're at the end, persist locations/endpoints and dispatch if len(batch_finding_ids) >= batch_max_size or is_final_finding: - if not settings.V3_FEATURE_LOCATIONS: - self.endpoint_manager.persist(user=self.user) - + self.item_manager.persist(user=self.user) # Apply parser-supplied tags for this batch before post-processing starts, # so rules/deduplication tasks see the tags already on the findings. bulk_apply_parser_tags(findings_with_parser_tags) findings_with_parser_tags.clear() - finding_ids_batch = list(batch_finding_ids) batch_finding_ids.clear() logger.debug("process_findings: dispatching batch with push_to_jira=%s (batch_size=%d, is_final=%s)", @@ -404,9 +398,8 @@ def close_old_findings( finding_groups_enabled=self.findings_groups_enabled, product_grading_option=False, ) - # Persist any accumulated endpoint status mitigations - if not settings.V3_FEATURE_LOCATIONS: - self.endpoint_manager.persist(user=self.user) + # Persist any accumulated location/endpoint status changes + self.item_manager.persist(user=self.user) # push finding groups to jira since we only only want to push whole groups # We dont check if the finding jira sync is applicable quite yet until we can get in the loop # but this is a way to at least make it that far diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index da4c48450d4..187596985b8 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -16,9 +16,9 @@ ) from dojo.importers.base_importer import BaseImporter, Parser from dojo.importers.endpoint_manager import EndpointManager +from dojo.importers.location_manager import LocationManager from dojo.importers.options import ImporterOptions from dojo.jira_link.helper import is_keep_in_sync_with_jira -from dojo.location.status import FindingLocationStatus from dojo.models import ( Development_Environment, Finding, @@ -82,8 +82,10 @@ def __init__(self, *args, **kwargs): import_type=Test_Import.REIMPORT_TYPE, **kwargs, ) - if not settings.V3_FEATURE_LOCATIONS: - self.endpoint_manager = EndpointManager(self.test.engagement.product) + if settings.V3_FEATURE_LOCATIONS: + self.item_manager = LocationManager(self.test.engagement.product) + else: + self.item_manager = EndpointManager(self.test.engagement.product) def process_scan( self, @@ -338,13 +340,7 @@ def process_findings( # Set the service supplied at import time if self.service is not None: unsaved_finding.service = self.service - if settings.V3_FEATURE_LOCATIONS: - # Clean any locations that are on the finding - self.location_manager.clean_unsaved_locations(unsaved_finding.unsaved_locations) - else: - # TODO: Delete this after the move to Locations - # Clean any endpoints that are on the finding - self.endpoint_manager.clean_unsaved_endpoints(unsaved_finding.unsaved_endpoints) + self.item_manager.clean_unsaved(unsaved_finding) # Calculate the hash code to be used to identify duplicates unsaved_finding.hash_code = self.calculate_unsaved_finding_hash_code(unsaved_finding) deduplicationLogger.debug(f"unsaved finding's hash_code: {unsaved_finding.hash_code}") @@ -380,27 +376,15 @@ def process_findings( continue # Update endpoints on the existing finding with those on the new finding if finding.dynamic_finding: - if settings.V3_FEATURE_LOCATIONS: - logger.debug( - "Re-import found an existing dynamic finding for this new " - "finding. Checking the status of locations", - ) - self.location_manager.update_location_status( - existing_finding, - unsaved_finding, - self.user, - ) - else: - # TODO: Delete this after the move to Locations - logger.debug( - "Re-import found an existing dynamic finding for this new " - "finding. Checking the status of endpoints", - ) - self.endpoint_manager.update_endpoint_status( - existing_finding, - unsaved_finding, - self.user, - ) + logger.debug( + "Re-import found an existing dynamic finding for this new " + "finding. Checking the status of locations/endpoints", + ) + self.item_manager.update_status( + existing_finding, + unsaved_finding, + self.user, + ) else: finding, finding_will_be_grouped = self.process_finding_that_was_not_matched(unsaved_finding) @@ -441,9 +425,7 @@ def process_findings( # - Deduplication batches: optimize bulk operations (larger batches = fewer queries) # They don't need to be aligned since they optimize different operations. if len(batch_finding_ids) >= dedupe_batch_max_size or is_final: - if not settings.V3_FEATURE_LOCATIONS: - self.endpoint_manager.persist(user=self.user) - + self.item_manager.persist(user=self.user) # Apply parser-supplied tags for this batch before post-processing starts, # so rules/deduplication tasks see the tags already on the findings. bulk_apply_parser_tags(findings_with_parser_tags) @@ -555,9 +537,8 @@ def close_old_findings( product_grading_option=False, ) mitigated_findings.append(finding) - # Persist any accumulated endpoint status mitigations - if not settings.V3_FEATURE_LOCATIONS: - self.endpoint_manager.persist(user=self.user) + # Persist any accumulated location/endpoint status changes + self.item_manager.persist(user=self.user) # push finding groups to jira since we only only want to push whole groups # We dont check if the finding jira sync is applicable quite yet until we can get in the loop # but this is a way to at least make it that far @@ -818,16 +799,7 @@ def process_matched_mitigated_finding( note = Notes(entry=f"Re-activated by {self.scan_type} re-upload.", author=self.user) note.save() - if settings.V3_FEATURE_LOCATIONS: - # Reactivate mitigated locations - mitigated_locations = existing_finding.locations.filter(status=FindingLocationStatus.Mitigated) - self.location_manager.reactivate_location_status(mitigated_locations) - else: - # TODO: Delete this after the move to Locations - # Accumulate endpoint statuses for bulk reactivation in persist() - self.endpoint_manager.record_statuses_to_reactivate( - self.endpoint_manager.get_non_special_endpoint_statuses(existing_finding), - ) + self.item_manager.record_reactivations_for_finding(existing_finding) existing_finding.notes.add(note) self.reactivated_items.append(existing_finding) # The new finding is active while the existing on is mitigated. The existing finding needs to @@ -991,19 +963,10 @@ def finding_post_processing( Save all associated objects to the finding after it has been saved for the purpose of foreign key restrictions """ - if settings.V3_FEATURE_LOCATIONS: - self.location_manager.add_locations_to_finding(finding, finding_from_report.unsaved_locations) - if len(self.endpoints_to_add) > 0: - self.location_manager.add_locations_to_finding(finding, self.endpoints_to_add) - else: - # TODO: Delete this after the move to Locations - for endpoint in finding_from_report.unsaved_endpoints: - key = self.endpoint_manager.record_endpoint(endpoint) - self.endpoint_manager.record_status_for_create(finding, key) - if len(self.endpoints_to_add) > 0: - for endpoint in self.endpoints_to_add: - key = self.endpoint_manager.record_endpoint(endpoint) - self.endpoint_manager.record_status_for_create(finding, key) + # Copy unsaved items from the parser output onto the saved finding so record_for_finding can read them + finding.unsaved_locations = getattr(finding_from_report, "unsaved_locations", []) + finding.unsaved_endpoints = getattr(finding_from_report, "unsaved_endpoints", []) + self.item_manager.record_for_finding(finding, self.endpoints_to_add or None) # For matched/existing findings, do not update tags from the report, # consistent with how other fields are handled on reimport. if not is_matched_finding: diff --git a/dojo/importers/endpoint_manager.py b/dojo/importers/endpoint_manager.py index c909f921201..e6657a42ba5 100644 --- a/dojo/importers/endpoint_manager.py +++ b/dojo/importers/endpoint_manager.py @@ -300,3 +300,45 @@ def persist(self, user: Dojo_User | None = None) -> None: batch_size=1000, ) self._statuses_to_reactivate.clear() + + # ------------------------------------------------------------------ + # Unified interface (shared with LocationManager) + # ------------------------------------------------------------------ + + def clean_unsaved(self, finding: Finding) -> None: + """Clean the unsaved endpoints on this finding.""" + self.clean_unsaved_endpoints(finding.unsaved_endpoints) + + def record_for_finding(self, finding: Finding, extra_items: list[Endpoint] | None = None) -> None: + """Record endpoints from the finding + any form-added extras for later batch creation.""" + for endpoint in finding.unsaved_endpoints: + key = self.record_endpoint(endpoint) + self.record_status_for_create(finding, key) + if extra_items: + for endpoint in extra_items: + key = self.record_endpoint(endpoint) + self.record_status_for_create(finding, key) + + def update_status(self, existing_finding: Finding, new_finding: Finding, user: Dojo_User) -> None: + """Accumulate status changes (mitigate/reactivate) based on old vs new finding.""" + self.update_endpoint_status(existing_finding, new_finding, user) + + def record_reactivations_for_finding(self, finding: Finding) -> None: + """Record endpoint statuses on this finding for reactivation.""" + self.record_statuses_to_reactivate(self.get_non_special_endpoint_statuses(finding)) + + def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User | None = None) -> None: + """Record endpoint statuses on this finding for mitigation.""" + self.record_statuses_to_mitigate(finding.status_finding.all()) + + def get_items_for_tagging(self, findings: list[Finding]): + """Return queryset of items to apply tags to.""" + return Endpoint.objects.filter(finding__in=findings).distinct() + + def get_item_tag_fallback(self, finding: Finding): + """Return iterable of taggable items for per-instance fallback.""" + return finding.endpoints.all() + + def serialize_extra_items(self, items: list) -> dict: + """Serialize extra items for import history.""" + return {"endpoints": [str(ep) for ep in items]} if items else {} diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index d7af2e58605..d025f57ef42 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -27,11 +27,240 @@ class LocationManager: + + def __init__(self, product: Product) -> None: + self._product = product + self._locations_by_finding: dict[int, tuple[Finding, list[UnsavedLocation]]] = {} + self._product_locations: list[UnsavedLocation] = [] + self._refs_to_mitigate: list[tuple[QuerySet[LocationFindingReference], Dojo_User]] = [] + self._refs_to_reactivate: list[QuerySet[LocationFindingReference]] = [] + + # ------------------------------------------------------------------ + # Accumulation methods (no DB hits) + # ------------------------------------------------------------------ + + def record_locations_for_finding( + self, + finding: Finding, + locations: list[UnsavedLocation], + ) -> None: + """Record locations to be associated with a finding. Flushed by persist().""" + if locations: + self._locations_by_finding.setdefault(finding.id, (finding, []))[1].extend(locations) + + def update_location_status( + self, + existing_finding: Finding, + new_finding: Finding, + user: Dojo_User, + ) -> None: + """Accumulate mitigate/reactivate operations for persist().""" + existing_location_refs: QuerySet[LocationFindingReference] = existing_finding.locations.exclude( + status__in=[ + FindingLocationStatus.FalsePositive, + FindingLocationStatus.RiskAccepted, + FindingLocationStatus.OutOfScope, + ], + ) + if new_finding.is_mitigated: + self._refs_to_mitigate.append((existing_location_refs, user)) + else: + new_locations_values = [ + str(location) for location in type(self).clean_unsaved_locations(new_finding.unsaved_locations) + ] + self._refs_to_reactivate.append( + existing_location_refs.filter(location__location_value__in=new_locations_values), + ) + self._refs_to_mitigate.append(( + existing_location_refs.exclude(location__location_value__in=new_locations_values), + user, + )) + + def record_reactivations(self, location_refs: QuerySet[LocationFindingReference]) -> None: + """Record location refs to reactivate. Flushed by persist().""" + self._refs_to_reactivate.append(location_refs) + + # ------------------------------------------------------------------ + # Unified interface (shared with EndpointManager) + # ------------------------------------------------------------------ + + def clean_unsaved(self, finding: Finding) -> None: + """Clean the unsaved locations on this finding.""" + type(self).clean_unsaved_locations(finding.unsaved_locations) + + def record_for_finding(self, finding: Finding, extra_items: list[UnsavedLocation] | None = None) -> None: + """Record locations from the finding + any form-added extras for later batch creation.""" + self.record_locations_for_finding(finding, finding.unsaved_locations) + if extra_items: + self.record_locations_for_finding(finding, extra_items) + + def update_status(self, existing_finding: Finding, new_finding: Finding, user: Dojo_User) -> None: + """Accumulate status changes (mitigate/reactivate) based on old vs new finding.""" + self.update_location_status(existing_finding, new_finding, user) + + def record_reactivations_for_finding(self, finding: Finding) -> None: + """Record mitigated location refs on this finding for reactivation.""" + mitigated = finding.locations.filter(status=FindingLocationStatus.Mitigated) + self._refs_to_reactivate.append(mitigated) + + def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User | None = None) -> None: + """Record all location refs on this finding for mitigation.""" + self._refs_to_mitigate.append((finding.locations.all(), user)) + + def get_items_for_tagging(self, findings: list[Finding]): + """Return queryset of items to apply tags to.""" + from dojo.location.models import Location # noqa: PLC0415 + return Location.objects.filter(findings__finding__in=findings).distinct() + + def get_item_tag_fallback(self, finding: Finding): + """Return iterable of taggable items for per-instance fallback.""" + return [ref.location for ref in finding.locations.all()] + + def serialize_extra_items(self, items: list) -> dict: + """Serialize extra items for import history.""" + return {"locations": [str(loc) for loc in items]} if items else {} + + # ------------------------------------------------------------------ + # Persist — flush all accumulated operations to DB + # ------------------------------------------------------------------ + + def persist(self, user: Dojo_User | None = None) -> None: + """Flush all accumulated location operations to the database.""" + # Step 1: Collect all locations across all findings, bulk get/create, bulk create refs + if self._locations_by_finding: + all_locations: list[AbstractLocation] = [] + finding_ranges: list[tuple[Finding, int, int]] = [] + + for finding, locations in self._locations_by_finding.values(): + cleaned = type(self).clean_unsaved_locations(locations) + start = len(all_locations) + all_locations.extend(cleaned) + end = len(all_locations) + if start < end: + finding_ranges.append((finding, start, end)) + + if all_locations: + saved = type(self)._bulk_get_or_create_locations(all_locations) + + # Build all refs across all findings in one pass + all_finding_refs = [] + all_product_refs = [] + + # Pre-fetch existing product refs for this product across all locations + all_location_ids = [loc.location_id for loc in saved] + existing_product_refs = set( + LocationProductReference.objects.filter( + location_id__in=all_location_ids, + product=self._product, + ).values_list("location_id", flat=True), + ) + + for finding, start, end in finding_ranges: + finding_locations = saved[start:end] + finding_location_ids = [loc.location_id for loc in finding_locations] + + existing_finding_refs = set( + LocationFindingReference.objects.filter( + location_id__in=finding_location_ids, + finding=finding, + ).values_list("location_id", flat=True), + ) + + for location in finding_locations: + assoc = location.get_association_data() + + if location.location_id not in existing_finding_refs: + all_finding_refs.append(LocationFindingReference( + location_id=location.location_id, + finding=finding, + status=FindingLocationStatus.Active, + relationship=assoc.relationship_type, + relationship_data=assoc.relationship_data, + )) + existing_finding_refs.add(location.location_id) + + if location.location_id not in existing_product_refs: + all_product_refs.append(LocationProductReference( + location_id=location.location_id, + product=self._product, + status=ProductLocationStatus.Active, + relationship=assoc.relationship_type, + relationship_data=assoc.relationship_data, + )) + existing_product_refs.add(location.location_id) + + if all_finding_refs: + LocationFindingReference.objects.bulk_create( + all_finding_refs, batch_size=1000, ignore_conflicts=True, + ) + if all_product_refs: + LocationProductReference.objects.bulk_create( + all_product_refs, batch_size=1000, ignore_conflicts=True, + ) + + self._locations_by_finding.clear() + + # Step 1b: Product-level locations (not tied to a finding) + if self._product_locations: + cleaned = type(self).clean_unsaved_locations(self._product_locations) + if cleaned: + saved = type(self)._bulk_get_or_create_locations(cleaned) + location_ids = [loc.location_id for loc in saved] + existing = set( + LocationProductReference.objects.filter( + location_id__in=location_ids, + product=self._product, + ).values_list("location_id", flat=True), + ) + new_refs = [] + for location in saved: + if location.location_id not in existing: + assoc = location.get_association_data() + new_refs.append(LocationProductReference( + location_id=location.location_id, + product=self._product, + status=ProductLocationStatus.Active, + relationship=assoc.relationship_type, + relationship_data=assoc.relationship_data, + )) + existing.add(location.location_id) + if new_refs: + LocationProductReference.objects.bulk_create( + new_refs, batch_size=1000, ignore_conflicts=True, + ) + self._product_locations.clear() + + # Step 2: Mitigate accumulated refs + for refs, mitigate_user in self._refs_to_mitigate: + refs.exclude(status=FindingLocationStatus.Mitigated).update( + auditor=mitigate_user, + audit_time=timezone.now(), + status=FindingLocationStatus.Mitigated, + ) + self._refs_to_mitigate.clear() + + # Step 3: Reactivate accumulated refs + for refs in self._refs_to_reactivate: + refs.filter(status=FindingLocationStatus.Mitigated).update( + auditor=None, + audit_time=timezone.now(), + status=FindingLocationStatus.Active, + ) + self._refs_to_reactivate.clear() + + # ------------------------------------------------------------------ + # Type registry + # ------------------------------------------------------------------ + @classmethod def get_supported_location_types(cls) -> dict[str, type[AbstractLocation]]: """Return a mapping of location type string to AbstractLocation subclass.""" return {URL.get_location_type(): URL} + # ------------------------------------------------------------------ + # Cleaning / conversion utilities + # ------------------------------------------------------------------ + @classmethod def make_abstract_locations(cls, locations: list[UnsavedLocation]) -> list[AbstractLocation]: """Converts the list of unsaved locations (AbstractLocation/LocationData objects) to a list of AbstractLocations.""" @@ -52,190 +281,43 @@ def make_abstract_locations(cls, locations: list[UnsavedLocation]) -> list[Abstr return abstract_locations @classmethod - def bulk_get_or_create_locations(cls, locations: list[UnsavedLocation]) -> list[AbstractLocation]: + def clean_unsaved_locations( + cls, + locations: list[UnsavedLocation], + ) -> list[AbstractLocation]: + """ + Convert locations represented as LocationData dataclasses to the appropriate AbstractLocation type, then clean + them. For any endpoints that fail this validation process, log a message that broken locations are being stored. + """ + locations = list(set(cls.make_abstract_locations(locations))) + for location in locations: + try: + location.clean() + except ValidationError as e: + logger.warning("DefectDojo is storing broken locations because cleaning wasn't successful: %s", e) + return locations + + # ------------------------------------------------------------------ + # Bulk internals + # ------------------------------------------------------------------ + + @classmethod + def _bulk_get_or_create_locations(cls, locations: list[AbstractLocation]) -> list[AbstractLocation]: """Bulk get-or-create a (possibly heterogeneous) list of AbstractLocations.""" - locations = cls.clean_unsaved_locations(locations) if not locations: return [] - # Util method for sorting/keying; returns the (Python) identity of the location entry's Type def type_id(x: tuple[int, AbstractLocation]) -> int: return id(type(x[1])) saved = [] - # Group by actual AbstractLocation subtype, tracking the original ordering (hence the `enumerate`) locations_with_idx = sorted(enumerate(locations), key=type_id) locations_by_type = groupby(locations_with_idx, key=type_id) for _, grouped_locations_with_idx in locations_by_type: - # Split into two lists: indices and homogenous location types indices, grouped_locations = zip(*grouped_locations_with_idx, strict=True) - # Determine the correct AbstractLocation class to use for bulk get/create loc_cls = type(grouped_locations[0]) saved_locations = loc_cls.bulk_get_or_create(grouped_locations) - # Zip 'em back together: associate the saved instance with its original index in the `locations` list saved.extend((idx, saved_loc) for idx, saved_loc in zip(indices, saved_locations, strict=True)) - # Sort by index to return in original order saved.sort(key=itemgetter(0)) return [loc for _, loc in saved] - - @classmethod - def bulk_create_refs( - cls, - locations: list[AbstractLocation], - *, - finding: Finding | None = None, - product: Product | None = None, - ) -> None: - """ - Bulk create LocationFindingReference and/or LocationProductReference rows. - - Iterates the unsaved/saved pairs once, building both finding and product - refs in a single pass. Skips refs that already exist in the DB. - """ - if not locations: - return - - if not finding and not product: - error_message = "One of 'finding' or 'product' must be provided." - raise ValueError(error_message) - - if finding: - # If associating with a finding, use its product regardless of whatever's set. Keeps in line with the - # original intended purpose: this is a bulk version of Location.(associate_with_finding|associate_with_product) - product = finding.test.engagement.product - - location_ids = [loc.location_id for loc in locations] - - # Pre-fetch existing refs to avoid duplicates - existing_finding_refs = set() - existing_product_refs = set() - if finding is not None: - existing_finding_refs = set( - LocationFindingReference.objects.filter( - location_id__in=location_ids, - finding=finding, - ).values_list("location_id", flat=True), - ) - if product is not None: - existing_product_refs = set( - LocationProductReference.objects.filter( - location_id__in=location_ids, - product=product, - ).values_list("location_id", flat=True), - ) - - new_finding_refs = [] - new_product_refs = [] - for location in locations: - assoc = location.get_association_data() - - if finding is not None and location.location_id not in existing_finding_refs: - new_finding_refs.append(LocationFindingReference( - location_id=location.location_id, - finding=finding, - status=FindingLocationStatus.Active, - relationship=assoc.relationship_type, - relationship_data=assoc.relationship_data, - )) - existing_finding_refs.add(location.location_id) - - if product is not None and location.location_id not in existing_product_refs: - new_product_refs.append(LocationProductReference( - location_id=location.location_id, - product=product, - status=ProductLocationStatus.Active, - relationship=assoc.relationship_type, - relationship_data=assoc.relationship_data, - )) - existing_product_refs.add(location.location_id) - - if new_finding_refs: - LocationFindingReference.objects.bulk_create( - new_finding_refs, batch_size=1000, ignore_conflicts=True, - ) - if new_product_refs: - LocationProductReference.objects.bulk_create( - new_product_refs, batch_size=1000, ignore_conflicts=True, - ) - - @classmethod - def add_locations_to_finding( - cls, - finding: Finding, - locations: list[UnsavedLocation], - ) -> None: - """Creates AbstractLocation objects from the given list and links them to the given Finding and its Product.""" - locations = cls.bulk_get_or_create_locations(locations) - cls.bulk_create_refs(locations, finding=finding) - logger.debug(f"LocationManager: {len(locations)} locations associated with {finding}") - - @staticmethod - def mitigate_location_status( - location_refs: QuerySet[LocationFindingReference], - user: Dojo_User, - ) -> None: - """Mitigate all given (non-mitigated) location refs.""" - location_refs.exclude(status=FindingLocationStatus.Mitigated).update( - auditor=user, - audit_time=timezone.now(), - status=FindingLocationStatus.Mitigated, - ) - - @staticmethod - def reactivate_location_status( - location_refs: QuerySet[LocationFindingReference], - ) -> None: - """Reactivate all given (mitigated) location refs.""" - location_refs.filter(status=FindingLocationStatus.Mitigated).update( - auditor=None, - audit_time=timezone.now(), - status=FindingLocationStatus.Active, - ) - - @classmethod - def clean_unsaved_locations( - cls, - locations: list[UnsavedLocation], - ) -> list[AbstractLocation]: - """ - Convert locations represented as LocationData dataclasses to the appropriate AbstractLocation type, then clean - them. For any endpoints that fail this validation process, log a message that broken locations are being stored. - """ - locations = list(set(cls.make_abstract_locations(locations))) - for location in locations: - try: - location.clean() - except ValidationError as e: - logger.warning("DefectDojo is storing broken locations because cleaning wasn't successful: %s", e) - return locations - - def update_location_status( - self, - existing_finding: Finding, - new_finding: Finding, - user: Dojo_User, - **kwargs: dict, - ) -> None: - """Update the list of locations from the new finding with the list that is in the old finding.""" - # New endpoints are already added in serializers.py / views.py (see comment "# for existing findings: make sure endpoints are present or created") - # So we only need to mitigate endpoints that are no longer present - existing_location_refs: QuerySet[LocationFindingReference] = existing_finding.locations.exclude( - status__in=[ - FindingLocationStatus.FalsePositive, - FindingLocationStatus.RiskAccepted, - FindingLocationStatus.OutOfScope, - ], - ) - if new_finding.is_mitigated: - # New finding is mitigated, so mitigate all existing location refs - self.mitigate_location_status(existing_location_refs, user) - else: - new_locations_values = [str(location) for location in type(self).clean_unsaved_locations(new_finding.unsaved_locations)] - # Reactivate endpoints in the old finding that are in the new finding - location_refs_to_reactivate = existing_location_refs.filter(location__location_value__in=new_locations_values) - # Mitigate endpoints in the existing finding not in the new finding - location_refs_to_mitigate = existing_location_refs.exclude(location__location_value__in=new_locations_values) - - self.reactivate_location_status(location_refs_to_reactivate) - self.mitigate_location_status(location_refs_to_mitigate, user) diff --git a/unittests/test_bulk_locations.py b/unittests/test_bulk_locations.py index 5a4a21e4ad8..c8867fc15d2 100644 --- a/unittests/test_bulk_locations.py +++ b/unittests/test_bulk_locations.py @@ -3,9 +3,8 @@ Covers: - AbstractLocation.bulk_get_or_create (on URL) -- LocationManager.bulk_get_or_create_locations (URL-only) -- LocationManager.bulk_create_refs (finding + product refs) -- LocationManager.add_locations_to_finding (end-to-end) +- LocationManager._bulk_get_or_create_locations (URL-only) +- LocationManager.record_locations_for_finding + persist (accumulator pattern) - Query efficiency """ @@ -32,11 +31,16 @@ def _make_url(host, path=""): return url +_finding_counter = 0 + + def _make_finding(): + global _finding_counter # noqa: PLW0603 + _finding_counter += 1 now = timezone.now() user, _ = User.objects.get_or_create(username="bulk_test_user", defaults={"is_active": True}) pt, _ = Product_Type.objects.get_or_create(name="Bulk Test Type") - product = Product.objects.create(name="Bulk Test Product", description="test", prod_type=pt) + product = Product.objects.create(name=f"Bulk Test Product {_finding_counter}", description="test", prod_type=pt) eng = Engagement.objects.create(product=product, target_start=now, target_end=now) tt, _ = Test_Type.objects.get_or_create(name="Bulk Test") test = Test.objects.create(engagement=eng, test_type=tt, target_start=now, target_end=now) @@ -136,7 +140,7 @@ def test_transaction_atomicity(self): # --------------------------------------------------------------------------- -# LocationManager.bulk_get_or_create_locations (URL-only) +# LocationManager._bulk_get_or_create_locations (URL-only) # --------------------------------------------------------------------------- @skip_unless_v3 class TestBulkGetOrCreateLocations(DojoTestCase): @@ -148,109 +152,64 @@ def test_supported_location_types_includes_url(self): def test_url_only(self): urls = [_make_url("oss-loc-mgr.example.com")] - saved = LocationManager.bulk_get_or_create_locations(urls) + saved = LocationManager._bulk_get_or_create_locations(urls) self.assertEqual(len(saved), 1) self.assertIsInstance(saved[0], URL) - def test_cleans_location_data(self): + def test_handles_cleaned_location_data(self): loc_data = LocationData(type="url", data={"url": "https://oss-from-data.example.com/api"}) - saved = LocationManager.bulk_get_or_create_locations([loc_data]) + cleaned = LocationManager.clean_unsaved_locations([loc_data]) + saved = LocationManager._bulk_get_or_create_locations(cleaned) self.assertEqual(len(saved), 1) self.assertIsInstance(saved[0], URL) self.assertEqual(saved[0].host, "oss-from-data.example.com") def test_empty_input(self): - self.assertEqual(LocationManager.bulk_get_or_create_locations([]), []) + self.assertEqual(LocationManager._bulk_get_or_create_locations([]), []) # --------------------------------------------------------------------------- -# LocationManager.bulk_create_refs +# LocationManager.persist — ref creation details # --------------------------------------------------------------------------- @skip_unless_v3 -class TestBulkCreateRefs(DojoTestCase): - - def test_creates_finding_and_product_refs(self): - finding = _make_finding() - product = finding.test.engagement.product - - saved = URL.bulk_get_or_create([_make_url("oss-refs-both.example.com")]) - LocationManager.bulk_create_refs(saved, finding=finding) - - self.assertTrue(LocationFindingReference.objects.filter( - location_id=saved[0].location_id, finding=finding, - ).exists()) - self.assertTrue(LocationProductReference.objects.filter( - location_id=saved[0].location_id, product=product, - ).exists()) - - def test_creates_product_refs_only(self): - pt, _ = Product_Type.objects.get_or_create(name="Refs Test Type") - product = Product.objects.create(name="Refs Test Product", description="test", prod_type=pt) - - saved = URL.bulk_get_or_create([_make_url("oss-refs-product.example.com")]) - LocationManager.bulk_create_refs(saved, product=product) - - self.assertTrue(LocationProductReference.objects.filter( - location_id=saved[0].location_id, product=product, - ).exists()) - self.assertFalse(LocationFindingReference.objects.filter( - location_id=saved[0].location_id, - ).exists()) - - def test_skips_existing_refs(self): - finding = _make_finding() - saved = URL.bulk_get_or_create([_make_url("oss-refs-existing.example.com")]) - - LocationManager.bulk_create_refs(saved, finding=finding) - LocationManager.bulk_create_refs(saved, finding=finding) - - self.assertEqual(LocationFindingReference.objects.filter( - location_id=saved[0].location_id, finding=finding, - ).count(), 1) +class TestPersistRefCreation(DojoTestCase): def test_uses_association_data(self): finding = _make_finding() + product = finding.test.engagement.product url = _make_url("oss-refs-assoc.example.com") url._association_data = LocationAssociationData( relationship_type="owned_by", relationship_data={"file_path": "/app/main.py"}, ) - saved = URL.bulk_get_or_create([url]) - LocationManager.bulk_create_refs(saved, finding=finding) - ref = LocationFindingReference.objects.get( - location_id=saved[0].location_id, finding=finding, - ) + mgr = LocationManager(product) + mgr.record_locations_for_finding(finding, [url]) + mgr.persist() + + ref = LocationFindingReference.objects.get(finding=finding) self.assertEqual(ref.relationship, "owned_by") self.assertEqual(ref.relationship_data, {"file_path": "/app/main.py"}) - def test_raises_without_finding_or_product(self): - saved = URL.bulk_get_or_create([_make_url("oss-refs-error.example.com")]) - with self.assertRaises(ValueError): - LocationManager.bulk_create_refs(saved) - - def test_empty_locations(self): - LocationManager.bulk_create_refs([], finding=_make_finding()) - - def test_finding_implies_product(self): - finding = _make_finding() - product = finding.test.engagement.product - saved = URL.bulk_get_or_create([_make_url("oss-refs-implied.example.com")]) + def test_product_only_locations(self): + pt, _ = Product_Type.objects.get_or_create(name="Refs Test Type") + product = Product.objects.create(name="Refs Product Only", description="test", prod_type=pt) - LocationManager.bulk_create_refs(saved, finding=finding) + mgr = LocationManager(product) + mgr._product_locations.extend([_make_url("oss-product-only.example.com")]) + mgr.persist() - self.assertTrue(LocationProductReference.objects.filter( - location_id=saved[0].location_id, product=product, - ).exists()) + self.assertTrue(LocationProductReference.objects.filter(product=product).exists()) + self.assertFalse(LocationFindingReference.objects.exists()) # --------------------------------------------------------------------------- -# End-to-end: add_locations_to_finding +# End-to-end: record + persist # --------------------------------------------------------------------------- @skip_unless_v3 -class TestAddLocationsToUnsavedFinding(DojoTestCase): +class TestRecordAndPersist(DojoTestCase): def test_full_pipeline(self): finding = _make_finding() @@ -261,25 +220,57 @@ def test_full_pipeline(self): LocationData(type="url", data={"url": "https://oss-e2e-2.example.com/api"}), ] - LocationManager.add_locations_to_finding(finding, loc_data) + mgr = LocationManager(product) + mgr.record_locations_for_finding(finding, loc_data) + mgr.persist() self.assertEqual(LocationFindingReference.objects.filter(finding=finding).count(), 2) self.assertEqual(LocationProductReference.objects.filter(product=product).count(), 2) def test_empty_locations(self): finding = _make_finding() - LocationManager.add_locations_to_finding(finding, []) + product = finding.test.engagement.product + + mgr = LocationManager(product) + mgr.record_locations_for_finding(finding, []) + mgr.persist() + self.assertEqual(LocationFindingReference.objects.filter(finding=finding).count(), 0) def test_idempotent(self): finding = _make_finding() + product = finding.test.engagement.product loc_data = [LocationData(type="url", data={"url": "https://oss-idempotent.example.com"})] - LocationManager.add_locations_to_finding(finding, loc_data) - LocationManager.add_locations_to_finding(finding, loc_data) + mgr = LocationManager(product) + mgr.record_locations_for_finding(finding, loc_data) + mgr.persist() + mgr.record_locations_for_finding(finding, loc_data) + mgr.persist() self.assertEqual(LocationFindingReference.objects.filter(finding=finding).count(), 1) + def test_multiple_findings_single_persist(self): + finding1 = _make_finding() + product = finding1.test.engagement.product + # Create second finding on the same product/engagement/test + finding2 = Finding.objects.create( + test=finding1.test, title="Bulk Test Finding 2", severity="High", reporter=finding1.reporter, + ) + + mgr = LocationManager(product) + mgr.record_locations_for_finding(finding1, [ + LocationData(type="url", data={"url": "https://oss-multi-1.example.com"}), + ]) + mgr.record_locations_for_finding(finding2, [ + LocationData(type="url", data={"url": "https://oss-multi-2.example.com"}), + ]) + mgr.persist() + + self.assertEqual(LocationFindingReference.objects.filter(finding=finding1).count(), 1) + self.assertEqual(LocationFindingReference.objects.filter(finding=finding2).count(), 1) + self.assertEqual(LocationProductReference.objects.filter(product=product).count(), 2) + # --------------------------------------------------------------------------- # Query efficiency From 8acea580c932cd37ac8af7c2f77179097c984cce Mon Sep 17 00:00:00 2001 From: dogboat Date: Wed, 15 Apr 2026 11:03:51 -0400 Subject: [PATCH 10/47] test updates --- dojo/importers/location_manager.py | 13 +++++++++++++ unittests/test_bulk_locations.py | 25 +++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index d025f57ef42..1f312342254 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -198,6 +198,14 @@ def persist(self, user: Dojo_User | None = None) -> None: all_product_refs, batch_size=1000, ignore_conflicts=True, ) + # bulk_create bypasses post_save signals, so manually trigger tag inheritance on each unique Location + from dojo.tags_signals import inherit_instance_tags # noqa: PLC0415 + seen_location_ids: set[int] = set() + for loc in saved: + if loc.location_id not in seen_location_ids: + seen_location_ids.add(loc.location_id) + inherit_instance_tags(loc.location) + self._locations_by_finding.clear() # Step 1b: Product-level locations (not tied to a finding) @@ -228,6 +236,11 @@ def persist(self, user: Dojo_User | None = None) -> None: LocationProductReference.objects.bulk_create( new_refs, batch_size=1000, ignore_conflicts=True, ) + + # bulk_create bypasses post_save signals; manually trigger tag inheritance + from dojo.tags_signals import inherit_instance_tags # noqa: PLC0415 + for loc in saved: + inherit_instance_tags(loc.location) self._product_locations.clear() # Step 2: Mitigate accumulated refs diff --git a/unittests/test_bulk_locations.py b/unittests/test_bulk_locations.py index c8867fc15d2..fa9d9492a7a 100644 --- a/unittests/test_bulk_locations.py +++ b/unittests/test_bulk_locations.py @@ -286,3 +286,28 @@ def test_bulk_fewer_queries_than_locations(self): # Expected: ~3 queries (SELECT existing, INSERT parents, INSERT subtypes) self.assertLess(len(ctx.captured_queries), 10) + + +# --------------------------------------------------------------------------- +# Tag inheritance after bulk persist +# --------------------------------------------------------------------------- +@skip_unless_v3 +class TestTagInheritanceOnPersist(DojoTestCase): + + def test_locations_inherit_product_tags(self): + """Locations should inherit tags from their associated product after persist.""" + finding = _make_finding() + product = finding.test.engagement.product + # Enable tag inheritance at the product level and add some product tags + product.enable_product_tag_inheritance = True + product.save() + product.tags.add("inherit", "tags", "these") + + loc_data = [LocationData(type="url", data={"url": "https://oss-tag-inherit.example.com"})] + mgr = LocationManager(product) + mgr.record_locations_for_finding(finding, loc_data) + mgr.persist() + + loc = Location.objects.get(url__host="oss-tag-inherit.example.com") + inherited = sorted(t.name for t in loc.inherited_tags.all()) + self.assertEqual(inherited, ["inherit", "tags", "these"]) From 3dce1c025624ba137be9c787087c832dbf84423e Mon Sep 17 00:00:00 2001 From: dogboat Date: Wed, 15 Apr 2026 11:04:49 -0400 Subject: [PATCH 11/47] linter --- dojo/importers/location_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 1f312342254..db790f3a296 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -10,6 +10,7 @@ from dojo.location.models import AbstractLocation, LocationFindingReference, LocationProductReference from dojo.location.status import FindingLocationStatus, ProductLocationStatus +from dojo.tags_signals import inherit_instance_tags from dojo.tools.locations import LocationData from dojo.url.models import URL @@ -199,7 +200,6 @@ def persist(self, user: Dojo_User | None = None) -> None: ) # bulk_create bypasses post_save signals, so manually trigger tag inheritance on each unique Location - from dojo.tags_signals import inherit_instance_tags # noqa: PLC0415 seen_location_ids: set[int] = set() for loc in saved: if loc.location_id not in seen_location_ids: @@ -238,7 +238,6 @@ def persist(self, user: Dojo_User | None = None) -> None: ) # bulk_create bypasses post_save signals; manually trigger tag inheritance - from dojo.tags_signals import inherit_instance_tags # noqa: PLC0415 for loc in saved: inherit_instance_tags(loc.location) self._product_locations.clear() From 2bc419211b17fda2bde7f504b3d97552e405ec2d Mon Sep 17 00:00:00 2001 From: dogboat Date: Wed, 15 Apr 2026 11:29:02 -0400 Subject: [PATCH 12/47] perf test updates --- unittests/test_importers_performance.py | 64 ++++++++++++------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/unittests/test_importers_performance.py b/unittests/test_importers_performance.py index 1a9c9fc137d..7ba2e04ee86 100644 --- a/unittests/test_importers_performance.py +++ b/unittests/test_importers_performance.py @@ -569,14 +569,14 @@ def test_import_reimport_reimport_performance_pghistory_async(self): configure_pghistory_triggers() self._import_reimport_performance( - expected_num_queries1=1191, - expected_num_async_tasks1=6, - expected_num_queries2=716, - expected_num_async_tasks2=17, - expected_num_queries3=346, - expected_num_async_tasks3=16, - expected_num_queries4=212, - expected_num_async_tasks4=6, + expected_num_queries1=165, + expected_num_async_tasks1=1, + expected_num_queries2=149, + expected_num_async_tasks2=1, + expected_num_queries3=60, + expected_num_async_tasks3=1, + expected_num_queries4=100, + expected_num_async_tasks4=0, ) @override_settings(ENABLE_AUDITLOG=True) @@ -593,14 +593,14 @@ def test_import_reimport_reimport_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._import_reimport_performance( - expected_num_queries1=1200, - expected_num_async_tasks1=6, - expected_num_queries2=725, - expected_num_async_tasks2=17, - expected_num_queries3=355, - expected_num_async_tasks3=16, - expected_num_queries4=212, - expected_num_async_tasks4=6, + expected_num_queries1=174, + expected_num_async_tasks1=1, + expected_num_queries2=158, + expected_num_async_tasks2=1, + expected_num_queries3=69, + expected_num_async_tasks3=1, + expected_num_queries4=100, + expected_num_async_tasks4=0, ) @override_settings(ENABLE_AUDITLOG=True) @@ -618,14 +618,14 @@ def test_import_reimport_reimport_performance_pghistory_no_async_with_product_gr self.system_settings(enable_product_grade=True) self._import_reimport_performance( - expected_num_queries1=1210, - expected_num_async_tasks1=8, - expected_num_queries2=735, - expected_num_async_tasks2=19, - expected_num_queries3=359, - expected_num_async_tasks3=18, - expected_num_queries4=222, - expected_num_async_tasks4=8, + expected_num_queries1=184, + expected_num_async_tasks1=3, + expected_num_queries2=168, + expected_num_async_tasks2=3, + expected_num_queries3=73, + expected_num_async_tasks3=3, + expected_num_queries4=110, + expected_num_async_tasks4=2, ) def _deduplication_performance(self, expected_num_queries1, expected_num_async_tasks1, expected_num_queries2, expected_num_async_tasks2, *, check_duplicates=True): @@ -718,10 +718,10 @@ def test_deduplication_performance_pghistory_async(self): self.system_settings(enable_deduplication=True) self._deduplication_performance( - expected_num_queries1=1411, - expected_num_async_tasks1=7, - expected_num_queries2=1016, - expected_num_async_tasks2=7, + expected_num_queries1=101, + expected_num_async_tasks1=1, + expected_num_queries2=92, + expected_num_async_tasks2=1, check_duplicates=False, # Async mode - deduplication happens later ) @@ -738,8 +738,8 @@ def test_deduplication_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._deduplication_performance( - expected_num_queries1=1420, - expected_num_async_tasks1=7, - expected_num_queries2=1132, - expected_num_async_tasks2=7, + expected_num_queries1=110, + expected_num_async_tasks1=1, + expected_num_queries2=208, + expected_num_async_tasks2=1, ) From 4e8f53f6ab3e092d563fe2592a10f3916f6007b7 Mon Sep 17 00:00:00 2001 From: dogboat Date: Wed, 15 Apr 2026 13:43:43 -0400 Subject: [PATCH 13/47] wip --- dojo/importers/location_manager.py | 231 ++++++++++++++++++++--------- unittests/test_bulk_locations.py | 68 +++++++++ 2 files changed, 233 insertions(+), 66 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index db790f3a296..222be976bf2 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -13,10 +13,9 @@ from dojo.tags_signals import inherit_instance_tags from dojo.tools.locations import LocationData from dojo.url.models import URL +from dojo.utils import get_system_setting if TYPE_CHECKING: - from django.db.models import QuerySet - from dojo.models import Dojo_User, Finding, Product logger = logging.getLogger(__name__) @@ -33,8 +32,32 @@ def __init__(self, product: Product) -> None: self._product = product self._locations_by_finding: dict[int, tuple[Finding, list[UnsavedLocation]]] = {} self._product_locations: list[UnsavedLocation] = [] - self._refs_to_mitigate: list[tuple[QuerySet[LocationFindingReference], Dojo_User]] = [] - self._refs_to_reactivate: list[QuerySet[LocationFindingReference]] = [] + # Status update inputs (deferred). All entries are processed in a single bulk pass by persist(). + # (existing_finding, new_finding, user): classified partial mitigate/reactivate + self._status_updates: list[tuple[Finding, Finding, Dojo_User]] = [] + # finding_id: fully reactivate (all mitigated refs on this finding become active) + self._finding_ids_to_fully_reactivate: list[int] = [] + # (finding_id, user): fully mitigate (all non-special refs on this finding become mitigated by user) + self._finding_ids_to_fully_mitigate: list[tuple[int, Dojo_User | None]] = [] + # Cached result of _should_inherit_product_tags() — lazily computed and reused across persist() calls + self._cached_should_inherit_product_tags: bool | None = None + + def _should_inherit_product_tags(self) -> bool: + """ + Return True if new LocationFindingReference/LocationProductReference creations + should trigger inherit_instance_tags on the affected locations. + + inherit_instance_tags() runs a complex JOIN query per location (via all_related_products()), + which is O(N) per bulk persist. We short-circuit when neither the product nor the system + setting has tag inheritance enabled — in that case, adding a new ref for self._product + cannot change any location's inherited tags. + """ + if self._cached_should_inherit_product_tags is None: + self._cached_should_inherit_product_tags = bool( + getattr(self._product, "enable_product_tag_inheritance", False) + or get_system_setting("enable_product_tag_inheritance"), + ) + return self._cached_should_inherit_product_tags # ------------------------------------------------------------------ # Accumulation methods (no DB hits) @@ -55,31 +78,8 @@ def update_location_status( new_finding: Finding, user: Dojo_User, ) -> None: - """Accumulate mitigate/reactivate operations for persist().""" - existing_location_refs: QuerySet[LocationFindingReference] = existing_finding.locations.exclude( - status__in=[ - FindingLocationStatus.FalsePositive, - FindingLocationStatus.RiskAccepted, - FindingLocationStatus.OutOfScope, - ], - ) - if new_finding.is_mitigated: - self._refs_to_mitigate.append((existing_location_refs, user)) - else: - new_locations_values = [ - str(location) for location in type(self).clean_unsaved_locations(new_finding.unsaved_locations) - ] - self._refs_to_reactivate.append( - existing_location_refs.filter(location__location_value__in=new_locations_values), - ) - self._refs_to_mitigate.append(( - existing_location_refs.exclude(location__location_value__in=new_locations_values), - user, - )) - - def record_reactivations(self, location_refs: QuerySet[LocationFindingReference]) -> None: - """Record location refs to reactivate. Flushed by persist().""" - self._refs_to_reactivate.append(location_refs) + """Defer status update to persist(). No DB access at record time.""" + self._status_updates.append((existing_finding, new_finding, user)) # ------------------------------------------------------------------ # Unified interface (shared with EndpointManager) @@ -100,13 +100,12 @@ def update_status(self, existing_finding: Finding, new_finding: Finding, user: D self.update_location_status(existing_finding, new_finding, user) def record_reactivations_for_finding(self, finding: Finding) -> None: - """Record mitigated location refs on this finding for reactivation.""" - mitigated = finding.locations.filter(status=FindingLocationStatus.Mitigated) - self._refs_to_reactivate.append(mitigated) + """Defer reactivation to persist(). No DB access at record time.""" + self._finding_ids_to_fully_reactivate.append(finding.id) def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User | None = None) -> None: - """Record all location refs on this finding for mitigation.""" - self._refs_to_mitigate.append((finding.locations.all(), user)) + """Defer mitigation to persist(). No DB access at record time.""" + self._finding_ids_to_fully_mitigate.append((finding.id, user)) def get_items_for_tagging(self, findings: list[Finding]): """Return queryset of items to apply tags to.""" @@ -146,31 +145,36 @@ def persist(self, user: Dojo_User | None = None) -> None: # Build all refs across all findings in one pass all_finding_refs = [] all_product_refs = [] + # Track locations that got new refs — only those need tag inheritance + # (mirrors original post_save signal behavior on LocationFindingReference/LocationProductReference) + locations_needing_inherit: dict[int, AbstractLocation] = {} - # Pre-fetch existing product refs for this product across all locations + # Pre-fetch existing product refs for this product across all locations (one query) all_location_ids = [loc.location_id for loc in saved] - existing_product_refs = set( + existing_product_refs: set[int] = set( LocationProductReference.objects.filter( location_id__in=all_location_ids, product=self._product, ).values_list("location_id", flat=True), ) + # Pre-fetch existing finding refs across ALL findings in one query (avoids N+1) + all_finding_ids = [finding.id for finding, _, _ in finding_ranges] + existing_finding_ref_keys: set[tuple[int, int]] = set( + LocationFindingReference.objects.filter( + location_id__in=all_location_ids, + finding_id__in=all_finding_ids, + ).values_list("finding_id", "location_id"), + ) + for finding, start, end in finding_ranges: finding_locations = saved[start:end] - finding_location_ids = [loc.location_id for loc in finding_locations] - - existing_finding_refs = set( - LocationFindingReference.objects.filter( - location_id__in=finding_location_ids, - finding=finding, - ).values_list("location_id", flat=True), - ) for location in finding_locations: assoc = location.get_association_data() + finding_ref_key = (finding.id, location.location_id) - if location.location_id not in existing_finding_refs: + if finding_ref_key not in existing_finding_ref_keys: all_finding_refs.append(LocationFindingReference( location_id=location.location_id, finding=finding, @@ -178,7 +182,8 @@ def persist(self, user: Dojo_User | None = None) -> None: relationship=assoc.relationship_type, relationship_data=assoc.relationship_data, )) - existing_finding_refs.add(location.location_id) + existing_finding_ref_keys.add(finding_ref_key) + locations_needing_inherit[location.location_id] = location if location.location_id not in existing_product_refs: all_product_refs.append(LocationProductReference( @@ -189,6 +194,7 @@ def persist(self, user: Dojo_User | None = None) -> None: relationship_data=assoc.relationship_data, )) existing_product_refs.add(location.location_id) + locations_needing_inherit[location.location_id] = location if all_finding_refs: LocationFindingReference.objects.bulk_create( @@ -199,11 +205,12 @@ def persist(self, user: Dojo_User | None = None) -> None: all_product_refs, batch_size=1000, ignore_conflicts=True, ) - # bulk_create bypasses post_save signals, so manually trigger tag inheritance on each unique Location - seen_location_ids: set[int] = set() - for loc in saved: - if loc.location_id not in seen_location_ids: - seen_location_ids.add(loc.location_id) + # bulk_create bypasses post_save signals; trigger tag inheritance only on locations + # that got new refs (matches original signal-based behavior). Short-circuit if the + # product has no tag inheritance enabled — calling inherit_instance_tags per location + # is expensive (each fires a complex JOIN on Product via all_related_products()). + if self._should_inherit_product_tags(): + for loc in locations_needing_inherit.values(): inherit_instance_tags(loc.location) self._locations_by_finding.clear() @@ -221,6 +228,8 @@ def persist(self, user: Dojo_User | None = None) -> None: ).values_list("location_id", flat=True), ) new_refs = [] + # Track locations that got new refs — only those need tag inheritance + locations_needing_inherit: dict[int, AbstractLocation] = {} for location in saved: if location.location_id not in existing: assoc = location.get_association_data() @@ -232,33 +241,123 @@ def persist(self, user: Dojo_User | None = None) -> None: relationship_data=assoc.relationship_data, )) existing.add(location.location_id) + locations_needing_inherit[location.location_id] = location if new_refs: LocationProductReference.objects.bulk_create( new_refs, batch_size=1000, ignore_conflicts=True, ) - # bulk_create bypasses post_save signals; manually trigger tag inheritance - for loc in saved: - inherit_instance_tags(loc.location) + # bulk_create bypasses post_save signals; trigger tag inheritance only on + # locations that got new product refs (short-circuited if the product has no + # tag inheritance enabled — see _should_inherit_product_tags()) + if self._should_inherit_product_tags(): + for loc in locations_needing_inherit.values(): + inherit_instance_tags(loc.location) self._product_locations.clear() - # Step 2: Mitigate accumulated refs - for refs, mitigate_user in self._refs_to_mitigate: - refs.exclude(status=FindingLocationStatus.Mitigated).update( - auditor=mitigate_user, - audit_time=timezone.now(), - status=FindingLocationStatus.Mitigated, + # Steps 2 & 3: Bulk status updates — classify refs, then execute in minimal queries + self._flush_status_updates() + + def _flush_status_updates(self) -> None: + """ + Resolve all accumulated status-update inputs and execute them as bulk UPDATEs. + + Produces ~3-4 queries total regardless of the number of findings processed: + 1 SELECT to fetch relevant location refs for partial-status updates, + 1 UPDATE for reactivations, + 1 UPDATE per unique mitigation user (typically 1). + """ + # Short-circuit if nothing to do + if not (self._status_updates or self._finding_ids_to_fully_reactivate or self._finding_ids_to_fully_mitigate): + return + + special_statuses = [ + FindingLocationStatus.FalsePositive, + FindingLocationStatus.RiskAccepted, + FindingLocationStatus.OutOfScope, + ] + + # Collect ref IDs to reactivate / mitigate across all accumulated inputs + ref_ids_to_reactivate: set[int] = set() + # Grouped by user since auditor differs per entry + ref_ids_to_mitigate_by_user: dict[Dojo_User | None, set[int]] = {} + + # Partial status updates (from update_location_status): need per-finding classification + if self._status_updates: + finding_ids_for_partial = {upd[0].id for upd in self._status_updates} + # Single fetch of all candidate refs with their location values + refs_by_finding: dict[int, list[LocationFindingReference]] = {} + for ref in ( + LocationFindingReference.objects + .filter(finding_id__in=finding_ids_for_partial) + .exclude(status__in=special_statuses) + .select_related("location") + ): + refs_by_finding.setdefault(ref.finding_id, []).append(ref) + + for existing_finding, new_finding, user in self._status_updates: + finding_refs = refs_by_finding.get(existing_finding.id, []) + if new_finding.is_mitigated: + # All non-special refs on this finding get mitigated + ref_ids_to_mitigate_by_user.setdefault(user, set()).update(r.id for r in finding_refs) + else: + new_loc_values = { + str(loc) for loc in type(self).clean_unsaved_locations(new_finding.unsaved_locations) + } + for ref in finding_refs: + if ref.location.location_value in new_loc_values: + ref_ids_to_reactivate.add(ref.id) + else: + ref_ids_to_mitigate_by_user.setdefault(user, set()).add(ref.id) + + # Full reactivations (from record_reactivations_for_finding): all mitigated refs for these findings + if self._finding_ids_to_fully_reactivate: + ref_ids_to_reactivate.update( + LocationFindingReference.objects.filter( + finding_id__in=self._finding_ids_to_fully_reactivate, + status=FindingLocationStatus.Mitigated, + ).values_list("id", flat=True), ) - self._refs_to_mitigate.clear() - # Step 3: Reactivate accumulated refs - for refs in self._refs_to_reactivate: - refs.filter(status=FindingLocationStatus.Mitigated).update( + # Full mitigations (from record_mitigations_for_finding): all non-special refs for these findings, per user + if self._finding_ids_to_fully_mitigate: + # Group finding_ids by user to do one SELECT per user + ids_by_user: dict[Dojo_User | None, list[int]] = {} + for finding_id, user in self._finding_ids_to_fully_mitigate: + ids_by_user.setdefault(user, []).append(finding_id) + for user, finding_ids in ids_by_user.items(): + ref_ids_to_mitigate_by_user.setdefault(user, set()).update( + LocationFindingReference.objects.filter( + finding_id__in=finding_ids, + ).exclude(status__in=special_statuses).values_list("id", flat=True), + ) + + # Execute bulk updates + now = timezone.now() + if ref_ids_to_reactivate: + LocationFindingReference.objects.filter( + id__in=ref_ids_to_reactivate, + status=FindingLocationStatus.Mitigated, + ).update( auditor=None, - audit_time=timezone.now(), + audit_time=now, status=FindingLocationStatus.Active, ) - self._refs_to_reactivate.clear() + + for user, ref_ids in ref_ids_to_mitigate_by_user.items(): + if ref_ids: + LocationFindingReference.objects.filter( + id__in=ref_ids, + ).exclude(status=FindingLocationStatus.Mitigated).update( + auditor=user, + audit_time=now, + status=FindingLocationStatus.Mitigated, + ) + + # Clear accumulators + self._status_updates.clear() + self._finding_ids_to_fully_reactivate.clear() + self._finding_ids_to_fully_mitigate.clear() # ------------------------------------------------------------------ # Type registry diff --git a/unittests/test_bulk_locations.py b/unittests/test_bulk_locations.py index fa9d9492a7a..f008bc43b3e 100644 --- a/unittests/test_bulk_locations.py +++ b/unittests/test_bulk_locations.py @@ -17,6 +17,7 @@ from dojo.importers.location_manager import LocationManager from dojo.location.models import Location, LocationFindingReference, LocationProductReference +from dojo.location.status import FindingLocationStatus from dojo.models import Engagement, Finding, Product, Product_Type, Test, Test_Type from dojo.tools.locations import LocationAssociationData, LocationData from dojo.url.models import URL @@ -311,3 +312,70 @@ def test_locations_inherit_product_tags(self): loc = Location.objects.get(url__host="oss-tag-inherit.example.com") inherited = sorted(t.name for t in loc.inherited_tags.all()) self.assertEqual(inherited, ["inherit", "tags", "these"]) + + +# --------------------------------------------------------------------------- +# Status update query efficiency +# --------------------------------------------------------------------------- +@skip_unless_v3 +class TestStatusUpdateQueryEfficiency(DojoTestCase): + + """ + Verify that persist() flushes status updates with a bounded number of queries, + regardless of how many findings were recorded (not O(n)). + """ + + def _setup_findings_with_mitigated_refs(self, count: int): + """Create `count` findings in a single product, each with a mitigated LocationFindingReference.""" + # Single product for all findings + first_finding = _make_finding() + product = first_finding.test.engagement.product + test = first_finding.test + reporter = first_finding.reporter + + findings = [first_finding] + findings.extend( + Finding.objects.create( + test=test, title=f"Status Test Finding {i}", severity="Medium", reporter=reporter, + ) + for i in range(count - 1) + ) + + # Create one mitigated LocationFindingReference per finding + for i, finding in enumerate(findings): + saved = URL.bulk_get_or_create([_make_url(f"oss-status-{i}.example.com")]) + LocationFindingReference.objects.create( + location=saved[0].location, + finding=finding, + status=FindingLocationStatus.Mitigated, + ) + return findings, product + + def test_reactivate_for_many_findings_is_bulk(self): + findings, product = self._setup_findings_with_mitigated_refs(count=20) + mgr = LocationManager(product) + for finding in findings: + mgr.record_reactivations_for_finding(finding) + + with CaptureQueriesContext(connection) as ctx: + mgr.persist() + + # Expected: 1 SELECT (gather ref IDs) + 1 UPDATE (reactivate). Allow tiny overhead. + self.assertLess(len(ctx.captured_queries), 5, ctx.captured_queries) + + def test_update_location_status_for_many_findings_is_bulk(self): + findings, product = self._setup_findings_with_mitigated_refs(count=20) + reporter = findings[0].reporter + mgr = LocationManager(product) + + # Simulate reimport "matched finding" flow: new finding with no unsaved locations => mitigate all + for finding in findings: + new_finding = Finding(title=finding.title, severity=finding.severity, test=finding.test, is_mitigated=True) + new_finding.unsaved_locations = [] + mgr.update_location_status(finding, new_finding, reporter) + + with CaptureQueriesContext(connection) as ctx: + mgr.persist() + + # Expected: 1 SELECT (partial-status fetch) + 1 UPDATE (mitigate for the single user). + self.assertLess(len(ctx.captured_queries), 5, ctx.captured_queries) From e7d912c55da0a0f40201b60a9f4ee7e8b7646eb6 Mon Sep 17 00:00:00 2001 From: dogboat Date: Wed, 15 Apr 2026 14:57:56 -0400 Subject: [PATCH 14/47] testing --- dojo/importers/location_manager.py | 24 +++--- dojo/tags_signals.py | 124 +++++++++++++++++++++++++++++ unittests/test_bulk_locations.py | 67 ++++++++++++++++ 3 files changed, 205 insertions(+), 10 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 222be976bf2..1b62096ffa3 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -10,7 +10,7 @@ from dojo.location.models import AbstractLocation, LocationFindingReference, LocationProductReference from dojo.location.status import FindingLocationStatus, ProductLocationStatus -from dojo.tags_signals import inherit_instance_tags +from dojo.tags_signals import bulk_inherit_location_tags from dojo.tools.locations import LocationData from dojo.url.models import URL from dojo.utils import get_system_setting @@ -207,11 +207,13 @@ def persist(self, user: Dojo_User | None = None) -> None: # bulk_create bypasses post_save signals; trigger tag inheritance only on locations # that got new refs (matches original signal-based behavior). Short-circuit if the - # product has no tag inheritance enabled — calling inherit_instance_tags per location - # is expensive (each fires a complex JOIN on Product via all_related_products()). - if self._should_inherit_product_tags(): - for loc in locations_needing_inherit.values(): - inherit_instance_tags(loc.location) + # product has no tag inheritance enabled, and use the bulk variant otherwise to + # avoid O(N) expensive JOINs via Location.all_related_products(). + if self._should_inherit_product_tags() and locations_needing_inherit: + bulk_inherit_location_tags( + (loc.location for loc in locations_needing_inherit.values()), + known_product=self._product, + ) self._locations_by_finding.clear() @@ -249,10 +251,12 @@ def persist(self, user: Dojo_User | None = None) -> None: # bulk_create bypasses post_save signals; trigger tag inheritance only on # locations that got new product refs (short-circuited if the product has no - # tag inheritance enabled — see _should_inherit_product_tags()) - if self._should_inherit_product_tags(): - for loc in locations_needing_inherit.values(): - inherit_instance_tags(loc.location) + # tag inheritance enabled — see _should_inherit_product_tags()). + if self._should_inherit_product_tags() and locations_needing_inherit: + bulk_inherit_location_tags( + (loc.location for loc in locations_needing_inherit.values()), + known_product=self._product, + ) self._product_locations.clear() # Steps 2 & 3: Bulk status updates — classify refs, then execute in minimal queries diff --git a/dojo/tags_signals.py b/dojo/tags_signals.py index 0fea7ae8ad5..b4e0bd18756 100644 --- a/dojo/tags_signals.py +++ b/dojo/tags_signals.py @@ -49,6 +49,130 @@ def inherit_instance_tags(instance): instance.inherit_tags(tag_list) +def bulk_inherit_location_tags(locations, *, known_product=None): + """ + Bulk equivalent of calling inherit_instance_tags(loc) for many Locations. + + Uses aggressive prefetching to produce O(1) queries for the "decide what needs + to change" phase, and only runs per-instance mutation queries (~3 each) for + locations that are actually out of sync with their product tags. + + Compared to the per-instance path, this avoids the N expensive JOINs in + Location.all_related_products() (~50ms each). + + Args: + locations: iterable of Location instances to update + known_product: optional hint — if provided, used as the minimum product + set for locations not already associated elsewhere. Not strictly + required for correctness, but lets us skip the fetch-related-products + query in the common case. + + """ + locations = list(locations) + if not locations: + return + + system_wide_inherit = bool(get_system_setting("enable_product_tag_inheritance")) + + # --- Bulk query: map location_id -> set[product_id] for every related product + location_ids = [loc.id for loc in locations] + product_ids_by_location: dict[int, set[int]] = {loc.id: set() for loc in locations} + + # Path 1: via LocationProductReference (direct association) + for loc_id, prod_id in LocationProductReference.objects.filter( + location_id__in=location_ids, + ).values_list("location_id", "product_id"): + product_ids_by_location[loc_id].add(prod_id) + + # Path 2: via LocationFindingReference -> Finding -> Test -> Engagement -> Product + for loc_id, prod_id in ( + LocationFindingReference.objects + .filter(location_id__in=location_ids) + .values_list("location_id", "finding__test__engagement__product_id") + ): + if prod_id is not None: + product_ids_by_location[loc_id].add(prod_id) + + # Seed with known_product so callers don't have to rely on refs being persisted before this call + if known_product is not None: + for loc_id in location_ids: + product_ids_by_location[loc_id].add(known_product.id) + + # --- Bulk query: fetch the unique products with their tags and inheritance flag + all_product_ids = {pid for pids in product_ids_by_location.values() for pid in pids} + if not all_product_ids: + return + + products = { + p.id: p + for p in Product.objects.filter(id__in=all_product_ids).prefetch_related("tags") + } + + # Products that contribute to inheritance (either opted in themselves or system-wide on) + contributing_product_ids = { + pid for pid, p in products.items() + if p.enable_product_tag_inheritance or system_wide_inherit + } + if not contributing_product_ids: + # No product with inheritance enabled and system-wide is off → nothing to do + return + + # Pre-compute the tag names each contributing product contributes + tags_by_product: dict[int, set[str]] = { + pid: {t.name for t in products[pid].tags.all()} + for pid in contributing_product_ids + } + + # --- Bulk query: existing inherited_tags per location + inherited_through = Location.inherited_tags.through + inherited_fk = Location.inherited_tags.field.m2m_reverse_field_name() + existing_inherited_by_location: dict[int, set[str]] = {loc.id: set() for loc in locations} + for loc_id, tag_name in inherited_through.objects.filter( + location_id__in=location_ids, + ).values_list("location_id", f"{inherited_fk}__name"): + existing_inherited_by_location[loc_id].add(tag_name) + + # --- Bulk query: existing user tags per location (needed by _manage_inherited_tags) + tags_through = Location.tags.through + tags_fk = Location.tags.field.m2m_reverse_field_name() + existing_tags_by_location: dict[int, list[str]] = {loc.id: [] for loc in locations} + for loc_id, tag_name in tags_through.objects.filter( + location_id__in=location_ids, + ).values_list("location_id", f"{tags_fk}__name"): + existing_tags_by_location[loc_id].append(tag_name) + + # --- Determine which locations are out of sync and call _manage_inherited_tags directly. + # Calling _manage_inherited_tags with pre-computed values skips the expensive + # products_to_inherit_tags_from() JOIN that location.inherit_tags() would run. + # + # Must disconnect make_inherited_tags_sticky while we mutate — otherwise each + # tags.set() / inherited_tags.set() fires m2m_changed, re-enters the whole expensive + # chain per location, and defeats the point of the bulk path. + from dojo.models import _manage_inherited_tags # noqa: PLC0415 circular import + + signals.m2m_changed.disconnect(make_inherited_tags_sticky, sender=tags_through) + signals.m2m_changed.disconnect(make_inherited_tags_sticky, sender=inherited_through) + try: + for location in locations: + target_tag_names: set[str] = set() + for pid in product_ids_by_location[location.id]: + if pid in contributing_product_ids: + target_tag_names |= tags_by_product[pid] + + existing = existing_inherited_by_location[location.id] + if target_tag_names == existing: + continue # Already in sync — skip the expensive mutation path entirely + + _manage_inherited_tags( + location, + list(target_tag_names), + potentially_existing_tags=existing_tags_by_location[location.id], + ) + finally: + signals.m2m_changed.connect(make_inherited_tags_sticky, sender=tags_through) + signals.m2m_changed.connect(make_inherited_tags_sticky, sender=inherited_through) + + def inherit_linked_instance_tags(instance: LocationFindingReference | LocationProductReference): inherit_instance_tags(instance.location) diff --git a/unittests/test_bulk_locations.py b/unittests/test_bulk_locations.py index f008bc43b3e..bbca75b396f 100644 --- a/unittests/test_bulk_locations.py +++ b/unittests/test_bulk_locations.py @@ -313,6 +313,73 @@ def test_locations_inherit_product_tags(self): inherited = sorted(t.name for t in loc.inherited_tags.all()) self.assertEqual(inherited, ["inherit", "tags", "these"]) + def test_bulk_inherit_is_no_op_when_already_in_sync(self): + """Calling persist() again with the same data should not re-inherit (no mutation queries).""" + finding = _make_finding() + product = finding.test.engagement.product + product.enable_product_tag_inheritance = True + product.save() + product.tags.add("a", "b") + + loc_data = [LocationData(type="url", data={"url": "https://oss-nosync.example.com"})] + # First import — mutations expected + mgr1 = LocationManager(product) + mgr1.record_locations_for_finding(finding, loc_data) + mgr1.persist() + + # Second import — tags already inherited, should be a fast no-op + mgr2 = LocationManager(product) + mgr2.record_locations_for_finding(finding, loc_data) + with CaptureQueriesContext(connection) as ctx: + mgr2.persist() + + # Verify no INSERT or UPDATE queries fired in the inheritance path + mutation_queries = [q for q in ctx.captured_queries if q["sql"].startswith(("INSERT", "UPDATE"))] + # There may still be refs check INSERTs if we're creating the LocationFindingReference again, + # but inherited_tags mutation should be absent. + for q in mutation_queries: + self.assertNotIn("inherited_tags", q["sql"].lower(), f"Unexpected inherited_tags mutation: {q['sql']}") + + def test_bulk_inherit_already_synced_is_constant_time(self): + """ + The main win from the bulk variant is skipping the per-instance mutation path when + locations are already in sync with their product's tags. This test verifies that + repeated persist() calls don't re-do the expensive tagulous work. + """ + finding = _make_finding() + product = finding.test.engagement.product + product.enable_product_tag_inheritance = True + product.save() + product.tags.add("p-tag-1", "p-tag-2") + + loc_data = [ + LocationData(type="url", data={"url": f"https://oss-sync-{i}.example.com"}) + for i in range(10) + ] + # First import to populate inherited_tags + mgr1 = LocationManager(product) + mgr1.record_locations_for_finding(finding, loc_data) + mgr1.persist() + + # Second import — same data, already in sync; should do zero mutation queries + mgr2 = LocationManager(product) + mgr2.record_locations_for_finding(finding, loc_data) + with CaptureQueriesContext(connection) as ctx: + mgr2.persist() + + # No UPDATEs or INSERTs on inherited_tags / tags through tables should fire + tag_through = Location.tags.through._meta.db_table + inherited_through = Location.inherited_tags.through._meta.db_table + for q in ctx.captured_queries: + sql = q["sql"].lower() + if sql.startswith(("insert", "update", "delete")): + self.assertNotIn( + tag_through.lower(), sql, f"Unexpected tags mutation: {q['sql']}", + ) + self.assertNotIn( + inherited_through.lower(), sql, f"Unexpected inherited_tags mutation: {q['sql']}", + ) + # --------------------------------------------------------------------------- # Status update query efficiency From f8d4bc1cf1dc399241ff6c0c63937d33e411fd2b Mon Sep 17 00:00:00 2001 From: dogboat Date: Wed, 15 Apr 2026 20:32:31 -0400 Subject: [PATCH 15/47] wip --- dojo/tags_signals.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dojo/tags_signals.py b/dojo/tags_signals.py index b4e0bd18756..76e7314cbec 100644 --- a/dojo/tags_signals.py +++ b/dojo/tags_signals.py @@ -150,8 +150,11 @@ def bulk_inherit_location_tags(locations, *, known_product=None): # chain per location, and defeats the point of the bulk path. from dojo.models import _manage_inherited_tags # noqa: PLC0415 circular import - signals.m2m_changed.disconnect(make_inherited_tags_sticky, sender=tags_through) - signals.m2m_changed.disconnect(make_inherited_tags_sticky, sender=inherited_through) + # Only disconnect/reconnect for senders where the signal is actually registered + # (tags.through). inherited_tags.through is not a registered sender — attempting + # to connect it after disconnect() would incorrectly add a new registration, + # causing recursion on subsequent calls. + disconnected = signals.m2m_changed.disconnect(make_inherited_tags_sticky, sender=tags_through) try: for location in locations: target_tag_names: set[str] = set() @@ -169,8 +172,8 @@ def bulk_inherit_location_tags(locations, *, known_product=None): potentially_existing_tags=existing_tags_by_location[location.id], ) finally: - signals.m2m_changed.connect(make_inherited_tags_sticky, sender=tags_through) - signals.m2m_changed.connect(make_inherited_tags_sticky, sender=inherited_through) + if disconnected: + signals.m2m_changed.connect(make_inherited_tags_sticky, sender=tags_through) def inherit_linked_instance_tags(instance: LocationFindingReference | LocationProductReference): From 22bba912bf7c73f418cc9f00f0bd3da58738800b Mon Sep 17 00:00:00 2001 From: dogboat Date: Thu, 16 Apr 2026 07:55:35 -0400 Subject: [PATCH 16/47] perf test updates --- unittests/test_importers_performance.py | 32 ++++++++++++------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/unittests/test_importers_performance.py b/unittests/test_importers_performance.py index 7ba2e04ee86..ac77faed610 100644 --- a/unittests/test_importers_performance.py +++ b/unittests/test_importers_performance.py @@ -569,13 +569,13 @@ def test_import_reimport_reimport_performance_pghistory_async(self): configure_pghistory_triggers() self._import_reimport_performance( - expected_num_queries1=165, + expected_num_queries1=144, expected_num_async_tasks1=1, - expected_num_queries2=149, + expected_num_queries2=119, expected_num_async_tasks2=1, - expected_num_queries3=60, + expected_num_queries3=32, expected_num_async_tasks3=1, - expected_num_queries4=100, + expected_num_queries4=96, expected_num_async_tasks4=0, ) @@ -593,13 +593,13 @@ def test_import_reimport_reimport_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._import_reimport_performance( - expected_num_queries1=174, + expected_num_queries1=153, expected_num_async_tasks1=1, - expected_num_queries2=158, + expected_num_queries2=128, expected_num_async_tasks2=1, - expected_num_queries3=69, + expected_num_queries3=41, expected_num_async_tasks3=1, - expected_num_queries4=100, + expected_num_queries4=96, expected_num_async_tasks4=0, ) @@ -618,13 +618,13 @@ def test_import_reimport_reimport_performance_pghistory_no_async_with_product_gr self.system_settings(enable_product_grade=True) self._import_reimport_performance( - expected_num_queries1=184, + expected_num_queries1=163, expected_num_async_tasks1=3, - expected_num_queries2=168, + expected_num_queries2=138, expected_num_async_tasks2=3, - expected_num_queries3=73, + expected_num_queries3=45, expected_num_async_tasks3=3, - expected_num_queries4=110, + expected_num_queries4=106, expected_num_async_tasks4=2, ) @@ -718,9 +718,9 @@ def test_deduplication_performance_pghistory_async(self): self.system_settings(enable_deduplication=True) self._deduplication_performance( - expected_num_queries1=101, + expected_num_queries1=79, expected_num_async_tasks1=1, - expected_num_queries2=92, + expected_num_queries2=70, expected_num_async_tasks2=1, check_duplicates=False, # Async mode - deduplication happens later ) @@ -738,8 +738,8 @@ def test_deduplication_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._deduplication_performance( - expected_num_queries1=110, + expected_num_queries1=88, expected_num_async_tasks1=1, - expected_num_queries2=208, + expected_num_queries2=186, expected_num_async_tasks2=1, ) From 933c2598ebd5ab62801c80af2224d81de1fc8ba1 Mon Sep 17 00:00:00 2001 From: dogboat Date: Thu, 16 Apr 2026 11:22:10 -0400 Subject: [PATCH 17/47] wip rename --- dojo/importers/base_importer.py | 20 ++++++++++---------- dojo/importers/default_importer.py | 10 +++++----- dojo/importers/default_reimporter.py | 17 ++++++++--------- dojo/importers/endpoint_manager.py | 20 ++++++++++---------- dojo/importers/location_manager.py | 20 ++++++++++---------- 5 files changed, 43 insertions(+), 44 deletions(-) diff --git a/dojo/importers/base_importer.py b/dojo/importers/base_importer.py index 8f504d862df..24eb76e15dc 100644 --- a/dojo/importers/base_importer.py +++ b/dojo/importers/base_importer.py @@ -389,18 +389,18 @@ def apply_import_tags( # Add any tags to any locations/endpoints of the findings imported if necessary if self.apply_tags_to_endpoints and self.tags: - items_qs = self.item_manager.get_items_for_tagging(findings_to_tag) + locations_qs = self.location_manager.get_locations_for_tagging(findings_to_tag) try: bulk_add_tags_to_instances( tag_or_tags=self.tags, - instances=items_qs, + instances=locations_qs, tag_field_name="tags", ) except IntegrityError: for finding in findings_to_tag: - for item in self.item_manager.get_item_tag_fallback(finding): + for location in self.location_manager.get_location_tag_fallback(finding): for tag in self.tags: - self.add_tags_safe(item, tag) + self.add_tags_safe(location, tag) def update_import_history( self, @@ -448,7 +448,7 @@ def update_import_history( import_settings["group_by"] = self.group_by import_settings["create_finding_groups_for_all_findings"] = self.create_finding_groups_for_all_findings if len(self.endpoints_to_add) > 0: - import_settings.update(self.item_manager.serialize_extra_items(self.endpoints_to_add)) + import_settings.update(self.location_manager.serialize_extra_locations(self.endpoints_to_add)) # Create the test import object test_import = Test_Import.objects.create( test=self.test, @@ -767,16 +767,16 @@ def process_request_response_pairs( burp_rr.clean() burp_rr.save() - def process_items( + def process_locations( self, finding: Finding, - extra_items_to_add: list | None = None, + extra_locations_to_add: list | None = None, ) -> None: """ Record locations/endpoints from the finding + any form-added extras. - Flushed to DB by item_manager.persist(). + Flushed to DB by location_manager.persist(). """ - self.item_manager.record_for_finding(finding, extra_items_to_add) + self.location_manager.record_for_finding(finding, extra_locations_to_add) def sanitize_vulnerability_ids(self, finding) -> None: """Remove undisired vulnerability id values""" @@ -869,7 +869,7 @@ def mitigate_finding( # Remove risk acceptance if present (vulnerability is now fixed) # risk_unaccept will check if finding.risk_accepted is True before proceeding ra_helper.risk_unaccept(self.user, finding, perform_save=False, post_comments=False) - self.item_manager.record_mitigations_for_finding(finding, self.user) + self.location_manager.record_mitigations_for_finding(finding, self.user) # to avoid pushing a finding group multiple times, we push those outside of the loop if finding_groups_enabled and finding.finding_group: # don't try to dedupe findings that we are closing diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index 8a7e3e0344a..e715f712c0d 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -60,9 +60,9 @@ def __init__(self, *args, **kwargs): **kwargs, ) if settings.V3_FEATURE_LOCATIONS: - self.item_manager = LocationManager(self.engagement.product) + self.location_manager = LocationManager(self.engagement.product) else: - self.item_manager = EndpointManager(self.engagement.product) + self.location_manager = EndpointManager(self.engagement.product) def create_test( self, @@ -243,7 +243,7 @@ def process_findings( ) # Process any request/response pairs self.process_request_response_pairs(finding) - self.process_items(finding, self.endpoints_to_add) + self.process_locations(finding, self.endpoints_to_add) # Parsers must use unsaved_tags to store tags, so we can clean them. # Accumulate for bulk application after the loop (O(unique_tags) instead of O(N·T)). cleaned_tags = clean_tags(finding.unsaved_tags) @@ -266,7 +266,7 @@ def process_findings( # If batch is full or we're at the end, persist locations/endpoints and dispatch if len(batch_finding_ids) >= batch_max_size or is_final_finding: - self.item_manager.persist(user=self.user) + self.location_manager.persist(user=self.user) # Apply parser-supplied tags for this batch before post-processing starts, # so rules/deduplication tasks see the tags already on the findings. bulk_apply_parser_tags(findings_with_parser_tags) @@ -399,7 +399,7 @@ def close_old_findings( product_grading_option=False, ) # Persist any accumulated location/endpoint status changes - self.item_manager.persist(user=self.user) + self.location_manager.persist(user=self.user) # push finding groups to jira since we only only want to push whole groups # We dont check if the finding jira sync is applicable quite yet until we can get in the loop # but this is a way to at least make it that far diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index 187596985b8..30213992e04 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -83,9 +83,9 @@ def __init__(self, *args, **kwargs): **kwargs, ) if settings.V3_FEATURE_LOCATIONS: - self.item_manager = LocationManager(self.test.engagement.product) + self.location_manager = LocationManager(self.test.engagement.product) else: - self.item_manager = EndpointManager(self.test.engagement.product) + self.location_manager = EndpointManager(self.test.engagement.product) def process_scan( self, @@ -340,7 +340,7 @@ def process_findings( # Set the service supplied at import time if self.service is not None: unsaved_finding.service = self.service - self.item_manager.clean_unsaved(unsaved_finding) + self.location_manager.clean_unsaved(unsaved_finding) # Calculate the hash code to be used to identify duplicates unsaved_finding.hash_code = self.calculate_unsaved_finding_hash_code(unsaved_finding) deduplicationLogger.debug(f"unsaved finding's hash_code: {unsaved_finding.hash_code}") @@ -380,7 +380,7 @@ def process_findings( "Re-import found an existing dynamic finding for this new " "finding. Checking the status of locations/endpoints", ) - self.item_manager.update_status( + self.location_manager.update_status( existing_finding, unsaved_finding, self.user, @@ -425,12 +425,11 @@ def process_findings( # - Deduplication batches: optimize bulk operations (larger batches = fewer queries) # They don't need to be aligned since they optimize different operations. if len(batch_finding_ids) >= dedupe_batch_max_size or is_final: - self.item_manager.persist(user=self.user) + self.location_manager.persist(user=self.user) # Apply parser-supplied tags for this batch before post-processing starts, # so rules/deduplication tasks see the tags already on the findings. bulk_apply_parser_tags(findings_with_parser_tags) findings_with_parser_tags.clear() - finding_ids_batch = list(batch_finding_ids) batch_finding_ids.clear() dojo_dispatch_task( @@ -538,7 +537,7 @@ def close_old_findings( ) mitigated_findings.append(finding) # Persist any accumulated location/endpoint status changes - self.item_manager.persist(user=self.user) + self.location_manager.persist(user=self.user) # push finding groups to jira since we only only want to push whole groups # We dont check if the finding jira sync is applicable quite yet until we can get in the loop # but this is a way to at least make it that far @@ -799,7 +798,7 @@ def process_matched_mitigated_finding( note = Notes(entry=f"Re-activated by {self.scan_type} re-upload.", author=self.user) note.save() - self.item_manager.record_reactivations_for_finding(existing_finding) + self.location_manager.record_reactivations_for_finding(existing_finding) existing_finding.notes.add(note) self.reactivated_items.append(existing_finding) # The new finding is active while the existing on is mitigated. The existing finding needs to @@ -966,7 +965,7 @@ def finding_post_processing( # Copy unsaved items from the parser output onto the saved finding so record_for_finding can read them finding.unsaved_locations = getattr(finding_from_report, "unsaved_locations", []) finding.unsaved_endpoints = getattr(finding_from_report, "unsaved_endpoints", []) - self.item_manager.record_for_finding(finding, self.endpoints_to_add or None) + self.location_manager.record_for_finding(finding, self.endpoints_to_add or None) # For matched/existing findings, do not update tags from the report, # consistent with how other fields are handled on reimport. if not is_matched_finding: diff --git a/dojo/importers/endpoint_manager.py b/dojo/importers/endpoint_manager.py index e6657a42ba5..ecff1821aab 100644 --- a/dojo/importers/endpoint_manager.py +++ b/dojo/importers/endpoint_manager.py @@ -309,13 +309,13 @@ def clean_unsaved(self, finding: Finding) -> None: """Clean the unsaved endpoints on this finding.""" self.clean_unsaved_endpoints(finding.unsaved_endpoints) - def record_for_finding(self, finding: Finding, extra_items: list[Endpoint] | None = None) -> None: + def record_for_finding(self, finding: Finding, extra_locations: list[Endpoint] | None = None) -> None: """Record endpoints from the finding + any form-added extras for later batch creation.""" for endpoint in finding.unsaved_endpoints: key = self.record_endpoint(endpoint) self.record_status_for_create(finding, key) - if extra_items: - for endpoint in extra_items: + if extra_locations: + for endpoint in extra_locations: key = self.record_endpoint(endpoint) self.record_status_for_create(finding, key) @@ -331,14 +331,14 @@ def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User | Non """Record endpoint statuses on this finding for mitigation.""" self.record_statuses_to_mitigate(finding.status_finding.all()) - def get_items_for_tagging(self, findings: list[Finding]): - """Return queryset of items to apply tags to.""" + def get_locations_for_tagging(self, findings: list[Finding]): + """Return queryset of locations to apply tags to.""" return Endpoint.objects.filter(finding__in=findings).distinct() - def get_item_tag_fallback(self, finding: Finding): - """Return iterable of taggable items for per-instance fallback.""" + def get_location_tag_fallback(self, finding: Finding): + """Return iterable of taggable locations for per-instance fallback.""" return finding.endpoints.all() - def serialize_extra_items(self, items: list) -> dict: - """Serialize extra items for import history.""" - return {"endpoints": [str(ep) for ep in items]} if items else {} + def serialize_extra_locations(self, locations: list) -> dict: + """Serialize extra locations for import history.""" + return {"endpoints": [str(ep) for ep in locations]} if locations else {} diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 1b62096ffa3..9ec8459382a 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -89,11 +89,11 @@ def clean_unsaved(self, finding: Finding) -> None: """Clean the unsaved locations on this finding.""" type(self).clean_unsaved_locations(finding.unsaved_locations) - def record_for_finding(self, finding: Finding, extra_items: list[UnsavedLocation] | None = None) -> None: + def record_for_finding(self, finding: Finding, extra_locations: list[UnsavedLocation] | None = None) -> None: """Record locations from the finding + any form-added extras for later batch creation.""" self.record_locations_for_finding(finding, finding.unsaved_locations) - if extra_items: - self.record_locations_for_finding(finding, extra_items) + if extra_locations: + self.record_locations_for_finding(finding, extra_locations) def update_status(self, existing_finding: Finding, new_finding: Finding, user: Dojo_User) -> None: """Accumulate status changes (mitigate/reactivate) based on old vs new finding.""" @@ -107,18 +107,18 @@ def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User | Non """Defer mitigation to persist(). No DB access at record time.""" self._finding_ids_to_fully_mitigate.append((finding.id, user)) - def get_items_for_tagging(self, findings: list[Finding]): - """Return queryset of items to apply tags to.""" + def get_locations_for_tagging(self, findings: list[Finding]): + """Return queryset of locations to apply tags to.""" from dojo.location.models import Location # noqa: PLC0415 return Location.objects.filter(findings__finding__in=findings).distinct() - def get_item_tag_fallback(self, finding: Finding): - """Return iterable of taggable items for per-instance fallback.""" + def get_location_tag_fallback(self, finding: Finding): + """Return iterable of taggable locations for per-instance fallback.""" return [ref.location for ref in finding.locations.all()] - def serialize_extra_items(self, items: list) -> dict: - """Serialize extra items for import history.""" - return {"locations": [str(loc) for loc in items]} if items else {} + def serialize_extra_locations(self, locations: list) -> dict: + """Serialize extra locations for import history.""" + return {"locations": [str(loc) for loc in locations]} if locations else {} # ------------------------------------------------------------------ # Persist — flush all accumulated operations to DB From 3920066314fd8caa1226086a4115293a41cc3a65 Mon Sep 17 00:00:00 2001 From: dogboat Date: Thu, 16 Apr 2026 12:26:01 -0400 Subject: [PATCH 18/47] wip --- dojo/importers/base_location_manager.py | 130 ++++++++++++++ dojo/importers/default_importer.py | 8 +- dojo/importers/default_reimporter.py | 8 +- dojo/importers/endpoint_manager.py | 3 +- dojo/importers/location_manager.py | 218 ++++++++++-------------- unittests/test_bulk_locations.py | 11 -- 6 files changed, 225 insertions(+), 153 deletions(-) create mode 100644 dojo/importers/base_location_manager.py diff --git a/dojo/importers/base_location_manager.py b/dojo/importers/base_location_manager.py new file mode 100644 index 00000000000..eaaeea5b6d0 --- /dev/null +++ b/dojo/importers/base_location_manager.py @@ -0,0 +1,130 @@ +""" +Base class and handler for location/endpoint managers in the import pipeline. + +BaseLocationManager defines the contract that both LocationManager (V3) and +EndpointManager (legacy) must implement. LocationHandler is the facade that +importers interact with — it picks the appropriate manager based on +V3_FEATURE_LOCATIONS and delegates all calls through the shared interface. + +This structure prevents drift between the two managers: adding an abstract +method to BaseLocationManager forces both to implement it, and callers can +only access methods exposed by LocationHandler. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dojo.models import Dojo_User, Finding, Product + + +class BaseLocationManager(ABC): + + """ + Abstract base for import-pipeline managers that handle network identifiers + (locations in V3, endpoints in legacy). + + Subclasses must implement every abstract method. The importer never calls + subclass-specific methods directly — it goes through LocationHandler. + """ + + def __init__(self, product: Product) -> None: + self._product = product + + @abstractmethod + def clean_unsaved(self, finding: Finding) -> None: + """Clean the unsaved locations/endpoints on this finding.""" + + @abstractmethod + def record_for_finding(self, finding: Finding, extra_locations: list | None = None) -> None: + """Record items from the finding + any form-added extras for later batch creation.""" + + @abstractmethod + def update_status(self, existing_finding: Finding, new_finding: Finding, user: Dojo_User) -> None: + """Accumulate status changes (mitigate/reactivate) based on old vs new finding.""" + + @abstractmethod + def record_reactivations_for_finding(self, finding: Finding) -> None: + """Record items on this finding for reactivation.""" + + @abstractmethod + def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User | None = None) -> None: + """Record items on this finding for mitigation.""" + + @abstractmethod + def get_locations_for_tagging(self, findings: list[Finding]): + """Return a queryset of taggable objects linked to the given findings.""" + + @abstractmethod + def get_location_tag_fallback(self, finding: Finding): + """Return an iterable of taggable objects for per-instance tag fallback.""" + + @abstractmethod + def serialize_extra_locations(self, locations: list) -> dict: + """Serialize extra locations/endpoints for import history settings.""" + + @abstractmethod + def persist(self, user: Dojo_User | None = None) -> None: + """Flush all accumulated operations to the database.""" + + +class LocationHandler: + + """ + Facade used by importers. Delegates to the appropriate BaseLocationManager + implementation based on V3_FEATURE_LOCATIONS. + + Callers only see the methods defined here — they cannot reach into the + internal manager to call implementation-specific methods. This prevents + V3-only or endpoint-only code from leaking into shared importer logic. + """ + + def __init__( + self, + product: Product, + *, + v3_manager_class: type[BaseLocationManager] | None = None, + v2_manager_class: type[BaseLocationManager] | None = None, + ) -> None: + from django.conf import settings # noqa: PLC0415 + + from dojo.importers.endpoint_manager import EndpointManager # noqa: PLC0415 + from dojo.importers.location_manager import LocationManager # noqa: PLC0415 + + self._product = product + if settings.V3_FEATURE_LOCATIONS: + cls = v3_manager_class or LocationManager + else: + cls = v2_manager_class or EndpointManager + self._manager: BaseLocationManager = cls(product) + + # --- Delegates (one per BaseLocationManager method) --- + + def clean_unsaved(self, finding: Finding) -> None: + return self._manager.clean_unsaved(finding) + + def record_for_finding(self, finding: Finding, extra_locations: list | None = None) -> None: + return self._manager.record_for_finding(finding, extra_locations) + + def update_status(self, existing_finding: Finding, new_finding: Finding, user: Dojo_User) -> None: + return self._manager.update_status(existing_finding, new_finding, user) + + def record_reactivations_for_finding(self, finding: Finding) -> None: + return self._manager.record_reactivations_for_finding(finding) + + def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User | None = None) -> None: + return self._manager.record_mitigations_for_finding(finding, user) + + def get_locations_for_tagging(self, findings: list[Finding]): + return self._manager.get_locations_for_tagging(findings) + + def get_location_tag_fallback(self, finding: Finding): + return self._manager.get_location_tag_fallback(finding) + + def serialize_extra_locations(self, locations: list) -> dict: + return self._manager.serialize_extra_locations(locations) + + def persist(self, user: Dojo_User | None = None) -> None: + return self._manager.persist(user) diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index e715f712c0d..483c6f303d8 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -9,8 +9,7 @@ from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding import helper as finding_helper from dojo.importers.base_importer import BaseImporter, Parser -from dojo.importers.endpoint_manager import EndpointManager -from dojo.importers.location_manager import LocationManager +from dojo.importers.base_location_manager import LocationHandler from dojo.importers.options import ImporterOptions from dojo.jira_link.helper import is_keep_in_sync_with_jira from dojo.models import ( @@ -59,10 +58,7 @@ def __init__(self, *args, **kwargs): import_type=Test_Import.IMPORT_TYPE, **kwargs, ) - if settings.V3_FEATURE_LOCATIONS: - self.location_manager = LocationManager(self.engagement.product) - else: - self.location_manager = EndpointManager(self.engagement.product) + self.location_manager = LocationHandler(self.engagement.product) def create_test( self, diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index 30213992e04..4cbcd0268ed 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -15,8 +15,7 @@ find_candidates_for_reimport_legacy, ) from dojo.importers.base_importer import BaseImporter, Parser -from dojo.importers.endpoint_manager import EndpointManager -from dojo.importers.location_manager import LocationManager +from dojo.importers.base_location_manager import LocationHandler from dojo.importers.options import ImporterOptions from dojo.jira_link.helper import is_keep_in_sync_with_jira from dojo.models import ( @@ -82,10 +81,7 @@ def __init__(self, *args, **kwargs): import_type=Test_Import.REIMPORT_TYPE, **kwargs, ) - if settings.V3_FEATURE_LOCATIONS: - self.location_manager = LocationManager(self.test.engagement.product) - else: - self.location_manager = EndpointManager(self.test.engagement.product) + self.location_manager = LocationHandler(self.test.engagement.product) def process_scan( self, diff --git a/dojo/importers/endpoint_manager.py b/dojo/importers/endpoint_manager.py index ecff1821aab..976e5b24dfc 100644 --- a/dojo/importers/endpoint_manager.py +++ b/dojo/importers/endpoint_manager.py @@ -6,6 +6,7 @@ from django.utils import timezone from hyperlink._url import SCHEME_PORT_MAP # noqa: PLC2701 +from dojo.importers.base_location_manager import BaseLocationManager from dojo.models import ( Dojo_User, Endpoint, @@ -30,7 +31,7 @@ class EndpointUniqueKey(NamedTuple): # TODO: Delete this after the move to Locations -class EndpointManager: +class EndpointManager(BaseLocationManager): def __init__(self, product: Product) -> None: self._product = product diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 9ec8459382a..ef926f384ef 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -8,6 +8,7 @@ from django.core.exceptions import ValidationError from django.utils import timezone +from dojo.importers.base_location_manager import BaseLocationManager from dojo.location.models import AbstractLocation, LocationFindingReference, LocationProductReference from dojo.location.status import FindingLocationStatus, ProductLocationStatus from dojo.tags_signals import bulk_inherit_location_tags @@ -26,12 +27,11 @@ UnsavedLocation = TypeVar("UnsavedLocation", LocationData, AbstractLocation) -class LocationManager: +class LocationManager(BaseLocationManager): def __init__(self, product: Product) -> None: - self._product = product + super().__init__(product) self._locations_by_finding: dict[int, tuple[Finding, list[UnsavedLocation]]] = {} - self._product_locations: list[UnsavedLocation] = [] # Status update inputs (deferred). All entries are processed in a single bulk pass by persist(). # (existing_finding, new_finding, user): classified partial mitigate/reactivate self._status_updates: list[tuple[Finding, Finding, Dojo_User]] = [] @@ -126,141 +126,101 @@ def serialize_extra_locations(self, locations: list) -> dict: def persist(self, user: Dojo_User | None = None) -> None: """Flush all accumulated location operations to the database.""" - # Step 1: Collect all locations across all findings, bulk get/create, bulk create refs - if self._locations_by_finding: - all_locations: list[AbstractLocation] = [] - finding_ranges: list[tuple[Finding, int, int]] = [] - - for finding, locations in self._locations_by_finding.values(): - cleaned = type(self).clean_unsaved_locations(locations) - start = len(all_locations) - all_locations.extend(cleaned) - end = len(all_locations) - if start < end: - finding_ranges.append((finding, start, end)) - - if all_locations: - saved = type(self)._bulk_get_or_create_locations(all_locations) - - # Build all refs across all findings in one pass - all_finding_refs = [] - all_product_refs = [] - # Track locations that got new refs — only those need tag inheritance - # (mirrors original post_save signal behavior on LocationFindingReference/LocationProductReference) - locations_needing_inherit: dict[int, AbstractLocation] = {} - - # Pre-fetch existing product refs for this product across all locations (one query) - all_location_ids = [loc.location_id for loc in saved] - existing_product_refs: set[int] = set( - LocationProductReference.objects.filter( - location_id__in=all_location_ids, - product=self._product, - ).values_list("location_id", flat=True), - ) + self._persist_finding_locations() + self._flush_status_updates() - # Pre-fetch existing finding refs across ALL findings in one query (avoids N+1) - all_finding_ids = [finding.id for finding, _, _ in finding_ranges] - existing_finding_ref_keys: set[tuple[int, int]] = set( - LocationFindingReference.objects.filter( - location_id__in=all_location_ids, - finding_id__in=all_finding_ids, - ).values_list("finding_id", "location_id"), - ) + def _persist_finding_locations(self) -> None: + """Bulk get/create locations and their finding+product refs.""" + if not self._locations_by_finding: + return - for finding, start, end in finding_ranges: - finding_locations = saved[start:end] - - for location in finding_locations: - assoc = location.get_association_data() - finding_ref_key = (finding.id, location.location_id) - - if finding_ref_key not in existing_finding_ref_keys: - all_finding_refs.append(LocationFindingReference( - location_id=location.location_id, - finding=finding, - status=FindingLocationStatus.Active, - relationship=assoc.relationship_type, - relationship_data=assoc.relationship_data, - )) - existing_finding_ref_keys.add(finding_ref_key) - locations_needing_inherit[location.location_id] = location - - if location.location_id not in existing_product_refs: - all_product_refs.append(LocationProductReference( - location_id=location.location_id, - product=self._product, - status=ProductLocationStatus.Active, - relationship=assoc.relationship_type, - relationship_data=assoc.relationship_data, - )) - existing_product_refs.add(location.location_id) - locations_needing_inherit[location.location_id] = location - - if all_finding_refs: - LocationFindingReference.objects.bulk_create( - all_finding_refs, batch_size=1000, ignore_conflicts=True, - ) - if all_product_refs: - LocationProductReference.objects.bulk_create( - all_product_refs, batch_size=1000, ignore_conflicts=True, - ) - - # bulk_create bypasses post_save signals; trigger tag inheritance only on locations - # that got new refs (matches original signal-based behavior). Short-circuit if the - # product has no tag inheritance enabled, and use the bulk variant otherwise to - # avoid O(N) expensive JOINs via Location.all_related_products(). - if self._should_inherit_product_tags() and locations_needing_inherit: - bulk_inherit_location_tags( - (loc.location for loc in locations_needing_inherit.values()), - known_product=self._product, - ) - - self._locations_by_finding.clear() - - # Step 1b: Product-level locations (not tied to a finding) - if self._product_locations: - cleaned = type(self).clean_unsaved_locations(self._product_locations) - if cleaned: - saved = type(self)._bulk_get_or_create_locations(cleaned) - location_ids = [loc.location_id for loc in saved] - existing = set( - LocationProductReference.objects.filter( - location_id__in=location_ids, - product=self._product, - ).values_list("location_id", flat=True), - ) - new_refs = [] - # Track locations that got new refs — only those need tag inheritance - locations_needing_inherit: dict[int, AbstractLocation] = {} - for location in saved: - if location.location_id not in existing: - assoc = location.get_association_data() - new_refs.append(LocationProductReference( + all_locations: list[AbstractLocation] = [] + finding_ranges: list[tuple[Finding, int, int]] = [] + + for finding, locations in self._locations_by_finding.values(): + cleaned = type(self).clean_unsaved_locations(locations) + start = len(all_locations) + all_locations.extend(cleaned) + end = len(all_locations) + if start < end: + finding_ranges.append((finding, start, end)) + + if all_locations: + saved = type(self)._bulk_get_or_create_locations(all_locations) + + # Build all refs across all findings in one pass + all_finding_refs = [] + all_product_refs = [] + # Track locations that got new refs — only those need tag inheritance + locations_needing_inherit: dict[int, AbstractLocation] = {} + + # Pre-fetch existing product refs for this product across all locations (one query) + all_location_ids = [loc.location_id for loc in saved] + existing_product_refs: set[int] = set( + LocationProductReference.objects.filter( + location_id__in=all_location_ids, + product=self._product, + ).values_list("location_id", flat=True), + ) + + # Pre-fetch existing finding refs across ALL findings in one query (avoids N+1) + all_finding_ids = [finding.id for finding, _, _ in finding_ranges] + existing_finding_ref_keys: set[tuple[int, int]] = set( + LocationFindingReference.objects.filter( + location_id__in=all_location_ids, + finding_id__in=all_finding_ids, + ).values_list("finding_id", "location_id"), + ) + + for finding, start, end in finding_ranges: + finding_locations = saved[start:end] + + for location in finding_locations: + assoc = location.get_association_data() + finding_ref_key = (finding.id, location.location_id) + + if finding_ref_key not in existing_finding_ref_keys: + all_finding_refs.append(LocationFindingReference( + location_id=location.location_id, + finding=finding, + status=FindingLocationStatus.Active, + relationship=assoc.relationship_type, + relationship_data=assoc.relationship_data, + )) + existing_finding_ref_keys.add(finding_ref_key) + locations_needing_inherit[location.location_id] = location + + if location.location_id not in existing_product_refs: + all_product_refs.append(LocationProductReference( location_id=location.location_id, product=self._product, status=ProductLocationStatus.Active, relationship=assoc.relationship_type, relationship_data=assoc.relationship_data, )) - existing.add(location.location_id) + existing_product_refs.add(location.location_id) locations_needing_inherit[location.location_id] = location - if new_refs: - LocationProductReference.objects.bulk_create( - new_refs, batch_size=1000, ignore_conflicts=True, - ) - - # bulk_create bypasses post_save signals; trigger tag inheritance only on - # locations that got new product refs (short-circuited if the product has no - # tag inheritance enabled — see _should_inherit_product_tags()). - if self._should_inherit_product_tags() and locations_needing_inherit: - bulk_inherit_location_tags( - (loc.location for loc in locations_needing_inherit.values()), - known_product=self._product, - ) - self._product_locations.clear() - - # Steps 2 & 3: Bulk status updates — classify refs, then execute in minimal queries - self._flush_status_updates() + + if all_finding_refs: + LocationFindingReference.objects.bulk_create( + all_finding_refs, batch_size=1000, ignore_conflicts=True, + ) + if all_product_refs: + LocationProductReference.objects.bulk_create( + all_product_refs, batch_size=1000, ignore_conflicts=True, + ) + + # bulk_create bypasses post_save signals; trigger tag inheritance only on locations + # that got new refs (matches original signal-based behavior). Short-circuit if the + # product has no tag inheritance enabled, and use the bulk variant otherwise to + # avoid O(N) expensive JOINs via Location.all_related_products(). + if self._should_inherit_product_tags() and locations_needing_inherit: + bulk_inherit_location_tags( + (loc.location for loc in locations_needing_inherit.values()), + known_product=self._product, + ) + + self._locations_by_finding.clear() def _flush_status_updates(self) -> None: """ diff --git a/unittests/test_bulk_locations.py b/unittests/test_bulk_locations.py index bbca75b396f..30457795366 100644 --- a/unittests/test_bulk_locations.py +++ b/unittests/test_bulk_locations.py @@ -194,17 +194,6 @@ def test_uses_association_data(self): self.assertEqual(ref.relationship, "owned_by") self.assertEqual(ref.relationship_data, {"file_path": "/app/main.py"}) - def test_product_only_locations(self): - pt, _ = Product_Type.objects.get_or_create(name="Refs Test Type") - product = Product.objects.create(name="Refs Product Only", description="test", prod_type=pt) - - mgr = LocationManager(product) - mgr._product_locations.extend([_make_url("oss-product-only.example.com")]) - mgr.persist() - - self.assertTrue(LocationProductReference.objects.filter(product=product).exists()) - self.assertFalse(LocationFindingReference.objects.exists()) - # --------------------------------------------------------------------------- # End-to-end: record + persist From e5d321398f75065fc852f73e9e21c5755f615853 Mon Sep 17 00:00:00 2001 From: dogboat Date: Thu, 16 Apr 2026 12:38:28 -0400 Subject: [PATCH 19/47] cleanup --- dojo/importers/location_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index ef926f384ef..f5a6ac4f432 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -87,7 +87,7 @@ def update_location_status( def clean_unsaved(self, finding: Finding) -> None: """Clean the unsaved locations on this finding.""" - type(self).clean_unsaved_locations(finding.unsaved_locations) + self.clean_unsaved_locations(finding.unsaved_locations) def record_for_finding(self, finding: Finding, extra_locations: list[UnsavedLocation] | None = None) -> None: """Record locations from the finding + any form-added extras for later batch creation.""" @@ -138,7 +138,7 @@ def _persist_finding_locations(self) -> None: finding_ranges: list[tuple[Finding, int, int]] = [] for finding, locations in self._locations_by_finding.values(): - cleaned = type(self).clean_unsaved_locations(locations) + cleaned = self.clean_unsaved_locations(locations) start = len(all_locations) all_locations.extend(cleaned) end = len(all_locations) @@ -146,7 +146,7 @@ def _persist_finding_locations(self) -> None: finding_ranges.append((finding, start, end)) if all_locations: - saved = type(self)._bulk_get_or_create_locations(all_locations) + saved = self._bulk_get_or_create_locations(all_locations) # Build all refs across all findings in one pass all_finding_refs = [] @@ -266,7 +266,7 @@ def _flush_status_updates(self) -> None: ref_ids_to_mitigate_by_user.setdefault(user, set()).update(r.id for r in finding_refs) else: new_loc_values = { - str(loc) for loc in type(self).clean_unsaved_locations(new_finding.unsaved_locations) + str(loc) for loc in self.clean_unsaved_locations(new_finding.unsaved_locations) } for ref in finding_refs: if ref.location.location_value in new_loc_values: From ce8f9eb960bb86777dd6c31bf83ca4221794f986 Mon Sep 17 00:00:00 2001 From: dogboat Date: Thu, 16 Apr 2026 12:46:20 -0400 Subject: [PATCH 20/47] comments --- dojo/importers/location_manager.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index f5a6ac4f432..a4dcfe85e7a 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -378,21 +378,36 @@ def clean_unsaved_locations( @classmethod def _bulk_get_or_create_locations(cls, locations: list[AbstractLocation]) -> list[AbstractLocation]: - """Bulk get-or-create a (possibly heterogeneous) list of AbstractLocations.""" + """ + Bulk get-or-create a (possibly heterogeneous) list of AbstractLocations. + + The input list may contain a mix of AbstractLocation instances. This method + groups them by concrete type, delegates each group to that type's bulk_get_or_create, + then reassembles results in the original input order. + """ if not locations: return [] + # Keying function: group by the (Python) identity of the concrete class (e.g., URL vs Dependency). + # Using id() because class objects aren't sortable. def type_id(x: tuple[int, AbstractLocation]) -> int: return id(type(x[1])) saved = [] + # Sort by type, tracking the original index via enumerate so we can restore order later locations_with_idx = sorted(enumerate(locations), key=type_id) + # Now group by type locations_by_type = groupby(locations_with_idx, key=type_id) for _, grouped_locations_with_idx in locations_by_type: + # Split into parallel lists: original indices and the homogeneous location objects indices, grouped_locations = zip(*grouped_locations_with_idx, strict=True) + # Determine the concrete AbstractLocation subclass (URL, Dependency, etc.) loc_cls = type(grouped_locations[0]) + # Delegate to the per-type bulk_get_or_create on AbstractLocation saved_locations = loc_cls.bulk_get_or_create(grouped_locations) + # Pair each result back with its original index saved.extend((idx, saved_loc) for idx, saved_loc in zip(indices, saved_locations, strict=True)) + # Restore the original input ordering saved.sort(key=itemgetter(0)) return [loc for _, loc in saved] From 8647d36cc9594f8f082a125aae5c10d7b057e2e8 Mon Sep 17 00:00:00 2001 From: dogboat Date: Thu, 16 Apr 2026 12:59:40 -0400 Subject: [PATCH 21/47] rename --- dojo/importers/base_importer.py | 12 ++++++------ dojo/importers/default_importer.py | 6 +++--- dojo/importers/default_reimporter.py | 14 +++++++------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/dojo/importers/base_importer.py b/dojo/importers/base_importer.py index 24eb76e15dc..ff04a5698de 100644 --- a/dojo/importers/base_importer.py +++ b/dojo/importers/base_importer.py @@ -389,7 +389,7 @@ def apply_import_tags( # Add any tags to any locations/endpoints of the findings imported if necessary if self.apply_tags_to_endpoints and self.tags: - locations_qs = self.location_manager.get_locations_for_tagging(findings_to_tag) + locations_qs = self.location_handler.get_locations_for_tagging(findings_to_tag) try: bulk_add_tags_to_instances( tag_or_tags=self.tags, @@ -398,7 +398,7 @@ def apply_import_tags( ) except IntegrityError: for finding in findings_to_tag: - for location in self.location_manager.get_location_tag_fallback(finding): + for location in self.location_handler.get_location_tag_fallback(finding): for tag in self.tags: self.add_tags_safe(location, tag) @@ -448,7 +448,7 @@ def update_import_history( import_settings["group_by"] = self.group_by import_settings["create_finding_groups_for_all_findings"] = self.create_finding_groups_for_all_findings if len(self.endpoints_to_add) > 0: - import_settings.update(self.location_manager.serialize_extra_locations(self.endpoints_to_add)) + import_settings.update(self.location_handler.serialize_extra_locations(self.endpoints_to_add)) # Create the test import object test_import = Test_Import.objects.create( test=self.test, @@ -774,9 +774,9 @@ def process_locations( ) -> None: """ Record locations/endpoints from the finding + any form-added extras. - Flushed to DB by location_manager.persist(). + Flushed to DB by location_handler.persist(). """ - self.location_manager.record_for_finding(finding, extra_locations_to_add) + self.location_handler.record_for_finding(finding, extra_locations_to_add) def sanitize_vulnerability_ids(self, finding) -> None: """Remove undisired vulnerability id values""" @@ -869,7 +869,7 @@ def mitigate_finding( # Remove risk acceptance if present (vulnerability is now fixed) # risk_unaccept will check if finding.risk_accepted is True before proceeding ra_helper.risk_unaccept(self.user, finding, perform_save=False, post_comments=False) - self.location_manager.record_mitigations_for_finding(finding, self.user) + self.location_handler.record_mitigations_for_finding(finding, self.user) # to avoid pushing a finding group multiple times, we push those outside of the loop if finding_groups_enabled and finding.finding_group: # don't try to dedupe findings that we are closing diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index 483c6f303d8..b3226a19c2b 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -58,7 +58,7 @@ def __init__(self, *args, **kwargs): import_type=Test_Import.IMPORT_TYPE, **kwargs, ) - self.location_manager = LocationHandler(self.engagement.product) + self.location_handler = LocationHandler(self.engagement.product) def create_test( self, @@ -262,7 +262,7 @@ def process_findings( # If batch is full or we're at the end, persist locations/endpoints and dispatch if len(batch_finding_ids) >= batch_max_size or is_final_finding: - self.location_manager.persist(user=self.user) + self.location_handler.persist(user=self.user) # Apply parser-supplied tags for this batch before post-processing starts, # so rules/deduplication tasks see the tags already on the findings. bulk_apply_parser_tags(findings_with_parser_tags) @@ -395,7 +395,7 @@ def close_old_findings( product_grading_option=False, ) # Persist any accumulated location/endpoint status changes - self.location_manager.persist(user=self.user) + self.location_handler.persist(user=self.user) # push finding groups to jira since we only only want to push whole groups # We dont check if the finding jira sync is applicable quite yet until we can get in the loop # but this is a way to at least make it that far diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index 4cbcd0268ed..280e4054ea7 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -81,7 +81,7 @@ def __init__(self, *args, **kwargs): import_type=Test_Import.REIMPORT_TYPE, **kwargs, ) - self.location_manager = LocationHandler(self.test.engagement.product) + self.location_handler = LocationHandler(self.test.engagement.product) def process_scan( self, @@ -336,7 +336,7 @@ def process_findings( # Set the service supplied at import time if self.service is not None: unsaved_finding.service = self.service - self.location_manager.clean_unsaved(unsaved_finding) + self.location_handler.clean_unsaved(unsaved_finding) # Calculate the hash code to be used to identify duplicates unsaved_finding.hash_code = self.calculate_unsaved_finding_hash_code(unsaved_finding) deduplicationLogger.debug(f"unsaved finding's hash_code: {unsaved_finding.hash_code}") @@ -376,7 +376,7 @@ def process_findings( "Re-import found an existing dynamic finding for this new " "finding. Checking the status of locations/endpoints", ) - self.location_manager.update_status( + self.location_handler.update_status( existing_finding, unsaved_finding, self.user, @@ -421,7 +421,7 @@ def process_findings( # - Deduplication batches: optimize bulk operations (larger batches = fewer queries) # They don't need to be aligned since they optimize different operations. if len(batch_finding_ids) >= dedupe_batch_max_size or is_final: - self.location_manager.persist(user=self.user) + self.location_handler.persist(user=self.user) # Apply parser-supplied tags for this batch before post-processing starts, # so rules/deduplication tasks see the tags already on the findings. bulk_apply_parser_tags(findings_with_parser_tags) @@ -533,7 +533,7 @@ def close_old_findings( ) mitigated_findings.append(finding) # Persist any accumulated location/endpoint status changes - self.location_manager.persist(user=self.user) + self.location_handler.persist(user=self.user) # push finding groups to jira since we only only want to push whole groups # We dont check if the finding jira sync is applicable quite yet until we can get in the loop # but this is a way to at least make it that far @@ -794,7 +794,7 @@ def process_matched_mitigated_finding( note = Notes(entry=f"Re-activated by {self.scan_type} re-upload.", author=self.user) note.save() - self.location_manager.record_reactivations_for_finding(existing_finding) + self.location_handler.record_reactivations_for_finding(existing_finding) existing_finding.notes.add(note) self.reactivated_items.append(existing_finding) # The new finding is active while the existing on is mitigated. The existing finding needs to @@ -961,7 +961,7 @@ def finding_post_processing( # Copy unsaved items from the parser output onto the saved finding so record_for_finding can read them finding.unsaved_locations = getattr(finding_from_report, "unsaved_locations", []) finding.unsaved_endpoints = getattr(finding_from_report, "unsaved_endpoints", []) - self.location_manager.record_for_finding(finding, self.endpoints_to_add or None) + self.location_handler.record_for_finding(finding, self.endpoints_to_add or None) # For matched/existing findings, do not update tags from the report, # consistent with how other fields are handled on reimport. if not is_matched_finding: From 37bbfb4accff62807e661be8dfb86f24c42c48cf Mon Sep 17 00:00:00 2001 From: dogboat Date: Thu, 16 Apr 2026 14:30:41 -0400 Subject: [PATCH 22/47] lint --- dojo/importers/location_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index a4dcfe85e7a..a7dba385450 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -9,7 +9,7 @@ from django.utils import timezone from dojo.importers.base_location_manager import BaseLocationManager -from dojo.location.models import AbstractLocation, LocationFindingReference, LocationProductReference +from dojo.location.models import AbstractLocation, Location, LocationFindingReference, LocationProductReference from dojo.location.status import FindingLocationStatus, ProductLocationStatus from dojo.tags_signals import bulk_inherit_location_tags from dojo.tools.locations import LocationData @@ -109,7 +109,6 @@ def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User | Non def get_locations_for_tagging(self, findings: list[Finding]): """Return queryset of locations to apply tags to.""" - from dojo.location.models import Location # noqa: PLC0415 return Location.objects.filter(findings__finding__in=findings).distinct() def get_location_tag_fallback(self, finding: Finding): From 164d3956e1557313c6b39c735fcd866a4e1b7caf Mon Sep 17 00:00:00 2001 From: dogboat Date: Thu, 16 Apr 2026 14:40:50 -0400 Subject: [PATCH 23/47] clean on endpoints --- dojo/importers/endpoint_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dojo/importers/endpoint_manager.py b/dojo/importers/endpoint_manager.py index 976e5b24dfc..77e86c724a2 100644 --- a/dojo/importers/endpoint_manager.py +++ b/dojo/importers/endpoint_manager.py @@ -312,6 +312,7 @@ def clean_unsaved(self, finding: Finding) -> None: def record_for_finding(self, finding: Finding, extra_locations: list[Endpoint] | None = None) -> None: """Record endpoints from the finding + any form-added extras for later batch creation.""" + self.clean_unsaved_endpoints(finding.unsaved_endpoints) for endpoint in finding.unsaved_endpoints: key = self.record_endpoint(endpoint) self.record_status_for_create(finding, key) From 8ec5559c0f2cb71c04cb17c3401dd5a2906ef405 Mon Sep 17 00:00:00 2001 From: dogboat Date: Thu, 16 Apr 2026 16:58:09 -0400 Subject: [PATCH 24/47] remove cleaning until later --- dojo/url/models.py | 4 +- unittests/test_bulk_locations.py | 67 +++++++++++++++++++++++++++++++- 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/dojo/url/models.py b/dojo/url/models.py index cc5ae338ddf..41d7bb65858 100644 --- a/dojo/url/models.py +++ b/dojo/url/models.py @@ -359,7 +359,7 @@ def from_parts( query=None, fragment=None, ) -> URL: - url = URL( + return URL( protocol=protocol, user_info=user_info, host=host, @@ -368,8 +368,6 @@ def from_parts( query=query, fragment=fragment, ) - url.clean() - return url @staticmethod def create_location_from_value(value: str) -> URL: diff --git a/unittests/test_bulk_locations.py b/unittests/test_bulk_locations.py index 30457795366..1941faf87ff 100644 --- a/unittests/test_bulk_locations.py +++ b/unittests/test_bulk_locations.py @@ -21,7 +21,7 @@ from dojo.models import Engagement, Finding, Product, Product_Type, Test, Test_Type from dojo.tools.locations import LocationAssociationData, LocationData from dojo.url.models import URL -from unittests.dojo_test_case import DojoTestCase, skip_unless_v3 +from unittests.dojo_test_case import DojoTestCase, skip_unless_v2, skip_unless_v3 User = get_user_model() @@ -261,6 +261,71 @@ def test_multiple_findings_single_persist(self): self.assertEqual(LocationFindingReference.objects.filter(finding=finding2).count(), 1) self.assertEqual(LocationProductReference.objects.filter(product=product).count(), 2) + def test_locations_are_cleaned_during_persist(self): + """ + Verify that locations go through clean() normalization when persisted. + + URL.clean() normalizes protocol/host to lowercase and sets default ports. + If clean isn't called, the raw input values would be stored as-is. + """ + finding = _make_finding() + product = finding.test.engagement.product + + # Create a URL with uppercase protocol and host — clean() should normalize these + loc_data = [LocationData(type="url", data={ + "protocol": "HTTPS", + "host": "UPPERCASE.EXAMPLE.COM", + "path": "api/v1", + })] + mgr = LocationManager(product) + mgr.record_locations_for_finding(finding, loc_data) + mgr.persist() + + saved_url = URL.objects.get(host="uppercase.example.com") + # Protocol should be lowercased + self.assertEqual(saved_url.protocol, "https") + # Host should be lowercased + self.assertEqual(saved_url.host, "uppercase.example.com") + # Default HTTPS port should be set + self.assertEqual(saved_url.port, 443) + + +# --------------------------------------------------------------------------- +# EndpointManager: verify clean is called during record_for_finding +# --------------------------------------------------------------------------- +@skip_unless_v2 +class TestEndpointCleanOnRecord(DojoTestCase): + + def test_endpoints_are_cleaned_during_record_for_finding(self): + """ + Verify that EndpointManager.record_for_finding() runs clean() on endpoints. + + Endpoint.clean() validates format (not normalize case). An endpoint with + an invalid protocol should trigger a warning log but still be recorded + (DefectDojo stores broken endpoints with a warning). An endpoint with a + valid protocol should pass through clean() without error. + """ + # Keep imports here for reasy removal of this entire test in the future, once endpoints is gone + from dojo.importers.endpoint_manager import EndpointManager # noqa: PLC0415 + from dojo.models import Endpoint # noqa: PLC0415 + + pt, _ = Product_Type.objects.get_or_create(name="EP Clean Test Type") + product = Product.objects.create(name="EP Clean Test Product", description="t", prod_type=pt) + + mgr = EndpointManager(product) + + finding = _make_finding() + # Valid endpoint + one with empty protocol (clean sets it to None) + ep_valid = Endpoint(protocol="https", host="good.example.com") + ep_empty_proto = Endpoint(protocol="", host="empty-proto.example.com") + finding.unsaved_endpoints = [ep_valid, ep_empty_proto] + + mgr.record_for_finding(finding) + + # clean() should have set empty protocol to None + self.assertEqual(ep_valid.protocol, "https") + self.assertIsNone(ep_empty_proto.protocol) + # --------------------------------------------------------------------------- # Query efficiency From e5cb3598c86a90748f18f932b36e5396090bb005 Mon Sep 17 00:00:00 2001 From: dogboat Date: Thu, 16 Apr 2026 17:51:24 -0400 Subject: [PATCH 25/47] test updates --- unittests/dojo_test_case.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unittests/dojo_test_case.py b/unittests/dojo_test_case.py index 66add0baa1c..03817ce7625 100644 --- a/unittests/dojo_test_case.py +++ b/unittests/dojo_test_case.py @@ -541,7 +541,10 @@ def get_latest_model(self, model): def get_unsaved_locations(self, finding): if settings.V3_FEATURE_LOCATIONS: - return LocationManager.make_abstract_locations(finding.unsaved_locations) + locations = LocationManager.make_abstract_locations(finding.unsaved_locations) + for loc in locations: + loc.clean() + return locations # TODO: Delete this after the move to Locations return finding.unsaved_endpoints From 9f0bb2eabff48cfa6950f9e61f50b07cd20baf81 Mon Sep 17 00:00:00 2001 From: dogboat Date: Thu, 16 Apr 2026 18:21:20 -0400 Subject: [PATCH 26/47] cache cleaned locations for tests --- unittests/dojo_test_case.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/unittests/dojo_test_case.py b/unittests/dojo_test_case.py index 03817ce7625..098aa77376d 100644 --- a/unittests/dojo_test_case.py +++ b/unittests/dojo_test_case.py @@ -540,20 +540,22 @@ def get_latest_model(self, model): return model.objects.order_by("id").last() def get_unsaved_locations(self, finding): - if settings.V3_FEATURE_LOCATIONS: - locations = LocationManager.make_abstract_locations(finding.unsaved_locations) + if not hasattr(finding, "_cached_unsaved_locations"): + if settings.V3_FEATURE_LOCATIONS: + locations = LocationManager.make_abstract_locations(finding.unsaved_locations) + else: + # TODO: Delete this after the move to Locations + locations = finding.unsaved_endpoints for loc in locations: loc.clean() - return locations - # TODO: Delete this after the move to Locations - return finding.unsaved_endpoints + finding._cached_unsaved_locations = locations + return finding._cached_unsaved_locations def validate_locations(self, findings): for finding in findings: - # AND SEVERITY HAHAHAHA self.assertIn(finding.severity, Finding.SEVERITIES) - for location in self.get_unsaved_locations(finding): - location.clean() + # get_unsaved_locations handles conversion + cleaning + caching + self.get_unsaved_locations(finding) class DojoTestCase(TestCase, DojoTestUtilsMixin): From e2089343f4f375a606cc921c6827ba1f885e27ac Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 08:11:22 -0400 Subject: [PATCH 27/47] remove unnecessary guard --- dojo/location/models.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/dojo/location/models.py b/dojo/location/models.py index fa840c23a01..ad99cf2dd75 100644 --- a/dojo/location/models.py +++ b/dojo/location/models.py @@ -370,9 +370,6 @@ def bulk_get_or_create(cls, locations: Iterable[Self]) -> list[Self]: if not isinstance(loc, cls): error_message = f"Invalid location type; expected {cls} but got {type(loc)}" raise TypeError(error_message) - # Set .identity_hash if not present - if not loc.identity_hash: - loc.clean() hashes.append(loc.identity_hash) # Look up existing objects, grouping by hash From 77261be8dafe983175c1b51ba2adf0ed9b0c108a Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 08:12:06 -0400 Subject: [PATCH 28/47] move clean call mmm --- dojo/importers/location_manager.py | 14 +++++--------- dojo/location/models.py | 1 + 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index a7dba385450..722e633c56d 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -360,16 +360,12 @@ def clean_unsaved_locations( locations: list[UnsavedLocation], ) -> list[AbstractLocation]: """ - Convert locations represented as LocationData dataclasses to the appropriate AbstractLocation type, then clean - them. For any endpoints that fail this validation process, log a message that broken locations are being stored. + Convert locations represented as LocationData dataclasses to the appropriate + AbstractLocation type and deduplicate them. Cleaning (validation + normalization) + is deferred to bulk_get_or_create which calls clean() on each location before + DB access. """ - locations = list(set(cls.make_abstract_locations(locations))) - for location in locations: - try: - location.clean() - except ValidationError as e: - logger.warning("DefectDojo is storing broken locations because cleaning wasn't successful: %s", e) - return locations + return list(set(cls.make_abstract_locations(locations))) # ------------------------------------------------------------------ # Bulk internals diff --git a/dojo/location/models.py b/dojo/location/models.py index ad99cf2dd75..4e57f22db39 100644 --- a/dojo/location/models.py +++ b/dojo/location/models.py @@ -370,6 +370,7 @@ def bulk_get_or_create(cls, locations: Iterable[Self]) -> list[Self]: if not isinstance(loc, cls): error_message = f"Invalid location type; expected {cls} but got {type(loc)}" raise TypeError(error_message) + loc.clean() hashes.append(loc.identity_hash) # Look up existing objects, grouping by hash From 5d35fa4883def323e2df93531b8af07e01e350c6 Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 08:48:08 -0400 Subject: [PATCH 29/47] consolidate --- dojo/location/models.py | 6 ++++++ dojo/url/models.py | 6 ------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dojo/location/models.py b/dojo/location/models.py index 4e57f22db39..9e74a350a25 100644 --- a/dojo/location/models.py +++ b/dojo/location/models.py @@ -276,6 +276,12 @@ class AbstractLocation(BaseModelWithoutTimeMeta): class Meta: abstract = True + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: object) -> bool: + return isinstance(other, type(self)) and str(self) == str(other) + def clean(self, *args: list, **kwargs: dict) -> None: self.set_identity_hash() super().clean(*args, **kwargs) diff --git a/dojo/url/models.py b/dojo/url/models.py index 41d7bb65858..6e1358cfac3 100644 --- a/dojo/url/models.py +++ b/dojo/url/models.py @@ -225,12 +225,6 @@ def __str__(self) -> str: return URL.URL_PARSING_CLASS().unparse(self) return self.manual_str() - def __hash__(self) -> int: - return hash(str(self)) - - def __eq__(self, other: object) -> bool: - return isinstance(other, URL) and str(self) == str(other) - @classmethod def get_location_type(cls) -> str: return cls.LOCATION_TYPE From a3e950db0533624d72efb7c4033a72ff4429b612 Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 10:17:21 -0400 Subject: [PATCH 30/47] restore clean mmm --- dojo/importers/location_manager.py | 14 +++++++++----- dojo/location/models.py | 1 - 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 722e633c56d..a7dba385450 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -360,12 +360,16 @@ def clean_unsaved_locations( locations: list[UnsavedLocation], ) -> list[AbstractLocation]: """ - Convert locations represented as LocationData dataclasses to the appropriate - AbstractLocation type and deduplicate them. Cleaning (validation + normalization) - is deferred to bulk_get_or_create which calls clean() on each location before - DB access. + Convert locations represented as LocationData dataclasses to the appropriate AbstractLocation type, then clean + them. For any endpoints that fail this validation process, log a message that broken locations are being stored. """ - return list(set(cls.make_abstract_locations(locations))) + locations = list(set(cls.make_abstract_locations(locations))) + for location in locations: + try: + location.clean() + except ValidationError as e: + logger.warning("DefectDojo is storing broken locations because cleaning wasn't successful: %s", e) + return locations # ------------------------------------------------------------------ # Bulk internals diff --git a/dojo/location/models.py b/dojo/location/models.py index 9e74a350a25..0dd241bfb57 100644 --- a/dojo/location/models.py +++ b/dojo/location/models.py @@ -376,7 +376,6 @@ def bulk_get_or_create(cls, locations: Iterable[Self]) -> list[Self]: if not isinstance(loc, cls): error_message = f"Invalid location type; expected {cls} but got {type(loc)}" raise TypeError(error_message) - loc.clean() hashes.append(loc.identity_hash) # Look up existing objects, grouping by hash From 8fb952033bfd5993feff9987eeb7df8695041e9d Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 10:19:12 -0400 Subject: [PATCH 31/47] fix --- dojo/location/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dojo/location/models.py b/dojo/location/models.py index 0dd241bfb57..9e74a350a25 100644 --- a/dojo/location/models.py +++ b/dojo/location/models.py @@ -376,6 +376,7 @@ def bulk_get_or_create(cls, locations: Iterable[Self]) -> list[Self]: if not isinstance(loc, cls): error_message = f"Invalid location type; expected {cls} but got {type(loc)}" raise TypeError(error_message) + loc.clean() hashes.append(loc.identity_hash) # Look up existing objects, grouping by hash From 5f9b22c61dbf39889cac885380fe62f9c7aa36de Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 12:19:20 -0400 Subject: [PATCH 32/47] add test --- unittests/test_bulk_locations.py | 47 ++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/unittests/test_bulk_locations.py b/unittests/test_bulk_locations.py index 1941faf87ff..0050baa6a13 100644 --- a/unittests/test_bulk_locations.py +++ b/unittests/test_bulk_locations.py @@ -500,3 +500,50 @@ def test_update_location_status_for_many_findings_is_bulk(self): # Expected: 1 SELECT (partial-status fetch) + 1 UPDATE (mitigate for the single user). self.assertLess(len(ctx.captured_queries), 5, ctx.captured_queries) + + def test_partial_status_update_reactivates_matching_mitigates_rest(self): + """ + When a reimported finding is NOT mitigated, locations still present in + the report should be reactivated, and locations absent from the report + should be mitigated. + """ + finding = _make_finding() + product = finding.test.engagement.product + + # Create three locations, all currently mitigated on this finding + url_kept = _make_url("kept.example.com") + url_also_kept = _make_url("also-kept.example.com") + url_gone = _make_url("gone.example.com") + saved = URL.bulk_get_or_create([url_kept, url_also_kept, url_gone]) + + refs = [] + for loc in saved: + refs.append(LocationFindingReference.objects.create( + location=loc.location, + finding=finding, + status=FindingLocationStatus.Mitigated, + )) + + # Simulate a reimport where the new finding is active and only has two of the three locations + new_finding = Finding( + title=finding.title, severity=finding.severity, + test=finding.test, is_mitigated=False, + ) + new_finding.unsaved_locations = [ + LocationData(type="url", data={"url": "https://kept.example.com"}), + LocationData(type="url", data={"url": "https://also-kept.example.com"}), + ] + + mgr = LocationManager(product) + mgr.update_location_status(finding, new_finding, finding.reporter) + mgr.persist() + + # Refresh from DB + for ref in refs: + ref.refresh_from_db() + + # The two locations still in the report should be reactivated + self.assertEqual(refs[0].status, FindingLocationStatus.Active) + self.assertEqual(refs[1].status, FindingLocationStatus.Active) + # The location no longer in the report should be mitigated + self.assertEqual(refs[2].status, FindingLocationStatus.Mitigated) From 6ea09f370377979e6466223c499a25b4aa8cd482 Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 12:22:51 -0400 Subject: [PATCH 33/47] linter --- unittests/test_bulk_locations.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/unittests/test_bulk_locations.py b/unittests/test_bulk_locations.py index 0050baa6a13..e6f40d1922e 100644 --- a/unittests/test_bulk_locations.py +++ b/unittests/test_bulk_locations.py @@ -516,13 +516,14 @@ def test_partial_status_update_reactivates_matching_mitigates_rest(self): url_gone = _make_url("gone.example.com") saved = URL.bulk_get_or_create([url_kept, url_also_kept, url_gone]) - refs = [] - for loc in saved: - refs.append(LocationFindingReference.objects.create( + refs = [ + LocationFindingReference.objects.create( location=loc.location, finding=finding, status=FindingLocationStatus.Mitigated, - )) + ) + for loc in saved + ] # Simulate a reimport where the new finding is active and only has two of the three locations new_finding = Finding( From 49074e8593d339a0db02b3a2ae6cc7f24a962576 Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 12:54:02 -0400 Subject: [PATCH 34/47] refactor --- dojo/importers/location_manager.py | 132 ++++++++++++++++++++++++++++- dojo/tags_signals.py | 127 --------------------------- 2 files changed, 130 insertions(+), 129 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index a7dba385450..3b4a1739483 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -6,12 +6,13 @@ from typing import TYPE_CHECKING, TypeVar from django.core.exceptions import ValidationError +from django.db.models import signals from django.utils import timezone from dojo.importers.base_location_manager import BaseLocationManager from dojo.location.models import AbstractLocation, Location, LocationFindingReference, LocationProductReference from dojo.location.status import FindingLocationStatus, ProductLocationStatus -from dojo.tags_signals import bulk_inherit_location_tags +from dojo.tags_signals import make_inherited_tags_sticky from dojo.tools.locations import LocationData from dojo.url.models import URL from dojo.utils import get_system_setting @@ -214,7 +215,7 @@ def _persist_finding_locations(self) -> None: # product has no tag inheritance enabled, and use the bulk variant otherwise to # avoid O(N) expensive JOINs via Location.all_related_products(). if self._should_inherit_product_tags() and locations_needing_inherit: - bulk_inherit_location_tags( + self._bulk_inherit_tags( (loc.location for loc in locations_needing_inherit.values()), known_product=self._product, ) @@ -410,3 +411,130 @@ def type_id(x: tuple[int, AbstractLocation]) -> int: # Restore the original input ordering saved.sort(key=itemgetter(0)) return [loc for _, loc in saved] + + # ------------------------------------------------------------------ + # Tag inheritance + # ------------------------------------------------------------------ + + @staticmethod + def _bulk_inherit_tags(locations, *, known_product=None): + """ + Bulk equivalent of calling inherit_instance_tags(loc) for many Locations. + + Uses aggressive prefetching to produce O(1) queries for the "decide what needs + to change" phase, and only runs per-instance mutation queries (~3 each) for + locations that are actually out of sync with their product tags. + + Compared to the per-instance path, this avoids the N expensive JOINs in + Location.all_related_products() (~50ms each). + + Args: + locations: iterable of Location instances to update + known_product: optional hint — if provided, used as the minimum product + set for locations not already associated elsewhere. Not strictly + required for correctness, but lets us skip the fetch-related-products + query in the common case. + + """ + from dojo.models import Product, _manage_inherited_tags # noqa: PLC0415 + + locations = list(locations) + if not locations: + return + + system_wide_inherit = bool(get_system_setting("enable_product_tag_inheritance")) + + # --- Bulk query: map location_id -> set[product_id] for every related product + location_ids = [loc.id for loc in locations] + product_ids_by_location: dict[int, set[int]] = {loc.id: set() for loc in locations} + + # Path 1: via LocationProductReference (direct association) + for loc_id, prod_id in LocationProductReference.objects.filter( + location_id__in=location_ids, + ).values_list("location_id", "product_id"): + product_ids_by_location[loc_id].add(prod_id) + + # Path 2: via LocationFindingReference -> Finding -> Test -> Engagement -> Product + for loc_id, prod_id in ( + LocationFindingReference.objects + .filter(location_id__in=location_ids) + .values_list("location_id", "finding__test__engagement__product_id") + ): + if prod_id is not None: + product_ids_by_location[loc_id].add(prod_id) + + # Seed with known_product so callers don't have to rely on refs being persisted before this call + if known_product is not None: + for loc_id in location_ids: + product_ids_by_location[loc_id].add(known_product.id) + + # --- Bulk query: fetch the unique products with their tags and inheritance flag + all_product_ids = {pid for pids in product_ids_by_location.values() for pid in pids} + if not all_product_ids: + return + + products = { + p.id: p + for p in Product.objects.filter(id__in=all_product_ids).prefetch_related("tags") + } + + # Products that contribute to inheritance (either opted in themselves or system-wide on) + contributing_product_ids = { + pid for pid, p in products.items() + if p.enable_product_tag_inheritance or system_wide_inherit + } + if not contributing_product_ids: + return + + # Pre-compute the tag names each contributing product contributes + tags_by_product: dict[int, set[str]] = { + pid: {t.name for t in products[pid].tags.all()} + for pid in contributing_product_ids + } + + # --- Bulk query: existing inherited_tags per location + inherited_through = Location.inherited_tags.through + inherited_fk = Location.inherited_tags.field.m2m_reverse_field_name() + existing_inherited_by_location: dict[int, set[str]] = {loc.id: set() for loc in locations} + for loc_id, tag_name in inherited_through.objects.filter( + location_id__in=location_ids, + ).values_list("location_id", f"{inherited_fk}__name"): + existing_inherited_by_location[loc_id].add(tag_name) + + # --- Bulk query: existing user tags per location (needed by _manage_inherited_tags) + tags_through = Location.tags.through + tags_fk = Location.tags.field.m2m_reverse_field_name() + existing_tags_by_location: dict[int, list[str]] = {loc.id: [] for loc in locations} + for loc_id, tag_name in tags_through.objects.filter( + location_id__in=location_ids, + ).values_list("location_id", f"{tags_fk}__name"): + existing_tags_by_location[loc_id].append(tag_name) + + # --- Determine which locations are out of sync and call _manage_inherited_tags directly. + # Must disconnect make_inherited_tags_sticky while we mutate — otherwise each + # tags.set() / inherited_tags.set() fires m2m_changed, re-enters the whole expensive + # chain per location, and defeats the point of the bulk path. + # Only disconnect/reconnect for senders where the signal is actually registered + # (tags.through). inherited_tags.through is not a registered sender — attempting + # to connect it after disconnect() would incorrectly add a new registration, + # causing recursion on subsequent calls. + disconnected = signals.m2m_changed.disconnect(make_inherited_tags_sticky, sender=tags_through) + try: + for location in locations: + target_tag_names: set[str] = set() + for pid in product_ids_by_location[location.id]: + if pid in contributing_product_ids: + target_tag_names |= tags_by_product[pid] + + existing = existing_inherited_by_location[location.id] + if target_tag_names == existing: + continue + + _manage_inherited_tags( + location, + list(target_tag_names), + potentially_existing_tags=existing_tags_by_location[location.id], + ) + finally: + if disconnected: + signals.m2m_changed.connect(make_inherited_tags_sticky, sender=tags_through) diff --git a/dojo/tags_signals.py b/dojo/tags_signals.py index 76e7314cbec..0fea7ae8ad5 100644 --- a/dojo/tags_signals.py +++ b/dojo/tags_signals.py @@ -49,133 +49,6 @@ def inherit_instance_tags(instance): instance.inherit_tags(tag_list) -def bulk_inherit_location_tags(locations, *, known_product=None): - """ - Bulk equivalent of calling inherit_instance_tags(loc) for many Locations. - - Uses aggressive prefetching to produce O(1) queries for the "decide what needs - to change" phase, and only runs per-instance mutation queries (~3 each) for - locations that are actually out of sync with their product tags. - - Compared to the per-instance path, this avoids the N expensive JOINs in - Location.all_related_products() (~50ms each). - - Args: - locations: iterable of Location instances to update - known_product: optional hint — if provided, used as the minimum product - set for locations not already associated elsewhere. Not strictly - required for correctness, but lets us skip the fetch-related-products - query in the common case. - - """ - locations = list(locations) - if not locations: - return - - system_wide_inherit = bool(get_system_setting("enable_product_tag_inheritance")) - - # --- Bulk query: map location_id -> set[product_id] for every related product - location_ids = [loc.id for loc in locations] - product_ids_by_location: dict[int, set[int]] = {loc.id: set() for loc in locations} - - # Path 1: via LocationProductReference (direct association) - for loc_id, prod_id in LocationProductReference.objects.filter( - location_id__in=location_ids, - ).values_list("location_id", "product_id"): - product_ids_by_location[loc_id].add(prod_id) - - # Path 2: via LocationFindingReference -> Finding -> Test -> Engagement -> Product - for loc_id, prod_id in ( - LocationFindingReference.objects - .filter(location_id__in=location_ids) - .values_list("location_id", "finding__test__engagement__product_id") - ): - if prod_id is not None: - product_ids_by_location[loc_id].add(prod_id) - - # Seed with known_product so callers don't have to rely on refs being persisted before this call - if known_product is not None: - for loc_id in location_ids: - product_ids_by_location[loc_id].add(known_product.id) - - # --- Bulk query: fetch the unique products with their tags and inheritance flag - all_product_ids = {pid for pids in product_ids_by_location.values() for pid in pids} - if not all_product_ids: - return - - products = { - p.id: p - for p in Product.objects.filter(id__in=all_product_ids).prefetch_related("tags") - } - - # Products that contribute to inheritance (either opted in themselves or system-wide on) - contributing_product_ids = { - pid for pid, p in products.items() - if p.enable_product_tag_inheritance or system_wide_inherit - } - if not contributing_product_ids: - # No product with inheritance enabled and system-wide is off → nothing to do - return - - # Pre-compute the tag names each contributing product contributes - tags_by_product: dict[int, set[str]] = { - pid: {t.name for t in products[pid].tags.all()} - for pid in contributing_product_ids - } - - # --- Bulk query: existing inherited_tags per location - inherited_through = Location.inherited_tags.through - inherited_fk = Location.inherited_tags.field.m2m_reverse_field_name() - existing_inherited_by_location: dict[int, set[str]] = {loc.id: set() for loc in locations} - for loc_id, tag_name in inherited_through.objects.filter( - location_id__in=location_ids, - ).values_list("location_id", f"{inherited_fk}__name"): - existing_inherited_by_location[loc_id].add(tag_name) - - # --- Bulk query: existing user tags per location (needed by _manage_inherited_tags) - tags_through = Location.tags.through - tags_fk = Location.tags.field.m2m_reverse_field_name() - existing_tags_by_location: dict[int, list[str]] = {loc.id: [] for loc in locations} - for loc_id, tag_name in tags_through.objects.filter( - location_id__in=location_ids, - ).values_list("location_id", f"{tags_fk}__name"): - existing_tags_by_location[loc_id].append(tag_name) - - # --- Determine which locations are out of sync and call _manage_inherited_tags directly. - # Calling _manage_inherited_tags with pre-computed values skips the expensive - # products_to_inherit_tags_from() JOIN that location.inherit_tags() would run. - # - # Must disconnect make_inherited_tags_sticky while we mutate — otherwise each - # tags.set() / inherited_tags.set() fires m2m_changed, re-enters the whole expensive - # chain per location, and defeats the point of the bulk path. - from dojo.models import _manage_inherited_tags # noqa: PLC0415 circular import - - # Only disconnect/reconnect for senders where the signal is actually registered - # (tags.through). inherited_tags.through is not a registered sender — attempting - # to connect it after disconnect() would incorrectly add a new registration, - # causing recursion on subsequent calls. - disconnected = signals.m2m_changed.disconnect(make_inherited_tags_sticky, sender=tags_through) - try: - for location in locations: - target_tag_names: set[str] = set() - for pid in product_ids_by_location[location.id]: - if pid in contributing_product_ids: - target_tag_names |= tags_by_product[pid] - - existing = existing_inherited_by_location[location.id] - if target_tag_names == existing: - continue # Already in sync — skip the expensive mutation path entirely - - _manage_inherited_tags( - location, - list(target_tag_names), - potentially_existing_tags=existing_tags_by_location[location.id], - ) - finally: - if disconnected: - signals.m2m_changed.connect(make_inherited_tags_sticky, sender=tags_through) - - def inherit_linked_instance_tags(instance: LocationFindingReference | LocationProductReference): inherit_instance_tags(instance.location) From ea2210a437b08c3fc6212749b273a0fd620b6a88 Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 13:00:09 -0400 Subject: [PATCH 35/47] fixup --- dojo/importers/location_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 3b4a1739483..dd8095f798f 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -3,7 +3,7 @@ import logging from itertools import groupby from operator import itemgetter -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING from django.core.exceptions import ValidationError from django.db.models import signals @@ -23,9 +23,9 @@ logger = logging.getLogger(__name__) -# TypeVar to represent unsaved locations coming from parsers. These might be existing AbstractLocations (when linking +# Unsaved locations coming from parsers. These might be existing AbstractLocations (when linking # existing endpoints) or LocationData objects sent by the parser. -UnsavedLocation = TypeVar("UnsavedLocation", LocationData, AbstractLocation) +UnsavedLocation = LocationData | AbstractLocation class LocationManager(BaseLocationManager): From 34ffb2914c59d51d631f91e69faefd1730d463ff Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 13:15:45 -0400 Subject: [PATCH 36/47] refactor --- dojo/importers/location_manager.py | 158 ++++++++++++++++------------- 1 file changed, 86 insertions(+), 72 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index dd8095f798f..5648116deab 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -33,6 +33,8 @@ class LocationManager(BaseLocationManager): def __init__(self, product: Product) -> None: super().__init__(product) self._locations_by_finding: dict[int, tuple[Finding, list[UnsavedLocation]]] = {} + # Product-only locations (not tied to a finding). Appended to by record_locations_for_product. + self._product_locations: list[UnsavedLocation] = [] # Status update inputs (deferred). All entries are processed in a single bulk pass by persist(). # (existing_finding, new_finding, user): classified partial mitigate/reactivate self._status_updates: list[tuple[Finding, Finding, Dojo_User]] = [] @@ -69,9 +71,10 @@ def record_locations_for_finding( finding: Finding, locations: list[UnsavedLocation], ) -> None: - """Record locations to be associated with a finding. Flushed by persist().""" + """Record locations to be associated with a finding (and its product). Flushed by persist().""" if locations: self._locations_by_finding.setdefault(finding.id, (finding, []))[1].extend(locations) + self._product_locations.extend(locations) def update_location_status( self, @@ -126,45 +129,53 @@ def serialize_extra_locations(self, locations: list) -> dict: def persist(self, user: Dojo_User | None = None) -> None: """Flush all accumulated location operations to the database.""" - self._persist_finding_locations() + self._persist_locations() self._flush_status_updates() - def _persist_finding_locations(self) -> None: - """Bulk get/create locations and their finding+product refs.""" - if not self._locations_by_finding: + def _persist_locations(self) -> None: + """Bulk get/create all locations and their finding + product refs.""" + if not self._product_locations: return - all_locations: list[AbstractLocation] = [] - finding_ranges: list[tuple[Finding, int, int]] = [] - + # --- Phase 1: Build finding ranges, then clean all product locations at once --- + # _product_locations contains everything: finding-associated locations are appended by + # record_locations_for_finding, product-only locations by record_locations_for_product. + # Build finding ranges first (indexing into the per-finding sublists), then clean the + # full _product_locations list in one pass. + finding_ranges: list[tuple[Finding, list[UnsavedLocation]]] = [] for finding, locations in self._locations_by_finding.values(): - cleaned = self.clean_unsaved_locations(locations) - start = len(all_locations) - all_locations.extend(cleaned) - end = len(all_locations) - if start < end: - finding_ranges.append((finding, start, end)) - - if all_locations: - saved = self._bulk_get_or_create_locations(all_locations) - - # Build all refs across all findings in one pass - all_finding_refs = [] - all_product_refs = [] - # Track locations that got new refs — only those need tag inheritance - locations_needing_inherit: dict[int, AbstractLocation] = {} - - # Pre-fetch existing product refs for this product across all locations (one query) - all_location_ids = [loc.location_id for loc in saved] - existing_product_refs: set[int] = set( - LocationProductReference.objects.filter( - location_id__in=all_location_ids, - product=self._product, - ).values_list("location_id", flat=True), - ) + if locations: + finding_ranges.append((finding, locations)) + + all_locations = self.clean_unsaved_locations(self._product_locations) + if not all_locations: + self._locations_by_finding.clear() + self._product_locations.clear() + return - # Pre-fetch existing finding refs across ALL findings in one query (avoids N+1) - all_finding_ids = [finding.id for finding, _, _ in finding_ranges] + # --- Phase 2: Bulk get/create --- + saved = self._bulk_get_or_create_locations(all_locations) + + # Build a lookup from identity_hash -> saved location for finding ref creation + saved_by_hash: dict[str, AbstractLocation] = {loc.identity_hash: loc for loc in saved} + + # --- Phase 3: Create refs --- + all_finding_refs = [] + all_product_refs = [] + locations_needing_inherit: dict[int, AbstractLocation] = {} + + # Pre-fetch existing product refs for this product across all locations (one query) + all_location_ids = [loc.location_id for loc in saved] + existing_product_refs: set[int] = set( + LocationProductReference.objects.filter( + location_id__in=all_location_ids, + product=self._product, + ).values_list("location_id", flat=True), + ) + + # Pre-fetch existing finding refs in one query (avoids N+1) + if finding_ranges: + all_finding_ids = [finding.id for finding, _ in finding_ranges] existing_finding_ref_keys: set[tuple[int, int]] = set( LocationFindingReference.objects.filter( location_id__in=all_location_ids, @@ -172,55 +183,58 @@ def _persist_finding_locations(self) -> None: ).values_list("finding_id", "location_id"), ) - for finding, start, end in finding_ranges: - finding_locations = saved[start:end] - - for location in finding_locations: - assoc = location.get_association_data() - finding_ref_key = (finding.id, location.location_id) - + for finding, unsaved_locations in finding_ranges: + # Re-clean per-finding locations to get the same AbstractLocations with identity_hashes + for location in self.clean_unsaved_locations(unsaved_locations): + saved_loc = saved_by_hash.get(location.identity_hash) + if saved_loc is None: + continue + finding_ref_key = (finding.id, saved_loc.location_id) if finding_ref_key not in existing_finding_ref_keys: + assoc = saved_loc.get_association_data() all_finding_refs.append(LocationFindingReference( - location_id=location.location_id, + location_id=saved_loc.location_id, finding=finding, status=FindingLocationStatus.Active, relationship=assoc.relationship_type, relationship_data=assoc.relationship_data, )) existing_finding_ref_keys.add(finding_ref_key) - locations_needing_inherit[location.location_id] = location - - if location.location_id not in existing_product_refs: - all_product_refs.append(LocationProductReference( - location_id=location.location_id, - product=self._product, - status=ProductLocationStatus.Active, - relationship=assoc.relationship_type, - relationship_data=assoc.relationship_data, - )) - existing_product_refs.add(location.location_id) - locations_needing_inherit[location.location_id] = location - - if all_finding_refs: - LocationFindingReference.objects.bulk_create( - all_finding_refs, batch_size=1000, ignore_conflicts=True, - ) - if all_product_refs: - LocationProductReference.objects.bulk_create( - all_product_refs, batch_size=1000, ignore_conflicts=True, - ) + locations_needing_inherit[saved_loc.location_id] = saved_loc + + # Product refs for all locations + for location in saved: + if location.location_id not in existing_product_refs: + assoc = location.get_association_data() + all_product_refs.append(LocationProductReference( + location_id=location.location_id, + product=self._product, + status=ProductLocationStatus.Active, + relationship=assoc.relationship_type, + relationship_data=assoc.relationship_data, + )) + existing_product_refs.add(location.location_id) + locations_needing_inherit[location.location_id] = location + + # --- Phase 4: Bulk create refs --- + if all_finding_refs: + LocationFindingReference.objects.bulk_create( + all_finding_refs, batch_size=1000, ignore_conflicts=True, + ) + if all_product_refs: + LocationProductReference.objects.bulk_create( + all_product_refs, batch_size=1000, ignore_conflicts=True, + ) - # bulk_create bypasses post_save signals; trigger tag inheritance only on locations - # that got new refs (matches original signal-based behavior). Short-circuit if the - # product has no tag inheritance enabled, and use the bulk variant otherwise to - # avoid O(N) expensive JOINs via Location.all_related_products(). - if self._should_inherit_product_tags() and locations_needing_inherit: - self._bulk_inherit_tags( - (loc.location for loc in locations_needing_inherit.values()), - known_product=self._product, - ) + # --- Phase 5: Tag inheritance --- + if self._should_inherit_product_tags() and locations_needing_inherit: + self._bulk_inherit_tags( + (loc.location for loc in locations_needing_inherit.values()), + known_product=self._product, + ) self._locations_by_finding.clear() + self._product_locations.clear() def _flush_status_updates(self) -> None: """ From 7a9b72abf4f07c84823016ca40840a273ecf848d Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 15:35:53 -0400 Subject: [PATCH 37/47] cleanup --- dojo/importers/location_manager.py | 94 +++++++++++++++--------------- dojo/location/models.py | 7 +-- 2 files changed, 50 insertions(+), 51 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 5648116deab..24c6acd6026 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -32,7 +32,7 @@ class LocationManager(BaseLocationManager): def __init__(self, product: Product) -> None: super().__init__(product) - self._locations_by_finding: dict[int, tuple[Finding, list[UnsavedLocation]]] = {} + self._locations_by_finding: dict[Finding, list[UnsavedLocation]] = {} # Product-only locations (not tied to a finding). Appended to by record_locations_for_product. self._product_locations: list[UnsavedLocation] = [] # Status update inputs (deferred). All entries are processed in a single bulk pass by persist(). @@ -73,7 +73,7 @@ def record_locations_for_finding( ) -> None: """Record locations to be associated with a finding (and its product). Flushed by persist().""" if locations: - self._locations_by_finding.setdefault(finding.id, (finding, []))[1].extend(locations) + self._locations_by_finding.setdefault(finding, []).extend(locations) self._product_locations.extend(locations) def update_location_status( @@ -133,49 +133,62 @@ def persist(self, user: Dojo_User | None = None) -> None: self._flush_status_updates() def _persist_locations(self) -> None: - """Bulk get/create all locations and their finding + product refs.""" + """Bulk get/create all locations and their finding/product refs.""" + # _product_locations contains all locations to persist: associate with finding -> associate with product if not self._product_locations: return - # --- Phase 1: Build finding ranges, then clean all product locations at once --- - # _product_locations contains everything: finding-associated locations are appended by - # record_locations_for_finding, product-only locations by record_locations_for_product. - # Build finding ranges first (indexing into the per-finding sublists), then clean the - # full _product_locations list in one pass. - finding_ranges: list[tuple[Finding, list[UnsavedLocation]]] = [] - for finding, locations in self._locations_by_finding.values(): - if locations: - finding_ranges.append((finding, locations)) - + # Convert all UnsavedLocation objects (possibly a mix of AbstractLocation and LocationData objects) to cleaned + # concrete location objects all_locations = self.clean_unsaved_locations(self._product_locations) if not all_locations: self._locations_by_finding.clear() self._product_locations.clear() return - # --- Phase 2: Bulk get/create --- + # Bulk persist all locations to the database saved = self._bulk_get_or_create_locations(all_locations) - # Build a lookup from identity_hash -> saved location for finding ref creation - saved_by_hash: dict[str, AbstractLocation] = {loc.identity_hash: loc for loc in saved} + # Build a lookup from (type, identity_hash) -> saved location for finding ref creation. + # identity_hash is only unique per concrete type, so we key by both. + # + # Finding/location mapping was tracked separately in _locations_by_finding which are still the raw + # UnsavedLocation objects; we'll need to line them up with the persisted locations. + saved_by_key: dict[tuple[type, str], AbstractLocation] = { + (type(loc), loc.identity_hash): loc for loc in saved + } - # --- Phase 3: Create refs --- + # Lists for bulk creation all_finding_refs = [] all_product_refs = [] - locations_needing_inherit: dict[int, AbstractLocation] = {} - - # Pre-fetch existing product refs for this product across all locations (one query) + # List of all location IDs, for querying existing refs all_location_ids = [loc.location_id for loc in saved] + + # Determine necessary product refs to create existing_product_refs: set[int] = set( LocationProductReference.objects.filter( location_id__in=all_location_ids, product=self._product, ).values_list("location_id", flat=True), ) + for location in saved: + if location.location_id not in existing_product_refs: + assoc = location.get_association_data() + all_product_refs.append(LocationProductReference( + location_id=location.location_id, + product=self._product, + status=ProductLocationStatus.Active, + relationship=assoc.relationship_type, + relationship_data=assoc.relationship_data, + )) + existing_product_refs.add(location.location_id) - # Pre-fetch existing finding refs in one query (avoids N+1) - if finding_ranges: - all_finding_ids = [finding.id for finding, _ in finding_ranges] + # Determine necessary finding refs to create + if self._locations_by_finding: + all_finding_ids = [finding.id for finding in self._locations_by_finding] + # Strictly speaking this returns more rows than we need (it's the cross of the location/finding lists rather + # than scoped per-finding), but more straightforward than constructing a per-finding lookup. We won't create + # any unwanted associations below anyway. existing_finding_ref_keys: set[tuple[int, int]] = set( LocationFindingReference.objects.filter( location_id__in=all_location_ids, @@ -183,12 +196,14 @@ def _persist_locations(self) -> None: ).values_list("finding_id", "location_id"), ) - for finding, unsaved_locations in finding_ranges: - # Re-clean per-finding locations to get the same AbstractLocations with identity_hashes + for finding, unsaved_locations in self._locations_by_finding.items(): + # Clean per-finding UnsavedLocations to get cleaned AbstractLocations with identity_hashes. The + # identity_hash uniquely defines the location per type, so using these we can match up with actual + # persisted locations from above, all of which will be represented in saved_by_key. (Keep in mind, + # _locations_by_finding contains a subset of the locations across all its values in + # _locations_by_finding.) for location in self.clean_unsaved_locations(unsaved_locations): - saved_loc = saved_by_hash.get(location.identity_hash) - if saved_loc is None: - continue + saved_loc = saved_by_key[type(location), location.identity_hash] finding_ref_key = (finding.id, saved_loc.location_id) if finding_ref_key not in existing_finding_ref_keys: assoc = saved_loc.get_association_data() @@ -200,23 +215,8 @@ def _persist_locations(self) -> None: relationship_data=assoc.relationship_data, )) existing_finding_ref_keys.add(finding_ref_key) - locations_needing_inherit[saved_loc.location_id] = saved_loc - - # Product refs for all locations - for location in saved: - if location.location_id not in existing_product_refs: - assoc = location.get_association_data() - all_product_refs.append(LocationProductReference( - location_id=location.location_id, - product=self._product, - status=ProductLocationStatus.Active, - relationship=assoc.relationship_type, - relationship_data=assoc.relationship_data, - )) - existing_product_refs.add(location.location_id) - locations_needing_inherit[location.location_id] = location - # --- Phase 4: Bulk create refs --- + # Bulk create references if all_finding_refs: LocationFindingReference.objects.bulk_create( all_finding_refs, batch_size=1000, ignore_conflicts=True, @@ -226,10 +226,10 @@ def _persist_locations(self) -> None: all_product_refs, batch_size=1000, ignore_conflicts=True, ) - # --- Phase 5: Tag inheritance --- - if self._should_inherit_product_tags() and locations_needing_inherit: + # Trigger bulk tag inheritance + if self._should_inherit_product_tags(): self._bulk_inherit_tags( - (loc.location for loc in locations_needing_inherit.values()), + (loc.location for loc in saved), known_product=self._product, ) diff --git a/dojo/location/models.py b/dojo/location/models.py index 9e74a350a25..48ffeb6878a 100644 --- a/dojo/location/models.py +++ b/dojo/location/models.py @@ -412,19 +412,18 @@ def bulk_get_or_create(cls, locations: Iterable[Self]) -> list[Self]: for loc in new_locations ] Location.objects.bulk_create(parents, batch_size=1000) - # Assign Location FKs to the subtypes, then bulk create them. for loc, parent in zip(new_locations, parents, strict=True): loc.location_id = parent.id loc.location = parent - # Note there is a subtle race condition here, if somehow one of our newly-created locations conflicts - # with an existing one (e.g. from a separate thread that commits while this is running). Setting + # Note: there is a subtle potential race condition here, if somehow one of the locations to be created + # has already been created, e.g. by a separate thread that commits while this thread is running. Setting # `ignore_conflicts=True` here would prevent this step from raising an IntegrityError, but would leave # dangling parent Location objects that were created above. Rather than performing a cleanup in that # (unlikely?) case, just allow the transaction to rollback. cls.objects.bulk_create(new_locations, batch_size=1000) - # Return in input order (minus dupes) + # Return in input order return [existing_by_hash[h] for h in hashes] From dcb291dc285e3f06f3a36dc56f54f81e7d820d5c Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 16:44:34 -0400 Subject: [PATCH 38/47] coments/cleanup --- dojo/importers/location_manager.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 24c6acd6026..fe771f2c516 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -40,8 +40,9 @@ def __init__(self, product: Product) -> None: self._status_updates: list[tuple[Finding, Finding, Dojo_User]] = [] # finding_id: fully reactivate (all mitigated refs on this finding become active) self._finding_ids_to_fully_reactivate: list[int] = [] - # (finding_id, user): fully mitigate (all non-special refs on this finding become mitigated by user) - self._finding_ids_to_fully_mitigate: list[tuple[int, Dojo_User | None]] = [] + # finding_id -> user: fully mitigate (all non-special refs on this finding become mitigated by user). + # If recorded multiple times for the same finding, last user wins. + self._finding_ids_to_fully_mitigate: dict[int, Dojo_User | None] = {} # Cached result of _should_inherit_product_tags() — lazily computed and reused across persist() calls self._cached_should_inherit_product_tags: bool | None = None @@ -109,7 +110,7 @@ def record_reactivations_for_finding(self, finding: Finding) -> None: def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User | None = None) -> None: """Defer mitigation to persist(). No DB access at record time.""" - self._finding_ids_to_fully_mitigate.append((finding.id, user)) + self._finding_ids_to_fully_mitigate[finding.id] = user def get_locations_for_tagging(self, findings: list[Finding]): """Return queryset of locations to apply tags to.""" @@ -128,9 +129,9 @@ def serialize_extra_locations(self, locations: list) -> dict: # ------------------------------------------------------------------ def persist(self, user: Dojo_User | None = None) -> None: - """Flush all accumulated location operations to the database.""" + """Persist all accumulated location operations to the database.""" self._persist_locations() - self._flush_status_updates() + self._persist_status_updates() def _persist_locations(self) -> None: """Bulk get/create all locations and their finding/product refs.""" @@ -233,18 +234,12 @@ def _persist_locations(self) -> None: known_product=self._product, ) + # Clear accumulators self._locations_by_finding.clear() self._product_locations.clear() - def _flush_status_updates(self) -> None: - """ - Resolve all accumulated status-update inputs and execute them as bulk UPDATEs. - - Produces ~3-4 queries total regardless of the number of findings processed: - 1 SELECT to fetch relevant location refs for partial-status updates, - 1 UPDATE for reactivations, - 1 UPDATE per unique mitigation user (typically 1). - """ + def _persist_status_updates(self) -> None: + """Bulk persist recorded finding/product ref statuses.""" # Short-circuit if nothing to do if not (self._status_updates or self._finding_ids_to_fully_reactivate or self._finding_ids_to_fully_mitigate): return @@ -288,7 +283,7 @@ def _flush_status_updates(self) -> None: else: ref_ids_to_mitigate_by_user.setdefault(user, set()).add(ref.id) - # Full reactivations (from record_reactivations_for_finding): all mitigated refs for these findings + # Reactivate all mitigated refs for these findings if self._finding_ids_to_fully_reactivate: ref_ids_to_reactivate.update( LocationFindingReference.objects.filter( @@ -297,11 +292,11 @@ def _flush_status_updates(self) -> None: ).values_list("id", flat=True), ) - # Full mitigations (from record_mitigations_for_finding): all non-special refs for these findings, per user + # Mitigate all non-special refs for these findings, per user if self._finding_ids_to_fully_mitigate: # Group finding_ids by user to do one SELECT per user ids_by_user: dict[Dojo_User | None, list[int]] = {} - for finding_id, user in self._finding_ids_to_fully_mitigate: + for finding_id, user in self._finding_ids_to_fully_mitigate.items(): ids_by_user.setdefault(user, []).append(finding_id) for user, finding_ids in ids_by_user.items(): ref_ids_to_mitigate_by_user.setdefault(user, set()).update( From 94a83b549812ce06b3710f223f710f52657b697e Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 17:32:59 -0400 Subject: [PATCH 39/47] product ref statuses --- dojo/importers/location_manager.py | 37 +++++++++++- unittests/test_bulk_locations.py | 93 ++++++++++++++++++++++++++++-- 2 files changed, 124 insertions(+), 6 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index fe771f2c516..8332d974e6f 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -305,8 +305,10 @@ def _persist_status_updates(self) -> None: ).exclude(status__in=special_statuses).values_list("id", flat=True), ) - # Execute bulk updates + # Execute bulk finding ref updates now = timezone.now() + all_affected_ref_ids: set[int] = set() + if ref_ids_to_reactivate: LocationFindingReference.objects.filter( id__in=ref_ids_to_reactivate, @@ -316,6 +318,7 @@ def _persist_status_updates(self) -> None: audit_time=now, status=FindingLocationStatus.Active, ) + all_affected_ref_ids |= ref_ids_to_reactivate for user, ref_ids in ref_ids_to_mitigate_by_user.items(): if ref_ids: @@ -326,6 +329,38 @@ def _persist_status_updates(self) -> None: audit_time=now, status=FindingLocationStatus.Mitigated, ) + all_affected_ref_ids |= ref_ids + + # Propagate to product refs: if any finding ref is active, product ref is active; otherwise mitigated. + if all_affected_ref_ids: + affected_location_ids = set( + LocationFindingReference.objects.filter( + id__in=all_affected_ref_ids, + ).values_list("location_id", flat=True), + ) + locations_still_active = set( + LocationFindingReference.objects.filter( + location_id__in=affected_location_ids, + finding__test__engagement__product=self._product, + status=FindingLocationStatus.Active, + ).values_list("location_id", flat=True), + ) + locations_now_mitigated = affected_location_ids - locations_still_active + + if locations_still_active: + LocationProductReference.objects.filter( + location_id__in=locations_still_active, + product=self._product, + ).exclude(status=ProductLocationStatus.Active).update( + status=ProductLocationStatus.Active, + ) + if locations_now_mitigated: + LocationProductReference.objects.filter( + location_id__in=locations_now_mitigated, + product=self._product, + ).exclude(status=ProductLocationStatus.Mitigated).update( + status=ProductLocationStatus.Mitigated, + ) # Clear accumulators self._status_updates.clear() diff --git a/unittests/test_bulk_locations.py b/unittests/test_bulk_locations.py index e6f40d1922e..bc1ae7a02c9 100644 --- a/unittests/test_bulk_locations.py +++ b/unittests/test_bulk_locations.py @@ -17,7 +17,7 @@ from dojo.importers.location_manager import LocationManager from dojo.location.models import Location, LocationFindingReference, LocationProductReference -from dojo.location.status import FindingLocationStatus +from dojo.location.status import FindingLocationStatus, ProductLocationStatus from dojo.models import Engagement, Finding, Product, Product_Type, Test, Test_Type from dojo.tools.locations import LocationAssociationData, LocationData from dojo.url.models import URL @@ -481,8 +481,9 @@ def test_reactivate_for_many_findings_is_bulk(self): with CaptureQueriesContext(connection) as ctx: mgr.persist() - # Expected: 1 SELECT (gather ref IDs) + 1 UPDATE (reactivate). Allow tiny overhead. - self.assertLess(len(ctx.captured_queries), 5, ctx.captured_queries) + # Expected: 1 SELECT (gather ref IDs) + 1 UPDATE (reactivate) + # + 1 SELECT (affected location_ids) + 1 SELECT (still-active check) + up to 1 UPDATE (product refs) + self.assertLess(len(ctx.captured_queries), 8, ctx.captured_queries) def test_update_location_status_for_many_findings_is_bulk(self): findings, product = self._setup_findings_with_mitigated_refs(count=20) @@ -498,8 +499,9 @@ def test_update_location_status_for_many_findings_is_bulk(self): with CaptureQueriesContext(connection) as ctx: mgr.persist() - # Expected: 1 SELECT (partial-status fetch) + 1 UPDATE (mitigate for the single user). - self.assertLess(len(ctx.captured_queries), 5, ctx.captured_queries) + # Expected: 1 SELECT (partial-status fetch) + 1 UPDATE (mitigate) + # + 1 SELECT (affected location_ids) + 1 SELECT (still-active check) + up to 1 UPDATE (product refs) + self.assertLess(len(ctx.captured_queries), 8, ctx.captured_queries) def test_partial_status_update_reactivates_matching_mitigates_rest(self): """ @@ -548,3 +550,84 @@ def test_partial_status_update_reactivates_matching_mitigates_rest(self): self.assertEqual(refs[1].status, FindingLocationStatus.Active) # The location no longer in the report should be mitigated self.assertEqual(refs[2].status, FindingLocationStatus.Mitigated) + + def test_product_ref_mitigated_when_all_finding_refs_mitigated(self): + """When all finding refs for a location are mitigated, the product ref should become mitigated.""" + finding = _make_finding() + product = finding.test.engagement.product + + url = _make_url("product-status-test.example.com") + saved = URL.bulk_get_or_create([url]) + loc = saved[0] + + # Create active finding ref and active product ref + LocationFindingReference.objects.create( + location=loc.location, finding=finding, status=FindingLocationStatus.Active, + ) + product_ref = LocationProductReference.objects.create( + location=loc.location, product=product, status=ProductLocationStatus.Active, + ) + + # Mitigate the finding + mgr = LocationManager(product) + mgr.record_mitigations_for_finding(finding, finding.reporter) + mgr.persist() + + product_ref.refresh_from_db() + self.assertEqual(product_ref.status, ProductLocationStatus.Mitigated) + + def test_product_ref_stays_active_when_some_finding_refs_still_active(self): + """When at least one finding ref is active, the product ref should stay active.""" + finding1 = _make_finding() + product = finding1.test.engagement.product + finding2 = Finding.objects.create( + test=finding1.test, title="Second Finding", severity="Medium", reporter=finding1.reporter, + ) + + url = _make_url("shared-location.example.com") + saved = URL.bulk_get_or_create([url]) + loc = saved[0] + + # Two findings share the same location, both active + LocationFindingReference.objects.create( + location=loc.location, finding=finding1, status=FindingLocationStatus.Active, + ) + LocationFindingReference.objects.create( + location=loc.location, finding=finding2, status=FindingLocationStatus.Active, + ) + product_ref = LocationProductReference.objects.create( + location=loc.location, product=product, status=ProductLocationStatus.Active, + ) + + # Mitigate only the first finding — second is still active + mgr = LocationManager(product) + mgr.record_mitigations_for_finding(finding1, finding1.reporter) + mgr.persist() + + product_ref.refresh_from_db() + self.assertEqual(product_ref.status, ProductLocationStatus.Active) + + def test_product_ref_reactivated_when_finding_ref_reactivated(self): + """When a finding ref is reactivated, the product ref should become active.""" + finding = _make_finding() + product = finding.test.engagement.product + + url = _make_url("reactivate-product.example.com") + saved = URL.bulk_get_or_create([url]) + loc = saved[0] + + # Start with everything mitigated + LocationFindingReference.objects.create( + location=loc.location, finding=finding, status=FindingLocationStatus.Mitigated, + ) + product_ref = LocationProductReference.objects.create( + location=loc.location, product=product, status=ProductLocationStatus.Mitigated, + ) + + # Reactivate the finding + mgr = LocationManager(product) + mgr.record_reactivations_for_finding(finding) + mgr.persist() + + product_ref.refresh_from_db() + self.assertEqual(product_ref.status, ProductLocationStatus.Active) From 77ed8a9d7f335b606f9b680ba056f72e1357b606 Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 17:36:47 -0400 Subject: [PATCH 40/47] persist in txn --- dojo/importers/location_manager.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 8332d974e6f..cd6c70260ae 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING from django.core.exceptions import ValidationError +from django.db import transaction from django.db.models import signals from django.utils import timezone @@ -130,8 +131,9 @@ def serialize_extra_locations(self, locations: list) -> dict: def persist(self, user: Dojo_User | None = None) -> None: """Persist all accumulated location operations to the database.""" - self._persist_locations() - self._persist_status_updates() + with transaction.atomic(): + self._persist_locations() + self._persist_status_updates() def _persist_locations(self) -> None: """Bulk get/create all locations and their finding/product refs.""" From 221a525ecdd023cfc8646bd99fbf5fba2945eeff Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 17:54:20 -0400 Subject: [PATCH 41/47] refactor --- dojo/importers/base_location_manager.py | 10 +++++----- dojo/importers/default_importer.py | 4 ++-- dojo/importers/default_reimporter.py | 4 ++-- dojo/importers/endpoint_manager.py | 20 ++++++++++---------- dojo/importers/location_manager.py | 6 +++--- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/dojo/importers/base_location_manager.py b/dojo/importers/base_location_manager.py index eaaeea5b6d0..df587edaafa 100644 --- a/dojo/importers/base_location_manager.py +++ b/dojo/importers/base_location_manager.py @@ -50,7 +50,7 @@ def record_reactivations_for_finding(self, finding: Finding) -> None: """Record items on this finding for reactivation.""" @abstractmethod - def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User | None = None) -> None: + def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User) -> None: """Record items on this finding for mitigation.""" @abstractmethod @@ -66,7 +66,7 @@ def serialize_extra_locations(self, locations: list) -> dict: """Serialize extra locations/endpoints for import history settings.""" @abstractmethod - def persist(self, user: Dojo_User | None = None) -> None: + def persist(self) -> None: """Flush all accumulated operations to the database.""" @@ -114,7 +114,7 @@ def update_status(self, existing_finding: Finding, new_finding: Finding, user: D def record_reactivations_for_finding(self, finding: Finding) -> None: return self._manager.record_reactivations_for_finding(finding) - def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User | None = None) -> None: + def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User) -> None: return self._manager.record_mitigations_for_finding(finding, user) def get_locations_for_tagging(self, findings: list[Finding]): @@ -126,5 +126,5 @@ def get_location_tag_fallback(self, finding: Finding): def serialize_extra_locations(self, locations: list) -> dict: return self._manager.serialize_extra_locations(locations) - def persist(self, user: Dojo_User | None = None) -> None: - return self._manager.persist(user) + def persist(self) -> None: + return self._manager.persist() diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index b3226a19c2b..8fb4cdc185a 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -262,7 +262,7 @@ def process_findings( # If batch is full or we're at the end, persist locations/endpoints and dispatch if len(batch_finding_ids) >= batch_max_size or is_final_finding: - self.location_handler.persist(user=self.user) + self.location_handler.persist() # Apply parser-supplied tags for this batch before post-processing starts, # so rules/deduplication tasks see the tags already on the findings. bulk_apply_parser_tags(findings_with_parser_tags) @@ -395,7 +395,7 @@ def close_old_findings( product_grading_option=False, ) # Persist any accumulated location/endpoint status changes - self.location_handler.persist(user=self.user) + self.location_handler.persist() # push finding groups to jira since we only only want to push whole groups # We dont check if the finding jira sync is applicable quite yet until we can get in the loop # but this is a way to at least make it that far diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index 280e4054ea7..b63b7134701 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -421,7 +421,7 @@ def process_findings( # - Deduplication batches: optimize bulk operations (larger batches = fewer queries) # They don't need to be aligned since they optimize different operations. if len(batch_finding_ids) >= dedupe_batch_max_size or is_final: - self.location_handler.persist(user=self.user) + self.location_handler.persist() # Apply parser-supplied tags for this batch before post-processing starts, # so rules/deduplication tasks see the tags already on the findings. bulk_apply_parser_tags(findings_with_parser_tags) @@ -533,7 +533,7 @@ def close_old_findings( ) mitigated_findings.append(finding) # Persist any accumulated location/endpoint status changes - self.location_handler.persist(user=self.user) + self.location_handler.persist() # push finding groups to jira since we only only want to push whole groups # We dont check if the finding jira sync is applicable quite yet until we can get in the loop # but this is a way to at least make it that far diff --git a/dojo/importers/endpoint_manager.py b/dojo/importers/endpoint_manager.py index 77e86c724a2..21e9673ed8c 100644 --- a/dojo/importers/endpoint_manager.py +++ b/dojo/importers/endpoint_manager.py @@ -37,7 +37,7 @@ def __init__(self, product: Product) -> None: self._product = product self._endpoints_to_create: dict[EndpointUniqueKey, dict] = {} self._statuses_to_create: list[tuple[Finding, EndpointUniqueKey]] = [] - self._statuses_to_mitigate: list[Endpoint_Status] = [] + self._statuses_to_mitigate: list[tuple[Endpoint_Status, Dojo_User | None]] = [] self._statuses_to_reactivate: list[Endpoint_Status] = [] @staticmethod @@ -159,13 +159,13 @@ def update_endpoint_status( new_finding_endpoints_list = new_finding.unsaved_endpoints if new_finding.is_mitigated: # New finding is mitigated, so mitigate all old endpoints - self._statuses_to_mitigate.extend(existing_finding_endpoint_status_list) + self._statuses_to_mitigate.extend((eps, user) for eps in existing_finding_endpoint_status_list) else: # Convert to set for O(1) lookups instead of O(n) linear search new_finding_endpoints_set = set(new_finding_endpoints_list) # Mitigate any endpoints in the old finding not found in the new finding self._statuses_to_mitigate.extend( - eps for eps in existing_finding_endpoint_status_list + (eps, user) for eps in existing_finding_endpoint_status_list if eps.endpoint not in new_finding_endpoints_set ) # Re-activate any endpoints in the old finding that are in the new finding @@ -178,9 +178,9 @@ def record_statuses_to_reactivate(self, statuses: list[Endpoint_Status]) -> None """Accumulate endpoint statuses for bulk reactivation in persist().""" self._statuses_to_reactivate.extend(statuses) - def record_statuses_to_mitigate(self, statuses: list[Endpoint_Status]) -> None: + def record_statuses_to_mitigate(self, statuses: list[Endpoint_Status], user: Dojo_User | None = None) -> None: """Accumulate endpoint statuses for bulk mitigation in persist().""" - self._statuses_to_mitigate.extend(statuses) + self._statuses_to_mitigate.extend((eps, user) for eps in statuses) def get_or_create_endpoints(self) -> tuple[dict[EndpointUniqueKey, Endpoint], list[Endpoint]]: """ @@ -239,7 +239,7 @@ def get_or_create_endpoints(self) -> tuple[dict[EndpointUniqueKey, Endpoint], li self._endpoints_to_create.clear() return endpoints_by_key, created - def persist(self, user: Dojo_User | None = None) -> None: + def persist(self) -> None: """ Persist all accumulated endpoint operations to the database. @@ -267,11 +267,11 @@ def persist(self, user: Dojo_User | None = None) -> None: if self._statuses_to_mitigate: now = timezone.now() to_update = [] - for endpoint_status in self._statuses_to_mitigate: + for endpoint_status, mitigated_by in self._statuses_to_mitigate: if endpoint_status.mitigated is False: endpoint_status.mitigated_time = now endpoint_status.last_modified = now - endpoint_status.mitigated_by = user + endpoint_status.mitigated_by = mitigated_by endpoint_status.mitigated = True to_update.append(endpoint_status) if to_update: @@ -329,9 +329,9 @@ def record_reactivations_for_finding(self, finding: Finding) -> None: """Record endpoint statuses on this finding for reactivation.""" self.record_statuses_to_reactivate(self.get_non_special_endpoint_statuses(finding)) - def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User | None = None) -> None: + def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User) -> None: """Record endpoint statuses on this finding for mitigation.""" - self.record_statuses_to_mitigate(finding.status_finding.all()) + self.record_statuses_to_mitigate(finding.status_finding.all(), user) def get_locations_for_tagging(self, findings: list[Finding]): """Return queryset of locations to apply tags to.""" diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index cd6c70260ae..6cdbaeef2ab 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -43,7 +43,7 @@ def __init__(self, product: Product) -> None: self._finding_ids_to_fully_reactivate: list[int] = [] # finding_id -> user: fully mitigate (all non-special refs on this finding become mitigated by user). # If recorded multiple times for the same finding, last user wins. - self._finding_ids_to_fully_mitigate: dict[int, Dojo_User | None] = {} + self._finding_ids_to_fully_mitigate: dict[int, Dojo_User] = {} # Cached result of _should_inherit_product_tags() — lazily computed and reused across persist() calls self._cached_should_inherit_product_tags: bool | None = None @@ -109,7 +109,7 @@ def record_reactivations_for_finding(self, finding: Finding) -> None: """Defer reactivation to persist(). No DB access at record time.""" self._finding_ids_to_fully_reactivate.append(finding.id) - def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User | None = None) -> None: + def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User) -> None: """Defer mitigation to persist(). No DB access at record time.""" self._finding_ids_to_fully_mitigate[finding.id] = user @@ -129,7 +129,7 @@ def serialize_extra_locations(self, locations: list) -> dict: # Persist — flush all accumulated operations to DB # ------------------------------------------------------------------ - def persist(self, user: Dojo_User | None = None) -> None: + def persist(self) -> None: """Persist all accumulated location operations to the database.""" with transaction.atomic(): self._persist_locations() From e24b4b6755d59e2ba4002e7c52093adb70c7fa0b Mon Sep 17 00:00:00 2001 From: dogboat Date: Fri, 17 Apr 2026 21:46:28 -0400 Subject: [PATCH 42/47] perf test updates --- unittests/test_importers_performance.py | 32 ++++++++++++------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/unittests/test_importers_performance.py b/unittests/test_importers_performance.py index ac77faed610..2ca6d251bf6 100644 --- a/unittests/test_importers_performance.py +++ b/unittests/test_importers_performance.py @@ -569,13 +569,13 @@ def test_import_reimport_reimport_performance_pghistory_async(self): configure_pghistory_triggers() self._import_reimport_performance( - expected_num_queries1=144, + expected_num_queries1=146, expected_num_async_tasks1=1, - expected_num_queries2=119, + expected_num_queries2=124, expected_num_async_tasks2=1, - expected_num_queries3=32, + expected_num_queries3=37, expected_num_async_tasks3=1, - expected_num_queries4=96, + expected_num_queries4=101, expected_num_async_tasks4=0, ) @@ -593,13 +593,13 @@ def test_import_reimport_reimport_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._import_reimport_performance( - expected_num_queries1=153, + expected_num_queries1=155, expected_num_async_tasks1=1, - expected_num_queries2=128, + expected_num_queries2=133, expected_num_async_tasks2=1, - expected_num_queries3=41, + expected_num_queries3=46, expected_num_async_tasks3=1, - expected_num_queries4=96, + expected_num_queries4=101, expected_num_async_tasks4=0, ) @@ -618,13 +618,13 @@ def test_import_reimport_reimport_performance_pghistory_no_async_with_product_gr self.system_settings(enable_product_grade=True) self._import_reimport_performance( - expected_num_queries1=163, + expected_num_queries1=165, expected_num_async_tasks1=3, - expected_num_queries2=138, + expected_num_queries2=143, expected_num_async_tasks2=3, - expected_num_queries3=45, + expected_num_queries3=50, expected_num_async_tasks3=3, - expected_num_queries4=106, + expected_num_queries4=111, expected_num_async_tasks4=2, ) @@ -718,9 +718,9 @@ def test_deduplication_performance_pghistory_async(self): self.system_settings(enable_deduplication=True) self._deduplication_performance( - expected_num_queries1=79, + expected_num_queries1=81, expected_num_async_tasks1=1, - expected_num_queries2=70, + expected_num_queries2=72, expected_num_async_tasks2=1, check_duplicates=False, # Async mode - deduplication happens later ) @@ -738,8 +738,8 @@ def test_deduplication_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._deduplication_performance( - expected_num_queries1=88, + expected_num_queries1=90, expected_num_async_tasks1=1, - expected_num_queries2=186, + expected_num_queries2=188, expected_num_async_tasks2=1, ) From efc56259afa8f743b658071d830daadbef99564b Mon Sep 17 00:00:00 2001 From: dogboat Date: Tue, 21 Apr 2026 10:55:10 -0400 Subject: [PATCH 43/47] comments --- dojo/importers/location_manager.py | 129 +++++++++++++++++++---------- 1 file changed, 87 insertions(+), 42 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 6cdbaeef2ab..65622ed2ef3 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -3,7 +3,7 @@ import logging from itertools import groupby from operator import itemgetter -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NamedTuple from django.core.exceptions import ValidationError from django.db import transaction @@ -29,6 +29,12 @@ UnsavedLocation = LocationData | AbstractLocation +class StatusUpdateEntry(NamedTuple): + existing_finding: Finding + new_finding: Finding + user: Dojo_User + + class LocationManager(BaseLocationManager): def __init__(self, product: Product) -> None: @@ -38,12 +44,12 @@ def __init__(self, product: Product) -> None: self._product_locations: list[UnsavedLocation] = [] # Status update inputs (deferred). All entries are processed in a single bulk pass by persist(). # (existing_finding, new_finding, user): classified partial mitigate/reactivate - self._status_updates: list[tuple[Finding, Finding, Dojo_User]] = [] + self._status_updates: list[StatusUpdateEntry] = [] # finding_id: fully reactivate (all mitigated refs on this finding become active) - self._finding_ids_to_fully_reactivate: list[int] = [] + self._findings_to_reactivate: list[int] = [] # finding_id -> user: fully mitigate (all non-special refs on this finding become mitigated by user). # If recorded multiple times for the same finding, last user wins. - self._finding_ids_to_fully_mitigate: dict[int, Dojo_User] = {} + self._findings_to_mitigate: dict[int, Dojo_User] = {} # Cached result of _should_inherit_product_tags() — lazily computed and reused across persist() calls self._cached_should_inherit_product_tags: bool | None = None @@ -85,7 +91,7 @@ def update_location_status( user: Dojo_User, ) -> None: """Defer status update to persist(). No DB access at record time.""" - self._status_updates.append((existing_finding, new_finding, user)) + self._status_updates.append(StatusUpdateEntry(existing_finding, new_finding, user)) # ------------------------------------------------------------------ # Unified interface (shared with EndpointManager) @@ -107,11 +113,11 @@ def update_status(self, existing_finding: Finding, new_finding: Finding, user: D def record_reactivations_for_finding(self, finding: Finding) -> None: """Defer reactivation to persist(). No DB access at record time.""" - self._finding_ids_to_fully_reactivate.append(finding.id) + self._findings_to_reactivate.append(finding.id) def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User) -> None: """Defer mitigation to persist(). No DB access at record time.""" - self._finding_ids_to_fully_mitigate[finding.id] = user + self._findings_to_mitigate[finding.id] = user def get_locations_for_tagging(self, findings: list[Finding]): """Return queryset of locations to apply tags to.""" @@ -241,41 +247,62 @@ def _persist_locations(self) -> None: self._product_locations.clear() def _persist_status_updates(self) -> None: - """Bulk persist recorded finding/product ref statuses.""" + """ + Bulk persist finding/product ref statuses. + + Throughout the (re)import process, we've tracked three types of status changes: locations to mitigate, locations + to reactivate, and locations whose statuses need to be evaluated at this time by comparing locations between + existing findings and new findings. + + To start, this method processes the comparisons between existing/new findings. If the new finding is Mitigated, + then all existing locations are added to the 'to mitigate' set. Otherwise, locations that are in both the new + finding and existing finding are added to the 'to reactivate' set, and locations that are on the existing + finding but not the new finding are added to the 'to mitigate' set. + + Next, all locations in the 'to reactivate' set are bulk set to Active, and all locations in the 'to mitigate' + set are bulk set to Mitigated. + + Finally, product associations are updated: if any location associated with a finding on this product is Active, + the LocationProductReference object is set to Active; otherwise, it is set to Mitigated. + """ # Short-circuit if nothing to do - if not (self._status_updates or self._finding_ids_to_fully_reactivate or self._finding_ids_to_fully_mitigate): + if not (self._status_updates or self._findings_to_reactivate or self._findings_to_mitigate): return + # List of statuses we'll skip processing changes for special_statuses = [ FindingLocationStatus.FalsePositive, FindingLocationStatus.RiskAccepted, FindingLocationStatus.OutOfScope, ] - # Collect ref IDs to reactivate / mitigate across all accumulated inputs + # The set of LocationFindingReference IDs to reactivate ref_ids_to_reactivate: set[int] = set() - # Grouped by user since auditor differs per entry - ref_ids_to_mitigate_by_user: dict[Dojo_User | None, set[int]] = {} + # The set of LocationFindingReference IDs to mitigate, and the user to associate with it + ref_ids_to_mitigate: dict[int, Dojo_User] = {} - # Partial status updates (from update_location_status): need per-finding classification + # Process status updates determined by comparing existing/new findings if self._status_updates: - finding_ids_for_partial = {upd[0].id for upd in self._status_updates} - # Single fetch of all candidate refs with their location values + # Look up all the existing LocationFindingReference objects and store per-Finding + existing_finding_ids = {upd.existing_finding.id for upd in self._status_updates} refs_by_finding: dict[int, list[LocationFindingReference]] = {} for ref in ( LocationFindingReference.objects - .filter(finding_id__in=finding_ids_for_partial) + .filter(finding_id__in=existing_finding_ids) .exclude(status__in=special_statuses) .select_related("location") ): refs_by_finding.setdefault(ref.finding_id, []).append(ref) + # Next: for each StatusUpdateEntry, determine what we should do with the existing refs for existing_finding, new_finding, user in self._status_updates: finding_refs = refs_by_finding.get(existing_finding.id, []) if new_finding.is_mitigated: - # All non-special refs on this finding get mitigated - ref_ids_to_mitigate_by_user.setdefault(user, set()).update(r.id for r in finding_refs) + # The new finding is mitigated, so mitigate all existing (non-special) refs + ref_ids_to_mitigate.update({r.id: user for r in finding_refs}) else: + # The new finding is not mitigated; we need to reactivate locations that are in the new finding and + # mitigate statuses that are NOT in the new finding. new_loc_values = { str(loc) for loc in self.clean_unsaved_locations(new_finding.unsaved_locations) } @@ -283,34 +310,35 @@ def _persist_status_updates(self) -> None: if ref.location.location_value in new_loc_values: ref_ids_to_reactivate.add(ref.id) else: - ref_ids_to_mitigate_by_user.setdefault(user, set()).add(ref.id) + ref_ids_to_mitigate[ref.id] = user - # Reactivate all mitigated refs for these findings - if self._finding_ids_to_fully_reactivate: + # Update the "reactivate set" with the IDs of existing LocationFindingReference objects we need to reactivate + if self._findings_to_reactivate: ref_ids_to_reactivate.update( LocationFindingReference.objects.filter( - finding_id__in=self._finding_ids_to_fully_reactivate, + finding_id__in=self._findings_to_reactivate, status=FindingLocationStatus.Mitigated, ).values_list("id", flat=True), ) - # Mitigate all non-special refs for these findings, per user - if self._finding_ids_to_fully_mitigate: - # Group finding_ids by user to do one SELECT per user - ids_by_user: dict[Dojo_User | None, list[int]] = {} - for finding_id, user in self._finding_ids_to_fully_mitigate.items(): - ids_by_user.setdefault(user, []).append(finding_id) - for user, finding_ids in ids_by_user.items(): - ref_ids_to_mitigate_by_user.setdefault(user, set()).update( - LocationFindingReference.objects.filter( - finding_id__in=finding_ids, - ).exclude(status__in=special_statuses).values_list("id", flat=True), - ) - - # Execute bulk finding ref updates + # Update the "mitigate set" with the IDs of existing LocationFindingReference objects we need to mitigate. + # Note we exclude LocationFindingReferences that currently have one of the special statuses. + if self._findings_to_mitigate: + ref_ids_to_mitigate.update({ + ref_id: self._findings_to_mitigate[finding_id] + for ref_id, finding_id in LocationFindingReference.objects.filter( + finding_id__in=self._findings_to_mitigate.keys(), + ).exclude(status__in=special_statuses).values_list("id", "finding_id") + }) + + # Hoorah we finally get around to actually updating stuff now = timezone.now() + # Track all updated LocationFindingReference IDs so we can update the corresponding LocationProductReferences + # as necessary: if any LocationFindingReference is Active, the LocationProductReferences will be set to Active; + # otherwise, they will be set to Mitigated. all_affected_ref_ids: set[int] = set() + # Update Mitigated => Active ("reactivate") if ref_ids_to_reactivate: LocationFindingReference.objects.filter( id__in=ref_ids_to_reactivate, @@ -322,24 +350,36 @@ def _persist_status_updates(self) -> None: ) all_affected_ref_ids |= ref_ids_to_reactivate - for user, ref_ids in ref_ids_to_mitigate_by_user.items(): - if ref_ids: + # Update ~Mitigated => Mitigated + if ref_ids_to_mitigate: + # Flip (ref_id -> user) to (user -> set[ref_id]) for per-user bulk updates + ref_ids_per_user: dict[Dojo_User, set[int]] = {} + for ref_id, user in ref_ids_to_mitigate.items(): + ref_ids_per_user.setdefault(user, set()).add(ref_id) + # Update per user + for user, ref_ids in ref_ids_per_user.items(): LocationFindingReference.objects.filter( id__in=ref_ids, - ).exclude(status=FindingLocationStatus.Mitigated).update( + ).exclude( + status=FindingLocationStatus.Mitigated + ).update( auditor=user, audit_time=now, status=FindingLocationStatus.Mitigated, ) all_affected_ref_ids |= ref_ids - # Propagate to product refs: if any finding ref is active, product ref is active; otherwise mitigated. + # Propagate to product refs: if any finding ref for this location on this product is Active, product ref is + # Active; otherwise Mitigated. if all_affected_ref_ids: + # Grab the location IDs for all the LocationFindingReferences we updated affected_location_ids = set( LocationFindingReference.objects.filter( id__in=all_affected_ref_ids, ).values_list("location_id", flat=True), ) + # Look up all affected LocationFindingReferences that are now Active and associated with this product + # through the "finding.test.engagement.product" chain locations_still_active = set( LocationFindingReference.objects.filter( location_id__in=affected_location_ids, @@ -347,8 +387,11 @@ def _persist_status_updates(self) -> None: status=FindingLocationStatus.Active, ).values_list("location_id", flat=True), ) + # Diff the two; this leaves IDs of locations that should be set to Mitigated at the product level locations_now_mitigated = affected_location_ids - locations_still_active + # Update LocationProductReferences to Active for any locations associated with this product that have an + # Active LocationFindingReference if locations_still_active: LocationProductReference.objects.filter( location_id__in=locations_still_active, @@ -356,6 +399,8 @@ def _persist_status_updates(self) -> None: ).exclude(status=ProductLocationStatus.Active).update( status=ProductLocationStatus.Active, ) + # Update LocationProductReferences to Mitigated for any locations associated with this product that have no + # Active LocationFindingReferences if locations_now_mitigated: LocationProductReference.objects.filter( location_id__in=locations_now_mitigated, @@ -366,8 +411,8 @@ def _persist_status_updates(self) -> None: # Clear accumulators self._status_updates.clear() - self._finding_ids_to_fully_reactivate.clear() - self._finding_ids_to_fully_mitigate.clear() + self._findings_to_reactivate.clear() + self._findings_to_mitigate.clear() # ------------------------------------------------------------------ # Type registry From 9dc2728f7f14812dd8405bf110f99fd37e661e3e Mon Sep 17 00:00:00 2001 From: dogboat Date: Tue, 21 Apr 2026 10:56:24 -0400 Subject: [PATCH 44/47] linter --- dojo/importers/location_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 65622ed2ef3..ffaa1bd51be 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -361,7 +361,7 @@ def _persist_status_updates(self) -> None: LocationFindingReference.objects.filter( id__in=ref_ids, ).exclude( - status=FindingLocationStatus.Mitigated + status=FindingLocationStatus.Mitigated, ).update( auditor=user, audit_time=now, From 97e60aa1cf467e459303e597d3909c03c11a362d Mon Sep 17 00:00:00 2001 From: dogboat Date: Tue, 21 Apr 2026 13:38:18 -0400 Subject: [PATCH 45/47] comments --- dojo/importers/location_manager.py | 198 +++++++++++------------------ 1 file changed, 75 insertions(+), 123 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index ffaa1bd51be..84cbc71c218 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -13,13 +13,16 @@ from dojo.importers.base_location_manager import BaseLocationManager from dojo.location.models import AbstractLocation, Location, LocationFindingReference, LocationProductReference from dojo.location.status import FindingLocationStatus, ProductLocationStatus +from dojo.models import Product, _manage_inherited_tags from dojo.tags_signals import make_inherited_tags_sticky from dojo.tools.locations import LocationData from dojo.url.models import URL from dojo.utils import get_system_setting if TYPE_CHECKING: - from dojo.models import Dojo_User, Finding, Product + from tagulous.models import TagField + + from dojo.models import Dojo_User, Finding logger = logging.getLogger(__name__) @@ -29,6 +32,7 @@ UnsavedLocation = LocationData | AbstractLocation +# Entry for status update; status will be determined by comparing locations between existing and new findings class StatusUpdateEntry(NamedTuple): existing_finding: Finding new_finding: Finding @@ -39,36 +43,17 @@ class LocationManager(BaseLocationManager): def __init__(self, product: Product) -> None: super().__init__(product) + # Maps findings to a list of locations self._locations_by_finding: dict[Finding, list[UnsavedLocation]] = {} - # Product-only locations (not tied to a finding). Appended to by record_locations_for_product. + # Product-only locations (not tied to a finding) self._product_locations: list[UnsavedLocation] = [] - # Status update inputs (deferred). All entries are processed in a single bulk pass by persist(). - # (existing_finding, new_finding, user): classified partial mitigate/reactivate + # Status update entries, which we'll use at persist-time to determine Location statuses by comparing + # existing vs new finding entries. self._status_updates: list[StatusUpdateEntry] = [] - # finding_id: fully reactivate (all mitigated refs on this finding become active) - self._findings_to_reactivate: list[int] = [] - # finding_id -> user: fully mitigate (all non-special refs on this finding become mitigated by user). - # If recorded multiple times for the same finding, last user wins. - self._findings_to_mitigate: dict[int, Dojo_User] = {} - # Cached result of _should_inherit_product_tags() — lazily computed and reused across persist() calls - self._cached_should_inherit_product_tags: bool | None = None - - def _should_inherit_product_tags(self) -> bool: - """ - Return True if new LocationFindingReference/LocationProductReference creations - should trigger inherit_instance_tags on the affected locations. - - inherit_instance_tags() runs a complex JOIN query per location (via all_related_products()), - which is O(N) per bulk persist. We short-circuit when neither the product nor the system - setting has tag inheritance enabled — in that case, adding a new ref for self._product - cannot change any location's inherited tags. - """ - if self._cached_should_inherit_product_tags is None: - self._cached_should_inherit_product_tags = bool( - getattr(self._product, "enable_product_tag_inheritance", False) - or get_system_setting("enable_product_tag_inheritance"), - ) - return self._cached_should_inherit_product_tags + # IDs of finding refs to reactivate + self._refs_to_reactivate: list[int] = [] + # IDs of finding refs to mitigate, by the associated user + self._refs_to_mitigate: dict[int, Dojo_User] = {} # ------------------------------------------------------------------ # Accumulation methods (no DB hits) @@ -113,11 +98,11 @@ def update_status(self, existing_finding: Finding, new_finding: Finding, user: D def record_reactivations_for_finding(self, finding: Finding) -> None: """Defer reactivation to persist(). No DB access at record time.""" - self._findings_to_reactivate.append(finding.id) + self._refs_to_reactivate.append(finding.id) def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User) -> None: """Defer mitigation to persist(). No DB access at record time.""" - self._findings_to_mitigate[finding.id] = user + self._refs_to_mitigate[finding.id] = user def get_locations_for_tagging(self, findings: list[Finding]): """Return queryset of locations to apply tags to.""" @@ -236,11 +221,7 @@ def _persist_locations(self) -> None: ) # Trigger bulk tag inheritance - if self._should_inherit_product_tags(): - self._bulk_inherit_tags( - (loc.location for loc in saved), - known_product=self._product, - ) + self._bulk_inherit_tags(loc.location for loc in saved) # Clear accumulators self._locations_by_finding.clear() @@ -266,7 +247,7 @@ def _persist_status_updates(self) -> None: the LocationProductReference object is set to Active; otherwise, it is set to Mitigated. """ # Short-circuit if nothing to do - if not (self._status_updates or self._findings_to_reactivate or self._findings_to_mitigate): + if not (self._status_updates or self._refs_to_reactivate or self._refs_to_mitigate): return # List of statuses we'll skip processing changes for @@ -313,21 +294,21 @@ def _persist_status_updates(self) -> None: ref_ids_to_mitigate[ref.id] = user # Update the "reactivate set" with the IDs of existing LocationFindingReference objects we need to reactivate - if self._findings_to_reactivate: + if self._refs_to_reactivate: ref_ids_to_reactivate.update( LocationFindingReference.objects.filter( - finding_id__in=self._findings_to_reactivate, + finding_id__in=self._refs_to_reactivate, status=FindingLocationStatus.Mitigated, ).values_list("id", flat=True), ) # Update the "mitigate set" with the IDs of existing LocationFindingReference objects we need to mitigate. # Note we exclude LocationFindingReferences that currently have one of the special statuses. - if self._findings_to_mitigate: + if self._refs_to_mitigate: ref_ids_to_mitigate.update({ - ref_id: self._findings_to_mitigate[finding_id] + ref_id: self._refs_to_mitigate[finding_id] for ref_id, finding_id in LocationFindingReference.objects.filter( - finding_id__in=self._findings_to_mitigate.keys(), + finding_id__in=self._refs_to_mitigate.keys(), ).exclude(status__in=special_statuses).values_list("id", "finding_id") }) @@ -411,8 +392,8 @@ def _persist_status_updates(self) -> None: # Clear accumulators self._status_updates.clear() - self._findings_to_reactivate.clear() - self._findings_to_mitigate.clear() + self._refs_to_reactivate.clear() + self._refs_to_mitigate.clear() # ------------------------------------------------------------------ # Type registry @@ -507,120 +488,91 @@ def type_id(x: tuple[int, AbstractLocation]) -> int: # Tag inheritance # ------------------------------------------------------------------ - @staticmethod - def _bulk_inherit_tags(locations, *, known_product=None): + def _bulk_inherit_tags(self, locations): """ - Bulk equivalent of calling inherit_instance_tags(loc) for many Locations. - - Uses aggressive prefetching to produce O(1) queries for the "decide what needs - to change" phase, and only runs per-instance mutation queries (~3 each) for - locations that are actually out of sync with their product tags. - - Compared to the per-instance path, this avoids the N expensive JOINs in - Location.all_related_products() (~50ms each). - - Args: - locations: iterable of Location instances to update - known_product: optional hint — if provided, used as the minimum product - set for locations not already associated elsewhere. Not strictly - required for correctness, but lets us skip the fetch-related-products - query in the common case. + Bulk equivalent of calling inherit_instance_tags(loc) for many Locations. Actually persisting updates is handled + by a per-location call to _manage_inherited_tags(), but at least determining what the tags are is more efficient + (plus we can skip locations that don't need an update at all). + When tag inheritance is enabled, computes the target inherited tags for each location from all related products + and updates only locations that are out of sync. """ - from dojo.models import Product, _manage_inherited_tags # noqa: PLC0415 - locations = list(locations) if not locations: return + # Check whether tag inheritance is enabled at either the product level or system-wide; quit early if neither + product_inherit = getattr(self._product, "enable_product_tag_inheritance", False) system_wide_inherit = bool(get_system_setting("enable_product_tag_inheritance")) + if not system_wide_inherit and not product_inherit: + return - # --- Bulk query: map location_id -> set[product_id] for every related product + # A location can be shared across multiple products. Its inherited tags should be the union of + # tags from ALL contributing products, not just the one running this import. location_ids = [loc.id for loc in locations] product_ids_by_location: dict[int, set[int]] = {loc.id: set() for loc in locations} - # Path 1: via LocationProductReference (direct association) + # Find associations through LocationProductReference entries for loc_id, prod_id in LocationProductReference.objects.filter( location_id__in=location_ids, ).values_list("location_id", "product_id"): product_ids_by_location[loc_id].add(prod_id) - # Path 2: via LocationFindingReference -> Finding -> Test -> Engagement -> Product + # Find associations through LocationFindingReference entries and the finding.test.engagement.product chain. + # This shouldn't add anything new, but just in case. for loc_id, prod_id in ( LocationFindingReference.objects .filter(location_id__in=location_ids) .values_list("location_id", "finding__test__engagement__product_id") ): - if prod_id is not None: - product_ids_by_location[loc_id].add(prod_id) - - # Seed with known_product so callers don't have to rely on refs being persisted before this call - if known_product is not None: - for loc_id in location_ids: - product_ids_by_location[loc_id].add(known_product.id) + product_ids_by_location[loc_id].add(prod_id) - # --- Bulk query: fetch the unique products with their tags and inheritance flag + # Fetch all products that will contribute to tag inheritance, and their tags all_product_ids = {pid for pids in product_ids_by_location.values() for pid in pids} - if not all_product_ids: - return - - products = { - p.id: p - for p in Product.objects.filter(id__in=all_product_ids).prefetch_related("tags") - } - - # Products that contribute to inheritance (either opted in themselves or system-wide on) - contributing_product_ids = { - pid for pid, p in products.items() - if p.enable_product_tag_inheritance or system_wide_inherit - } - if not contributing_product_ids: - return - - # Pre-compute the tag names each contributing product contributes + product_qs = Product.objects.filter(id__in=all_product_ids).prefetch_related("tags") + if not system_wide_inherit: + # Product-level inheritance only + product_qs = product_qs.filter(enable_product_tag_inheritance=True) + # Materialize into a dict for ease of use + products: dict[int, Product] = {p.id: p for p in product_qs} + # Get distinct tags, per-product tags_by_product: dict[int, set[str]] = { - pid: {t.name for t in products[pid].tags.all()} - for pid in contributing_product_ids + pid: {t.name for t in p.tags.all()} + for pid, p in products.items() } - # --- Bulk query: existing inherited_tags per location - inherited_through = Location.inherited_tags.through - inherited_fk = Location.inherited_tags.field.m2m_reverse_field_name() - existing_inherited_by_location: dict[int, set[str]] = {loc.id: set() for loc in locations} - for loc_id, tag_name in inherited_through.objects.filter( - location_id__in=location_ids, - ).values_list("location_id", f"{inherited_fk}__name"): - existing_inherited_by_location[loc_id].add(tag_name) - - # --- Bulk query: existing user tags per location (needed by _manage_inherited_tags) - tags_through = Location.tags.through - tags_fk = Location.tags.field.m2m_reverse_field_name() - existing_tags_by_location: dict[int, list[str]] = {loc.id: [] for loc in locations} - for loc_id, tag_name in tags_through.objects.filter( - location_id__in=location_ids, - ).values_list("location_id", f"{tags_fk}__name"): - existing_tags_by_location[loc_id].append(tag_name) - - # --- Determine which locations are out of sync and call _manage_inherited_tags directly. - # Must disconnect make_inherited_tags_sticky while we mutate — otherwise each - # tags.set() / inherited_tags.set() fires m2m_changed, re-enters the whole expensive - # chain per location, and defeats the point of the bulk path. - # Only disconnect/reconnect for senders where the signal is actually registered - # (tags.through). inherited_tags.through is not a registered sender — attempting - # to connect it after disconnect() would incorrectly add a new registration, - # causing recursion on subsequent calls. - disconnected = signals.m2m_changed.disconnect(make_inherited_tags_sticky, sender=tags_through) + # Helper method for getting all tags from the given TagField + def _get_tags(tags_field: TagField) -> dict[int, set[str]]: + through_model = tags_field.through + fk_name = tags_field.field.m2m_reverse_field_name() + tags_by_location: dict[int, set[str]] = {loc.id: set() for loc in locations} + for l_id, t_name in through_model.objects.filter( + location_id__in=location_ids, + ).values_list("location_id", f"{fk_name}__name"): + tags_by_location[l_id].add(t_name) + return tags_by_location + + # Gather inherited and 'regular' tags per location + existing_inherited_by_location: dict[int, set[str]] = _get_tags(Location.inherited_tags) + existing_tags_by_location: dict[int, set[str]] = _get_tags(Location.tags) + + # Perform the bulk updates. First, though, disconnect the make_inherited_tags_sticky signal on Location.tags + # while updating, otherwise each (inherited_)tags.set() will trigger, defeating the purpose of this bulk update. + disconnected = signals.m2m_changed.disconnect(make_inherited_tags_sticky, sender=Location.tags.through) try: for location in locations: target_tag_names: set[str] = set() for pid in product_ids_by_location[location.id]: - if pid in contributing_product_ids: + # product_ids_by_location may contain products that shouldn't to contribute to tag inheritance (we + # didn't filter either location ref lookups to check), so do a last-minute check here + if pid in products: target_tag_names |= tags_by_product[pid] - existing = existing_inherited_by_location[location.id] - if target_tag_names == existing: + if target_tag_names == existing_inherited_by_location[location.id]: + # The existing set matches the expected set, so nothing more to do for this location continue + # Update tags for this location _manage_inherited_tags( location, list(target_tag_names), @@ -628,4 +580,4 @@ def _bulk_inherit_tags(locations, *, known_product=None): ) finally: if disconnected: - signals.m2m_changed.connect(make_inherited_tags_sticky, sender=tags_through) + signals.m2m_changed.connect(make_inherited_tags_sticky, sender=Location.tags.through) From 4e4009ca551bf6123b0b2af62365cc6335199ad0 Mon Sep 17 00:00:00 2001 From: dogboat Date: Tue, 21 Apr 2026 15:51:11 -0400 Subject: [PATCH 46/47] comment --- dojo/importers/location_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 84cbc71c218..4462ab0a187 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -133,7 +133,8 @@ def _persist_locations(self) -> None: return # Convert all UnsavedLocation objects (possibly a mix of AbstractLocation and LocationData objects) to cleaned - # concrete location objects + # concrete location objects. _product_locations is the superset of all locations (finding-associated + + # product-only), so cleaning it once covers everything. all_locations = self.clean_unsaved_locations(self._product_locations) if not all_locations: self._locations_by_finding.clear() @@ -195,7 +196,7 @@ def _persist_locations(self) -> None: # identity_hash uniquely defines the location per type, so using these we can match up with actual # persisted locations from above, all of which will be represented in saved_by_key. (Keep in mind, # _locations_by_finding contains a subset of the locations across all its values in - # _locations_by_finding.) + # _product_locations.) for location in self.clean_unsaved_locations(unsaved_locations): saved_loc = saved_by_key[type(location), location.identity_hash] finding_ref_key = (finding.id, saved_loc.location_id) From 5af3da9edf7802292ffc9002378c2d779db59d9f Mon Sep 17 00:00:00 2001 From: dogboat Date: Tue, 21 Apr 2026 16:55:17 -0400 Subject: [PATCH 47/47] fixup --- dojo/importers/location_manager.py | 34 ++++++++++++------------------ 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 4462ab0a187..94999538fc1 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -43,10 +43,10 @@ class LocationManager(BaseLocationManager): def __init__(self, product: Product) -> None: super().__init__(product) - # Maps findings to a list of locations - self._locations_by_finding: dict[Finding, list[UnsavedLocation]] = {} - # Product-only locations (not tied to a finding) - self._product_locations: list[UnsavedLocation] = [] + # Maps findings to a list of cleaned locations + self._locations_by_finding: dict[Finding, list[AbstractLocation]] = {} + # All locations needing product refs (finding-associated + product-only), cleaned at record time + self._product_locations: list[AbstractLocation] = [] # Status update entries, which we'll use at persist-time to determine Location statuses by comparing # existing vs new finding entries. self._status_updates: list[StatusUpdateEntry] = [] @@ -66,8 +66,9 @@ def record_locations_for_finding( ) -> None: """Record locations to be associated with a finding (and its product). Flushed by persist().""" if locations: - self._locations_by_finding.setdefault(finding, []).extend(locations) - self._product_locations.extend(locations) + cleaned = self.clean_unsaved_locations(locations) + self._locations_by_finding.setdefault(finding, []).extend(cleaned) + self._product_locations.extend(cleaned) def update_location_status( self, @@ -132,10 +133,9 @@ def _persist_locations(self) -> None: if not self._product_locations: return - # Convert all UnsavedLocation objects (possibly a mix of AbstractLocation and LocationData objects) to cleaned - # concrete location objects. _product_locations is the superset of all locations (finding-associated + - # product-only), so cleaning it once covers everything. - all_locations = self.clean_unsaved_locations(self._product_locations) + # Locations are already cleaned at record time (in record_locations_for_finding). Deduplicate the + # full set — _product_locations is the superset of all locations (finding-associated + product-only). + all_locations = list({(type(loc), loc.identity_hash): loc for loc in self._product_locations}.values()) if not all_locations: self._locations_by_finding.clear() self._product_locations.clear() @@ -146,9 +146,6 @@ def _persist_locations(self) -> None: # Build a lookup from (type, identity_hash) -> saved location for finding ref creation. # identity_hash is only unique per concrete type, so we key by both. - # - # Finding/location mapping was tracked separately in _locations_by_finding which are still the raw - # UnsavedLocation objects; we'll need to line them up with the persisted locations. saved_by_key: dict[tuple[type, str], AbstractLocation] = { (type(loc), loc.identity_hash): loc for loc in saved } @@ -191,13 +188,10 @@ def _persist_locations(self) -> None: ).values_list("finding_id", "location_id"), ) - for finding, unsaved_locations in self._locations_by_finding.items(): - # Clean per-finding UnsavedLocations to get cleaned AbstractLocations with identity_hashes. The - # identity_hash uniquely defines the location per type, so using these we can match up with actual - # persisted locations from above, all of which will be represented in saved_by_key. (Keep in mind, - # _locations_by_finding contains a subset of the locations across all its values in - # _product_locations.) - for location in self.clean_unsaved_locations(unsaved_locations): + for finding, cleaned_locations in self._locations_by_finding.items(): + # Locations were already cleaned at record time — identity_hash is set, so we can + # look up the persisted location directly. + for location in cleaned_locations: saved_loc = saved_by_key[type(location), location.identity_hash] finding_ref_key = (finding.id, saved_loc.location_id) if finding_ref_key not in existing_finding_ref_keys: