diff --git a/dojo/db_migrations/0264_alter_url_identity_hash_alter_urlevent_identity_hash.py b/dojo/db_migrations/0264_alter_url_identity_hash_alter_urlevent_identity_hash.py new file mode 100644 index 00000000000..725612dbcc7 --- /dev/null +++ b/dojo/db_migrations/0264_alter_url_identity_hash_alter_urlevent_identity_hash.py @@ -0,0 +1,24 @@ +# Generated by Django 5.2.12 on 2026-04-10 14:24 + +import django.core.validators +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('dojo', '0263_language_type_unique_language'), + ] + + operations = [ + migrations.AlterField( + model_name='url', + name='identity_hash', + field=models.CharField(db_index=True, editable=False, help_text='The hash of the location for uniqueness', max_length=64, unique=True, validators=[django.core.validators.MinLengthValidator(64)]), + ), + migrations.AlterField( + model_name='urlevent', + name='identity_hash', + field=models.CharField(editable=False, help_text='The hash of the location for uniqueness', max_length=64, validators=[django.core.validators.MinLengthValidator(64)]), + ), + ] diff --git a/dojo/importers/base_importer.py b/dojo/importers/base_importer.py index c149f4e169d..ff04a5698de 100644 --- a/dojo/importers/base_importer.py +++ b/dojo/importers/base_importer.py @@ -13,8 +13,6 @@ import dojo.finding.helper as finding_helper import dojo.risk_acceptance.helper as ra_helper -from dojo.celery_dispatch import dojo_dispatch_task -from dojo.importers.location_manager import LocationManager, UnsavedLocation from dojo.importers.options import ImporterOptions from dojo.jira_link.helper import is_keep_in_sync_with_jira from dojo.location.models import Location @@ -80,8 +78,6 @@ def __init__( and will raise a `NotImplemented` exception """ ImporterOptions.__init__(self, *args, **kwargs) - if settings.V3_FEATURE_LOCATIONS: - self.location_manager = LocationManager() def check_child_implementation_exception(self): """ @@ -391,36 +387,20 @@ def apply_import_tags( for tag in self.tags: self.add_tags_safe(finding, tag) - if settings.V3_FEATURE_LOCATIONS: - # Add any tags to any locations of the findings imported if necessary - if self.apply_tags_to_endpoints and self.tags: - # Collect all endpoints linked to the affected findings - locations_qs = Location.objects.filter(findings__finding__in=findings_to_tag).distinct() - try: - bulk_add_tags_to_instances( - tag_or_tags=self.tags, - instances=locations_qs, - tag_field_name="tags", - ) - except IntegrityError: - for finding in findings_to_tag: - for location in finding.locations.all(): - for tag in self.tags: - self.add_tags_safe(location.location, tag) - # Add any tags to any endpoints of the findings imported if necessary - elif self.apply_tags_to_endpoints and self.tags: - endpoints_qs = Endpoint.objects.filter(finding__in=findings_to_tag).distinct() + # Add any tags to any locations/endpoints of the findings imported if necessary + if self.apply_tags_to_endpoints and self.tags: + locations_qs = self.location_handler.get_locations_for_tagging(findings_to_tag) try: bulk_add_tags_to_instances( tag_or_tags=self.tags, - instances=endpoints_qs, + instances=locations_qs, tag_field_name="tags", ) except IntegrityError: for finding in findings_to_tag: - for endpoint in finding.endpoints.all(): + for location in self.location_handler.get_location_tag_fallback(finding): for tag in self.tags: - self.add_tags_safe(endpoint, tag) + self.add_tags_safe(location, tag) def update_import_history( self, @@ -467,14 +447,8 @@ def update_import_history( import_settings["apply_tags_to_endpoints"] = self.apply_tags_to_endpoints import_settings["group_by"] = self.group_by import_settings["create_finding_groups_for_all_findings"] = self.create_finding_groups_for_all_findings - if settings.V3_FEATURE_LOCATIONS: - # Add the list of locations that were added exclusively at import time - if len(self.endpoints_to_add) > 0: - import_settings["locations"] = [str(location) for location in self.endpoints_to_add] - # TODO: Delete this after the move to Locations - # Add the list of endpoints that were added exclusively at import time - elif len(self.endpoints_to_add) > 0: - import_settings["endpoints"] = [str(endpoint) for endpoint in self.endpoints_to_add] + if len(self.endpoints_to_add) > 0: + import_settings.update(self.location_handler.serialize_extra_locations(self.endpoints_to_add)) # Create the test import object test_import = Test_Import.objects.create( test=self.test, @@ -796,50 +770,13 @@ def process_request_response_pairs( def process_locations( self, finding: Finding, - locations_to_add: list[UnsavedLocation], + extra_locations_to_add: list | None = None, ) -> None: """ - Process any locations to add to the finding. Locations could come from two places - - Directly from the report - - Supplied by the user from the import form - These locations will be processed in to Location objects and associated with the - finding and product - """ - # Save the unsaved locations - self.location_manager.chunk_locations_and_disperse(finding, finding.unsaved_locations) - # Check for any that were added in the form - if len(locations_to_add) > 0: - logger.debug("locations_to_add: %s", locations_to_add) - self.location_manager.chunk_locations_and_disperse(finding, locations_to_add) - - # TODO: Delete this after the move to Locations - def process_endpoints( - self, - finding: Finding, - endpoints_to_add: list[Endpoint], - ) -> None: + Record locations/endpoints from the finding + any form-added extras. + Flushed to DB by location_handler.persist(). """ - Process any endpoints to add to the finding. Endpoints could come from two places - - Directly from the report - - Supplied by the user from the import form - These endpoints will be processed in to endpoints objects and associated with the - finding and and product - """ - if settings.V3_FEATURE_LOCATIONS: - msg = "BaseImporter#process_endpoints() method is deprecated when V3_FEATURE_LOCATIONS is enabled" - raise NotImplementedError(msg) - - # Clean and record unsaved endpoints from the report - self.endpoint_manager.clean_unsaved_endpoints(finding.unsaved_endpoints) - for endpoint in finding.unsaved_endpoints: - key = self.endpoint_manager.record_endpoint(endpoint) - self.endpoint_manager.record_status_for_create(finding, key) - # Record any endpoints added from the form - if len(endpoints_to_add) > 0: - logger.debug("endpoints_to_add: %s", endpoints_to_add) - for endpoint in endpoints_to_add: - key = self.endpoint_manager.record_endpoint(endpoint) - self.endpoint_manager.record_status_for_create(finding, key) + self.location_handler.record_for_finding(finding, extra_locations_to_add) def sanitize_vulnerability_ids(self, finding) -> None: """Remove undisired vulnerability id values""" @@ -932,19 +869,7 @@ def mitigate_finding( # Remove risk acceptance if present (vulnerability is now fixed) # risk_unaccept will check if finding.risk_accepted is True before proceeding ra_helper.risk_unaccept(self.user, finding, perform_save=False, post_comments=False) - if settings.V3_FEATURE_LOCATIONS: - # Mitigate the location statuses - dojo_dispatch_task( - LocationManager.mitigate_location_status, - finding.locations.all(), - self.user, - kwuser=self.user, - sync=True, - ) - else: - # TODO: Delete this after the move to Locations - # Accumulate endpoint statuses for bulk mitigate in persist() - self.endpoint_manager.record_statuses_to_mitigate(finding.status_finding.all()) + self.location_handler.record_mitigations_for_finding(finding, self.user) # to avoid pushing a finding group multiple times, we push those outside of the loop if finding_groups_enabled and finding.finding_group: # don't try to dedupe findings that we are closing diff --git a/dojo/importers/base_location_manager.py b/dojo/importers/base_location_manager.py new file mode 100644 index 00000000000..df587edaafa --- /dev/null +++ b/dojo/importers/base_location_manager.py @@ -0,0 +1,130 @@ +""" +Base class and handler for location/endpoint managers in the import pipeline. + +BaseLocationManager defines the contract that both LocationManager (V3) and +EndpointManager (legacy) must implement. LocationHandler is the facade that +importers interact with — it picks the appropriate manager based on +V3_FEATURE_LOCATIONS and delegates all calls through the shared interface. + +This structure prevents drift between the two managers: adding an abstract +method to BaseLocationManager forces both to implement it, and callers can +only access methods exposed by LocationHandler. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dojo.models import Dojo_User, Finding, Product + + +class BaseLocationManager(ABC): + + """ + Abstract base for import-pipeline managers that handle network identifiers + (locations in V3, endpoints in legacy). + + Subclasses must implement every abstract method. The importer never calls + subclass-specific methods directly — it goes through LocationHandler. + """ + + def __init__(self, product: Product) -> None: + self._product = product + + @abstractmethod + def clean_unsaved(self, finding: Finding) -> None: + """Clean the unsaved locations/endpoints on this finding.""" + + @abstractmethod + def record_for_finding(self, finding: Finding, extra_locations: list | None = None) -> None: + """Record items from the finding + any form-added extras for later batch creation.""" + + @abstractmethod + def update_status(self, existing_finding: Finding, new_finding: Finding, user: Dojo_User) -> None: + """Accumulate status changes (mitigate/reactivate) based on old vs new finding.""" + + @abstractmethod + def record_reactivations_for_finding(self, finding: Finding) -> None: + """Record items on this finding for reactivation.""" + + @abstractmethod + def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User) -> None: + """Record items on this finding for mitigation.""" + + @abstractmethod + def get_locations_for_tagging(self, findings: list[Finding]): + """Return a queryset of taggable objects linked to the given findings.""" + + @abstractmethod + def get_location_tag_fallback(self, finding: Finding): + """Return an iterable of taggable objects for per-instance tag fallback.""" + + @abstractmethod + def serialize_extra_locations(self, locations: list) -> dict: + """Serialize extra locations/endpoints for import history settings.""" + + @abstractmethod + def persist(self) -> None: + """Flush all accumulated operations to the database.""" + + +class LocationHandler: + + """ + Facade used by importers. Delegates to the appropriate BaseLocationManager + implementation based on V3_FEATURE_LOCATIONS. + + Callers only see the methods defined here — they cannot reach into the + internal manager to call implementation-specific methods. This prevents + V3-only or endpoint-only code from leaking into shared importer logic. + """ + + def __init__( + self, + product: Product, + *, + v3_manager_class: type[BaseLocationManager] | None = None, + v2_manager_class: type[BaseLocationManager] | None = None, + ) -> None: + from django.conf import settings # noqa: PLC0415 + + from dojo.importers.endpoint_manager import EndpointManager # noqa: PLC0415 + from dojo.importers.location_manager import LocationManager # noqa: PLC0415 + + self._product = product + if settings.V3_FEATURE_LOCATIONS: + cls = v3_manager_class or LocationManager + else: + cls = v2_manager_class or EndpointManager + self._manager: BaseLocationManager = cls(product) + + # --- Delegates (one per BaseLocationManager method) --- + + def clean_unsaved(self, finding: Finding) -> None: + return self._manager.clean_unsaved(finding) + + def record_for_finding(self, finding: Finding, extra_locations: list | None = None) -> None: + return self._manager.record_for_finding(finding, extra_locations) + + def update_status(self, existing_finding: Finding, new_finding: Finding, user: Dojo_User) -> None: + return self._manager.update_status(existing_finding, new_finding, user) + + def record_reactivations_for_finding(self, finding: Finding) -> None: + return self._manager.record_reactivations_for_finding(finding) + + def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User) -> None: + return self._manager.record_mitigations_for_finding(finding, user) + + def get_locations_for_tagging(self, findings: list[Finding]): + return self._manager.get_locations_for_tagging(findings) + + def get_location_tag_fallback(self, finding: Finding): + return self._manager.get_location_tag_fallback(finding) + + def serialize_extra_locations(self, locations: list) -> dict: + return self._manager.serialize_extra_locations(locations) + + def persist(self) -> None: + return self._manager.persist() diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index e8bd56baf55..8fb4cdc185a 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -9,7 +9,7 @@ from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding import helper as finding_helper from dojo.importers.base_importer import BaseImporter, Parser -from dojo.importers.endpoint_manager import EndpointManager +from dojo.importers.base_location_manager import LocationHandler from dojo.importers.options import ImporterOptions from dojo.jira_link.helper import is_keep_in_sync_with_jira from dojo.models import ( @@ -58,8 +58,7 @@ def __init__(self, *args, **kwargs): import_type=Test_Import.IMPORT_TYPE, **kwargs, ) - if not settings.V3_FEATURE_LOCATIONS: - self.endpoint_manager = EndpointManager(self.engagement.product) + self.location_handler = LocationHandler(self.engagement.product) def create_test( self, @@ -240,13 +239,7 @@ def process_findings( ) # Process any request/response pairs self.process_request_response_pairs(finding) - if settings.V3_FEATURE_LOCATIONS: - # Process any locations on the finding, or added on the form - self.process_locations(finding, self.endpoints_to_add) - else: - # TODO: Delete this after the move to Locations - # Process any endpoints on the finding, or added on the form - self.process_endpoints(finding, self.endpoints_to_add) + self.process_locations(finding, self.endpoints_to_add) # Parsers must use unsaved_tags to store tags, so we can clean them. # Accumulate for bulk application after the loop (O(unique_tags) instead of O(N·T)). cleaned_tags = clean_tags(finding.unsaved_tags) @@ -267,16 +260,13 @@ def process_findings( logger.debug("process_findings: computed push_to_jira=%s", push_to_jira) batch_finding_ids.append(finding.id) - # If batch is full or we're at the end, persist endpoints and dispatch + # If batch is full or we're at the end, persist locations/endpoints and dispatch if len(batch_finding_ids) >= batch_max_size or is_final_finding: - if not settings.V3_FEATURE_LOCATIONS: - self.endpoint_manager.persist(user=self.user) - + self.location_handler.persist() # Apply parser-supplied tags for this batch before post-processing starts, # so rules/deduplication tasks see the tags already on the findings. bulk_apply_parser_tags(findings_with_parser_tags) findings_with_parser_tags.clear() - finding_ids_batch = list(batch_finding_ids) batch_finding_ids.clear() logger.debug("process_findings: dispatching batch with push_to_jira=%s (batch_size=%d, is_final=%s)", @@ -404,9 +394,8 @@ def close_old_findings( finding_groups_enabled=self.findings_groups_enabled, product_grading_option=False, ) - # Persist any accumulated endpoint status mitigations - if not settings.V3_FEATURE_LOCATIONS: - self.endpoint_manager.persist(user=self.user) + # Persist any accumulated location/endpoint status changes + self.location_handler.persist() # push finding groups to jira since we only only want to push whole groups # We dont check if the finding jira sync is applicable quite yet until we can get in the loop # but this is a way to at least make it that far diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index 2a22da10a35..b63b7134701 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -15,10 +15,9 @@ find_candidates_for_reimport_legacy, ) from dojo.importers.base_importer import BaseImporter, Parser -from dojo.importers.endpoint_manager import EndpointManager +from dojo.importers.base_location_manager import LocationHandler from dojo.importers.options import ImporterOptions from dojo.jira_link.helper import is_keep_in_sync_with_jira -from dojo.location.status import FindingLocationStatus from dojo.models import ( Development_Environment, Finding, @@ -82,8 +81,7 @@ def __init__(self, *args, **kwargs): import_type=Test_Import.REIMPORT_TYPE, **kwargs, ) - if not settings.V3_FEATURE_LOCATIONS: - self.endpoint_manager = EndpointManager(self.test.engagement.product) + self.location_handler = LocationHandler(self.test.engagement.product) def process_scan( self, @@ -338,13 +336,7 @@ def process_findings( # Set the service supplied at import time if self.service is not None: unsaved_finding.service = self.service - if settings.V3_FEATURE_LOCATIONS: - # Clean any locations that are on the finding - self.location_manager.clean_unsaved_locations(unsaved_finding.unsaved_locations) - else: - # TODO: Delete this after the move to Locations - # Clean any endpoints that are on the finding - self.endpoint_manager.clean_unsaved_endpoints(unsaved_finding.unsaved_endpoints) + self.location_handler.clean_unsaved(unsaved_finding) # Calculate the hash code to be used to identify duplicates unsaved_finding.hash_code = self.calculate_unsaved_finding_hash_code(unsaved_finding) deduplicationLogger.debug(f"unsaved finding's hash_code: {unsaved_finding.hash_code}") @@ -380,27 +372,15 @@ def process_findings( continue # Update endpoints on the existing finding with those on the new finding if finding.dynamic_finding: - if settings.V3_FEATURE_LOCATIONS: - logger.debug( - "Re-import found an existing dynamic finding for this new " - "finding. Checking the status of locations", - ) - self.location_manager.update_location_status( - existing_finding, - unsaved_finding, - self.user, - ) - else: - # TODO: Delete this after the move to Locations - logger.debug( - "Re-import found an existing dynamic finding for this new " - "finding. Checking the status of endpoints", - ) - self.endpoint_manager.update_endpoint_status( - existing_finding, - unsaved_finding, - self.user, - ) + logger.debug( + "Re-import found an existing dynamic finding for this new " + "finding. Checking the status of locations/endpoints", + ) + self.location_handler.update_status( + existing_finding, + unsaved_finding, + self.user, + ) else: finding, finding_will_be_grouped = self.process_finding_that_was_not_matched(unsaved_finding) @@ -441,14 +421,11 @@ def process_findings( # - Deduplication batches: optimize bulk operations (larger batches = fewer queries) # They don't need to be aligned since they optimize different operations. if len(batch_finding_ids) >= dedupe_batch_max_size or is_final: - if not settings.V3_FEATURE_LOCATIONS: - self.endpoint_manager.persist(user=self.user) - + self.location_handler.persist() # Apply parser-supplied tags for this batch before post-processing starts, # so rules/deduplication tasks see the tags already on the findings. bulk_apply_parser_tags(findings_with_parser_tags) findings_with_parser_tags.clear() - finding_ids_batch = list(batch_finding_ids) batch_finding_ids.clear() dojo_dispatch_task( @@ -555,9 +532,8 @@ def close_old_findings( product_grading_option=False, ) mitigated_findings.append(finding) - # Persist any accumulated endpoint status mitigations - if not settings.V3_FEATURE_LOCATIONS: - self.endpoint_manager.persist(user=self.user) + # Persist any accumulated location/endpoint status changes + self.location_handler.persist() # push finding groups to jira since we only only want to push whole groups # We dont check if the finding jira sync is applicable quite yet until we can get in the loop # but this is a way to at least make it that far @@ -818,16 +794,7 @@ def process_matched_mitigated_finding( note = Notes(entry=f"Re-activated by {self.scan_type} re-upload.", author=self.user) note.save() - if settings.V3_FEATURE_LOCATIONS: - # Reactivate mitigated locations - mitigated_locations = existing_finding.locations.filter(status=FindingLocationStatus.Mitigated) - self.location_manager.chunk_locations_and_reactivate(mitigated_locations) - else: - # TODO: Delete this after the move to Locations - # Accumulate endpoint statuses for bulk reactivation in persist() - self.endpoint_manager.record_statuses_to_reactivate( - self.endpoint_manager.get_non_special_endpoint_statuses(existing_finding), - ) + self.location_handler.record_reactivations_for_finding(existing_finding) existing_finding.notes.add(note) self.reactivated_items.append(existing_finding) # The new finding is active while the existing on is mitigated. The existing finding needs to @@ -991,19 +958,10 @@ def finding_post_processing( Save all associated objects to the finding after it has been saved for the purpose of foreign key restrictions """ - if settings.V3_FEATURE_LOCATIONS: - self.location_manager.chunk_locations_and_disperse(finding, finding_from_report.unsaved_locations) - if len(self.endpoints_to_add) > 0: - self.location_manager.chunk_locations_and_disperse(finding, self.endpoints_to_add) - else: - # TODO: Delete this after the move to Locations - for endpoint in finding_from_report.unsaved_endpoints: - key = self.endpoint_manager.record_endpoint(endpoint) - self.endpoint_manager.record_status_for_create(finding, key) - if len(self.endpoints_to_add) > 0: - for endpoint in self.endpoints_to_add: - key = self.endpoint_manager.record_endpoint(endpoint) - self.endpoint_manager.record_status_for_create(finding, key) + # Copy unsaved items from the parser output onto the saved finding so record_for_finding can read them + finding.unsaved_locations = getattr(finding_from_report, "unsaved_locations", []) + finding.unsaved_endpoints = getattr(finding_from_report, "unsaved_endpoints", []) + self.location_handler.record_for_finding(finding, self.endpoints_to_add or None) # For matched/existing findings, do not update tags from the report, # consistent with how other fields are handled on reimport. if not is_matched_finding: diff --git a/dojo/importers/endpoint_manager.py b/dojo/importers/endpoint_manager.py index c909f921201..21e9673ed8c 100644 --- a/dojo/importers/endpoint_manager.py +++ b/dojo/importers/endpoint_manager.py @@ -6,6 +6,7 @@ from django.utils import timezone from hyperlink._url import SCHEME_PORT_MAP # noqa: PLC2701 +from dojo.importers.base_location_manager import BaseLocationManager from dojo.models import ( Dojo_User, Endpoint, @@ -30,13 +31,13 @@ class EndpointUniqueKey(NamedTuple): # TODO: Delete this after the move to Locations -class EndpointManager: +class EndpointManager(BaseLocationManager): def __init__(self, product: Product) -> None: self._product = product self._endpoints_to_create: dict[EndpointUniqueKey, dict] = {} self._statuses_to_create: list[tuple[Finding, EndpointUniqueKey]] = [] - self._statuses_to_mitigate: list[Endpoint_Status] = [] + self._statuses_to_mitigate: list[tuple[Endpoint_Status, Dojo_User | None]] = [] self._statuses_to_reactivate: list[Endpoint_Status] = [] @staticmethod @@ -158,13 +159,13 @@ def update_endpoint_status( new_finding_endpoints_list = new_finding.unsaved_endpoints if new_finding.is_mitigated: # New finding is mitigated, so mitigate all old endpoints - self._statuses_to_mitigate.extend(existing_finding_endpoint_status_list) + self._statuses_to_mitigate.extend((eps, user) for eps in existing_finding_endpoint_status_list) else: # Convert to set for O(1) lookups instead of O(n) linear search new_finding_endpoints_set = set(new_finding_endpoints_list) # Mitigate any endpoints in the old finding not found in the new finding self._statuses_to_mitigate.extend( - eps for eps in existing_finding_endpoint_status_list + (eps, user) for eps in existing_finding_endpoint_status_list if eps.endpoint not in new_finding_endpoints_set ) # Re-activate any endpoints in the old finding that are in the new finding @@ -177,9 +178,9 @@ def record_statuses_to_reactivate(self, statuses: list[Endpoint_Status]) -> None """Accumulate endpoint statuses for bulk reactivation in persist().""" self._statuses_to_reactivate.extend(statuses) - def record_statuses_to_mitigate(self, statuses: list[Endpoint_Status]) -> None: + def record_statuses_to_mitigate(self, statuses: list[Endpoint_Status], user: Dojo_User | None = None) -> None: """Accumulate endpoint statuses for bulk mitigation in persist().""" - self._statuses_to_mitigate.extend(statuses) + self._statuses_to_mitigate.extend((eps, user) for eps in statuses) def get_or_create_endpoints(self) -> tuple[dict[EndpointUniqueKey, Endpoint], list[Endpoint]]: """ @@ -238,7 +239,7 @@ def get_or_create_endpoints(self) -> tuple[dict[EndpointUniqueKey, Endpoint], li self._endpoints_to_create.clear() return endpoints_by_key, created - def persist(self, user: Dojo_User | None = None) -> None: + def persist(self) -> None: """ Persist all accumulated endpoint operations to the database. @@ -266,11 +267,11 @@ def persist(self, user: Dojo_User | None = None) -> None: if self._statuses_to_mitigate: now = timezone.now() to_update = [] - for endpoint_status in self._statuses_to_mitigate: + for endpoint_status, mitigated_by in self._statuses_to_mitigate: if endpoint_status.mitigated is False: endpoint_status.mitigated_time = now endpoint_status.last_modified = now - endpoint_status.mitigated_by = user + endpoint_status.mitigated_by = mitigated_by endpoint_status.mitigated = True to_update.append(endpoint_status) if to_update: @@ -300,3 +301,46 @@ def persist(self, user: Dojo_User | None = None) -> None: batch_size=1000, ) self._statuses_to_reactivate.clear() + + # ------------------------------------------------------------------ + # Unified interface (shared with LocationManager) + # ------------------------------------------------------------------ + + def clean_unsaved(self, finding: Finding) -> None: + """Clean the unsaved endpoints on this finding.""" + self.clean_unsaved_endpoints(finding.unsaved_endpoints) + + def record_for_finding(self, finding: Finding, extra_locations: list[Endpoint] | None = None) -> None: + """Record endpoints from the finding + any form-added extras for later batch creation.""" + self.clean_unsaved_endpoints(finding.unsaved_endpoints) + for endpoint in finding.unsaved_endpoints: + key = self.record_endpoint(endpoint) + self.record_status_for_create(finding, key) + if extra_locations: + for endpoint in extra_locations: + key = self.record_endpoint(endpoint) + self.record_status_for_create(finding, key) + + def update_status(self, existing_finding: Finding, new_finding: Finding, user: Dojo_User) -> None: + """Accumulate status changes (mitigate/reactivate) based on old vs new finding.""" + self.update_endpoint_status(existing_finding, new_finding, user) + + def record_reactivations_for_finding(self, finding: Finding) -> None: + """Record endpoint statuses on this finding for reactivation.""" + self.record_statuses_to_reactivate(self.get_non_special_endpoint_statuses(finding)) + + def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User) -> None: + """Record endpoint statuses on this finding for mitigation.""" + self.record_statuses_to_mitigate(finding.status_finding.all(), user) + + def get_locations_for_tagging(self, findings: list[Finding]): + """Return queryset of locations to apply tags to.""" + return Endpoint.objects.filter(finding__in=findings).distinct() + + def get_location_tag_fallback(self, finding: Finding): + """Return iterable of taggable locations for per-instance fallback.""" + return finding.endpoints.all() + + def serialize_extra_locations(self, locations: list) -> dict: + """Serialize extra locations for import history.""" + return {"endpoints": [str(ep) for ep in locations]} if locations else {} diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 125ff922dc0..94999538fc1 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -1,50 +1,420 @@ +from __future__ import annotations + import logging -from typing import TypeVar +from itertools import groupby +from operator import itemgetter +from typing import TYPE_CHECKING, NamedTuple from django.core.exceptions import ValidationError -from django.db.models import QuerySet +from django.db import transaction +from django.db.models import signals from django.utils import timezone -from dojo.celery import app -from dojo.celery_dispatch import dojo_dispatch_task -from dojo.location.models import AbstractLocation, LocationFindingReference -from dojo.location.status import FindingLocationStatus -from dojo.models import ( - Dojo_User, - Finding, -) +from dojo.importers.base_location_manager import BaseLocationManager +from dojo.location.models import AbstractLocation, Location, LocationFindingReference, LocationProductReference +from dojo.location.status import FindingLocationStatus, ProductLocationStatus +from dojo.models import Product, _manage_inherited_tags +from dojo.tags_signals import make_inherited_tags_sticky from dojo.tools.locations import LocationData from dojo.url.models import URL +from dojo.utils import get_system_setting + +if TYPE_CHECKING: + from tagulous.models import TagField + + from dojo.models import Dojo_User, Finding logger = logging.getLogger(__name__) -# TypeVar to represent unsaved locations coming from parsers. These might be existing AbstractLocations (when linking +# Unsaved locations coming from parsers. These might be existing AbstractLocations (when linking # existing endpoints) or LocationData objects sent by the parser. -UnsavedLocation = TypeVar("UnsavedLocation", LocationData, AbstractLocation) +UnsavedLocation = LocationData | AbstractLocation + + +# Entry for status update; status will be determined by comparing locations between existing and new findings +class StatusUpdateEntry(NamedTuple): + existing_finding: Finding + new_finding: Finding + user: Dojo_User + + +class LocationManager(BaseLocationManager): + + def __init__(self, product: Product) -> None: + super().__init__(product) + # Maps findings to a list of cleaned locations + self._locations_by_finding: dict[Finding, list[AbstractLocation]] = {} + # All locations needing product refs (finding-associated + product-only), cleaned at record time + self._product_locations: list[AbstractLocation] = [] + # Status update entries, which we'll use at persist-time to determine Location statuses by comparing + # existing vs new finding entries. + self._status_updates: list[StatusUpdateEntry] = [] + # IDs of finding refs to reactivate + self._refs_to_reactivate: list[int] = [] + # IDs of finding refs to mitigate, by the associated user + self._refs_to_mitigate: dict[int, Dojo_User] = {} + + # ------------------------------------------------------------------ + # Accumulation methods (no DB hits) + # ------------------------------------------------------------------ + + def record_locations_for_finding( + self, + finding: Finding, + locations: list[UnsavedLocation], + ) -> None: + """Record locations to be associated with a finding (and its product). Flushed by persist().""" + if locations: + cleaned = self.clean_unsaved_locations(locations) + self._locations_by_finding.setdefault(finding, []).extend(cleaned) + self._product_locations.extend(cleaned) + + def update_location_status( + self, + existing_finding: Finding, + new_finding: Finding, + user: Dojo_User, + ) -> None: + """Defer status update to persist(). No DB access at record time.""" + self._status_updates.append(StatusUpdateEntry(existing_finding, new_finding, user)) + + # ------------------------------------------------------------------ + # Unified interface (shared with EndpointManager) + # ------------------------------------------------------------------ + + def clean_unsaved(self, finding: Finding) -> None: + """Clean the unsaved locations on this finding.""" + self.clean_unsaved_locations(finding.unsaved_locations) + + def record_for_finding(self, finding: Finding, extra_locations: list[UnsavedLocation] | None = None) -> None: + """Record locations from the finding + any form-added extras for later batch creation.""" + self.record_locations_for_finding(finding, finding.unsaved_locations) + if extra_locations: + self.record_locations_for_finding(finding, extra_locations) + + def update_status(self, existing_finding: Finding, new_finding: Finding, user: Dojo_User) -> None: + """Accumulate status changes (mitigate/reactivate) based on old vs new finding.""" + self.update_location_status(existing_finding, new_finding, user) + + def record_reactivations_for_finding(self, finding: Finding) -> None: + """Defer reactivation to persist(). No DB access at record time.""" + self._refs_to_reactivate.append(finding.id) + + def record_mitigations_for_finding(self, finding: Finding, user: Dojo_User) -> None: + """Defer mitigation to persist(). No DB access at record time.""" + self._refs_to_mitigate[finding.id] = user + + def get_locations_for_tagging(self, findings: list[Finding]): + """Return queryset of locations to apply tags to.""" + return Location.objects.filter(findings__finding__in=findings).distinct() + + def get_location_tag_fallback(self, finding: Finding): + """Return iterable of taggable locations for per-instance fallback.""" + return [ref.location for ref in finding.locations.all()] + + def serialize_extra_locations(self, locations: list) -> dict: + """Serialize extra locations for import history.""" + return {"locations": [str(loc) for loc in locations]} if locations else {} + + # ------------------------------------------------------------------ + # Persist — flush all accumulated operations to DB + # ------------------------------------------------------------------ + + def persist(self) -> None: + """Persist all accumulated location operations to the database.""" + with transaction.atomic(): + self._persist_locations() + self._persist_status_updates() + + def _persist_locations(self) -> None: + """Bulk get/create all locations and their finding/product refs.""" + # _product_locations contains all locations to persist: associate with finding -> associate with product + if not self._product_locations: + return + + # Locations are already cleaned at record time (in record_locations_for_finding). Deduplicate the + # full set — _product_locations is the superset of all locations (finding-associated + product-only). + all_locations = list({(type(loc), loc.identity_hash): loc for loc in self._product_locations}.values()) + if not all_locations: + self._locations_by_finding.clear() + self._product_locations.clear() + return + + # Bulk persist all locations to the database + saved = self._bulk_get_or_create_locations(all_locations) + + # Build a lookup from (type, identity_hash) -> saved location for finding ref creation. + # identity_hash is only unique per concrete type, so we key by both. + saved_by_key: dict[tuple[type, str], AbstractLocation] = { + (type(loc), loc.identity_hash): loc for loc in saved + } + + # Lists for bulk creation + all_finding_refs = [] + all_product_refs = [] + # List of all location IDs, for querying existing refs + all_location_ids = [loc.location_id for loc in saved] + + # Determine necessary product refs to create + existing_product_refs: set[int] = set( + LocationProductReference.objects.filter( + location_id__in=all_location_ids, + product=self._product, + ).values_list("location_id", flat=True), + ) + for location in saved: + if location.location_id not in existing_product_refs: + assoc = location.get_association_data() + all_product_refs.append(LocationProductReference( + location_id=location.location_id, + product=self._product, + status=ProductLocationStatus.Active, + relationship=assoc.relationship_type, + relationship_data=assoc.relationship_data, + )) + existing_product_refs.add(location.location_id) + + # Determine necessary finding refs to create + if self._locations_by_finding: + all_finding_ids = [finding.id for finding in self._locations_by_finding] + # Strictly speaking this returns more rows than we need (it's the cross of the location/finding lists rather + # than scoped per-finding), but more straightforward than constructing a per-finding lookup. We won't create + # any unwanted associations below anyway. + existing_finding_ref_keys: set[tuple[int, int]] = set( + LocationFindingReference.objects.filter( + location_id__in=all_location_ids, + finding_id__in=all_finding_ids, + ).values_list("finding_id", "location_id"), + ) + + for finding, cleaned_locations in self._locations_by_finding.items(): + # Locations were already cleaned at record time — identity_hash is set, so we can + # look up the persisted location directly. + for location in cleaned_locations: + saved_loc = saved_by_key[type(location), location.identity_hash] + finding_ref_key = (finding.id, saved_loc.location_id) + if finding_ref_key not in existing_finding_ref_keys: + assoc = saved_loc.get_association_data() + all_finding_refs.append(LocationFindingReference( + location_id=saved_loc.location_id, + finding=finding, + status=FindingLocationStatus.Active, + relationship=assoc.relationship_type, + relationship_data=assoc.relationship_data, + )) + existing_finding_ref_keys.add(finding_ref_key) + + # Bulk create references + if all_finding_refs: + LocationFindingReference.objects.bulk_create( + all_finding_refs, batch_size=1000, ignore_conflicts=True, + ) + if all_product_refs: + LocationProductReference.objects.bulk_create( + all_product_refs, batch_size=1000, ignore_conflicts=True, + ) + + # Trigger bulk tag inheritance + self._bulk_inherit_tags(loc.location for loc in saved) + + # Clear accumulators + self._locations_by_finding.clear() + self._product_locations.clear() + + def _persist_status_updates(self) -> None: + """ + Bulk persist finding/product ref statuses. + + Throughout the (re)import process, we've tracked three types of status changes: locations to mitigate, locations + to reactivate, and locations whose statuses need to be evaluated at this time by comparing locations between + existing findings and new findings. + + To start, this method processes the comparisons between existing/new findings. If the new finding is Mitigated, + then all existing locations are added to the 'to mitigate' set. Otherwise, locations that are in both the new + finding and existing finding are added to the 'to reactivate' set, and locations that are on the existing + finding but not the new finding are added to the 'to mitigate' set. + + Next, all locations in the 'to reactivate' set are bulk set to Active, and all locations in the 'to mitigate' + set are bulk set to Mitigated. + + Finally, product associations are updated: if any location associated with a finding on this product is Active, + the LocationProductReference object is set to Active; otherwise, it is set to Mitigated. + """ + # Short-circuit if nothing to do + if not (self._status_updates or self._refs_to_reactivate or self._refs_to_mitigate): + return + + # List of statuses we'll skip processing changes for + special_statuses = [ + FindingLocationStatus.FalsePositive, + FindingLocationStatus.RiskAccepted, + FindingLocationStatus.OutOfScope, + ] + + # The set of LocationFindingReference IDs to reactivate + ref_ids_to_reactivate: set[int] = set() + # The set of LocationFindingReference IDs to mitigate, and the user to associate with it + ref_ids_to_mitigate: dict[int, Dojo_User] = {} + + # Process status updates determined by comparing existing/new findings + if self._status_updates: + # Look up all the existing LocationFindingReference objects and store per-Finding + existing_finding_ids = {upd.existing_finding.id for upd in self._status_updates} + refs_by_finding: dict[int, list[LocationFindingReference]] = {} + for ref in ( + LocationFindingReference.objects + .filter(finding_id__in=existing_finding_ids) + .exclude(status__in=special_statuses) + .select_related("location") + ): + refs_by_finding.setdefault(ref.finding_id, []).append(ref) + + # Next: for each StatusUpdateEntry, determine what we should do with the existing refs + for existing_finding, new_finding, user in self._status_updates: + finding_refs = refs_by_finding.get(existing_finding.id, []) + if new_finding.is_mitigated: + # The new finding is mitigated, so mitigate all existing (non-special) refs + ref_ids_to_mitigate.update({r.id: user for r in finding_refs}) + else: + # The new finding is not mitigated; we need to reactivate locations that are in the new finding and + # mitigate statuses that are NOT in the new finding. + new_loc_values = { + str(loc) for loc in self.clean_unsaved_locations(new_finding.unsaved_locations) + } + for ref in finding_refs: + if ref.location.location_value in new_loc_values: + ref_ids_to_reactivate.add(ref.id) + else: + ref_ids_to_mitigate[ref.id] = user + + # Update the "reactivate set" with the IDs of existing LocationFindingReference objects we need to reactivate + if self._refs_to_reactivate: + ref_ids_to_reactivate.update( + LocationFindingReference.objects.filter( + finding_id__in=self._refs_to_reactivate, + status=FindingLocationStatus.Mitigated, + ).values_list("id", flat=True), + ) + + # Update the "mitigate set" with the IDs of existing LocationFindingReference objects we need to mitigate. + # Note we exclude LocationFindingReferences that currently have one of the special statuses. + if self._refs_to_mitigate: + ref_ids_to_mitigate.update({ + ref_id: self._refs_to_mitigate[finding_id] + for ref_id, finding_id in LocationFindingReference.objects.filter( + finding_id__in=self._refs_to_mitigate.keys(), + ).exclude(status__in=special_statuses).values_list("id", "finding_id") + }) + + # Hoorah we finally get around to actually updating stuff + now = timezone.now() + # Track all updated LocationFindingReference IDs so we can update the corresponding LocationProductReferences + # as necessary: if any LocationFindingReference is Active, the LocationProductReferences will be set to Active; + # otherwise, they will be set to Mitigated. + all_affected_ref_ids: set[int] = set() + + # Update Mitigated => Active ("reactivate") + if ref_ids_to_reactivate: + LocationFindingReference.objects.filter( + id__in=ref_ids_to_reactivate, + status=FindingLocationStatus.Mitigated, + ).update( + auditor=None, + audit_time=now, + status=FindingLocationStatus.Active, + ) + all_affected_ref_ids |= ref_ids_to_reactivate + + # Update ~Mitigated => Mitigated + if ref_ids_to_mitigate: + # Flip (ref_id -> user) to (user -> set[ref_id]) for per-user bulk updates + ref_ids_per_user: dict[Dojo_User, set[int]] = {} + for ref_id, user in ref_ids_to_mitigate.items(): + ref_ids_per_user.setdefault(user, set()).add(ref_id) + # Update per user + for user, ref_ids in ref_ids_per_user.items(): + LocationFindingReference.objects.filter( + id__in=ref_ids, + ).exclude( + status=FindingLocationStatus.Mitigated, + ).update( + auditor=user, + audit_time=now, + status=FindingLocationStatus.Mitigated, + ) + all_affected_ref_ids |= ref_ids + + # Propagate to product refs: if any finding ref for this location on this product is Active, product ref is + # Active; otherwise Mitigated. + if all_affected_ref_ids: + # Grab the location IDs for all the LocationFindingReferences we updated + affected_location_ids = set( + LocationFindingReference.objects.filter( + id__in=all_affected_ref_ids, + ).values_list("location_id", flat=True), + ) + # Look up all affected LocationFindingReferences that are now Active and associated with this product + # through the "finding.test.engagement.product" chain + locations_still_active = set( + LocationFindingReference.objects.filter( + location_id__in=affected_location_ids, + finding__test__engagement__product=self._product, + status=FindingLocationStatus.Active, + ).values_list("location_id", flat=True), + ) + # Diff the two; this leaves IDs of locations that should be set to Mitigated at the product level + locations_now_mitigated = affected_location_ids - locations_still_active + # Update LocationProductReferences to Active for any locations associated with this product that have an + # Active LocationFindingReference + if locations_still_active: + LocationProductReference.objects.filter( + location_id__in=locations_still_active, + product=self._product, + ).exclude(status=ProductLocationStatus.Active).update( + status=ProductLocationStatus.Active, + ) + # Update LocationProductReferences to Mitigated for any locations associated with this product that have no + # Active LocationFindingReferences + if locations_now_mitigated: + LocationProductReference.objects.filter( + location_id__in=locations_now_mitigated, + product=self._product, + ).exclude(status=ProductLocationStatus.Mitigated).update( + status=ProductLocationStatus.Mitigated, + ) + + # Clear accumulators + self._status_updates.clear() + self._refs_to_reactivate.clear() + self._refs_to_mitigate.clear() + + # ------------------------------------------------------------------ + # Type registry + # ------------------------------------------------------------------ -# test_notifications.py: Implement Locations -class LocationManager: @classmethod - def get_or_create_location(cls, unsaved_location: AbstractLocation) -> AbstractLocation | None: - """Gets/creates the given AbstractLocation.""" - if isinstance(unsaved_location, URL): - return URL.get_or_create_from_object(unsaved_location) - logger.debug(f"IMPORT_SCAN: Unsupported location type: {type(unsaved_location)}") - return None + def get_supported_location_types(cls) -> dict[str, type[AbstractLocation]]: + """Return a mapping of location type string to AbstractLocation subclass.""" + return {URL.get_location_type(): URL} + + # ------------------------------------------------------------------ + # Cleaning / conversion utilities + # ------------------------------------------------------------------ @classmethod def make_abstract_locations(cls, locations: list[UnsavedLocation]) -> list[AbstractLocation]: """Converts the list of unsaved locations (AbstractLocation/LocationData objects) to a list of AbstractLocations.""" + supported_types = cls.get_supported_location_types() abstract_locations = [] for location in locations: if isinstance(location, AbstractLocation): abstract_locations.append(location) - elif isinstance(location, LocationData) and location.type == URL.get_location_type(): + elif isinstance(location, LocationData) and (loc_cls := supported_types.get(location.type)): try: - abstract_locations.append(URL.from_location_data(location)) + abstract_locations.append(loc_cls.from_location_data(location)) except (ValidationError, ValueError): logger.debug("Skipping invalid location data: %s", location) else: @@ -52,72 +422,6 @@ def make_abstract_locations(cls, locations: list[UnsavedLocation]) -> list[Abstr return abstract_locations - @classmethod - def _add_locations_to_unsaved_finding( - cls, - finding: Finding, - locations: list[UnsavedLocation], - **kwargs: dict, # noqa: ARG003 - ) -> None: - """Creates AbstractLocation objects from the given list and links them to the given finding.""" - locations = cls.clean_unsaved_locations(locations) - - logger.debug(f"IMPORT_SCAN: Adding {len(locations)} locations to finding: {finding}") - - # LOCATION LOCATION LOCATION - # TODO: bulk create the finding/product refs... - locations_saved = 0 - for unsaved_location in locations: - if saved_location := cls.get_or_create_location(unsaved_location): - locations_saved += 1 - association_data = unsaved_location.get_association_data() - saved_location.location.associate_with_finding( - finding, - status=FindingLocationStatus.Active, - relationship=association_data.relationship_type, - relationship_data=association_data.relationship_data, - ) - - logger.debug(f"IMPORT_SCAN: {locations_saved} locations imported") - - @app.task - def add_locations_to_unsaved_finding( - manager_cls_path: str, # noqa: N805 - finding: Finding, - locations: list[UnsavedLocation], - **kwargs: dict, - ) -> None: - """Celery task that resolves the LocationManager class and delegates to _add_locations_to_unsaved_finding.""" - from django.utils.module_loading import import_string # noqa: PLC0415 - - manager_cls = import_string(manager_cls_path) - manager_cls._add_locations_to_unsaved_finding(finding, locations, **kwargs) - - @app.task - def mitigate_location_status( - location_refs: QuerySet[LocationFindingReference], # noqa: N805 - user: Dojo_User, - **kwargs: dict, - ) -> None: - """Mitigate all given (non-mitigated) location refs""" - location_refs.exclude(status=FindingLocationStatus.Mitigated).update( - auditor=user, - audit_time=timezone.now(), - status=FindingLocationStatus.Mitigated, - ) - - @app.task - def reactivate_location_status( - location_refs: QuerySet[LocationFindingReference], # noqa: N805 - **kwargs: dict, - ) -> None: - """Reactivate all given (mitigated) locations refs""" - location_refs.filter(status=FindingLocationStatus.Mitigated).update( - auditor=None, - audit_time=timezone.now(), - status=FindingLocationStatus.Active, - ) - @classmethod def clean_unsaved_locations( cls, @@ -135,58 +439,140 @@ def clean_unsaved_locations( logger.warning("DefectDojo is storing broken locations because cleaning wasn't successful: %s", e) return locations - def update_location_status( - self, - existing_finding: Finding, - new_finding: Finding, - user: Dojo_User, - **kwargs: dict, - ) -> None: - """Update the list of locations from the new finding with the list that is in the old finding""" - # New endpoints are already added in serializers.py / views.py (see comment "# for existing findings: make sure endpoints are present or created") - # So we only need to mitigate endpoints that are no longer present - existing_location_refs: QuerySet[LocationFindingReference] = existing_finding.locations.exclude( - status__in=[ - FindingLocationStatus.FalsePositive, - FindingLocationStatus.RiskAccepted, - FindingLocationStatus.OutOfScope, - ], - ) - if new_finding.is_mitigated: - # New finding is mitigated, so mitigate all existing location refs - self.chunk_locations_and_mitigate(existing_location_refs, user) - else: - new_locations_values = [str(location) for location in type(self).clean_unsaved_locations(new_finding.unsaved_locations)] - # Reactivate endpoints in the old finding that are in the new finding - location_refs_to_reactivate = existing_location_refs.filter(location__location_value__in=new_locations_values) - # Mitigate endpoints in the existing finding not in the new finding - location_refs_to_mitigate = existing_location_refs.exclude(location__location_value__in=new_locations_values) - - self.chunk_locations_and_reactivate(location_refs_to_reactivate) - self.chunk_locations_and_mitigate(location_refs_to_mitigate, user) - - def chunk_locations_and_disperse( - self, - finding: Finding, - locations: list[UnsavedLocation], - **kwargs: dict, - ) -> None: + # ------------------------------------------------------------------ + # Bulk internals + # ------------------------------------------------------------------ + + @classmethod + def _bulk_get_or_create_locations(cls, locations: list[AbstractLocation]) -> list[AbstractLocation]: + """ + Bulk get-or-create a (possibly heterogeneous) list of AbstractLocations. + + The input list may contain a mix of AbstractLocation instances. This method + groups them by concrete type, delegates each group to that type's bulk_get_or_create, + then reassembles results in the original input order. + """ + if not locations: + return [] + + # Keying function: group by the (Python) identity of the concrete class (e.g., URL vs Dependency). + # Using id() because class objects aren't sortable. + def type_id(x: tuple[int, AbstractLocation]) -> int: + return id(type(x[1])) + + saved = [] + # Sort by type, tracking the original index via enumerate so we can restore order later + locations_with_idx = sorted(enumerate(locations), key=type_id) + # Now group by type + locations_by_type = groupby(locations_with_idx, key=type_id) + for _, grouped_locations_with_idx in locations_by_type: + # Split into parallel lists: original indices and the homogeneous location objects + indices, grouped_locations = zip(*grouped_locations_with_idx, strict=True) + # Determine the concrete AbstractLocation subclass (URL, Dependency, etc.) + loc_cls = type(grouped_locations[0]) + # Delegate to the per-type bulk_get_or_create on AbstractLocation + saved_locations = loc_cls.bulk_get_or_create(grouped_locations) + # Pair each result back with its original index + saved.extend((idx, saved_loc) for idx, saved_loc in zip(indices, saved_locations, strict=True)) + + # Restore the original input ordering + saved.sort(key=itemgetter(0)) + return [loc for _, loc in saved] + + # ------------------------------------------------------------------ + # Tag inheritance + # ------------------------------------------------------------------ + + def _bulk_inherit_tags(self, locations): + """ + Bulk equivalent of calling inherit_instance_tags(loc) for many Locations. Actually persisting updates is handled + by a per-location call to _manage_inherited_tags(), but at least determining what the tags are is more efficient + (plus we can skip locations that don't need an update at all). + + When tag inheritance is enabled, computes the target inherited tags for each location from all related products + and updates only locations that are out of sync. + """ + locations = list(locations) if not locations: return - cls_path = f"{type(self).__module__}.{type(self).__qualname__}" - dojo_dispatch_task(self.add_locations_to_unsaved_finding, cls_path, finding, locations, sync=True) - def chunk_locations_and_reactivate( - self, - location_refs: QuerySet[LocationFindingReference], - **kwargs: dict, - ) -> None: - dojo_dispatch_task(self.reactivate_location_status, location_refs, sync=True) + # Check whether tag inheritance is enabled at either the product level or system-wide; quit early if neither + product_inherit = getattr(self._product, "enable_product_tag_inheritance", False) + system_wide_inherit = bool(get_system_setting("enable_product_tag_inheritance")) + if not system_wide_inherit and not product_inherit: + return - def chunk_locations_and_mitigate( - self, - location_refs: QuerySet[LocationFindingReference], - user: Dojo_User, - **kwargs: dict, - ) -> None: - dojo_dispatch_task(self.mitigate_location_status, location_refs, user, sync=True) + # A location can be shared across multiple products. Its inherited tags should be the union of + # tags from ALL contributing products, not just the one running this import. + location_ids = [loc.id for loc in locations] + product_ids_by_location: dict[int, set[int]] = {loc.id: set() for loc in locations} + + # Find associations through LocationProductReference entries + for loc_id, prod_id in LocationProductReference.objects.filter( + location_id__in=location_ids, + ).values_list("location_id", "product_id"): + product_ids_by_location[loc_id].add(prod_id) + + # Find associations through LocationFindingReference entries and the finding.test.engagement.product chain. + # This shouldn't add anything new, but just in case. + for loc_id, prod_id in ( + LocationFindingReference.objects + .filter(location_id__in=location_ids) + .values_list("location_id", "finding__test__engagement__product_id") + ): + product_ids_by_location[loc_id].add(prod_id) + + # Fetch all products that will contribute to tag inheritance, and their tags + all_product_ids = {pid for pids in product_ids_by_location.values() for pid in pids} + product_qs = Product.objects.filter(id__in=all_product_ids).prefetch_related("tags") + if not system_wide_inherit: + # Product-level inheritance only + product_qs = product_qs.filter(enable_product_tag_inheritance=True) + # Materialize into a dict for ease of use + products: dict[int, Product] = {p.id: p for p in product_qs} + # Get distinct tags, per-product + tags_by_product: dict[int, set[str]] = { + pid: {t.name for t in p.tags.all()} + for pid, p in products.items() + } + + # Helper method for getting all tags from the given TagField + def _get_tags(tags_field: TagField) -> dict[int, set[str]]: + through_model = tags_field.through + fk_name = tags_field.field.m2m_reverse_field_name() + tags_by_location: dict[int, set[str]] = {loc.id: set() for loc in locations} + for l_id, t_name in through_model.objects.filter( + location_id__in=location_ids, + ).values_list("location_id", f"{fk_name}__name"): + tags_by_location[l_id].add(t_name) + return tags_by_location + + # Gather inherited and 'regular' tags per location + existing_inherited_by_location: dict[int, set[str]] = _get_tags(Location.inherited_tags) + existing_tags_by_location: dict[int, set[str]] = _get_tags(Location.tags) + + # Perform the bulk updates. First, though, disconnect the make_inherited_tags_sticky signal on Location.tags + # while updating, otherwise each (inherited_)tags.set() will trigger, defeating the purpose of this bulk update. + disconnected = signals.m2m_changed.disconnect(make_inherited_tags_sticky, sender=Location.tags.through) + try: + for location in locations: + target_tag_names: set[str] = set() + for pid in product_ids_by_location[location.id]: + # product_ids_by_location may contain products that shouldn't to contribute to tag inheritance (we + # didn't filter either location ref lookups to check), so do a last-minute check here + if pid in products: + target_tag_names |= tags_by_product[pid] + + if target_tag_names == existing_inherited_by_location[location.id]: + # The existing set matches the expected set, so nothing more to do for this location + continue + + # Update tags for this location + _manage_inherited_tags( + location, + list(target_tag_names), + potentially_existing_tags=existing_tags_by_location[location.id], + ) + finally: + if disconnected: + signals.m2m_changed.connect(make_inherited_tags_sticky, sender=Location.tags.through) diff --git a/dojo/location/models.py b/dojo/location/models.py index 3ab313ace87..48ffeb6878a 100644 --- a/dojo/location/models.py +++ b/dojo/location/models.py @@ -1,7 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Self, TypeVar +import hashlib +from typing import TYPE_CHECKING, Self +from django.core.validators import MinLengthValidator from django.db import transaction from django.db.models import ( CASCADE, @@ -36,6 +38,7 @@ from dojo.tools.locations import LocationAssociationData if TYPE_CHECKING: + from collections.abc import Iterable from datetime import datetime from dojo.tools.locations import LocationData @@ -251,10 +254,6 @@ class Meta: ] -# TypeVar to help linting in AbstractLocation child classes -T = TypeVar("T", bound="AbstractLocation") - - class AbstractLocation(BaseModelWithoutTimeMeta): location = OneToOneField( Location, @@ -263,10 +262,30 @@ class AbstractLocation(BaseModelWithoutTimeMeta): null=False, related_name="%(class)s", ) + identity_hash = CharField( + null=False, + blank=False, + max_length=64, + editable=False, + unique=True, + db_index=True, + validators=[MinLengthValidator(64)], + help_text=_("The hash of the location for uniqueness"), + ) class Meta: abstract = True + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: object) -> bool: + return isinstance(other, type(self)) and str(self) == str(other) + + def clean(self, *args: list, **kwargs: dict) -> None: + self.set_identity_hash() + super().clean(*args, **kwargs) + @classmethod def get_location_type(cls) -> str: """Return the type of location (e.g., 'url').""" @@ -279,7 +298,7 @@ def get_location_value(self) -> str: raise NotImplementedError(msg) @staticmethod - def create_location_from_value(value: str) -> T: + def create_location_from_value(value: str) -> Self: """ Dynamically create a Location and subclass instance based on location_type and location_value. Uses parse_string_value from the correct subclass. @@ -287,6 +306,9 @@ def create_location_from_value(value: str) -> T: msg = "Subclasses must implement create_location_from_value" raise NotImplementedError(msg) + def set_identity_hash(self): + self.identity_hash = hashlib.blake2b(str(self).encode(), digest_size=32).hexdigest() + def pre_save_logic(self): """Automatically create or update the associated Location.""" location_value = self.get_location_value() @@ -303,7 +325,7 @@ def pre_save_logic(self): self.location.save(update_fields=["location_type", "location_value"]) @classmethod - def from_location_data(cls: T, location_data: LocationData) -> T: + def from_location_data(cls, location_data: LocationData) -> Self: """ Checks that the given LocationData object represents this type, then calls #_from_location_data_impl() to build one based on its contents. Saving boilerplate checking is all. @@ -314,7 +336,7 @@ def from_location_data(cls: T, location_data: LocationData) -> T: return cls._from_location_data_impl(location_data) @classmethod - def _from_location_data_impl(cls: T, location_data: LocationData) -> T: + def _from_location_data_impl(cls, location_data: LocationData) -> Self: """Given a LocationData object trusted to represent this type, build a Location object from its contents.""" msg = "Subclasses must implement _from_location_data_impl" raise NotImplementedError(msg) @@ -328,11 +350,82 @@ def get_association_data(self) -> LocationAssociationData: return getattr(self, "_association_data", LocationAssociationData()) @classmethod - def get_or_create_from_object(cls: T, location: T) -> T: + def get_or_create_from_object(cls, location: Self) -> Self: """Given an object of this type, this method should get/create the object and return it.""" msg = "Subclasses must implement get_or_create_from_object" raise NotImplementedError(msg) + @classmethod + def bulk_get_or_create(cls, locations: Iterable[Self]) -> list[Self]: + """ + Get or create multiple locations in bulk. + + For each location, looks up by identity_hash. Creates missing ones using + bulk_create for both the parent Location rows and the subtype rows. + Returns the full list of saved instances (existing + newly created), + in the same order as the input. Duplicate inputs map to the same saved instance. + """ + if not locations: + return [] + + # Create the list of hashes of the supplied locations; we will also use this to reconstruct the initial ordering + # of locations we return (which would otherwise be lost if duplicates are represented in `locations`). + hashes = [] + for loc in locations: + # Sanity check the given locations list is homogenous + if not isinstance(loc, cls): + error_message = f"Invalid location type; expected {cls} but got {type(loc)}" + raise TypeError(error_message) + loc.clean() + hashes.append(loc.identity_hash) + + # Look up existing objects, grouping by hash + existing_by_hash = { + obj.identity_hash: obj + for obj in cls.objects.filter(identity_hash__in=hashes).select_related("location") + } + + # Create the list of new locations to create + new_locations = [] + for loc in locations: + if loc.identity_hash not in existing_by_hash: + new_locations.append(loc) + # Mark it so we don't try to create duplicates within the same batch + existing_by_hash[loc.identity_hash] = loc + else: + # Preserve association data from the input onto the existing saved object, in case we're associating + # existing locations with findings/products + saved = existing_by_hash[loc.identity_hash] + if hasattr(loc, "_association_data") and not hasattr(saved, "_association_data"): + saved._association_data = loc._association_data + + # Create 'em + if new_locations: + location_type = cls.get_location_type() + with transaction.atomic(): + # Bulk create parent Locations + parents = [ + Location( + location_type=location_type, + location_value=loc.get_location_value(), + ) + for loc in new_locations + ] + Location.objects.bulk_create(parents, batch_size=1000) + # Assign Location FKs to the subtypes, then bulk create them. + for loc, parent in zip(new_locations, parents, strict=True): + loc.location_id = parent.id + loc.location = parent + # Note: there is a subtle potential race condition here, if somehow one of the locations to be created + # has already been created, e.g. by a separate thread that commits while this thread is running. Setting + # `ignore_conflicts=True` here would prevent this step from raising an IntegrityError, but would leave + # dangling parent Location objects that were created above. Rather than performing a cleanup in that + # (unlikely?) case, just allow the transaction to rollback. + cls.objects.bulk_create(new_locations, batch_size=1000) + + # Return in input order + return [existing_by_hash[h] for h in hashes] + class ReferenceDataMixin(Model): diff --git a/dojo/url/models.py b/dojo/url/models.py index 06f7e2cd008..6e1358cfac3 100644 --- a/dojo/url/models.py +++ b/dojo/url/models.py @@ -1,6 +1,5 @@ from __future__ import annotations -import hashlib import ipaddress from contextlib import suppress from dataclasses import dataclass @@ -9,7 +8,7 @@ import idna from django.core.exceptions import ValidationError -from django.core.validators import MaxValueValidator, MinLengthValidator, MinValueValidator +from django.core.validators import MaxValueValidator, MinValueValidator from django.db import IntegrityError, transaction from django.db.models import ( BooleanField, @@ -185,15 +184,6 @@ class URL(AbstractLocation): blank=False, help_text="Dictates whether the endpoint was found to have host validation issues during creation", ) - identity_hash = CharField( - null=False, - blank=False, - max_length=64, - editable=False, - unique=True, - validators=[MinLengthValidator(64)], - help_text="The hash of the URL for uniqueness", - ) objects = URLManager().from_queryset(URLQueryset)() @@ -235,12 +225,6 @@ def __str__(self) -> str: return URL.URL_PARSING_CLASS().unparse(self) return self.manual_str() - def __hash__(self) -> int: - return hash(str(self)) - - def __eq__(self, other: object) -> bool: - return isinstance(other, URL) and str(self) == str(other) - @classmethod def get_location_type(cls) -> str: return cls.LOCATION_TYPE @@ -262,7 +246,6 @@ def clean(self, *args: list, **kwargs: dict) -> None: self.clean_path() self.clean_query() self.clean_fragment() - self.set_identity_hash() super().clean(*args, **kwargs) def clean_protocol(self) -> None: @@ -323,9 +306,6 @@ def clean_query(self) -> None: else: self.query = self.replace_null_bytes(self.query.strip().removeprefix("?")) - def set_identity_hash(self): - self.identity_hash = hashlib.blake2b(str(self).encode(), digest_size=32).hexdigest() - def replace_null_bytes(self, value: str) -> str: return value.replace("\x00", "%00") @@ -373,7 +353,7 @@ def from_parts( query=None, fragment=None, ) -> URL: - url = URL( + return URL( protocol=protocol, user_info=user_info, host=host, @@ -382,8 +362,6 @@ def from_parts( query=query, fragment=fragment, ) - url.clean() - return url @staticmethod def create_location_from_value(value: str) -> URL: diff --git a/unittests/dojo_test_case.py b/unittests/dojo_test_case.py index 66add0baa1c..098aa77376d 100644 --- a/unittests/dojo_test_case.py +++ b/unittests/dojo_test_case.py @@ -540,17 +540,22 @@ def get_latest_model(self, model): return model.objects.order_by("id").last() def get_unsaved_locations(self, finding): - if settings.V3_FEATURE_LOCATIONS: - return LocationManager.make_abstract_locations(finding.unsaved_locations) - # TODO: Delete this after the move to Locations - return finding.unsaved_endpoints + if not hasattr(finding, "_cached_unsaved_locations"): + if settings.V3_FEATURE_LOCATIONS: + locations = LocationManager.make_abstract_locations(finding.unsaved_locations) + else: + # TODO: Delete this after the move to Locations + locations = finding.unsaved_endpoints + for loc in locations: + loc.clean() + finding._cached_unsaved_locations = locations + return finding._cached_unsaved_locations def validate_locations(self, findings): for finding in findings: - # AND SEVERITY HAHAHAHA self.assertIn(finding.severity, Finding.SEVERITIES) - for location in self.get_unsaved_locations(finding): - location.clean() + # get_unsaved_locations handles conversion + cleaning + caching + self.get_unsaved_locations(finding) class DojoTestCase(TestCase, DojoTestUtilsMixin): diff --git a/unittests/test_bulk_locations.py b/unittests/test_bulk_locations.py new file mode 100644 index 00000000000..bc1ae7a02c9 --- /dev/null +++ b/unittests/test_bulk_locations.py @@ -0,0 +1,633 @@ +""" +Tests for bulk location creation and association (open-source, URL-only). + +Covers: +- AbstractLocation.bulk_get_or_create (on URL) +- LocationManager._bulk_get_or_create_locations (URL-only) +- LocationManager.record_locations_for_finding + persist (accumulator pattern) +- Query efficiency +""" + +from unittest.mock import patch + +from django.contrib.auth import get_user_model +from django.db import connection +from django.test.utils import CaptureQueriesContext +from django.utils import timezone + +from dojo.importers.location_manager import LocationManager +from dojo.location.models import Location, LocationFindingReference, LocationProductReference +from dojo.location.status import FindingLocationStatus, ProductLocationStatus +from dojo.models import Engagement, Finding, Product, Product_Type, Test, Test_Type +from dojo.tools.locations import LocationAssociationData, LocationData +from dojo.url.models import URL +from unittests.dojo_test_case import DojoTestCase, skip_unless_v2, skip_unless_v3 + +User = get_user_model() + + +def _make_url(host, path=""): + url = URL(protocol="https", host=host, path=path) + url.clean() + return url + + +_finding_counter = 0 + + +def _make_finding(): + global _finding_counter # noqa: PLW0603 + _finding_counter += 1 + now = timezone.now() + user, _ = User.objects.get_or_create(username="bulk_test_user", defaults={"is_active": True}) + pt, _ = Product_Type.objects.get_or_create(name="Bulk Test Type") + product = Product.objects.create(name=f"Bulk Test Product {_finding_counter}", description="test", prod_type=pt) + eng = Engagement.objects.create(product=product, target_start=now, target_end=now) + tt, _ = Test_Type.objects.get_or_create(name="Bulk Test") + test = Test.objects.create(engagement=eng, test_type=tt, target_start=now, target_end=now) + return Finding.objects.create(test=test, title="Bulk Test Finding", severity="Medium", reporter=user) + + +# --------------------------------------------------------------------------- +# AbstractLocation.bulk_get_or_create (URL) +# --------------------------------------------------------------------------- +@skip_unless_v3 +class TestBulkGetOrCreateURL(DojoTestCase): + + def test_all_new(self): + urls = [_make_url(f"oss-new-{i}.example.com") for i in range(5)] + saved = URL.bulk_get_or_create(urls) + + self.assertEqual(len(saved), 5) + self.assertTrue(all(s.pk is not None for s in saved)) + self.assertTrue(all(s.location_id is not None for s in saved)) + self.assertEqual(URL.objects.filter(pk__in=[s.pk for s in saved]).count(), 5) + + def test_all_existing(self): + originals = [URL.get_or_create_from_object(_make_url(f"oss-existing-{i}.example.com")) for i in range(3)] + + urls = [_make_url(f"oss-existing-{i}.example.com") for i in range(3)] + saved = URL.bulk_get_or_create(urls) + + self.assertEqual(len(saved), 3) + self.assertEqual({s.pk for s in saved}, {o.pk for o in originals}) + + def test_mixed_new_and_existing(self): + existing = URL.get_or_create_from_object(_make_url("oss-mixed-existing.example.com")) + + urls = [ + _make_url("oss-mixed-existing.example.com"), + _make_url("oss-mixed-new.example.com"), + ] + saved = URL.bulk_get_or_create(urls) + + self.assertEqual(len(saved), 2) + self.assertEqual(saved[0].pk, existing.pk) + self.assertIsNotNone(saved[1].pk) + + def test_duplicates_in_input(self): + urls = [ + _make_url("oss-dupe.example.com"), + _make_url("oss-dupe.example.com"), + _make_url("oss-unique.example.com"), + ] + saved = URL.bulk_get_or_create(urls) + + self.assertEqual(len(saved), 3) + self.assertEqual(saved[0].pk, saved[1].pk) + self.assertNotEqual(saved[2].pk, saved[0].pk) + self.assertEqual(URL.objects.filter(host__in=["oss-dupe.example.com", "oss-unique.example.com"]).count(), 2) + + def test_preserves_association_data_on_new(self): + url = _make_url("oss-assoc-new.example.com") + url._association_data = LocationAssociationData( + relationship_type="owned_by", + relationship_data={"file_path": "/src/main.py"}, + ) + + saved = URL.bulk_get_or_create([url]) + + self.assertEqual(saved[0].get_association_data().relationship_type, "owned_by") + + def test_copies_association_data_to_existing(self): + URL.get_or_create_from_object(_make_url("oss-assoc-existing.example.com")) + + url = _make_url("oss-assoc-existing.example.com") + url._association_data = LocationAssociationData(relationship_type="used_by") + + saved = URL.bulk_get_or_create([url]) + + self.assertEqual(saved[0].get_association_data().relationship_type, "used_by") + + def test_empty_input(self): + self.assertEqual(URL.bulk_get_or_create([]), []) + + def test_parent_location_created(self): + saved = URL.bulk_get_or_create([_make_url("oss-parent.example.com")]) + + loc = Location.objects.get(pk=saved[0].location_id) + self.assertEqual(loc.location_type, "url") + self.assertIn("oss-parent.example.com", loc.location_value) + + def test_transaction_atomicity(self): + initial_count = Location.objects.count() + urls = [_make_url("oss-atomic.example.com")] + + with patch.object(URL.objects, "bulk_create", side_effect=Exception("boom")), \ + self.assertRaisesMessage(Exception, "boom"): + URL.bulk_get_or_create(urls) + + self.assertEqual(Location.objects.count(), initial_count) + + +# --------------------------------------------------------------------------- +# LocationManager._bulk_get_or_create_locations (URL-only) +# --------------------------------------------------------------------------- +@skip_unless_v3 +class TestBulkGetOrCreateLocations(DojoTestCase): + + def test_supported_location_types_includes_url(self): + supported = LocationManager.get_supported_location_types() + self.assertIn("url", supported) + self.assertIs(supported["url"], URL) + + def test_url_only(self): + urls = [_make_url("oss-loc-mgr.example.com")] + saved = LocationManager._bulk_get_or_create_locations(urls) + + self.assertEqual(len(saved), 1) + self.assertIsInstance(saved[0], URL) + + def test_handles_cleaned_location_data(self): + loc_data = LocationData(type="url", data={"url": "https://oss-from-data.example.com/api"}) + cleaned = LocationManager.clean_unsaved_locations([loc_data]) + saved = LocationManager._bulk_get_or_create_locations(cleaned) + + self.assertEqual(len(saved), 1) + self.assertIsInstance(saved[0], URL) + self.assertEqual(saved[0].host, "oss-from-data.example.com") + + def test_empty_input(self): + self.assertEqual(LocationManager._bulk_get_or_create_locations([]), []) + + +# --------------------------------------------------------------------------- +# LocationManager.persist — ref creation details +# --------------------------------------------------------------------------- +@skip_unless_v3 +class TestPersistRefCreation(DojoTestCase): + + def test_uses_association_data(self): + finding = _make_finding() + product = finding.test.engagement.product + url = _make_url("oss-refs-assoc.example.com") + url._association_data = LocationAssociationData( + relationship_type="owned_by", + relationship_data={"file_path": "/app/main.py"}, + ) + + mgr = LocationManager(product) + mgr.record_locations_for_finding(finding, [url]) + mgr.persist() + + ref = LocationFindingReference.objects.get(finding=finding) + self.assertEqual(ref.relationship, "owned_by") + self.assertEqual(ref.relationship_data, {"file_path": "/app/main.py"}) + + +# --------------------------------------------------------------------------- +# End-to-end: record + persist +# --------------------------------------------------------------------------- +@skip_unless_v3 +class TestRecordAndPersist(DojoTestCase): + + def test_full_pipeline(self): + finding = _make_finding() + product = finding.test.engagement.product + + loc_data = [ + LocationData(type="url", data={"url": "https://oss-e2e-1.example.com/api"}), + LocationData(type="url", data={"url": "https://oss-e2e-2.example.com/api"}), + ] + + mgr = LocationManager(product) + mgr.record_locations_for_finding(finding, loc_data) + mgr.persist() + + self.assertEqual(LocationFindingReference.objects.filter(finding=finding).count(), 2) + self.assertEqual(LocationProductReference.objects.filter(product=product).count(), 2) + + def test_empty_locations(self): + finding = _make_finding() + product = finding.test.engagement.product + + mgr = LocationManager(product) + mgr.record_locations_for_finding(finding, []) + mgr.persist() + + self.assertEqual(LocationFindingReference.objects.filter(finding=finding).count(), 0) + + def test_idempotent(self): + finding = _make_finding() + product = finding.test.engagement.product + loc_data = [LocationData(type="url", data={"url": "https://oss-idempotent.example.com"})] + + mgr = LocationManager(product) + mgr.record_locations_for_finding(finding, loc_data) + mgr.persist() + mgr.record_locations_for_finding(finding, loc_data) + mgr.persist() + + self.assertEqual(LocationFindingReference.objects.filter(finding=finding).count(), 1) + + def test_multiple_findings_single_persist(self): + finding1 = _make_finding() + product = finding1.test.engagement.product + # Create second finding on the same product/engagement/test + finding2 = Finding.objects.create( + test=finding1.test, title="Bulk Test Finding 2", severity="High", reporter=finding1.reporter, + ) + + mgr = LocationManager(product) + mgr.record_locations_for_finding(finding1, [ + LocationData(type="url", data={"url": "https://oss-multi-1.example.com"}), + ]) + mgr.record_locations_for_finding(finding2, [ + LocationData(type="url", data={"url": "https://oss-multi-2.example.com"}), + ]) + mgr.persist() + + self.assertEqual(LocationFindingReference.objects.filter(finding=finding1).count(), 1) + self.assertEqual(LocationFindingReference.objects.filter(finding=finding2).count(), 1) + self.assertEqual(LocationProductReference.objects.filter(product=product).count(), 2) + + def test_locations_are_cleaned_during_persist(self): + """ + Verify that locations go through clean() normalization when persisted. + + URL.clean() normalizes protocol/host to lowercase and sets default ports. + If clean isn't called, the raw input values would be stored as-is. + """ + finding = _make_finding() + product = finding.test.engagement.product + + # Create a URL with uppercase protocol and host — clean() should normalize these + loc_data = [LocationData(type="url", data={ + "protocol": "HTTPS", + "host": "UPPERCASE.EXAMPLE.COM", + "path": "api/v1", + })] + mgr = LocationManager(product) + mgr.record_locations_for_finding(finding, loc_data) + mgr.persist() + + saved_url = URL.objects.get(host="uppercase.example.com") + # Protocol should be lowercased + self.assertEqual(saved_url.protocol, "https") + # Host should be lowercased + self.assertEqual(saved_url.host, "uppercase.example.com") + # Default HTTPS port should be set + self.assertEqual(saved_url.port, 443) + + +# --------------------------------------------------------------------------- +# EndpointManager: verify clean is called during record_for_finding +# --------------------------------------------------------------------------- +@skip_unless_v2 +class TestEndpointCleanOnRecord(DojoTestCase): + + def test_endpoints_are_cleaned_during_record_for_finding(self): + """ + Verify that EndpointManager.record_for_finding() runs clean() on endpoints. + + Endpoint.clean() validates format (not normalize case). An endpoint with + an invalid protocol should trigger a warning log but still be recorded + (DefectDojo stores broken endpoints with a warning). An endpoint with a + valid protocol should pass through clean() without error. + """ + # Keep imports here for reasy removal of this entire test in the future, once endpoints is gone + from dojo.importers.endpoint_manager import EndpointManager # noqa: PLC0415 + from dojo.models import Endpoint # noqa: PLC0415 + + pt, _ = Product_Type.objects.get_or_create(name="EP Clean Test Type") + product = Product.objects.create(name="EP Clean Test Product", description="t", prod_type=pt) + + mgr = EndpointManager(product) + + finding = _make_finding() + # Valid endpoint + one with empty protocol (clean sets it to None) + ep_valid = Endpoint(protocol="https", host="good.example.com") + ep_empty_proto = Endpoint(protocol="", host="empty-proto.example.com") + finding.unsaved_endpoints = [ep_valid, ep_empty_proto] + + mgr.record_for_finding(finding) + + # clean() should have set empty protocol to None + self.assertEqual(ep_valid.protocol, "https") + self.assertIsNone(ep_empty_proto.protocol) + + +# --------------------------------------------------------------------------- +# Query efficiency +# --------------------------------------------------------------------------- +@skip_unless_v3 +class TestBulkQueryEfficiency(DojoTestCase): + + def test_bulk_fewer_queries_than_locations(self): + urls = [_make_url(f"oss-perf-{i}.example.com") for i in range(50)] + + with CaptureQueriesContext(connection) as ctx: + URL.bulk_get_or_create(urls) + + # Expected: ~3 queries (SELECT existing, INSERT parents, INSERT subtypes) + self.assertLess(len(ctx.captured_queries), 10) + + +# --------------------------------------------------------------------------- +# Tag inheritance after bulk persist +# --------------------------------------------------------------------------- +@skip_unless_v3 +class TestTagInheritanceOnPersist(DojoTestCase): + + def test_locations_inherit_product_tags(self): + """Locations should inherit tags from their associated product after persist.""" + finding = _make_finding() + product = finding.test.engagement.product + # Enable tag inheritance at the product level and add some product tags + product.enable_product_tag_inheritance = True + product.save() + product.tags.add("inherit", "tags", "these") + + loc_data = [LocationData(type="url", data={"url": "https://oss-tag-inherit.example.com"})] + mgr = LocationManager(product) + mgr.record_locations_for_finding(finding, loc_data) + mgr.persist() + + loc = Location.objects.get(url__host="oss-tag-inherit.example.com") + inherited = sorted(t.name for t in loc.inherited_tags.all()) + self.assertEqual(inherited, ["inherit", "tags", "these"]) + + def test_bulk_inherit_is_no_op_when_already_in_sync(self): + """Calling persist() again with the same data should not re-inherit (no mutation queries).""" + finding = _make_finding() + product = finding.test.engagement.product + product.enable_product_tag_inheritance = True + product.save() + product.tags.add("a", "b") + + loc_data = [LocationData(type="url", data={"url": "https://oss-nosync.example.com"})] + # First import — mutations expected + mgr1 = LocationManager(product) + mgr1.record_locations_for_finding(finding, loc_data) + mgr1.persist() + + # Second import — tags already inherited, should be a fast no-op + mgr2 = LocationManager(product) + mgr2.record_locations_for_finding(finding, loc_data) + with CaptureQueriesContext(connection) as ctx: + mgr2.persist() + + # Verify no INSERT or UPDATE queries fired in the inheritance path + mutation_queries = [q for q in ctx.captured_queries if q["sql"].startswith(("INSERT", "UPDATE"))] + # There may still be refs check INSERTs if we're creating the LocationFindingReference again, + # but inherited_tags mutation should be absent. + for q in mutation_queries: + self.assertNotIn("inherited_tags", q["sql"].lower(), f"Unexpected inherited_tags mutation: {q['sql']}") + + def test_bulk_inherit_already_synced_is_constant_time(self): + """ + The main win from the bulk variant is skipping the per-instance mutation path when + locations are already in sync with their product's tags. This test verifies that + repeated persist() calls don't re-do the expensive tagulous work. + """ + finding = _make_finding() + product = finding.test.engagement.product + product.enable_product_tag_inheritance = True + product.save() + product.tags.add("p-tag-1", "p-tag-2") + + loc_data = [ + LocationData(type="url", data={"url": f"https://oss-sync-{i}.example.com"}) + for i in range(10) + ] + # First import to populate inherited_tags + mgr1 = LocationManager(product) + mgr1.record_locations_for_finding(finding, loc_data) + mgr1.persist() + + # Second import — same data, already in sync; should do zero mutation queries + mgr2 = LocationManager(product) + mgr2.record_locations_for_finding(finding, loc_data) + with CaptureQueriesContext(connection) as ctx: + mgr2.persist() + + # No UPDATEs or INSERTs on inherited_tags / tags through tables should fire + tag_through = Location.tags.through._meta.db_table + inherited_through = Location.inherited_tags.through._meta.db_table + for q in ctx.captured_queries: + sql = q["sql"].lower() + if sql.startswith(("insert", "update", "delete")): + self.assertNotIn( + tag_through.lower(), sql, f"Unexpected tags mutation: {q['sql']}", + ) + self.assertNotIn( + inherited_through.lower(), sql, f"Unexpected inherited_tags mutation: {q['sql']}", + ) + + +# --------------------------------------------------------------------------- +# Status update query efficiency +# --------------------------------------------------------------------------- +@skip_unless_v3 +class TestStatusUpdateQueryEfficiency(DojoTestCase): + + """ + Verify that persist() flushes status updates with a bounded number of queries, + regardless of how many findings were recorded (not O(n)). + """ + + def _setup_findings_with_mitigated_refs(self, count: int): + """Create `count` findings in a single product, each with a mitigated LocationFindingReference.""" + # Single product for all findings + first_finding = _make_finding() + product = first_finding.test.engagement.product + test = first_finding.test + reporter = first_finding.reporter + + findings = [first_finding] + findings.extend( + Finding.objects.create( + test=test, title=f"Status Test Finding {i}", severity="Medium", reporter=reporter, + ) + for i in range(count - 1) + ) + + # Create one mitigated LocationFindingReference per finding + for i, finding in enumerate(findings): + saved = URL.bulk_get_or_create([_make_url(f"oss-status-{i}.example.com")]) + LocationFindingReference.objects.create( + location=saved[0].location, + finding=finding, + status=FindingLocationStatus.Mitigated, + ) + return findings, product + + def test_reactivate_for_many_findings_is_bulk(self): + findings, product = self._setup_findings_with_mitigated_refs(count=20) + mgr = LocationManager(product) + for finding in findings: + mgr.record_reactivations_for_finding(finding) + + with CaptureQueriesContext(connection) as ctx: + mgr.persist() + + # Expected: 1 SELECT (gather ref IDs) + 1 UPDATE (reactivate) + # + 1 SELECT (affected location_ids) + 1 SELECT (still-active check) + up to 1 UPDATE (product refs) + self.assertLess(len(ctx.captured_queries), 8, ctx.captured_queries) + + def test_update_location_status_for_many_findings_is_bulk(self): + findings, product = self._setup_findings_with_mitigated_refs(count=20) + reporter = findings[0].reporter + mgr = LocationManager(product) + + # Simulate reimport "matched finding" flow: new finding with no unsaved locations => mitigate all + for finding in findings: + new_finding = Finding(title=finding.title, severity=finding.severity, test=finding.test, is_mitigated=True) + new_finding.unsaved_locations = [] + mgr.update_location_status(finding, new_finding, reporter) + + with CaptureQueriesContext(connection) as ctx: + mgr.persist() + + # Expected: 1 SELECT (partial-status fetch) + 1 UPDATE (mitigate) + # + 1 SELECT (affected location_ids) + 1 SELECT (still-active check) + up to 1 UPDATE (product refs) + self.assertLess(len(ctx.captured_queries), 8, ctx.captured_queries) + + def test_partial_status_update_reactivates_matching_mitigates_rest(self): + """ + When a reimported finding is NOT mitigated, locations still present in + the report should be reactivated, and locations absent from the report + should be mitigated. + """ + finding = _make_finding() + product = finding.test.engagement.product + + # Create three locations, all currently mitigated on this finding + url_kept = _make_url("kept.example.com") + url_also_kept = _make_url("also-kept.example.com") + url_gone = _make_url("gone.example.com") + saved = URL.bulk_get_or_create([url_kept, url_also_kept, url_gone]) + + refs = [ + LocationFindingReference.objects.create( + location=loc.location, + finding=finding, + status=FindingLocationStatus.Mitigated, + ) + for loc in saved + ] + + # Simulate a reimport where the new finding is active and only has two of the three locations + new_finding = Finding( + title=finding.title, severity=finding.severity, + test=finding.test, is_mitigated=False, + ) + new_finding.unsaved_locations = [ + LocationData(type="url", data={"url": "https://kept.example.com"}), + LocationData(type="url", data={"url": "https://also-kept.example.com"}), + ] + + mgr = LocationManager(product) + mgr.update_location_status(finding, new_finding, finding.reporter) + mgr.persist() + + # Refresh from DB + for ref in refs: + ref.refresh_from_db() + + # The two locations still in the report should be reactivated + self.assertEqual(refs[0].status, FindingLocationStatus.Active) + self.assertEqual(refs[1].status, FindingLocationStatus.Active) + # The location no longer in the report should be mitigated + self.assertEqual(refs[2].status, FindingLocationStatus.Mitigated) + + def test_product_ref_mitigated_when_all_finding_refs_mitigated(self): + """When all finding refs for a location are mitigated, the product ref should become mitigated.""" + finding = _make_finding() + product = finding.test.engagement.product + + url = _make_url("product-status-test.example.com") + saved = URL.bulk_get_or_create([url]) + loc = saved[0] + + # Create active finding ref and active product ref + LocationFindingReference.objects.create( + location=loc.location, finding=finding, status=FindingLocationStatus.Active, + ) + product_ref = LocationProductReference.objects.create( + location=loc.location, product=product, status=ProductLocationStatus.Active, + ) + + # Mitigate the finding + mgr = LocationManager(product) + mgr.record_mitigations_for_finding(finding, finding.reporter) + mgr.persist() + + product_ref.refresh_from_db() + self.assertEqual(product_ref.status, ProductLocationStatus.Mitigated) + + def test_product_ref_stays_active_when_some_finding_refs_still_active(self): + """When at least one finding ref is active, the product ref should stay active.""" + finding1 = _make_finding() + product = finding1.test.engagement.product + finding2 = Finding.objects.create( + test=finding1.test, title="Second Finding", severity="Medium", reporter=finding1.reporter, + ) + + url = _make_url("shared-location.example.com") + saved = URL.bulk_get_or_create([url]) + loc = saved[0] + + # Two findings share the same location, both active + LocationFindingReference.objects.create( + location=loc.location, finding=finding1, status=FindingLocationStatus.Active, + ) + LocationFindingReference.objects.create( + location=loc.location, finding=finding2, status=FindingLocationStatus.Active, + ) + product_ref = LocationProductReference.objects.create( + location=loc.location, product=product, status=ProductLocationStatus.Active, + ) + + # Mitigate only the first finding — second is still active + mgr = LocationManager(product) + mgr.record_mitigations_for_finding(finding1, finding1.reporter) + mgr.persist() + + product_ref.refresh_from_db() + self.assertEqual(product_ref.status, ProductLocationStatus.Active) + + def test_product_ref_reactivated_when_finding_ref_reactivated(self): + """When a finding ref is reactivated, the product ref should become active.""" + finding = _make_finding() + product = finding.test.engagement.product + + url = _make_url("reactivate-product.example.com") + saved = URL.bulk_get_or_create([url]) + loc = saved[0] + + # Start with everything mitigated + LocationFindingReference.objects.create( + location=loc.location, finding=finding, status=FindingLocationStatus.Mitigated, + ) + product_ref = LocationProductReference.objects.create( + location=loc.location, product=product, status=ProductLocationStatus.Mitigated, + ) + + # Reactivate the finding + mgr = LocationManager(product) + mgr.record_reactivations_for_finding(finding) + mgr.persist() + + product_ref.refresh_from_db() + self.assertEqual(product_ref.status, ProductLocationStatus.Active) diff --git a/unittests/test_importers_performance.py b/unittests/test_importers_performance.py index 1a9c9fc137d..2ca6d251bf6 100644 --- a/unittests/test_importers_performance.py +++ b/unittests/test_importers_performance.py @@ -569,14 +569,14 @@ def test_import_reimport_reimport_performance_pghistory_async(self): configure_pghistory_triggers() self._import_reimport_performance( - expected_num_queries1=1191, - expected_num_async_tasks1=6, - expected_num_queries2=716, - expected_num_async_tasks2=17, - expected_num_queries3=346, - expected_num_async_tasks3=16, - expected_num_queries4=212, - expected_num_async_tasks4=6, + expected_num_queries1=146, + expected_num_async_tasks1=1, + expected_num_queries2=124, + expected_num_async_tasks2=1, + expected_num_queries3=37, + expected_num_async_tasks3=1, + expected_num_queries4=101, + expected_num_async_tasks4=0, ) @override_settings(ENABLE_AUDITLOG=True) @@ -593,14 +593,14 @@ def test_import_reimport_reimport_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._import_reimport_performance( - expected_num_queries1=1200, - expected_num_async_tasks1=6, - expected_num_queries2=725, - expected_num_async_tasks2=17, - expected_num_queries3=355, - expected_num_async_tasks3=16, - expected_num_queries4=212, - expected_num_async_tasks4=6, + expected_num_queries1=155, + expected_num_async_tasks1=1, + expected_num_queries2=133, + expected_num_async_tasks2=1, + expected_num_queries3=46, + expected_num_async_tasks3=1, + expected_num_queries4=101, + expected_num_async_tasks4=0, ) @override_settings(ENABLE_AUDITLOG=True) @@ -618,14 +618,14 @@ def test_import_reimport_reimport_performance_pghistory_no_async_with_product_gr self.system_settings(enable_product_grade=True) self._import_reimport_performance( - expected_num_queries1=1210, - expected_num_async_tasks1=8, - expected_num_queries2=735, - expected_num_async_tasks2=19, - expected_num_queries3=359, - expected_num_async_tasks3=18, - expected_num_queries4=222, - expected_num_async_tasks4=8, + expected_num_queries1=165, + expected_num_async_tasks1=3, + expected_num_queries2=143, + expected_num_async_tasks2=3, + expected_num_queries3=50, + expected_num_async_tasks3=3, + expected_num_queries4=111, + expected_num_async_tasks4=2, ) def _deduplication_performance(self, expected_num_queries1, expected_num_async_tasks1, expected_num_queries2, expected_num_async_tasks2, *, check_duplicates=True): @@ -718,10 +718,10 @@ def test_deduplication_performance_pghistory_async(self): self.system_settings(enable_deduplication=True) self._deduplication_performance( - expected_num_queries1=1411, - expected_num_async_tasks1=7, - expected_num_queries2=1016, - expected_num_async_tasks2=7, + expected_num_queries1=81, + expected_num_async_tasks1=1, + expected_num_queries2=72, + expected_num_async_tasks2=1, check_duplicates=False, # Async mode - deduplication happens later ) @@ -738,8 +738,8 @@ def test_deduplication_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._deduplication_performance( - expected_num_queries1=1420, - expected_num_async_tasks1=7, - expected_num_queries2=1132, - expected_num_async_tasks2=7, + expected_num_queries1=90, + expected_num_async_tasks1=1, + expected_num_queries2=188, + expected_num_async_tasks2=1, )