From 9c787b711ee6508c11e625cb60a3e2ab1fa29159 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Tue, 3 Mar 2026 00:07:58 +0100 Subject: [PATCH 1/4] Fleet sharing main mechanisms - DB schema for resource exports and imports - Submitting jobs to imported fleets - Viewing imported fleets and instances in API, CLI, UI - Filtering events by imported fleets and instances Currently testable through unit tests and through exports and imports manually created in the DB. --- frontend/src/pages/Fleets/List/hooks.tsx | 1 + .../pages/Instances/List/hooks/useFilters.ts | 1 + frontend/src/types/fleet.d.ts | 1 + frontend/src/types/instance.d.ts | 1 + src/dstack/_internal/cli/commands/fleet.py | 4 +- ...7cca121fcb_add_resource_exports_imports.py | 142 +++++ src/dstack/_internal/server/models.py | 68 +++ src/dstack/_internal/server/routers/fleets.py | 27 +- .../_internal/server/routers/instances.py | 15 +- src/dstack/_internal/server/schemas/fleets.py | 15 + .../_internal/server/schemas/instances.py | 1 + .../_internal/server/security/permissions.py | 68 ++- .../_internal/server/services/events.py | 18 +- .../_internal/server/services/fleets.py | 77 ++- .../_internal/server/services/instances.py | 23 +- .../server/services/jobs/__init__.py | 5 +- .../_internal/server/services/runs/plan.py | 30 +- src/dstack/_internal/server/testing/common.py | 21 + src/dstack/_internal/utils/common.py | 18 +- src/dstack/api/server/_fleets.py | 6 +- .../scheduled_tasks/test_submitted_jobs.py | 105 ++++ .../_internal/server/routers/test_events.py | 228 ++++++++ .../_internal/server/routers/test_fleets.py | 526 ++++++++++++++++++ .../server/routers/test_instances.py | 279 ++++++++++ .../_internal/server/routers/test_runs.py | 52 ++ 25 files changed, 1685 insertions(+), 47 deletions(-) create mode 100644 src/dstack/_internal/server/migrations/versions/2026/03_02_1345_ea7cca121fcb_add_resource_exports_imports.py diff --git a/frontend/src/pages/Fleets/List/hooks.tsx b/frontend/src/pages/Fleets/List/hooks.tsx index 639d7b8683..397a673ddf 100644 --- a/frontend/src/pages/Fleets/List/hooks.tsx +++ b/frontend/src/pages/Fleets/List/hooks.tsx @@ -182,6 +182,7 @@ export const useFilters = (localStorePrefix = 'fleet-list-page') => { return { ...params, only_active: onlyActive, + include_imported: true, } as Partial; }, [propertyFilterQuery, onlyActive]); diff --git a/frontend/src/pages/Instances/List/hooks/useFilters.ts b/frontend/src/pages/Instances/List/hooks/useFilters.ts index 55453c33e4..50741de62e 100644 --- a/frontend/src/pages/Instances/List/hooks/useFilters.ts +++ b/frontend/src/pages/Instances/List/hooks/useFilters.ts @@ -83,6 +83,7 @@ export const useFilters = (localStorePrefix = 'instances-list-page') => { return { ...params, only_active: onlyActive, + include_imported: true, } as Partial; }, [propertyFilterQuery, onlyActive]); diff --git a/frontend/src/types/fleet.d.ts b/frontend/src/types/fleet.d.ts index 2813cd4023..b5050167b3 100644 --- a/frontend/src/types/fleet.d.ts +++ b/frontend/src/types/fleet.d.ts @@ -3,6 +3,7 @@ declare type TSpotPolicy = 'spot' | 'on-demand' | 'auto'; declare type TFleetListRequestParams = TBaseRequestListParams & { project_name?: string; only_active?: boolean; + include_imported?: boolean; }; declare interface ISSHHostParamsRequest { diff --git a/frontend/src/types/instance.d.ts b/frontend/src/types/instance.d.ts index 585f4f5093..555e355dae 100644 --- a/frontend/src/types/instance.d.ts +++ b/frontend/src/types/instance.d.ts @@ -2,6 +2,7 @@ declare type TInstanceListRequestParams = TBaseRequestListParams & { project_names?: string[]; fleet_ids?: string[]; only_active?: boolean; + include_imported?: boolean; }; declare type TInstanceStatus = diff --git a/src/dstack/_internal/cli/commands/fleet.py b/src/dstack/_internal/cli/commands/fleet.py index 130e2c3fcf..1a0ba8335d 100644 --- a/src/dstack/_internal/cli/commands/fleet.py +++ b/src/dstack/_internal/cli/commands/fleet.py @@ -93,7 +93,7 @@ def _command(self, args: argparse.Namespace): args.subfunc(args) def _list(self, args: argparse.Namespace): - fleets = self.api.client.fleets.list(self.api.project) + fleets = self.api.client.fleets.list(self.api.project, include_imported=True) if not args.watch: print_fleets_table(fleets, verbose=args.verbose) return @@ -103,7 +103,7 @@ def _list(self, args: argparse.Namespace): while True: live.update(get_fleets_table(fleets, verbose=args.verbose)) time.sleep(LIVE_TABLE_PROVISION_INTERVAL_SECS) - fleets = self.api.client.fleets.list(self.api.project) + fleets = self.api.client.fleets.list(self.api.project, include_imported=True) except KeyboardInterrupt: pass diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_02_1345_ea7cca121fcb_add_resource_exports_imports.py b/src/dstack/_internal/server/migrations/versions/2026/03_02_1345_ea7cca121fcb_add_resource_exports_imports.py new file mode 100644 index 0000000000..a6bd031335 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_02_1345_ea7cca121fcb_add_resource_exports_imports.py @@ -0,0 +1,142 @@ +"""Add resource exports imports + +Revision ID: ea7cca121fcb +Revises: 46150101edec +Create Date: 2026-03-02 13:45:57.118841+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "ea7cca121fcb" +down_revision = "46150101edec" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "resource_exports", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column( + "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False + ), + sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + name=op.f("fk_resource_exports_project_id_projects"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_resource_exports")), + sa.UniqueConstraint("project_id", "name", name="uq_resource_exports_project_id_name"), + ) + with op.batch_alter_table("resource_exports", schema=None) as batch_op: + batch_op.create_index( + batch_op.f("ix_resource_exports_project_id"), ["project_id"], unique=False + ) + + op.create_table( + "exported_fleets", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column( + "resource_export_id", + sqlalchemy_utils.types.uuid.UUIDType(binary=False), + nullable=False, + ), + sa.Column("fleet_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.ForeignKeyConstraint( + ["fleet_id"], + ["fleets.id"], + name=op.f("fk_exported_fleets_fleet_id_fleets"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["resource_export_id"], + ["resource_exports.id"], + name=op.f("fk_exported_fleets_resource_export_id_resource_exports"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_exported_fleets")), + sa.UniqueConstraint( + "resource_export_id", "fleet_id", name="uq_exported_fleets_resource_export_id_fleet_id" + ), + ) + with op.batch_alter_table("exported_fleets", schema=None) as batch_op: + batch_op.create_index( + batch_op.f("ix_exported_fleets_fleet_id"), ["fleet_id"], unique=False + ) + batch_op.create_index( + batch_op.f("ix_exported_fleets_resource_export_id"), + ["resource_export_id"], + unique=False, + ) + + op.create_table( + "resource_imports", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column( + "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False + ), + sa.Column( + "resource_export_id", + sqlalchemy_utils.types.uuid.UUIDType(binary=False), + nullable=False, + ), + sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + name=op.f("fk_resource_imports_project_id_projects"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["resource_export_id"], + ["resource_exports.id"], + name=op.f("fk_resource_imports_resource_export_id_resource_exports"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_resource_imports")), + sa.UniqueConstraint( + "project_id", + "resource_export_id", + name="uq_resource_imports_project_id_resource_export_id", + ), + ) + with op.batch_alter_table("resource_imports", schema=None) as batch_op: + batch_op.create_index( + batch_op.f("ix_resource_imports_project_id"), ["project_id"], unique=False + ) + batch_op.create_index( + batch_op.f("ix_resource_imports_resource_export_id"), + ["resource_export_id"], + unique=False, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("resource_imports", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_resource_imports_resource_export_id")) + batch_op.drop_index(batch_op.f("ix_resource_imports_project_id")) + + op.drop_table("resource_imports") + with op.batch_alter_table("exported_fleets", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_exported_fleets_resource_export_id")) + batch_op.drop_index(batch_op.f("ix_exported_fleets_fleet_id")) + + op.drop_table("exported_fleets") + with op.batch_alter_table("resource_exports", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_resource_exports_project_id")) + + op.drop_table("resource_exports") + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 15c5488da5..418721e3b1 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -978,3 +978,71 @@ class EventTargetModel(BaseModel): ) entity_id: Mapped[uuid.UUID] = mapped_column(UUIDType(binary=False), index=True) entity_name: Mapped[str] = mapped_column(String(200)) + + +class ResourceExportModel(BaseModel): + __tablename__ = "resource_exports" + __table_args__ = ( + UniqueConstraint("project_id", "name", name="uq_resource_exports_project_id_name"), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + name: Mapped[str] = mapped_column(String(100)) + project_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("projects.id", ondelete="CASCADE"), index=True + ) + project: Mapped["ProjectModel"] = relationship() + created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) + resource_imports: Mapped[List["ResourceImportModel"]] = relationship( + back_populates="resource_export" + ) + exported_fleets: Mapped[List["ExportedFleetModel"]] = relationship( + back_populates="resource_export" + ) + + +class ResourceImportModel(BaseModel): + __tablename__ = "resource_imports" + __table_args__ = ( + UniqueConstraint( + "project_id", + "resource_export_id", + name="uq_resource_imports_project_id_resource_export_id", + ), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + project_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("projects.id", ondelete="CASCADE"), index=True + ) + project: Mapped["ProjectModel"] = relationship() + resource_export_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("resource_exports.id", ondelete="CASCADE"), index=True + ) + resource_export: Mapped["ResourceExportModel"] = relationship() + created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) + + +class ExportedFleetModel(BaseModel): + __tablename__ = "exported_fleets" + __table_args__ = ( + UniqueConstraint( + "resource_export_id", "fleet_id", name="uq_exported_fleets_resource_export_id_fleet_id" + ), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + resource_export_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("resource_exports.id", ondelete="CASCADE"), index=True + ) + resource_export: Mapped["ResourceExportModel"] = relationship() + fleet_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("fleets.id", ondelete="CASCADE"), index=True + ) + fleet: Mapped["FleetModel"] = relationship() diff --git a/src/dstack/_internal/server/routers/fleets.py b/src/dstack/_internal/server/routers/fleets.py index a436d1123a..cb18db8bbd 100644 --- a/src/dstack/_internal/server/routers/fleets.py +++ b/src/dstack/_internal/server/routers/fleets.py @@ -9,6 +9,7 @@ from dstack._internal.core.models.fleets import Fleet, FleetPlan from dstack._internal.server.compatibility.common import patch_offers_list from dstack._internal.server.db import get_session +from dstack._internal.server.deps import Project from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.fleets import ( ApplyFleetPlanRequest, @@ -18,8 +19,13 @@ GetFleetPlanRequest, GetFleetRequest, ListFleetsRequest, + ListProjectFleetsRequest, +) +from dstack._internal.server.security.permissions import ( + Authenticated, + ProjectMember, + check_can_access_fleet, ) -from dstack._internal.server.security.permissions import Authenticated, ProjectMember from dstack._internal.server.utils.routers import ( CustomORJSONResponse, get_base_api_additional_responses, @@ -58,6 +64,7 @@ async def list_fleets( user=user, project_name=body.project_name, only_active=body.only_active, + include_imported=body.include_imported, prev_created_at=body.prev_created_at, prev_id=body.prev_id, limit=body.limit, @@ -68,6 +75,7 @@ async def list_fleets( @project_router.post("/list", response_model=List[Fleet]) async def list_project_fleets( + body: Optional[ListProjectFleetsRequest] = None, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ): @@ -76,8 +84,14 @@ async def list_project_fleets( Includes only active fleet instances. To list all fleet instances, use `/api/instances/list`. """ _, project = user_project + if body is None: + body = ListProjectFleetsRequest() return CustomORJSONResponse( - await fleets_services.list_project_fleets(session=session, project=project) + await fleets_services.list_project_fleets( + session=session, + project=project, + include_imported=body.include_imported, + ) ) @@ -85,16 +99,19 @@ async def list_project_fleets( async def get_fleet( body: GetFleetRequest, session: AsyncSession = Depends(get_session), - user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), + user: UserModel = Depends(Authenticated()), + project: ProjectModel = Depends(Project()), ): """ Returns a fleet given `name` or `id`. If given `name`, does not return deleted fleets. If given `id`, returns deleted fleets. """ - _, project = user_project + await check_can_access_fleet( + session=session, user=user, fleet_project=project, fleet_name_or_id=body.get_name_or_id() + ) fleet = await fleets_services.get_fleet( - session=session, project=project, name=body.name, fleet_id=body.id + session=session, project=project, name_or_id=body.get_name_or_id() ) if fleet is None: raise ResourceNotExistsError() diff --git a/src/dstack/_internal/server/routers/instances.py b/src/dstack/_internal/server/routers/instances.py index b241d7e764..1b64ec8b82 100644 --- a/src/dstack/_internal/server/routers/instances.py +++ b/src/dstack/_internal/server/routers/instances.py @@ -7,6 +7,7 @@ from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.core.models.instances import Instance from dstack._internal.server.db import get_session +from dstack._internal.server.deps import Project from dstack._internal.server.models import ProjectModel, UserModel from dstack._internal.server.schemas.instances import ( GetInstanceHealthChecksRequest, @@ -14,7 +15,11 @@ GetInstanceRequest, ListInstancesRequest, ) -from dstack._internal.server.security.permissions import Authenticated, ProjectMember +from dstack._internal.server.security.permissions import ( + Authenticated, + ProjectMember, + check_can_access_instance, +) from dstack._internal.server.utils.routers import ( CustomORJSONResponse, get_base_api_additional_responses, @@ -52,6 +57,7 @@ async def list_instances( project_names=body.project_names, fleet_ids=body.fleet_ids, only_active=body.only_active, + include_imported=body.include_imported, prev_created_at=body.prev_created_at, prev_id=body.prev_id, limit=body.limit, @@ -83,12 +89,15 @@ async def get_instance_health_checks( async def get_instance( body: GetInstanceRequest, session: Annotated[AsyncSession, Depends(get_session)], - user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectMember())], + user: Annotated[UserModel, Depends(Authenticated())], + project: Annotated[ProjectModel, Depends(Project())], ): """ Returns an instance given its ID. """ - _, project = user_project + await check_can_access_instance( + session=session, user=user, instance_project=project, instance_id=body.id + ) instance = await instances_services.get_instance( session=session, project=project, instance_id=body.id ) diff --git a/src/dstack/_internal/server/schemas/fleets.py b/src/dstack/_internal/server/schemas/fleets.py index 3df43d12ce..4bb25d50bb 100644 --- a/src/dstack/_internal/server/schemas/fleets.py +++ b/src/dstack/_internal/server/schemas/fleets.py @@ -4,23 +4,38 @@ from pydantic import Field +from dstack._internal.core.errors import ServerClientError from dstack._internal.core.models.common import CoreModel from dstack._internal.core.models.fleets import ApplyFleetPlanInput, FleetSpec +from dstack._internal.utils.common import EntityID, EntityName, EntityNameOrID class ListFleetsRequest(CoreModel): project_name: Optional[str] = None only_active: bool = False + include_imported: bool = False prev_created_at: Optional[datetime] = None prev_id: Optional[UUID] = None limit: int = Field(100, ge=0, le=100) ascending: bool = False +class ListProjectFleetsRequest(CoreModel): + include_imported: bool = False + + class GetFleetRequest(CoreModel): name: Optional[str] id: Optional[UUID] = None + def get_name_or_id(self) -> EntityNameOrID: + if self.id is not None: + return EntityID(id=self.id) + elif self.name is not None: + return EntityName(name=self.name) + else: + raise ServerClientError("name or id must be specified") + class GetFleetPlanRequest(CoreModel): spec: FleetSpec diff --git a/src/dstack/_internal/server/schemas/instances.py b/src/dstack/_internal/server/schemas/instances.py index 120ff161dc..8f87935b92 100644 --- a/src/dstack/_internal/server/schemas/instances.py +++ b/src/dstack/_internal/server/schemas/instances.py @@ -15,6 +15,7 @@ class ListInstancesRequest(CoreModel): project_names: Optional[list[str]] = None fleet_ids: Optional[list[UUID]] = None only_active: bool = False + include_imported: bool = False prev_created_at: Optional[datetime] = None prev_id: Optional[UUID] = None limit: int = 1000 diff --git a/src/dstack/_internal/server/security/permissions.py b/src/dstack/_internal/server/security/permissions.py index 0ecddf1d9e..a76f3659a0 100644 --- a/src/dstack/_internal/server/security/permissions.py +++ b/src/dstack/_internal/server/security/permissions.py @@ -1,13 +1,23 @@ from typing import Annotated, Optional, Tuple +from uuid import UUID from fastapi import Depends, HTTPException, Security from fastapi.security import HTTPBearer from fastapi.security.http import HTTPAuthorizationCredentials +from sqlalchemy import exists, func, select from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.server.db import get_session -from dstack._internal.server.models import ProjectModel, UserModel +from dstack._internal.server.models import ( + ExportedFleetModel, + FleetModel, + InstanceModel, + MemberModel, + ProjectModel, + ResourceImportModel, + UserModel, +) from dstack._internal.server.services.projects import ( get_project_model_by_name, get_user_project_role, @@ -18,6 +28,7 @@ error_invalid_token, error_not_found, ) +from dstack._internal.utils.common import EntityName, EntityNameOrID class Authenticated: @@ -249,3 +260,58 @@ async def is_project_member(session: AsyncSession, project_name: str, token: str return True except HTTPException: return False + + +async def check_can_access_fleet( + session: AsyncSession, + user: UserModel, + fleet_project: ProjectModel, + fleet_name_or_id: EntityNameOrID, +) -> None: + if ( + user.global_role == GlobalRole.ADMIN + or get_user_project_role(user=user, project=fleet_project) is not None + ): + return + filters = [ + FleetModel.project_id == fleet_project.id, + exists().where( + MemberModel.user_id == user.id, + MemberModel.project_id == ResourceImportModel.project_id, + ResourceImportModel.resource_export_id == ExportedFleetModel.resource_export_id, + ExportedFleetModel.fleet_id == FleetModel.id, + ), + ] + if isinstance(fleet_name_or_id, EntityName): + filters.extend([FleetModel.name == fleet_name_or_id.name, FleetModel.deleted == False]) + else: + filters.append(FleetModel.id == fleet_name_or_id.id) + res = await session.execute(select(func.count()).select_from(FleetModel).where(*filters)) + if res.scalar_one() == 0: + raise error_forbidden() + + +async def check_can_access_instance( + session: AsyncSession, + user: UserModel, + instance_project: ProjectModel, + instance_id: UUID, +) -> None: + if ( + user.global_role == GlobalRole.ADMIN + or get_user_project_role(user=user, project=instance_project) is not None + ): + return + filters = [ + InstanceModel.project_id == instance_project.id, + InstanceModel.id == instance_id, + exists().where( + MemberModel.user_id == user.id, + MemberModel.project_id == ResourceImportModel.project_id, + ResourceImportModel.resource_export_id == ExportedFleetModel.resource_export_id, + ExportedFleetModel.fleet_id == InstanceModel.fleet_id, + ), + ] + res = await session.execute(select(func.count()).select_from(InstanceModel).where(*filters)) + if res.scalar_one() == 0: + raise error_forbidden() diff --git a/src/dstack/_internal/server/services/events.py b/src/dstack/_internal/server/services/events.py index d46b43e201..dd7b33dc7f 100644 --- a/src/dstack/_internal/server/services/events.py +++ b/src/dstack/_internal/server/services/events.py @@ -252,14 +252,14 @@ async def list_events( limit: int, ascending: bool, ) -> list[Event]: - target_filters = [] + target_visibility_filters = [] if user.global_role != GlobalRole.ADMIN: query = select(MemberModel.project_id).where(MemberModel.user_id == user.id) res = await session.execute(query) # In Postgres, fetching project IDs separately is orders of magnitude faster # than using a subquery. project_ids = list(res.unique().scalars().all()) - target_filters.append( + target_visibility_filters.append( or_( EventTargetModel.entity_project_id.in_(project_ids), and_( @@ -269,6 +269,7 @@ async def list_events( ), ) ) + target_filters = [] if target_projects is not None: target_filters.append( and_( @@ -426,6 +427,8 @@ async def list_events( if event_filters: query = query.where(*event_filters) if target_filters: + # Each returned event should reference at least one target the user **wants** to see + # (as defined by user-provided filters). query = query.where( exists().where( and_( @@ -434,6 +437,17 @@ async def list_events( ) ) ) + if target_visibility_filters: + # Each returned event should reference at least one target the user **can** see + # (as defined by project membership). + query = query.where( + exists().where( + and_( + EventTargetModel.event_id == EventModel.id, + *target_visibility_filters, + ) + ) + ) res = await session.execute(query) event_models = res.unique().scalars().all() return list(map(event_model_to_event, event_models)) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index c0ec21aeaa..e74d17d5a0 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -4,7 +4,7 @@ from functools import wraps from typing import List, Literal, Optional, Tuple, TypeVar, Union -from sqlalchemy import and_, func, or_, select +from sqlalchemy import and_, exists, false, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import aliased, joinedload, selectinload @@ -53,11 +53,13 @@ from dstack._internal.core.services.diff import ModelDiff, copy_model, diff_models from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite from dstack._internal.server.models import ( + ExportedFleetModel, FleetModel, InstanceModel, JobModel, MemberModel, ProjectModel, + ResourceImportModel, RunModel, UserModel, ) @@ -82,6 +84,7 @@ ) from dstack._internal.server.services.resources import set_resources_defaults from dstack._internal.utils import random_names +from dstack._internal.utils.common import EntityID, EntityName, EntityNameOrID from dstack._internal.utils.logging import get_logger from dstack._internal.utils.ssh import pkey_from_str @@ -193,6 +196,7 @@ async def list_fleets( user: UserModel, project_name: Optional[str], only_active: bool, + include_imported: bool, prev_created_at: Optional[datetime], prev_id: Optional[uuid.UUID], limit: int, @@ -209,6 +213,7 @@ async def list_fleets( session=session, projects=projects, only_active=only_active, + include_imported=include_imported, prev_created_at=prev_created_at, prev_id=prev_id, limit=limit, @@ -221,13 +226,25 @@ async def list_projects_fleet_models( session: AsyncSession, projects: List[ProjectModel], only_active: bool, + include_imported: bool, prev_created_at: Optional[datetime], prev_id: Optional[uuid.UUID], limit: int, ascending: bool, ) -> List[FleetModel]: filters = [] - filters.append(FleetModel.project_id.in_(p.id for p in projects)) + project_ids = {p.id for p in projects} + is_fleet_imported_subquery = exists().where( + ResourceImportModel.project_id.in_(project_ids), + ResourceImportModel.resource_export_id == ExportedFleetModel.resource_export_id, + ExportedFleetModel.fleet_id == FleetModel.id, + ) + filters.append( + or_( + FleetModel.project_id.in_(project_ids), + is_fleet_imported_subquery if include_imported else false(), + ) + ) if only_active: filters.append(FleetModel.deleted == False) if prev_created_at is not None: @@ -259,7 +276,10 @@ async def list_projects_fleet_models( .where(*filters) .order_by(*order_by) .limit(limit) - .options(selectinload(FleetModel.instances.and_(InstanceModel.deleted == False))) + .options( + joinedload(FleetModel.project).load_only(ProjectModel.name), + selectinload(FleetModel.instances.and_(InstanceModel.deleted == False)), + ) ) fleet_models = list(res.unique().scalars().all()) return fleet_models @@ -269,8 +289,11 @@ async def list_project_fleets( session: AsyncSession, project: ProjectModel, names: Optional[List[str]] = None, + include_imported: bool = False, ) -> List[Fleet]: - fleet_models = await list_project_fleet_models(session=session, project=project, names=names) + fleet_models = await list_project_fleet_models( + session=session, project=project, names=names, include_imported=include_imported + ) return [fleet_model_to_fleet(v) for v in fleet_models] @@ -278,11 +301,21 @@ async def list_project_fleet_models( session: AsyncSession, project: ProjectModel, names: Optional[List[str]] = None, + include_imported: bool = False, include_deleted: bool = False, ) -> List[FleetModel]: - filters = [ - FleetModel.project_id == project.id, - ] + filters = [] + is_fleet_imported_subquery = exists().where( + ResourceImportModel.project_id == project.id, + ResourceImportModel.resource_export_id == ExportedFleetModel.resource_export_id, + ExportedFleetModel.fleet_id == FleetModel.id, + ) + filters.append( + or_( + FleetModel.project_id == project.id, + is_fleet_imported_subquery if include_imported else false(), + ) + ) if names is not None: filters.append(FleetModel.name.in_(names)) if not include_deleted: @@ -290,7 +323,10 @@ async def list_project_fleet_models( res = await session.execute( select(FleetModel) .where(*filters) - .options(selectinload(FleetModel.instances.and_(InstanceModel.deleted == False))) + .options( + joinedload(FleetModel.project).load_only(ProjectModel.name), + selectinload(FleetModel.instances.and_(InstanceModel.deleted == False)), + ) ) return list(res.unique().scalars().all()) @@ -298,20 +334,17 @@ async def list_project_fleet_models( async def get_fleet( session: AsyncSession, project: ProjectModel, - name: Optional[str] = None, - fleet_id: Optional[uuid.UUID] = None, + name_or_id: EntityNameOrID, include_sensitive: bool = False, ) -> Optional[Fleet]: - if fleet_id is not None: + if isinstance(name_or_id, EntityID): fleet_model = await get_project_fleet_model_by_id( - session=session, project=project, fleet_id=fleet_id + session=session, project=project, fleet_id=name_or_id.id ) - elif name is not None: + else: fleet_model = await get_project_fleet_model_by_name( - session=session, project=project, name=name + session=session, project=project, name=name_or_id.name ) - else: - raise ServerClientError("name or id must be specified") if fleet_model is None: return None return fleet_model_to_fleet(fleet_model, include_sensitive=include_sensitive) @@ -329,7 +362,10 @@ async def get_project_fleet_model_by_id( res = await session.execute( select(FleetModel) .where(*filters) - .options(joinedload(FleetModel.instances.and_(InstanceModel.deleted == False))) + .options( + joinedload(FleetModel.instances.and_(InstanceModel.deleted == False)), + joinedload(FleetModel.project).load_only(ProjectModel.name), + ) ) return res.unique().scalar_one_or_none() @@ -349,7 +385,10 @@ async def get_project_fleet_model_by_name( res = await session.execute( select(FleetModel) .where(*filters) - .options(joinedload(FleetModel.instances.and_(InstanceModel.deleted == False))) + .options( + joinedload(FleetModel.instances.and_(InstanceModel.deleted == False)), + joinedload(FleetModel.project).load_only(ProjectModel.name), + ) ) return res.unique().scalar_one_or_none() @@ -379,7 +418,7 @@ async def get_plan( current_fleet = await get_fleet( session=session, project=project, - name=effective_spec.configuration.name, + name_or_id=EntityName(effective_spec.configuration.name), include_sensitive=True, ) if current_fleet is not None: diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index 046f092c03..f58f8705ff 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -5,7 +5,7 @@ from typing import Dict, List, Literal, Optional, Union import gpuhunt -from sqlalchemy import and_, or_, select +from sqlalchemy import and_, exists, false, or_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload, load_only @@ -42,10 +42,12 @@ from dstack._internal.core.services.profiles import get_termination from dstack._internal.server import settings as server_settings from dstack._internal.server.models import ( + ExportedFleetModel, FleetModel, InstanceHealthCheckModel, InstanceModel, ProjectModel, + ResourceImportModel, UserModel, ) from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse @@ -516,13 +518,23 @@ async def list_projects_instance_models( projects: List[ProjectModel], fleet_ids: Optional[Iterable[uuid.UUID]], only_active: bool, + include_imported: bool, prev_created_at: Optional[datetime], prev_id: Optional[uuid.UUID], limit: int, ascending: bool, ) -> List[InstanceModel]: + project_ids = [p.id for p in projects] + is_instance_imported_subquery = exists().where( + ResourceImportModel.project_id.in_(project_ids), + ResourceImportModel.resource_export_id == ExportedFleetModel.resource_export_id, + ExportedFleetModel.fleet_id == InstanceModel.fleet_id, + ) filters: List = [ - InstanceModel.project_id.in_(p.id for p in projects), + or_( + InstanceModel.project_id.in_(p.id for p in projects), + is_instance_imported_subquery if include_imported else false(), + ) ] if fleet_ids is not None: filters.append(InstanceModel.fleet_id.in_(fleet_ids)) @@ -569,7 +581,10 @@ async def list_projects_instance_models( .where(*filters) .order_by(*order_by) .limit(limit) - .options(joinedload(InstanceModel.fleet)) + .options( + joinedload(InstanceModel.fleet), + joinedload(InstanceModel.project).load_only(ProjectModel.name), + ) ) instance_models = list(res.unique().scalars().all()) return instance_models @@ -581,6 +596,7 @@ async def list_user_instances( project_names: Optional[Container[str]], fleet_ids: Optional[Iterable[uuid.UUID]], only_active: bool, + include_imported: bool, prev_created_at: Optional[datetime], prev_id: Optional[uuid.UUID], limit: int, @@ -600,6 +616,7 @@ async def list_user_instances( projects=projects, fleet_ids=fleet_ids, only_active=only_active, + include_imported=include_imported, prev_created_at=prev_created_at, prev_id=prev_id, limit=limit, diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index eb10bda5c4..bf0f65bb6f 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -273,10 +273,7 @@ def _get_job_configurator( async def stop_runner(session: AsyncSession, job_model: JobModel): res = await session.execute( select(InstanceModel) - .where( - InstanceModel.project_id == job_model.project_id, - InstanceModel.id == job_model.instance_id, - ) + .where(InstanceModel.id == job_model.instance_id) .options(joinedload(InstanceModel.project)) ) instance: Optional[InstanceModel] = res.scalar() diff --git a/src/dstack/_internal/server/services/runs/plan.py b/src/dstack/_internal/server/services/runs/plan.py index 5e3b6e5a02..780fbb495b 100644 --- a/src/dstack/_internal/server/services/runs/plan.py +++ b/src/dstack/_internal/server/services/runs/plan.py @@ -1,9 +1,9 @@ import math from typing import Optional, Union -from sqlalchemy import and_, not_, or_, select +from sqlalchemy import and_, exists, not_, or_, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import contains_eager, noload +from sqlalchemy.orm import contains_eager, joinedload, noload from dstack._internal.core.backends.base.backend import Backend from dstack._internal.core.models.fleets import Fleet, InstanceGroupPlacement @@ -21,7 +21,14 @@ RunSpec, ) from dstack._internal.core.models.volumes import Volume -from dstack._internal.server.models import FleetModel, InstanceModel, ProjectModel, RunModel +from dstack._internal.server.models import ( + ExportedFleetModel, + FleetModel, + InstanceModel, + ProjectModel, + ResourceImportModel, + RunModel, +) from dstack._internal.server.services.fleets import ( check_can_create_new_cloud_instance_in_fleet, fleet_model_to_fleet, @@ -206,8 +213,16 @@ async def get_run_candidate_fleet_models_filters( # If another job freed the instance but is still trying to detach volumes, # do not provision on it to prevent attaching volumes that are currently detaching. detaching_instances_ids = await get_instances_ids_with_detaching_volumes(session) + is_fleet_imported_subquery = exists().where( + ResourceImportModel.project_id == project.id, + ResourceImportModel.resource_export_id == ExportedFleetModel.resource_export_id, + ExportedFleetModel.fleet_id == FleetModel.id, + ) fleet_filters = [ - FleetModel.project_id == project.id, + or_( + FleetModel.project_id == project.id, + is_fleet_imported_subquery, + ), FleetModel.deleted == False, ] if run_model is not None and run_model.fleet is not None: @@ -235,7 +250,12 @@ async def select_run_candidate_fleet_models_with_filters( .join(FleetModel.instances) .where(*fleet_filters) .where(*instance_filters) - .options(contains_eager(FleetModel.instances)) + .options( + contains_eager(FleetModel.instances), + joinedload(FleetModel.project) + .load_only(ProjectModel.name) + .joinedload(ProjectModel.backends), + ) .execution_options(populate_existing=True) ) if lock_instances: diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 6bff65dea3..79bac2153c 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -93,6 +93,7 @@ ComputeGroupModel, DecryptedString, EventModel, + ExportedFleetModel, FileArchiveModel, FleetModel, GatewayComputeModel, @@ -107,6 +108,8 @@ ProjectModel, RepoCredsModel, RepoModel, + ResourceExportModel, + ResourceImportModel, RunModel, SecretModel, UserModel, @@ -514,6 +517,24 @@ async def create_compute_group( return compute_group +async def create_resource_export( + session: AsyncSession, + exporter_project: ProjectModel, + importer_projects: list[ProjectModel], + exported_fleets: list[FleetModel], + name: str = "test_resource_export", +) -> ResourceExportModel: + resource_export = ResourceExportModel( + name=name, + project=exporter_project, + resource_imports=[ResourceImportModel(project=project) for project in importer_projects], + exported_fleets=[ExportedFleetModel(fleet=fleet) for fleet in exported_fleets], + ) + session.add(resource_export) + await session.commit() + return resource_export + + async def create_probe( session: AsyncSession, job: JobModel, diff --git a/src/dstack/_internal/utils/common.py b/src/dstack/_internal/utils/common.py index ba139c6bfc..2db91882ff 100644 --- a/src/dstack/_internal/utils/common.py +++ b/src/dstack/_internal/utils/common.py @@ -4,16 +4,32 @@ import re import time from collections.abc import Callable +from dataclasses import dataclass from datetime import datetime, timedelta, timezone from functools import partial from pathlib import Path -from typing import Any, Iterable, List, Optional, TypeVar +from typing import Any, Iterable, List, Optional, TypeVar, Union from urllib.parse import urlparse +from uuid import UUID from typing_extensions import ParamSpec from dstack._internal.core.models.common import Duration + +@dataclass +class EntityName: + name: str + + +@dataclass +class EntityID: + id: UUID + + +EntityNameOrID = Union[EntityName, EntityID] + + P = ParamSpec("P") R = TypeVar("R") diff --git a/src/dstack/api/server/_fleets.py b/src/dstack/api/server/_fleets.py index 9bfb1cb422..95bb22e82d 100644 --- a/src/dstack/api/server/_fleets.py +++ b/src/dstack/api/server/_fleets.py @@ -16,13 +16,15 @@ DeleteFleetsRequest, GetFleetPlanRequest, GetFleetRequest, + ListProjectFleetsRequest, ) from dstack.api.server._group import APIClientGroup class FleetsAPIClient(APIClientGroup): - def list(self, project_name: str) -> List[Fleet]: - resp = self._request(f"/api/project/{project_name}/fleets/list") + def list(self, project_name: str, *, include_imported: bool = False) -> List[Fleet]: + body = ListProjectFleetsRequest(include_imported=include_imported) + resp = self._request(f"/api/project/{project_name}/fleets/list", body=body.json()) return parse_obj_as(List[Fleet.__response__], resp.json()) def get( diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py index f33f608c71..96f37101ea 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py @@ -21,6 +21,7 @@ JobStatus, JobTerminationReason, ) +from dstack._internal.core.models.users import GlobalRole from dstack._internal.core.models.volumes import ( InstanceMountPoint, VolumeAttachmentData, @@ -46,6 +47,7 @@ create_job, create_project, create_repo, + create_resource_export, create_run, create_user, create_volume, @@ -55,6 +57,7 @@ get_job_provisioning_data, get_placement_group_provisioning_data, get_run_spec, + get_ssh_fleet_configuration, get_volume_provisioning_data, ) @@ -365,6 +368,108 @@ async def test_assignes_job_to_instance(self, test_db, session: AsyncSession): job.instance_assigned and job.instance is not None and job.instance.id == instance.id ) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_assigns_job_to_imported_fleet(self, test_db, session: AsyncSession): + exporter_user = await create_user( + session, name="exporter-user", global_role=GlobalRole.USER + ) + importer_user = await create_user( + session, name="importer_user", global_role=GlobalRole.USER + ) + exporter_project = await create_project( + session, name="exporter-project", owner=exporter_user + ) + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + repo = await create_repo(session=session, project_id=importer_project.id) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + instance = await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + run = await create_run( + session=session, + project=importer_project, + repo=repo, + user=importer_user, + ) + job = await create_job( + session=session, + run=run, + instance_assigned=False, + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + await process_submitted_jobs() + await session.refresh(job) + res = await session.execute(select(JobModel).options(joinedload(JobModel.instance))) + job = res.unique().scalar_one() + assert job.status == JobStatus.SUBMITTED + assert ( + job.instance_assigned and job.instance is not None and job.instance.id == instance.id + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_not_assigns_job_to_foreign_fleet_if_not_imported( + self, test_db, session: AsyncSession + ): + exporter_user = await create_user( + session, name="exporter-user", global_role=GlobalRole.USER + ) + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project( + session, name="exporter-project", owner=exporter_user + ) + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + repo = await create_repo(session=session, project_id=importer_project.id) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + run = await create_run( + session=session, + project=importer_project, + repo=repo, + user=importer_user, + ) + job = await create_job( + session=session, + run=run, + instance_assigned=False, + ) + await process_submitted_jobs() + await session.refresh(job) + res = await session.execute(select(JobModel).options(joinedload(JobModel.instance))) + job = res.unique().scalar_one() + assert job.status == JobStatus.TERMINATING + assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY + assert not job.instance_assigned + assert job.instance is None + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_does_no_reuse_unavailable_instances(self, test_db, session: AsyncSession): diff --git a/src/tests/_internal/server/routers/test_events.py b/src/tests/_internal/server/routers/test_events.py index f31c082d06..e8d43d8f1f 100644 --- a/src/tests/_internal/server/routers/test_events.py +++ b/src/tests/_internal/server/routers/test_events.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest +import pytest_asyncio from freezegun import freeze_time from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession @@ -16,9 +17,12 @@ create_job, create_project, create_repo, + create_resource_export, create_run, create_user, get_auth_headers, + get_fleet_spec, + get_ssh_fleet_configuration, ) pytestmark = [ @@ -1326,3 +1330,227 @@ async def test_limits_events_regardless_number_of_targets( ) resp.raise_for_status() assert len(resp.json()) == 2 + + +class TestListEventsWithExportedFleet: + @pytest_asyncio.fixture + async def exported_fleet_setup(self, session: AsyncSession): + # Create exporter user and project + exporter_user = await create_user( + session, name="exporter-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project( + session, name="exporter-project", owner=exporter_user + ) + await add_project_member( + session=session, + project=exporter_project, + user=exporter_user, + project_role=ProjectRole.USER, + ) + + # Create first importer user and project + importer_user_1 = await create_user( + session, name="importer-user-1", global_role=GlobalRole.USER + ) + importer_project_1 = await create_project( + session, name="importer-project-1", owner=importer_user_1 + ) + await add_project_member( + session=session, + project=importer_project_1, + user=importer_user_1, + project_role=ProjectRole.USER, + ) + + # Create second importer user and project + importer_user_2 = await create_user( + session, name="importer-user-2", global_role=GlobalRole.USER + ) + importer_project_2 = await create_project( + session, name="importer-project-2", owner=importer_user_2 + ) + await add_project_member( + session=session, + project=importer_project_2, + user=importer_user_2, + project_role=ProjectRole.USER, + ) + + # Create fleet and instance + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + events.emit( + session=session, + message="Fleet created", + actor=events.UserActor.from_user(exporter_user), + targets=[events.Target.from_model(fleet)], + ) + instance = await create_instance( + session=session, project=exporter_project, fleet=fleet, name="exported-fleet-0" + ) + events.emit( + session=session, + message="Instance created", + actor=events.SystemActor(), + targets=[events.Target.from_model(instance)], + ) + + # Create resource export + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project_1, importer_project_2], + exported_fleets=[fleet], + ) + + # Create first importer run and job + importer_run_1 = await create_run( + session=session, + project=importer_project_1, + user=importer_user_1, + repo=await create_repo(session=session, project_id=importer_project_1.id), + run_name="importer-run-1", + ) + events.emit( + session=session, + message="Run created", + actor=events.UserActor.from_user(importer_user_1), + targets=[events.Target.from_model(importer_run_1)], + ) + importer_job_1 = await create_job( + session=session, + run=importer_run_1, + fleet=fleet, + instance=instance, + ) + events.emit( + session=session, + message="Job assigned to instance", + actor=events.SystemActor(), + targets=[events.Target.from_model(importer_job_1), events.Target.from_model(instance)], + ) + + # Create second importer run and job + importer_run_2 = await create_run( + session=session, + project=importer_project_2, + user=importer_user_2, + repo=await create_repo(session=session, project_id=importer_project_2.id), + run_name="importer-run-2", + ) + events.emit( + session=session, + message="Run created", + actor=events.UserActor.from_user(importer_user_2), + targets=[events.Target.from_model(importer_run_2)], + ) + importer_job_2 = await create_job( + session=session, + run=importer_run_2, + fleet=fleet, + instance=instance, + ) + events.emit( + session=session, + message="Job assigned to instance", + actor=events.SystemActor(), + targets=[events.Target.from_model(importer_job_2), events.Target.from_model(instance)], + ) + + await session.commit() + + return { + "exporter_user": exporter_user, + "importer_user_1": importer_user_1, + "importer_user_2": importer_user_2, + "exported_fleet": fleet, + } + + @pytest.mark.parametrize("with_filter", [True, False]) + async def test_exporter_user_sees_all_events_targeting_exported_fleet( + self, + session: AsyncSession, + client: AsyncClient, + exported_fleet_setup: dict, + with_filter: bool, + ) -> None: + filters = {} + if with_filter: + filters = {"within_fleets": [str(exported_fleet_setup["exported_fleet"].id)]} + resp = await client.post( + "/api/events/list", + headers=get_auth_headers(exported_fleet_setup["exporter_user"].token), + json={"ascending": True, **filters}, + ) + resp.raise_for_status() + assert resp.json()[0]["message"] == "Fleet created" + assert resp.json()[1]["message"] == "Instance created" + assert resp.json()[2]["message"] == "Job assigned to instance" + assert {t["name"] for t in resp.json()[2]["targets"]} == { + "exported-fleet-0", + "importer-run-1-0-0", + } + assert resp.json()[3]["message"] == "Job assigned to instance" + assert {t["name"] for t in resp.json()[3]["targets"]} == { + "exported-fleet-0", + "importer-run-2-0-0", + } + assert len(resp.json()) == 4 + + @pytest.mark.parametrize( + ("user_key", "job_name"), + [ + ("importer_user_1", "importer-run-1-0-0"), + ("importer_user_2", "importer-run-2-0-0"), + ], + ) + async def test_importer_user_sees_only_events_about_their_own_run( + self, + session: AsyncSession, + client: AsyncClient, + exported_fleet_setup: dict, + user_key: str, + job_name: str, + ) -> None: + resp = await client.post( + "/api/events/list", + headers=get_auth_headers(exported_fleet_setup[user_key].token), + json={"ascending": True}, + ) + resp.raise_for_status() + assert resp.json()[0]["message"] == "Run created" + assert resp.json()[1]["message"] == "Job assigned to instance" + assert {t["name"] for t in resp.json()[1]["targets"]} == {"exported-fleet-0", job_name} + assert len(resp.json()) == 2 + + @pytest.mark.parametrize( + ("user_key", "job_name"), + [ + ("importer_user_1", "importer-run-1-0-0"), + ("importer_user_2", "importer-run-2-0-0"), + ], + ) + async def test_importer_user_can_filter_by_imported_fleet( + self, + session: AsyncSession, + client: AsyncClient, + exported_fleet_setup: dict, + user_key: str, + job_name: str, + ) -> None: + resp = await client.post( + "/api/events/list", + headers=get_auth_headers(exported_fleet_setup[user_key].token), + json={ + "ascending": True, + "within_fleets": [str(exported_fleet_setup["exported_fleet"].id)], + }, + ) + resp.raise_for_status() + assert resp.json()[0]["message"] == "Job assigned to instance" + assert {t["name"] for t in resp.json()[0]["targets"]} == {"exported-fleet-0", job_name} + assert len(resp.json()) == 1 diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index fed647d2c8..b8cf097e1b 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -36,6 +36,7 @@ create_job, create_project, create_repo, + create_resource_export, create_run, create_user, default_permissions_context, @@ -141,6 +142,174 @@ async def test_non_admin_cannot_see_others_projects( assert len(response_json) == 1 assert response_json[0]["project_name"] == "project1" + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize("with_project_name_filter", [True, False]) + async def test_returns_imported_fleet_with_include_imported( + self, test_db, session: AsyncSession, client: AsyncClient, with_project_name_filter: bool + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + instance = await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + await create_fleet( + session=session, + project=importer_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="local-fleet")), + ) + response = await client.post( + "/api/fleets/list", + headers=get_auth_headers(importer_user.token), + json={ + "include_imported": True, + "project_name": "importer-project" if with_project_name_filter else None, + }, + ) + assert response.status_code == 200 + response_json = response.json() + response_json.sort(key=lambda f: f["name"]) + assert len(response_json) == 2 + assert response_json[0]["name"] == "exported-fleet" + assert response_json[0]["project_name"] == "exporter-project" + assert len(response_json[0]["instances"]) == 1 + assert response_json[0]["instances"][0]["id"] == str(instance.id) + assert response_json[1]["name"] == "local-fleet" + assert response_json[1]["project_name"] == "importer-project" + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_not_returns_imported_fleet_without_include_imported( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + await create_fleet( + session=session, + project=importer_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="local-fleet")), + ) + response = await client.post( + "/api/fleets/list", + headers=get_auth_headers(importer_user.token), + json={}, + ) + assert response.status_code == 200 + response_json = response.json() + assert len(response_json) == 1 + assert response_json[0]["name"] == "local-fleet" + assert response_json[0]["project_name"] == "importer-project" + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_imported_fleet_once_when_user_member_of_both_projects( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, name="user", global_role=GlobalRole.USER) + exporter_project = await create_project(session, name="exporter-project", owner=user) + importer_project = await create_project(session, name="importer-project", owner=user) + await add_project_member( + session=session, + project=exporter_project, + user=user, + project_role=ProjectRole.USER, + ) + await add_project_member( + session=session, + project=importer_project, + user=user, + project_role=ProjectRole.USER, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="shared-fleet")), + ) + instance = await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="local-exporter-fleet")), + ) + await create_fleet( + session=session, + project=importer_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="local-importer-fleet")), + ) + response = await client.post( + "/api/fleets/list", + headers=get_auth_headers(user.token), + json={"include_imported": True}, + ) + assert response.status_code == 200 + response_json = response.json() + response_json.sort(key=lambda f: f["name"]) + assert len(response_json) == 3 + assert response_json[0]["name"] == "local-exporter-fleet" + assert response_json[0]["project_name"] == "exporter-project" + assert response_json[1]["name"] == "local-importer-fleet" + assert response_json[1]["project_name"] == "importer-project" + assert response_json[2]["name"] == "shared-fleet" + assert response_json[2]["project_name"] == "exporter-project" + assert len(response_json[2]["instances"]) == 1 + assert response_json[2]["instances"][0]["id"] == str(instance.id) + class TestListProjectFleets: @pytest.mark.asyncio @@ -182,6 +351,106 @@ async def test_lists_fleets(self, test_db, session: AsyncSession, client: AsyncC } ] + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_imported_fleet_with_include_imported( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + instance = await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + await create_fleet( + session=session, + project=importer_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="local-fleet")), + ) + response = await client.post( + f"/api/project/{importer_project.name}/fleets/list", + headers=get_auth_headers(importer_user.token), + json={"include_imported": True}, + ) + assert response.status_code == 200 + response_json = response.json() + response_json.sort(key=lambda f: f["name"]) + assert len(response_json) == 2 + assert response_json[0]["name"] == "exported-fleet" + assert response_json[0]["project_name"] == "exporter-project" + assert len(response_json[0]["instances"]) == 1 + assert response_json[0]["instances"][0]["id"] == str(instance.id) + assert response_json[1]["name"] == "local-fleet" + assert response_json[1]["project_name"] == "importer-project" + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_not_returns_imported_fleet_without_include_imported( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + await create_fleet( + session=session, + project=importer_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="local-fleet")), + ) + response = await client.post( + f"/api/project/{importer_project.name}/fleets/list", + headers=get_auth_headers(importer_user.token), + json={}, # No include_imported parameter + ) + assert response.status_code == 200 + response_json = response.json() + assert len(response_json) == 1 + assert response_json[0]["name"] == "local-fleet" + assert response_json[0]["project_name"] == "importer-project" + class TestGetFleet: @pytest.mark.asyncio @@ -371,6 +640,115 @@ async def test_returns_foreign_fleet_to_global_admin( assert response.status_code == 200 assert response.json()["name"] == "test-fleet" + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + "by_id", [pytest.param(False, id="by-name"), pytest.param(False, id="by-id")] + ) + async def test_returns_imported_fleet( + self, test_db, session: AsyncSession, client: AsyncClient, by_id: bool + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + instance = await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + if by_id: + body = {"id": str(fleet.id)} + else: + body = {"name": "exported-fleet"} + response = await client.post( + "/api/project/exporter-project/fleets/get", + headers=get_auth_headers(importer_user.token), + json=body, + ) + assert response.status_code == 200 + assert response.json()["id"] == str(fleet.id) + assert response.json()["name"] == "exported-fleet" + assert response.json()["project_name"] == "exporter-project" + assert len(response.json()["instances"]) == 1 + assert response.json()["instances"][0]["id"] == str(instance.id) + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + "by_id", [pytest.param(False, id="by-name"), pytest.param(False, id="by-id")] + ) + async def test_returns_403_on_foreign_fleet_if_not_imported( + self, test_db, session: AsyncSession, client: AsyncClient, by_id: bool + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + not_importer_user = await create_user( + session, name="not-importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project( + session, name="exporter-project", owner=importer_user + ) + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + not_importer_project = await create_project( + session, name="not-importer-project", owner=not_importer_user + ) + await add_project_member( + session=session, + project=not_importer_project, + user=not_importer_user, + project_role=ProjectRole.USER, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + if by_id: + body = {"id": str(fleet.id)} + else: + body = {"name": "exported-fleet"} + response = await client.post( + "/api/project/exporter-project/fleets/get", + headers=get_auth_headers(not_importer_user.token), + json=body, + ) + assert response.status_code == 403 + class TestApplyFleetPlan: @pytest.mark.asyncio @@ -918,6 +1296,43 @@ async def test_forbids_if_no_permission_to_manage_ssh_fleets( ) assert response.status_code in [401, 403] + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_importer_member_cannot_apply_plan_on_imported_fleet( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + spec = get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=spec, + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + response = await client.post( + f"/api/project/{exporter_project.name}/fleets/apply", + headers=get_auth_headers(importer_user.token), + json={"plan": {"spec": spec.dict()}, "force": False}, + ) + assert response.status_code == 403 + class TestDeleteFleets: @pytest.mark.asyncio @@ -1062,6 +1477,42 @@ async def test_forbids_if_no_permission_to_manage_ssh_fleets( ) assert response.status_code in [401, 403] + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_importer_member_cannot_delete_imported_fleet( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + response = await client.post( + f"/api/project/{exporter_project.name}/fleets/delete", + headers=get_auth_headers(importer_user.token), + json={"names": [fleet.name]}, + ) + assert response.status_code == 403 + class TestDeleteFleetInstances: @pytest.mark.asyncio @@ -1188,6 +1639,48 @@ async def test_returns_400_when_fleet_locked( assert fleet.status != FleetStatus.TERMINATING assert instance.status != InstanceStatus.TERMINATING + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_importer_member_cannot_delete_imported_fleet_instances( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + instance_num=1, + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + response = await client.post( + f"/api/project/{exporter_project.name}/fleets/delete_instances", + headers=get_auth_headers(importer_user.token), + json={"name": fleet.name, "instance_nums": [1]}, + ) + assert response.status_code == 403 + class TestGetPlan: @pytest.mark.asyncio @@ -1384,6 +1877,39 @@ async def test_replaces_no_balance_with_not_available_for_old_clients( assert offers[0]["availability"] == InstanceAvailability.AVAILABLE.value assert offers[1]["availability"] == expected_availability.value + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_importer_member_cannot_get_plan_for_imported_fleet( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.ADMIN, + ) + spec = get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")) + fleet = await create_fleet(session=session, project=exporter_project, spec=spec) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + response = await client.post( + f"/api/project/{exporter_project.name}/fleets/get_plan", + headers=get_auth_headers(importer_user.token), + json={"spec": spec.dict()}, + ) + assert response.status_code == 403 + def _fleet_model_to_json_dict(fleet: FleetModel) -> dict: return json.loads(fleet_model_to_fleet(fleet).json()) diff --git a/src/tests/_internal/server/routers/test_instances.py b/src/tests/_internal/server/routers/test_instances.py index 45363bfd92..424457824b 100644 --- a/src/tests/_internal/server/routers/test_instances.py +++ b/src/tests/_internal/server/routers/test_instances.py @@ -18,10 +18,12 @@ create_instance, create_instance_health_check, create_project, + create_resource_export, create_user, get_auth_headers, get_fleet_configuration, get_fleet_spec, + get_ssh_fleet_configuration, ) @@ -268,6 +270,193 @@ async def test_not_authenticated(self, client: AsyncClient, data) -> None: resp = await client.post("/api/instances/list", json={}) assert resp.status_code in [401, 403] + @pytest.mark.parametrize("with_project_name_filter", [True, False]) + async def test_returns_imported_instances_with_include_imported( + self, session: AsyncSession, client: AsyncClient, with_project_name_filter: bool + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.USER, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + await create_instance( + session=session, project=exporter_project, fleet=fleet, name="exported-fleet-0" + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + local_fleet = await create_fleet( + session=session, + project=importer_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="local-fleet")), + ) + await create_instance( + session=session, project=importer_project, fleet=local_fleet, name="local-fleet-0" + ) + response = await client.post( + "/api/instances/list", + headers=get_auth_headers(importer_user.token), + json={ + "include_imported": True, + "project_names": ["importer-project"] if with_project_name_filter else None, + }, + ) + assert response.status_code == 200 + response_json = response.json() + response_json.sort(key=lambda i: i["name"]) + assert len(response_json) == 2 + assert response_json[0]["name"] == "exported-fleet-0" + assert response_json[0]["project_name"] == "exporter-project" + assert response_json[0]["fleet_name"] == "exported-fleet" + assert response_json[1]["name"] == "local-fleet-0" + assert response_json[1]["project_name"] == "importer-project" + assert response_json[1]["fleet_name"] == "local-fleet" + + async def test_not_returns_imported_instances_without_include_imported( + self, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.USER, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + name="exported-fleet-0", + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + local_fleet = await create_fleet( + session=session, + project=importer_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="local-fleet")), + ) + await create_instance( + session=session, project=importer_project, fleet=local_fleet, name="local-fleet-0" + ) + response = await client.post( + "/api/instances/list", + headers=get_auth_headers(importer_user.token), + json={}, # No include_imported + ) + assert response.status_code == 200 + response_json = response.json() + assert len(response_json) == 1 + assert response_json[0]["name"] == "local-fleet-0" + assert response_json[0]["project_name"] == "importer-project" + assert response_json[0]["fleet_name"] == "local-fleet" + + async def test_returns_imported_instances_once_when_user_member_of_both_projects( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, name="user", global_role=GlobalRole.USER) + exporter_project = await create_project(session, name="exporter-project", owner=user) + importer_project = await create_project(session, name="importer-project", owner=user) + await add_project_member( + session=session, + project=exporter_project, + user=user, + project_role=ProjectRole.USER, + ) + await add_project_member( + session=session, + project=importer_project, + user=user, + project_role=ProjectRole.USER, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="shared-fleet")), + ) + await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + name="shared-fleet-0", + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + local_exporter_fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="local-exporter-fleet")), + ) + await create_instance( + session=session, + project=exporter_project, + fleet=local_exporter_fleet, + name="local-exported-fleet-0", + ) + local_importer_fleet = await create_fleet( + session=session, + project=importer_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="local-importer-fleet")), + ) + await create_instance( + session=session, + project=importer_project, + fleet=local_importer_fleet, + name="local-importer-fleet-0", + ) + response = await client.post( + "/api/instances/list", + headers=get_auth_headers(user.token), + json={"include_imported": True}, + ) + assert response.status_code == 200 + response_json = response.json() + response_json.sort(key=lambda i: i["name"]) + assert len(response_json) == 3 + assert response_json[0]["name"] == "local-exported-fleet-0" + assert response_json[0]["project_name"] == "exporter-project" + assert response_json[0]["fleet_name"] == "local-exporter-fleet" + assert response_json[1]["name"] == "local-importer-fleet-0" + assert response_json[1]["project_name"] == "importer-project" + assert response_json[1]["fleet_name"] == "local-importer-fleet" + assert response_json[2]["name"] == "shared-fleet-0" + assert response_json[2]["project_name"] == "exporter-project" + assert response_json[2]["fleet_name"] == "shared-fleet" + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -509,3 +698,93 @@ async def test_returns_403_if_not_project_member_and_instance_not_exists( json={"id": str(uuid.uuid4())}, ) assert resp.status_code == 403 + + async def test_returns_imported_instance( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.USER, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + instance = await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + response = await client.post( + "/api/project/exporter-project/instances/get", + headers=get_auth_headers(importer_user.token), + json={"id": str(instance.id)}, + ) + assert response.status_code == 200 + response_json = response.json() + assert response_json["id"] == str(instance.id) + assert response_json["project_name"] == "exporter-project" + assert response_json["fleet_name"] == "exported-fleet" + + async def test_returns_403_on_foreign_instance_if_not_imported( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + not_importer_user = await create_user( + session, name="not-importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project( + session, name="exporter-project", owner=importer_user + ) + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + not_importer_project = await create_project( + session, name="not-importer-project", owner=not_importer_user + ) + await add_project_member( + session=session, + project=not_importer_project, + user=not_importer_user, + project_role=ProjectRole.USER, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + instance = await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + response = await client.post( + "/api/project/exporter-project/instances/get", + headers=get_auth_headers(not_importer_user.token), + json={"id": str(instance.id)}, + ) + assert response.status_code == 403 diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 1f6b1ebf3e..c3a77b8bb6 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -64,12 +64,14 @@ create_job, create_project, create_repo, + create_resource_export, create_run, create_user, get_auth_headers, get_fleet_spec, get_job_provisioning_data, get_run_spec, + get_ssh_fleet_configuration, list_events, ) from dstack._internal.server.testing.matchers import SomeUUID4Str @@ -1384,6 +1386,56 @@ async def test_returns_run_plan_instance_volumes( assert response.status_code == 200, response.json() assert response.json() == run_plan_dict + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_run_plan_with_offer_from_imported_fleet( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + ) -> None: + importer_user = await create_user(session, global_role=GlobalRole.USER) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.USER, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + instance_num=1, + backend=BackendType.REMOTE, + ) + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + + run_spec = {"configuration": {"type": "dev-environment", "ide": "vscode"}} + body = {"run_spec": run_spec} + response = await client.post( + "/api/project/importer-project/runs/get_plan", + headers=get_auth_headers(importer_user.token), + json=body, + ) + assert response.status_code == 200, response.json() + response_json = response.json() + assert response_json["project_name"] == "importer-project" + assert response_json["job_plans"][0]["offers"][0]["backend"] == "remote" + @pytest.mark.parametrize( ("client_version", "expected_availability"), [ From 606e917c74e41f57b285dafa2066c56888288186 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Wed, 4 Mar 2026 11:11:24 +0100 Subject: [PATCH 2/4] [chore]: Use `project_ids` variable --- src/dstack/_internal/server/services/instances.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index f58f8705ff..9ee18ec81d 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -532,7 +532,7 @@ async def list_projects_instance_models( ) filters: List = [ or_( - InstanceModel.project_id.in_(p.id for p in projects), + InstanceModel.project_id.in_(project_ids), is_instance_imported_subquery if include_imported else false(), ) ] From eb236cca559ead1cd330e8956ace4f5ddd4850fc Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Wed, 4 Mar 2026 22:54:54 +0100 Subject: [PATCH 3/4] Add tests for fleets imported twice --- .../_internal/server/routers/test_fleets.py | 98 +++++++++++++++++++ .../server/routers/test_instances.py | 47 +++++++++ 2 files changed, 145 insertions(+) diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index b8cf097e1b..e8a5b0c96c 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -310,6 +310,55 @@ async def test_returns_imported_fleet_once_when_user_member_of_both_projects( assert len(response_json[2]["instances"]) == 1 assert response_json[2]["instances"][0]["id"] == str(instance.id) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_fleet_once_if_imported_twice( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.USER, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + instance = await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + ) + for name in ["export-1", "export-2"]: + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + name=name, + ) + response = await client.post( + "/api/fleets/list", + headers=get_auth_headers(importer_user.token), + json={"include_imported": True}, + ) + assert response.status_code == 200 + response_json = response.json() + assert len(response_json) == 1 + assert response_json[0]["name"] == "exported-fleet" + assert response_json[0]["project_name"] == "exporter-project" + assert len(response_json[0]["instances"]) == 1 + assert response_json[0]["instances"][0]["id"] == str(instance.id) + class TestListProjectFleets: @pytest.mark.asyncio @@ -451,6 +500,55 @@ async def test_not_returns_imported_fleet_without_include_imported( assert response_json[0]["name"] == "local-fleet" assert response_json[0]["project_name"] == "importer-project" + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_fleet_once_if_imported_twice( + self, test_db, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.USER, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + instance = await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + ) + for name in ["export-1", "export-2"]: + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + name=name, + ) + response = await client.post( + f"/api/project/{importer_project.name}/fleets/list", + headers=get_auth_headers(importer_user.token), + json={"include_imported": True}, + ) + assert response.status_code == 200 + response_json = response.json() + assert len(response_json) == 1 + assert response_json[0]["name"] == "exported-fleet" + assert response_json[0]["project_name"] == "exporter-project" + assert len(response_json[0]["instances"]) == 1 + assert response_json[0]["instances"][0]["id"] == str(instance.id) + class TestGetFleet: @pytest.mark.asyncio diff --git a/src/tests/_internal/server/routers/test_instances.py b/src/tests/_internal/server/routers/test_instances.py index 424457824b..c66fd7eb27 100644 --- a/src/tests/_internal/server/routers/test_instances.py +++ b/src/tests/_internal/server/routers/test_instances.py @@ -457,6 +457,53 @@ async def test_returns_imported_instances_once_when_user_member_of_both_projects assert response_json[2]["project_name"] == "exporter-project" assert response_json[2]["fleet_name"] == "shared-fleet" + async def test_returns_instance_once_if_imported_twice( + self, session: AsyncSession, client: AsyncClient + ): + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project(session, name="exporter-project") + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.USER, + ) + fleet = await create_fleet( + session=session, + project=exporter_project, + spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), + ) + await create_instance( + session=session, + project=exporter_project, + fleet=fleet, + name="exported-fleet-0", + ) + for name in ["export-1", "export-2"]: + await create_resource_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + name=name, + ) + response = await client.post( + "/api/instances/list", + headers=get_auth_headers(importer_user.token), + json={"include_imported": True}, + ) + assert response.status_code == 200 + response_json = response.json() + assert len(response_json) == 1 + assert response_json[0]["name"] == "exported-fleet-0" + assert response_json[0]["project_name"] == "exporter-project" + assert response_json[0]["fleet_name"] == "exported-fleet" + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) From 3c70ccc6b1af75894cc2ff45d5ac5d284c5f4066 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Wed, 4 Mar 2026 23:27:51 +0100 Subject: [PATCH 4/4] Rename to exports and imports (without "resource") --- ...7cca121fcb_add_resource_exports_imports.py | 142 ------------------ .../03_04_2221_5e8c7a9202bc_add_exports.py | 118 +++++++++++++++ src/dstack/_internal/server/models.py | 40 ++--- .../_internal/server/security/permissions.py | 10 +- .../_internal/server/services/fleets.py | 10 +- .../_internal/server/services/instances.py | 6 +- .../_internal/server/services/runs/plan.py | 6 +- src/dstack/_internal/server/testing/common.py | 18 +-- .../scheduled_tasks/test_submitted_jobs.py | 4 +- .../_internal/server/routers/test_events.py | 6 +- .../_internal/server/routers/test_fleets.py | 28 ++-- .../server/routers/test_instances.py | 14 +- .../_internal/server/routers/test_runs.py | 4 +- 13 files changed, 187 insertions(+), 219 deletions(-) delete mode 100644 src/dstack/_internal/server/migrations/versions/2026/03_02_1345_ea7cca121fcb_add_resource_exports_imports.py create mode 100644 src/dstack/_internal/server/migrations/versions/2026/03_04_2221_5e8c7a9202bc_add_exports.py diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_02_1345_ea7cca121fcb_add_resource_exports_imports.py b/src/dstack/_internal/server/migrations/versions/2026/03_02_1345_ea7cca121fcb_add_resource_exports_imports.py deleted file mode 100644 index a6bd031335..0000000000 --- a/src/dstack/_internal/server/migrations/versions/2026/03_02_1345_ea7cca121fcb_add_resource_exports_imports.py +++ /dev/null @@ -1,142 +0,0 @@ -"""Add resource exports imports - -Revision ID: ea7cca121fcb -Revises: 46150101edec -Create Date: 2026-03-02 13:45:57.118841+00:00 - -""" - -import sqlalchemy as sa -import sqlalchemy_utils -from alembic import op - -import dstack._internal.server.models - -# revision identifiers, used by Alembic. -revision = "ea7cca121fcb" -down_revision = "46150101edec" -branch_labels = None -depends_on = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "resource_exports", - sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), - sa.Column("name", sa.String(length=100), nullable=False), - sa.Column( - "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False - ), - sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=False), - sa.ForeignKeyConstraint( - ["project_id"], - ["projects.id"], - name=op.f("fk_resource_exports_project_id_projects"), - ondelete="CASCADE", - ), - sa.PrimaryKeyConstraint("id", name=op.f("pk_resource_exports")), - sa.UniqueConstraint("project_id", "name", name="uq_resource_exports_project_id_name"), - ) - with op.batch_alter_table("resource_exports", schema=None) as batch_op: - batch_op.create_index( - batch_op.f("ix_resource_exports_project_id"), ["project_id"], unique=False - ) - - op.create_table( - "exported_fleets", - sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), - sa.Column( - "resource_export_id", - sqlalchemy_utils.types.uuid.UUIDType(binary=False), - nullable=False, - ), - sa.Column("fleet_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), - sa.ForeignKeyConstraint( - ["fleet_id"], - ["fleets.id"], - name=op.f("fk_exported_fleets_fleet_id_fleets"), - ondelete="CASCADE", - ), - sa.ForeignKeyConstraint( - ["resource_export_id"], - ["resource_exports.id"], - name=op.f("fk_exported_fleets_resource_export_id_resource_exports"), - ondelete="CASCADE", - ), - sa.PrimaryKeyConstraint("id", name=op.f("pk_exported_fleets")), - sa.UniqueConstraint( - "resource_export_id", "fleet_id", name="uq_exported_fleets_resource_export_id_fleet_id" - ), - ) - with op.batch_alter_table("exported_fleets", schema=None) as batch_op: - batch_op.create_index( - batch_op.f("ix_exported_fleets_fleet_id"), ["fleet_id"], unique=False - ) - batch_op.create_index( - batch_op.f("ix_exported_fleets_resource_export_id"), - ["resource_export_id"], - unique=False, - ) - - op.create_table( - "resource_imports", - sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), - sa.Column( - "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False - ), - sa.Column( - "resource_export_id", - sqlalchemy_utils.types.uuid.UUIDType(binary=False), - nullable=False, - ), - sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=False), - sa.ForeignKeyConstraint( - ["project_id"], - ["projects.id"], - name=op.f("fk_resource_imports_project_id_projects"), - ondelete="CASCADE", - ), - sa.ForeignKeyConstraint( - ["resource_export_id"], - ["resource_exports.id"], - name=op.f("fk_resource_imports_resource_export_id_resource_exports"), - ondelete="CASCADE", - ), - sa.PrimaryKeyConstraint("id", name=op.f("pk_resource_imports")), - sa.UniqueConstraint( - "project_id", - "resource_export_id", - name="uq_resource_imports_project_id_resource_export_id", - ), - ) - with op.batch_alter_table("resource_imports", schema=None) as batch_op: - batch_op.create_index( - batch_op.f("ix_resource_imports_project_id"), ["project_id"], unique=False - ) - batch_op.create_index( - batch_op.f("ix_resource_imports_resource_export_id"), - ["resource_export_id"], - unique=False, - ) - - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table("resource_imports", schema=None) as batch_op: - batch_op.drop_index(batch_op.f("ix_resource_imports_resource_export_id")) - batch_op.drop_index(batch_op.f("ix_resource_imports_project_id")) - - op.drop_table("resource_imports") - with op.batch_alter_table("exported_fleets", schema=None) as batch_op: - batch_op.drop_index(batch_op.f("ix_exported_fleets_resource_export_id")) - batch_op.drop_index(batch_op.f("ix_exported_fleets_fleet_id")) - - op.drop_table("exported_fleets") - with op.batch_alter_table("resource_exports", schema=None) as batch_op: - batch_op.drop_index(batch_op.f("ix_resource_exports_project_id")) - - op.drop_table("resource_exports") - # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_04_2221_5e8c7a9202bc_add_exports.py b/src/dstack/_internal/server/migrations/versions/2026/03_04_2221_5e8c7a9202bc_add_exports.py new file mode 100644 index 0000000000..05a022f7ff --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_04_2221_5e8c7a9202bc_add_exports.py @@ -0,0 +1,118 @@ +"""Add exports + +Revision ID: 5e8c7a9202bc +Revises: 46150101edec +Create Date: 2026-03-04 22:21:54.971260+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "5e8c7a9202bc" +down_revision = "46150101edec" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "exports", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column( + "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False + ), + sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + name=op.f("fk_exports_project_id_projects"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_exports")), + sa.UniqueConstraint("project_id", "name", name="uq_exports_project_id_name"), + ) + with op.batch_alter_table("exports", schema=None) as batch_op: + batch_op.create_index(batch_op.f("ix_exports_project_id"), ["project_id"], unique=False) + + op.create_table( + "exported_fleets", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("export_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("fleet_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.ForeignKeyConstraint( + ["export_id"], + ["exports.id"], + name=op.f("fk_exported_fleets_export_id_exports"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["fleet_id"], + ["fleets.id"], + name=op.f("fk_exported_fleets_fleet_id_fleets"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_exported_fleets")), + sa.UniqueConstraint("export_id", "fleet_id", name="uq_exported_fleets_export_id_fleet_id"), + ) + with op.batch_alter_table("exported_fleets", schema=None) as batch_op: + batch_op.create_index( + batch_op.f("ix_exported_fleets_export_id"), ["export_id"], unique=False + ) + batch_op.create_index( + batch_op.f("ix_exported_fleets_fleet_id"), ["fleet_id"], unique=False + ) + + op.create_table( + "imports", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column( + "project_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False + ), + sa.Column("export_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["export_id"], + ["exports.id"], + name=op.f("fk_imports_export_id_exports"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + name=op.f("fk_imports_project_id_projects"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_imports")), + sa.UniqueConstraint("project_id", "export_id", name="uq_imports_project_id_export_id"), + ) + with op.batch_alter_table("imports", schema=None) as batch_op: + batch_op.create_index(batch_op.f("ix_imports_export_id"), ["export_id"], unique=False) + batch_op.create_index(batch_op.f("ix_imports_project_id"), ["project_id"], unique=False) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("imports", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_imports_project_id")) + batch_op.drop_index(batch_op.f("ix_imports_export_id")) + + op.drop_table("imports") + with op.batch_alter_table("exported_fleets", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_exported_fleets_fleet_id")) + batch_op.drop_index(batch_op.f("ix_exported_fleets_export_id")) + + op.drop_table("exported_fleets") + with op.batch_alter_table("exports", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_exports_project_id")) + + op.drop_table("exports") + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 418721e3b1..15801a25df 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -980,11 +980,9 @@ class EventTargetModel(BaseModel): entity_name: Mapped[str] = mapped_column(String(200)) -class ResourceExportModel(BaseModel): - __tablename__ = "resource_exports" - __table_args__ = ( - UniqueConstraint("project_id", "name", name="uq_resource_exports_project_id_name"), - ) +class ExportModel(BaseModel): + __tablename__ = "exports" + __table_args__ = (UniqueConstraint("project_id", "name", name="uq_exports_project_id_name"),) id: Mapped[uuid.UUID] = mapped_column( UUIDType(binary=False), primary_key=True, default=uuid.uuid4 @@ -995,21 +993,17 @@ class ResourceExportModel(BaseModel): ) project: Mapped["ProjectModel"] = relationship() created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) - resource_imports: Mapped[List["ResourceImportModel"]] = relationship( - back_populates="resource_export" - ) - exported_fleets: Mapped[List["ExportedFleetModel"]] = relationship( - back_populates="resource_export" - ) + imports: Mapped[List["ImportModel"]] = relationship(back_populates="export") + exported_fleets: Mapped[List["ExportedFleetModel"]] = relationship(back_populates="export") -class ResourceImportModel(BaseModel): - __tablename__ = "resource_imports" +class ImportModel(BaseModel): + __tablename__ = "imports" __table_args__ = ( UniqueConstraint( "project_id", - "resource_export_id", - name="uq_resource_imports_project_id_resource_export_id", + "export_id", + name="uq_imports_project_id_export_id", ), ) @@ -1020,28 +1014,26 @@ class ResourceImportModel(BaseModel): ForeignKey("projects.id", ondelete="CASCADE"), index=True ) project: Mapped["ProjectModel"] = relationship() - resource_export_id: Mapped[uuid.UUID] = mapped_column( - ForeignKey("resource_exports.id", ondelete="CASCADE"), index=True + export_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("exports.id", ondelete="CASCADE"), index=True ) - resource_export: Mapped["ResourceExportModel"] = relationship() + export: Mapped["ExportModel"] = relationship() created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) class ExportedFleetModel(BaseModel): __tablename__ = "exported_fleets" __table_args__ = ( - UniqueConstraint( - "resource_export_id", "fleet_id", name="uq_exported_fleets_resource_export_id_fleet_id" - ), + UniqueConstraint("export_id", "fleet_id", name="uq_exported_fleets_export_id_fleet_id"), ) id: Mapped[uuid.UUID] = mapped_column( UUIDType(binary=False), primary_key=True, default=uuid.uuid4 ) - resource_export_id: Mapped[uuid.UUID] = mapped_column( - ForeignKey("resource_exports.id", ondelete="CASCADE"), index=True + export_id: Mapped[uuid.UUID] = mapped_column( + ForeignKey("exports.id", ondelete="CASCADE"), index=True ) - resource_export: Mapped["ResourceExportModel"] = relationship() + export: Mapped["ExportModel"] = relationship() fleet_id: Mapped[uuid.UUID] = mapped_column( ForeignKey("fleets.id", ondelete="CASCADE"), index=True ) diff --git a/src/dstack/_internal/server/security/permissions.py b/src/dstack/_internal/server/security/permissions.py index a76f3659a0..107e526d30 100644 --- a/src/dstack/_internal/server/security/permissions.py +++ b/src/dstack/_internal/server/security/permissions.py @@ -12,10 +12,10 @@ from dstack._internal.server.models import ( ExportedFleetModel, FleetModel, + ImportModel, InstanceModel, MemberModel, ProjectModel, - ResourceImportModel, UserModel, ) from dstack._internal.server.services.projects import ( @@ -277,8 +277,8 @@ async def check_can_access_fleet( FleetModel.project_id == fleet_project.id, exists().where( MemberModel.user_id == user.id, - MemberModel.project_id == ResourceImportModel.project_id, - ResourceImportModel.resource_export_id == ExportedFleetModel.resource_export_id, + MemberModel.project_id == ImportModel.project_id, + ImportModel.export_id == ExportedFleetModel.export_id, ExportedFleetModel.fleet_id == FleetModel.id, ), ] @@ -307,8 +307,8 @@ async def check_can_access_instance( InstanceModel.id == instance_id, exists().where( MemberModel.user_id == user.id, - MemberModel.project_id == ResourceImportModel.project_id, - ResourceImportModel.resource_export_id == ExportedFleetModel.resource_export_id, + MemberModel.project_id == ImportModel.project_id, + ImportModel.export_id == ExportedFleetModel.export_id, ExportedFleetModel.fleet_id == InstanceModel.fleet_id, ), ] diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index e74d17d5a0..ca5a2e7b4f 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -55,11 +55,11 @@ from dstack._internal.server.models import ( ExportedFleetModel, FleetModel, + ImportModel, InstanceModel, JobModel, MemberModel, ProjectModel, - ResourceImportModel, RunModel, UserModel, ) @@ -235,8 +235,8 @@ async def list_projects_fleet_models( filters = [] project_ids = {p.id for p in projects} is_fleet_imported_subquery = exists().where( - ResourceImportModel.project_id.in_(project_ids), - ResourceImportModel.resource_export_id == ExportedFleetModel.resource_export_id, + ImportModel.project_id.in_(project_ids), + ImportModel.export_id == ExportedFleetModel.export_id, ExportedFleetModel.fleet_id == FleetModel.id, ) filters.append( @@ -306,8 +306,8 @@ async def list_project_fleet_models( ) -> List[FleetModel]: filters = [] is_fleet_imported_subquery = exists().where( - ResourceImportModel.project_id == project.id, - ResourceImportModel.resource_export_id == ExportedFleetModel.resource_export_id, + ImportModel.project_id == project.id, + ImportModel.export_id == ExportedFleetModel.export_id, ExportedFleetModel.fleet_id == FleetModel.id, ) filters.append( diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index 9ee18ec81d..079faf90c7 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -44,10 +44,10 @@ from dstack._internal.server.models import ( ExportedFleetModel, FleetModel, + ImportModel, InstanceHealthCheckModel, InstanceModel, ProjectModel, - ResourceImportModel, UserModel, ) from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse @@ -526,8 +526,8 @@ async def list_projects_instance_models( ) -> List[InstanceModel]: project_ids = [p.id for p in projects] is_instance_imported_subquery = exists().where( - ResourceImportModel.project_id.in_(project_ids), - ResourceImportModel.resource_export_id == ExportedFleetModel.resource_export_id, + ImportModel.project_id.in_(project_ids), + ImportModel.export_id == ExportedFleetModel.export_id, ExportedFleetModel.fleet_id == InstanceModel.fleet_id, ) filters: List = [ diff --git a/src/dstack/_internal/server/services/runs/plan.py b/src/dstack/_internal/server/services/runs/plan.py index 780fbb495b..4738622a07 100644 --- a/src/dstack/_internal/server/services/runs/plan.py +++ b/src/dstack/_internal/server/services/runs/plan.py @@ -24,9 +24,9 @@ from dstack._internal.server.models import ( ExportedFleetModel, FleetModel, + ImportModel, InstanceModel, ProjectModel, - ResourceImportModel, RunModel, ) from dstack._internal.server.services.fleets import ( @@ -214,8 +214,8 @@ async def get_run_candidate_fleet_models_filters( # do not provision on it to prevent attaching volumes that are currently detaching. detaching_instances_ids = await get_instances_ids_with_detaching_volumes(session) is_fleet_imported_subquery = exists().where( - ResourceImportModel.project_id == project.id, - ResourceImportModel.resource_export_id == ExportedFleetModel.resource_export_id, + ImportModel.project_id == project.id, + ImportModel.export_id == ExportedFleetModel.export_id, ExportedFleetModel.fleet_id == FleetModel.id, ) fleet_filters = [ diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 79bac2153c..2893418a0d 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -94,10 +94,12 @@ DecryptedString, EventModel, ExportedFleetModel, + ExportModel, FileArchiveModel, FleetModel, GatewayComputeModel, GatewayModel, + ImportModel, InstanceHealthCheckModel, InstanceModel, JobMetricsPoint, @@ -108,8 +110,6 @@ ProjectModel, RepoCredsModel, RepoModel, - ResourceExportModel, - ResourceImportModel, RunModel, SecretModel, UserModel, @@ -517,22 +517,22 @@ async def create_compute_group( return compute_group -async def create_resource_export( +async def create_export( session: AsyncSession, exporter_project: ProjectModel, importer_projects: list[ProjectModel], exported_fleets: list[FleetModel], - name: str = "test_resource_export", -) -> ResourceExportModel: - resource_export = ResourceExportModel( + name: str = "test_export", +) -> ExportModel: + export = ExportModel( name=name, project=exporter_project, - resource_imports=[ResourceImportModel(project=project) for project in importer_projects], + imports=[ImportModel(project=project) for project in importer_projects], exported_fleets=[ExportedFleetModel(fleet=fleet) for fleet in exported_fleets], ) - session.add(resource_export) + session.add(export) await session.commit() - return resource_export + return export async def create_probe( diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py index 96f37101ea..db75bbf530 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py @@ -42,12 +42,12 @@ from dstack._internal.server.settings import JobNetworkMode from dstack._internal.server.testing.common import ( ComputeMockSpec, + create_export, create_fleet, create_instance, create_job, create_project, create_repo, - create_resource_export, create_run, create_user, create_volume, @@ -406,7 +406,7 @@ async def test_assigns_job_to_imported_fleet(self, test_db, session: AsyncSessio run=run, instance_assigned=False, ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], diff --git a/src/tests/_internal/server/routers/test_events.py b/src/tests/_internal/server/routers/test_events.py index e8d43d8f1f..cb8e44b85a 100644 --- a/src/tests/_internal/server/routers/test_events.py +++ b/src/tests/_internal/server/routers/test_events.py @@ -12,12 +12,12 @@ from dstack._internal.server.services import events from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( + create_export, create_fleet, create_instance, create_job, create_project, create_repo, - create_resource_export, create_run, create_user, get_auth_headers, @@ -1399,8 +1399,8 @@ async def exported_fleet_setup(self, session: AsyncSession): targets=[events.Target.from_model(instance)], ) - # Create resource export - await create_resource_export( + # Create export + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project_1, importer_project_2], diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index e8a5b0c96c..12108eed31 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -31,12 +31,12 @@ from dstack._internal.server.services.permissions import DefaultPermissions from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( + create_export, create_fleet, create_instance, create_job, create_project, create_repo, - create_resource_export, create_run, create_user, default_permissions_context, @@ -171,7 +171,7 @@ async def test_returns_imported_fleet_with_include_imported( project=exporter_project, fleet=fleet, ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -224,7 +224,7 @@ async def test_not_returns_imported_fleet_without_include_imported( project=exporter_project, spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -276,7 +276,7 @@ async def test_returns_imported_fleet_once_when_user_member_of_both_projects( project=exporter_project, fleet=fleet, ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -339,7 +339,7 @@ async def test_returns_fleet_once_if_imported_twice( fleet=fleet, ) for name in ["export-1", "export-2"]: - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -428,7 +428,7 @@ async def test_returns_imported_fleet_with_include_imported( project=exporter_project, fleet=fleet, ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -478,7 +478,7 @@ async def test_not_returns_imported_fleet_without_include_imported( project=exporter_project, spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -529,7 +529,7 @@ async def test_returns_fleet_once_if_imported_twice( fleet=fleet, ) for name in ["export-1", "export-2"]: - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -769,7 +769,7 @@ async def test_returns_imported_fleet( project=exporter_project, fleet=fleet, ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -830,7 +830,7 @@ async def test_returns_403_on_foreign_fleet_if_not_imported( project=exporter_project, fleet=fleet, ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -1418,7 +1418,7 @@ async def test_importer_member_cannot_apply_plan_on_imported_fleet( project=exporter_project, spec=spec, ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -1598,7 +1598,7 @@ async def test_importer_member_cannot_delete_imported_fleet( project=exporter_project, spec=get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")), ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -1766,7 +1766,7 @@ async def test_importer_member_cannot_delete_imported_fleet_instances( fleet=fleet, instance_num=1, ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -1995,7 +1995,7 @@ async def test_importer_member_cannot_get_plan_for_imported_fleet( ) spec = get_fleet_spec(get_ssh_fleet_configuration(name="exported-fleet")) fleet = await create_fleet(session=session, project=exporter_project, spec=spec) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], diff --git a/src/tests/_internal/server/routers/test_instances.py b/src/tests/_internal/server/routers/test_instances.py index c66fd7eb27..439538c14c 100644 --- a/src/tests/_internal/server/routers/test_instances.py +++ b/src/tests/_internal/server/routers/test_instances.py @@ -14,11 +14,11 @@ from dstack._internal.server.models import UserModel from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( + create_export, create_fleet, create_instance, create_instance_health_check, create_project, - create_resource_export, create_user, get_auth_headers, get_fleet_configuration, @@ -295,7 +295,7 @@ async def test_returns_imported_instances_with_include_imported( await create_instance( session=session, project=exporter_project, fleet=fleet, name="exported-fleet-0" ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -355,7 +355,7 @@ async def test_not_returns_imported_instances_without_include_imported( fleet=fleet, name="exported-fleet-0", ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -410,7 +410,7 @@ async def test_returns_imported_instances_once_when_user_member_of_both_projects fleet=fleet, name="shared-fleet-0", ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -485,7 +485,7 @@ async def test_returns_instance_once_if_imported_twice( name="exported-fleet-0", ) for name in ["export-1", "export-2"]: - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -772,7 +772,7 @@ async def test_returns_imported_instance( project=exporter_project, fleet=fleet, ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], @@ -823,7 +823,7 @@ async def test_returns_403_on_foreign_instance_if_not_imported( project=exporter_project, fleet=fleet, ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project], diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index c3a77b8bb6..4d6e7aa95d 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -57,6 +57,7 @@ from dstack._internal.server.services.runs.spec import validate_run_spec_and_set_defaults from dstack._internal.server.testing.common import ( create_backend, + create_export, create_fleet, create_gateway, create_gateway_compute, @@ -64,7 +65,6 @@ create_job, create_project, create_repo, - create_resource_export, create_run, create_user, get_auth_headers, @@ -1417,7 +1417,7 @@ async def test_returns_run_plan_with_offer_from_imported_fleet( instance_num=1, backend=BackendType.REMOTE, ) - await create_resource_export( + await create_export( session=session, exporter_project=exporter_project, importer_projects=[importer_project],