diff --git a/aikido_zen/background_process/aikido_background_process.py b/aikido_zen/background_process/aikido_background_process.py index da00af662..b4b1b3f37 100644 --- a/aikido_zen/background_process/aikido_background_process.py +++ b/aikido_zen/background_process/aikido_background_process.py @@ -102,13 +102,7 @@ def send_to_connection_manager(self, event_scheduler): EMPTY_QUEUE_INTERVAL, 1, self.send_to_connection_manager, (event_scheduler,) ) while not self.queue.empty(): - queue_attack_item = self.queue.get() - self.connection_manager.on_detected_attack( - attack=queue_attack_item[0], - context=queue_attack_item[1], - blocked=queue_attack_item[2], - stack=queue_attack_item[3], - ) + self.connection_manager.report_api_event(self.queue.get()) def add_exit_handlers(): diff --git a/aikido_zen/background_process/cloud_connection_manager/__init__.py b/aikido_zen/background_process/cloud_connection_manager/__init__.py index a61fb9ba0..ea92f5074 100644 --- a/aikido_zen/background_process/cloud_connection_manager/__init__.py +++ b/aikido_zen/background_process/cloud_connection_manager/__init__.py @@ -10,12 +10,12 @@ from aikido_zen.storage.users import Users from aikido_zen.storage.hostnames import Hostnames from ..realtime.start_polling_for_changes import start_polling_for_changes +from ...helpers.get_current_unixtime_ms import get_unixtime_ms from ...storage.ai_statistics import AIStatistics from ...storage.firewall_lists import FirewallLists from ...storage.statistics import Statistics # Import functions : -from .on_detected_attack import on_detected_attack from .get_manager_info import get_manager_info from .update_service_config import update_service_config from .on_start import on_start @@ -57,7 +57,7 @@ def __init__(self, block, api, token, serverless): def start(self, event_scheduler): """Send out start event and add heartbeats""" - res = self.on_start() + res = on_start(self) if res.get("error", None) == "invalid_token": logger.info( "Token was invalid, not starting heartbeats and realtime polling." @@ -81,22 +81,10 @@ def report_initial_stats(self): if should_report_initial_stats: self.send_heartbeat() - def on_detected_attack(self, attack, context, blocked, stack): - """This will send something to the API when an attack is detected""" - return on_detected_attack(self, attack, context, blocked, stack) - - def on_start(self): - """This will send out an Event signalling the start to the server""" - return on_start(self) - def send_heartbeat(self): """This will send a heartbeat to the server""" return send_heartbeat(self) - def get_manager_info(self): - """This returns info about the connection_manager""" - return get_manager_info(self) - def update_service_config(self, res): """Update configuration based on the server's response""" return update_service_config(self, res) @@ -104,3 +92,26 @@ def update_service_config(self, res): def update_firewall_lists(self): """Will update service config with blocklist of IP addresses""" return update_firewall_lists(self) + + def report_api_event(self, event): + if not self.token: + return {"success": False, "error": "invalid_token"} + try: + payload = { + "time": get_unixtime_ms(), + "agent": get_manager_info(self), + } + payload.update(event) # Merge default fields with event fields + + result = self.api.report(self.token, payload, self.timeout_in_sec) + if not result.get("success", True): + logger.error( + "CloudConnectionManager: Reporting to api failed, error=%s", + result.get("error", "unknown"), + ) + return result + except Exception as e: + logger.debug(e) + logger.error( + "CloudConnectionManager: Reporting to api failed, unexpected error (see debug logs)" + ) diff --git a/aikido_zen/background_process/cloud_connection_manager/on_detected_attack.py b/aikido_zen/background_process/cloud_connection_manager/on_detected_attack.py deleted file mode 100644 index 3d04f2bf4..000000000 --- a/aikido_zen/background_process/cloud_connection_manager/on_detected_attack.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Mainly exports on_detected_attack""" - -import json - -from aikido_zen.helpers.get_current_unixtime_ms import get_unixtime_ms -from aikido_zen.helpers.logging import logger -from aikido_zen.helpers.limit_length_metadata import limit_length_metadata -from aikido_zen.helpers.serialize_to_json import serialize_to_json - - -def on_detected_attack(connection_manager, attack, context, blocked, stack): - """ - This will send something to the API when an attack is detected - """ - if not connection_manager.token: - return - - try: - attack["user"] = getattr(context, "user", None) - attack["payload"] = json.dumps(attack["payload"])[:4096] - attack["metadata"] = limit_length_metadata(attack["metadata"], 4096) - attack["blocked"] = blocked - attack["stack"] = stack - - payload = { - "type": "detected_attack", - "time": get_unixtime_ms(), - "agent": connection_manager.get_manager_info(), - "attack": attack, - "request": extract_request_if_possible(context), - } - logger.debug(serialize_to_json(payload)) - result = connection_manager.api.report( - connection_manager.token, - payload, - connection_manager.timeout_in_sec, - ) - logger.debug("Result : %s", result) - except Exception as e: - logger.debug(e) - logger.info("Failed to report an attack") - - -def extract_request_if_possible(context): - if not context: - return None - return { - "method": getattr(context, "method", None), - "url": getattr(context, "url", None), - "ipAddress": getattr(context, "remote_address", None), - "source": getattr(context, "source", None), - "route": getattr(context, "route", None), - "userAgent": context.get_user_agent(), - } diff --git a/aikido_zen/background_process/cloud_connection_manager/on_start.py b/aikido_zen/background_process/cloud_connection_manager/on_start.py index 8414821ce..6407ee7cf 100644 --- a/aikido_zen/background_process/cloud_connection_manager/on_start.py +++ b/aikido_zen/background_process/cloud_connection_manager/on_start.py @@ -5,26 +5,16 @@ def on_start(connection_manager): - """ - This will send out an Event signalling the start to the server - """ - if not connection_manager.token: - return - res = connection_manager.api.report( - connection_manager.token, - { - "type": "started", - "time": get_unixtime_ms(), - "agent": connection_manager.get_manager_info(), - }, - connection_manager.timeout_in_sec, - ) + event = {"type": "started"} + res = connection_manager.report_api_event(event) + if not res.get("success", True): - # Update config time even in failure : - connection_manager.conf.last_updated_at = get_unixtime_ms() - logger.error("Failed to communicate with Aikido Server : %s", res["error"]) + connection_manager.conf.last_updated_at = ( + get_unixtime_ms() + ) # Update config time even in failure else: connection_manager.update_service_config(res) connection_manager.update_firewall_lists() logger.info("Established connection with Aikido Server") + return res diff --git a/aikido_zen/background_process/cloud_connection_manager/on_start_test.py b/aikido_zen/background_process/cloud_connection_manager/on_start_test.py index 4dffb92cb..fd9580d4e 100644 --- a/aikido_zen/background_process/cloud_connection_manager/on_start_test.py +++ b/aikido_zen/background_process/cloud_connection_manager/on_start_test.py @@ -8,7 +8,7 @@ def mock_connection_manager(): connection_manager = MagicMock() connection_manager.token = "test_token" connection_manager.timeout_in_sec = 5 - connection_manager.api.report = MagicMock(return_value={"success": True}) + connection_manager.report_api_event = MagicMock(return_value={"success": True}) connection_manager.get_manager_info = lambda: {} connection_manager.update_service_config = MagicMock() return connection_manager @@ -19,7 +19,7 @@ def test_on_start_no_token(): connection_manager = MagicMock() connection_manager.token = None on_start(connection_manager) - connection_manager.api.report.assert_not_called() + connection_manager.report_api_event.assert_called() def test_on_start_success(mock_connection_manager, caplog): @@ -27,7 +27,7 @@ def test_on_start_success(mock_connection_manager, caplog): on_start(mock_connection_manager) # Check that the API report method was called - mock_connection_manager.api.report.assert_called_once() + mock_connection_manager.report_api_event.assert_called_once() # Check that the service config was updated mock_connection_manager.update_service_config.assert_called_once() @@ -38,7 +38,7 @@ def test_on_start_success(mock_connection_manager, caplog): def test_on_start_failure(mock_connection_manager, caplog): """Test that an error is logged when the API call fails.""" - mock_connection_manager.api.report.return_value = { + mock_connection_manager.report_api_event.return_value = { "success": False, "error": "Some error", } @@ -46,10 +46,7 @@ def test_on_start_failure(mock_connection_manager, caplog): on_start(mock_connection_manager) # Check that the API report method was called - mock_connection_manager.api.report.assert_called_once() + mock_connection_manager.report_api_event.assert_called_once() # Check that the service config was not updated mock_connection_manager.update_service_config.assert_not_called() - - # Check that the error log was called - assert "Failed to communicate with Aikido Server : Some error" in caplog.text diff --git a/aikido_zen/background_process/cloud_connection_manager/send_heartbeat.py b/aikido_zen/background_process/cloud_connection_manager/send_heartbeat.py index 9a1ff9885..18fa732fa 100644 --- a/aikido_zen/background_process/cloud_connection_manager/send_heartbeat.py +++ b/aikido_zen/background_process/cloud_connection_manager/send_heartbeat.py @@ -2,7 +2,6 @@ from aikido_zen.background_process.packages import PackagesStore from aikido_zen.helpers.logging import logger -from aikido_zen.helpers.get_current_unixtime_ms import get_unixtime_ms def send_heartbeat(connection_manager): @@ -26,12 +25,9 @@ def send_heartbeat(connection_manager): connection_manager.ai_stats.clear() PackagesStore.clear() - res = connection_manager.api.report( - connection_manager.token, + res = connection_manager.report_api_event( { "type": "heartbeat", - "time": get_unixtime_ms(), - "agent": connection_manager.get_manager_info(), "stats": stats, "ai": ai_stats, "hostnames": outgoing_domains, @@ -39,7 +35,6 @@ def send_heartbeat(connection_manager): "routes": routes, "users": users, "middlewareInstalled": connection_manager.middleware_installed, - }, - connection_manager.timeout_in_sec, + } ) connection_manager.update_service_config(res) diff --git a/aikido_zen/background_process/cloud_connection_manager/update_service_config.py b/aikido_zen/background_process/cloud_connection_manager/update_service_config.py index 7d7f729e8..42ce97ca1 100644 --- a/aikido_zen/background_process/cloud_connection_manager/update_service_config.py +++ b/aikido_zen/background_process/cloud_connection_manager/update_service_config.py @@ -9,8 +9,8 @@ def update_service_config(connection_manager, res): Update configuration based on the server's response """ if res.get("success", False) is False: - logger.debug(res) return + if "block" in res.keys() and res["block"] != connection_manager.block: logger.debug("Updating blocking, setting blocking to : %s", res["block"]) connection_manager.block = bool(res["block"]) diff --git a/aikido_zen/background_process/commands/__init__.py b/aikido_zen/background_process/commands/__init__.py index 151d25621..5b36b7580 100644 --- a/aikido_zen/background_process/commands/__init__.py +++ b/aikido_zen/background_process/commands/__init__.py @@ -1,7 +1,6 @@ -"""Commands __init__.py file""" - from aikido_zen.helpers.logging import logger -from .attack import process_attack +from aikido_zen.helpers.ipc.command_types import CommandContext +from .put_event import PutEventCommand from .check_firewall_lists import process_check_firewall_lists from .read_property import process_read_property from .should_ratelimit import process_should_ratelimit @@ -9,26 +8,27 @@ from .sync_data import process_sync_data commands_map = { - # This maps to a tuple : (function, returns_data?) - # Commands that don't return data : - "ATTACK": (process_attack, False), - # Commands that return data : - "SYNC_DATA": (process_sync_data, True), - "READ_PROPERTY": (process_read_property, True), - "SHOULD_RATELIMIT": (process_should_ratelimit, True), - "PING": (process_ping, True), - "CHECK_FIREWALL_LISTS": (process_check_firewall_lists, True), + "SYNC_DATA": process_sync_data, + "READ_PROPERTY": process_read_property, + "SHOULD_RATELIMIT": process_should_ratelimit, + "PING": process_ping, + "CHECK_FIREWALL_LISTS": process_check_firewall_lists, } +modern_commands = [PutEventCommand] + def process_incoming_command(connection_manager, obj, conn, queue): - """Processes an incoming command""" - action = obj[0] - data = obj[1] - if action in commands_map: - func, returns_data = commands_map[action] - if returns_data: - return conn.send(func(connection_manager, data, queue)) - func(connection_manager, data, queue) - else: - logger.debug("Command : `%s` not found, aborting", action) + inbound_identifier = obj[0] + inbound_request = obj[1] + if inbound_identifier in commands_map: + func = commands_map[inbound_identifier] + return conn.send(func(connection_manager, inbound_request)) + + for cmd in modern_commands: + if cmd.identifier() == inbound_identifier: + cmd.run(CommandContext(connection_manager, queue, conn), inbound_request) + return None + + logger.debug("Command : `%s` not found - did not execute", inbound_identifier) + return None diff --git a/aikido_zen/background_process/commands/attack.py b/aikido_zen/background_process/commands/attack.py deleted file mode 100644 index 92308e989..000000000 --- a/aikido_zen/background_process/commands/attack.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Main export is process_attack""" - - -def process_attack(connection_manager, data, queue): - """ - Adds ATTACK data object to queue - Expected data object : [injection_results, context, blocked_or_not, stacktrace] - """ - queue.put(data) diff --git a/aikido_zen/background_process/commands/attack_test.py b/aikido_zen/background_process/commands/attack_test.py deleted file mode 100644 index b5cb8e81d..000000000 --- a/aikido_zen/background_process/commands/attack_test.py +++ /dev/null @@ -1,69 +0,0 @@ -import pytest -from queue import Queue -from unittest.mock import MagicMock -from .attack import process_attack - - -class MockCloudConnectionManager: - def __init__(self): - self.statistics = MagicMock() - - -def test_process_attack_adds_data_to_queue(): - queue = Queue() - connection_manager = MockCloudConnectionManager() - data = ("injection_results", "context", True, "stacktrace") # Example data - process_attack(connection_manager, data, queue) - - # Check if the data is added to the queue - assert not queue.empty() - assert queue.get() == data - - -def test_process_attack_statistics_not_called_when_disabled(): - queue = Queue() - connection_manager = MockCloudConnectionManager() - connection_manager.statistics = None # Disable statistics - data = ("injection_results", "context", True, "stacktrace") # Example data - process_attack(connection_manager, data, queue) - - # Check if on_detected_attack was not called - assert ( - connection_manager.statistics is None - or not connection_manager.statistics.on_detected_attack.called - ) - - -def test_process_attack_multiple_calls(): - queue = Queue() - connection_manager = MockCloudConnectionManager() - data1 = ("injection_results_1", "context_1", True, "stacktrace_1") - data2 = ("injection_results_2", "context_2", False, "stacktrace_2") - - process_attack(connection_manager, data1, queue) - process_attack(connection_manager, data2, queue) - - # Check if both data items are added to the queue - assert queue.qsize() == 2 - assert queue.get() == data1 - assert queue.get() == data2 - - -def test_process_attack_with_different_data_formats(): - queue = Queue() - connection_manager = MockCloudConnectionManager() - - # Test with different types of data - data1 = ("injection_results", "context", True, "stacktrace") - data2 = ("injection_results", "context", False, "stacktrace") - data3 = ("injection_results", "context", None, "stacktrace") - - process_attack(connection_manager, data1, queue) - process_attack(connection_manager, data2, queue) - process_attack(connection_manager, data3, queue) - - # Check if all data items are added to the queue - assert queue.qsize() == 3 - assert queue.get() == data1 - assert queue.get() == data2 - assert queue.get() == data3 diff --git a/aikido_zen/background_process/commands/check_firewall_lists.py b/aikido_zen/background_process/commands/check_firewall_lists.py index 390544a8a..acf323043 100644 --- a/aikido_zen/background_process/commands/check_firewall_lists.py +++ b/aikido_zen/background_process/commands/check_firewall_lists.py @@ -1,13 +1,7 @@ """Exports process_check_firewall_lists""" -from aikido_zen.background_process.cloud_connection_manager import ( - CloudConnectionManager, -) - -def process_check_firewall_lists( - connection_manager: CloudConnectionManager, data, conn, queue=None -): +def process_check_firewall_lists(connection_manager, data): """ Checks whether an IP is blocked data: {"ip": string, "user-agent": string} diff --git a/aikido_zen/background_process/commands/ping.py b/aikido_zen/background_process/commands/ping.py index 9245e5fd1..9363eab0c 100644 --- a/aikido_zen/background_process/commands/ping.py +++ b/aikido_zen/background_process/commands/ping.py @@ -1,6 +1,6 @@ """exports `process_ping`""" -def process_ping(connection_manager, data, queue=None): +def process_ping(connection_manager, data): """when main process quits , or during testing etc""" return "Received" diff --git a/aikido_zen/background_process/commands/put_event.py b/aikido_zen/background_process/commands/put_event.py new file mode 100644 index 000000000..a2e51a94b --- /dev/null +++ b/aikido_zen/background_process/commands/put_event.py @@ -0,0 +1,27 @@ +from aikido_zen.helpers.ipc.command_types import Command, CommandContext, Payload + + +class PutEventReq: + def __init__(self, event): + # Event is a dictionary containing data that is going to be reported to core + # "time" and "agent" fields are added by default from the CloudConnectionManager + self.event = event + + +class PutEventCommand(Command): + @classmethod + def identifier(cls) -> str: + return "put_event" + + @classmethod + def returns_data(cls) -> bool: + return False + + @classmethod + def run(cls, context: CommandContext, request: PutEventReq): + # Events sent here get put in the event queue so that they are processed in the background + context.queue.put(request.event) + + @classmethod + def generate(cls, request) -> Payload: + return Payload(cls, PutEventReq(request)) diff --git a/aikido_zen/background_process/commands/read_property.py b/aikido_zen/background_process/commands/read_property.py index b70e6f31a..1bff5e395 100644 --- a/aikido_zen/background_process/commands/read_property.py +++ b/aikido_zen/background_process/commands/read_property.py @@ -3,7 +3,7 @@ from aikido_zen.helpers.logging import logger -def process_read_property(connection_manager, data, queue=None): +def process_read_property(connection_manager, data): """ Takes in one arg : name of property on connection_manager, tries to read it. Meant to get config props diff --git a/aikido_zen/background_process/commands/should_ratelimit.py b/aikido_zen/background_process/commands/should_ratelimit.py index 40ba02979..d51f99ce9 100644 --- a/aikido_zen/background_process/commands/should_ratelimit.py +++ b/aikido_zen/background_process/commands/should_ratelimit.py @@ -3,7 +3,7 @@ import aikido_zen.ratelimiting as ratelimiting -def process_should_ratelimit(connection_manager, data, queue=None): +def process_should_ratelimit(connection_manager, data): """ Called to check if the context passed along as data should be rate limited data object should be a dict including route_metadata, remote_address and user diff --git a/aikido_zen/background_process/commands/sync_data.py b/aikido_zen/background_process/commands/sync_data.py index 5fedcc8e1..4c5a95012 100644 --- a/aikido_zen/background_process/commands/sync_data.py +++ b/aikido_zen/background_process/commands/sync_data.py @@ -4,7 +4,7 @@ from aikido_zen.background_process.packages import PackagesStore -def process_sync_data(connection_manager, data, conn, queue=None): +def process_sync_data(connection_manager, data): """ Synchronizes data between the thread-local cache (with a TTL of usually 1 minute) and the background thread. Which data gets synced? diff --git a/aikido_zen/background_process/commands/sync_data_test.py b/aikido_zen/background_process/commands/sync_data_test.py index 286386cbd..32dad72dd 100644 --- a/aikido_zen/background_process/commands/sync_data_test.py +++ b/aikido_zen/background_process/commands/sync_data_test.py @@ -72,7 +72,7 @@ def test_process_sync_data_initialization(setup_connection_manager): ], } - result = process_sync_data(connection_manager, data, None) + result = process_sync_data(connection_manager, data) # Check that routes were initialized correctly assert len(connection_manager.routes) == 2 @@ -147,7 +147,7 @@ def test_process_sync_data_with_last_updated_at_below_zero(setup_connection_mana "middleware_installed": True, } - result = process_sync_data(connection_manager, data, None) + result = process_sync_data(connection_manager, data) # Check that routes were initialized correctly assert len(connection_manager.routes) == 2 @@ -214,7 +214,7 @@ def test_process_sync_data_existing_route_and_hostnames(setup_connection_manager } # First call to initialize the route - process_sync_data(connection_manager, data, None) + process_sync_data(connection_manager, data) # Second call to update the existing route data_update = { @@ -241,7 +241,7 @@ def test_process_sync_data_existing_route_and_hostnames(setup_connection_manager }, } - result = process_sync_data(connection_manager, data_update, None) + result = process_sync_data(connection_manager, data_update) # Check that the hit count was updated correctly assert ( @@ -275,7 +275,7 @@ def test_process_sync_data_no_routes(setup_connection_manager): connection_manager = setup_connection_manager data = {"current_routes": {}, "reqs": 0} # No requests to add - result = process_sync_data(connection_manager, data, None) + result = process_sync_data(connection_manager, data) # Check that no routes were initialized assert len(connection_manager.routes) == 0 diff --git a/aikido_zen/helpers/create_detected_attack_api_event.py b/aikido_zen/helpers/create_detected_attack_api_event.py new file mode 100644 index 000000000..c236092fc --- /dev/null +++ b/aikido_zen/helpers/create_detected_attack_api_event.py @@ -0,0 +1,36 @@ +import json + +from aikido_zen.helpers.limit_length_metadata import limit_length_metadata +from aikido_zen.helpers.logging import logger + + +def create_detected_attack_api_event(attack, context, blocked, stack): + try: + return { + "type": "detected_attack", + "attack": { + **attack, + "user": getattr(context, "user", None), + "payload": json.dumps(attack["payload"])[:4096], + "metadata": limit_length_metadata(attack["metadata"], 4096), + "blocked": blocked, + "stack": stack, + }, + "request": extract_request_if_possible(context), + } + except Exception as e: + logger.error("Failed to create detected_attack API event: %s", str(e)) + return None + + +def extract_request_if_possible(context): + if not context: + return None + return { + "method": getattr(context, "method", None), + "url": getattr(context, "url", None), + "ipAddress": getattr(context, "remote_address", None), + "source": getattr(context, "source", None), + "route": getattr(context, "route", None), + "userAgent": context.get_user_agent(), + } diff --git a/aikido_zen/background_process/cloud_connection_manager/on_detected_attack_test.py b/aikido_zen/helpers/create_detected_attack_api_event_test.py similarity index 51% rename from aikido_zen/background_process/cloud_connection_manager/on_detected_attack_test.py rename to aikido_zen/helpers/create_detected_attack_api_event_test.py index ae90f0ead..9918dabdc 100644 --- a/aikido_zen/background_process/cloud_connection_manager/on_detected_attack_test.py +++ b/aikido_zen/helpers/create_detected_attack_api_event_test.py @@ -1,111 +1,77 @@ import pytest from unittest.mock import MagicMock -from .on_detected_attack import on_detected_attack -from ...context import Context +from .create_detected_attack_api_event import create_detected_attack_api_event +from aikido_zen.context import Context import aikido_zen.test_utils as test_utils -@pytest.fixture -def mock_connection_manager(): - connection_manager = MagicMock() - connection_manager.token = "test_token" - connection_manager.block = True - connection_manager.timeout_in_sec = 5 - connection_manager.api.report = MagicMock(return_value={"status": "success"}) - connection_manager.get_manager_info = lambda: {} - return connection_manager - - -def test_on_detected_attack_no_token(): - connection_manager = MagicMock() - connection_manager.token = None - - on_detected_attack( - connection_manager, - attack={}, - context=test_utils.generate_context(), - blocked=False, - stack=None, - ) - - connection_manager.api.report.assert_not_called() - - -def test_on_detected_attack_with_long_payload(mock_connection_manager): +def test_create_attack_event_with_long_payload(): long_payload = "x" * 5000 # Create a payload longer than 4096 characters attack = { "payload": long_payload, "metadata": {"test": "1"}, } - on_detected_attack( - mock_connection_manager, + event = create_detected_attack_api_event( attack=attack, context=test_utils.generate_context(), blocked=False, stack=None, ) - assert len(attack["payload"]) == 4096 # Ensure payload is truncated - mock_connection_manager.api.report.assert_called_once() + assert len(event["attack"]["payload"]) == 4096 # Ensure payload is truncated -def test_on_detected_attack_with_long_metadata(mock_connection_manager): +def test_create_attack_event_with_long_metadata(): long_metadata = "x" * 5000 # Create metadata longer than 4096 characters attack = { "payload": {}, "metadata": {"test": long_metadata}, } - on_detected_attack( - mock_connection_manager, + event = create_detected_attack_api_event( attack=attack, context=test_utils.generate_context(), blocked=False, stack=None, ) - assert attack["metadata"]["test"] == long_metadata[:4096] - mock_connection_manager.api.report.assert_called_once() + assert event["attack"]["metadata"]["test"] == long_metadata[:4096] -def test_on_detected_attack_success(mock_connection_manager): +def test_create_attack_event_success(): attack = { "payload": {"key": "value"}, "metadata": {}, } - on_detected_attack( - mock_connection_manager, + event = create_detected_attack_api_event( attack=attack, context=test_utils.generate_context(), blocked=False, stack=None, ) - assert mock_connection_manager.api.report.call_count == 1 - - -def test_on_detected_attack_exception_handling(mock_connection_manager, caplog): - attack = { - "payload": {"key": "value"}, - "metadata": {"key": "value"}, + assert event == { + "attack": { + "blocked": False, + "metadata": {}, + "payload": '{"key": "value"}', + "stack": None, + "user": None, + }, + "request": { + "ipAddress": "1.1.1.1", + "method": "POST", + "route": "/", + "source": "flask", + "url": "http://localhost:8080/", + "userAgent": None, + }, + "type": "detected_attack", } - # Simulate an exception during the API call - mock_connection_manager.api.report.side_effect = Exception("API error") - - on_detected_attack( - mock_connection_manager, - attack=attack, - context=test_utils.generate_context(), - blocked=False, - stack=None, - ) - - assert "Failed to report an attack" in caplog.text - -def test_on_detected_attack_with_blocked_and_stack(mock_connection_manager): +def test_create_attack_event_with_blocked_and_stack(): attack = { "payload": {"key": "value"}, "metadata": {}, @@ -113,8 +79,7 @@ def test_on_detected_attack_with_blocked_and_stack(mock_connection_manager): blocked = True stack = "sample stack trace" - on_detected_attack( - mock_connection_manager, + event = create_detected_attack_api_event( attack=attack, context=test_utils.generate_context(), blocked=blocked, @@ -122,19 +87,17 @@ def test_on_detected_attack_with_blocked_and_stack(mock_connection_manager): ) # Check that the attack dictionary has the blocked and stack fields set - assert attack["blocked"] is True - assert attack["stack"] == stack - assert mock_connection_manager.api.report.call_count == 1 + assert event["attack"]["blocked"] is True + assert event["attack"]["stack"] == stack -def test_on_detected_attack_request_data_and_attack_data(mock_connection_manager): +def test_create_attack_event_request_data_and_attack_data(): attack = { "payload": {"key": "value"}, "metadata": {"test": "true"}, } - on_detected_attack( - mock_connection_manager, + event = create_detected_attack_api_event( attack=attack, context=test_utils.generate_context( method="GET", @@ -147,9 +110,6 @@ def test_on_detected_attack_request_data_and_attack_data(mock_connection_manager stack=None, ) - # Extract the call arguments for the report method - _, event, _ = mock_connection_manager.api.report.call_args[0] - # Verify the request attribute in the payload request_data = event["request"] @@ -170,44 +130,36 @@ def test_on_detected_attack_request_data_and_attack_data(mock_connection_manager assert attack_data["user"] is None -def test_on_detected_attack_with_user(mock_connection_manager): +def test_create_attack_event_with_user(): attack = { "payload": {"key": "value"}, "metadata": {}, } - on_detected_attack( - mock_connection_manager, + event = create_detected_attack_api_event( attack=attack, context=test_utils.generate_context(user="test_user"), blocked=False, stack=None, ) - # Extract the call arguments for the report method - _, event, _ = mock_connection_manager.api.report.call_args[0] - # Verify the user is included in the attack data assert event["attack"]["user"] == "test_user" -def test_on_detected_attack_no_context_and_attack_data(mock_connection_manager): +def test_create_attack_event_no_context_and_attack_data(): attack = { "payload": {"key": "value"}, "metadata": {"test": "true"}, } - on_detected_attack( - mock_connection_manager, + event = create_detected_attack_api_event( attack=attack, context=None, blocked=False, stack=None, ) - # Extract the call arguments for the report method - _, event, _ = mock_connection_manager.api.report.call_args[0] - # Verify the request attribute in the payload request_data = event["request"] attack_data = event["attack"] diff --git a/aikido_zen/helpers/ipc/__init__.py b/aikido_zen/helpers/ipc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/aikido_zen/helpers/ipc/command_types.py b/aikido_zen/helpers/ipc/command_types.py new file mode 100644 index 000000000..dbbace9f3 --- /dev/null +++ b/aikido_zen/helpers/ipc/command_types.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod + + +class CommandContext: + def __init__(self, connection_manager, queue, connection): + self.connection_manager = connection_manager + self.queue = queue + self.connection = connection + + def send(self, response): + self.connection.send(response) + + +class Payload: + def __init__(self, command, request): + self.identifier = command.identifier() + self.returns_data = command.returns_data() + self.request = request + + +class Command(ABC): + @classmethod + @abstractmethod + def identifier(cls) -> str: + pass + + @classmethod + @abstractmethod + def returns_data(cls) -> bool: + pass + + @classmethod + @abstractmethod + def run(cls, context: CommandContext, request): + pass + + @classmethod + @abstractmethod + def generate(cls, request) -> Payload: + pass diff --git a/aikido_zen/helpers/ipc/send_payload.py b/aikido_zen/helpers/ipc/send_payload.py new file mode 100644 index 000000000..563ff337c --- /dev/null +++ b/aikido_zen/helpers/ipc/send_payload.py @@ -0,0 +1,8 @@ +from aikido_zen.background_process import AikidoIPCCommunications +from aikido_zen.helpers.ipc.command_types import Payload + + +def send_payload(comms: AikidoIPCCommunications, payload: Payload): + return comms.send_data_to_bg_process( + payload.identifier, payload.request, payload.returns_data + ) diff --git a/aikido_zen/helpers/net/__init__.py b/aikido_zen/helpers/net/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/aikido_zen/vulnerabilities/ssrf/is_private_ip.py b/aikido_zen/helpers/net/is_private_ip.py similarity index 100% rename from aikido_zen/vulnerabilities/ssrf/is_private_ip.py rename to aikido_zen/helpers/net/is_private_ip.py diff --git a/aikido_zen/vulnerabilities/ssrf/is_private_ip_test.py b/aikido_zen/helpers/net/is_private_ip_test.py similarity index 96% rename from aikido_zen/vulnerabilities/ssrf/is_private_ip_test.py rename to aikido_zen/helpers/net/is_private_ip_test.py index db1467d93..504cc0251 100644 --- a/aikido_zen/vulnerabilities/ssrf/is_private_ip_test.py +++ b/aikido_zen/helpers/net/is_private_ip_test.py @@ -1,5 +1,4 @@ -import pytest -from .is_private_ip import is_private_ip +from aikido_zen.helpers.net.is_private_ip import is_private_ip # Test cases for is_private_ip diff --git a/aikido_zen/sources/functions/request_handler_test.py b/aikido_zen/sources/functions/request_handler_test.py index db76717f3..9af37f31f 100644 --- a/aikido_zen/sources/functions/request_handler_test.py +++ b/aikido_zen/sources/functions/request_handler_test.py @@ -42,7 +42,7 @@ def __init__(self): def send_data_to_bg_process(self, action, obj, receive=False, timeout_in_sec=0.1): if action != "CHECK_FIREWALL_LISTS": return {"success": False} - res = process_check_firewall_lists(self.conn_manager, obj, None, None) + res = process_check_firewall_lists(self.conn_manager, obj) return { "success": True, "data": res, diff --git a/aikido_zen/storage/firewall_lists.py b/aikido_zen/storage/firewall_lists.py index 68293a6b7..9381e82c7 100644 --- a/aikido_zen/storage/firewall_lists.py +++ b/aikido_zen/storage/firewall_lists.py @@ -1,7 +1,7 @@ from aikido_zen.helpers.ip_matcher import IPMatcher import regex as re -from aikido_zen.vulnerabilities.ssrf.is_private_ip import is_private_ip +from aikido_zen.helpers.net.is_private_ip import is_private_ip class FirewallLists: diff --git a/aikido_zen/vulnerabilities/__init__.py b/aikido_zen/vulnerabilities/__init__.py index bf37a49f7..5f428c7eb 100644 --- a/aikido_zen/vulnerabilities/__init__.py +++ b/aikido_zen/vulnerabilities/__init__.py @@ -29,6 +29,9 @@ from .path_traversal.check_context_for_path_traversal import ( check_context_for_path_traversal, ) +from ..background_process.commands import PutEventCommand +from ..helpers.create_detected_attack_api_event import create_detected_attack_api_event +from ..helpers.ipc.send_payload import send_payload def run_vulnerability_scan(kind, op, args): @@ -97,17 +100,17 @@ def run_vulnerability_scan(kind, op, args): logger.debug("Exception occurred in run_vulnerability_scan : %s", e) if injection_results: - logger.debug("Injection results : %s", serialize_to_json(injection_results)) - blocked = is_blocking_enabled() operation = injection_results["operation"] thread_cache.stats.on_detected_attack(blocked, operation) stack = get_clean_stacktrace() - if comms: - comms.send_data_to_bg_process( - "ATTACK", (injection_results, context, blocked, stack) - ) + event = create_detected_attack_api_event( + injection_results, context, blocked, stack + ) + logger.debug("Attack: %s", serialize_to_json(event)[:5000]) + if comms and event: + send_payload(comms, PutEventCommand.generate(event)) if blocked: raise error_type(*error_args) diff --git a/aikido_zen/vulnerabilities/init_test.py b/aikido_zen/vulnerabilities/init_test.py index da7b13e1c..5b7a811aa 100644 --- a/aikido_zen/vulnerabilities/init_test.py +++ b/aikido_zen/vulnerabilities/init_test.py @@ -144,13 +144,29 @@ def test_sql_injection_with_comms(caplog, get_context, monkeypatch): ) mock_comms.send_data_to_bg_process.assert_called_once() call_args = mock_comms.send_data_to_bg_process.call_args[0] - assert call_args[0] == "ATTACK" - assert call_args[1][0]["kind"] == "sql_injection" - assert ( - call_args[1][0]["metadata"]["sql"] - == "INSERT * INTO VALUES ('doggoss2', TRUE);" - ) - assert call_args[1][0]["metadata"]["dialect"] == "mysql" + assert call_args[0] == "put_event" + assert call_args[1].event["request"] == { + "ipAddress": "198.51.100.23", + "method": "GET", + "route": "/hello", + "source": "flask", + "url": "http://localhost:8080/hello", + "userAgent": None, + } + del call_args[1].event["attack"]["stack"] # Hard to test + assert call_args[1].event["attack"] == { + "blocked": True, + "kind": "sql_injection", + "metadata": { + "dialect": "mysql", + "sql": "INSERT * INTO VALUES ('doggoss2', TRUE);", + }, + "operation": "test_op", + "pathToPayload": ".test_input_sql", + "payload": '"doggoss2\', TRUE"', + "source": "body", + "user": None, + } def test_ssrf_vulnerability_scan_adds_hostname(get_context): diff --git a/aikido_zen/vulnerabilities/ssrf/inspect_getaddrinfo_result.py b/aikido_zen/vulnerabilities/ssrf/inspect_getaddrinfo_result.py index d20e52d9d..67e52a5bb 100644 --- a/aikido_zen/vulnerabilities/ssrf/inspect_getaddrinfo_result.py +++ b/aikido_zen/vulnerabilities/ssrf/inspect_getaddrinfo_result.py @@ -7,7 +7,7 @@ from aikido_zen.helpers.logging import logger from aikido_zen.thread.thread_cache import get_cache from .imds import resolves_to_imds_ip -from .is_private_ip import is_private_ip +from aikido_zen.helpers.net.is_private_ip import is_private_ip from .find_hostname_in_context import find_hostname_in_context from .extract_ip_array_from_results import extract_ip_array_from_results from .is_redirect_to_private_ip import is_redirect_to_private_ip diff --git a/end2end/django_mysql_test.py b/end2end/django_mysql_test.py index 2ffed17e7..ed791dddd 100644 --- a/end2end/django_mysql_test.py +++ b/end2end/django_mysql_test.py @@ -104,7 +104,7 @@ def test_initial_heartbeat(): "method": "POST", "path": "/app/create" }], - packages={'wrapt', 'asgiref', 'aikido_zen', 'django', 'sqlparse', 'mysqlclient'} + packages={'wrapt', 'asgiref', 'aikido_zen', 'django', 'sqlparse', 'mysqlclient', 'regex'} ) req_stats = heartbeat_events[0]["stats"]["requests"] assert req_stats["aborted"] == 0