diff --git a/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py b/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py index 775e3967..9e2b6c13 100644 --- a/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py +++ b/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py @@ -9,8 +9,6 @@ from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. revision: str = "e4c05d7591a8" diff --git a/src/codegate/cli.py b/src/codegate/cli.py index 1ae3f9c2..5c08821c 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -16,8 +16,8 @@ from codegate.config import Config, ConfigurationError from codegate.db.connection import ( init_db_sync, - init_session_if_not_exists, init_instance, + init_session_if_not_exists, ) from codegate.pipeline.factory import PipelineFactory from codegate.pipeline.sensitive_data.manager import SensitiveDataManager diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 3f439aea..d417da55 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -617,7 +617,7 @@ async def init_instance(self) -> None: await self._execute_with_no_return(sql, instance.model_dump()) except IntegrityError as e: logger.debug(f"Exception type: {type(e)}") - raise AlreadyExistsError(f"Instance already initialized.") + raise AlreadyExistsError("Instance already initialized.") class DbReader(DbCodeGate): diff --git a/src/codegate/muxing/router.py b/src/codegate/muxing/router.py index bfa9c663..26002441 100644 --- a/src/codegate/muxing/router.py +++ b/src/codegate/muxing/router.py @@ -10,7 +10,7 @@ from codegate.muxing.adapter import BodyAdapter, ResponseAdapter from codegate.providers.fim_analyzer import FIMAnalyzer from codegate.providers.registry import ProviderRegistry -from codegate.workspaces.crud import WorkspaceCrud +from codegate.workspaces.crud import WorkspaceCrud, WorkspaceDoesNotExistError logger = structlog.get_logger("codegate") @@ -40,23 +40,47 @@ def _ensure_path_starts_with_slash(self, path: str) -> str: return path if path.startswith("/") else f"/{path}" async def _get_model_route( - self, thing_to_match: mux_models.ThingToMatchMux + self, thing_to_match: mux_models.ThingToMatchMux, workspace_name: Optional[str] = None ) -> Optional[rulematcher.ModelRoute]: """ Get the model route for the given things_to_match. + + If workspace_name is provided and exists, use that workspace. + Otherwise, use the active workspace. """ - mux_registry = await rulematcher.get_muxing_rules_registry() try: - # Try to get a model route for the active workspace - model_route = await mux_registry.get_match_for_active_workspace(thing_to_match) - return model_route + mux_registry = await rulematcher.get_muxing_rules_registry() + relevant_workspace = await self._get_relevant_workspace_name( + mux_registry, workspace_name + ) + return await mux_registry.get_match_for_workspace(relevant_workspace, thing_to_match) except rulematcher.MuxMatchingError as e: logger.exception(f"Error matching rule and getting model route: {e}") raise HTTPException(detail=str(e), status_code=404) except Exception as e: - logger.exception(f"Error getting active workspace muxes: {e}") + logger.exception(f"Error getting workspace muxes: {e}") raise HTTPException(detail=str(e), status_code=404) + async def _get_relevant_workspace_name( + self, mreg: rulematcher.MuxingRulesinWorkspaces, workspace_name: Optional[str] + ) -> str: + if not workspace_name: + # No workspace specified, use active workspace + return mreg.get_active_workspace() + + try: + # Verify the requested workspace exists + # TODO: We should have an in-memory cache of the workspaces + await self._ws_crud.get_workspace_by_name(workspace_name) + logger.debug(f"Using workspace from X-CodeGate-Workspace header: {workspace_name}") + return workspace_name + except WorkspaceDoesNotExistError: + # Workspace doesn't exist, fall back to active workspace + logger.warning( + f"Workspace {workspace_name} does not exist, falling back to active workspace" + ) + return mreg.get_active_workspace() + def _setup_routes(self): @self.router.post(f"/{self.route_name}/{{rest_of_path:path}}") @@ -68,7 +92,7 @@ async def route_to_dest_provider( """ Route the request to the correct destination provider. - 1. Get destination provider from DB and active workspace. + 1. Get destination provider from DB and workspace (from header or active). 2. Map the request body to the destination provider format. 3. Run pipeline. Selecting the correct destination provider. 4. Transmit the response back to the client in OpenAI format. @@ -78,14 +102,17 @@ async def route_to_dest_provider( data = json.loads(body) is_fim_request = FIMAnalyzer.is_fim_request(rest_of_path, data) - # 1. Get destination provider from DB and active workspace. + # Check if X-CodeGate-Workspace header is present + workspace_header = request.headers.get("X-CodeGate-Workspace") + + # 1. Get destination provider from DB and workspace (from header or active). thing_to_match = mux_models.ThingToMatchMux( body=data, url_request_path=rest_of_path, is_fim_request=is_fim_request, client_type=request.state.detected_client, ) - model_route = await self._get_model_route(thing_to_match) + model_route = await self._get_model_route(thing_to_match, workspace_header) if not model_route: raise HTTPException( detail="No matching rule found for the active workspace", status_code=404 diff --git a/src/codegate/muxing/rulematcher.py b/src/codegate/muxing/rulematcher.py index d41eb2ce..723e253d 100644 --- a/src/codegate/muxing/rulematcher.py +++ b/src/codegate/muxing/rulematcher.py @@ -25,7 +25,7 @@ class MuxMatchingError(Exception): pass -async def get_muxing_rules_registry(): +async def get_muxing_rules_registry() -> "MuxingRulesinWorkspaces": """Returns a singleton instance of the muxing rules registry.""" global _muxrules_sgtn @@ -199,23 +199,27 @@ async def set_active_workspace(self, workspace_name: str) -> None: """Set the active workspace.""" self._active_workspace = workspace_name + def get_active_workspace(self) -> str: + """Get the active workspace.""" + return self._active_workspace + async def get_registries(self) -> List[str]: """Get the list of workspaces.""" async with self._lock: return list(self._ws_rules.keys()) - async def get_match_for_active_workspace( - self, thing_to_match: mux_models.ThingToMatchMux + async def get_match_for_workspace( + self, workspace_name: str, thing_to_match: mux_models.ThingToMatchMux ) -> Optional[ModelRoute]: - """Get the first match for the given thing_to_match.""" + """Get the first match for the given thing_to_match in the specified workspace.""" # We iterate over all the rules and return the first match # Since we already do a deepcopy in __getitem__, we don't need to lock here try: - rules = await self.get_ws_rules(self._active_workspace) + rules = await self.get_ws_rules(workspace_name) for rule in rules: if rule.match(thing_to_match): return rule.destination() return None except KeyError: - raise RuntimeError("No rules found for the active workspace") + raise RuntimeError(f"No rules found for workspace {workspace_name}")