diff --git a/models/src/agent_control_models/__init__.py b/models/src/agent_control_models/__init__.py index 38c80d63..42286b0c 100644 --- a/models/src/agent_control_models/__init__.py +++ b/models/src/agent_control_models/__init__.py @@ -83,6 +83,8 @@ from .server import ( AgentRef, AgentSummary, + CloneControlRequest, + CloneControlResponse, ConflictMode, ControlSummary, ControlVersionSummary, @@ -95,9 +97,11 @@ ListAgentsResponse, ListControlsResponse, ListControlVersionsResponse, + ListPublishedControlsResponse, PaginationInfo, PatchControlRequest, PatchControlResponse, + PublishedControlSummary, RenderControlTemplateRequest, RenderControlTemplateResponse, StepKey, @@ -165,6 +169,8 @@ # Server models "AgentRef", "AgentSummary", + "CloneControlRequest", + "CloneControlResponse", "ConflictMode", "ControlVersionSummary", "ControlSummary", @@ -177,9 +183,11 @@ "ListAgentsResponse", "ListControlVersionsResponse", "ListControlsResponse", + "ListPublishedControlsResponse", "PaginationInfo", "PatchControlRequest", "PatchControlResponse", + "PublishedControlSummary", "RenderControlTemplateRequest", "RenderControlTemplateResponse", "StepKey", diff --git a/models/src/agent_control_models/errors.py b/models/src/agent_control_models/errors.py index 7edf010e..cc7215c6 100644 --- a/models/src/agent_control_models/errors.py +++ b/models/src/agent_control_models/errors.py @@ -69,6 +69,7 @@ class ErrorCode(StrEnum): CONTROL_NAME_CONFLICT = "CONTROL_NAME_CONFLICT" EVALUATOR_NAME_CONFLICT = "EVALUATOR_NAME_CONFLICT" CONTROL_IN_USE = "CONTROL_IN_USE" + CONTROL_PUBLISHED = "CONTROL_PUBLISHED" CONTROL_TEMPLATE_CONFLICT = "CONTROL_TEMPLATE_CONFLICT" EVALUATOR_IN_USE = "EVALUATOR_IN_USE" SCHEMA_INCOMPATIBLE = "SCHEMA_INCOMPATIBLE" @@ -373,6 +374,7 @@ def make_error_type(error_code: ErrorCode) -> str: ErrorCode.CONTROL_NAME_CONFLICT: "Control Name Already Exists", ErrorCode.EVALUATOR_NAME_CONFLICT: "Evaluator Name Conflict", ErrorCode.CONTROL_IN_USE: "Control In Use", + ErrorCode.CONTROL_PUBLISHED: "Published Control Conflict", ErrorCode.CONTROL_TEMPLATE_CONFLICT: "Control Template Conflict", ErrorCode.EVALUATOR_IN_USE: "Evaluator In Use", ErrorCode.SCHEMA_INCOMPATIBLE: "Schema Incompatible", diff --git a/models/src/agent_control_models/server.py b/models/src/agent_control_models/server.py index 55d19a1a..b83c74c2 100644 --- a/models/src/agent_control_models/server.py +++ b/models/src/agent_control_models/server.py @@ -515,6 +515,45 @@ class ListControlsResponse(BaseModel): pagination: PaginationInfo = Field(..., description="Pagination metadata") +class PublishedControlSummary(BaseModel): + """Summary of a published control in the default store.""" + + id: int = Field(..., description="Control ID") + name: str = Field(..., description="Control name") + description: str | None = Field(None, description="Control description") + enabled: bool = Field(True, description="Whether control is enabled") + execution: str | None = Field(None, description="'server' or 'sdk'") + step_types: list[str] | None = Field(None, description="Step types in scope") + stages: list[str] | None = Field(None, description="Evaluation stages in scope") + tags: list[str] = Field(default_factory=list, description="Control tags") + template_backed: bool = Field( + False, + description="Whether the control was created from a template", + ) + template_rendered: bool | None = Field( + None, + description=( + "Whether a template-backed control has been rendered. " + "True for rendered templates, False for unrendered templates, " + "None for non-template controls." + ), + ) + published_at: str = Field( + ..., + description="ISO 8601 timestamp when the control was published to the default store", + ) + + +class ListPublishedControlsResponse(BaseModel): + """Response for listing controls published in the default store.""" + + controls: list[PublishedControlSummary] = Field( + ..., + description="List of published control summaries", + ) + pagination: PaginationInfo = Field(..., description="Pagination metadata") + + class ControlVersionSummary(BaseModel): """Summary of a single control version.""" @@ -585,3 +624,26 @@ class PatchControlResponse(BaseModel): enabled: bool | None = Field( None, description="Current enabled status (if control has data configured)" ) + + +class CloneControlRequest(BaseModel): + """Request to clone a control.""" + + name: SlugName | None = Field( + None, + description=( + "Optional name for the cloned control. If omitted, the server generates " + "a unique copy name." + ), + ) + + +class CloneControlResponse(BaseModel): + """Response for cloning a control.""" + + control_id: int = Field(..., description="Identifier of the cloned control") + name: str = Field(..., description="Name assigned to the cloned control") + cloned_control_id: int = Field( + ..., + description="Identifier of the source control the clone was created from", + ) diff --git a/server/alembic/versions/7d9c2f1a3b44_control_store_publish_and_clone.py b/server/alembic/versions/7d9c2f1a3b44_control_store_publish_and_clone.py new file mode 100644 index 00000000..3e3cfb8a --- /dev/null +++ b/server/alembic/versions/7d9c2f1a3b44_control_store_publish_and_clone.py @@ -0,0 +1,107 @@ +"""add control store publication tables and clone provenance + +Revision ID: 7d9c2f1a3b44 +Revises: c1e9f9c4a1d2 +Create Date: 2026-04-15 16:30:00.000000 + +""" + +from __future__ import annotations + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "7d9c2f1a3b44" +down_revision = "c1e9f9c4a1d2" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "control_stores", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), + ) + + op.create_table( + "control_stores_controls", + sa.Column("store_id", sa.Integer(), nullable=False), + sa.Column("control_id", sa.Integer(), nullable=False), + sa.Column( + "published_at", + sa.DateTime(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.ForeignKeyConstraint(["control_id"], ["controls.id"]), + sa.ForeignKeyConstraint(["store_id"], ["control_stores.id"]), + sa.PrimaryKeyConstraint("store_id", "control_id"), + ) + op.create_index( + "idx_control_stores_controls_store_published", + "control_stores_controls", + ["store_id", "published_at", "control_id"], + unique=False, + ) + op.create_index( + "idx_control_stores_controls_control", + "control_stores_controls", + ["control_id"], + unique=False, + ) + + op.add_column( + "controls", + sa.Column( + "cloned_control_id", + sa.Integer(), + nullable=True, + ), + ) + op.create_foreign_key( + "fk_controls_cloned_control_id_controls", + "controls", + "controls", + ["cloned_control_id"], + ["id"], + ) + + control_stores = sa.table( + "control_stores", + sa.column("id", sa.Integer()), + sa.column("name", sa.String()), + ) + op.get_bind().execute( + sa.insert(control_stores).values(name="default") + ) + + +def downgrade() -> None: + op.drop_constraint( + "fk_controls_cloned_control_id_controls", + "controls", + type_="foreignkey", + ) + op.drop_column("controls", "cloned_control_id") + + op.drop_index( + "idx_control_stores_controls_control", + table_name="control_stores_controls", + ) + op.drop_index( + "idx_control_stores_controls_store_published", + table_name="control_stores_controls", + ) + op.drop_table("control_stores_controls") + op.drop_table("control_stores") diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index 899a1bac..db1f356d 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -1212,7 +1212,27 @@ async def add_agent_control( """Associate a control directly with an agent (idempotent).""" agent = await _get_agent_or_404(agent_name, db) control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id) + control = await control_service.get_active_control_or_404(control_id, for_update=True) + if await control_service.is_control_published(control_id): + raise ConflictError( + error_code=ErrorCode.CONTROL_PUBLISHED, + detail=( + f"Control '{control.name}' is published in the Control Store and " + "cannot be attached directly to an agent" + ), + resource="Control", + resource_id=str(control_id), + hint="Clone the published control first, then attach the clone to the agent.", + errors=[ + ValidationErrorItem( + resource="Control", + field="control_id", + code="published_control_conflict", + message="Published controls must be cloned before agent association.", + value=control_id, + ) + ], + ) validation_errors = _validate_controls_for_agent(agent, [control]) if validation_errors: diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index c40bb39b..aea5a000 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -1,10 +1,14 @@ import datetime as dt +from typing import cast from agent_control_engine import list_evaluators from agent_control_models import ControlDefinition, TemplateControlInput, UnrenderedTemplateControl from agent_control_models.errors import ErrorCode, ValidationErrorItem from agent_control_models.server import ( AgentRef, + AssocResponse, + CloneControlRequest, + CloneControlResponse, ControlSummary, ControlVersionSummary, CreateControlRequest, @@ -16,9 +20,11 @@ GetControlVersionResponse, ListControlsResponse, ListControlVersionsResponse, + ListPublishedControlsResponse, PaginationInfo, PatchControlRequest, PatchControlResponse, + PublishedControlSummary, RenderControlTemplateRequest, RenderControlTemplateResponse, SetControlDataRequest, @@ -30,7 +36,7 @@ from jsonschema_rs import ValidationError as JSONSchemaValidationError from pydantic import ValidationError from sqlalchemy import select -from sqlalchemy.exc import IntegrityError +from sqlalchemy.exc import IntegrityError, OperationalError, ProgrammingError from sqlalchemy.ext.asyncio import AsyncSession from ..auth import require_admin_key @@ -68,6 +74,7 @@ router = APIRouter(prefix="/controls", tags=["controls"]) template_router = APIRouter(prefix="/control-templates", tags=["controls"]) +store_router = APIRouter(prefix="/control-stores/default/controls", tags=["control-stores"]) _logger = get_logger(__name__) @@ -155,6 +162,93 @@ def _enabled_from_stored_payload(data: object) -> bool: return raw_enabled if type(raw_enabled) is bool else True +def _control_description_from_data(data: dict[str, object]) -> str | None: + """Return a human-friendly control description from stored JSON.""" + template = data.get("template") + template_description = template.get("description") if isinstance(template, dict) else None + description = data.get("description") + return cast(str | None, description or template_description) + + +def _control_scope_from_data(data: dict[str, object]) -> dict[str, object]: + """Return the stored scope object when present.""" + raw_scope = data.get("scope") + return raw_scope if isinstance(raw_scope, dict) else {} + + +def _build_control_summary( + control_id: int, + name: str, + data: dict[str, object], + *, + usage: AgentRef | None, + used_by_agents_count: int, +) -> ControlSummary: + """Build the admin control-browse summary from stored JSON.""" + scope = _control_scope_from_data(data) + return ControlSummary( + id=control_id, + name=name, + description=_control_description_from_data(data), + enabled=_enabled_from_stored_payload(data), + execution=cast(str | None, data.get("execution")), + step_types=cast(list[str] | None, scope.get("step_types")), + stages=cast(list[str] | None, scope.get("stages")), + tags=cast(list[str], data.get("tags", [])), + template_backed="template" in data, + template_rendered=("condition" in data if "template" in data else None), + used_by_agent=usage, + used_by_agents_count=used_by_agents_count, + ) + + +def _build_published_control_summary( + control_id: int, + name: str, + data: dict[str, object], + *, + published_at: dt.datetime, +) -> PublishedControlSummary: + """Build the published-store summary from stored JSON.""" + scope = _control_scope_from_data(data) + return PublishedControlSummary( + id=control_id, + name=name, + description=_control_description_from_data(data), + enabled=_enabled_from_stored_payload(data), + execution=cast(str | None, data.get("execution")), + step_types=cast(list[str] | None, scope.get("step_types")), + stages=cast(list[str] | None, scope.get("stages")), + tags=cast(list[str], data.get("tags", [])), + template_backed="template" in data, + template_rendered=("condition" in data if "template" in data else None), + published_at=published_at.isoformat(), + ) + + +def _published_control_conflict(control_id: int, *, action: str) -> ConflictError: + """Return the standard conflict for published controls on runtime paths.""" + return ConflictError( + error_code=ErrorCode.CONTROL_PUBLISHED, + detail=( + f"Control with ID '{control_id}' is published in the Control Store and " + f"cannot be used for runtime {action}" + ), + resource="Control", + resource_id=str(control_id), + hint="Clone the published control first, then associate the clone.", + errors=[ + ValidationErrorItem( + resource="Control", + field="control_id", + code="published_control_conflict", + message="Published controls must be cloned before runtime association.", + value=control_id, + ) + ], + ) + + def _template_backed_raw_update_conflict(control_id: int) -> ConflictError: """Return the v1 conflict raised when raw data updates target template-backed controls.""" return ConflictError( @@ -839,28 +933,14 @@ async def list_controls( # Build summaries (filtering already done at DB level) summaries: list[ControlSummary] = [] for ctrl in page.controls: - # Extract summary fields from JSONB data data = ctrl.data or {} - scope = data.get("scope") or {} usage = usage_by_control_id.get(ctrl.id) summaries.append( - ControlSummary( - id=ctrl.id, - name=ctrl.name, - description=( - data.get("description") - or (data.get("template") or {}).get("description") - ), - enabled=data.get("enabled", True), - execution=data.get("execution"), - step_types=scope.get("step_types"), - stages=scope.get("stages"), - tags=data.get("tags", []), - template_backed="template" in data, - template_rendered=( - "condition" in data if "template" in data else None - ), - used_by_agent=( + _build_control_summary( + ctrl.id, + ctrl.name, + data, + usage=( AgentRef(agent_name=usage.representative_agent_name) if usage is not None and usage.representative_agent_name is not None else None @@ -880,6 +960,271 @@ async def list_controls( ) +@store_router.post( + "/{control_id}", + dependencies=[Depends(require_admin_key)], + response_model=AssocResponse, + summary="Publish a control to the default store", + response_description="Success confirmation", +) +async def publish_control( + control_id: int, + db: AsyncSession = Depends(get_async_db), +) -> AssocResponse: + """Publish an active, valid, runtime-unassociated control to the default store.""" + control_service = ControlService(db) + control = await control_service.get_active_control_or_404(control_id, for_update=True) + control_name = control.name + _parse_stored_control_data( + control.data, + control_name=control_name, + control_id=control.id, + ) + + associations = await control_service.list_control_associations(control_id) + if associations.policy_ids or associations.agent_names: + raise ConflictError( + error_code=ErrorCode.CONTROL_IN_USE, + detail=( + f"Control '{control.name}' is already associated with " + f"{len(associations.policy_ids)} policy/policies and " + f"{len(associations.agent_names)} agent(s)" + ), + resource="Control", + resource_id=str(control_id), + hint=( + "Published controls must not have runtime associations. " + "Clone them for runtime use." + ), + ) + + try: + await control_service.publish_control(control_id) + await db.commit() + except Exception: + await db.rollback() + _logger.error( + "Failed to publish control '%s' (%s) to the default store", + control_name, + control_id, + exc_info=True, + ) + raise DatabaseError( + detail=f"Failed to publish control '{control_name}': database error", + resource="Control", + operation="publish", + ) + + return AssocResponse(success=True) + + +@store_router.delete( + "/{control_id}", + dependencies=[Depends(require_admin_key)], + response_model=AssocResponse, + summary="Unpublish a control from the default store", + response_description="Success confirmation", +) +async def unpublish_control( + control_id: int, + db: AsyncSession = Depends(get_async_db), +) -> AssocResponse: + """Remove a control from the default store idempotently.""" + control_service = ControlService(db) + control = await control_service.get_active_control_or_404(control_id, for_update=True) + control_name = control.name + + try: + await control_service.unpublish_control(control_id) + await db.commit() + except Exception: + await db.rollback() + _logger.error( + "Failed to unpublish control '%s' (%s) from the default store", + control_name, + control_id, + exc_info=True, + ) + raise DatabaseError( + detail=f"Failed to unpublish control '{control_name}': database error", + resource="Control", + operation="unpublish", + ) + + return AssocResponse(success=True) + + +@store_router.get( + "", + dependencies=[Depends(require_admin_key)], + response_model=ListPublishedControlsResponse, + summary="Browse controls published in the default store", + response_description="Paginated published control summaries", +) +async def list_published_controls( + cursor: str | None = Query( + None, + description="Opaque cursor from the previous page", + ), + limit: int = Query(_DEFAULT_PAGINATION_LIMIT, ge=1, le=_MAX_PAGINATION_LIMIT), + name: str | None = Query(None, description="Filter by name (partial, case-insensitive)"), + enabled: bool | None = Query(None, description="Filter by enabled status"), + tag: str | None = Query(None, description="Filter by tag"), + db: AsyncSession = Depends(get_async_db), +) -> ListPublishedControlsResponse: + """List default-store controls ordered by publication time descending.""" + try: + page = await ControlService(db).list_published_controls_page( + cursor=cursor, + limit=limit, + name=name, + enabled=enabled, + tag=tag, + ) + except (OperationalError, ProgrammingError, RuntimeError): + _logger.error( + "Failed to list published controls from the default store", + exc_info=True, + ) + raise DatabaseError( + detail="Failed to list published controls: database error", + resource="ControlStore", + operation="list", + ) + + return ListPublishedControlsResponse( + controls=[ + _build_published_control_summary( + published.control.id, + published.control.name, + published.control.data, + published_at=published.published_at, + ) + for published in page.controls + ], + pagination=PaginationInfo( + limit=limit, + total=page.total, + next_cursor=page.next_cursor, + has_more=page.has_more, + ), + ) + + +@router.post( + "/{control_id}/clone", + dependencies=[Depends(require_admin_key)], + response_model=CloneControlResponse, + summary="Clone a control", + response_description="Cloned control metadata", +) +async def clone_control( + control_id: int, + request: CloneControlRequest | None = None, + db: AsyncSession = Depends(get_async_db), +) -> CloneControlResponse: + """Clone an active control into a new independent control with provenance.""" + max_auto_name_retries = 5 + control_service = ControlService(db) + source_control = await control_service.get_active_control_or_404(control_id, for_update=True) + _parse_stored_control_data( + source_control.data, + control_name=source_control.name, + control_id=source_control.id, + ) + + requested_name = request.name if request is not None else None + if ( + requested_name is not None + and await control_service.active_control_name_exists(requested_name) + ): + raise ConflictError( + error_code=ErrorCode.CONTROL_NAME_CONFLICT, + detail=f"Control with name '{requested_name}' already exists", + resource="Control", + resource_id=requested_name, + hint="Choose a different name or update the existing control.", + ) + + source_name = source_control.name + auto_name_attempts = 0 + + while True: + target_name = ( + requested_name + if requested_name is not None + else await control_service.generate_unique_clone_name(source_control.name) + ) + try: + cloned_control = await control_service.clone_control( + source_control=source_control, + name=target_name, + ) + await db.commit() + break + except IntegrityError as exc: + await db.rollback() + if _is_control_name_conflict(exc): + if requested_name is not None: + raise ConflictError( + error_code=ErrorCode.CONTROL_NAME_CONFLICT, + detail=f"Control with name '{target_name}' already exists", + resource="Control", + resource_id=target_name, + hint="Choose a different name or update the existing control.", + ) + auto_name_attempts += 1 + if auto_name_attempts >= max_auto_name_retries: + raise ConflictError( + error_code=ErrorCode.CONTROL_NAME_CONFLICT, + detail=f"Failed to generate a unique clone name for '{source_name}'", + resource="Control", + resource_id=str(control_id), + hint="Retry the clone request or provide an explicit name.", + ) + source_control = await control_service.get_active_control_or_404( + control_id, + for_update=True, + ) + _parse_stored_control_data( + source_control.data, + control_name=source_control.name, + control_id=source_control.id, + ) + source_name = source_control.name + continue + _logger.error( + "Failed to clone control '%s' (%s) due to integrity error", + source_name, + control_id, + exc_info=True, + ) + raise DatabaseError( + detail=f"Failed to clone control '{source_name}': database error", + resource="Control", + operation="clone", + ) + except Exception: + await db.rollback() + _logger.error( + "Failed to clone control '%s' (%s)", + source_name, + control_id, + exc_info=True, + ) + raise DatabaseError( + detail=f"Failed to clone control '{source_name}': database error", + resource="Control", + operation="clone", + ) + + return CloneControlResponse( + control_id=cloned_control.id, + name=cloned_control.name, + cloned_control_id=source_control.id, + ) + + @router.delete( "/{control_id}", dependencies=[Depends(require_admin_key)], @@ -975,6 +1320,7 @@ async def delete_control( control_service.mark_control_deleted(control, deleted_at=dt.datetime.now(dt.UTC)) control_name = control.name try: + await control_service.remove_all_store_publications(control_id) await control_service.create_version( control, event_type="deleted", diff --git a/server/src/agent_control_server/endpoints/policies.py b/server/src/agent_control_server/endpoints/policies.py index 7b8b2ef9..d1266c8b 100644 --- a/server/src/agent_control_server/endpoints/policies.py +++ b/server/src/agent_control_server/endpoints/policies.py @@ -1,4 +1,4 @@ -from agent_control_models.errors import ErrorCode +from agent_control_models.errors import ErrorCode, ValidationErrorItem from agent_control_models.server import ( AssocResponse, CreatePolicyRequest, @@ -119,7 +119,27 @@ async def add_control_to_policy( ) control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id) + control = await control_service.get_active_control_or_404(control_id, for_update=True) + if await control_service.is_control_published(control_id): + raise ConflictError( + error_code=ErrorCode.CONTROL_PUBLISHED, + detail=( + f"Control '{control.name}' is published in the Control Store and " + "cannot be attached directly to a policy" + ), + resource="Control", + resource_id=str(control_id), + hint="Clone the published control first, then attach the clone to the policy.", + errors=[ + ValidationErrorItem( + resource="Control", + field="control_id", + code="published_control_conflict", + message="Published controls must be cloned before policy association.", + value=control_id, + ) + ], + ) # Add association using INSERT ... ON CONFLICT DO NOTHING for idempotency try: diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index 7f3ac718..868f6524 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -20,6 +20,7 @@ from .db import AsyncSessionLocal from .endpoints.agents import router as agent_router from .endpoints.controls import router as control_router +from .endpoints.controls import store_router as control_store_router from .endpoints.controls import template_router as control_template_router from .endpoints.evaluation import router as evaluation_router from .endpoints.evaluators import router as evaluator_router @@ -208,6 +209,11 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- prefix=api_v1_prefix, dependencies=[Depends(require_api_key)], ) +app.include_router( + control_store_router, + prefix=api_v1_prefix, + dependencies=[Depends(require_api_key)], +) app.include_router( control_template_router, prefix=api_v1_prefix, diff --git a/server/src/agent_control_server/models.py b/server/src/agent_control_server/models.py index 583e6181..7227cf90 100644 --- a/server/src/agent_control_server/models.py +++ b/server/src/agent_control_server/models.py @@ -56,6 +56,27 @@ class AgentData(BaseModel): Column("control_id", ForeignKey("controls.id"), primary_key=True, index=True), ) +# Association table for ControlStore <> Control publication relationship +control_stores_controls: Table = Table( + "control_stores_controls", + Base.metadata, + Column("store_id", ForeignKey("control_stores.id"), primary_key=True, index=True), + Column("control_id", ForeignKey("controls.id"), primary_key=True, index=True), + Column( + "published_at", + DateTime(timezone=True), + server_default=text("CURRENT_TIMESTAMP"), + nullable=False, + ), + Index( + "idx_control_stores_controls_store_published", + "store_id", + "published_at", + "control_id", + ), + Index("idx_control_stores_controls_control", "control_id"), +) + class Policy(Base): __tablename__ = "policies" @@ -71,6 +92,23 @@ class Policy(Base): ) +class ControlStore(Base): + __tablename__ = "control_stores" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) + created_at: Mapped[dt.datetime] = mapped_column( + DateTime(timezone=True), + server_default=text("CURRENT_TIMESTAMP"), + nullable=False, + ) + controls: Mapped[list["Control"]] = relationship( + "Control", + secondary=lambda: control_stores_controls, + back_populates="stores", + ) + + class Control(Base): __tablename__ = "controls" __table_args__ = ( @@ -85,6 +123,11 @@ class Control(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) name: Mapped[str] = mapped_column(String(255), nullable=False) + cloned_control_id: Mapped[int | None] = mapped_column( + Integer, + ForeignKey("controls.id"), + nullable=True, + ) # JSONB payload describing control specifics data: Mapped[dict[str, Any]] = mapped_column( JSONB, server_default=text("'{}'::jsonb"), nullable=False @@ -100,6 +143,11 @@ class Control(Base): agents: Mapped[list["Agent"]] = relationship( "Agent", secondary=lambda: agent_controls, back_populates="controls" ) + stores: Mapped[list["ControlStore"]] = relationship( + "ControlStore", + secondary=lambda: control_stores_controls, + back_populates="controls", + ) class ControlVersion(Base): diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index 73dee6e8..77e9020b 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -2,6 +2,7 @@ import datetime as dt from collections.abc import Sequence +from copy import deepcopy from dataclasses import dataclass from typing import Any, Literal, cast @@ -13,13 +14,35 @@ from agent_control_models.errors import ErrorCode, ValidationErrorItem from agent_control_models.policy import Control as APIControl from pydantic import ValidationError -from sqlalchemy import Integer, String, delete, func, literal, or_, select, union, union_all +from sqlalchemy import ( + Integer, + String, + and_, + delete, + func, + insert, + literal, + or_, + select, + union, + union_all, +) from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.exc import OperationalError, ProgrammingError from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import load_only from sqlalchemy.sql import Select from ..errors import APIValidationError, NotFoundError -from ..models import Control, ControlVersion, agent_controls, agent_policies, policy_controls +from ..models import ( + Control, + ControlStore, + ControlVersion, + agent_controls, + agent_policies, + control_stores_controls, + policy_controls, +) from .control_definitions import ( parse_control_definition_or_api_error, parse_runtime_control_definition_or_api_error, @@ -29,6 +52,9 @@ type AgentControlRenderedState = Literal["rendered", "unrendered", "all"] type AgentControlEnabledState = Literal["enabled", "disabled", "all"] +_DEFAULT_CONTROL_STORE_NAME = "default" +_MAX_CONTROL_NAME_LENGTH = 255 + @dataclass(frozen=True) class RuntimeControl: @@ -59,6 +85,24 @@ class ControlListPage: next_cursor: str | None +@dataclass(frozen=True) +class PublishedControlRow: + """Published control row plus store publication metadata.""" + + control: Control + published_at: dt.datetime + + +@dataclass(frozen=True) +class PublishedControlPage: + """Paginated published-control results for the default store.""" + + controls: list[PublishedControlRow] + total: int + has_more: bool + next_cursor: str | None + + @dataclass(frozen=True) class ControlUsage: """Usage attribution summary for a listed control.""" @@ -89,11 +133,33 @@ class ControlService: def __init__(self, db: AsyncSession) -> None: self._db = db - def create_control(self, *, name: str, data: dict[str, Any]) -> Control: - """Create a new pending control row.""" - control = Control(name=name, data=data) - self._db.add(control) - return control + @staticmethod + def _with_control_runtime_columns(stmt: Select[Any]) -> Select[Any]: + """Restrict control reads to columns that exist before Phase 3.""" + return stmt.options( + load_only( + Control.id, + Control.name, + Control.data, + Control.deleted_at, + ) + ) + + def create_control( + self, + *, + name: str, + data: dict[str, Any], + cloned_control_id: int | None = None, + ) -> Control: + """Create a pending control object to be inserted with its first version.""" + control_kwargs: dict[str, Any] = { + "name": name, + "data": deepcopy(data), + } + if cloned_control_id is not None: + control_kwargs["cloned_control_id"] = cloned_control_id + return Control(**control_kwargs) @staticmethod def rename_control(control: Control, *, name: str) -> None: @@ -124,7 +190,9 @@ async def get_control_or_404( for_update: bool = False, ) -> Control: """Load any control row, including soft-deleted controls.""" - stmt = select(Control).where(Control.id == control_id) + stmt = self._with_control_runtime_columns( + select(Control).where(Control.id == control_id) + ) if for_update: stmt = stmt.with_for_update() result = await self._db.execute(stmt) @@ -146,7 +214,9 @@ async def get_active_control_or_404( for_update: bool = False, ) -> Control: """Load an active control row or raise CONTROL_NOT_FOUND.""" - stmt = select(Control).where(Control.id == control_id, Control.deleted_at.is_(None)) + stmt = self._with_control_runtime_columns( + select(Control).where(Control.id == control_id, Control.deleted_at.is_(None)) + ) if for_update: stmt = stmt.with_for_update() result = await self._db.execute(stmt) @@ -174,6 +244,94 @@ async def active_control_name_exists( result = await self._db.execute(stmt) return result.first() is not None + async def generate_unique_clone_name(self, source_name: str) -> str: + """Return a unique active-control name for a cloned control.""" + copy_index = 1 + while True: + suffix = "-copy" if copy_index == 1 else f"-copy-{copy_index}" + max_base_length = max(1, _MAX_CONTROL_NAME_LENGTH - len(suffix)) + candidate = f"{source_name[:max_base_length]}{suffix}" + if not await self.active_control_name_exists(candidate): + return candidate + copy_index += 1 + + async def clone_control( + self, + *, + source_control: Control, + name: str | None = None, + ) -> Control: + """Create a cloned control with provenance and initial version history.""" + target_name = name or await self.generate_unique_clone_name(source_control.name) + source_version_num = await self._latest_version_num(source_control.id) + cloned_control = self.create_control( + name=target_name, + data=source_control.data, + cloned_control_id=source_control.id, + ) + await self.create_version( + cloned_control, + event_type="cloned", + note=( + f"Cloned from '{source_control.name}' " + f"(id:{source_control.id}) at version {source_version_num}" + ), + ) + return cloned_control + + async def is_control_published(self, control_id: int) -> bool: + """Return whether a control is currently published in the default store.""" + default_store_id = await self._get_default_store_id(required=False) + if default_store_id is None: + return False + + try: + async with self._db.begin_nested(): + result = await self._db.execute( + select(control_stores_controls.c.control_id).where( + control_stores_controls.c.store_id == default_store_id, + control_stores_controls.c.control_id == control_id, + ) + ) + except (OperationalError, ProgrammingError) as exc: + if _is_missing_control_store_schema_error(exc): + return False + raise + return result.first() is not None + + async def publish_control(self, control_id: int) -> None: + """Publish a control to the default store idempotently.""" + default_store_id = await self._get_default_store_id() + await self._db.execute( + pg_insert(control_stores_controls) + .values(store_id=default_store_id, control_id=control_id) + .on_conflict_do_nothing() + ) + + async def unpublish_control(self, control_id: int) -> None: + """Remove a control from the default store idempotently.""" + default_store_id = await self._get_default_store_id() + await self._db.execute( + delete(control_stores_controls).where( + control_stores_controls.c.store_id == default_store_id, + control_stores_controls.c.control_id == control_id, + ) + ) + + async def remove_all_store_publications(self, control_id: int) -> None: + """Remove all store publication rows for a control.""" + try: + async with self._db.begin_nested(): + await self._db.execute( + delete(control_stores_controls).where( + control_stores_controls.c.control_id == control_id + ) + ) + except (OperationalError, ProgrammingError) as exc: + if _is_missing_control_store_schema_error(exc): + return + raise + async def create_version( self, control: Control, @@ -183,14 +341,17 @@ async def create_version( ) -> ControlVersion: """Append a new immutable version row for the current control state.""" await self._db.flush() + if control.id is None: + await self._insert_control_row(control) await self._lock_control_row(control.id) + cloned_control_id = await self._get_snapshot_cloned_control_id(control) next_version_num = await self._next_version_num(control.id) version = ControlVersion( control_id=control.id, version_num=next_version_num, event_type=event_type, - snapshot=self._build_snapshot(control), + snapshot=self._build_snapshot(control, cloned_control_id=cloned_control_id), note=note, ) self._db.add(version) @@ -265,7 +426,7 @@ async def get_version_or_404(self, control_id: int, version_num: int) -> Control async def list_controls_for_policy(self, policy_id: int) -> list[Control]: """Return DB control rows directly associated with a policy.""" - stmt = ( + stmt = self._with_control_runtime_columns( select(Control) .join(policy_controls, Control.id == policy_controls.c.control_id) .where(policy_controls.c.policy_id == policy_id, Control.deleted_at.is_(None)) @@ -370,7 +531,9 @@ async def list_controls_page( tag: str | None, ) -> ControlListPage: """Return paginated active controls for the browse endpoint.""" - query = select(Control).where(Control.deleted_at.is_(None)).order_by(Control.id.desc()) + query = self._with_control_runtime_columns( + select(Control).where(Control.deleted_at.is_(None)).order_by(Control.id.desc()) + ) query = self._apply_control_list_filters( query, name=name, @@ -416,6 +579,94 @@ async def list_controls_page( next_cursor=next_cursor, ) + async def list_published_controls_page( + self, + *, + cursor: str | None, + limit: int, + name: str | None, + enabled: bool | None, + tag: str | None, + ) -> PublishedControlPage: + """Return paginated published controls from the default store.""" + default_store_id = await self._get_default_store_id() + query = ( + select(Control, control_stores_controls.c.published_at) + .join(control_stores_controls, Control.id == control_stores_controls.c.control_id) + .where( + control_stores_controls.c.store_id == default_store_id, + Control.deleted_at.is_(None), + ) + .order_by( + control_stores_controls.c.published_at.desc(), + control_stores_controls.c.control_id.desc(), + ) + ) + query = self._apply_published_control_filters( + query, + name=name, + enabled=enabled, + tag=tag, + ) + + if cursor is not None: + cursor_published_at, cursor_control_id = _parse_published_control_cursor(cursor) + query = query.where( + or_( + control_stores_controls.c.published_at < cursor_published_at, + and_( + control_stores_controls.c.published_at == cursor_published_at, + control_stores_controls.c.control_id < cursor_control_id, + ), + ) + ) + + result = await self._db.execute(query.limit(limit + 1)) + rows = result.all() + controls = [ + PublishedControlRow( + control=cast(Control, row[0]), + published_at=cast(dt.datetime, row[1]), + ) + for row in rows + ] + + total_query = ( + select(func.count()) + .select_from(control_stores_controls) + .join(Control, Control.id == control_stores_controls.c.control_id) + .where( + control_stores_controls.c.store_id == default_store_id, + Control.deleted_at.is_(None), + ) + ) + total_query = self._apply_published_control_filters( + total_query, + name=name, + enabled=enabled, + tag=tag, + ) + total_result = await self._db.execute(total_query) + total = cast(int, total_result.scalar_one()) + + has_more = len(controls) > limit + if has_more: + controls = controls[:-1] + + next_cursor: str | None = None + if has_more and controls: + next_cursor = _build_published_control_cursor( + controls[-1].published_at, + controls[-1].control.id, + ) + + return PublishedControlPage( + controls=controls, + total=total, + has_more=has_more, + next_cursor=next_cursor, + ) + async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, ControlUsage]: """Return representative agent usage and usage counts for the provided controls.""" if not control_ids: @@ -593,6 +844,48 @@ async def remove_all_control_associations(self, control_id: int) -> ControlAssoc ) return associations + async def _get_default_store_id(self, *, required: bool = True) -> int | None: + """Return the seeded default-store ID.""" + try: + if required: + result = await self._db.execute( + select(ControlStore.id).where(ControlStore.name == _DEFAULT_CONTROL_STORE_NAME) + ) + else: + async with self._db.begin_nested(): + result = await self._db.execute( + select(ControlStore.id).where( + ControlStore.name == _DEFAULT_CONTROL_STORE_NAME + ) + ) + except (OperationalError, ProgrammingError) as exc: + if _is_missing_control_store_schema_error(exc): + if not required: + return None + raise RuntimeError( + "Default control store is unavailable; run the Phase 3 migration " + "before using control-store endpoints." + ) from exc + raise + store_id = result.scalar_one_or_none() + if store_id is None: + if not required: + return None + raise RuntimeError( + "Default control store is missing; run the Phase 3 migration before using " + "control-store endpoints." + ) + return cast(int, store_id) + + async def _latest_version_num(self, control_id: int) -> int: + """Return the latest version number currently stored for a control.""" + result = await self._db.execute( + select(func.coalesce(func.max(ControlVersion.version_num), 0)).where( + ControlVersion.control_id == control_id + ) + ) + return cast(int, result.scalar_one()) + async def _next_version_num(self, control_id: int) -> int: """Compute the next monotonically increasing version number for a control.""" result = await self._db.execute( @@ -608,6 +901,39 @@ async def _lock_control_row(self, control_id: int) -> None: select(Control.id).where(Control.id == control_id).with_for_update() ) + async def _insert_control_row(self, control: Control) -> None: + """Insert a new control row without requiring post-Phase-3 columns.""" + insert_values: dict[str, Any] = { + "name": control.name, + "data": deepcopy(control.data), + } + if control.deleted_at is not None: + insert_values["deleted_at"] = control.deleted_at + cloned_control_id = cast(int | None, control.__dict__.get("cloned_control_id")) + if cloned_control_id is not None: + insert_values["cloned_control_id"] = cloned_control_id + + result = await self._db.execute( + insert(Control.__table__).values(**insert_values).returning(Control.id) + ) + control.id = cast(int, result.scalar_one()) + + async def _get_snapshot_cloned_control_id(self, control: Control) -> int | None: + """Load clone provenance for version snapshots with rollout-safe fallback.""" + if "cloned_control_id" in control.__dict__: + return cast(int | None, control.__dict__["cloned_control_id"]) + + try: + async with self._db.begin_nested(): + result = await self._db.execute( + select(Control.cloned_control_id).where(Control.id == control.id) + ) + except (OperationalError, ProgrammingError) as exc: + if _is_missing_cloned_control_id_schema_error(exc): + return None + raise + return cast(int | None, result.scalar_one()) + async def _list_db_controls_for_agent(self, agent_name: str) -> Sequence[Control]: """Return DB control rows associated with an agent.""" policy_control_ids = ( @@ -624,7 +950,7 @@ async def _list_db_controls_for_agent(self, agent_name: str) -> Sequence[Control ) control_ids_subquery = union(policy_control_ids, direct_control_ids).subquery() - stmt = ( + stmt = self._with_control_runtime_columns( select(Control) .join(control_ids_subquery, Control.id == control_ids_subquery.c.control_id) .where(Control.deleted_at.is_(None)) @@ -696,11 +1022,44 @@ def _apply_control_list_filters( return stmt + def _apply_published_control_filters( + self, + stmt: Select[Any], + *, + name: str | None, + enabled: bool | None, + tag: str | None, + ) -> Select[Any]: + """Apply published-store browse filters to a query.""" + if name is not None: + stmt = stmt.where( + Control.name.ilike(f"%{escape_like_pattern(name)}%", escape="\\") + ) + + if enabled is not None: + if enabled: + stmt = stmt.where( + or_( + Control.data["enabled"].astext == "true", + ~Control.data.has_key("enabled"), + ) + ) + else: + stmt = stmt.where(Control.data["enabled"].astext == "false") + + if tag is not None: + stmt = stmt.where(Control.data["tags"].contains([tag])) + + return stmt + @staticmethod - def _build_snapshot(control: Control) -> dict[str, Any]: + def _build_snapshot( + control: Control, + *, + cloned_control_id: int | None, + ) -> dict[str, Any]: """Serialize the persisted control state stored in version history.""" deleted_at = control.deleted_at.isoformat() if control.deleted_at is not None else None - cloned_control_id = cast(int | None, getattr(control, "cloned_control_id", None)) return { "name": control.name, "data": control.data, @@ -718,6 +1077,35 @@ def _is_unrendered_template_payload(data: object) -> bool: ) +def _build_published_control_cursor(published_at: dt.datetime, control_id: int) -> str: + """Encode the published-control sort key into an opaque cursor string.""" + return f"{published_at.isoformat()}::{control_id}" + + +def _parse_published_control_cursor(cursor: str) -> tuple[dt.datetime, int]: + """Decode a published-control cursor or raise the standard invalid-cursor error.""" + try: + published_at_raw, control_id_raw = cursor.rsplit("::", 1) + published_at = dt.datetime.fromisoformat(published_at_raw) + control_id = int(control_id_raw) + except (TypeError, ValueError) as exc: + raise APIValidationError( + error_code=ErrorCode.VALIDATION_ERROR, + detail="Published-control cursor is invalid", + resource="ControlStore", + errors=[ + ValidationErrorItem( + resource="ControlStore", + field="cursor", + code="invalid_cursor", + message="Cursor must come from a previous published-controls response.", + value=cursor, + ) + ], + ) from exc + return published_at, control_id + + def _parse_unrendered_template_or_api_error(control: Control) -> UnrenderedTemplateControl: """Parse an unrendered template control or raise the standard corrupted-data error.""" try: @@ -790,3 +1178,23 @@ def _matches_enabled_state( if enabled_state == "enabled": return is_enabled return not is_enabled + + +def _is_missing_control_store_schema_error( + error: OperationalError | ProgrammingError, +) -> bool: + """Return whether an error came from the pre-Phase-3 store schema being absent.""" + error_text = " ".join(part for part in (str(error.orig), str(error)) if part).lower() + if "control_stores" not in error_text and "control_stores_controls" not in error_text: + return False + return "does not exist" in error_text or "no such table" in error_text + + +def _is_missing_cloned_control_id_schema_error( + error: OperationalError | ProgrammingError, +) -> bool: + """Return whether an error came from pre-Phase-3 clone provenance schema.""" + error_text = " ".join(part for part in (str(error.orig), str(error)) if part).lower() + if "cloned_control_id" not in error_text: + return False + return "does not exist" in error_text or "no such column" in error_text diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 5920344f..8d8fdf43 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,14 +1,13 @@ import pytest +from agent_control_engine import discover_evaluators from fastapi.testclient import TestClient from sqlalchemy import MetaData, create_engine, inspect, text -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from agent_control_engine import discover_evaluators from agent_control_server.config import auth_settings, db_config from agent_control_server.db import Base from agent_control_server.main import app as fastapi_app - -import agent_control_server.models # ensure models are imported so tables are registered +from agent_control_server.models import ControlStore # Discover evaluators at test session start discover_evaluators() @@ -48,6 +47,20 @@ def _truncate_all_tables() -> None: conn.execute(table.delete()) +def _seed_default_control_store() -> None: + """Seed the default control store for tests that create tables directly.""" + with engine.begin() as conn: + schema = "public" if conn.dialect.name == "postgresql" else None + if "control_stores" not in inspect(conn).get_table_names(schema=schema): + return + + existing = conn.execute( + text("SELECT id FROM control_stores WHERE name = 'default' LIMIT 1") + ).scalar() + if existing is None: + conn.execute(ControlStore.__table__.insert().values(name="default")) + + @pytest.fixture(scope="session") def db_engine(): """Provide the sqlalchemy engine for tests.""" @@ -79,10 +92,16 @@ def db_schema() -> None: admin_engine.dispose() # Recreate tables for tests in the configured database. - reflected_metadata = MetaData() - reflected_metadata.reflect(bind=engine) - reflected_metadata.drop_all(bind=engine) + if engine.dialect.name == "postgresql": + with engine.begin() as conn: + conn.execute(text("DROP SCHEMA IF EXISTS public CASCADE")) + conn.execute(text("CREATE SCHEMA public")) + else: + reflected_metadata = MetaData() + reflected_metadata.reflect(bind=engine) + reflected_metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) + _seed_default_control_store() yield @@ -93,7 +112,12 @@ def setup_auth(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(auth_settings, "api_keys", TEST_API_KEY) monkeypatch.setattr(auth_settings, "admin_api_keys", TEST_ADMIN_API_KEY) # Clear cached properties so they recompute with monkeypatched values - for attr in ("_parsed_api_keys", "_parsed_admin_api_keys", "_all_valid_keys", "_all_admin_keys"): + for attr in ( + "_parsed_api_keys", + "_parsed_admin_api_keys", + "_all_valid_keys", + "_all_admin_keys", + ): auth_settings.__dict__.pop(attr, None) @@ -136,6 +160,7 @@ def unauthenticated_client(app: object) -> TestClient: @pytest.fixture(autouse=True) def clean_db(): _truncate_all_tables() + _seed_default_control_store() yield diff --git a/server/tests/test_control_phase3_alembic_migration.py b/server/tests/test_control_phase3_alembic_migration.py new file mode 100644 index 00000000..a22e16e2 --- /dev/null +++ b/server/tests/test_control_phase3_alembic_migration.py @@ -0,0 +1,343 @@ +"""Alembic coverage for Phase 3 control-store schema changes.""" + +from __future__ import annotations + +import asyncio +import json +import uuid +from pathlib import Path + +import pytest +from alembic.config import Config +from fastapi import FastAPI +from fastapi.testclient import TestClient +from sqlalchemy import create_engine, inspect, text +from sqlalchemy.engine import Engine, make_url +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from agent_control_server.config import db_config +from agent_control_server.db import get_async_db +from alembic import command + +from .conftest import TEST_ADMIN_API_KEY +from .utils import VALID_CONTROL_PAYLOAD + +SERVER_DIR = Path(__file__).resolve().parents[1] +PRE_MIGRATION_REVISION = "c1e9f9c4a1d2" +MIGRATION_REVISION = "7d9c2f1a3b44" +_BASE_DB_URL = make_url(db_config.get_url()) + +pytestmark = pytest.mark.skipif( + _BASE_DB_URL.get_backend_name() != "postgresql", + reason="Phase 3 Alembic migration tests require PostgreSQL.", +) + + +def _insert_control( + engine: Engine, + *, + name: str, + data: dict[str, object] | None = None, +) -> int: + payload = data if data is not None else {"description": "pre-phase3"} + with engine.begin() as conn: + return int( + conn.execute( + text( + """ + INSERT INTO controls (name, data) + VALUES (:name, CAST(:data AS JSONB)) + RETURNING id + """ + ), + {"name": name, "data": json.dumps(payload)}, + ).scalar_one() + ) + + +def _insert_policy(engine: Engine, *, name: str) -> int: + with engine.begin() as conn: + return int( + conn.execute( + text( + """ + INSERT INTO policies (name) + VALUES (:name) + RETURNING id + """ + ), + {"name": name}, + ).scalar_one() + ) + + +def _insert_agent(engine: Engine, *, name: str) -> str: + with engine.begin() as conn: + conn.execute( + text( + """ + INSERT INTO agents (name, data) + VALUES (:name, CAST(:data AS JSONB)) + """ + ), + { + "name": name, + "data": json.dumps( + { + "agent_metadata": {}, + "steps": [], + "evaluators": [], + } + ), + }, + ) + return name + + +@pytest.fixture +def temp_db_url() -> str: + temp_db_name = f"agent_control_phase3_{uuid.uuid4().hex[:12]}" + admin_url = _BASE_DB_URL.set(database="postgres").render_as_string(hide_password=False) + target_url = _BASE_DB_URL.set(database=temp_db_name).render_as_string(hide_password=False) + + admin_engine = create_engine(admin_url, isolation_level="AUTOCOMMIT") + with admin_engine.connect() as conn: + conn.execute(text(f'CREATE DATABASE "{temp_db_name}"')) + admin_engine.dispose() + + try: + yield target_url + finally: + cleanup_engine = create_engine(admin_url, isolation_level="AUTOCOMMIT") + with cleanup_engine.connect() as conn: + conn.execute( + text( + """ + SELECT pg_terminate_backend(pid) + FROM pg_stat_activity + WHERE datname = :db_name AND pid <> pg_backend_pid() + """ + ), + {"db_name": temp_db_name}, + ) + conn.execute(text(f'DROP DATABASE IF EXISTS "{temp_db_name}"')) + cleanup_engine.dispose() + + +@pytest.fixture +def alembic_config(temp_db_url: str) -> Config: + cfg = Config(str(SERVER_DIR / "alembic.ini")) + cfg.set_main_option("script_location", str(SERVER_DIR / "alembic")) + cfg.set_main_option("sqlalchemy.url", temp_db_url) + return cfg + + +@pytest.fixture +def temp_engine(temp_db_url: str) -> Engine: + engine = create_engine(temp_db_url, future=True) + try: + yield engine + finally: + engine.dispose() + + +@pytest.fixture +def upgrade_to(alembic_config: Config): + def _upgrade(revision: str, *, sql: bool = False) -> None: + command.upgrade(alembic_config, revision, sql=sql) + + return _upgrade + + +def test_upgrade_seeds_default_store_and_adds_clone_provenance( + upgrade_to, + temp_engine: Engine, +) -> None: + upgrade_to(PRE_MIGRATION_REVISION) + control_id = _insert_control(temp_engine, name="pre-phase3-control") + + upgrade_to(MIGRATION_REVISION) + + with temp_engine.begin() as conn: + stores = conn.execute( + text("SELECT id, name FROM control_stores ORDER BY id") + ).mappings().all() + control = conn.execute( + text( + """ + SELECT id, cloned_control_id + FROM controls + WHERE id = :control_id + """ + ), + {"control_id": control_id}, + ).mappings().one() + + assert stores == [{"id": 1, "name": "default"}] + assert control["cloned_control_id"] is None + assert "control_stores_controls" in inspect(temp_engine).get_table_names() + + +def test_upgrade_advances_control_store_identity_after_default_seed( + upgrade_to, + temp_engine: Engine, +) -> None: + upgrade_to(PRE_MIGRATION_REVISION) + upgrade_to(MIGRATION_REVISION) + + with temp_engine.begin() as conn: + next_store_id = int( + conn.execute( + text( + """ + INSERT INTO control_stores (name) + VALUES ('post-seed-store') + RETURNING id + """ + ) + ).scalar_one() + ) + + assert next_store_id > 1 + + +def test_pre_phase3_runtime_control_endpoints_remain_usable_during_rollout( + app: FastAPI, + upgrade_to, + temp_db_url: str, + temp_engine: Engine, +) -> None: + # Given: a database upgraded only to the pre-Phase-3 schema + upgrade_to(PRE_MIGRATION_REVISION) + policy_control_id = _insert_control( + temp_engine, + name="pre-phase3-policy-control", + data=VALID_CONTROL_PAYLOAD, + ) + agent_control_id = _insert_control( + temp_engine, + name="pre-phase3-agent-control", + data=VALID_CONTROL_PAYLOAD, + ) + delete_control_id = _insert_control( + temp_engine, + name="pre-phase3-delete-control", + data=VALID_CONTROL_PAYLOAD, + ) + policy_id = _insert_policy(temp_engine, name="pre-phase3-policy") + agent_name = _insert_agent(temp_engine, name="pre-phase3-agent") + + async_engine = create_async_engine(temp_db_url, echo=False) + session_factory = async_sessionmaker( + bind=async_engine, + autoflush=False, + expire_on_commit=False, + class_=AsyncSession, + ) + + async def _override_get_async_db(): + async with session_factory() as session: + yield session + + app.dependency_overrides[get_async_db] = _override_get_async_db + try: + with TestClient( + app, + raise_server_exceptions=True, + headers={"X-API-Key": TEST_ADMIN_API_KEY}, + ) as client: + # When: using existing runtime endpoints against legacy rows before store tables exist + create_response = client.put( + "/api/v1/controls", + json={ + "name": "pre-phase3-created-control", + "data": VALID_CONTROL_PAYLOAD, + }, + ) + policy_assoc = client.post( + f"/api/v1/policies/{policy_id}/controls/{policy_control_id}" + ) + + agent_assoc = client.post( + f"/api/v1/agents/{agent_name}/controls/{agent_control_id}" + ) + + controls_response = client.get("/api/v1/controls") + agent_controls_response = client.get(f"/api/v1/agents/{agent_name}/controls") + delete_response = client.delete(f"/api/v1/controls/{delete_control_id}") + + # Then: the legacy runtime endpoints and read paths still succeed + # without control-store tables + assert create_response.status_code == 200, create_response.text + assert policy_assoc.status_code == 200, policy_assoc.text + assert agent_assoc.status_code == 200, agent_assoc.text + assert controls_response.status_code == 200, controls_response.text + assert agent_controls_response.status_code == 200, agent_controls_response.text + assert delete_response.status_code == 200, delete_response.text + assert { + control["name"] for control in controls_response.json()["controls"] + } >= { + "pre-phase3-created-control", + "pre-phase3-policy-control", + "pre-phase3-agent-control", + } + assert [ + control["id"] for control in agent_controls_response.json()["controls"] + ] == [agent_control_id] + finally: + app.dependency_overrides.pop(get_async_db, None) + asyncio.run(async_engine.dispose()) + + +def test_pre_phase3_control_store_endpoints_fail_gracefully( + app: FastAPI, + upgrade_to, + temp_db_url: str, + temp_engine: Engine, +) -> None: + # Given: a database that has not yet received the Phase 3 store schema + upgrade_to(PRE_MIGRATION_REVISION) + control_id = _insert_control( + temp_engine, + name="pre-phase3-store-control", + data=VALID_CONTROL_PAYLOAD, + ) + + async_engine = create_async_engine(temp_db_url, echo=False) + session_factory = async_sessionmaker( + bind=async_engine, + autoflush=False, + expire_on_commit=False, + class_=AsyncSession, + ) + + async def _override_get_async_db(): + async with session_factory() as session: + yield session + + app.dependency_overrides[get_async_db] = _override_get_async_db + try: + with TestClient( + app, + raise_server_exceptions=True, + headers={"X-API-Key": TEST_ADMIN_API_KEY}, + ) as client: + # When: calling new control-store endpoints before the migration lands + publish_response = client.post( + f"/api/v1/control-stores/default/controls/{control_id}" + ) + unpublish_response = client.delete( + f"/api/v1/control-stores/default/controls/{control_id}" + ) + list_response = client.get("/api/v1/control-stores/default/controls") + + # Then: the endpoints fail with the standard database error envelope + assert publish_response.status_code == 500, publish_response.text + assert publish_response.json()["error_code"] == "DATABASE_ERROR" + assert unpublish_response.status_code == 500, unpublish_response.text + assert unpublish_response.json()["error_code"] == "DATABASE_ERROR" + assert list_response.status_code == 500, list_response.text + assert list_response.json()["error_code"] == "DATABASE_ERROR" + finally: + app.dependency_overrides.pop(get_async_db, None) + asyncio.run(async_engine.dispose()) diff --git a/server/tests/test_control_store.py b/server/tests/test_control_store.py new file mode 100644 index 00000000..a36a1f18 --- /dev/null +++ b/server/tests/test_control_store.py @@ -0,0 +1,1188 @@ +from __future__ import annotations + +import asyncio +import datetime as dt +import uuid +from collections.abc import AsyncGenerator +from copy import deepcopy +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from sqlalchemy import insert, select, update +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session + +from agent_control_server.db import get_async_db +from agent_control_server.models import ( + Control, + ControlStore, + ControlVersion, + Policy, + control_stores_controls, + policy_controls, +) +from agent_control_server.services.controls import ControlService + +from .conftest import AsyncSessionTest, engine +from .utils import VALID_CONTROL_PAYLOAD + + +def _make_integrity_error(constraint_name: str) -> IntegrityError: + diag = SimpleNamespace(constraint_name=constraint_name) + orig = Exception(f'duplicate key value violates unique constraint "{constraint_name}"') + setattr(orig, "diag", diag) + return IntegrityError("statement", {}, orig) + + +def _unrendered_template_payload() -> dict[str, Any]: + return { + "template": { + "description": "Regex denial template", + "parameters": { + "pattern": { + "type": "regex_re2", + "label": "Pattern", + }, + }, + "definition_template": { + "description": "Template-backed control", + "execution": "server", + "scope": {"step_types": ["llm"], "stages": ["pre"]}, + "condition": { + "selector": {"path": "input"}, + "evaluator": { + "name": "regex", + "config": {"pattern": {"$param": "pattern"}}, + }, + }, + "action": {"decision": "deny"}, + }, + }, + "template_values": {}, + } + + +def _ensure_default_store() -> int: + with Session(engine) as session: + existing = session.scalar(select(ControlStore.id).where(ControlStore.name == "default")) + if existing is not None: + return int(existing) + + store = ControlStore(name="default") + session.add(store) + session.commit() + return int(store.id) + + +def _create_control( + client: TestClient, + *, + name: str | None = None, + data: dict[str, object] | None = None, +) -> tuple[int, str]: + control_name = name or f"control-{uuid.uuid4()}" + payload = deepcopy(data) if data is not None else deepcopy(VALID_CONTROL_PAYLOAD) + response = client.put("/api/v1/controls", json={"name": control_name, "data": payload}) + assert response.status_code == 200, response.text + return int(response.json()["control_id"]), control_name + + +def _create_policy(client: TestClient, *, name: str | None = None) -> int: + response = client.put( + "/api/v1/policies", + json={"name": name or f"policy-{uuid.uuid4()}"}, + ) + assert response.status_code == 200, response.text + return int(response.json()["policy_id"]) + + +def _create_agent(client: TestClient, *, name: str | None = None) -> str: + agent_name = name or f"agent-{uuid.uuid4().hex[:12]}" + response = client.post( + "/api/v1/agents/initAgent", + json={"agent": {"agent_name": agent_name}, "steps": []}, + ) + assert response.status_code == 200, response.text + return agent_name + + +async def _create_versioned_control( + *, + name: str | None = None, + data: dict[str, object] | None = None, +) -> tuple[int, str]: + control_name = name or f"control-{uuid.uuid4()}" + payload = deepcopy(data) if data is not None else deepcopy(VALID_CONTROL_PAYLOAD) + + async with AsyncSessionTest() as session: + service = ControlService(session) + control = service.create_control(name=control_name, data=payload) + await service.create_version( + control, + event_type="created", + note="Initial creation", + ) + await session.commit() + return control.id, control_name + + +async def _create_policy_row(*, name: str | None = None) -> int: + async with AsyncSessionTest() as session: + policy = Policy(name=name or f"policy-{uuid.uuid4()}") + session.add(policy) + await session.commit() + return policy.id + + +def _insert_raw_control( + *, + name: str | None = None, + data: dict[str, object] | None = None, +) -> tuple[int, str]: + control_name = name or f"control-{uuid.uuid4()}" + control = Control( + name=control_name, + data=deepcopy(data) if data is not None else {}, + ) + with Session(engine) as session: + session.add(control) + session.commit() + session.refresh(control) + return int(control.id), control_name + + +def _fetch_control(control_id: int) -> Control | None: + with Session(engine) as session: + return session.scalars(select(Control).where(Control.id == control_id)).first() + + +def _fetch_versions(control_id: int) -> list[ControlVersion]: + with Session(engine) as session: + return list( + session.scalars( + select(ControlVersion) + .where(ControlVersion.control_id == control_id) + .order_by(ControlVersion.version_num) + ).all() + ) + + +def _published_rows(control_id: int | None = None) -> list[tuple[int, int, dt.datetime]]: + with Session(engine) as session: + stmt = select( + control_stores_controls.c.store_id, + control_stores_controls.c.control_id, + control_stores_controls.c.published_at, + ).order_by(control_stores_controls.c.control_id) + if control_id is not None: + stmt = stmt.where(control_stores_controls.c.control_id == control_id) + return [ + (int(store_id), int(published_control_id), published_at) + for store_id, published_control_id, published_at in session.execute(stmt).all() + ] + + +def _set_published_at(control_id: int, published_at: dt.datetime) -> None: + with Session(engine) as session: + session.execute( + update(control_stores_controls) + .where(control_stores_controls.c.control_id == control_id) + .values(published_at=published_at) + ) + session.commit() + + +def test_publish_control_is_idempotent(client: TestClient) -> None: + _ensure_default_store() + control_id, _ = _create_control(client) + + first = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + second = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + + assert first.status_code == 200, first.text + assert second.status_code == 200, second.text + assert _published_rows(control_id) == [_published_rows(control_id)[0]] + + +def test_publish_control_rejects_missing_and_deleted_controls(client: TestClient) -> None: + _ensure_default_store() + missing = client.post("/api/v1/control-stores/default/controls/99999") + assert missing.status_code == 404 + assert missing.json()["error_code"] == "CONTROL_NOT_FOUND" + + control_id, _ = _create_control(client) + delete_response = client.delete(f"/api/v1/controls/{control_id}") + assert delete_response.status_code == 200, delete_response.text + + deleted = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + assert deleted.status_code == 404 + assert deleted.json()["error_code"] == "CONTROL_NOT_FOUND" + + +def test_publish_control_accepts_rendered_and_unrendered_controls(client: TestClient) -> None: + _ensure_default_store() + rendered_id, _ = _create_control(client) + unrendered_id, _ = _create_control( + client, + data=_unrendered_template_payload(), + ) + + rendered_response = client.post(f"/api/v1/control-stores/default/controls/{rendered_id}") + unrendered_response = client.post(f"/api/v1/control-stores/default/controls/{unrendered_id}") + + assert rendered_response.status_code == 200, rendered_response.text + assert unrendered_response.status_code == 200, unrendered_response.text + assert [row[1] for row in _published_rows()] == [rendered_id, unrendered_id] + + +def test_publish_control_rejects_runtime_associations(client: TestClient) -> None: + _ensure_default_store() + + policy_control_id, _ = _create_control(client) + policy_id = _create_policy(client) + policy_assoc = client.post(f"/api/v1/policies/{policy_id}/controls/{policy_control_id}") + assert policy_assoc.status_code == 200, policy_assoc.text + + policy_publish = client.post(f"/api/v1/control-stores/default/controls/{policy_control_id}") + assert policy_publish.status_code == 409 + assert policy_publish.json()["error_code"] == "CONTROL_IN_USE" + + agent_control_id, _ = _create_control(client) + agent_name = _create_agent(client) + agent_assoc = client.post(f"/api/v1/agents/{agent_name}/controls/{agent_control_id}") + assert agent_assoc.status_code == 200, agent_assoc.text + + agent_publish = client.post(f"/api/v1/control-stores/default/controls/{agent_control_id}") + assert agent_publish.status_code == 409 + assert agent_publish.json()["error_code"] == "CONTROL_IN_USE" + + +def test_unpublish_control_removes_publication_state(client: TestClient) -> None: + # Given: a published control in the default store + _ensure_default_store() + control_id, _ = _create_control(client) + publish_response = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + assert publish_response.status_code == 200, publish_response.text + + # When: unpublishing it twice + first = client.delete(f"/api/v1/control-stores/default/controls/{control_id}") + second = client.delete(f"/api/v1/control-stores/default/controls/{control_id}") + + # Then: the endpoint is idempotent and leaves no publication rows behind + assert first.status_code == 200, first.text + assert second.status_code == 200, second.text + assert _published_rows(control_id) == [] + + +def test_unpublish_control_rejects_missing_and_deleted_controls(client: TestClient) -> None: + # Given: the default store exists + _ensure_default_store() + + # When: unpublishing a control that does not exist + missing = client.delete("/api/v1/control-stores/default/controls/99999") + + # Then: the API reports the control as missing + assert missing.status_code == 404 + assert missing.json()["error_code"] == "CONTROL_NOT_FOUND" + + # Given: a control that has already been soft-deleted + control_id, _ = _create_control(client) + delete_response = client.delete(f"/api/v1/controls/{control_id}") + assert delete_response.status_code == 200, delete_response.text + + # When: unpublishing the soft-deleted control + deleted = client.delete(f"/api/v1/control-stores/default/controls/{control_id}") + + # Then: the deleted control is treated as not found + assert deleted.status_code == 404 + assert deleted.json()["error_code"] == "CONTROL_NOT_FOUND" + + +def test_list_published_controls_uses_cursor_pagination_and_name_filter( + client: TestClient, +) -> None: + _ensure_default_store() + alpha_id, _ = _create_control(client, name="AlphaControl") + beta_id, _ = _create_control(client, name="BetaDetector") + gamma_id, _ = _create_control(client, name="GammaControl") + + for control_id in (alpha_id, beta_id, gamma_id): + response = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + assert response.status_code == 200, response.text + + _set_published_at(alpha_id, dt.datetime(2026, 4, 15, 10, 0, tzinfo=dt.UTC)) + _set_published_at(beta_id, dt.datetime(2026, 4, 15, 11, 0, tzinfo=dt.UTC)) + _set_published_at(gamma_id, dt.datetime(2026, 4, 15, 12, 0, tzinfo=dt.UTC)) + + first_page = client.get("/api/v1/control-stores/default/controls", params={"limit": 2}) + assert first_page.status_code == 200, first_page.text + first_body = first_page.json() + + assert [item["id"] for item in first_body["controls"]] == [gamma_id, beta_id] + assert first_body["pagination"]["has_more"] is True + assert first_body["pagination"]["next_cursor"] is not None + + second_page = client.get( + "/api/v1/control-stores/default/controls", + params={"limit": 2, "cursor": first_body["pagination"]["next_cursor"]}, + ) + assert second_page.status_code == 200, second_page.text + second_body = second_page.json() + + assert [item["id"] for item in second_body["controls"]] == [alpha_id] + assert second_body["pagination"]["has_more"] is False + assert second_body["pagination"]["next_cursor"] is None + + filtered = client.get( + "/api/v1/control-stores/default/controls", + params={"name": "detec"}, + ) + assert filtered.status_code == 200, filtered.text + assert [item["id"] for item in filtered.json()["controls"]] == [beta_id] + + +def test_list_published_controls_filters_by_tag_and_enabled(client: TestClient) -> None: + _ensure_default_store() + enabled_payload = deepcopy(VALID_CONTROL_PAYLOAD) + enabled_payload["tags"] = ["pci"] + enabled_id, _ = _create_control(client, name="enabled-control", data=enabled_payload) + + disabled_payload = deepcopy(VALID_CONTROL_PAYLOAD) + disabled_payload["enabled"] = False + disabled_payload["tags"] = ["pci"] + disabled_id, _ = _create_control(client, name="disabled-control", data=disabled_payload) + + other_payload = deepcopy(VALID_CONTROL_PAYLOAD) + other_payload["tags"] = ["other"] + other_id, _ = _create_control(client, name="other-control", data=other_payload) + + for control_id in (enabled_id, disabled_id, other_id): + response = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + assert response.status_code == 200, response.text + + tag_filtered = client.get("/api/v1/control-stores/default/controls", params={"tag": "pci"}) + assert tag_filtered.status_code == 200, tag_filtered.text + assert {item["id"] for item in tag_filtered.json()["controls"]} == {enabled_id, disabled_id} + + disabled_filtered = client.get( + "/api/v1/control-stores/default/controls", + params={"enabled": "false"}, + ) + assert disabled_filtered.status_code == 200, disabled_filtered.text + assert [item["id"] for item in disabled_filtered.json()["controls"]] == [disabled_id] + + +def test_list_published_controls_cursor_survives_unpublished_cursor_row( + client: TestClient, +) -> None: + _ensure_default_store() + first_id, _ = _create_control(client, name="first-control") + second_id, _ = _create_control(client, name="second-control") + + for control_id in (first_id, second_id): + response = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + assert response.status_code == 200, response.text + + _set_published_at(first_id, dt.datetime(2026, 4, 15, 10, 0, tzinfo=dt.UTC)) + _set_published_at(second_id, dt.datetime(2026, 4, 15, 11, 0, tzinfo=dt.UTC)) + + first_page = client.get("/api/v1/control-stores/default/controls", params={"limit": 1}) + assert first_page.status_code == 200, first_page.text + next_cursor = first_page.json()["pagination"]["next_cursor"] + + unpublish_response = client.delete(f"/api/v1/control-stores/default/controls/{second_id}") + assert unpublish_response.status_code == 200, unpublish_response.text + + next_page = client.get( + "/api/v1/control-stores/default/controls", + params={"limit": 1, "cursor": next_cursor}, + ) + assert next_page.status_code == 200, next_page.text + assert [item["id"] for item in next_page.json()["controls"]] == [first_id] + + +def test_list_published_controls_cursor_survives_unpublish_and_republish( + client: TestClient, +) -> None: + _ensure_default_store() + alpha_id, _ = _create_control(client, name="AlphaControl") + beta_id, _ = _create_control(client, name="BetaDetector") + gamma_id, _ = _create_control(client, name="GammaControl") + + for control_id in (alpha_id, beta_id, gamma_id): + response = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + assert response.status_code == 200, response.text + + _set_published_at(alpha_id, dt.datetime(2026, 4, 15, 10, 0, tzinfo=dt.UTC)) + _set_published_at(beta_id, dt.datetime(2026, 4, 15, 11, 0, tzinfo=dt.UTC)) + _set_published_at(gamma_id, dt.datetime(2026, 4, 15, 12, 0, tzinfo=dt.UTC)) + + first_page = client.get("/api/v1/control-stores/default/controls", params={"limit": 2}) + assert first_page.status_code == 200, first_page.text + first_body = first_page.json() + assert [item["id"] for item in first_body["controls"]] == [gamma_id, beta_id] + + unpublish_response = client.delete(f"/api/v1/control-stores/default/controls/{beta_id}") + assert unpublish_response.status_code == 200, unpublish_response.text + republish_response = client.post(f"/api/v1/control-stores/default/controls/{beta_id}") + assert republish_response.status_code == 200, republish_response.text + _set_published_at(beta_id, dt.datetime(2026, 4, 15, 13, 0, tzinfo=dt.UTC)) + + next_page = client.get( + "/api/v1/control-stores/default/controls", + params={"limit": 2, "cursor": first_body["pagination"]["next_cursor"]}, + ) + assert next_page.status_code == 200, next_page.text + assert [item["id"] for item in next_page.json()["controls"]] == [alpha_id] + + +def test_list_published_controls_rejects_malformed_cursor(client: TestClient) -> None: + _ensure_default_store() + control_id, _ = _create_control(client, name="cursor-target") + response = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + assert response.status_code == 200, response.text + + malformed_page = client.get( + "/api/v1/control-stores/default/controls", + params={"limit": 1, "cursor": "not-a-valid-cursor"}, + ) + assert malformed_page.status_code == 422 + assert malformed_page.json()["error_code"] == "VALIDATION_ERROR" + + +def test_list_published_controls_uses_control_id_tie_breaker_for_equal_timestamps( + client: TestClient, +) -> None: + # Given: three published controls with the exact same publication timestamp + _ensure_default_store() + control_ids = [ + _create_control(client, name=f"tie-break-{index}")[0] + for index in range(3) + ] + for control_id in control_ids: + response = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + assert response.status_code == 200, response.text + _set_published_at(control_id, dt.datetime(2026, 4, 15, 12, 0, tzinfo=dt.UTC)) + + expected_order = sorted(control_ids, reverse=True) + + # When: requesting the first page + first_page = client.get("/api/v1/control-stores/default/controls", params={"limit": 2}) + + # Then: equal timestamps fall back to control_id descending order + assert first_page.status_code == 200, first_page.text + first_body = first_page.json() + assert [item["id"] for item in first_body["controls"]] == expected_order[:2] + assert first_body["pagination"]["next_cursor"] is not None + + # When: requesting the next page from that cursor + second_page = client.get( + "/api/v1/control-stores/default/controls", + params={"limit": 2, "cursor": first_body["pagination"]["next_cursor"]}, + ) + + # Then: the remaining lower-id control is returned + assert second_page.status_code == 200, second_page.text + assert [item["id"] for item in second_page.json()["controls"]] == expected_order[2:] + + +@pytest.mark.asyncio +@pytest.mark.skipif( + engine.dialect.name != "postgresql", + reason="Control-store concurrency coverage requires PostgreSQL row locking semantics", +) +async def test_publish_waits_for_policy_association_and_preserves_catalog_invariant() -> None: + control_id, _ = await _create_versioned_control() + policy_id = await _create_policy_row() + + association_has_lock = asyncio.Event() + publish_started = asyncio.Event() + release_association = asyncio.Event() + + async def associate_control() -> None: + async with AsyncSessionTest() as session: + service = ControlService(session) + await service.get_active_control_or_404(control_id, for_update=True) + assert not await service.is_control_published(control_id) + association_has_lock.set() + await release_association.wait() + await service.add_control_to_policy(policy_id=policy_id, control_id=control_id) + await session.commit() + + async def publish_control() -> None: + async with AsyncSessionTest() as session: + service = ControlService(session) + publish_started.set() + await service.get_active_control_or_404(control_id, for_update=True) + associations = await service.list_control_associations(control_id) + if not associations.policy_ids and not associations.agent_names: + await service.publish_control(control_id) + await session.commit() + + association_task = asyncio.create_task(associate_control()) + await association_has_lock.wait() + publish_task = asyncio.create_task(publish_control()) + await publish_started.wait() + release_association.set() + await asyncio.gather(association_task, publish_task) + + with Session(engine) as session: + policy_link = session.execute( + select(policy_controls.c.control_id).where( + policy_controls.c.policy_id == policy_id, + policy_controls.c.control_id == control_id, + ) + ).first() + + assert policy_link is not None + assert _published_rows(control_id) == [] + + +@pytest.mark.asyncio +@pytest.mark.skipif( + engine.dialect.name != "postgresql", + reason="Control-store concurrency coverage requires PostgreSQL row locking semantics", +) +async def test_policy_association_waits_for_publish_and_preserves_catalog_invariant() -> None: + control_id, _ = await _create_versioned_control() + policy_id = await _create_policy_row() + + publish_has_lock = asyncio.Event() + association_started = asyncio.Event() + release_publish = asyncio.Event() + + async def publish_control() -> None: + async with AsyncSessionTest() as session: + service = ControlService(session) + await service.get_active_control_or_404(control_id, for_update=True) + associations = await service.list_control_associations(control_id) + assert associations.policy_ids == [] + publish_has_lock.set() + await release_publish.wait() + await service.publish_control(control_id) + await session.commit() + + async def associate_control() -> None: + async with AsyncSessionTest() as session: + service = ControlService(session) + association_started.set() + await service.get_active_control_or_404(control_id, for_update=True) + if not await service.is_control_published(control_id): + await service.add_control_to_policy(policy_id=policy_id, control_id=control_id) + await session.commit() + + publish_task = asyncio.create_task(publish_control()) + await publish_has_lock.wait() + association_task = asyncio.create_task(associate_control()) + await association_started.wait() + release_publish.set() + await asyncio.gather(publish_task, association_task) + + with Session(engine) as session: + policy_link = session.execute( + select(policy_controls.c.control_id).where( + policy_controls.c.policy_id == policy_id, + policy_controls.c.control_id == control_id, + ) + ).first() + + assert policy_link is None + assert len(_published_rows(control_id)) == 1 + + +@pytest.mark.asyncio +@pytest.mark.skipif( + engine.dialect.name != "postgresql", + reason="Control-store concurrency coverage requires PostgreSQL row locking semantics", +) +async def test_unpublish_waits_for_publish_and_applies_afterward() -> None: + # Given: a publish request that acquires the control-row lock first + control_id, _ = await _create_versioned_control() + publish_has_lock = asyncio.Event() + unpublish_started = asyncio.Event() + release_publish = asyncio.Event() + + async def publish_control() -> None: + async with AsyncSessionTest() as session: + service = ControlService(session) + await service.get_active_control_or_404(control_id, for_update=True) + publish_has_lock.set() + await release_publish.wait() + await service.publish_control(control_id) + await session.commit() + + async def unpublish_control() -> None: + async with AsyncSessionTest() as session: + service = ControlService(session) + unpublish_started.set() + await service.get_active_control_or_404(control_id, for_update=True) + await service.unpublish_control(control_id) + await session.commit() + + # When: unpublish starts while publish still holds the lock + publish_task = asyncio.create_task(publish_control()) + await publish_has_lock.wait() + unpublish_task = asyncio.create_task(unpublish_control()) + await unpublish_started.wait() + release_publish.set() + await asyncio.gather(publish_task, unpublish_task) + + # Then: the later unpublish takes effect after publish commits + assert _published_rows(control_id) == [] + + +def test_publish_control_rejects_corrupted_stored_data(client: TestClient) -> None: + # Given: a control row whose stored JSON is not a valid control definition + _ensure_default_store() + control_id, _ = _insert_raw_control(data={"description": "broken"}) + + # When: publishing the corrupted control + response = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + + # Then: the endpoint reports corrupted stored data rather than publishing it + assert response.status_code == 422 + assert response.json()["error_code"] == "CORRUPTED_DATA" + + +def test_default_store_seed_keeps_create_all_identity_usable() -> None: + # Given: the session-level test schema has already seeded the default store + _ensure_default_store() + + with Session(engine) as session: + # When: inserting another control store directly through the ORM + store = ControlStore(name="secondary-store") + session.add(store) + session.commit() + + # Then: the identity sequence advances past the seeded default row + assert store.id is not None + assert store.id > 1 + + +def test_clone_control_creates_independent_control_with_provenance_and_version( + client: TestClient, +) -> None: + source_id, source_name = _create_control(client) + + clone_response = client.post( + f"/api/v1/controls/{source_id}/clone", + json={"name": f"{source_name}-clone"}, + ) + assert clone_response.status_code == 200, clone_response.text + clone_id = int(clone_response.json()["control_id"]) + + clone = _fetch_control(clone_id) + source = _fetch_control(source_id) + assert clone is not None + assert source is not None + assert clone.cloned_control_id == source_id + assert clone.data == source.data + + updated_payload = deepcopy(VALID_CONTROL_PAYLOAD) + updated_payload["description"] = "Clone only" + update_response = client.put( + f"/api/v1/controls/{clone_id}/data", + json={"data": updated_payload}, + ) + assert update_response.status_code == 200, update_response.text + + refreshed_clone = _fetch_control(clone_id) + refreshed_source = _fetch_control(source_id) + assert refreshed_clone is not None + assert refreshed_source is not None + assert refreshed_clone.data["description"] == "Clone only" + assert refreshed_source.data["description"] == VALID_CONTROL_PAYLOAD["description"] + + versions = _fetch_versions(clone_id) + assert [version.version_num for version in versions] == [1, 2] + assert versions[0].event_type == "cloned" + assert versions[0].snapshot["cloned_control_id"] == source_id + assert versions[0].note == f"Cloned from '{source_name}' (id:{source_id}) at version 1" + + +def test_clone_control_without_name_generates_unique_copy_name(client: TestClient) -> None: + source_id, source_name = _create_control(client, name="PII-Detection") + _create_control(client, name=f"{source_name}-copy") + + clone_response = client.post(f"/api/v1/controls/{source_id}/clone") + assert clone_response.status_code == 200, clone_response.text + assert clone_response.json()["name"] == f"{source_name}-copy-2" + + +def test_clone_control_without_name_truncates_generated_copy_name(client: TestClient) -> None: + # Given: a control name already at the maximum allowed length + source_name = "x" * 255 + source_id, _ = _create_control(client, name=source_name) + + # When: cloning without an explicit name + clone_response = client.post(f"/api/v1/controls/{source_id}/clone") + + # Then: the generated copy name stays within the control name limit + assert clone_response.status_code == 200, clone_response.text + cloned_name = clone_response.json()["name"] + assert len(cloned_name) == 255 + assert cloned_name.endswith("-copy") + assert cloned_name == f"{source_name[:250]}-copy" + + +def test_clone_control_rejects_requested_name_conflict(client: TestClient) -> None: + # Given: a source control and a different active control already using the requested clone name + source_id, _ = _create_control(client, name="source-control") + _, existing_name = _create_control(client, name="existing-clone-name") + + # When: cloning into that existing name + clone_response = client.post( + f"/api/v1/controls/{source_id}/clone", + json={"name": existing_name}, + ) + + # Then: the API reports a control-name conflict + assert clone_response.status_code == 409 + assert clone_response.json()["error_code"] == "CONTROL_NAME_CONFLICT" + + +def test_clone_control_rejects_corrupted_source_data(client: TestClient) -> None: + # Given: a source control whose stored JSON is corrupted + source_id, _ = _insert_raw_control(data={"description": "broken"}) + + # When: cloning that control + clone_response = client.post(f"/api/v1/controls/{source_id}/clone") + + # Then: the endpoint rejects the corrupted source row + assert clone_response.status_code == 422 + assert clone_response.json()["error_code"] == "CORRUPTED_DATA" + + +def test_clone_control_records_latest_source_version_in_provenance_note( + client: TestClient, +) -> None: + # Given: a source control with multiple recorded versions + source_id, source_name = _create_control(client) + updated_payload = deepcopy(VALID_CONTROL_PAYLOAD) + updated_payload["description"] = "Version two" + update_response = client.put( + f"/api/v1/controls/{source_id}/data", + json={"data": updated_payload}, + ) + assert update_response.status_code == 200, update_response.text + patch_response = client.patch( + f"/api/v1/controls/{source_id}", + json={"enabled": False}, + ) + assert patch_response.status_code == 200, patch_response.text + + # When: cloning the latest source state + clone_response = client.post(f"/api/v1/controls/{source_id}/clone") + + # Then: the clone provenance points at the latest source version number + assert clone_response.status_code == 200, clone_response.text + clone_id = int(clone_response.json()["control_id"]) + clone_versions = _fetch_versions(clone_id) + assert clone_versions[0].event_type == "cloned" + assert clone_versions[0].note == f"Cloned from '{source_name}' (id:{source_id}) at version 3" + + +def test_clone_control_preserves_unrendered_template_shape(client: TestClient) -> None: + source_id, source_name = _create_control( + client, + name="template-control", + data=_unrendered_template_payload(), + ) + + clone_response = client.post( + f"/api/v1/controls/{source_id}/clone", + json={"name": f"{source_name}-copy"}, + ) + assert clone_response.status_code == 200, clone_response.text + clone_id = int(clone_response.json()["control_id"]) + + get_response = client.get(f"/api/v1/controls/{clone_id}/data") + assert get_response.status_code == 200, get_response.text + data = get_response.json()["data"] + + assert data["enabled"] is False + assert "template" in data + assert "condition" not in data + assert _fetch_control(clone_id) is not None + assert _fetch_control(clone_id).cloned_control_id == source_id # type: ignore[union-attr] + + +def test_clone_control_rejects_deleted_source(client: TestClient) -> None: + source_id, _ = _create_control(client) + delete_response = client.delete(f"/api/v1/controls/{source_id}") + assert delete_response.status_code == 200, delete_response.text + + clone_response = client.post(f"/api/v1/controls/{source_id}/clone") + assert clone_response.status_code == 404 + assert clone_response.json()["error_code"] == "CONTROL_NOT_FOUND" + + +def test_add_agent_control_rejects_published_control(client: TestClient) -> None: + _ensure_default_store() + control_id, _ = _create_control(client) + publish_response = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + assert publish_response.status_code == 200, publish_response.text + + agent_name = _create_agent(client) + response = client.post(f"/api/v1/agents/{agent_name}/controls/{control_id}") + + assert response.status_code == 409 + assert response.json()["error_code"] == "CONTROL_PUBLISHED" + + +def test_add_control_to_policy_rejects_published_control(client: TestClient) -> None: + _ensure_default_store() + control_id, _ = _create_control(client) + publish_response = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + assert publish_response.status_code == 200, publish_response.text + + policy_id = _create_policy(client) + response = client.post(f"/api/v1/policies/{policy_id}/controls/{control_id}") + + assert response.status_code == 409 + assert response.json()["error_code"] == "CONTROL_PUBLISHED" + + +def test_delete_control_removes_store_publication_rows(client: TestClient) -> None: + _ensure_default_store() + control_id, _ = _create_control(client) + publish_response = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + assert publish_response.status_code == 200, publish_response.text + + delete_response = client.delete(f"/api/v1/controls/{control_id}") + + assert delete_response.status_code == 200, delete_response.text + assert _published_rows(control_id) == [] + + +def test_delete_control_with_runtime_association_still_requires_force_even_if_published( + client: TestClient, +) -> None: + store_id = _ensure_default_store() + control_id, _ = _create_control(client) + policy_id = _create_policy(client) + assoc_response = client.post(f"/api/v1/policies/{policy_id}/controls/{control_id}") + assert assoc_response.status_code == 200, assoc_response.text + + with Session(engine) as session: + session.execute( + insert(control_stores_controls).values(store_id=store_id, control_id=control_id) + ) + session.commit() + + blocked = client.delete(f"/api/v1/controls/{control_id}") + assert blocked.status_code == 409 + assert blocked.json()["error_code"] == "CONTROL_IN_USE" + assert _published_rows(control_id) != [] + + forced = client.delete(f"/api/v1/controls/{control_id}?force=true") + assert forced.status_code == 200, forced.text + assert _published_rows(control_id) == [] + + +def test_browse_published_controls_marks_unrendered_templates(client: TestClient) -> None: + _ensure_default_store() + control_id, _ = _create_control( + client, + name="template-published", + data=_unrendered_template_payload(), + ) + publish_response = client.post(f"/api/v1/control-stores/default/controls/{control_id}") + assert publish_response.status_code == 200, publish_response.text + + response = client.get("/api/v1/control-stores/default/controls") + assert response.status_code == 200, response.text + summary = response.json()["controls"][0] + + assert summary["id"] == control_id + assert summary["template_backed"] is True + assert summary["template_rendered"] is False + assert summary["enabled"] is False + + +def test_publish_control_database_error_returns_500( + app: FastAPI, + client: TestClient, +) -> None: + # Given: a valid control lookup path whose commit fails during publish + control = Control( + id=123, + name="publish-db-error", + data=deepcopy(VALID_CONTROL_PAYLOAD), + deleted_at=None, + ) + control_result = MagicMock() + control_result.scalars.return_value.first.return_value = control + associations_result = MagicMock() + associations_result.all.return_value = [] + store_result = MagicMock() + store_result.scalar_one_or_none.return_value = 1 + publish_result = MagicMock() + mock_session = AsyncMock(spec=AsyncSession) + mock_session.execute = AsyncMock( + side_effect=[ + control_result, + associations_result, + store_result, + publish_result, + ] + ) + mock_session.commit.side_effect = Exception("Database error") + mock_session.rollback = AsyncMock() + + async def mock_db() -> AsyncGenerator[AsyncSession, None]: + yield mock_session + + app.dependency_overrides[get_async_db] = mock_db + try: + # When: publishing the control + response = client.post("/api/v1/control-stores/default/controls/123") + finally: + app.dependency_overrides.clear() + + # Then: the endpoint rolls back and reports a database error + assert response.status_code == 500 + assert response.json()["error_code"] == "DATABASE_ERROR" + assert mock_session.rollback.await_count == 1 + lock_stmt = mock_session.execute.await_args_list[0].args[0] + assert getattr(lock_stmt, "_for_update_arg", None) is not None + + +def test_unpublish_control_database_error_returns_500( + app: FastAPI, + client: TestClient, +) -> None: + # Given: a valid unpublish path whose commit fails after removing the publication row + control = Control( + id=123, + name="unpublish-db-error", + data=deepcopy(VALID_CONTROL_PAYLOAD), + deleted_at=None, + ) + control_result = MagicMock() + control_result.scalars.return_value.first.return_value = control + store_result = MagicMock() + store_result.scalar_one_or_none.return_value = 1 + delete_result = MagicMock() + mock_session = AsyncMock(spec=AsyncSession) + mock_session.execute = AsyncMock( + side_effect=[ + control_result, + store_result, + delete_result, + ] + ) + mock_session.commit.side_effect = Exception("Database error") + mock_session.rollback = AsyncMock() + + async def mock_db() -> AsyncGenerator[AsyncSession, None]: + yield mock_session + + app.dependency_overrides[get_async_db] = mock_db + try: + # When: unpublishing the control + response = client.delete("/api/v1/control-stores/default/controls/123") + finally: + app.dependency_overrides.clear() + + # Then: the endpoint rolls back and reports a database error + assert response.status_code == 500 + assert response.json()["error_code"] == "DATABASE_ERROR" + assert mock_session.rollback.await_count == 1 + + +def test_clone_control_integrity_name_conflict_returns_conflict( + app: FastAPI, + client: TestClient, +) -> None: + # Given: a clone request that races with another writer on the target control name + source_control = Control( + id=123, + name="clone-source", + data=deepcopy(VALID_CONTROL_PAYLOAD), + deleted_at=None, + ) + control_result = MagicMock() + control_result.scalars.return_value.first.return_value = source_control + name_lookup_result = MagicMock() + name_lookup_result.first.return_value = None + source_version_result = MagicMock() + source_version_result.scalar_one.return_value = 1 + insert_result = MagicMock() + insert_result.scalar_one.return_value = 201 + lock_result = MagicMock() + clone_version_result = MagicMock() + clone_version_result.scalar_one.return_value = 1 + mock_session = AsyncMock(spec=AsyncSession) + mock_session.execute = AsyncMock( + side_effect=[ + control_result, + name_lookup_result, + source_version_result, + insert_result, + lock_result, + clone_version_result, + ] + ) + mock_session.flush = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit.side_effect = _make_integrity_error("idx_controls_name_active") + mock_session.rollback = AsyncMock() + + async def mock_db() -> AsyncGenerator[AsyncSession, None]: + yield mock_session + + app.dependency_overrides[get_async_db] = mock_db + try: + # When: cloning into a name claimed concurrently by another transaction + response = client.post( + "/api/v1/controls/123/clone", + json={"name": "race-target"}, + ) + finally: + app.dependency_overrides.clear() + + # Then: the endpoint maps the integrity failure to a control-name conflict + assert response.status_code == 409 + assert response.json()["error_code"] == "CONTROL_NAME_CONFLICT" + assert mock_session.rollback.await_count == 1 + + +def test_clone_control_without_name_retries_generated_name_conflict( + app: FastAPI, + client: TestClient, +) -> None: + # Given: an auto-generated clone name that collides on the first commit attempt + source_control = Control( + id=123, + name="source-control", + data=deepcopy(VALID_CONTROL_PAYLOAD), + deleted_at=None, + ) + control_result = MagicMock() + control_result.scalars.return_value.first.return_value = source_control + first_name_lookup_result = MagicMock() + first_name_lookup_result.first.return_value = None + first_source_version_result = MagicMock() + first_source_version_result.scalar_one.return_value = 3 + first_insert_result = MagicMock() + first_insert_result.scalar_one.return_value = 201 + first_lock_result = MagicMock() + first_clone_version_result = MagicMock() + first_clone_version_result.scalar_one.return_value = 1 + retry_control_result = MagicMock() + retry_control_result.scalars.return_value.first.return_value = source_control + retry_first_name_lookup_result = MagicMock() + retry_first_name_lookup_result.first.return_value = object() + retry_second_name_lookup_result = MagicMock() + retry_second_name_lookup_result.first.return_value = None + retry_source_version_result = MagicMock() + retry_source_version_result.scalar_one.return_value = 3 + retry_insert_result = MagicMock() + retry_insert_result.scalar_one.return_value = 202 + retry_lock_result = MagicMock() + retry_clone_version_result = MagicMock() + retry_clone_version_result.scalar_one.return_value = 1 + + mock_session = AsyncMock(spec=AsyncSession) + mock_session.execute = AsyncMock( + side_effect=[ + control_result, + first_name_lookup_result, + first_source_version_result, + first_insert_result, + first_lock_result, + first_clone_version_result, + retry_control_result, + retry_first_name_lookup_result, + retry_second_name_lookup_result, + retry_source_version_result, + retry_insert_result, + retry_lock_result, + retry_clone_version_result, + ] + ) + mock_session.add = MagicMock() + mock_session.flush = AsyncMock() + mock_session.commit.side_effect = [ + _make_integrity_error("idx_controls_name_active"), + None, + ] + mock_session.rollback = AsyncMock() + + async def mock_db() -> AsyncGenerator[AsyncSession, None]: + yield mock_session + + app.dependency_overrides[get_async_db] = mock_db + try: + # When: cloning without an explicit name + response = client.post("/api/v1/controls/123/clone") + finally: + app.dependency_overrides.clear() + + # Then: the endpoint retries and returns the next unique generated name + assert response.status_code == 200, response.text + assert response.json()["name"] == "source-control-copy-2" + assert response.json()["control_id"] == 202 + assert mock_session.rollback.await_count == 1 + + +def test_clone_control_non_name_integrity_error_returns_500( + app: FastAPI, + client: TestClient, +) -> None: + # Given: a clone request that hits a non-name integrity failure during commit + source_control = Control( + id=123, + name="clone-source", + data=deepcopy(VALID_CONTROL_PAYLOAD), + deleted_at=None, + ) + control_result = MagicMock() + control_result.scalars.return_value.first.return_value = source_control + name_lookup_result = MagicMock() + name_lookup_result.first.return_value = None + source_version_result = MagicMock() + source_version_result.scalar_one.return_value = 1 + insert_result = MagicMock() + insert_result.scalar_one.return_value = 201 + lock_result = MagicMock() + clone_version_result = MagicMock() + clone_version_result.scalar_one.return_value = 1 + mock_session = AsyncMock(spec=AsyncSession) + mock_session.execute = AsyncMock( + side_effect=[ + control_result, + name_lookup_result, + source_version_result, + insert_result, + lock_result, + clone_version_result, + ] + ) + mock_session.flush = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit.side_effect = _make_integrity_error( + "uq_control_versions_control_version" + ) + mock_session.rollback = AsyncMock() + + async def mock_db() -> AsyncGenerator[AsyncSession, None]: + yield mock_session + + app.dependency_overrides[get_async_db] = mock_db + try: + # When: cloning the control + response = client.post("/api/v1/controls/123/clone", json={"name": "clone-target"}) + finally: + app.dependency_overrides.clear() + + # Then: the endpoint reports a database error instead of a name conflict + assert response.status_code == 500 + assert response.json()["error_code"] == "DATABASE_ERROR" + assert mock_session.rollback.await_count == 1 diff --git a/server/tests/test_control_versions.py b/server/tests/test_control_versions.py index f387a1f6..f8c255bc 100644 --- a/server/tests/test_control_versions.py +++ b/server/tests/test_control_versions.py @@ -149,6 +149,37 @@ def test_delete_control_force_creates_deleted_version_row(client: TestClient) -> assert latest.snapshot["deleted_at"] is not None +def test_cloned_control_versions_preserve_clone_provenance_after_update_and_delete( + client: TestClient, +) -> None: + # Given: a cloned control with source provenance + source_id, source_name = _create_control(client) + clone_response = client.post( + f"/api/v1/controls/{source_id}/clone", + json={"name": f"{source_name}-clone"}, + ) + assert clone_response.status_code == 200, clone_response.text + clone_id = int(clone_response.json()["control_id"]) + + updated_payload = deepcopy(VALID_CONTROL_PAYLOAD) + updated_payload["description"] = "Clone-only update" + + # When: updating and then deleting the cloned control + update_response = client.put( + f"/api/v1/controls/{clone_id}/data", + json={"data": updated_payload}, + ) + assert update_response.status_code == 200, update_response.text + delete_response = client.delete(f"/api/v1/controls/{clone_id}") + assert delete_response.status_code == 200, delete_response.text + + # Then: every later version keeps the original clone provenance + versions = _fetch_versions(clone_id) + assert [version.version_num for version in versions] == [1, 2, 3] + assert [version.event_type for version in versions] == ["cloned", "updated", "deleted"] + assert all(version.snapshot["cloned_control_id"] == source_id for version in versions) + + def test_list_control_versions_paginates_newest_first_without_snapshot( client: TestClient, ) -> None: diff --git a/server/tests/test_controls_additional.py b/server/tests/test_controls_additional.py index 361c7fd1..323cda60 100644 --- a/server/tests/test_controls_additional.py +++ b/server/tests/test_controls_additional.py @@ -8,19 +8,18 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from agent_control_evaluators import RegexEvaluatorConfig +from agent_control_models import ConditionNode from fastapi.testclient import TestClient from sqlalchemy import text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from agent_control_models import ConditionNode from agent_control_server.db import get_async_db -from agent_control_server.models import Control - -from agent_control_evaluators import RegexEvaluatorConfig from agent_control_server.endpoints import controls as controls_module from agent_control_server.main import app +from agent_control_server.models import Control from .conftest import engine from .utils import VALID_CONTROL_PAYLOAD @@ -96,6 +95,7 @@ def test_patch_control_rename_integrity_error_returns_conflict(client: TestClien name="old-control", data=deepcopy(VALID_CONTROL_PAYLOAD), deleted_at=None, + cloned_control_id=None, ) async def mock_db_integrity_error() -> AsyncGenerator[AsyncSession, None]: @@ -141,6 +141,7 @@ def test_patch_control_non_name_integrity_error_returns_500(client: TestClient) name="old-control", data=deepcopy(VALID_CONTROL_PAYLOAD), deleted_at=None, + cloned_control_id=None, ) async def mock_db_integrity_error() -> AsyncGenerator[AsyncSession, None]: @@ -314,7 +315,7 @@ def test_patch_control_rename_with_spaces_rejected(client: TestClient) -> None: def test_create_control_trimmed_name_stored(client: TestClient) -> None: - """Control names are canonicalized at the API boundary: leading/trailing whitespace is trimmed.""" + """Control names are canonicalized: leading and trailing whitespace is trimmed.""" resp = client.put( "/api/v1/controls", json={"name": " trimmed-control ", "data": VALID_CONTROL_PAYLOAD}, @@ -327,7 +328,7 @@ def test_create_control_trimmed_name_stored(client: TestClient) -> None: def test_patch_control_trimmed_name_stored(client: TestClient) -> None: - """PATCH control name is canonicalized at the API boundary: leading/trailing whitespace is trimmed.""" + """PATCH control names are canonicalized before persistence.""" control_id, _ = _create_control(client) resp = client.patch( f"/api/v1/controls/{control_id}", @@ -462,7 +463,10 @@ def test_list_controls_enabled_true_includes_missing_enabled(client: TestClient) # Given: controls with enabled true, enabled false, and missing enabled control_true_id, control_true_name = _create_control(client, name=f"Enabled-{uuid.uuid4()}") control_false_id, control_false_name = _create_control(client, name=f"Disabled-{uuid.uuid4()}") - control_missing_id, control_missing_name = _create_control(client, name=f"Missing-{uuid.uuid4()}") + control_missing_id, control_missing_name = _create_control( + client, + name=f"Missing-{uuid.uuid4()}", + ) data_true = deepcopy(VALID_CONTROL_PAYLOAD) data_true["enabled"] = True @@ -693,7 +697,10 @@ def test_create_control_allows_reusing_soft_deleted_name(client: TestClient) -> assert delete_resp.status_code == 200 # When: creating a new control with the same name - recreate_resp = client.put("/api/v1/controls", json={"name": name, "data": VALID_CONTROL_PAYLOAD}) + recreate_resp = client.put( + "/api/v1/controls", + json={"name": name, "data": VALID_CONTROL_PAYLOAD}, + ) # Then: creation succeeds because uniqueness only applies to active rows assert recreate_resp.status_code == 200, recreate_resp.text @@ -789,11 +796,10 @@ def test_set_control_data_agent_scoped_agent_not_found(client: TestClient) -> No def test_set_control_data_agent_scoped_evaluator_missing(client: TestClient) -> None: # Given: an agent without the referenced evaluator agent_name = f"agent-{uuid.uuid4().hex[:12]}" - agent_name = agent_name resp = client.post( "/api/v1/agents/initAgent", json={ - "agent": {"agent_name": agent_name, "agent_name": agent_name}, + "agent": {"agent_name": agent_name}, "steps": [], "evaluators": [], }, @@ -802,7 +808,10 @@ def test_set_control_data_agent_scoped_evaluator_missing(client: TestClient) -> control_id, _ = _create_control(client) payload = deepcopy(VALID_CONTROL_PAYLOAD) - payload["condition"]["evaluator"] = {"name": f"{agent_name}:missing", "config": {"pattern": "x"}} + payload["condition"]["evaluator"] = { + "name": f"{agent_name}:missing", + "config": {"pattern": "x"}, + } # When: setting data with evaluator not registered on agent resp = client.put(f"/api/v1/controls/{control_id}/data", json={"data": payload}) @@ -811,17 +820,19 @@ def test_set_control_data_agent_scoped_evaluator_missing(client: TestClient) -> assert resp.status_code == 422 body = resp.json() assert body["error_code"] == "EVALUATOR_NOT_FOUND" - assert any(err.get("field") == "data.condition.evaluator.name" for err in body.get("errors", [])) + assert any( + err.get("field") == "data.condition.evaluator.name" + for err in body.get("errors", []) + ) def test_set_control_data_agent_scoped_invalid_schema(client: TestClient) -> None: # Given: an agent with evaluator schema requiring "pattern" agent_name = f"agent-{uuid.uuid4().hex[:12]}" - agent_name = agent_name resp = client.post( "/api/v1/agents/initAgent", json={ - "agent": {"agent_name": agent_name, "agent_name": agent_name}, + "agent": {"agent_name": agent_name}, "steps": [], "evaluators": [ { @@ -849,7 +860,10 @@ def test_set_control_data_agent_scoped_invalid_schema(client: TestClient) -> Non assert resp.status_code == 422 body = resp.json() assert body["error_code"] == "INVALID_CONFIG" - assert any(err.get("field") == "data.condition.evaluator.config" for err in body.get("errors", [])) + assert any( + err.get("field") == "data.condition.evaluator.config" + for err in body.get("errors", []) + ) def test_patch_control_updates_name_and_enabled(client: TestClient) -> None: @@ -907,11 +921,10 @@ def test_set_control_data_agent_scoped_corrupted_agent_data_returns_422( ) -> None: # Given: an agent whose stored data is corrupted agent_name = f"agent-{uuid.uuid4().hex[:12]}" - agent_name = agent_name resp = client.post( "/api/v1/agents/initAgent", json={ - "agent": {"agent_name": agent_name, "agent_name": agent_name}, + "agent": {"agent_name": agent_name}, "steps": [], "evaluators": [{"name": "custom", "config_schema": {"type": "object"}}], }, diff --git a/server/tests/test_policies.py b/server/tests/test_policies.py index 623142b7..570e375a 100644 --- a/server/tests/test_policies.py +++ b/server/tests/test_policies.py @@ -53,10 +53,10 @@ def test_policy_add_control_and_list(client: TestClient) -> None: assert r.json()["success"] is True # When: listing policy controls - l = client.get(f"/api/v1/policies/{policy_id}/controls") + list_response = client.get(f"/api/v1/policies/{policy_id}/controls") # Then: the control id is included - assert l.status_code == 200 - assert control_id in l.json()["control_ids"] + assert list_response.status_code == 200 + assert control_id in list_response.json()["control_ids"] def test_policy_add_control_idempotent(client: TestClient) -> None: @@ -72,9 +72,9 @@ def test_policy_add_control_idempotent(client: TestClient) -> None: assert r.json()["success"] is True # Then: listing still shows it once (set semantics by ids) - l = client.get(f"/api/v1/policies/{policy_id}/controls") - assert l.status_code == 200 - ids = l.json()["control_ids"] + list_response = client.get(f"/api/v1/policies/{policy_id}/controls") + assert list_response.status_code == 200 + ids = list_response.json()["control_ids"] assert ids.count(control_id) == 1 @@ -91,10 +91,10 @@ def test_policy_remove_control(client: TestClient) -> None: assert d.json()["success"] is True # When: listing controls - l = client.get(f"/api/v1/policies/{policy_id}/controls") + list_response = client.get(f"/api/v1/policies/{policy_id}/controls") # Then: the control is not present - assert l.status_code == 200 - assert control_id not in l.json()["control_ids"] + assert list_response.status_code == 200 + assert control_id not in list_response.json()["control_ids"] def test_policy_remove_control_idempotent_when_not_associated(client: TestClient) -> None: @@ -210,11 +210,21 @@ def test_policy_add_control_db_error_returns_500( policy_result.scalars.return_value.first.return_value = policy control_result = MagicMock() control_result.scalars.return_value.first.return_value = control + store_result = MagicMock() + store_result.scalar_one_or_none.return_value = 1 + publication_result = MagicMock() + publication_result.first.return_value = None async def mock_db() -> AsyncGenerator[AsyncSession, None]: mock_session = AsyncMock(spec=AsyncSession) mock_session.execute = AsyncMock( - side_effect=[policy_result, control_result, MagicMock()] + side_effect=[ + policy_result, + control_result, + store_result, + publication_result, + MagicMock(), + ] ) mock_session.commit.side_effect = Exception("db error") mock_session.rollback = AsyncMock() diff --git a/server/tests/test_services_controls.py b/server/tests/test_services_controls.py index 705054f0..a8346fbe 100644 --- a/server/tests/test_services_controls.py +++ b/server/tests/test_services_controls.py @@ -9,6 +9,7 @@ import pytest from agent_control_models.errors import ErrorCode from sqlalchemy import insert, select +from sqlalchemy.exc import ProgrammingError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session @@ -104,6 +105,14 @@ def _fetch_all_versions() -> list[ControlVersion]: return list(session.scalars(select(ControlVersion)).all()) +def _missing_store_schema_error() -> ProgrammingError: + return ProgrammingError( + "SELECT * FROM control_stores", + {}, + Exception('relation "control_stores" does not exist'), + ) + + @pytest.mark.asyncio async def test_create_version_locks_control_row_before_allocating_version_number() -> None: # Given: a control service with a mocked session @@ -121,6 +130,7 @@ async def test_create_version_locks_control_row_before_allocating_version_number name=f"control-{uuid.uuid4()}", data=VALID_CONTROL_PAYLOAD, deleted_at=None, + cloned_control_id=None, ) # When: creating a new version row @@ -134,6 +144,29 @@ async def test_create_version_locks_control_row_before_allocating_version_number assert version.version_num == 4 +@pytest.mark.asyncio +async def test_is_control_published_returns_false_when_store_schema_is_absent() -> None: + # Given: a pre-migration database session where control-store tables do not exist yet + mock_session = AsyncMock(spec=AsyncSession) + mock_session.execute = AsyncMock(side_effect=[_missing_store_schema_error()]) + + # When: checking whether a control is published + is_published = await ControlService(mock_session).is_control_published(123) + + # Then: the compatibility guard treats it as unpublished + assert is_published is False + + +@pytest.mark.asyncio +async def test_remove_all_store_publications_noops_when_store_schema_is_absent() -> None: + # Given: a pre-migration database session where publication rows cannot exist yet + mock_session = AsyncMock(spec=AsyncSession) + mock_session.execute = AsyncMock(side_effect=[_missing_store_schema_error()]) + + # When/Then: removing publications does not raise + await ControlService(mock_session).remove_all_store_publications(123) + + @pytest.mark.asyncio async def test_create_control_transaction_rollback_does_not_persist_control_or_version() -> None: # Given: a new control plus its initial version inside an open transaction