diff --git a/AGENTS.md b/AGENTS.md index eb348b291b..336b97bb5b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -18,6 +18,7 @@ - Python targets 3.9+ with 4-space indentation and max line length of 99 (see `ruff.toml`; `E501` is ignored but keep lines readable). - Imports are sorted via Ruff’s isort settings (`dstack` treated as first-party). - Keep primary/public functions before local helper functions in a module section. +- Keep private classes, exceptions, and similar implementation-specific types close to the private functions that use them unless they are shared more broadly in the module. - Prefer pydantic-style models in `core/models`. - Tests use `test_*.py` modules and `test_*` functions; fixtures live near usage. diff --git a/pyproject.toml b/pyproject.toml index 259cbf7b25..8c53b2e166 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ include = [ "src/dstack/_internal/core/backends/runpod", "src/dstack/_internal/cli/services/configurators", "src/dstack/_internal/cli/commands", + "src/tests/_internal/server/background/pipeline_tasks", ] ignore = [ "src/dstack/_internal/server/migrations/versions", diff --git a/src/dstack/_internal/core/errors.py b/src/dstack/_internal/core/errors.py index 0bfd5f6f33..0d4262fe9b 100644 --- a/src/dstack/_internal/core/errors.py +++ b/src/dstack/_internal/core/errors.py @@ -136,6 +136,10 @@ class ConfigurationError(DstackError): pass +class SSHProvisioningError(DstackError): + pass + + class SSHError(DstackError): pass diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index 6b3762419f..556e13daaf 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -4,6 +4,7 @@ from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupPipeline from dstack._internal.server.background.pipeline_tasks.fleets import FleetPipeline from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline +from dstack._internal.server.background.pipeline_tasks.instances import InstancePipeline from dstack._internal.server.background.pipeline_tasks.placement_groups import ( PlacementGroupPipeline, ) @@ -19,6 +20,7 @@ def __init__(self) -> None: ComputeGroupPipeline(), FleetPipeline(), GatewayPipeline(), + InstancePipeline(), PlacementGroupPipeline(), VolumePipeline(), ] diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py index aa5af9a4a3..76073b7893 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/base.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -1,6 +1,8 @@ import asyncio +import logging import math import random +import time import uuid from abc import ABC, abstractmethod from collections.abc import Iterable, Sequence @@ -331,6 +333,7 @@ async def start(self): self._running = True while self._running: item = await self._queue.get() + start_time = time.time() logger.debug("Processing %s item %s", item.__tablename__, item.id) try: await self.process(item) @@ -338,7 +341,12 @@ async def start(self): logger.exception("Unexpected exception when processing item") finally: await self._heartbeater.untrack(item) - logger.debug("Processed %s item %s", item.__tablename__, item.id) + logger.debug( + "Processed %s item %s in %.3f", + item.__tablename__, + item.id, + time.time() - start_time, + ) def stop(self): self._running = False @@ -416,3 +424,40 @@ def resolve_now_placeholders(update_values: _ResolveNowInput, now: datetime): for key, value in update_values.items(): if value is NOW_PLACEHOLDER: update_values[key] = now + + +def log_lock_token_mismatch( + logger: logging.Logger, + item: PipelineItem, + action: str = "process", +) -> None: + logger.warning( + "Failed to %s %s item %s: lock_token mismatch." + " The item is expected to be processed and updated on another fetch iteration.", + action, + item.__tablename__, + item.id, + ) + + +def log_lock_token_changed_after_processing( + logger: logging.Logger, + item: PipelineItem, + action: str = "update", + expected_outcome: str = "updated", +) -> None: + logger.warning( + "Failed to %s %s item %s after processing: lock_token changed." + " The item is expected to be processed and %s on another fetch iteration.", + action, + item.__tablename__, + item.id, + expected_outcome, + ) + + +def log_lock_token_changed_on_reset(logger: logging.Logger) -> None: + logger.warning( + "Failed to reset lock: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration." + ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py index 0ee2975eb2..69ce3e7998 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -20,6 +20,8 @@ PipelineItem, UpdateMapDateTime, Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, @@ -194,12 +196,7 @@ async def process(self, item: PipelineItem): ) compute_group_model = res.unique().scalar_one_or_none() if compute_group_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return result = _TerminateResult() @@ -228,12 +225,7 @@ async def process(self, item: PipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) return if not result.instances_update_map: return @@ -249,6 +241,8 @@ async def process(self, item: PipelineItem): instance_model=instance_model, old_status=instance_model.status, new_status=InstanceStatus.TERMINATED, + termination_reason=instance_model.termination_reason, + termination_reason_message=instance_model.termination_reason_message, ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 55ffcd7f94..2a63e21bd5 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -2,13 +2,17 @@ import uuid from dataclasses import dataclass, field from datetime import timedelta -from typing import Sequence, TypedDict +from typing import Optional, Sequence, TypedDict from sqlalchemy import or_, select, update from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.orm import joinedload, load_only, selectinload -from dstack._internal.core.models.fleets import FleetSpec, FleetStatus +from dstack._internal.core.models.fleets import ( + FleetSpec, + FleetStatus, + InstanceGroupPlacement, +) from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason from dstack._internal.core.models.runs import RunStatus from dstack._internal.server.background.pipeline_tasks.base import ( @@ -20,6 +24,9 @@ PipelineItem, UpdateMapDateTime, Worker, + log_lock_token_changed_after_processing, + log_lock_token_changed_on_reset, + log_lock_token_mismatch, resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, @@ -210,77 +217,35 @@ async def process(self, item: PipelineItem): ) fleet_model = res.unique().scalar_one_or_none() if fleet_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return - instance_lock, _ = get_locker(get_db().dialect_name).get_lockset( - InstanceModel.__tablename__ - ) - async with instance_lock: - res = await session.execute( - select(InstanceModel) - .where( - InstanceModel.fleet_id == item.id, - InstanceModel.deleted == False, - # TODO: Lock instance models in the DB - # or_( - # InstanceModel.lock_expires_at.is_(None), - # InstanceModel.lock_expires_at < get_current_datetime(), - # ), - # or_( - # InstanceModel.lock_owner.is_(None), - # InstanceModel.lock_owner == FleetPipeline.__name__, - # ), - ) - .with_for_update(skip_locked=True, key_share=True) + # Lock instance only if consolidation is needed. + locked_instance_ids: set[uuid.UUID] = set() + consolidation_fleet_spec = _get_fleet_spec_if_ready_for_consolidation(fleet_model) + consolidation_instances = None + if consolidation_fleet_spec is not None: + consolidation_instances = await _lock_fleet_instances_for_consolidation( + session=session, + item=item, ) - locked_instance_models = res.scalars().all() - if len(fleet_model.instances) != len(locked_instance_models): - logger.debug( - "Failed to lock fleet %s instances. The fleet will be processed later.", - item.id, - ) - now = get_current_datetime() - # Keep `lock_owner` so that `InstancePipeline` sees that the fleet is being locked - # but unset `lock_expires_at` to process the item again ASAP (after `min_processing_interval`). - # Unset `lock_token` so that heartbeater can no longer update the item. - res = await session.execute( - update(FleetModel) - .where( - FleetModel.id == item.id, - FleetModel.lock_token == item.lock_token, - ) - .values( - lock_expires_at=None, - lock_token=None, - last_processed_at=now, - ) - ) - if res.rowcount == 0: # pyright: ignore[reportAttributeAccessIssue] - logger.warning( - "Failed to reset lock: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration." - ) + if consolidation_instances is None: return + locked_instance_ids = {instance.id for instance in consolidation_instances} - # TODO: Lock instance models in the DB - # for instance_model in locked_instance_models: - # instance_model.lock_expires_at = item.lock_expires_at - # instance_model.lock_token = item.lock_token - # instance_model.lock_owner = FleetPipeline.__name__ - # await session.commit() - - result = await _process_fleet(fleet_model) + result = await _process_fleet( + fleet_model, + consolidation_fleet_spec=consolidation_fleet_spec, + consolidation_instances=consolidation_instances, + ) fleet_update_map = _FleetUpdateMap() fleet_update_map.update(result.fleet_update_map) set_processed_update_map_fields(fleet_update_map) set_unlock_update_map_fields(fleet_update_map) - instance_update_rows = _build_instance_update_rows(result.instance_id_to_update_map) + instance_update_rows = _build_instance_update_rows( + result.instance_id_to_update_map, + unlock_instance_ids=locked_instance_ids, + ) async with get_session_ctx() as session: now = get_current_datetime() @@ -297,12 +262,13 @@ async def process(self, item: PipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) + if locked_instance_ids: + await _unlock_fleet_locked_instances( + session=session, + item=item, + locked_instance_ids=locked_instance_ids, + ) # TODO: Clean up fleet. return @@ -314,14 +280,14 @@ async def process(self, item: PipelineItem): ) if instance_update_rows: await session.execute( - update(InstanceModel).execution_options(synchronize_session=False), + update(InstanceModel), instance_update_rows, ) - if result.new_instances_count > 0: + if len(result.new_instance_creates) > 0: await _create_missing_fleet_instances( session=session, fleet_model=fleet_model, - new_instances_count=result.new_instances_count, + new_instance_creates=result.new_instance_creates, ) emit_fleet_status_change_event( session=session, @@ -339,9 +305,10 @@ class _FleetUpdateMap(ItemUpdateMap, total=False): deleted_at: UpdateMapDateTime consolidation_attempt: int last_consolidated_at: UpdateMapDateTime + current_master_instance_id: Optional[uuid.UUID] -class _InstanceUpdateMap(TypedDict, total=False): +class _InstanceUpdateMap(ItemUpdateMap, total=False): status: InstanceStatus termination_reason: InstanceTerminationReason termination_reason_message: str @@ -351,51 +318,154 @@ class _InstanceUpdateMap(TypedDict, total=False): id: uuid.UUID +def _get_fleet_spec_if_ready_for_consolidation(fleet_model: FleetModel) -> Optional[FleetSpec]: + if fleet_model.status == FleetStatus.TERMINATING: + return None + consolidation_fleet_spec = get_fleet_spec(fleet_model) + if ( + consolidation_fleet_spec.configuration.nodes is None + or consolidation_fleet_spec.autocreated + ): + return None + if not _is_fleet_ready_for_consolidation(fleet_model): + return None + return consolidation_fleet_spec + + +async def _lock_fleet_instances_for_consolidation( + session: AsyncSession, + item: PipelineItem, +) -> Optional[list[InstanceModel]]: + instance_lock, _ = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__) + async with instance_lock: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.fleet_id == item.id, + InstanceModel.deleted == False, + or_( + InstanceModel.lock_expires_at.is_(None), + InstanceModel.lock_expires_at < get_current_datetime(), + ), + or_( + InstanceModel.lock_owner.is_(None), + InstanceModel.lock_owner == FleetPipeline.__name__, + ), + ) + .with_for_update(skip_locked=True, key_share=True) + ) + locked_instance_models = list(res.scalars().all()) + locked_instance_ids = {instance_model.id for instance_model in locked_instance_models} + + res = await session.execute( + select(InstanceModel.id).where( + InstanceModel.fleet_id == item.id, + InstanceModel.deleted == False, + ) + ) + current_instance_ids = set(res.scalars().all()) + if current_instance_ids != locked_instance_ids: + logger.debug( + "Failed to lock fleet %s instances. The fleet will be processed later.", + item.id, + ) + # Keep `lock_owner` so that `InstancePipeline` sees that the fleet is being locked + # but unset `lock_expires_at` to process the item again ASAP (after `min_processing_interval`). + # Unset `lock_token` so that heartbeater can no longer update the item. + res = await session.execute( + update(FleetModel) + .where( + FleetModel.id == item.id, + FleetModel.lock_token == item.lock_token, + ) + .values( + lock_expires_at=None, + lock_token=None, + last_processed_at=get_current_datetime(), + ) + .returning(FleetModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + log_lock_token_changed_on_reset(logger) + return None + + for instance_model in locked_instance_models: + instance_model.lock_expires_at = item.lock_expires_at + instance_model.lock_token = item.lock_token + instance_model.lock_owner = FleetPipeline.__name__ + await session.commit() + return locked_instance_models + + @dataclass class _ProcessResult: fleet_update_map: _FleetUpdateMap = field(default_factory=_FleetUpdateMap) instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) - new_instances_count: int = 0 + new_instance_creates: list["_NewInstanceCreate"] = field(default_factory=list) + + +class _NewInstanceCreate(TypedDict): + id: uuid.UUID + instance_num: int @dataclass class _MaintainNodesResult: instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) - new_instances_count: int = 0 + new_instance_creates: list[_NewInstanceCreate] = field(default_factory=list) changes_required: bool = False @property def has_changes(self) -> bool: - return len(self.instance_id_to_update_map) > 0 or self.new_instances_count > 0 + return len(self.instance_id_to_update_map) > 0 or len(self.new_instance_creates) > 0 -async def _process_fleet(fleet_model: FleetModel) -> _ProcessResult: - result = _consolidate_fleet_state_with_spec(fleet_model) - if result.new_instances_count > 0: - # Avoid deleting fleets that are about to provision new instances. - return result - delete = _should_delete_fleet(fleet_model) - if delete: +async def _process_fleet( + fleet_model: FleetModel, + consolidation_fleet_spec: Optional[FleetSpec] = None, + consolidation_instances: Optional[Sequence[InstanceModel]] = None, +) -> _ProcessResult: + result = _ProcessResult() + effective_instances = list(consolidation_instances or fleet_model.instances) + if consolidation_fleet_spec is not None: + result = _consolidate_fleet_state_with_spec( + fleet_model, + consolidation_fleet_spec=consolidation_fleet_spec, + consolidation_instances=effective_instances, + ) + if len(result.new_instance_creates) == 0 and _should_delete_fleet(fleet_model): result.fleet_update_map["status"] = FleetStatus.TERMINATED result.fleet_update_map["deleted"] = True result.fleet_update_map["deleted_at"] = NOW_PLACEHOLDER + _set_fail_instances_on_master_bootstrap_failure( + fleet_model=fleet_model, + instance_models=effective_instances, + instance_id_to_update_map=result.instance_id_to_update_map, + ) + _set_current_master_instance_id( + fleet_model=fleet_model, + fleet_update_map=result.fleet_update_map, + instance_models=effective_instances, + instance_id_to_update_map=result.instance_id_to_update_map, + new_instance_creates=result.new_instance_creates, + ) return result -def _consolidate_fleet_state_with_spec(fleet_model: FleetModel) -> _ProcessResult: +def _consolidate_fleet_state_with_spec( + fleet_model: FleetModel, + consolidation_fleet_spec: FleetSpec, + consolidation_instances: Sequence[InstanceModel], +) -> _ProcessResult: result = _ProcessResult() - if fleet_model.status == FleetStatus.TERMINATING: - return result - fleet_spec = get_fleet_spec(fleet_model) - if fleet_spec.configuration.nodes is None or fleet_spec.autocreated: - # Only explicitly created cloud fleets are consolidated. - return result - if not _is_fleet_ready_for_consolidation(fleet_model): - return result - maintain_nodes_result = _maintain_fleet_nodes_in_min_max_range(fleet_model, fleet_spec) + maintain_nodes_result = _maintain_fleet_nodes_in_min_max_range( + instances=consolidation_instances, + fleet_spec=consolidation_fleet_spec, + ) if maintain_nodes_result.has_changes: result.instance_id_to_update_map = maintain_nodes_result.instance_id_to_update_map - result.new_instances_count = maintain_nodes_result.new_instances_count + result.new_instance_creates = maintain_nodes_result.new_instance_creates if maintain_nodes_result.changes_required: result.fleet_update_map["consolidation_attempt"] = fleet_model.consolidation_attempt + 1 else: @@ -431,7 +501,7 @@ def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta: def _maintain_fleet_nodes_in_min_max_range( - fleet_model: FleetModel, + instances: Sequence[InstanceModel], fleet_spec: FleetSpec, ) -> _MaintainNodesResult: """ @@ -439,7 +509,7 @@ def _maintain_fleet_nodes_in_min_max_range( """ assert fleet_spec.configuration.nodes is not None result = _MaintainNodesResult() - for instance in fleet_model.instances: + for instance in instances: # Delete terminated but not deleted instances since # they are going to be replaced with new pending instances. if instance.status == InstanceStatus.TERMINATED and not instance.deleted: @@ -449,13 +519,19 @@ def _maintain_fleet_nodes_in_min_max_range( "deleted_at": NOW_PLACEHOLDER, } active_instances = [ - i for i in fleet_model.instances if i.status != InstanceStatus.TERMINATED and not i.deleted + i for i in instances if i.status != InstanceStatus.TERMINATED and not i.deleted ] active_instances_num = len(active_instances) if active_instances_num < fleet_spec.configuration.nodes.min: result.changes_required = True nodes_missing = fleet_spec.configuration.nodes.min - active_instances_num - result.new_instances_count = nodes_missing + taken_instance_nums = {instance.instance_num for instance in active_instances} + for _ in range(nodes_missing): + instance_num = get_next_instance_num(taken_instance_nums) + taken_instance_nums.add(instance_num) + result.new_instance_creates.append( + _NewInstanceCreate(id=uuid.uuid4(), instance_num=instance_num) + ) return result if ( fleet_spec.configuration.nodes.max is None @@ -467,7 +543,7 @@ def _maintain_fleet_nodes_in_min_max_range( # or if nodes.max is updated. result.changes_required = True nodes_redundant = active_instances_num - fleet_spec.configuration.nodes.max - for instance in fleet_model.instances: + for instance in instances: if nodes_redundant == 0: break if instance.status == InstanceStatus.IDLE: @@ -506,42 +582,59 @@ def _should_delete_fleet(fleet_model: FleetModel) -> bool: def _build_instance_update_rows( instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], + unlock_instance_ids: set[uuid.UUID], ) -> list[_InstanceUpdateMap]: instance_update_rows = [] - for instance_id, instance_update_map in instance_id_to_update_map.items(): + for instance_id in sorted(instance_id_to_update_map.keys() | unlock_instance_ids): + instance_update_map = instance_id_to_update_map.get(instance_id) update_row = _InstanceUpdateMap() - update_row.update(instance_update_map) + if instance_update_map is not None: + update_row.update(instance_update_map) + if instance_id in unlock_instance_ids: + set_unlock_update_map_fields(update_row) update_row["id"] = instance_id set_processed_update_map_fields(update_row) instance_update_rows.append(update_row) return instance_update_rows +async def _unlock_fleet_locked_instances( + session: AsyncSession, + item: PipelineItem, + locked_instance_ids: set[uuid.UUID], +) -> None: + await session.execute( + update(InstanceModel) + .where( + InstanceModel.id.in_(locked_instance_ids), + InstanceModel.lock_token == item.lock_token, + InstanceModel.lock_owner == FleetPipeline.__name__, + ) + .values( + lock_expires_at=None, + lock_token=None, + lock_owner=None, + ) + ) + + async def _create_missing_fleet_instances( session: AsyncSession, fleet_model: FleetModel, - new_instances_count: int, + new_instance_creates: Sequence[_NewInstanceCreate], ): fleet_spec = get_fleet_spec(fleet_model) - res = await session.execute( - select(InstanceModel.instance_num).where( - InstanceModel.fleet_id == fleet_model.id, - InstanceModel.deleted == False, - ) - ) - taken_instance_nums = set(res.scalars().all()) - for _ in range(new_instances_count): - instance_num = get_next_instance_num(taken_instance_nums) + for new_instance_create in new_instance_creates: instance_model = create_fleet_instance_model( session=session, project=fleet_model.project, # TODO: Store fleet.user and pass it instead of the project owner. username=fleet_model.project.owner.name, spec=fleet_spec, - instance_num=instance_num, + instance_num=new_instance_create["instance_num"], + instance_id=new_instance_create["id"], ) instance_model.fleet_id = fleet_model.id - taken_instance_nums.add(instance_num) events.emit( session=session, message=( @@ -553,6 +646,173 @@ async def _create_missing_fleet_instances( ) logger.info( "Added %d instances to fleet %s", - new_instances_count, + len(new_instance_creates), fleet_model.name, ) + + +def _set_fail_instances_on_master_bootstrap_failure( + fleet_model: FleetModel, + instance_models: Sequence[InstanceModel], + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], +) -> None: + """ + Terminates instances with MASTER_FAILED if the master dies with NO_OFFERS in a cluster with node.min == 0. + This is needed to avoid master re-election loop and fail fast. + """ + fleet_spec = get_fleet_spec(fleet_model) + if ( + not _is_cloud_cluster_fleet_spec(fleet_spec) + or fleet_spec.configuration.nodes is None + or fleet_spec.configuration.nodes.min != 0 + or fleet_model.current_master_instance_id is None + ): + return + + current_master_instance_model = None + for instance_model in instance_models: + if instance_model.id == fleet_model.current_master_instance_id: + current_master_instance_model = instance_model + break + if current_master_instance_model is None: + return + + if ( + current_master_instance_model.status != InstanceStatus.TERMINATED + or current_master_instance_model.termination_reason != InstanceTerminationReason.NO_OFFERS + ): + return + + surviving_instance_models = _get_surviving_instance_models_after_updates( + instance_models=instance_models, + instance_id_to_update_map=instance_id_to_update_map, + ) + if any( + instance_model.status not in InstanceStatus.finished_statuses() + and instance_model.job_provisioning_data is not None + for instance_model in surviving_instance_models + ): + # It should not be possible to provision non-master instances ahead of master + # but we still safe-guard against the case when there can be other instances provisioned. + return + + for instance_model in surviving_instance_models: + if ( + instance_model.id == current_master_instance_model.id + or instance_model.status in InstanceStatus.finished_statuses() + ): + continue + update_map = instance_id_to_update_map.setdefault(instance_model.id, _InstanceUpdateMap()) + update_map["status"] = InstanceStatus.TERMINATED + update_map["termination_reason"] = InstanceTerminationReason.MASTER_FAILED + + +def _set_current_master_instance_id( + fleet_model: FleetModel, + fleet_update_map: _FleetUpdateMap, + instance_models: Sequence[InstanceModel], + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], + new_instance_creates: Sequence[_NewInstanceCreate], +) -> None: + """ + Sets `current_master_instance_id` for `fleet_model`. + Master instance can be changed if the previous master is gone. + If there are no active instances, newly selected master may change backend/region/az/placement. + """ + fleet_spec = get_fleet_spec(fleet_model) + if not _is_cloud_cluster_fleet_spec(fleet_spec): + fleet_update_map["current_master_instance_id"] = None + return + surviving_instance_models = _get_surviving_instance_models_after_updates( + instance_models=instance_models, + instance_id_to_update_map=instance_id_to_update_map, + ) + current_master_instance_id = _select_current_master_instance_id( + current_master_instance_id=fleet_model.current_master_instance_id, + surviving_instance_models=surviving_instance_models, + instance_id_to_update_map=instance_id_to_update_map, + new_instance_creates=new_instance_creates, + ) + fleet_update_map["current_master_instance_id"] = current_master_instance_id + + +def _get_surviving_instance_models_after_updates( + instance_models: Sequence[InstanceModel], + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], +) -> list[InstanceModel]: + surviving_instance_models = [] + for instance_model in sorted(instance_models, key=lambda i: (i.instance_num, i.created_at)): + instance_update_map = instance_id_to_update_map.get(instance_model.id) + if instance_update_map is not None and instance_update_map.get("deleted"): + continue + surviving_instance_models.append(instance_model) + return surviving_instance_models + + +def _select_current_master_instance_id( + current_master_instance_id: Optional[uuid.UUID], + surviving_instance_models: Sequence[InstanceModel], + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], + new_instance_creates: Sequence[_NewInstanceCreate], +) -> Optional[uuid.UUID]: + # Keep the current master stable while it is still alive so InstancePipeline + # does not see fleet-wide election churn between provisioning attempts. + if current_master_instance_id is not None: + for instance_model in surviving_instance_models: + if ( + instance_model.id == current_master_instance_id + and _get_effective_instance_status( + instance_model, + instance_id_to_update_map=instance_id_to_update_map, + ) + not in InstanceStatus.finished_statuses() + ): + return instance_model.id + + # If the old master is gone, prefer a surviving provisioned instance so we + # keep following an already-established cluster placement decision. + for instance_model in surviving_instance_models: + if ( + _get_effective_instance_status( + instance_model, + instance_id_to_update_map=instance_id_to_update_map, + ) + not in InstanceStatus.finished_statuses() + and instance_model.job_provisioning_data is not None + ): + return instance_model.id + + # Prefer existing surviving instances over freshly planned replacements to + # avoid election churn during min-nodes backfill. + for instance_model in surviving_instance_models: + if ( + _get_effective_instance_status( + instance_model, + instance_id_to_update_map=instance_id_to_update_map, + ) + not in InstanceStatus.finished_statuses() + ): + return instance_model.id + + for new_instance_create in sorted(new_instance_creates, key=lambda i: i["instance_num"]): + return new_instance_create["id"] + + return None + + +def _get_effective_instance_status( + instance_model: InstanceModel, + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], +) -> InstanceStatus: + update_map = instance_id_to_update_map.get(instance_model.id) + if update_map is None: + return instance_model.status + return update_map.get("status", instance_model.status) + + +def _is_cloud_cluster_fleet_spec(fleet_spec: FleetSpec) -> bool: + configuration = fleet_spec.configuration + return ( + configuration.placement == InstanceGroupPlacement.CLUSTER + and configuration.ssh_config is None + ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py index 2d5f0a947b..81ba2ae708 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py @@ -18,6 +18,8 @@ Pipeline, PipelineItem, Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, @@ -219,12 +221,7 @@ async def _process_submitted_item(item: GatewayPipelineItem): ) gateway_model = res.unique().scalar_one_or_none() if gateway_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return result = await _process_submitted_gateway(gateway_model) @@ -251,12 +248,7 @@ async def _process_submitted_item(item: GatewayPipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) # TODO: Clean up gateway_compute_model. return emit_gateway_status_change_event( @@ -345,12 +337,7 @@ async def _process_provisioning_item(item: GatewayPipelineItem): ) gateway_model = res.unique().scalar_one_or_none() if gateway_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return result = await _process_provisioning_gateway(gateway_model) @@ -372,12 +359,7 @@ async def _process_provisioning_item(item: GatewayPipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) return emit_gateway_status_change_event( session=session, @@ -464,12 +446,7 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): ) gateway_model = res.unique().scalar_one_or_none() if gateway_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return result = await _process_to_be_deleted_gateway(gateway_model) @@ -485,11 +462,11 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): ) deleted_ids = list(res.scalars().all()) if len(deleted_ids) == 0: - logger.warning( - "Failed to delete %s item %s after processing: lock_token changed." - " The item is expected to be processed and deleted on another fetch iteration.", - item.__tablename__, - item.id, + log_lock_token_changed_after_processing( + logger, + item, + action="delete", + expected_outcome="deleted", ) return events.emit( @@ -514,12 +491,7 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) return if result.gateway_compute_update_map: diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py new file mode 100644 index 0000000000..b5289e05e9 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -0,0 +1,476 @@ +import asyncio +import uuid +from dataclasses import dataclass +from datetime import timedelta +from typing import Optional, Sequence + +from sqlalchemy import and_, not_, or_, select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload, load_only + +from dstack._internal.core.models.health import HealthStatus +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + Pipeline, + PipelineItem, + Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, +) +from dstack._internal.server.background.pipeline_tasks.instances.check import ( + check_instance, + process_idle_timeout, +) +from dstack._internal.server.background.pipeline_tasks.instances.cloud_provisioning import ( + create_cloud_instance, +) +from dstack._internal.server.background.pipeline_tasks.instances.common import ( + ProcessResult, +) +from dstack._internal.server.background.pipeline_tasks.instances.ssh_deploy import ( + add_ssh_instance, +) +from dstack._internal.server.background.pipeline_tasks.instances.termination import ( + terminate_instance, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ( + InstanceHealthCheckModel, + InstanceModel, + JobModel, + ProjectModel, +) +from dstack._internal.server.services import events +from dstack._internal.server.services.instances import ( + emit_instance_status_change_event, + is_ssh_instance, +) +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.placement import ( + schedule_fleet_placement_groups_deletion, +) +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class InstancePipelineItem(PipelineItem): + status: InstanceStatus + + +class InstancePipeline(Pipeline[InstancePipelineItem]): + def __init__( + self, + workers_num: int = 20, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=15), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[InstancePipelineItem]( + model_type=InstanceModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = InstanceFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + InstanceWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return InstanceModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater[InstancePipelineItem]: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher[InstancePipelineItem]: + return self.__fetcher + + @property + def _workers(self) -> Sequence["InstanceWorker"]: + return self.__workers + + +class InstanceFetcher(Fetcher[InstancePipelineItem]): + def __init__( + self, + queue: asyncio.Queue[InstancePipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[InstancePipelineItem], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.InstanceFetcher.fetch") + async def fetch(self, limit: int) -> list[InstancePipelineItem]: + instance_lock, _ = get_locker(get_db().dialect_name).get_lockset( + InstanceModel.__tablename__ + ) + async with instance_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.status.in_( + [ + InstanceStatus.PENDING, + InstanceStatus.PROVISIONING, + InstanceStatus.BUSY, + InstanceStatus.IDLE, + InstanceStatus.TERMINATING, + ] + ), + not_( + and_( + InstanceModel.status == InstanceStatus.TERMINATING, + InstanceModel.compute_group_id.is_not(None), + ) + ), + InstanceModel.deleted == False, + or_( + InstanceModel.last_processed_at <= now - self._min_processing_interval, + InstanceModel.last_processed_at == InstanceModel.created_at, + ), + or_( + InstanceModel.lock_expires_at.is_(None), + InstanceModel.lock_expires_at < now, + ), + or_( + InstanceModel.lock_owner.is_(None), + InstanceModel.lock_owner == InstancePipeline.__name__, + ), + ) + .order_by(InstanceModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True, of=InstanceModel) + .options( + load_only( + InstanceModel.id, + InstanceModel.lock_token, + InstanceModel.lock_expires_at, + InstanceModel.status, + ) + ) + ) + instance_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for instance_model in instance_models: + prev_lock_expired = instance_model.lock_expires_at is not None + instance_model.lock_expires_at = lock_expires_at + instance_model.lock_token = lock_token + instance_model.lock_owner = InstancePipeline.__name__ + items.append( + InstancePipelineItem( + __tablename__=InstanceModel.__tablename__, + id=instance_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + status=instance_model.status, + ) + ) + await session.commit() + return items + + +class InstanceWorker(Worker[InstancePipelineItem]): + def __init__( + self, + queue: asyncio.Queue[InstancePipelineItem], + heartbeater: Heartbeater[InstancePipelineItem], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.InstanceWorker.process") + async def process(self, item: InstancePipelineItem): + process_context: Optional[_ProcessContext] = None + if item.status == InstanceStatus.PENDING: + process_context = await _process_pending_item(item) + elif item.status == InstanceStatus.PROVISIONING: + process_context = await _process_provisioning_item(item) + elif item.status == InstanceStatus.IDLE: + process_context = await _process_idle_item(item) + elif item.status == InstanceStatus.BUSY: + process_context = await _process_busy_item(item) + elif item.status == InstanceStatus.TERMINATING: + process_context = await _process_terminating_item(item) + if process_context is None: + return + set_processed_update_map_fields(process_context.result.instance_update_map) + set_unlock_update_map_fields(process_context.result.instance_update_map) + await _apply_process_result( + item=item, + instance_model=process_context.instance_model, + result=process_context.result, + ) + + +@dataclass +class _ProcessContext: + instance_model: InstanceModel + result: ProcessResult + + +async def _process_pending_item(item: InstancePipelineItem) -> Optional[_ProcessContext]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_pending_or_terminating( + session=session, + item=item, + ) + if instance_model is None: + log_lock_token_mismatch(logger, item) + return None + if is_ssh_instance(instance_model): + result = await add_ssh_instance(instance_model) + else: + result = await create_cloud_instance(instance_model) + return _ProcessContext(instance_model=instance_model, result=result) + + +async def _process_provisioning_item(item: InstancePipelineItem) -> Optional[_ProcessContext]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_check(session=session, item=item) + if instance_model is None: + log_lock_token_mismatch(logger, item) + return None + result = await check_instance(instance_model) + return _ProcessContext(instance_model=instance_model, result=result) + + +async def _process_idle_item(item: InstancePipelineItem) -> Optional[_ProcessContext]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_idle(session=session, item=item) + if instance_model is None: + log_lock_token_mismatch(logger, item) + return None + idle_result = await process_idle_timeout( + session=session, + instance_model=instance_model, + ) + if idle_result is not None: + return _ProcessContext(instance_model=instance_model, result=idle_result) + result = await check_instance(instance_model) + return _ProcessContext(instance_model=instance_model, result=result) + + +async def _process_busy_item(item: InstancePipelineItem) -> Optional[_ProcessContext]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_check(session=session, item=item) + if instance_model is None: + log_lock_token_mismatch(logger, item) + return None + result = await check_instance(instance_model) + return _ProcessContext(instance_model=instance_model, result=result) + + +async def _process_terminating_item(item: InstancePipelineItem) -> Optional[_ProcessContext]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_pending_or_terminating( + session=session, + item=item, + ) + if instance_model is None: + log_lock_token_mismatch(logger, item) + return None + result = await terminate_instance(instance_model) + return _ProcessContext(instance_model=instance_model, result=result) + + +async def _refetch_locked_instance_for_pending_or_terminating( + session: AsyncSession, item: InstancePipelineItem +) -> Optional[InstanceModel]: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) + .options(joinedload(InstanceModel.fleet)) + ) + return res.unique().scalar_one_or_none() + + +async def _refetch_locked_instance_for_idle( + session: AsyncSession, item: InstancePipelineItem +) -> Optional[InstanceModel]: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .options(joinedload(InstanceModel.project)) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) + .options(joinedload(InstanceModel.fleet)) + ) + return res.unique().scalar_one_or_none() + + +async def _refetch_locked_instance_for_check( + session: AsyncSession, item: InstancePipelineItem +) -> Optional[InstanceModel]: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .options( + joinedload(InstanceModel.project).load_only( + ProjectModel.id, + ProjectModel.ssh_public_key, + ProjectModel.ssh_private_key, + ) + ) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) + ) + return res.unique().scalar_one_or_none() + + +async def _apply_process_result( + item: InstancePipelineItem, + instance_model: InstanceModel, + result: ProcessResult, +) -> None: + async with get_session_ctx() as session: + if result.health_check_create is not None: + session.add(InstanceHealthCheckModel(**result.health_check_create)) + if result.new_placement_group_models: + session.add_all(result.new_placement_group_models) + if result.health_check_create is not None or result.new_placement_group_models: + await session.flush() + + now = get_current_datetime() + resolve_now_placeholders(result.instance_update_map, now=now) + + res = await session.execute( + update(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .values(**result.instance_update_map) + .returning(InstanceModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + log_lock_token_changed_after_processing(logger, item) + await session.rollback() + return + + if result.schedule_pg_deletion_fleet_id is not None: + await schedule_fleet_placement_groups_deletion( + session=session, + fleet_id=result.schedule_pg_deletion_fleet_id, + except_placement_group_ids=( + () + if result.schedule_pg_deletion_except_id is None + else (result.schedule_pg_deletion_except_id,) + ), + ) + + emit_instance_status_change_event( + session=session, + instance_model=instance_model, + old_status=instance_model.status, + new_status=result.instance_update_map.get("status", instance_model.status), + termination_reason=result.instance_update_map.get( + "termination_reason", instance_model.termination_reason + ), + termination_reason_message=result.instance_update_map.get( + "termination_reason_message", + instance_model.termination_reason_message, + ), + ) + _emit_instance_health_change_event( + session=session, + instance_model=instance_model, + old_health=instance_model.health, + new_health=result.instance_update_map.get("health", instance_model.health), + ) + _emit_instance_reachability_change_event( + session=session, + instance_model=instance_model, + old_status=instance_model.status, + old_unreachable=instance_model.unreachable, + new_unreachable=result.instance_update_map.get( + "unreachable", instance_model.unreachable + ), + ) + + +def _emit_instance_health_change_event( + session: AsyncSession, + instance_model: InstanceModel, + old_health: HealthStatus, + new_health: HealthStatus, +) -> None: + if old_health == new_health: + return + events.emit( + session, + f"Instance health changed {old_health.upper()} -> {new_health.upper()}", + actor=events.SystemActor(), + targets=[events.Target.from_model(instance_model)], + ) + + +def _emit_instance_reachability_change_event( + session: AsyncSession, + instance_model: InstanceModel, + old_status: InstanceStatus, + old_unreachable: bool, + new_unreachable: bool, +) -> None: + if not old_status.is_available() or old_unreachable == new_unreachable: + return + events.emit( + session, + "Instance became unreachable" if new_unreachable else "Instance became reachable", + actor=events.SystemActor(), + targets=[events.Target.from_model(instance_model)], + ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py new file mode 100644 index 0000000000..d23d536cd1 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py @@ -0,0 +1,568 @@ +import logging +import uuid +from datetime import timedelta +from typing import Optional + +import gpuhunt +import requests +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from dstack._internal.core.backends.base.backend import Backend +from dstack._internal.core.backends.base.compute import ( + get_dstack_runner_download_url, + get_dstack_runner_version, + get_dstack_shim_download_url, + get_dstack_shim_version, +) +from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT +from dstack._internal.core.errors import ProvisioningError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.health import HealthStatus +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.core.models.profiles import TerminationPolicy +from dstack._internal.core.models.runs import JobProvisioningData +from dstack._internal.server import settings as server_settings +from dstack._internal.server.background.pipeline_tasks.instances.common import ( + TERMINATION_DEADLINE_OFFSET, + HealthCheckCreate, + ProcessResult, + can_terminate_fleet_instances_on_idle_duration, + get_instance_idle_duration, + get_provisioning_deadline, + set_health_update, + set_status_update, + set_unreachable_update, +) +from dstack._internal.server.db import get_session_ctx +from dstack._internal.server.models import InstanceHealthCheckModel, InstanceModel, ProjectModel +from dstack._internal.server.schemas.instances import InstanceCheck +from dstack._internal.server.schemas.runner import ( + ComponentInfo, + ComponentStatus, + InstanceHealthResponse, +) +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services.instances import ( + get_instance_provisioning_data, + get_instance_ssh_private_keys, + is_ssh_instance, + remove_dangling_tasks_from_instance, +) +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.runner import client as runner_client +from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel +from dstack._internal.utils.common import get_current_datetime, get_or_error, run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +async def process_idle_timeout( + session: AsyncSession, + instance_model: InstanceModel, +) -> Optional[ProcessResult]: + if not ( + instance_model.status == InstanceStatus.IDLE + and instance_model.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE + and not instance_model.jobs + ): + return None + # Do not terminate instances on idle duration if fleet is already at `nodes.min`. + # This is an optimization to avoid terminate-create loop. + # There may be race conditions since we don't take the fleet lock. + # That's ok: in the worst case we go below `nodes.min`, but + # the fleet consolidation logic will provision new nodes. + if ( + instance_model.fleet is not None + and not await can_terminate_fleet_instances_on_idle_duration( + session=session, + fleet_model=instance_model.fleet, + ) + ): + return None + + idle_duration = get_instance_idle_duration(instance_model) + if idle_duration <= timedelta(seconds=instance_model.termination_idle_time): + return None + + result = ProcessResult() + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.IDLE_TIMEOUT, + termination_reason_message=f"Instance idle for {idle_duration.seconds}s", + ) + return result + + +async def check_instance(instance_model: InstanceModel) -> ProcessResult: + result = ProcessResult() + if ( + instance_model.status == InstanceStatus.BUSY + and instance_model.jobs + and all(job.status.is_finished() for job in instance_model.jobs) + ): + # A busy instance could have no active jobs due to this bug: + # https://github.com/dstackai/dstack/issues/2068 + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.JOB_FINISHED, + ) + logger.warning( + "Detected busy instance %s with finished job. Marked as TERMINATING", + instance_model.name, + extra={ + "instance_name": instance_model.name, + "instance_status": instance_model.status.value, + }, + ) + return result + + job_provisioning_data = get_or_error(get_instance_provisioning_data(instance_model)) + if job_provisioning_data.hostname is None: + return await _process_wait_for_instance_provisioning_data( + instance_model=instance_model, + job_provisioning_data=job_provisioning_data, + ) + + if not job_provisioning_data.dockerized: + if instance_model.status == InstanceStatus.PROVISIONING: + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.BUSY, + ) + return result + + check_instance_health = await _should_check_instance_health(instance_model.id) + instance_check = await _run_instance_check( + instance_model=instance_model, + job_provisioning_data=job_provisioning_data, + check_instance_health=check_instance_health, + ) + health_status = _get_health_status_for_instance_check( + instance_model=instance_model, + instance_check=instance_check, + check_instance_health=check_instance_health, + ) + _log_instance_check_result( + instance_model=instance_model, + instance_check=instance_check, + health_status=health_status, + check_instance_health=check_instance_health, + ) + + if instance_check.has_health_checks(): + # ensured by has_health_checks() + assert instance_check.health_response is not None + result.health_check_create = HealthCheckCreate( + instance_id=instance_model.id, + collected_at=get_current_datetime(), + status=health_status, + response=instance_check.health_response.json(), + ) + + set_health_update( + update_map=result.instance_update_map, + instance_model=instance_model, + health=health_status, + ) + set_unreachable_update( + update_map=result.instance_update_map, + instance_model=instance_model, + unreachable=not instance_check.reachable, + ) + + if instance_check.reachable: + result.instance_update_map["termination_deadline"] = None + if instance_model.status == InstanceStatus.PROVISIONING: + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.IDLE if not instance_model.jobs else InstanceStatus.BUSY, + ) + return result + + now = get_current_datetime() + if not is_ssh_instance(instance_model) and instance_model.termination_deadline is None: + result.instance_update_map["termination_deadline"] = now + TERMINATION_DEADLINE_OFFSET + + if ( + instance_model.status == InstanceStatus.PROVISIONING + and instance_model.started_at is not None + ): + provisioning_deadline = get_provisioning_deadline( + instance_model=instance_model, + job_provisioning_data=job_provisioning_data, + ) + if now > provisioning_deadline: + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, + termination_reason_message="Instance did not become reachable in time", + ) + elif instance_model.status.is_available(): + deadline = instance_model.termination_deadline + if deadline is not None and now > deadline: + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.UNREACHABLE, + ) + return result + + +async def _should_check_instance_health(instance_id) -> bool: + health_check_cutoff = get_current_datetime() - timedelta( + seconds=server_settings.SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS + ) + async with get_session_ctx() as session: + res = await session.execute( + select(func.count(1)).where( + InstanceHealthCheckModel.instance_id == instance_id, + InstanceHealthCheckModel.collected_at > health_check_cutoff, + ) + ) + return res.scalar_one() == 0 + + +async def _run_instance_check( + instance_model: InstanceModel, + job_provisioning_data: JobProvisioningData, + check_instance_health: bool, +) -> InstanceCheck: + ssh_private_keys = get_instance_ssh_private_keys(instance_model) + instance_check = await run_async( + _check_instance_inner, + ssh_private_keys, + job_provisioning_data, + None, + instance=instance_model, + check_instance_health=check_instance_health, + ) + # May return False if fails to establish ssh connection. + if instance_check is False: + return InstanceCheck(reachable=False, message="SSH or tunnel error") + return instance_check + + +def _get_health_status_for_instance_check( + instance_model: InstanceModel, + instance_check: InstanceCheck, + check_instance_health: bool, +) -> HealthStatus: + if instance_check.reachable and check_instance_health: + return instance_check.get_health_status() + # Keep previous health status. + return instance_model.health + + +def _log_instance_check_result( + instance_model: InstanceModel, + instance_check: InstanceCheck, + health_status: HealthStatus, + check_instance_health: bool, +) -> None: + loglevel = logging.DEBUG + if not instance_check.reachable and instance_model.status.is_available(): + loglevel = logging.WARNING + elif check_instance_health and not health_status.is_healthy(): + loglevel = logging.WARNING + logger.log( + loglevel, + "Instance %s check: reachable=%s health_status=%s message=%r", + instance_model.name, + instance_check.reachable, + health_status.name, + instance_check.message, + extra={"instance_name": instance_model.name, "health_status": health_status}, + ) + + +async def _process_wait_for_instance_provisioning_data( + instance_model: InstanceModel, + job_provisioning_data: JobProvisioningData, +) -> ProcessResult: + result = ProcessResult() + logger.debug("Waiting for instance %s to become running", instance_model.name) + provisioning_deadline = get_provisioning_deadline( + instance_model=instance_model, + job_provisioning_data=job_provisioning_data, + ) + if get_current_datetime() > provisioning_deadline: + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, + termination_reason_message="Backend did not complete provisioning in time", + ) + return result + + backend = await _get_backend_for_provisioning_wait( + project_id=instance_model.project_id, + backend_type=job_provisioning_data.backend, + ) + if backend is None: + logger.warning( + "Instance %s failed because instance's backend is not available", + instance_model.name, + ) + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Backend not available", + ) + return result + + try: + await run_async( + backend.compute().update_provisioning_data, + job_provisioning_data, + instance_model.project.ssh_public_key, + instance_model.project.ssh_private_key, + ) + result.instance_update_map["job_provisioning_data"] = job_provisioning_data.json() + except ProvisioningError as exc: + logger.warning( + "Error while waiting for instance %s to become running: %s", + instance_model.name, + repr(exc), + ) + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Error while waiting for instance to become running", + ) + except Exception: + logger.exception( + "Got exception when updating instance %s provisioning data", + instance_model.name, + ) + return result + + +async def _get_backend_for_provisioning_wait( + project_id: uuid.UUID, + backend_type: BackendType, +) -> Optional[Backend]: + async with get_session_ctx() as session: + res = await session.execute( + select(ProjectModel) + .where(ProjectModel.id == project_id) + .options(joinedload(ProjectModel.backends)) + ) + project_model = res.unique().scalar_one_or_none() + if project_model is None: + return None + return await backends_services.get_project_backend_by_type( + project=project_model, + backend_type=backend_type, + ) + + +@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) +def _check_instance_inner( + ports: dict[int, int], + *, + instance: InstanceModel, + check_instance_health: bool = False, +) -> InstanceCheck: + instance_health_response: Optional[InstanceHealthResponse] = None + shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + method = shim_client.healthcheck + try: + healthcheck_response = method(unmask_exceptions=True) + if check_instance_health: + method = shim_client.get_instance_health + instance_health_response = method() + except requests.RequestException as exc: + template = "shim.%s(): request error: %s" + args = (method.__func__.__name__, exc) + logger.debug(template, *args) + return InstanceCheck(reachable=False, message=template % args) + except Exception as exc: + template = "shim.%s(): unexpected exception %s: %s" + args = (method.__func__.__name__, exc.__class__.__name__, exc) + logger.exception(template, *args) + return InstanceCheck(reachable=False, message=template % args) + + try: + remove_dangling_tasks_from_instance(shim_client, instance) + except Exception as exc: + logger.exception("%s: error removing dangling tasks: %s", fmt(instance), exc) + + # There should be no shim API calls after this function call since it can request shim restart. + _maybe_install_components(instance, shim_client) + return runner_client.healthcheck_response_to_instance_check( + healthcheck_response, + instance_health_response, + ) + + +def _maybe_install_components( + instance_model: InstanceModel, + shim_client: runner_client.ShimClient, +) -> None: + try: + components = shim_client.get_components() + except requests.RequestException as exc: + logger.warning( + "Instance %s: shim.get_components(): request error: %s", instance_model.name, exc + ) + return + if components is None: + logger.debug("Instance %s: no components info", instance_model.name) + return + + installed_shim_version: Optional[str] = None + installation_requested = False + + if (runner_info := components.runner) is not None: + installation_requested |= _maybe_install_runner(instance_model, shim_client, runner_info) + else: + logger.debug("Instance %s: no runner info", instance_model.name) + + if (shim_info := components.shim) is not None: + if shim_info.status == ComponentStatus.INSTALLED: + installed_shim_version = shim_info.version + installation_requested |= _maybe_install_shim(instance_model, shim_client, shim_info) + else: + logger.debug("Instance %s: no shim info", instance_model.name) + + # old shim without `dstack-shim` component and `/api/shutdown` support + # or the same version is already running + # or we just requested installation of at least one component + # or at least one component is already being installed + # or at least one shim task won't survive restart + running_shim_version = shim_client.get_version_string() + if ( + installed_shim_version is None + or installed_shim_version == running_shim_version + or installation_requested + or any(component.status == ComponentStatus.INSTALLING for component in components) + or not shim_client.is_safe_to_restart() + ): + return + + if shim_client.shutdown(force=False): + logger.debug( + "Instance %s: restarting shim %s -> %s", + instance_model.name, + running_shim_version, + installed_shim_version, + ) + else: + logger.debug("Instance %s: cannot restart shim", instance_model.name) + + +def _maybe_install_runner( + instance_model: InstanceModel, + shim_client: runner_client.ShimClient, + runner_info: ComponentInfo, +) -> bool: + # For developers: + # * To install the latest dev build for the current branch from the CI, + # set DSTACK_USE_LATEST_FROM_BRANCH=1. + # * To provide your own build, set DSTACK_RUNNER_VERSION_URL and DSTACK_RUNNER_DOWNLOAD_URL. + expected_version = get_dstack_runner_version() + if expected_version is None: + return False + + installed_version = runner_info.version + logger.debug( + "Instance %s: runner status=%s installed_version=%s", + instance_model.name, + runner_info.status.value, + installed_version or "(no version)", + ) + if runner_info.status == ComponentStatus.INSTALLING: + logger.debug("Instance %s: runner is already being installed", instance_model.name) + return False + if installed_version and installed_version == expected_version: + logger.debug("Instance %s: expected runner version already installed", instance_model.name) + return False + + url = get_dstack_runner_download_url( + arch=_get_instance_cpu_arch(instance_model), + version=expected_version, + ) + logger.debug( + "Instance %s: installing runner %s -> %s from %s", + instance_model.name, + installed_version or "(no version)", + expected_version, + url, + ) + try: + shim_client.install_runner(url) + return True + except requests.RequestException as exc: + logger.warning("Instance %s: shim.install_runner(): %s", instance_model.name, exc) + return False + + +def _maybe_install_shim( + instance_model: InstanceModel, + shim_client: runner_client.ShimClient, + shim_info: ComponentInfo, +) -> bool: + # For developers: + # * To install the latest dev build for the current branch from the CI, + # set DSTACK_USE_LATEST_FROM_BRANCH=1. + # * To provide your own build, set DSTACK_SHIM_VERSION_URL and DSTACK_SHIM_DOWNLOAD_URL. + expected_version = get_dstack_shim_version() + if expected_version is None: + return False + + installed_version = shim_info.version + logger.debug( + "Instance %s: shim status=%s installed_version=%s running_version=%s", + instance_model.name, + shim_info.status.value, + installed_version or "(no version)", + shim_client.get_version_string(), + ) + if shim_info.status == ComponentStatus.INSTALLING: + logger.debug("Instance %s: shim is already being installed", instance_model.name) + return False + if installed_version and installed_version == expected_version: + logger.debug("Instance %s: expected shim version already installed", instance_model.name) + return False + + url = get_dstack_shim_download_url( + arch=_get_instance_cpu_arch(instance_model), + version=expected_version, + ) + logger.debug( + "Instance %s: installing shim %s -> %s from %s", + instance_model.name, + installed_version or "(no version)", + expected_version, + url, + ) + try: + shim_client.install_shim(url) + return True + except requests.RequestException as exc: + logger.warning("Instance %s: shim.install_shim(): %s", instance_model.name, exc) + return False + + +def _get_instance_cpu_arch(instance_model: InstanceModel) -> Optional[gpuhunt.CPUArchitecture]: + job_provisioning_data = get_instance_provisioning_data(instance_model) + if job_provisioning_data is None: + return None + return job_provisioning_data.instance_type.resources.cpu_arch diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py new file mode 100644 index 0000000000..4d2cbd8696 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py @@ -0,0 +1,421 @@ +import uuid +from dataclasses import dataclass +from typing import Optional + +from pydantic import ValidationError +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import load_only +from sqlalchemy.orm.attributes import set_committed_value + +from dstack._internal.core.backends.base.compute import ( + ComputeWithCreateInstanceSupport, + ComputeWithPlacementGroupSupport, + generate_unique_placement_group_name, +) +from dstack._internal.core.backends.features import ( + BACKENDS_WITH_CREATE_INSTANCE_SUPPORT, + BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT, +) +from dstack._internal.core.errors import ( + BackendError, + PlacementGroupNotSupportedError, +) +from dstack._internal.core.models.instances import ( + InstanceOfferWithAvailability, + InstanceStatus, + InstanceTerminationReason, +) +from dstack._internal.core.models.placement import PlacementGroupConfiguration, PlacementStrategy +from dstack._internal.core.models.runs import JobProvisioningData +from dstack._internal.server import settings as server_settings +from dstack._internal.server.background.pipeline_tasks.base import NOW_PLACEHOLDER +from dstack._internal.server.background.pipeline_tasks.instances.common import ( + ProcessResult, + set_status_update, +) +from dstack._internal.server.db import get_session_ctx +from dstack._internal.server.models import FleetModel, InstanceModel, PlacementGroupModel +from dstack._internal.server.services.fleets import get_create_instance_offers, is_cloud_cluster +from dstack._internal.server.services.instances import ( + get_instance_configuration, + get_instance_profile, + get_instance_provisioning_data, + get_instance_requirements, +) +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.offers import get_instance_offer_with_restricted_az +from dstack._internal.server.services.placement import ( + get_fleet_placement_group_models, + placement_group_model_to_placement_group, + placement_group_model_to_placement_group_optional, +) +from dstack._internal.utils.common import get_or_error, run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class _ClusterMasterContext: + current_master_instance_model: InstanceModel + is_current_instance_master: bool + master_job_provisioning_data: Optional[JobProvisioningData] + + +async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: + result = ProcessResult() + + try: + instance_configuration = get_instance_configuration(instance_model) + profile = get_instance_profile(instance_model) + requirements = get_instance_requirements(instance_model) + except ValidationError as exc: + logger.exception( + "%s: error parsing profile, requirements or instance configuration", + fmt(instance_model), + ) + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message=( + f"Error to parse profile, requirements or instance_configuration: {exc}" + ), + ) + return result + + cluster_context = None + placement_group_models: list[PlacementGroupModel] = [] + placement_group_model = None + master_job_provisioning_data = None + if instance_model.fleet is not None and is_cloud_cluster(instance_model.fleet): + cluster_context = await _get_cluster_master_context(instance_model) + if cluster_context is None: + # Waiting for the master + return result + placement_group_models, placement_group_model = await _get_cluster_placement_context( + instance_model=instance_model, + cluster_context=cluster_context, + ) + master_job_provisioning_data = cluster_context.master_job_provisioning_data + + offers = await get_create_instance_offers( + project=instance_model.project, + profile=profile, + requirements=requirements, + fleet_model=instance_model.fleet, + placement_group=placement_group_model_to_placement_group_optional(placement_group_model), + blocks="auto" if instance_model.total_blocks is None else instance_model.total_blocks, + exclude_not_available=True, + master_job_provisioning_data=master_job_provisioning_data, + infer_master_job_provisioning_data_from_fleet_instances=False, + ) + + # Limit number of offers tried to prevent long-running processing in case all offers fail. + for backend, instance_offer in offers[: server_settings.MAX_OFFERS_TRIED]: + if instance_offer.backend not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT: + continue + compute = backend.compute() + assert isinstance(compute, ComputeWithCreateInstanceSupport) + if master_job_provisioning_data is not None: + # `get_create_instance_offers()` already restricts backend and region from the master. + # Availability zone still has to be narrowed per offer. + instance_offer = get_instance_offer_with_restricted_az( + instance_offer=instance_offer, + master_job_provisioning_data=master_job_provisioning_data, + ) + if ( + cluster_context is not None + and cluster_context.is_current_instance_master + and instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT + and isinstance(compute, ComputeWithPlacementGroupSupport) + and ( + compute.are_placement_groups_compatible_with_reservations(instance_offer.backend) + or instance_configuration.reservation is None + ) + ): + ( + placement_group_model, + created_placement_group_model, + ) = await _find_or_create_suitable_placement_group_model( + instance_model=instance_model, + placement_group_models=placement_group_models, + instance_offer=instance_offer, + compute=compute, + ) + if placement_group_model is None: + continue + if created_placement_group_model: + placement_group_models.append(placement_group_model) + result.new_placement_group_models.append(placement_group_model) + + logger.debug( + "Trying %s in %s/%s for $%0.4f per hour", + instance_offer.instance.name, + instance_offer.backend.value, + instance_offer.region, + instance_offer.price, + ) + try: + job_provisioning_data = await run_async( + compute.create_instance, + instance_offer, + instance_configuration, + placement_group_model_to_placement_group_optional(placement_group_model), + ) + except BackendError as exc: + logger.warning( + "%s launch in %s/%s failed: %s", + instance_offer.instance.name, + instance_offer.backend.value, + instance_offer.region, + repr(exc), + extra={"instance_name": instance_model.name}, + ) + continue + except Exception: + logger.exception( + "Got exception when launching %s in %s/%s", + instance_offer.instance.name, + instance_offer.backend.value, + instance_offer.region, + ) + continue + + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.PROVISIONING, + ) + result.instance_update_map["backend"] = backend.TYPE + result.instance_update_map["region"] = instance_offer.region + result.instance_update_map["price"] = instance_offer.price + result.instance_update_map["instance_configuration"] = instance_configuration.json() + result.instance_update_map["job_provisioning_data"] = job_provisioning_data.json() + result.instance_update_map["offer"] = instance_offer.json() + result.instance_update_map["total_blocks"] = instance_offer.total_blocks + result.instance_update_map["started_at"] = NOW_PLACEHOLDER + + if ( + instance_model.fleet_id is not None + and cluster_context is not None + and cluster_context.is_current_instance_master + ): + # Clean up placement groups that did not end up being used. + result.schedule_pg_deletion_fleet_id = instance_model.fleet_id + if placement_group_model is not None: + result.schedule_pg_deletion_except_id = placement_group_model.id + return result + + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.NO_OFFERS, + termination_reason_message="All offers failed" if offers else "No offers found", + ) + return result + + +async def _get_cluster_master_context( + instance_model: InstanceModel, +) -> Optional[_ClusterMasterContext]: + assert instance_model.fleet is not None and is_cloud_cluster(instance_model.fleet) + assert instance_model.fleet_id is not None + async with get_session_ctx() as session: + current_master_instance_model = await _load_current_master_instance( + session=session, + fleet_id=instance_model.fleet_id, + ) + if current_master_instance_model is None: + logger.debug( + "%s: waiting for fleet pipeline to elect current cluster master", + fmt(instance_model), + ) + return None + + is_current_instance_master = current_master_instance_model.id == instance_model.id + master_job_provisioning_data = None + if not is_current_instance_master: + if ( + current_master_instance_model.deleted + or current_master_instance_model.status == InstanceStatus.TERMINATED + ): + logger.debug( + "%s: waiting for fleet pipeline to replace current master %s", + fmt(instance_model), + current_master_instance_model.id, + ) + return None + master_job_provisioning_data = get_instance_provisioning_data( + current_master_instance_model + ) + if master_job_provisioning_data is None: + logger.debug( + "%s: waiting for current master %s to determine cluster placement", + fmt(instance_model), + current_master_instance_model.id, + ) + return None + + return _ClusterMasterContext( + current_master_instance_model=current_master_instance_model, + is_current_instance_master=is_current_instance_master, + master_job_provisioning_data=master_job_provisioning_data, + ) + + +async def _get_cluster_placement_context( + instance_model: InstanceModel, + cluster_context: _ClusterMasterContext, +) -> tuple[list[PlacementGroupModel], Optional[PlacementGroupModel]]: + assert instance_model.fleet is not None and is_cloud_cluster(instance_model.fleet) + assert instance_model.fleet_id is not None + async with get_session_ctx() as session: + placement_group_models = await get_fleet_placement_group_models( + session=session, + fleet_id=instance_model.fleet_id, + ) + placement_group_model = None + if not cluster_context.is_current_instance_master: + # Non-master instances only reuse the placement group chosen by the + # current master. They never create a new placement group themselves. + placement_group_model = _get_current_master_placement_group_model( + placement_group_models=placement_group_models, + fleet_id=instance_model.fleet_id, + ) + if placement_group_model is not None: + _populate_current_master_placement_group_relations( + placement_group_model=placement_group_model, + instance_model=instance_model, + ) + return placement_group_models, placement_group_model + + +async def _load_current_master_instance( + session: AsyncSession, + fleet_id: uuid.UUID, +) -> Optional[InstanceModel]: + res = await session.execute( + select(FleetModel.current_master_instance_id).where(FleetModel.id == fleet_id) + ) + current_master_instance_id = res.scalar_one_or_none() + if current_master_instance_id is None: + return None + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == current_master_instance_id, + ) + .options( + load_only( + InstanceModel.id, + InstanceModel.deleted, + InstanceModel.status, + InstanceModel.job_provisioning_data, + ) + ) + ) + return res.scalar_one_or_none() + + +def _get_current_master_placement_group_model( + placement_group_models: list[PlacementGroupModel], + fleet_id: uuid.UUID, +) -> Optional[PlacementGroupModel]: + if not placement_group_models: + return None + if len(placement_group_models) > 1: + logger.error( + "Expected 0 or 1 placement groups associated with fleet master %s, found %s." + " Using the first placement group for this provisioning attempt.", + fleet_id, + len(placement_group_models), + ) + return placement_group_models[0] + + +def _populate_current_master_placement_group_relations( + placement_group_model: PlacementGroupModel, + instance_model: InstanceModel, +) -> None: + # Placement groups are loaded in a separate session from the instance worker. + # Reattach the already-known project/fleet objects so later detached access + # can still build a PlacementGroup value object without lazy loading. + set_committed_value(placement_group_model, "project", instance_model.project) + if instance_model.fleet is not None: + set_committed_value(placement_group_model, "fleet", instance_model.fleet) + + +async def _find_or_create_suitable_placement_group_model( + instance_model: InstanceModel, + placement_group_models: list[PlacementGroupModel], + instance_offer: InstanceOfferWithAvailability, + compute: ComputeWithPlacementGroupSupport, +) -> tuple[Optional[PlacementGroupModel], bool]: + for placement_group_model in placement_group_models: + if compute.is_suitable_placement_group( + placement_group_model_to_placement_group(placement_group_model), + instance_offer, + ): + return placement_group_model, False + + assert instance_model.fleet is not None + placement_group_id = uuid.uuid4() + placement_group_name = generate_unique_placement_group_name( + project_name=instance_model.project.name, + fleet_name=instance_model.fleet.name, + ) + placement_group_model = PlacementGroupModel( + id=placement_group_id, + name=placement_group_name, + project=instance_model.project, + fleet=get_or_error(instance_model.fleet), + configuration=PlacementGroupConfiguration( + backend=instance_offer.backend, + region=instance_offer.region, + placement_strategy=PlacementStrategy.CLUSTER, + ).json(), + ) + placement_group = placement_group_model_to_placement_group(placement_group_model) + logger.debug( + "Creating placement group %s in %s/%s", + placement_group.name, + placement_group.configuration.backend.value, + placement_group.configuration.region, + ) + try: + provisioning_data = await run_async( + compute.create_placement_group, + placement_group, + instance_offer, + ) + except PlacementGroupNotSupportedError: + logger.debug( + "Skipping offer %s because placement group not supported", + instance_offer.instance.name, + ) + return None, False + except BackendError as exc: + logger.warning( + "Failed to create placement group %s in %s/%s: %r", + placement_group.name, + placement_group.configuration.backend.value, + placement_group.configuration.region, + exc, + ) + return None, False + except Exception: + logger.exception( + "Got exception when creating placement group %s in %s/%s", + placement_group.name, + placement_group.configuration.backend.value, + placement_group.configuration.region, + ) + return None, False + + placement_group.provisioning_data = provisioning_data + placement_group_model.provisioning_data = provisioning_data.json() + return placement_group_model, True diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py new file mode 100644 index 0000000000..34e80311fd --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py @@ -0,0 +1,177 @@ +import datetime +import uuid +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Optional, TypedDict, Union + +from paramiko.pkey import PKey +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.health import HealthStatus +from dstack._internal.core.models.instances import ( + InstanceStatus, + InstanceTerminationReason, + SSHKey, +) +from dstack._internal.core.models.runs import JobProvisioningData +from dstack._internal.server.background.pipeline_tasks.base import ( + ItemUpdateMap, + UpdateMapDateTime, +) +from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout +from dstack._internal.server.models import FleetModel, InstanceModel, PlacementGroupModel +from dstack._internal.server.services.fleets import get_fleet_spec +from dstack._internal.utils.common import UNSET, Unset, get_current_datetime +from dstack._internal.utils.ssh import pkey_from_str + +TERMINATION_DEADLINE_OFFSET = timedelta(minutes=20) +TERMINATION_RETRY_TIMEOUT = timedelta(seconds=30) +TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15) +PROVISIONING_TIMEOUT_SECONDS = 10 * 60 # 10 minutes in seconds + + +class InstanceUpdateMap(ItemUpdateMap, total=False): + status: InstanceStatus + unreachable: bool + started_at: UpdateMapDateTime + finished_at: UpdateMapDateTime + instance_configuration: str + termination_deadline: Optional[datetime.datetime] + termination_reason: Optional[InstanceTerminationReason] + termination_reason_message: Optional[str] + health: HealthStatus + first_termination_retry_at: UpdateMapDateTime + last_termination_retry_at: UpdateMapDateTime + backend: BackendType + backend_data: Optional[str] + offer: str + region: str + price: float + job_provisioning_data: str + total_blocks: int + busy_blocks: int + deleted: bool + deleted_at: UpdateMapDateTime + + +class HealthCheckCreate(TypedDict): + instance_id: uuid.UUID + collected_at: datetime.datetime + status: HealthStatus + response: str + + +@dataclass +class ProcessResult: + instance_update_map: InstanceUpdateMap = field(default_factory=InstanceUpdateMap) + health_check_create: Optional[HealthCheckCreate] = None + new_placement_group_models: list[PlacementGroupModel] = field(default_factory=list) + schedule_pg_deletion_fleet_id: Optional[uuid.UUID] = None + schedule_pg_deletion_except_id: Optional[uuid.UUID] = None + + +async def can_terminate_fleet_instances_on_idle_duration( + session: AsyncSession, + fleet_model: FleetModel, +) -> bool: + fleet_spec = get_fleet_spec(fleet_model) + if fleet_spec.configuration.nodes is None or fleet_spec.autocreated: + return True + res = await session.execute( + select(func.count(1)).where( + InstanceModel.fleet_id == fleet_model.id, + InstanceModel.deleted == False, + InstanceModel.status.not_in(InstanceStatus.finished_statuses()), + ) + ) + return res.scalar_one() > fleet_spec.configuration.nodes.min + + +def get_instance_idle_duration(instance_model: InstanceModel) -> datetime.timedelta: + last_time = instance_model.created_at + if instance_model.last_job_processed_at is not None: + last_time = instance_model.last_job_processed_at + return get_current_datetime() - last_time + + +def get_provisioning_deadline( + instance_model: InstanceModel, + job_provisioning_data: JobProvisioningData, +) -> datetime.datetime: + assert instance_model.started_at is not None + timeout_interval = get_provisioning_timeout( + backend_type=job_provisioning_data.get_base_backend(), + instance_type_name=job_provisioning_data.instance_type.name, + ) + return instance_model.started_at + timeout_interval + + +def next_termination_retry_at(last_termination_retry_at: datetime.datetime) -> datetime.datetime: + return last_termination_retry_at + TERMINATION_RETRY_TIMEOUT + + +def get_termination_deadline(first_termination_retry_at: datetime.datetime) -> datetime.datetime: + return first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION + + +def ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]: + return [pkey_from_str(ssh_key.private) for ssh_key in ssh_keys if ssh_key.private is not None] + + +def set_status_update( + update_map: InstanceUpdateMap, + instance_model: InstanceModel, + new_status: InstanceStatus, + termination_reason: Union[Optional[InstanceTerminationReason], Unset] = UNSET, + termination_reason_message: Union[Optional[str], Unset] = UNSET, +) -> bool: + old_status = instance_model.status + changed = False + if old_status == new_status: + if not isinstance(termination_reason, Unset): + update_map["termination_reason"] = termination_reason + changed = True + if not isinstance(termination_reason_message, Unset): + update_map["termination_reason_message"] = termination_reason_message + changed = True + return changed + + effective_termination_reason = instance_model.termination_reason + if not isinstance(termination_reason, Unset): + effective_termination_reason = termination_reason + update_map["termination_reason"] = effective_termination_reason + changed = True + + effective_termination_reason_message = instance_model.termination_reason_message + if not isinstance(termination_reason_message, Unset): + effective_termination_reason_message = termination_reason_message + update_map["termination_reason_message"] = effective_termination_reason_message + changed = True + + update_map["status"] = new_status + changed = True + return changed + + +def set_health_update( + update_map: InstanceUpdateMap, + instance_model: InstanceModel, + health: HealthStatus, +) -> bool: + if instance_model.health == health: + return False + update_map["health"] = health + return True + + +def set_unreachable_update( + update_map: InstanceUpdateMap, + instance_model: InstanceModel, + unreachable: bool, +) -> bool: + if not instance_model.status.is_available() or instance_model.unreachable == unreachable: + return False + update_map["unreachable"] = unreachable + return True diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/ssh_deploy.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/ssh_deploy.py new file mode 100644 index 0000000000..b4e3e1122a --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/ssh_deploy.py @@ -0,0 +1,302 @@ +import asyncio +from datetime import timedelta +from typing import Any, Optional + +from paramiko.pkey import PKey +from paramiko.ssh_exception import PasswordRequiredException +from pydantic import ValidationError + +from dstack._internal import settings +from dstack._internal.core.backends.base.compute import ( + GoArchType, + get_dstack_runner_binary_path, + get_dstack_shim_binary_path, + get_dstack_working_dir, + get_shim_env, + get_shim_pre_start_commands, +) +from dstack._internal.core.errors import SSHProvisioningError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import ( + InstanceAvailability, + InstanceOfferWithAvailability, + InstanceRuntime, + InstanceStatus, + InstanceTerminationReason, + RemoteConnectionInfo, +) +from dstack._internal.core.models.runs import JobProvisioningData +from dstack._internal.server.background.pipeline_tasks.base import NOW_PLACEHOLDER +from dstack._internal.server.background.pipeline_tasks.instances.common import ( + PROVISIONING_TIMEOUT_SECONDS, + ProcessResult, + set_status_update, + ssh_keys_to_pkeys, +) +from dstack._internal.server.models import InstanceModel +from dstack._internal.server.schemas.instances import InstanceCheck +from dstack._internal.server.schemas.runner import HealthcheckResponse +from dstack._internal.server.services.instances import get_instance_remote_connection_info +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.offers import is_divisible_into_blocks +from dstack._internal.server.services.runner import client as runner_client +from dstack._internal.server.services.ssh_fleets.provisioning import ( + detect_cpu_arch, + get_host_info, + get_paramiko_connection, + get_shim_healthcheck, + host_info_to_instance_type, + remove_dstack_runner_if_exists, + remove_host_info_if_exists, + run_pre_start_commands, + run_shim_as_systemd_service, + upload_envs, +) +from dstack._internal.utils.common import get_current_datetime, run_async +from dstack._internal.utils.logging import get_logger +from dstack._internal.utils.network import get_ip_from_network, is_ip_among_addresses + +logger = get_logger(__name__) + + +async def add_ssh_instance(instance_model: InstanceModel) -> ProcessResult: + result = ProcessResult() + logger.info("Adding ssh instance %s...", instance_model.name) + + retry_duration_deadline = instance_model.created_at + timedelta( + seconds=PROVISIONING_TIMEOUT_SECONDS + ) + if retry_duration_deadline < get_current_datetime(): + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, + termination_reason_message=( + f"Failed to add SSH instance in {PROVISIONING_TIMEOUT_SECONDS}s" + ), + ) + return result + + remote_details = get_instance_remote_connection_info(instance_model) + assert remote_details is not None + + try: + pkeys = ssh_keys_to_pkeys(remote_details.ssh_keys) + ssh_proxy_pkeys = None + if remote_details.ssh_proxy_keys is not None: + ssh_proxy_pkeys = ssh_keys_to_pkeys(remote_details.ssh_proxy_keys) + except (ValueError, PasswordRequiredException): + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Unsupported private SSH key type", + ) + return result + + authorized_keys = [pkey.public.strip() for pkey in remote_details.ssh_keys] + authorized_keys.append(instance_model.project.ssh_public_key.strip()) + + try: + future = run_async( + _deploy_instance, + remote_details, + pkeys, + ssh_proxy_pkeys, + authorized_keys, + ) + health, host_info, arch = await asyncio.wait_for(future, timeout=20 * 60) + except (asyncio.TimeoutError, TimeoutError) as exc: + logger.warning( + "%s: deploy timeout when adding SSH instance: %s", + fmt(instance_model), + repr(exc), + ) + return result + except SSHProvisioningError as exc: + logger.warning( + "%s: provisioning error when adding SSH instance: %s", + fmt(instance_model), + repr(exc), + ) + return result + except Exception: + logger.exception("%s: unexpected error when adding SSH instance", fmt(instance_model)) + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Unexpected error when adding SSH instance", + ) + return result + + instance_type = host_info_to_instance_type(host_info, arch) + try: + instance_network, internal_ip = _resolve_ssh_instance_network(instance_model, host_info) + except _SSHInstanceNetworkResolutionError as exc: + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message=str(exc), + ) + return result + + divisible, blocks = is_divisible_into_blocks( + cpu_count=instance_type.resources.cpus, + gpu_count=len(instance_type.resources.gpus), + blocks="auto" if instance_model.total_blocks is None else instance_model.total_blocks, + ) + if not divisible: + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Cannot split into blocks", + ) + return result + + region = instance_model.region + assert region is not None + job_provisioning_data = JobProvisioningData( + backend=BackendType.REMOTE, + instance_type=instance_type, + instance_id="instance_id", + hostname=remote_details.host, + region=region, + price=0, + internal_ip=internal_ip, + instance_network=instance_network, + username=remote_details.ssh_user, + ssh_port=remote_details.port, + dockerized=True, + backend_data=None, + ssh_proxy=remote_details.ssh_proxy, + ) + instance_offer = InstanceOfferWithAvailability( + backend=BackendType.REMOTE, + instance=instance_type, + region=region, + price=0, + availability=InstanceAvailability.AVAILABLE, + instance_runtime=InstanceRuntime.SHIM, + ) + + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.IDLE if health else InstanceStatus.PROVISIONING, + ) + result.instance_update_map["backend"] = BackendType.REMOTE + result.instance_update_map["price"] = 0 + result.instance_update_map["offer"] = instance_offer.json() + result.instance_update_map["job_provisioning_data"] = job_provisioning_data.json() + result.instance_update_map["started_at"] = NOW_PLACEHOLDER + result.instance_update_map["total_blocks"] = blocks + return result + + +class _SSHInstanceNetworkResolutionError(Exception): + pass + + +def _resolve_ssh_instance_network( + instance_model: InstanceModel, + host_info: dict[str, Any], +) -> tuple[Optional[str], Optional[str]]: + instance_network = None + internal_ip = None + try: + default_job_provisioning_data = JobProvisioningData.__response__.parse_raw( + instance_model.job_provisioning_data + ) + instance_network = default_job_provisioning_data.instance_network + internal_ip = default_job_provisioning_data.internal_ip + except ValidationError: + pass + + host_network_addresses = host_info.get("addresses", []) + if internal_ip is None: + internal_ip = get_ip_from_network( + network=instance_network, + addresses=host_network_addresses, + ) + if instance_network is not None and internal_ip is None: + raise _SSHInstanceNetworkResolutionError( + "Failed to locate internal IP address on the given network" + ) + if internal_ip is not None and not is_ip_among_addresses( + ip_address=internal_ip, + addresses=host_network_addresses, + ): + raise _SSHInstanceNetworkResolutionError( + "Specified internal IP not found among instance interfaces" + ) + return instance_network, internal_ip + + +def _deploy_instance( + remote_details: RemoteConnectionInfo, + pkeys: list[PKey], + ssh_proxy_pkeys: Optional[list[PKey]], + authorized_keys: list[str], +) -> tuple[InstanceCheck, dict[str, Any], GoArchType]: + with get_paramiko_connection( + remote_details.ssh_user, + remote_details.host, + remote_details.port, + pkeys, + remote_details.ssh_proxy, + ssh_proxy_pkeys, + ) as client: + logger.debug("Connected to %s %s", remote_details.ssh_user, remote_details.host) + + arch = detect_cpu_arch(client) + logger.debug("%s: CPU arch is %s", remote_details.host, arch) + + # Execute pre start commands + shim_pre_start_commands = get_shim_pre_start_commands(arch=arch) + run_pre_start_commands(client, shim_pre_start_commands, authorized_keys) + logger.debug("The script for installing dstack has been executed") + + # Upload envs + shim_envs = get_shim_env(arch=arch) + try: + fleet_configuration_envs = remote_details.env.as_dict() + except ValueError as exc: + raise SSHProvisioningError(f"Invalid Env: {exc}") from exc + shim_envs.update(fleet_configuration_envs) + dstack_working_dir = get_dstack_working_dir() + dstack_shim_binary_path = get_dstack_shim_binary_path() + dstack_runner_binary_path = get_dstack_runner_binary_path() + upload_envs(client, dstack_working_dir, shim_envs) + logger.debug("The dstack-shim environment variables have been installed") + + # Ensure we have fresh versions of host info.json and dstack-runner + remove_host_info_if_exists(client, dstack_working_dir) + remove_dstack_runner_if_exists(client, dstack_runner_binary_path) + + # Run dstack-shim as a systemd service + run_shim_as_systemd_service( + client=client, + binary_path=dstack_shim_binary_path, + working_dir=dstack_working_dir, + dev=settings.DSTACK_VERSION is None, + ) + + # Get host info + host_info = get_host_info(client, dstack_working_dir) + logger.debug("Received a host_info %s", host_info) + + healthcheck_out = get_shim_healthcheck(client) + try: + healthcheck = HealthcheckResponse.__response__.parse_raw(healthcheck_out) + except ValueError as exc: + raise SSHProvisioningError(f"Cannot parse HealthcheckResponse: {exc}") from exc + instance_check = runner_client.healthcheck_response_to_instance_check(healthcheck) + return instance_check, host_info, arch diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py new file mode 100644 index 0000000000..eb1f3c8a39 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py @@ -0,0 +1,88 @@ +from dstack._internal.core.errors import BackendError, NotYetTerminated +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.server.background.pipeline_tasks.base import NOW_PLACEHOLDER +from dstack._internal.server.background.pipeline_tasks.instances.common import ( + ProcessResult, + get_termination_deadline, + next_termination_retry_at, + set_status_update, +) +from dstack._internal.server.models import InstanceModel +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services.instances import get_instance_provisioning_data +from dstack._internal.utils.common import get_current_datetime, run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +async def terminate_instance(instance_model: InstanceModel) -> ProcessResult: + result = ProcessResult() + now = get_current_datetime() + if ( + instance_model.last_termination_retry_at is not None + and next_termination_retry_at(instance_model.last_termination_retry_at) > now + ): + return result + + job_provisioning_data = get_instance_provisioning_data(instance_model) + if job_provisioning_data is not None and job_provisioning_data.backend != BackendType.REMOTE: + backend = await backends_services.get_project_backend_by_type( + project=instance_model.project, + backend_type=job_provisioning_data.backend, + ) + if backend is None: + logger.error( + "Failed to terminate instance %s. Backend %s not available.", + instance_model.name, + job_provisioning_data.backend, + ) + else: + logger.debug("Terminating runner instance %s", job_provisioning_data.hostname) + try: + await run_async( + backend.compute().terminate_instance, + job_provisioning_data.instance_id, + job_provisioning_data.region, + job_provisioning_data.backend_data, + ) + except Exception as exc: + first_retry_at = instance_model.first_termination_retry_at + if first_retry_at is None: + first_retry_at = now + result.instance_update_map["first_termination_retry_at"] = NOW_PLACEHOLDER + result.instance_update_map["last_termination_retry_at"] = NOW_PLACEHOLDER + if next_termination_retry_at(now) < get_termination_deadline(first_retry_at): + if isinstance(exc, NotYetTerminated): + logger.debug( + "Instance %s termination in progress: %s", + instance_model.name, + exc, + ) + else: + logger.warning( + "Failed to terminate instance %s. Will retry. Error: %r", + instance_model.name, + exc, + exc_info=not isinstance(exc, BackendError), + ) + return result + logger.error( + "Failed all attempts to terminate instance %s." + " Please terminate the instance manually to avoid unexpected charges." + " Error: %r", + instance_model.name, + exc, + exc_info=not isinstance(exc, BackendError), + ) + + result.instance_update_map["deleted"] = True + result.instance_update_map["deleted_at"] = NOW_PLACEHOLDER + result.instance_update_map["finished_at"] = NOW_PLACEHOLDER + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + ) + return result diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py index 703cfe1548..552ae00dc8 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -18,6 +18,8 @@ PipelineItem, UpdateMapDateTime, Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, @@ -189,12 +191,7 @@ async def process(self, item: PipelineItem): ) placement_group_model = res.unique().scalar_one_or_none() if placement_group_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return result = await _delete_placement_group(placement_group_model) @@ -217,12 +214,7 @@ async def process(self, item: PipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) class _PlacementGroupUpdateMap(ItemUpdateMap, total=False): diff --git a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py index c7a8f5761a..81d94c361b 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py @@ -19,6 +19,8 @@ PipelineItem, UpdateMapDateTime, Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, @@ -204,8 +206,6 @@ async def process(self, item: VolumePipelineItem): await _process_to_be_deleted_item(item) elif item.status == VolumeStatus.SUBMITTED: await _process_submitted_item(item) - elif item.status == VolumeStatus.ACTIVE: - pass async def _process_submitted_item(item: VolumePipelineItem): @@ -227,12 +227,7 @@ async def _process_submitted_item(item: VolumePipelineItem): ) volume_model = res.unique().scalar_one_or_none() if volume_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return result = await _process_submitted_volume(volume_model) @@ -253,12 +248,7 @@ async def _process_submitted_item(item: VolumePipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) # TODO: Clean up volume. return emit_volume_status_change_event( @@ -369,12 +359,7 @@ async def _process_to_be_deleted_item(item: VolumePipelineItem): ) volume_model = res.unique().scalar_one_or_none() if volume_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return result = await _process_to_be_deleted_volume(volume_model) @@ -396,12 +381,7 @@ async def _process_to_be_deleted_item(item: VolumePipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) return events.emit( session, diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py index 9c7cd6ac1a..2994fca37c 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -14,8 +14,10 @@ from dstack._internal.server.background.scheduled_tasks.idle_volumes import ( process_idle_volumes, ) +from dstack._internal.server.background.scheduled_tasks.instance_healthchecks import ( + delete_instance_healthchecks, +) from dstack._internal.server.background.scheduled_tasks.instances import ( - delete_instance_health_checks, process_instances, ) from dstack._internal.server.background.scheduled_tasks.metrics import ( @@ -93,16 +95,16 @@ def start_scheduled_tasks() -> AsyncIOScheduler: _scheduler.add_job(collect_metrics, IntervalTrigger(seconds=10), max_instances=1) _scheduler.add_job(delete_metrics, IntervalTrigger(minutes=5), max_instances=1) _scheduler.add_job(delete_events, IntervalTrigger(minutes=7), max_instances=1) + _scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15)) + _scheduler.add_job( + process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 + ) + _scheduler.add_job(delete_instance_healthchecks, IntervalTrigger(minutes=5), max_instances=1) if settings.ENABLE_PROMETHEUS_METRICS: _scheduler.add_job( collect_prometheus_metrics, IntervalTrigger(seconds=10), max_instances=1 ) _scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1) - _scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15)) - _scheduler.add_job( - process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 - ) - _scheduler.add_job(delete_instance_health_checks, IntervalTrigger(minutes=5), max_instances=1) if not FeatureFlags.PIPELINE_PROCESSING_ENABLED: _scheduler.add_job( process_fleets, @@ -144,13 +146,13 @@ def start_scheduled_tasks() -> AsyncIOScheduler: kwargs={"batch_size": 5}, max_instances=2 if replica == 0 else 1, ) - _scheduler.add_job( - process_instances, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) if not FeatureFlags.PIPELINE_PROCESSING_ENABLED: + _scheduler.add_job( + process_instances, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) _scheduler.add_job( process_compute_groups, IntervalTrigger(seconds=15, jitter=2), diff --git a/src/dstack/_internal/server/background/scheduled_tasks/instance_healthchecks.py b/src/dstack/_internal/server/background/scheduled_tasks/instance_healthchecks.py new file mode 100644 index 0000000000..41e83c71aa --- /dev/null +++ b/src/dstack/_internal/server/background/scheduled_tasks/instance_healthchecks.py @@ -0,0 +1,20 @@ +from datetime import timedelta + +from sqlalchemy import delete + +from dstack._internal.server import settings +from dstack._internal.server.db import get_session_ctx +from dstack._internal.server.models import InstanceHealthCheckModel +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime + + +@sentry_utils.instrument_scheduled_task +async def delete_instance_healthchecks(): + now = get_current_datetime() + cutoff = now - timedelta(seconds=settings.SERVER_INSTANCE_HEALTH_TTL_SECONDS) + async with get_session_ctx() as session: + await session.execute( + delete(InstanceHealthCheckModel).where(InstanceHealthCheckModel.collected_at < cutoff) + ) + await session.commit() diff --git a/src/dstack/_internal/server/background/scheduled_tasks/instances.py b/src/dstack/_internal/server/background/scheduled_tasks/instances.py index e5ecba5278..1857e0ad09 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/instances.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/instances.py @@ -9,7 +9,7 @@ from paramiko.pkey import PKey from paramiko.ssh_exception import PasswordRequiredException from pydantic import ValidationError -from sqlalchemy import and_, delete, func, not_, select +from sqlalchemy import and_, func, not_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -33,12 +33,11 @@ BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT, ) from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT - -# FIXME: ProvisioningError is a subclass of ComputeError and should not be used outside of Compute from dstack._internal.core.errors import ( BackendError, NotYetTerminated, ProvisioningError, + SSHProvisioningError, ) from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.fleets import InstanceGroupPlacement @@ -108,8 +107,7 @@ ) from dstack._internal.server.services.runner import client as runner_client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel -from dstack._internal.server.utils import sentry_utils -from dstack._internal.server.utils.provisioning import ( +from dstack._internal.server.services.ssh_fleets.provisioning import ( detect_cpu_arch, get_host_info, get_paramiko_connection, @@ -121,6 +119,7 @@ run_shim_as_systemd_service, upload_envs, ) +from dstack._internal.server.utils import sentry_utils from dstack._internal.utils.common import ( get_current_datetime, get_or_error, @@ -152,17 +151,6 @@ async def process_instances(batch_size: int = 1): await asyncio.gather(*tasks) -@sentry_utils.instrument_scheduled_task -async def delete_instance_health_checks(): - now = get_current_datetime() - cutoff = now - timedelta(seconds=server_settings.SERVER_INSTANCE_HEALTH_TTL_SECONDS) - async with get_session_ctx() as session: - await session.execute( - delete(InstanceHealthCheckModel).where(InstanceHealthCheckModel.collected_at < cutoff) - ) - await session.commit() - - @sentry_utils.instrument_scheduled_task async def _process_next_instance(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__) @@ -211,63 +199,81 @@ async def _process_next_instance(): async def _process_instance(session: AsyncSession, instance: InstanceModel): logger.debug("%s: processing instance, status: %s", fmt(instance), instance.status.upper()) - # Refetch to load related attributes. - # Load related attributes only for statuses that always need them. - if instance.status in ( - InstanceStatus.PENDING, - InstanceStatus.TERMINATING, - ): - res = await session.execute( - select(InstanceModel) - .where(InstanceModel.id == instance.id) - .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) - .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) - .options( - joinedload(InstanceModel.fleet).joinedload( - FleetModel.instances.and_(InstanceModel.deleted == False) - ), - ) - .execution_options(populate_existing=True) - ) - instance = res.unique().scalar_one() + if instance.status == InstanceStatus.PENDING: + await _process_pending_instance(session, instance) + elif instance.status == InstanceStatus.PROVISIONING: + await _process_provisioning_instance(session, instance) elif instance.status == InstanceStatus.IDLE: - res = await session.execute( - select(InstanceModel) - .where(InstanceModel.id == instance.id) - .options(joinedload(InstanceModel.project)) - .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) - .options( - joinedload(InstanceModel.fleet).joinedload( - FleetModel.instances.and_(InstanceModel.deleted == False) - ), + await _process_idle_instance(session, instance) + elif instance.status == InstanceStatus.BUSY: + await _process_busy_instance(session, instance) + elif instance.status == InstanceStatus.TERMINATING: + await _process_terminating_instance(session, instance) + + instance.last_processed_at = get_current_datetime() + await session.commit() + + +async def _process_pending_instance(session: AsyncSession, instance: InstanceModel): + instance = await _refetch_instance_for_pending_or_terminating(session, instance.id) + if is_ssh_instance(instance): + await _add_remote(session, instance) + else: + await _create_instance(session=session, instance=instance) + + +async def _process_provisioning_instance(session: AsyncSession, instance: InstanceModel): + await _check_instance(session, instance) + + +async def _process_idle_instance(session: AsyncSession, instance: InstanceModel): + instance = await _refetch_instance_for_idle(session, instance.id) + idle_duration_expired = _check_and_mark_terminating_if_idle_duration_expired(session, instance) + if not idle_duration_expired: + await _check_instance(session, instance) + + +async def _process_busy_instance(session: AsyncSession, instance: InstanceModel): + await _check_instance(session, instance) + + +async def _process_terminating_instance(session: AsyncSession, instance: InstanceModel): + instance = await _refetch_instance_for_pending_or_terminating(session, instance.id) + await _terminate(session, instance) + + +async def _refetch_instance_for_pending_or_terminating( + session: AsyncSession, instance_id +) -> InstanceModel: + res = await session.execute( + select(InstanceModel) + .where(InstanceModel.id == instance_id) + .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) + .options( + joinedload(InstanceModel.fleet).joinedload( + FleetModel.instances.and_(InstanceModel.deleted == False) ) - .execution_options(populate_existing=True) ) - instance = res.unique().scalar_one() + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one() - if instance.status == InstanceStatus.PENDING: - if is_ssh_instance(instance): - await _add_remote(session, instance) - else: - await _create_instance( - session=session, - instance=instance, + +async def _refetch_instance_for_idle(session: AsyncSession, instance_id) -> InstanceModel: + res = await session.execute( + select(InstanceModel) + .where(InstanceModel.id == instance_id) + .options(joinedload(InstanceModel.project)) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) + .options( + joinedload(InstanceModel.fleet).joinedload( + FleetModel.instances.and_(InstanceModel.deleted == False) ) - elif instance.status in ( - InstanceStatus.PROVISIONING, - InstanceStatus.IDLE, - InstanceStatus.BUSY, - ): - idle_duration_expired = _check_and_mark_terminating_if_idle_duration_expired( - session, instance ) - if not idle_duration_expired: - await _check_instance(session, instance) - elif instance.status == InstanceStatus.TERMINATING: - await _terminate(session, instance) - - instance.last_processed_at = get_current_datetime() - await session.commit() + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one() def _check_and_mark_terminating_if_idle_duration_expired( @@ -324,76 +330,61 @@ async def _add_remote(session: AsyncSession, instance: InstanceModel) -> None: switch_instance_status(session, instance, InstanceStatus.TERMINATED) return + remote_details = get_instance_remote_connection_info(instance) + assert remote_details is not None + try: - remote_details = get_instance_remote_connection_info(instance) - assert remote_details is not None - # Prepare connection key - try: - pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys) - if remote_details.ssh_proxy_keys is not None: - ssh_proxy_pkeys = _ssh_keys_to_pkeys(remote_details.ssh_proxy_keys) - else: - ssh_proxy_pkeys = None - except (ValueError, PasswordRequiredException): - instance.termination_reason = InstanceTerminationReason.ERROR - instance.termination_reason_message = "Unsupported private SSH key type" - switch_instance_status(session, instance, InstanceStatus.TERMINATED) - return - - authorized_keys = [pk.public.strip() for pk in remote_details.ssh_keys] - authorized_keys.append(instance.project.ssh_public_key.strip()) + pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys) + if remote_details.ssh_proxy_keys is not None: + ssh_proxy_pkeys = _ssh_keys_to_pkeys(remote_details.ssh_proxy_keys) + else: + ssh_proxy_pkeys = None + except (ValueError, PasswordRequiredException): + instance.termination_reason = InstanceTerminationReason.ERROR + instance.termination_reason_message = "Unsupported private SSH key type" + switch_instance_status(session, instance, InstanceStatus.TERMINATED) + return - try: - future = run_async( - _deploy_instance, remote_details, pkeys, ssh_proxy_pkeys, authorized_keys - ) - deploy_timeout = 20 * 60 # 20 minutes - result = await asyncio.wait_for(future, timeout=deploy_timeout) - health, host_info, arch = result - except (asyncio.TimeoutError, TimeoutError) as e: - raise ProvisioningError(f"Deploy timeout: {e}") from e - except Exception as e: - raise ProvisioningError(f"Deploy instance raised an error: {e}") from e - except ProvisioningError as e: + authorized_keys = [pk.public.strip() for pk in remote_details.ssh_keys] + authorized_keys.append(instance.project.ssh_public_key.strip()) + + try: + future = run_async( + _deploy_instance, remote_details, pkeys, ssh_proxy_pkeys, authorized_keys + ) + deploy_timeout = 20 * 60 # 20 minutes + health, host_info, arch = await asyncio.wait_for(future, timeout=deploy_timeout) + except (asyncio.TimeoutError, TimeoutError) as e: logger.warning( - "Provisioning instance %s could not be completed because of the error: %s", - instance.name, - e, + "%s: deploy timeout when adding SSH instance: %s", + fmt(instance), + repr(e), + ) + # Stays in PENDING, may retry later + return + except SSHProvisioningError as e: + logger.warning( + "%s: provisioning error when adding SSH instance: %s", + fmt(instance), + repr(e), ) # Stays in PENDING, may retry later return + except Exception: + logger.exception("%s: unexpected error when adding SSH instance", fmt(instance)) + instance.termination_reason = InstanceTerminationReason.ERROR + instance.termination_reason_message = "Unexpected error when adding SSH instance" + switch_instance_status(session, instance, InstanceStatus.TERMINATED) + return instance_type = host_info_to_instance_type(host_info, arch) - instance_network = None - internal_ip = None try: - default_jpd = JobProvisioningData.__response__.parse_raw(instance.job_provisioning_data) - instance_network = default_jpd.instance_network - internal_ip = default_jpd.internal_ip - except ValidationError: - pass - - host_network_addresses = host_info.get("addresses", []) - if internal_ip is None: - internal_ip = get_ip_from_network( - network=instance_network, - addresses=host_network_addresses, - ) - if instance_network is not None and internal_ip is None: + instance_network, internal_ip = _resolve_ssh_instance_network(instance, host_info) + except _SSHInstanceNetworkResolutionError as e: instance.termination_reason = InstanceTerminationReason.ERROR - instance.termination_reason_message = ( - "Failed to locate internal IP address on the given network" - ) + instance.termination_reason_message = str(e) switch_instance_status(session, instance, InstanceStatus.TERMINATED) return - if internal_ip is not None: - if not is_ip_among_addresses(ip_address=internal_ip, addresses=host_network_addresses): - instance.termination_reason = InstanceTerminationReason.ERROR - instance.termination_reason_message = ( - "Specified internal IP not found among instance interfaces" - ) - switch_instance_status(session, instance, InstanceStatus.TERMINATED) - return divisible, blocks = is_divisible_into_blocks( cpu_count=instance_type.resources.cpus, @@ -444,6 +435,41 @@ async def _add_remote(session: AsyncSession, instance: InstanceModel) -> None: instance.started_at = get_current_datetime() +class _SSHInstanceNetworkResolutionError(Exception): + pass + + +def _resolve_ssh_instance_network( + instance: InstanceModel, host_info: dict[str, Any] +) -> tuple[Optional[str], Optional[str]]: + instance_network = None + internal_ip = None + try: + default_jpd = JobProvisioningData.__response__.parse_raw(instance.job_provisioning_data) + instance_network = default_jpd.instance_network + internal_ip = default_jpd.internal_ip + except ValidationError: + pass + + host_network_addresses = host_info.get("addresses", []) + if internal_ip is None: + internal_ip = get_ip_from_network( + network=instance_network, + addresses=host_network_addresses, + ) + if instance_network is not None and internal_ip is None: + raise _SSHInstanceNetworkResolutionError( + "Failed to locate internal IP address on the given network" + ) + if internal_ip is not None and not is_ip_among_addresses( + ip_address=internal_ip, addresses=host_network_addresses + ): + raise _SSHInstanceNetworkResolutionError( + "Specified internal IP not found among instance interfaces" + ) + return instance_network, internal_ip + + def _deploy_instance( remote_details: RemoteConnectionInfo, pkeys: list[PKey], @@ -473,7 +499,7 @@ def _deploy_instance( try: fleet_configuration_envs = remote_details.env.as_dict() except ValueError as e: - raise ProvisioningError(f"Invalid Env: {e}") from e + raise SSHProvisioningError(f"Invalid Env: {e}") from e shim_envs.update(fleet_configuration_envs) dstack_working_dir = get_dstack_working_dir() dstack_shim_binary_path = get_dstack_shim_binary_path() @@ -501,7 +527,7 @@ def _deploy_instance( try: healthcheck = HealthcheckResponse.__response__.parse_raw(healthcheck_out) except ValueError as e: - raise ProvisioningError(f"Cannot parse HealthcheckResponse: {e}") from e + raise SSHProvisioningError(f"Cannot parse HealthcheckResponse: {e}") from e instance_check = runner_client.healthcheck_response_to_instance_check(healthcheck) return instance_check, host_info, arch @@ -646,6 +672,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No if instance.fleet and instance.id == master_instance.id and is_cloud_cluster(instance.fleet): # Do not attempt to deploy other instances, as they won't determine the correct cluster # backend, region, and placement group without a successfully deployed master instance + # FIXME: Race condition with siblings processed concurrently. for sibling_instance in instance.fleet.instances: if sibling_instance.id == instance.id: continue @@ -707,50 +734,22 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non switch_instance_status(session, instance, InstanceStatus.BUSY) return - ssh_private_keys = get_instance_ssh_private_keys(instance) - - health_check_cutoff = get_current_datetime() - timedelta( - seconds=server_settings.SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS - ) - res = await session.execute( - select(func.count(1)).where( - InstanceHealthCheckModel.instance_id == instance.id, - InstanceHealthCheckModel.collected_at > health_check_cutoff, - ) + check_instance_health = await _should_check_instance_health(session, instance) + instance_check = await _run_instance_check( + instance=instance, + job_provisioning_data=job_provisioning_data, + check_instance_health=check_instance_health, ) - check_instance_health = res.scalar_one() == 0 - - # May return False if fails to establish ssh connection - instance_check = await run_async( - _check_instance_inner, - ssh_private_keys, - job_provisioning_data, - None, + health_status = _get_health_status_for_instance_check( instance=instance, + instance_check=instance_check, check_instance_health=check_instance_health, ) - if instance_check is False: - instance_check = InstanceCheck(reachable=False, message="SSH or tunnel error") - - if instance_check.reachable and check_instance_health: - health_status = instance_check.get_health_status() - else: - # Keep previous health status - health_status = instance.health - - loglevel = logging.DEBUG - if not instance_check.reachable and instance.status.is_available(): - loglevel = logging.WARNING - elif check_instance_health and not health_status.is_healthy(): - loglevel = logging.WARNING - logger.log( - loglevel, - "Instance %s check: reachable=%s health_status=%s message=%r", - instance.name, - instance_check.reachable, - health_status.name, - instance_check.message, - extra={"instance_name": instance.name, "health_status": health_status}, + _log_instance_check_result( + instance=instance, + instance_check=instance_check, + health_status=health_status, + check_instance_health=check_instance_health, ) if instance_check.has_health_checks(): @@ -797,6 +796,73 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non switch_instance_status(session, instance, InstanceStatus.TERMINATING) +async def _should_check_instance_health(session: AsyncSession, instance: InstanceModel) -> bool: + health_check_cutoff = get_current_datetime() - timedelta( + seconds=server_settings.SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS + ) + result = await session.execute( + select(func.count(1)).where( + InstanceHealthCheckModel.instance_id == instance.id, + InstanceHealthCheckModel.collected_at > health_check_cutoff, + ) + ) + return result.scalar_one() == 0 + + +async def _run_instance_check( + instance: InstanceModel, + job_provisioning_data: JobProvisioningData, + check_instance_health: bool, +) -> InstanceCheck: + ssh_private_keys = get_instance_ssh_private_keys(instance) + + # May return False if fails to establish ssh connection + instance_check = await run_async( + _check_instance_inner, + ssh_private_keys, + job_provisioning_data, + None, + instance=instance, + check_instance_health=check_instance_health, + ) + if instance_check is False: + return InstanceCheck(reachable=False, message="SSH or tunnel error") + return instance_check + + +def _get_health_status_for_instance_check( + instance: InstanceModel, + instance_check: InstanceCheck, + check_instance_health: bool, +) -> HealthStatus: + if instance_check.reachable and check_instance_health: + return instance_check.get_health_status() + # Keep previous health status + return instance.health + + +def _log_instance_check_result( + instance: InstanceModel, + instance_check: InstanceCheck, + health_status: HealthStatus, + check_instance_health: bool, +) -> None: + loglevel = logging.DEBUG + if not instance_check.reachable and instance.status.is_available(): + loglevel = logging.WARNING + elif check_instance_health and not health_status.is_healthy(): + loglevel = logging.WARNING + logger.log( + loglevel, + "Instance %s check: reachable=%s health_status=%s message=%r", + instance.name, + instance_check.reachable, + health_status.name, + instance_check.message, + extra={"instance_name": instance.name, "health_status": health_status}, + ) + + async def _wait_for_instance_provisioning_data( session: AsyncSession, project: ProjectModel, @@ -1134,7 +1200,8 @@ def _get_termination_deadline(instance: InstanceModel) -> datetime.datetime: def _need_to_wait_fleet_provisioning( - instance: InstanceModel, master_instance: InstanceModel + instance: InstanceModel, + master_instance: InstanceModel, ) -> bool: # Cluster cloud instances should wait for the first fleet instance to be provisioned # so that they are provisioned in the same backend/region diff --git a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py index 151f07deeb..729ded205c 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py @@ -574,6 +574,10 @@ async def _fetch_fleet_with_master_instance_provisioning_data( fleet_model: Optional[FleetModel], job: Job, ) -> Optional[JobProvisioningData]: + # TODO: When submitted-jobs provisioning moves to pipelines, stop inferring the + # cluster master from loaded fleet instances here. Resolve the current master via + # FleetModel.current_master_instance_id so jobs follow the same master election + # as FleetPipeline/InstancePipeline. master_instance_provisioning_data = None if is_master_job(job) and fleet_model is not None: fleet = fleet_model_to_fleet(fleet_model) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py index 3749076c1a..27163b53d9 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py @@ -66,6 +66,7 @@ async def _process_next_terminating_job(): .where( InstanceModel.id == job_model.used_instance_id, InstanceModel.id.not_in(instance_lockset), + InstanceModel.lock_expires_at.is_(None), ) .with_for_update(skip_locked=True, key_share=True) ) diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_05_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py b/src/dstack/_internal/server/migrations/versions/2026/03_05_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py new file mode 100644 index 0000000000..f1c2b1217a --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_05_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py @@ -0,0 +1,47 @@ +"""Add InstanceModel pipeline columns + +Revision ID: 8e8647f20aa4 +Revises: 5e8c7a9202bc +Create Date: 2026-03-05 05:47:39.307013+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "8e8647f20aa4" +down_revision = "5e8c7a9202bc" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_05_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py b/src/dstack/_internal/server/migrations/versions/2026/03_05_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py new file mode 100644 index 0000000000..e629de0950 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_05_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py @@ -0,0 +1,49 @@ +"""Add ix_instances_pipeline_fetch_q index + +Revision ID: 297c68450cc8 +Revises: 8e8647f20aa4 +Create Date: 2026-03-05 07:51:02.855596+00:00 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "297c68450cc8" +down_revision = "8e8647f20aa4" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_instances_pipeline_fetch_q", + table_name="instances", + if_exists=True, + postgresql_concurrently=True, + ) + op.create_index( + "ix_instances_pipeline_fetch_q", + "instances", + [sa.literal_column("last_processed_at ASC")], + unique=False, + sqlite_where=sa.text("deleted = 0"), + postgresql_where=sa.text("deleted IS FALSE"), + postgresql_concurrently=True, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_instances_pipeline_fetch_q", + table_name="instances", + if_exists=True, + postgresql_concurrently=True, + ) + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_05_1015_9cb8e4e4d986_add_fleet_current_master_instance.py b/src/dstack/_internal/server/migrations/versions/2026/03_05_1015_9cb8e4e4d986_add_fleet_current_master_instance.py new file mode 100644 index 0000000000..2049236267 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_05_1015_9cb8e4e4d986_add_fleet_current_master_instance.py @@ -0,0 +1,37 @@ +"""Add FleetModel current master instance + +Revision ID: 9cb8e4e4d986 +Revises: 297c68450cc8 +Create Date: 2026-03-05 10:15:00.000000+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +# revision identifiers, used by Alembic. +revision = "9cb8e4e4d986" +down_revision = "297c68450cc8" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("fleets", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "current_master_instance_id", + sqlalchemy_utils.types.uuid.UUIDType(binary=False), + nullable=True, + ) + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("fleets", schema=None) as batch_op: + batch_op.drop_column("current_master_instance_id") + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_05_1045_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py b/src/dstack/_internal/server/migrations/versions/2026/03_05_1045_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py new file mode 100644 index 0000000000..e1cb938750 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_05_1045_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py @@ -0,0 +1,42 @@ +"""Add ix_fleets_current_master_instance_id index + +Revision ID: c7b0a8e57294 +Revises: 9cb8e4e4d986 +Create Date: 2026-03-05 10:45:00.000000+00:00 + +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "c7b0a8e57294" +down_revision = "9cb8e4e4d986" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.get_context().autocommit_block(): + op.drop_index( + "ix_fleets_current_master_instance_id", + table_name="fleets", + if_exists=True, + postgresql_concurrently=True, + ) + op.create_index( + "ix_fleets_current_master_instance_id", + "fleets", + ["current_master_instance_id"], + unique=False, + postgresql_concurrently=True, + ) + + +def downgrade() -> None: + with op.get_context().autocommit_block(): + op.drop_index( + "ix_fleets_current_master_instance_id", + table_name="fleets", + if_exists=True, + postgresql_concurrently=True, + ) diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 15801a25df..d1a30b941b 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -602,7 +602,14 @@ class FleetModel(PipelineModelMixin, BaseModel): runs: Mapped[List["RunModel"]] = relationship(back_populates="fleet") jobs: Mapped[List["JobModel"]] = relationship(back_populates="fleet") - instances: Mapped[List["InstanceModel"]] = relationship(back_populates="fleet") + instances: Mapped[List["InstanceModel"]] = relationship( + back_populates="fleet", + foreign_keys="InstanceModel.fleet_id", + ) + + current_master_instance_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUIDType(binary=False), index=True + ) # `consolidation_attempt` counts how many times in a row fleet needed consolidation. # Allows increasing delays between attempts. @@ -619,7 +626,7 @@ class FleetModel(PipelineModelMixin, BaseModel): ) -class InstanceModel(BaseModel): +class InstanceModel(PipelineModelMixin, BaseModel): __tablename__ = "instances" id: Mapped[uuid.UUID] = mapped_column( @@ -647,7 +654,10 @@ class InstanceModel(BaseModel): pool: Mapped[Optional["PoolModel"]] = relationship(back_populates="instances") fleet_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("fleets.id"), index=True) - fleet: Mapped[Optional["FleetModel"]] = relationship(back_populates="instances") + fleet: Mapped[Optional["FleetModel"]] = relationship( + back_populates="instances", + foreign_keys=[fleet_id], + ) compute_group_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("compute_groups.id")) compute_group: Mapped[Optional["ComputeGroupModel"]] = relationship(back_populates="instances") @@ -727,6 +737,15 @@ class InstanceModel(BaseModel): cascade="save-update, merge, delete-orphan, delete", ) + __table_args__ = ( + Index( + "ix_instances_pipeline_fetch_q", + last_processed_at.asc(), + postgresql_where=deleted == false(), + sqlite_where=deleted == false(), + ), + ) + class InstanceHealthCheckModel(BaseModel): __tablename__ = "instance_health_checks" diff --git a/src/dstack/_internal/server/routers/fleets.py b/src/dstack/_internal/server/routers/fleets.py index cb18db8bbd..58c87d653b 100644 --- a/src/dstack/_internal/server/routers/fleets.py +++ b/src/dstack/_internal/server/routers/fleets.py @@ -26,6 +26,7 @@ ProjectMember, check_can_access_fleet, ) +from dstack._internal.server.services.pipelines import PipelineHinterProtocol, get_pipeline_hinter from dstack._internal.server.utils.routers import ( CustomORJSONResponse, get_base_api_additional_responses, @@ -144,6 +145,7 @@ async def apply_plan( body: ApplyFleetPlanRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), + pipeline_hinter: PipelineHinterProtocol = Depends(get_pipeline_hinter), ): """ Creates a new fleet or updates an existing fleet. @@ -158,6 +160,7 @@ async def apply_plan( project=project, plan=body.plan, force=body.force, + pipeline_hinter=pipeline_hinter, ) ) @@ -167,6 +170,7 @@ async def create_fleet( body: CreateFleetRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), + pipeline_hinter: PipelineHinterProtocol = Depends(get_pipeline_hinter), ): """ Creates a fleet given a fleet configuration. @@ -178,6 +182,7 @@ async def create_fleet( project=project, user=user, spec=body.spec, + pipeline_hinter=pipeline_hinter, ) ) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index ca5a2e7b4f..183e81b208 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -1,3 +1,4 @@ +import asyncio import uuid from collections.abc import Callable from datetime import datetime @@ -51,7 +52,7 @@ from dstack._internal.core.models.users import GlobalRole from dstack._internal.core.services import validate_dstack_resource_name from dstack._internal.core.services.diff import ModelDiff, copy_model, diff_models -from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite +from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite, sqlite_commit from dstack._internal.server.models import ( ExportedFleetModel, FleetModel, @@ -75,6 +76,7 @@ get_locker, string_to_lock_id, ) +from dstack._internal.server.services.pipelines import PipelineHinterProtocol from dstack._internal.server.services.plugins import apply_plugin_policies from dstack._internal.server.services.projects import ( get_member, @@ -84,7 +86,12 @@ ) from dstack._internal.server.services.resources import set_resources_defaults from dstack._internal.utils import random_names -from dstack._internal.utils.common import EntityID, EntityName, EntityNameOrID +from dstack._internal.utils.common import ( + EntityID, + EntityName, + EntityNameOrID, + get_current_datetime, +) from dstack._internal.utils.logging import get_logger from dstack._internal.utils.ssh import pkey_from_str @@ -465,19 +472,26 @@ async def get_create_instance_offers( fleet_model: Optional[FleetModel] = None, blocks: Union[int, Literal["auto"]] = 1, exclude_not_available: bool = False, + master_job_provisioning_data: Optional[JobProvisioningData] = None, + infer_master_job_provisioning_data_from_fleet_instances: bool = True, ) -> List[Tuple[Backend, InstanceOfferWithAvailability]]: multinode = False - master_job_provisioning_data = None if fleet_spec is not None: multinode = fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER if fleet_model is not None: - fleet = fleet_model_to_fleet(fleet_model) - multinode = fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER - for instance in fleet_model.instances: - jpd = instances_services.get_instance_provisioning_data(instance) - if jpd is not None: - master_job_provisioning_data = jpd - break + fleet_spec_from_model = get_fleet_spec(fleet_model) + multinode = fleet_spec_from_model.configuration.placement == InstanceGroupPlacement.CLUSTER + # The caller may override the current cluster master explicitly instead + # of inferring placement restrictions from the loaded fleet instances. + if ( + master_job_provisioning_data is None + and infer_master_job_provisioning_data_from_fleet_instances + ): + for instance in fleet_model.instances: + jpd = instances_services.get_instance_provisioning_data(instance) + if jpd is not None: + master_job_provisioning_data = jpd + break offers = await offers_services.get_offers_by_requirements( project=project, @@ -503,6 +517,7 @@ async def apply_plan( project: ProjectModel, plan: ApplyFleetPlanInput, force: bool, + pipeline_hinter: PipelineHinterProtocol, ) -> Fleet: spec = await apply_plugin_policies( user=user.name, @@ -523,6 +538,7 @@ async def apply_plan( project=project, user=user, spec=spec, + pipeline_hinter=pipeline_hinter, ) fleet_model = await get_project_fleet_model_by_name( @@ -536,6 +552,7 @@ async def apply_plan( project=project, user=user, spec=spec, + pipeline_hinter=pipeline_hinter, ) instances_ids = sorted(i.id for i in fleet_model.instances if not i.deleted) @@ -546,6 +563,8 @@ async def apply_plan( ): # Refetch after lock # TODO: Lock instances with FOR UPDATE? + # We do not respect InstanceModel.lock_* fields here because FleetPipeline does not update SSH instances. + # TODO: Respect InstanceModel.lock_* fields if FleetPipeline and apply update the same instances. res = await session.execute( select(FleetModel) .where( @@ -591,6 +610,7 @@ async def apply_plan( project=project, user=user, spec=spec, + pipeline_hinter=pipeline_hinter, ) @@ -599,6 +619,7 @@ async def create_fleet( project: ProjectModel, user: UserModel, spec: FleetSpec, + pipeline_hinter: PipelineHinterProtocol, ) -> Fleet: spec = await apply_plugin_policies( user=user.name, @@ -612,7 +633,9 @@ async def create_fleet( if spec.configuration.ssh_config is not None: _check_can_manage_ssh_fleets(user=user, project=project) - return await _create_fleet(session=session, project=project, user=user, spec=spec) + return await _create_fleet( + session=session, project=project, user=user, spec=spec, pipeline_hinter=pipeline_hinter + ) def create_fleet_instance_model( @@ -621,6 +644,7 @@ def create_fleet_instance_model( username: str, spec: FleetSpec, instance_num: int, + instance_id: Optional[uuid.UUID] = None, ) -> InstanceModel: profile = spec.merged_profile requirements = get_fleet_requirements(spec) @@ -632,6 +656,7 @@ def create_fleet_instance_model( requirements=requirements, instance_name=f"{spec.configuration.name}-{instance_num}", instance_num=instance_num, + instance_id=instance_id, reservation=spec.merged_profile.reservation, blocks=spec.configuration.blocks, tags=spec.configuration.tags, @@ -716,7 +741,7 @@ async def delete_fleets( .order_by(FleetModel.id) ) fleets_ids = list(res.scalars().unique().all()) - res = await session.execute( + stmt = ( select(InstanceModel.id) .where( InstanceModel.fleet_id.in_(fleets_ids), @@ -724,60 +749,73 @@ async def delete_fleets( ) .order_by(InstanceModel.id) ) + if instance_nums is not None: + stmt = stmt.where(InstanceModel.instance_num.in_(instance_nums)) + res = await session.execute(stmt) instances_ids = list(res.scalars().unique().all()) - if is_db_sqlite(): - # Start new transaction to see committed changes after lock - await session.commit() + await sqlite_commit(session) async with ( get_locker(get_db().dialect_name).lock_ctx(FleetModel.__tablename__, fleets_ids), get_locker(get_db().dialect_name).lock_ctx(InstanceModel.__tablename__, instances_ids), ): - # Refetch after lock. - # TODO: Do not lock fleet when deleting only instances. - res = await session.execute( - select(FleetModel) - .where( - FleetModel.project_id == project.id, - FleetModel.id.in_(fleets_ids), - FleetModel.deleted == False, - FleetModel.lock_expires_at.is_(None), - ) - .options( - selectinload(FleetModel.instances.and_(InstanceModel.id.in_(instances_ids))) - .selectinload(InstanceModel.jobs) - .load_only(JobModel.id) - ) - .options( - selectinload( - FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) - ).load_only(RunModel.status) + # Retry locking fleets to increase lock acquisition chances. + # This hack is needed until requests are queued. + fleet_models = [] + for i in range(10): + res = await session.execute( + select(FleetModel) + .where( + FleetModel.project_id == project.id, + FleetModel.id.in_(fleets_ids), + FleetModel.deleted == False, + FleetModel.lock_expires_at.is_(None), + ) + .options( + selectinload(FleetModel.instances.and_(InstanceModel.id.in_(instances_ids))) + .selectinload(InstanceModel.jobs) + .load_only(JobModel.id) + ) + .options( + selectinload( + FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) + ).load_only(RunModel.status) + ) + .order_by(FleetModel.id) # take locks in order + .with_for_update(key_share=True, of=FleetModel) + .execution_options(populate_existing=True) ) - .execution_options(populate_existing=True) - .order_by(FleetModel.id) # take locks in order - .with_for_update(key_share=True, of=FleetModel) - ) - fleet_models = res.scalars().unique().all() + fleet_models = res.scalars().unique().all() + if len(fleet_models) == len(fleets_ids): + break + await asyncio.sleep(0.5) if len(fleet_models) != len(fleets_ids): - # TODO: Make the endpoint fully async so we don't need to lock and error: - # put the request in queue and process in the background. + # TODO: Make the endpoint fully async so we don't need to lock and error. msg = ( "Failed to delete fleets: fleets are being processed currently. Try again later." if instance_nums is None else "Failed to delete fleet instances: fleets are being processed currently. Try again later." ) raise ServerClientError(msg) - res = await session.execute( - select(InstanceModel.id) - .where( - InstanceModel.id.in_(instances_ids), - InstanceModel.deleted == False, + # Retry locking instances to increase lock acquisition chances. + # This hack is needed until requests are queued. + instances_left_to_lock = set(instances_ids) + for i in range(10): + res = await session.execute( + select(InstanceModel.id) + .where( + InstanceModel.id.in_(instances_left_to_lock), + InstanceModel.deleted == False, + InstanceModel.lock_expires_at.is_(None), + ) + .order_by(InstanceModel.id) # take locks in order + .with_for_update(key_share=True, of=InstanceModel) + .execution_options(populate_existing=True) ) - .order_by(InstanceModel.id) # take locks in order - .with_for_update(key_share=True, of=InstanceModel) - .execution_options(populate_existing=True) - ) - instance_models_ids = list(res.scalars().unique().all()) - if len(instance_models_ids) != len(instances_ids): + instances_left_to_lock.difference_update(res.scalars().unique().all()) + if len(instances_left_to_lock) == 0: + break + await asyncio.sleep(0.5) + if len(instances_left_to_lock) > 0: msg = ( "Failed to delete fleets: fleet instances are being processed currently. Try again later." if instance_nums is None @@ -785,8 +823,8 @@ async def delete_fleets( ) raise ServerClientError(msg) for fleet_model in fleet_models: - fleet = fleet_model_to_fleet(fleet_model) - if fleet.spec.configuration.ssh_config is not None: + fleet_spec = get_fleet_spec(fleet_model) + if fleet_spec.configuration.ssh_config is not None: _check_can_manage_ssh_fleets(user=user, project=project) if instance_nums is None: logger.info("Deleting fleets: %s", [f.name for f in fleet_models]) @@ -867,10 +905,10 @@ def is_fleet_empty(fleet_model: FleetModel) -> bool: def is_cloud_cluster(fleet_model: FleetModel) -> bool: - fleet = fleet_model_to_fleet(fleet_model) + fleet_spec = get_fleet_spec(fleet_model) return ( - fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER - and fleet.spec.configuration.ssh_config is None + fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER + and fleet_spec.configuration.ssh_config is None ) @@ -905,6 +943,9 @@ def get_fleet_master_instance_provisioning_data( ) -> Optional[JobProvisioningData]: master_instance_provisioning_data = None if fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER: + # TODO: This legacy helper infers the cluster master from fleet instances. + # Pipeline-based provisioning should use FleetModel.current_master_instance_id + # instead of relying on instance ordering in the loaded relationship. # Offers for master jobs must be in the same cluster as existing instances. fleet_instance_models = [im for im in fleet_model.instances if not im.deleted] if len(fleet_instance_models) > 0: @@ -940,6 +981,7 @@ async def _create_fleet( project: ProjectModel, user: UserModel, spec: FleetSpec, + pipeline_hinter: PipelineHinterProtocol, ) -> Fleet: lock_namespace = f"fleet_names_{project.name}" if is_db_sqlite(): @@ -962,6 +1004,7 @@ async def _create_fleet( else: spec.configuration.name = await generate_fleet_name(session=session, project=project) + now = get_current_datetime() fleet_model = FleetModel( id=uuid.uuid4(), name=spec.configuration.name, @@ -969,6 +1012,8 @@ async def _create_fleet( status=FleetStatus.ACTIVE, spec=spec.json(), instances=[], + created_at=now, + last_processed_at=now, ) session.add(fleet_model) events.emit( @@ -1021,6 +1066,9 @@ async def _create_fleet( ) fleet_model.instances.append(instance_model) await session.commit() + if spec.configuration.ssh_config is None: + pipeline_hinter.hint_fetch(FleetModel.__name__) + pipeline_hinter.hint_fetch(InstanceModel.__name__) return fleet_model_to_fleet(fleet_model) diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index ddc3d64c44..b4dfef083f 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -341,23 +341,28 @@ async def _delete_gateways_pipeline( async with get_locker(get_db().dialect_name).lock_ctx( GatewayModel.__tablename__, gateways_ids ): - # Refetch after lock - res = await session.execute( - select(GatewayModel) - .where( - GatewayModel.id.in_(gateways_ids), - GatewayModel.project_id == project.id, - GatewayModel.lock_expires_at.is_(None), + # Retry locking gateways to increase lock acquisition chances. + # This hack is needed until requests are queued. + gateway_models = [] + for i in range(10): + res = await session.execute( + select(GatewayModel) + .where( + GatewayModel.id.in_(gateways_ids), + GatewayModel.project_id == project.id, + GatewayModel.lock_expires_at.is_(None), + ) + .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) + .order_by(GatewayModel.id) # take locks in order + .with_for_update(key_share=True, of=GatewayModel) + .execution_options(populate_existing=True) ) - .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) - .order_by(GatewayModel.id) # take locks in order - .with_for_update(key_share=True, nowait=True, of=GatewayModel) - .execution_options(populate_existing=True) - ) - gateway_models = res.scalars().all() + gateway_models = res.scalars().all() + if len(gateway_models) == len(gateways_ids): + break + await asyncio.sleep(0.5) if len(gateway_models) != len(gateways_ids): - # TODO: Make the endpoint fully async so we don't need to lock and error: - # put the request in queue and process in the background. + # TODO: Make the endpoint fully async so we don't need to lock and error. raise ServerClientError( "Failed to delete gateways: gateways are being processed currently. Try again later." ) diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index 079faf90c7..e07bce938b 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -90,6 +90,8 @@ def switch_instance_status( instance_model=instance_model, old_status=old_status, new_status=new_status, + termination_reason=instance_model.termination_reason, + termination_reason_message=instance_model.termination_reason_message, actor=actor, ) @@ -99,20 +101,26 @@ def emit_instance_status_change_event( instance_model: InstanceModel, old_status: InstanceStatus, new_status: InstanceStatus, + termination_reason: Optional[InstanceTerminationReason], + termination_reason_message: Optional[str], actor: events.AnyActor = events.SystemActor(), ) -> None: if old_status == new_status: return msg = get_instance_status_change_message( - instance_model=instance_model, old_status=old_status, new_status=new_status, + termination_reason=termination_reason, + termination_reason_message=termination_reason_message, ) events.emit(session, msg, actor=actor, targets=[events.Target.from_model(instance_model)]) def get_instance_status_change_message( - instance_model: InstanceModel, old_status: InstanceStatus, new_status: InstanceStatus + old_status: InstanceStatus, + new_status: InstanceStatus, + termination_reason: Optional[InstanceTerminationReason], + termination_reason_message: Optional[str], ) -> str: msg = f"Instance status changed {old_status.upper()} -> {new_status.upper()}" if ( @@ -120,20 +128,20 @@ def get_instance_status_change_message( or new_status == InstanceStatus.TERMINATED and old_status != InstanceStatus.TERMINATING ): - if instance_model.termination_reason is None: + if termination_reason is None: raise ValueError( f"termination_reason must be set when switching to {new_status.upper()} status" ) if ( - instance_model.termination_reason == InstanceTerminationReason.ERROR - and not instance_model.termination_reason_message + termination_reason == InstanceTerminationReason.ERROR + and not termination_reason_message ): raise ValueError( "termination_reason_message must be set when termination_reason is ERROR" ) - msg += f". Termination reason: {instance_model.termination_reason.upper()}" - if instance_model.termination_reason_message: - msg += f" ({instance_model.termination_reason_message})" + msg += f". Termination reason: {termination_reason.upper()}" + if termination_reason_message: + msg += f" ({termination_reason_message})" return msg @@ -651,11 +659,13 @@ def create_instance_model( reservation: Optional[str], blocks: Union[Literal["auto"], int], tags: Optional[Dict[str, str]], + instance_id: Optional[uuid.UUID] = None, ) -> InstanceModel: termination_policy, termination_idle_time = get_termination( profile, DEFAULT_FLEET_TERMINATION_IDLE_TIME ) - instance_id = uuid.uuid4() + if instance_id is None: + instance_id = uuid.uuid4() project_ssh_key = SSHKey( public=project.ssh_public_key.strip(), private=project.ssh_private_key.strip(), @@ -669,12 +679,14 @@ def create_instance_model( reservation=reservation, tags=tags, ) + now = common_utils.get_current_datetime() instance = InstanceModel( id=instance_id, name=instance_name, instance_num=instance_num, project=project, - created_at=common_utils.get_current_datetime(), + created_at=now, + last_processed_at=now, status=InstanceStatus.PENDING, unreachable=False, profile=profile.json(), diff --git a/src/dstack/_internal/server/services/runs/plan.py b/src/dstack/_internal/server/services/runs/plan.py index 4738622a07..9694cccd71 100644 --- a/src/dstack/_internal/server/services/runs/plan.py +++ b/src/dstack/_internal/server/services/runs/plan.py @@ -259,9 +259,12 @@ async def select_run_candidate_fleet_models_with_filters( .execution_options(populate_existing=True) ) if lock_instances: - stmt = stmt.order_by(InstanceModel.id).with_for_update( # take locks in order - key_share=True, of=InstanceModel - ) + # Skip locked instances since waiting for all the instances to unlock may take indefinite time. + # TODO: Switch to optimistic locking – implement select-lock-reselect loop. + stmt = stmt.where(InstanceModel.lock_expires_at.is_(None)) + stmt = stmt.order_by( + InstanceModel.id # take locks in order + ).with_for_update(skip_locked=True, key_share=True, of=InstanceModel) res = await session.execute(stmt) fleet_models_with_instances = list(res.unique().scalars().all()) fleet_models_with_instances_ids = [f.id for f in fleet_models_with_instances] diff --git a/src/dstack/_internal/server/services/ssh_fleets/__init__.py b/src/dstack/_internal/server/services/ssh_fleets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/dstack/_internal/server/utils/provisioning.py b/src/dstack/_internal/server/services/ssh_fleets/provisioning.py similarity index 86% rename from src/dstack/_internal/server/utils/provisioning.py rename to src/dstack/_internal/server/services/ssh_fleets/provisioning.py index fcbe3bf086..3a7c21e6dd 100644 --- a/src/dstack/_internal/server/utils/provisioning.py +++ b/src/dstack/_internal/server/services/ssh_fleets/provisioning.py @@ -14,9 +14,7 @@ normalize_arch, ) from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT - -# FIXME: ProvisioningError is a subclass of ComputeError and should not be used outside of Compute -from dstack._internal.core.errors import ProvisioningError +from dstack._internal.core.errors import SSHProvisioningError from dstack._internal.core.models.instances import ( Disk, Gpu, @@ -46,15 +44,15 @@ def detect_cpu_arch(client: paramiko.SSHClient) -> GoArchType: try: _, stdout, stderr = client.exec_command(cmd, timeout=20) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"detect_cpu_arch: {e}") from e + raise SSHProvisioningError(f"detect_cpu_arch: {e}") from e out = stdout.read().strip().decode() err = stderr.read().strip().decode() if err: - raise ProvisioningError(f"detect_cpu_arch: {cmd} failed, stdout: {out}, stderr: {err}") + raise SSHProvisioningError(f"detect_cpu_arch: {cmd} failed, stdout: {out}, stderr: {err}") try: return normalize_arch(out) except ValueError as e: - raise ProvisioningError(f"detect_cpu_arch: failed to normalize arch: {e}") from e + raise SSHProvisioningError(f"detect_cpu_arch: failed to normalize arch: {e}") from e def sftp_upload(client: paramiko.SSHClient, path: str, body: str) -> None: @@ -66,7 +64,7 @@ def sftp_upload(client: paramiko.SSHClient, path: str, body: str) -> None: sftp.putfo(io.BytesIO(body.encode()), path) sftp.close() except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"sft_upload failed: {e}") from e + raise SSHProvisioningError(f"sft_upload failed: {e}") from e def upload_envs(client: paramiko.SSHClient, working_dir: str, envs: Dict[str, str]) -> None: @@ -80,11 +78,11 @@ def upload_envs(client: paramiko.SSHClient, working_dir: str, envs: Dict[str, st out = stdout.read().strip().decode() err = stderr.read().strip().decode() if out or err: - raise ProvisioningError( + raise SSHProvisioningError( f"The command 'upload_envs' didn't work. stdout: {out}, stderr: {err}" ) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"upload_envs failed: {e}") from e + raise SSHProvisioningError(f"upload_envs failed: {e}") from e def run_pre_start_commands( @@ -98,11 +96,11 @@ def run_pre_start_commands( out = stdout.read().strip().decode() err = stderr.read().strip().decode() if out or err: - raise ProvisioningError( + raise SSHProvisioningError( f"The command 'authorized_keys' didn't work. stdout: {out}, stderr: {err}" ) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"upload authorized_keys failed: {e}") from e + raise SSHProvisioningError(f"upload authorized_keys failed: {e}") from e script = " && ".join(shim_pre_start_commands) try: @@ -110,11 +108,11 @@ def run_pre_start_commands( out = stdout.read().strip().decode() err = stderr.read().strip().decode() if out or err: - raise ProvisioningError( + raise SSHProvisioningError( f"The command 'run_pre_start_commands' didn't work. stdout: {out}, stderr: {err}" ) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"run_pre-start_commands failed: {e}") from e + raise SSHProvisioningError(f"run_pre-start_commands failed: {e}") from e def run_shim_as_systemd_service( @@ -158,11 +156,11 @@ def run_shim_as_systemd_service( out = stdout.read().strip().decode() err = stderr.read().strip().decode() if out or err: - raise ProvisioningError( + raise SSHProvisioningError( f"The command 'run_shim_as_systemd_service' didn't work. stdout: {out}, stderr: {err}" ) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"run_shim_as_systemd failed: {e}") from e + raise SSHProvisioningError(f"run_shim_as_systemd failed: {e}") from e def check_dstack_shim_service(client: paramiko.SSHClient): @@ -170,12 +168,12 @@ def check_dstack_shim_service(client: paramiko.SSHClient): _, stdout, _ = client.exec_command("sudo systemctl status dstack-shim.service", timeout=10) status = stdout.read() except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"Checking dstack-shim.service status failed: {e}") from e + raise SSHProvisioningError(f"Checking dstack-shim.service status failed: {e}") from e for raw_line in status.splitlines(): line = raw_line.decode() if line.strip().startswith("Active: failed"): - raise ProvisioningError(f"The dstack-shim service doesn't start: {line.strip()}") + raise SSHProvisioningError(f"The dstack-shim service doesn't start: {line.strip()}") def remove_host_info_if_exists(client: paramiko.SSHClient, working_dir: str) -> None: @@ -188,7 +186,7 @@ def remove_host_info_if_exists(client: paramiko.SSHClient, working_dir: str) -> if err: logger.debug(f"{HOST_INFO_FILE} hasn't been removed: %s", err) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"remove_host_info_if_exists failed: {e}") + raise SSHProvisioningError(f"remove_host_info_if_exists failed: {e}") def remove_dstack_runner_if_exists(client: paramiko.SSHClient, path: str) -> None: @@ -198,7 +196,7 @@ def remove_dstack_runner_if_exists(client: paramiko.SSHClient, path: str) -> Non if err: logger.debug(f"{path} hasn't been removed: %s", err) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"remove_dstack_runner_if_exists failed: {e}") + raise SSHProvisioningError(f"remove_dstack_runner_if_exists failed: {e}") def get_host_info(client: paramiko.SSHClient, working_dir: str) -> Dict[str, Any]: @@ -224,11 +222,11 @@ def get_host_info(client: paramiko.SSHClient, working_dir: str) -> Dict[str, Any return host_info except ValueError: # JSON parse error check_dstack_shim_service(client) - raise ProvisioningError("Cannot parse host_info") + raise SSHProvisioningError("Cannot parse host_info") time.sleep(iter_delay) else: check_dstack_shim_service(client) - raise ProvisioningError("Cannot get host_info") + raise SSHProvisioningError("Cannot get host_info") def get_shim_healthcheck(client: paramiko.SSHClient) -> str: @@ -240,7 +238,7 @@ def get_shim_healthcheck(client: paramiko.SSHClient) -> str: return healthcheck logger.debug("healthcheck is empty. retry") time.sleep(iter_delay) - raise ProvisioningError("Cannot get HealthcheckResponse") + raise SSHProvisioningError("Cannot get HealthcheckResponse") def _get_shim_healthcheck(client: paramiko.SSHClient) -> Optional[str]: @@ -251,9 +249,11 @@ def _get_shim_healthcheck(client: paramiko.SSHClient) -> Optional[str]: out = stdout.read().strip().decode() err = stderr.read().strip().decode() except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"get_shim_healthcheck failed: {e}") from e + raise SSHProvisioningError(f"get_shim_healthcheck failed: {e}") from e if err: - raise ProvisioningError(f"get_shim_healthcheck didn't work. stdout: {out}, stderr: {err}") + raise SSHProvisioningError( + f"get_shim_healthcheck didn't work. stdout: {out}, stderr: {err}" + ) if not out: return None return out @@ -306,7 +306,7 @@ def get_paramiko_connection( ) -> Generator[paramiko.SSHClient, None, None]: if proxy is not None: if proxy_pkeys is None: - raise ProvisioningError("Missing proxy private keys") + raise SSHProvisioningError("Missing proxy private keys") proxy_ctx = get_paramiko_connection( proxy.username, proxy.hostname, proxy.port, proxy_pkeys ) @@ -321,7 +321,7 @@ def get_paramiko_connection( try: proxy_channel = transport.open_channel("direct-tcpip", (host, port), ("", 0)) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"Proxy channel failed: {e}") from e + raise SSHProvisioningError(f"Proxy channel failed: {e}") from e client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) for pkey in pkeys: logger.debug("Try to connect to %s with key %s", conn_url, pkey.fingerprint) @@ -333,7 +333,7 @@ def get_paramiko_connection( f'Authentication failed to connect to "{conn_url}" and {pkey.fingerprint}' ) keys_fp = ", ".join(f"{pk.fingerprint!r}" for pk in pkeys) - raise ProvisioningError( + raise SSHProvisioningError( f"SSH connection to the {conn_url} with keys [{keys_fp}] was unsuccessful" ) @@ -347,7 +347,7 @@ def _paramiko_connect( channel: Optional[paramiko.Channel] = None, ) -> bool: """ - Returns `True` if connected, `False` if auth failed, and raises `ProvisioningError` + Returns `True` if connected, `False` if auth failed, and raises `SSHProvisioningError` on other errors. """ try: @@ -365,4 +365,4 @@ def _paramiko_connect( except paramiko.AuthenticationException: return False except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"Connect failed: {e}") from e + raise SSHProvisioningError(f"Connect failed: {e}") from e diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index 1c846c724f..ac2f88a5d1 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -1,3 +1,4 @@ +import asyncio import uuid from datetime import datetime, timedelta from typing import List, Optional @@ -353,24 +354,29 @@ async def _delete_volumes_pipeline( await session.commit() logger.info("Deleting volumes: %s", [v.name for v in volume_models]) async with get_locker(get_db().dialect_name).lock_ctx(VolumeModel.__tablename__, volumes_ids): - # Refetch after lock - res = await session.execute( - select(VolumeModel) - .where( - VolumeModel.project_id == project.id, - VolumeModel.id.in_(volumes_ids), - VolumeModel.deleted == False, - VolumeModel.lock_expires_at.is_(None), + # Retry locking volumes to increase lock acquisition chances. + # This hack is needed until requests are queued. + volume_models = [] + for i in range(10): + res = await session.execute( + select(VolumeModel) + .where( + VolumeModel.project_id == project.id, + VolumeModel.id.in_(volumes_ids), + VolumeModel.deleted == False, + VolumeModel.lock_expires_at.is_(None), + ) + .options(selectinload(VolumeModel.attachments)) + .order_by(VolumeModel.id) # take locks in order + .with_for_update(key_share=True, of=VolumeModel) + .execution_options(populate_existing=True) ) - .options(selectinload(VolumeModel.attachments)) - .execution_options(populate_existing=True) - .order_by(VolumeModel.id) # take locks in order - .with_for_update(key_share=True, of=VolumeModel) - ) - volume_models = res.scalars().unique().all() + volume_models = res.scalars().unique().all() + if len(volume_models) == len(volumes_ids): + break + await asyncio.sleep(0.5) if len(volume_models) != len(volumes_ids): - # TODO: Make the endpoint fully async so we don't need to lock and error: - # put the request in queue and process in the background. + # TODO: Make the endpoint fully async so we don't need to lock and error. raise ServerClientError( "Failed to delete volumes: volumes are being processed currently. Try again later." ) diff --git a/src/dstack/_internal/utils/common.py b/src/dstack/_internal/utils/common.py index 2db91882ff..c761bfcc28 100644 --- a/src/dstack/_internal/utils/common.py +++ b/src/dstack/_internal/utils/common.py @@ -8,7 +8,7 @@ from datetime import datetime, timedelta, timezone from functools import partial from pathlib import Path -from typing import Any, Iterable, List, Optional, TypeVar, Union +from typing import Any, Final, Iterable, List, Optional, TypeVar, Union from urllib.parse import urlparse from uuid import UUID @@ -17,6 +17,17 @@ from dstack._internal.core.models.common import Duration +class Unset: + pass + + +UNSET: Final = Unset() +""" +Use `UNSET` as kwargs default value to distinguish between +specified and non-specified `Optional` values. +""" + + @dataclass class EntityName: name: str diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py index 6d24669f7c..776240fc47 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py @@ -1,5 +1,6 @@ +import asyncio import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from unittest.mock import Mock, patch import pytest @@ -9,7 +10,11 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.compute_groups import ComputeGroupStatus from dstack._internal.server.background.pipeline_tasks.base import PipelineItem -from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupWorker +from dstack._internal.server.background.pipeline_tasks.compute_groups import ( + ComputeGroupFetcher, + ComputeGroupPipeline, + ComputeGroupWorker, +) from dstack._internal.server.models import ComputeGroupModel from dstack._internal.server.testing.common import ( ComputeMockSpec, @@ -17,6 +22,7 @@ create_fleet, create_project, ) +from dstack._internal.utils.common import get_current_datetime @pytest.fixture @@ -24,6 +30,17 @@ def worker() -> ComputeGroupWorker: return ComputeGroupWorker(queue=Mock(), heartbeater=Mock()) +@pytest.fixture +def fetcher() -> ComputeGroupFetcher: + return ComputeGroupFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=15), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + def _compute_group_to_pipeline_item(compute_group: ComputeGroupModel) -> PipelineItem: assert compute_group.lock_token is not None assert compute_group.lock_expires_at is not None @@ -36,9 +53,104 @@ def _compute_group_to_pipeline_item(compute_group: ComputeGroupModel) -> Pipelin ) +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestComputeGroupFetcher: + async def test_fetch_selects_eligible_compute_groups_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: ComputeGroupFetcher + ): + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + eligible = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=stale - timedelta(seconds=2), + ) + finished = await create_compute_group( + session=session, + project=project, + fleet=fleet, + status=ComputeGroupStatus.TERMINATED, + last_processed_at=stale - timedelta(seconds=1), + ) + recent = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=now, + ) + locked = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=stale, + ) + locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert [item.id for item in items] == [eligible.id] + + for compute_group in [eligible, finished, recent, locked]: + await session.refresh(compute_group) + + assert eligible.lock_owner == ComputeGroupPipeline.__name__ + assert eligible.lock_expires_at is not None + assert eligible.lock_token is not None + + assert finished.lock_owner is None + assert recent.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_returns_oldest_compute_groups_first_up_to_limit( + self, test_db, session: AsyncSession, fetcher: ComputeGroupFetcher + ): + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + now = get_current_datetime() + + oldest = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=now - timedelta(minutes=3), + ) + middle = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=now - timedelta(minutes=2), + ) + newest = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=now - timedelta(minutes=1), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == ComputeGroupPipeline.__name__ + assert middle.lock_owner == ComputeGroupPipeline.__name__ + assert newest.lock_owner is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestComputeGroupWorker: - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_terminates_compute_group( self, test_db, session: AsyncSession, worker: ComputeGroupWorker ): @@ -64,8 +176,6 @@ async def test_terminates_compute_group( assert compute_group.status == ComputeGroupStatus.TERMINATED assert compute_group.deleted - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_retries_compute_group_termination( self, test_db, session: AsyncSession, worker: ComputeGroupWorker ): diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index 746ddf2ea4..d2b53226e3 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -1,17 +1,25 @@ +import asyncio import uuid -from datetime import datetime, timezone -from unittest.mock import Mock +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, Mock, patch import pytest from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.models.fleets import FleetNodesSpec, FleetStatus -from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.models.fleets import ( + FleetNodesSpec, + FleetStatus, + InstanceGroupPlacement, +) +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason from dstack._internal.core.models.runs import RunStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole +from dstack._internal.server.background.pipeline_tasks import fleets as fleets_pipeline from dstack._internal.server.background.pipeline_tasks.base import PipelineItem from dstack._internal.server.background.pipeline_tasks.fleets import ( + FleetFetcher, + FleetPipeline, FleetWorker, ) from dstack._internal.server.models import FleetModel, InstanceModel @@ -24,8 +32,12 @@ create_repo, create_run, create_user, + get_fleet_configuration, get_fleet_spec, + get_job_provisioning_data, + get_ssh_fleet_configuration, ) +from dstack._internal.utils.common import get_current_datetime @pytest.fixture @@ -33,6 +45,17 @@ def worker() -> FleetWorker: return FleetWorker(queue=Mock(), heartbeater=Mock()) +@pytest.fixture +def fetcher() -> FleetFetcher: + return FleetFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=60), + lock_timeout=timedelta(seconds=20), + heartbeater=Mock(), + ) + + def _fleet_to_pipeline_item(fleet: FleetModel) -> PipelineItem: assert fleet.lock_token is not None assert fleet.lock_expires_at is not None @@ -45,9 +68,735 @@ def _fleet_to_pipeline_item(fleet: FleetModel) -> PipelineItem: ) +async def _lock_fleet_for_processing(session: AsyncSession, fleet: FleetModel) -> None: + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestFleetFetcher: + async def test_fetch_selects_eligible_fleets_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: FleetFetcher + ): + project = await create_project(session) + now = get_current_datetime() + + stale = await create_fleet( + session=session, + project=project, + last_processed_at=now - timedelta(minutes=3), + ) + just_created = await create_fleet( + session=session, + project=project, + created_at=now, + last_processed_at=now, + name="just-created", + ) + deleted = await create_fleet( + session=session, + project=project, + deleted=True, + name="deleted", + last_processed_at=now - timedelta(minutes=2), + ) + recent = await create_fleet( + session=session, + project=project, + created_at=now - timedelta(minutes=2), + last_processed_at=now, + name="recent", + ) + locked = await create_fleet( + session=session, + project=project, + name="locked", + last_processed_at=now - timedelta(minutes=1, seconds=1), + ) + locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert {item.id for item in items} == {stale.id, just_created.id} + + for fleet in [stale, just_created, deleted, recent, locked]: + await session.refresh(fleet) + + assert stale.lock_owner == FleetPipeline.__name__ + assert just_created.lock_owner == FleetPipeline.__name__ + assert stale.lock_expires_at is not None + assert just_created.lock_expires_at is not None + assert stale.lock_token is not None + assert just_created.lock_token is not None + assert len({stale.lock_token, just_created.lock_token}) == 1 + + assert deleted.lock_owner is None + assert recent.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_returns_oldest_fleets_first_up_to_limit( + self, test_db, session: AsyncSession, fetcher: FleetFetcher + ): + project = await create_project(session) + now = get_current_datetime() + + oldest = await create_fleet( + session=session, + project=project, + name="oldest", + last_processed_at=now - timedelta(minutes=4), + ) + middle = await create_fleet( + session=session, + project=project, + name="middle", + last_processed_at=now - timedelta(minutes=3), + ) + newest = await create_fleet( + session=session, + project=project, + name="newest", + last_processed_at=now - timedelta(minutes=2), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == FleetPipeline.__name__ + assert middle.lock_owner == FleetPipeline.__name__ + assert newest.lock_owner is None + + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestFleetWorker: + async def test_skips_instance_locking_for_ssh_fleet( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec(conf=get_ssh_fleet_configuration()), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + original_last_processed_at = fleet.last_processed_at + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + instance.lock_token = uuid.uuid4() + instance.lock_expires_at = datetime(2025, 1, 2, 3, 5, tzinfo=timezone.utc) + instance.lock_owner = "OtherPipeline" + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(instance) + assert not fleet.deleted + assert fleet.lock_owner is None + assert fleet.lock_token is None + assert fleet.lock_expires_at is None + assert fleet.last_processed_at > original_last_processed_at + assert instance.lock_owner == "OtherPipeline" + + async def test_skips_instance_locking_when_fleet_is_not_ready_for_consolidation( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + original_last_processed_at = fleet.last_processed_at + original_last_consolidated_at = datetime.now(timezone.utc) + fleet.consolidation_attempt = 1 + fleet.last_consolidated_at = original_last_consolidated_at + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + instance.lock_token = uuid.uuid4() + instance.lock_expires_at = datetime(2025, 1, 2, 3, 5, tzinfo=timezone.utc) + instance.lock_owner = "OtherPipeline" + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(instance) + assert not fleet.deleted + assert fleet.consolidation_attempt == 1 + assert fleet.last_consolidated_at == original_last_consolidated_at + assert fleet.lock_owner is None + assert fleet.lock_token is None + assert fleet.lock_expires_at is None + assert fleet.last_processed_at > original_last_processed_at + assert instance.lock_owner == "OtherPipeline" + + async def test_resets_fleet_lock_when_not_all_instances_can_be_locked( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=0, + ) + locked_elsewhere = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=1, + ) + original_last_processed_at = fleet.last_processed_at + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + fleet.lock_owner = FleetPipeline.__name__ + locked_elsewhere.lock_token = uuid.uuid4() + locked_elsewhere.lock_expires_at = datetime(2025, 1, 2, 3, 5, tzinfo=timezone.utc) + locked_elsewhere.lock_owner = "OtherPipeline" + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(locked_elsewhere) + assert fleet.lock_owner == FleetPipeline.__name__ + assert fleet.lock_token is None + assert fleet.lock_expires_at is None + assert fleet.last_processed_at > original_last_processed_at + assert locked_elsewhere.lock_owner == "OtherPipeline" + + async def test_unlocks_instances_after_consolidation( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=0, + ) + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(instance) + assert instance.lock_owner is None + assert instance.lock_token is None + assert instance.lock_expires_at is None + + async def test_unlocks_instances_when_fleet_lock_token_changes_after_processing( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=0, + ) + await _lock_fleet_for_processing(session, fleet) + + async def mock_process_fleet(*args, **kwargs): + fleet_model = args[0] + fleet_model.lock_token = uuid.uuid4() + return fleets_pipeline._ProcessResult() + + with patch.object( + fleets_pipeline, + "_process_fleet", + AsyncMock(side_effect=mock_process_fleet), + ): + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(instance) + assert instance.lock_owner is None + assert instance.lock_token is None + assert instance.lock_expires_at is None + + async def test_syncs_initial_current_master_for_cluster_fleet( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + first_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=1, + ) + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert fleet.current_master_instance_id == first_instance.id + + async def test_keeps_current_master_when_it_is_still_active( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + current_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PROVISIONING, + job_provisioning_data=get_job_provisioning_data(), + instance_num=1, + ) + fleet.current_master_instance_id = current_master.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert fleet.current_master_instance_id == current_master.id + + async def test_promotes_provisioned_survivor_when_current_master_terminated( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=1, target=1, max=2), + ) + ), + ) + terminated_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + provisioned_survivor = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + job_provisioning_data=get_job_provisioning_data(), + instance_num=1, + ) + fleet.current_master_instance_id = terminated_master.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(terminated_master) + assert terminated_master.deleted + assert fleet.current_master_instance_id == provisioned_survivor.id + + async def test_promotes_next_bootstrap_candidate_when_current_master_terminated( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=1, target=1, max=2), + ) + ), + ) + terminated_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + next_candidate = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=1, + ) + fleet.current_master_instance_id = terminated_master.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert fleet.current_master_instance_id == next_candidate.id + + async def test_does_not_elect_terminating_bootstrap_candidate_as_master( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=1, target=1, max=3), + ) + ), + ) + terminated_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATING, + job_provisioning_data=None, + offer=None, + instance_num=1, + ) + pending_candidate = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=2, + ) + fleet.current_master_instance_id = terminated_master.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert fleet.current_master_instance_id == pending_candidate.id + + async def test_clears_current_master_for_non_cluster_fleet( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec(), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + fleet.current_master_instance_id = instance.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert fleet.current_master_instance_id is None + + async def test_syncs_current_master_after_creating_missing_instances( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + instances = ( + ( + await session.execute( + select(InstanceModel) + .where(InstanceModel.fleet_id == fleet.id, InstanceModel.deleted == False) + .order_by(InstanceModel.instance_num, InstanceModel.created_at) + ) + ) + .scalars() + .all() + ) + assert len(instances) == 2 + assert fleet.current_master_instance_id == instances[0].id + + async def test_prefers_surviving_instance_over_new_replacement_for_master_election( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + terminated_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + surviving_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=1, + ) + fleet.current_master_instance_id = terminated_master.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(terminated_master) + await session.refresh(surviving_instance) + non_deleted_instances = ( + ( + await session.execute( + select(InstanceModel) + .where(InstanceModel.fleet_id == fleet.id, InstanceModel.deleted == False) + .order_by(InstanceModel.instance_num, InstanceModel.created_at) + ) + ) + .scalars() + .all() + ) + + assert terminated_master.deleted + assert fleet.current_master_instance_id == surviving_instance.id + assert len(non_deleted_instances) == 2 + assert any( + instance.id != surviving_instance.id and instance.instance_num == 0 + for instance in non_deleted_instances + ) + + async def test_min_zero_failed_master_terminates_unprovisioned_siblings( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=0, target=3, max=3), + ) + ), + ) + failed_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + failed_master.termination_reason = InstanceTerminationReason.NO_OFFERS + sibling1 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=1, + ) + sibling2 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=2, + ) + fleet.current_master_instance_id = failed_master.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(failed_master) + await session.refresh(sibling1) + await session.refresh(sibling2) + assert failed_master.deleted + assert sibling1.status == InstanceStatus.TERMINATED + assert sibling2.status == InstanceStatus.TERMINATED + assert sibling1.termination_reason == InstanceTerminationReason.MASTER_FAILED + assert sibling2.termination_reason == InstanceTerminationReason.MASTER_FAILED + assert fleet.current_master_instance_id is None + + async def test_min_zero_failed_master_preserves_provisioned_survivor( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=0, target=2, max=2), + ) + ), + ) + failed_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + failed_master.termination_reason = InstanceTerminationReason.NO_OFFERS + provisioned_survivor = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + job_provisioning_data=get_job_provisioning_data(), + instance_num=1, + ) + pending_sibling = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=2, + ) + fleet.current_master_instance_id = failed_master.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(provisioned_survivor) + await session.refresh(pending_sibling) + assert provisioned_survivor.status == InstanceStatus.IDLE + assert pending_sibling.status == InstanceStatus.PENDING + assert pending_sibling.termination_reason is None + assert fleet.current_master_instance_id == provisioned_survivor.id + async def test_deletes_empty_autocreated_fleet( self, test_db, session: AsyncSession, worker: FleetWorker ): @@ -392,7 +1141,6 @@ async def test_consolidation_attempt_resets_when_no_changes( ) assert len(instances) == 1 assert fleet.consolidation_attempt == 0 - assert ( - fleet.last_consolidated_at is not None - and fleet.last_consolidated_at > previous_last_consolidated_at - ) + last_consolidated_at = fleet.last_consolidated_at + assert last_consolidated_at + assert last_consolidated_at > previous_last_consolidated_at diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py index 59cbd370e9..a1d7b360f9 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py @@ -1,5 +1,6 @@ +import asyncio import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from unittest.mock import MagicMock, Mock, patch import pytest @@ -10,6 +11,8 @@ from dstack._internal.core.errors import BackendError from dstack._internal.core.models.gateways import GatewayProvisioningData, GatewayStatus from dstack._internal.server.background.pipeline_tasks.gateways import ( + GatewayFetcher, + GatewayPipeline, GatewayPipelineItem, GatewayWorker, ) @@ -23,6 +26,7 @@ create_project, list_events, ) +from dstack._internal.utils.common import get_current_datetime @pytest.fixture @@ -30,6 +34,17 @@ def worker() -> GatewayWorker: return GatewayWorker(queue=Mock(), heartbeater=Mock()) +@pytest.fixture +def fetcher() -> GatewayFetcher: + return GatewayFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=15), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + def _gateway_to_pipeline_item(gateway_model: GatewayModel) -> GatewayPipelineItem: assert gateway_model.lock_token is not None assert gateway_model.lock_expires_at is not None @@ -44,6 +59,167 @@ def _gateway_to_pipeline_item(gateway_model: GatewayModel) -> GatewayPipelineIte ) +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestGatewayFetcher: + async def test_fetch_selects_eligible_gateways_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: GatewayFetcher + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + submitted = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="submitted", + status=GatewayStatus.SUBMITTED, + last_processed_at=stale - timedelta(seconds=3), + ) + provisioning = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="provisioning", + status=GatewayStatus.PROVISIONING, + last_processed_at=stale - timedelta(seconds=2), + ) + to_be_deleted = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="to-be-deleted", + status=GatewayStatus.RUNNING, + last_processed_at=stale - timedelta(seconds=1), + ) + to_be_deleted.to_be_deleted = True + + just_created = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="just-created", + status=GatewayStatus.SUBMITTED, + last_processed_at=now, + ) + just_created.created_at = now + just_created.last_processed_at = now + + ineligible_status = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="ineligible-status", + status=GatewayStatus.RUNNING, + last_processed_at=stale, + ) + recent = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="recent", + status=GatewayStatus.SUBMITTED, + last_processed_at=now, + ) + recent.created_at = now - timedelta(minutes=2) + recent.last_processed_at = now + + locked = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="locked", + status=GatewayStatus.SUBMITTED, + last_processed_at=stale + timedelta(seconds=1), + ) + locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert {item.id for item in items} == { + submitted.id, + provisioning.id, + to_be_deleted.id, + just_created.id, + } + assert {(item.id, item.status, item.to_be_deleted) for item in items} == { + (submitted.id, GatewayStatus.SUBMITTED, False), + (provisioning.id, GatewayStatus.PROVISIONING, False), + (to_be_deleted.id, GatewayStatus.RUNNING, True), + (just_created.id, GatewayStatus.SUBMITTED, False), + } + + for gateway in [ + submitted, + provisioning, + to_be_deleted, + just_created, + ineligible_status, + recent, + locked, + ]: + await session.refresh(gateway) + + fetched_gateways = [submitted, provisioning, to_be_deleted, just_created] + assert all(gateway.lock_owner == GatewayPipeline.__name__ for gateway in fetched_gateways) + assert all(gateway.lock_expires_at is not None for gateway in fetched_gateways) + assert all(gateway.lock_token is not None for gateway in fetched_gateways) + assert len({gateway.lock_token for gateway in fetched_gateways}) == 1 + + assert ineligible_status.lock_owner is None + assert recent.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_returns_oldest_gateways_first_up_to_limit( + self, test_db, session: AsyncSession, fetcher: GatewayFetcher + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + now = get_current_datetime() + + oldest = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="oldest", + status=GatewayStatus.SUBMITTED, + last_processed_at=now - timedelta(minutes=3), + ) + middle = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="middle", + status=GatewayStatus.PROVISIONING, + last_processed_at=now - timedelta(minutes=2), + ) + newest = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="newest", + status=GatewayStatus.SUBMITTED, + last_processed_at=now - timedelta(minutes=1), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == GatewayPipeline.__name__ + assert middle.lock_owner == GatewayPipeline.__name__ + assert newest.lock_owner is None + + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestGatewayWorkerSubmitted: diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/__init__.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/conftest.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/conftest.py new file mode 100644 index 0000000000..f7600e0ba6 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/conftest.py @@ -0,0 +1,58 @@ +import asyncio +import datetime as dt +from unittest.mock import Mock + +import pytest + +from dstack._internal.core.backends.base.compute import GoArchType +from dstack._internal.server.background.pipeline_tasks.instances import ( + InstanceFetcher, + InstanceWorker, +) +from dstack._internal.server.background.pipeline_tasks.instances import ( + ssh_deploy as instances_ssh_deploy, +) +from dstack._internal.server.schemas.instances import InstanceCheck + + +@pytest.fixture +def fetcher() -> InstanceFetcher: + return InstanceFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=dt.timedelta(seconds=10), + lock_timeout=dt.timedelta(seconds=30), + heartbeater=Mock(), + ) + + +@pytest.fixture +def worker() -> InstanceWorker: + return InstanceWorker(queue=asyncio.Queue(), heartbeater=Mock()) + + +@pytest.fixture +def host_info() -> dict: + return { + "gpu_vendor": "nvidia", + "gpu_name": "T4", + "gpu_memory": 16384, + "gpu_count": 1, + "addresses": ["192.168.100.100/24"], + "disk_size": 260976517120, + "cpus": 32, + "memory": 33544130560, + } + + +@pytest.fixture +def deploy_instance_mock(monkeypatch: pytest.MonkeyPatch, host_info: dict) -> Mock: + mock = Mock( + return_value=( + InstanceCheck(reachable=True), + host_info, + GoArchType.AMD64, + ) + ) + monkeypatch.setattr(instances_ssh_deploy, "_deploy_instance", mock) + return mock diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/helpers.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/helpers.py new file mode 100644 index 0000000000..81eb0fde5c --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/helpers.py @@ -0,0 +1,40 @@ +import datetime as dt +import uuid + +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.server.background.pipeline_tasks.instances import ( + InstancePipeline, + InstancePipelineItem, + InstanceWorker, +) +from dstack._internal.server.models import InstanceModel + +LOCK_EXPIRES_AT = dt.datetime(2025, 1, 2, 3, 4, tzinfo=dt.timezone.utc) + + +def instance_to_pipeline_item(instance_model: InstanceModel) -> InstancePipelineItem: + assert instance_model.lock_token is not None + assert instance_model.lock_expires_at is not None + return InstancePipelineItem( + __tablename__=instance_model.__tablename__, + id=instance_model.id, + lock_token=instance_model.lock_token, + lock_expires_at=instance_model.lock_expires_at, + prev_lock_expired=False, + status=instance_model.status, + ) + + +def lock_instance(instance_model: InstanceModel) -> None: + instance_model.lock_token = uuid.uuid4() + instance_model.lock_expires_at = LOCK_EXPIRES_AT + instance_model.lock_owner = InstancePipeline.__name__ + + +async def process_instance( + session: AsyncSession, worker: InstanceWorker, instance_model: InstanceModel +) -> None: + lock_instance(instance_model) + await session.commit() + await worker.process(instance_to_pipeline_item(instance_model)) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py new file mode 100644 index 0000000000..b555556881 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py @@ -0,0 +1,944 @@ +import datetime as dt +import logging +from unittest.mock import Mock + +import pytest +import pytest_asyncio +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.fleets import FleetNodesSpec +from dstack._internal.core.models.health import HealthStatus +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.core.models.profiles import TerminationPolicy +from dstack._internal.core.models.runs import JobStatus +from dstack._internal.server.background.pipeline_tasks.instances import InstanceWorker +from dstack._internal.server.background.pipeline_tasks.instances import check as instances_check +from dstack._internal.server.models import InstanceHealthCheckModel, InstanceModel +from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse, DCGMHealthResult +from dstack._internal.server.schemas.instances import InstanceCheck +from dstack._internal.server.schemas.runner import ( + ComponentInfo, + ComponentName, + ComponentStatus, + HealthcheckResponse, + InstanceHealthResponse, + TaskListResponse, +) +from dstack._internal.server.services.runner.client import ComponentList, ShimClient +from dstack._internal.server.testing.common import ( + create_fleet, + create_instance, + create_job, + create_project, + create_repo, + create_run, + create_user, + get_fleet_configuration, + get_fleet_spec, + get_remote_connection_info, + list_events, +) +from dstack._internal.utils.common import get_current_datetime +from tests._internal.server.background.pipeline_tasks.test_instances.helpers import ( + process_instance, +) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("image_config_mock") +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestCheckInstance: + async def test_check_shim_transitions_provisioning_on_ready( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + ) + instance.termination_deadline = get_current_datetime() + dt.timedelta(days=1) + await session.commit() + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.termination_deadline is None + + async def test_check_shim_transitions_provisioning_on_terminating( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + ) + instance.started_at = get_current_datetime() + dt.timedelta(minutes=-20) + await session.commit() + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False, message="Shim problem")), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATING + assert instance.termination_deadline is not None + + async def test_check_shim_transitions_provisioning_on_busy( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + ) + instance.termination_deadline = get_current_datetime().replace( + tzinfo=dt.timezone.utc + ) + dt.timedelta(days=1) + job = await create_job( + session=session, + run=run, + status=JobStatus.SUBMITTED, + instance=instance, + ) + await session.commit() + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + await session.refresh(job) + + assert instance.status == InstanceStatus.BUSY + assert instance.termination_deadline is None + assert job.instance == instance + + async def test_check_shim_start_termination_deadline( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=False, + ) + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False, message="SSH connection fail")), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is True + assert instance.termination_deadline is not None + assert instance.termination_deadline.replace( + tzinfo=dt.timezone.utc + ) > get_current_datetime() + dt.timedelta(minutes=19) + + async def test_check_shim_does_not_start_termination_deadline_with_ssh_instance( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=False, + remote_connection_info=get_remote_connection_info(), + ) + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False, message="SSH connection fail")), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is True + assert instance.termination_deadline is None + + async def test_check_shim_stop_termination_deadline( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + ) + instance.termination_deadline = get_current_datetime() + dt.timedelta(minutes=19) + await session.commit() + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.termination_deadline is None + + async def test_check_shim_terminate_instance_by_deadline( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + ) + termination_deadline_time = get_current_datetime() + dt.timedelta(minutes=-19) + instance.termination_deadline = termination_deadline_time + await session.commit() + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False, message="Not ok")), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATING + assert instance.termination_deadline == termination_deadline_time + assert instance.termination_reason == InstanceTerminationReason.UNREACHABLE + + @pytest.mark.parametrize( + ["termination_policy", "has_job"], + [ + pytest.param(TerminationPolicy.DESTROY_AFTER_IDLE, False, id="destroy-no-job"), + pytest.param(TerminationPolicy.DESTROY_AFTER_IDLE, True, id="destroy-with-job"), + pytest.param(TerminationPolicy.DONT_DESTROY, False, id="dont-destroy-no-job"), + pytest.param(TerminationPolicy.DONT_DESTROY, True, id="dont-destroy-with-job"), + ], + ) + async def test_check_shim_process_unreachable_state( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + termination_policy: TerminationPolicy, + has_job: bool, + ): + project = await create_project(session=session) + if has_job: + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job = await create_job( + session=session, + run=run, + status=JobStatus.SUBMITTED, + ) + else: + job = None + instance = await create_instance( + session=session, + project=project, + created_at=get_current_datetime(), + termination_policy=termination_policy, + status=InstanceStatus.IDLE, + unreachable=True, + job=job, + ) + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + events = await list_events(session) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is False + assert len(events) == 1 + assert events[0].message == "Instance became reachable" + + @pytest.mark.parametrize("health_status", [HealthStatus.HEALTHY, HealthStatus.FAILURE]) + async def test_check_shim_switch_to_unreachable_state( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + health_status: HealthStatus, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=False, + health_status=health_status, + ) + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False)), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + events = await list_events(session) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is True + assert instance.health == health_status + assert len(events) == 1 + assert events[0].message == "Instance became unreachable" + + async def test_check_shim_check_instance_health( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=False, + health_status=HealthStatus.HEALTHY, + ) + health_response = InstanceHealthResponse( + dcgm=DCGMHealthResponse( + overall_health=DCGMHealthResult.DCGM_HEALTH_RESULT_WARN, + incidents=[], + ) + ) + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock( + return_value=InstanceCheck( + reachable=True, + health_response=health_response, + ) + ), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + events = await list_events(session) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is False + assert instance.health == HealthStatus.WARNING + assert len(events) == 1 + assert events[0].message == "Instance health changed HEALTHY -> WARNING" + + res = await session.execute(select(InstanceHealthCheckModel)) + health_check = res.scalars().one() + assert health_check.status == HealthStatus.WARNING + assert health_check.response == health_response.json() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestProcessIdleTimeout: + async def test_does_not_terminate_by_idle_timeout_when_fleet_at_min_nodes( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + get_fleet_configuration(nodes=FleetNodesSpec(min=1, target=1, max=1)) + ), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + instance.termination_idle_time = 300 + instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE + instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19) + await session.commit() + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + + await process_instance(session, worker, instance) + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.termination_reason is None + + async def test_terminates_by_idle_timeout_when_fleet_above_min_nodes( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + get_fleet_configuration(nodes=FleetNodesSpec(min=1, target=2, max=2)) + ), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + instance.termination_idle_time = 300 + instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE + instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19) + await session.commit() + + await process_instance(session, worker, instance) + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATING + assert instance.termination_reason == InstanceTerminationReason.IDLE_TIMEOUT + + async def test_terminate_by_idle_timeout( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + ) + instance.termination_idle_time = 300 + instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE + instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19) + await session.commit() + + await process_instance(session, worker, instance) + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATING + assert instance.termination_reason == InstanceTerminationReason.IDLE_TIMEOUT + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class BaseTestMaybeInstallComponents: + EXPECTED_VERSION = "0.20.1" + + @pytest_asyncio.fixture + async def instance(self, session: AsyncSession) -> InstanceModel: + project = await create_project(session=session) + return await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + + @pytest.fixture + def component_list(self) -> ComponentList: + return ComponentList() + + @pytest.fixture + def debug_task_log(self, caplog: pytest.LogCaptureFixture) -> pytest.LogCaptureFixture: + caplog.set_level(level=logging.DEBUG, logger=instances_check.__name__) + return caplog + + @pytest.fixture + def shim_client_mock( + self, + monkeypatch: pytest.MonkeyPatch, + component_list: ComponentList, + ) -> Mock: + mock = Mock(spec_set=ShimClient) + mock.healthcheck.return_value = HealthcheckResponse( + service="dstack-shim", + version=self.EXPECTED_VERSION, + ) + mock.get_instance_health.return_value = InstanceHealthResponse() + mock.get_components.return_value = component_list + mock.list_tasks.return_value = TaskListResponse(tasks=[]) + mock.is_safe_to_restart.return_value = False + monkeypatch.setattr( + "dstack._internal.server.services.runner.client.ShimClient", + Mock(return_value=mock), + ) + return mock + + +@pytest.mark.usefixtures("get_dstack_runner_version_mock") +class TestMaybeInstallRunner(BaseTestMaybeInstallComponents): + @pytest.fixture + def component_list(self) -> ComponentList: + components = ComponentList() + components.add( + ComponentInfo( + name=ComponentName.RUNNER, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + return components + + @pytest.fixture + def get_dstack_runner_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=self.EXPECTED_VERSION) + monkeypatch.setattr(instances_check, "get_dstack_runner_version", mock) + return mock + + @pytest.fixture + def get_dstack_runner_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value="https://example.com/runner") + monkeypatch.setattr(instances_check, "get_dstack_runner_download_url", mock) + return mock + + async def test_cannot_determine_expected_version( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_runner_version_mock: Mock, + ): + get_dstack_runner_version_mock.return_value = None + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_not_called() + + async def test_expected_version_already_installed( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + ): + shim_client_mock.get_components.return_value.runner.version = self.EXPECTED_VERSION + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert "expected runner version already installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_not_called() + + @pytest.mark.parametrize("status", [ComponentStatus.NOT_INSTALLED, ComponentStatus.ERROR]) + async def test_install_not_installed_or_error( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_runner_download_url_mock: Mock, + status: ComponentStatus, + ): + shim_client_mock.get_components.return_value.runner.version = "" + shim_client_mock.get_components.return_value.runner.status = status + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert f"installing runner (no version) -> {self.EXPECTED_VERSION}" in debug_task_log.text + get_dstack_runner_download_url_mock.assert_called_once_with( + arch=None, + version=self.EXPECTED_VERSION, + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_called_once_with( + get_dstack_runner_download_url_mock.return_value + ) + + @pytest.mark.parametrize("installed_version", ["0.19.40", "0.21.0", "dev"]) + async def test_install_installed( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_runner_download_url_mock: Mock, + installed_version: str, + ): + shim_client_mock.get_components.return_value.runner.version = installed_version + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert ( + f"installing runner {installed_version} -> {self.EXPECTED_VERSION}" + in debug_task_log.text + ) + get_dstack_runner_download_url_mock.assert_called_once_with( + arch=None, + version=self.EXPECTED_VERSION, + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_called_once_with( + get_dstack_runner_download_url_mock.return_value + ) + + async def test_already_installing( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + ): + shim_client_mock.get_components.return_value.runner.version = "dev" + shim_client_mock.get_components.return_value.runner.status = ComponentStatus.INSTALLING + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert "runner is already being installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_not_called() + + +@pytest.mark.usefixtures("get_dstack_shim_version_mock") +class TestMaybeInstallShim(BaseTestMaybeInstallComponents): + @pytest.fixture + def component_list(self) -> ComponentList: + components = ComponentList() + components.add( + ComponentInfo( + name=ComponentName.SHIM, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + return components + + @pytest.fixture + def get_dstack_shim_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=self.EXPECTED_VERSION) + monkeypatch.setattr(instances_check, "get_dstack_shim_version", mock) + return mock + + @pytest.fixture + def get_dstack_shim_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value="https://example.com/shim") + monkeypatch.setattr(instances_check, "get_dstack_shim_download_url", mock) + return mock + + async def test_cannot_determine_expected_version( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_shim_version_mock: Mock, + ): + get_dstack_shim_version_mock.return_value = None + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_not_called() + + async def test_expected_version_already_installed( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + ): + shim_client_mock.get_components.return_value.shim.version = self.EXPECTED_VERSION + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert "expected shim version already installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_not_called() + + @pytest.mark.parametrize("status", [ComponentStatus.NOT_INSTALLED, ComponentStatus.ERROR]) + async def test_install_not_installed_or_error( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_shim_download_url_mock: Mock, + status: ComponentStatus, + ): + shim_client_mock.get_components.return_value.shim.version = "" + shim_client_mock.get_components.return_value.shim.status = status + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert f"installing shim (no version) -> {self.EXPECTED_VERSION}" in debug_task_log.text + get_dstack_shim_download_url_mock.assert_called_once_with( + arch=None, + version=self.EXPECTED_VERSION, + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_called_once_with( + get_dstack_shim_download_url_mock.return_value + ) + + @pytest.mark.parametrize("installed_version", ["0.19.40", "0.21.0", "dev"]) + async def test_install_installed( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_shim_download_url_mock: Mock, + installed_version: str, + ): + shim_client_mock.get_components.return_value.shim.version = installed_version + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert ( + f"installing shim {installed_version} -> {self.EXPECTED_VERSION}" + in debug_task_log.text + ) + get_dstack_shim_download_url_mock.assert_called_once_with( + arch=None, + version=self.EXPECTED_VERSION, + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_called_once_with( + get_dstack_shim_download_url_mock.return_value + ) + + async def test_already_installing( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + ): + shim_client_mock.get_components.return_value.shim.version = "dev" + shim_client_mock.get_components.return_value.shim.status = ComponentStatus.INSTALLING + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert "shim is already being installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_not_called() + + +@pytest.mark.usefixtures("maybe_install_runner_mock", "maybe_install_shim_mock") +class TestMaybeRestartShim(BaseTestMaybeInstallComponents): + @pytest.fixture + def component_list(self) -> ComponentList: + components = ComponentList() + components.add( + ComponentInfo( + name=ComponentName.RUNNER, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + components.add( + ComponentInfo( + name=ComponentName.SHIM, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + return components + + @pytest.fixture + def maybe_install_runner_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=False) + monkeypatch.setattr(instances_check, "_maybe_install_runner", mock) + return mock + + @pytest.fixture + def maybe_install_shim_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=False) + monkeypatch.setattr(instances_check, "_maybe_install_shim", mock) + return mock + + async def test_up_to_date(self, test_db, instance: InstanceModel, shim_client_mock: Mock): + shim_client_mock.get_version_string.return_value = self.EXPECTED_VERSION + shim_client_mock.is_safe_to_restart.return_value = True + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_no_shim_component_info( + self, test_db, instance: InstanceModel, shim_client_mock: Mock + ): + shim_client_mock.get_components.return_value = ComponentList() + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_shutdown_requested( + self, test_db, instance: InstanceModel, shim_client_mock: Mock + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_called_once_with(force=False) + + async def test_outdated_but_task_wont_survive_restart( + self, test_db, instance: InstanceModel, shim_client_mock: Mock + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = False + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_runner_installation_in_progress( + self, + test_db, + instance: InstanceModel, + shim_client_mock: Mock, + component_list: ComponentList, + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + runner_info = component_list.runner + assert runner_info is not None + runner_info.status = ComponentStatus.INSTALLING + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_shim_installation_in_progress( + self, + test_db, + instance: InstanceModel, + shim_client_mock: Mock, + component_list: ComponentList, + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + shim_info = component_list.shim + assert shim_info is not None + shim_info.status = ComponentStatus.INSTALLING + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_runner_installation_requested( + self, + test_db, + instance: InstanceModel, + shim_client_mock: Mock, + maybe_install_runner_mock: Mock, + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + maybe_install_runner_mock.return_value = True + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_shim_installation_requested( + self, + test_db, + instance: InstanceModel, + shim_client_mock: Mock, + maybe_install_shim_mock: Mock, + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + maybe_install_shim_mock.return_value = True + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py new file mode 100644 index 0000000000..afcb75336b --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py @@ -0,0 +1,872 @@ +from typing import Optional +from unittest.mock import Mock, patch + +import gpuhunt +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import NoCapacityError, ProvisioningError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.fleets import FleetNodesSpec, InstanceGroupPlacement +from dstack._internal.core.models.instances import ( + Gpu, + InstanceAvailability, + InstanceOffer, + InstanceOfferWithAvailability, + InstanceStatus, + InstanceTerminationReason, + InstanceType, + Resources, +) +from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData +from dstack._internal.core.models.runs import JobProvisioningData +from dstack._internal.server.background.pipeline_tasks.instances import InstanceWorker +from dstack._internal.server.models import PlacementGroupModel +from dstack._internal.server.testing.common import ( + ComputeMockSpec, + create_fleet, + create_instance, + create_placement_group, + create_project, + get_fleet_configuration, + get_fleet_spec, + get_instance_offer_with_availability, + get_job_provisioning_data, + get_placement_group_provisioning_data, +) +from tests._internal.server.background.pipeline_tasks.test_instances.helpers import ( + instance_to_pipeline_item, + lock_instance, + process_instance, +) + + +async def _set_current_master_instance(session: AsyncSession, fleet, instance) -> None: + fleet.current_master_instance_id = None if instance is None else instance.id + await session.commit() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestCloudProvisioning: + @pytest.mark.parametrize( + ["cpus", "gpus", "requested_blocks", "expected_blocks"], + [ + pytest.param(32, 8, 1, 1, id="gpu-instance-no-blocks"), + pytest.param(32, 8, 2, 2, id="gpu-instance-four-gpu-per-block"), + pytest.param(32, 8, 4, 4, id="gpu-instance-two-gpus-per-block"), + pytest.param(32, 8, None, 8, id="gpu-instance-auto-max-gpu"), + pytest.param(4, 8, None, 4, id="gpu-instance-auto-max-cpu"), + pytest.param(8, 8, None, 8, id="gpu-instance-auto-max-cpu-and-gpu"), + pytest.param(32, 0, 1, 1, id="cpu-instance-no-blocks"), + pytest.param(32, 0, 2, 2, id="cpu-instance-four-cpu-per-block"), + pytest.param(32, 0, 4, 4, id="cpu-instance-two-cpus-per-block"), + pytest.param(32, 0, None, 32, id="cpu-instance-auto-max-cpu"), + ], + ) + async def test_creates_instance( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + cpus: int, + gpus: int, + requested_blocks: Optional[int], + expected_blocks: int, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + total_blocks=requested_blocks, + busy_blocks=0, + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + backend_mock = Mock() + m.return_value = [backend_mock] + backend_mock.TYPE = BackendType.AWS + gpu = Gpu(name="T4", memory_mib=16384, vendor=gpuhunt.AcceleratorVendor.NVIDIA) + offer = InstanceOfferWithAvailability( + backend=BackendType.AWS, + instance=InstanceType( + name="instance", + resources=Resources( + cpus=cpus, + memory_mib=131072, + spot=False, + gpus=[gpu] * gpus, + ), + ), + region="us", + price=1.0, + availability=InstanceAvailability.AVAILABLE, + total_blocks=expected_blocks, + ) + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [offer] + backend_mock.compute.return_value.create_instance.return_value = JobProvisioningData( + backend=offer.backend, + instance_type=offer.instance, + instance_id="instance_id", + hostname="1.1.1.1", + internal_ip=None, + region=offer.region, + price=offer.price, + username="ubuntu", + ssh_port=22, + ssh_proxy=None, + dockerized=True, + backend_data=None, + ) + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + assert instance.total_blocks == expected_blocks + assert instance.busy_blocks == 0 + + @pytest.mark.parametrize("err", [RuntimeError("Unexpected"), ProvisioningError("Expected")]) + async def test_tries_second_offer_if_first_fails( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + err: Exception, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + ) + aws_mock = Mock() + aws_mock.TYPE = BackendType.AWS + offer = get_instance_offer_with_availability(backend=BackendType.AWS, price=1.0) + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + aws_mock.compute.return_value.get_offers.return_value = [offer] + aws_mock.compute.return_value.create_instance.side_effect = err + gcp_mock = Mock() + gcp_mock.TYPE = BackendType.GCP + offer = get_instance_offer_with_availability(backend=BackendType.GCP, price=2.0) + gcp_mock.compute.return_value = Mock(spec=ComputeMockSpec) + gcp_mock.compute.return_value.get_offers.return_value = [offer] + gcp_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=offer.backend, + region=offer.region, + price=offer.price, + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [aws_mock, gcp_mock] + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + aws_mock.compute.return_value.create_instance.assert_called_once() + assert instance.backend == BackendType.GCP + + @pytest.mark.parametrize("err", [RuntimeError("Unexpected"), ProvisioningError("Expected")]) + async def test_fails_if_all_offers_fail( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + err: Exception, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + ) + aws_mock = Mock() + aws_mock.TYPE = BackendType.AWS + offer = get_instance_offer_with_availability(backend=BackendType.AWS, price=1.0) + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + aws_mock.compute.return_value.get_offers.return_value = [offer] + aws_mock.compute.return_value.create_instance.side_effect = err + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [aws_mock] + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.NO_OFFERS + + async def test_fails_if_no_offers( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [] + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.NO_OFFERS + + async def test_waits_when_fleet_has_no_current_master( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=0, + ) + + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PENDING + assert backend_mock.compute.return_value.create_instance.call_count == 0 + + async def test_waits_for_current_master_to_determine_cluster_placement( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + master_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=0, + ) + sibling_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=1, + ) + await _set_current_master_instance(session, fleet, master_instance) + + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, sibling_instance) + + await session.refresh(master_instance) + await session.refresh(sibling_instance) + assert master_instance.status == InstanceStatus.PENDING + assert sibling_instance.status == InstanceStatus.PENDING + assert backend_mock.compute.return_value.create_instance.call_count == 0 + + async def test_failed_master_does_not_provision_stale_sibling_until_fleet_reassigns_it( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + master_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=0, + ) + sibling_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=1, + ) + await _set_current_master_instance(session, fleet, master_instance) + + lock_instance(master_instance) + lock_instance(sibling_instance) + await session.commit() + master_item = instance_to_pipeline_item(master_instance) + sibling_item = instance_to_pipeline_item(sibling_instance) + + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [] + await worker.process(master_item) + + await session.refresh(master_instance) + await session.refresh(sibling_instance) + assert master_instance.status == InstanceStatus.TERMINATED + assert master_instance.termination_reason == InstanceTerminationReason.NO_OFFERS + assert sibling_instance.status == InstanceStatus.PENDING + + gcp_mock = Mock() + gcp_mock.TYPE = BackendType.GCP + gcp_mock.compute.return_value = Mock(spec=ComputeMockSpec) + gcp_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.GCP, region="us-central1") + ] + gcp_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=BackendType.GCP, + region="us-central1", + ) + aws_mock = Mock() + aws_mock.TYPE = BackendType.AWS + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + aws_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, region="us-east-1") + ] + aws_mock.compute.return_value.create_placement_group.return_value = ( + get_placement_group_provisioning_data() + ) + aws_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ) + + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [gcp_mock, aws_mock] + await worker.process(sibling_item) + + await session.refresh(sibling_instance) + assert sibling_instance.status == InstanceStatus.PENDING + assert gcp_mock.compute.return_value.get_offers.call_count == 0 + assert gcp_mock.compute.return_value.create_instance.call_count == 0 + assert aws_mock.compute.return_value.create_instance.call_count == 0 + + await _set_current_master_instance(session, fleet, sibling_instance) + promoted_backend_mock = Mock() + promoted_backend_mock.TYPE = BackendType.AWS + promoted_backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + promoted_backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, region="us-east-1") + ] + promoted_backend_mock.compute.return_value.create_placement_group.return_value = ( + get_placement_group_provisioning_data() + ) + promoted_backend_mock.compute.return_value.create_instance.return_value = ( + get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ) + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [promoted_backend_mock] + await process_instance(session, worker, sibling_instance) + + await session.refresh(sibling_instance) + assert sibling_instance.status == InstanceStatus.PROVISIONING + assert sibling_instance.backend == BackendType.AWS + assert sibling_instance.region == "us-east-1" + assert promoted_backend_mock.compute.return_value.create_instance.call_count == 1 + + async def test_follows_current_master_backend_and_region_constraints( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + master_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + job_provisioning_data=get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ), + instance_num=0, + ) + sibling_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=1, + ) + await _set_current_master_instance(session, fleet, master_instance) + + gcp_mock = Mock() + gcp_mock.TYPE = BackendType.GCP + gcp_mock.compute.return_value = Mock(spec=ComputeMockSpec) + gcp_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.GCP, region="us-central1") + ] + gcp_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=BackendType.GCP, + region="us-central1", + ) + aws_mock = Mock() + aws_mock.TYPE = BackendType.AWS + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + aws_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, region="us-east-1") + ] + aws_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [gcp_mock, aws_mock] + await process_instance(session, worker, sibling_instance) + + await session.refresh(sibling_instance) + assert sibling_instance.status == InstanceStatus.PROVISIONING + assert sibling_instance.backend == BackendType.AWS + assert sibling_instance.region == "us-east-1" + assert gcp_mock.compute.return_value.get_offers.call_count == 0 + assert gcp_mock.compute.return_value.create_instance.call_count == 0 + assert aws_mock.compute.return_value.create_instance.call_count == 1 + + async def test_non_master_does_not_create_new_placement_group_without_master_pg( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + master_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + job_provisioning_data=get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ), + instance_num=0, + ) + sibling_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=1, + ) + await _set_current_master_instance(session, fleet, master_instance) + + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, region="us-east-1") + ] + backend_mock.compute.return_value.is_suitable_placement_group.return_value = True + backend_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, sibling_instance) + + await session.refresh(sibling_instance) + assert sibling_instance.status == InstanceStatus.PROVISIONING + assert backend_mock.compute.return_value.create_placement_group.call_count == 0 + placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() + assert len(placement_groups) == 0 + + async def test_non_master_reuses_existing_current_master_placement_group( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=3, target=3, max=3), + ) + ), + ) + master_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + job_provisioning_data=get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ), + instance_num=0, + ) + current_master_pg = await create_placement_group( + session=session, + project=project, + fleet=fleet, + ) + sibling_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=1, + ) + await _set_current_master_instance(session, fleet, master_instance) + + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, region="us-east-1") + ] + backend_mock.compute.return_value.is_suitable_placement_group.return_value = True + backend_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, sibling_instance) + + await session.refresh(sibling_instance) + assert sibling_instance.status == InstanceStatus.PROVISIONING + assert backend_mock.compute.return_value.create_placement_group.call_count == 0 + create_call = backend_mock.compute.return_value.create_instance.call_args + assert create_call is not None + assert create_call.args[2] is not None + assert create_call.args[2].name == current_master_pg.name + placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() + assert len(placement_groups) == 1 + + async def test_allows_parallel_processing_after_master_is_provisioned( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=3, target=3, max=3), + ) + ), + ) + master_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + job_provisioning_data=get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ), + instance_num=0, + ) + later_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=2, + ) + earlier_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=1, + ) + await _set_current_master_instance(session, fleet, master_instance) + + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, region="us-east-1") + ] + backend_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, later_instance) + assert backend_mock.compute.return_value.create_instance.call_count == 1 + await process_instance(session, worker, earlier_instance) + + await session.refresh(later_instance) + await session.refresh(earlier_instance) + assert later_instance.status == InstanceStatus.PROVISIONING + assert earlier_instance.status == InstanceStatus.PROVISIONING + assert backend_mock.compute.return_value.create_instance.call_count == 2 + + @pytest.mark.parametrize( + ("placement", "should_create"), + [ + pytest.param(InstanceGroupPlacement.CLUSTER, True, id="placement-cluster"), + pytest.param(None, False, id="no-placement"), + ], + ) + async def test_create_placement_group_if_placement_cluster( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + placement: Optional[InstanceGroupPlacement], + should_create: bool, + ) -> None: + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=placement, nodes=FleetNodesSpec(min=1, target=1, max=1) + ) + ), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + ) + if placement == InstanceGroupPlacement.CLUSTER: + await _set_current_master_instance(session, fleet, instance) + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability() + ] + backend_mock.compute.return_value.create_instance.return_value = ( + get_job_provisioning_data() + ) + backend_mock.compute.return_value.create_placement_group.return_value = ( + get_placement_group_provisioning_data() + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() + if should_create: + assert backend_mock.compute.return_value.create_placement_group.call_count == 1 + assert len(placement_groups) == 1 + else: + assert backend_mock.compute.return_value.create_placement_group.call_count == 0 + assert len(placement_groups) == 0 + + @pytest.mark.parametrize("can_reuse", [True, False]) + async def test_reuses_placement_group_between_offers_if_the_group_is_suitable( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + can_reuse: bool, + ) -> None: + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=1, target=1, max=1), + ) + ), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + ) + await _set_current_master_instance(session, fleet, instance) + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(instance_type="bad-offer-1"), + get_instance_offer_with_availability(instance_type="bad-offer-2"), + get_instance_offer_with_availability(instance_type="good-offer"), + ] + + def create_instance_method( + instance_offer: InstanceOfferWithAvailability, *args, **kwargs + ) -> JobProvisioningData: + if instance_offer.instance.name == "good-offer": + return get_job_provisioning_data() + raise NoCapacityError() + + backend_mock.compute.return_value.create_instance = create_instance_method + backend_mock.compute.return_value.create_placement_group.return_value = ( + get_placement_group_provisioning_data() + ) + backend_mock.compute.return_value.is_suitable_placement_group.return_value = can_reuse + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() + if can_reuse: + assert backend_mock.compute.return_value.create_placement_group.call_count == 1 + assert len(placement_groups) == 1 + else: + assert backend_mock.compute.return_value.create_placement_group.call_count == 3 + assert len(placement_groups) == 3 + to_be_deleted_count = sum(pg.fleet_deleted for pg in placement_groups) + assert to_be_deleted_count == 2 + + @pytest.mark.parametrize("err", [NoCapacityError(), RuntimeError()]) + async def test_handles_create_placement_group_errors( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + err: Exception, + ) -> None: + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=1, target=1, max=1), + ) + ), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + ) + await _set_current_master_instance(session, fleet, instance) + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(instance_type="bad-offer"), + get_instance_offer_with_availability(instance_type="good-offer"), + ] + backend_mock.compute.return_value.create_instance.return_value = ( + get_job_provisioning_data() + ) + + def create_placement_group_method( + placement_group: PlacementGroup, master_instance_offer: InstanceOffer + ) -> PlacementGroupProvisioningData: + if master_instance_offer.instance.name == "good-offer": + return get_placement_group_provisioning_data() + raise err + + backend_mock.compute.return_value.create_placement_group = create_placement_group_method + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + assert instance.offer + assert "good-offer" in instance.offer + assert "bad-offer" not in instance.offer + placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() + assert len(placement_groups) == 1 diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_pipeline.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_pipeline.py new file mode 100644 index 0000000000..012c7fdb38 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_pipeline.py @@ -0,0 +1,253 @@ +import datetime as dt +import uuid +from unittest.mock import Mock + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.server.background.pipeline_tasks.instances import ( + InstanceFetcher, + InstancePipeline, + InstanceWorker, +) +from dstack._internal.server.background.pipeline_tasks.instances import check as instances_check +from dstack._internal.server.schemas.instances import InstanceCheck +from dstack._internal.server.testing.common import ( + create_compute_group, + create_fleet, + create_instance, + create_project, +) +from dstack._internal.utils.common import get_current_datetime +from tests._internal.server.background.pipeline_tasks.test_instances.helpers import ( + instance_to_pipeline_item, + lock_instance, + process_instance, +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestInstanceFetcher: + async def test_fetch_selects_eligible_instances_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: InstanceFetcher + ): + project = await create_project(session=session) + fleet = await create_fleet(session=session, project=project) + compute_group = await create_compute_group(session=session, project=project, fleet=fleet) + now = get_current_datetime() + stale = now - dt.timedelta(minutes=1) + + pending = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + last_processed_at=stale - dt.timedelta(seconds=5), + ) + provisioning = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + name="provisioning", + last_processed_at=stale - dt.timedelta(seconds=4), + ) + busy = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + name="busy", + last_processed_at=stale - dt.timedelta(seconds=3), + ) + idle = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + name="idle", + last_processed_at=stale - dt.timedelta(seconds=2), + ) + terminating = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + name="terminating", + last_processed_at=stale - dt.timedelta(seconds=1), + ) + + deleted = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + name="deleted", + last_processed_at=stale, + ) + deleted.deleted = True + + recent = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + name="recent", + last_processed_at=now, + ) + + terminating_compute_group = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + name="terminating-compute-group", + last_processed_at=stale + dt.timedelta(seconds=1), + ) + terminating_compute_group.compute_group = compute_group + + locked = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + name="locked", + last_processed_at=stale + dt.timedelta(seconds=2), + ) + locked.lock_expires_at = now + dt.timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert {item.id for item in items} == { + pending.id, + provisioning.id, + busy.id, + idle.id, + terminating.id, + } + assert {item.status for item in items} == { + InstanceStatus.PENDING, + InstanceStatus.PROVISIONING, + InstanceStatus.BUSY, + InstanceStatus.IDLE, + InstanceStatus.TERMINATING, + } + + for instance in [ + pending, + provisioning, + busy, + idle, + terminating, + deleted, + recent, + terminating_compute_group, + locked, + ]: + await session.refresh(instance) + + expected_lock_owner = InstancePipeline.__name__ + fetched_instances = [pending, provisioning, busy, idle, terminating] + assert all(instance.lock_owner == expected_lock_owner for instance in fetched_instances) + assert all(instance.lock_expires_at is not None for instance in fetched_instances) + assert all(instance.lock_token is not None for instance in fetched_instances) + assert len({instance.lock_token for instance in fetched_instances}) == 1 + + assert deleted.lock_owner is None + assert recent.lock_owner is None + assert terminating_compute_group.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_respects_order_and_limit( + self, test_db, session: AsyncSession, fetcher: InstanceFetcher + ): + project = await create_project(session=session) + now = get_current_datetime() + + oldest = await create_instance( + session=session, + project=project, + name="oldest", + last_processed_at=now - dt.timedelta(minutes=3), + ) + middle = await create_instance( + session=session, + project=project, + name="middle", + last_processed_at=now - dt.timedelta(minutes=2), + ) + newest = await create_instance( + session=session, + project=project, + name="newest", + last_processed_at=now - dt.timedelta(minutes=1), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == InstancePipeline.__name__ + assert middle.lock_owner == InstancePipeline.__name__ + assert newest.lock_owner is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestInstanceWorker: + async def test_process_skips_when_lock_token_changes( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + ) + + lock_instance(instance) + await session.commit() + item = instance_to_pipeline_item(instance) + new_lock_token = uuid.uuid4() + instance.lock_token = new_lock_token + await session.commit() + + await worker.process(item) + await session.refresh(instance) + + assert instance.lock_token == new_lock_token + assert instance.lock_owner == InstancePipeline.__name__ + + async def test_process_unlocks_and_updates_last_processed_at_after_check( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + ) + before_processed_at = instance.last_processed_at + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.lock_expires_at is None + assert instance.lock_token is None + assert instance.lock_owner is None + assert instance.last_processed_at > before_processed_at diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_ssh_deploy.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_ssh_deploy.py new file mode 100644 index 0000000000..c103458ed4 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_ssh_deploy.py @@ -0,0 +1,248 @@ +import datetime as dt +from typing import Optional +from unittest.mock import Mock + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import SSHProvisioningError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.server.background.pipeline_tasks.instances import InstanceWorker +from dstack._internal.server.background.pipeline_tasks.instances import ( + ssh_deploy as instances_ssh_deploy, +) +from dstack._internal.server.testing.common import ( + create_instance, + create_project, + get_job_provisioning_data, + get_remote_connection_info, +) +from dstack._internal.utils.common import get_current_datetime +from tests._internal.server.background.pipeline_tasks.test_instances.helpers import ( + process_instance, +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestSSHDeploy: + async def test_pending_ssh_instance_terminates_on_provision_timeout( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime() - dt.timedelta(days=100), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.PROVISIONING_TIMEOUT + + @pytest.mark.parametrize( + ["cpus", "gpus", "requested_blocks", "expected_blocks"], + [ + pytest.param(32, 8, 1, 1, id="gpu-instance-no-blocks"), + pytest.param(32, 8, 2, 2, id="gpu-instance-four-gpu-per-block"), + pytest.param(32, 8, 4, 4, id="gpu-instance-two-gpus-per-block"), + pytest.param(32, 8, None, 8, id="gpu-instance-auto-max-gpu"), + pytest.param(4, 8, None, 4, id="gpu-instance-auto-max-cpu"), + pytest.param(8, 8, None, 8, id="gpu-instance-auto-max-cpu-and-gpu"), + pytest.param(32, 0, 1, 1, id="cpu-instance-no-blocks"), + pytest.param(32, 0, 2, 2, id="cpu-instance-four-cpu-per-block"), + pytest.param(32, 0, 4, 4, id="cpu-instance-two-cpus-per-block"), + pytest.param(32, 0, None, 32, id="cpu-instance-auto-max-cpu"), + ], + ) + async def test_adds_ssh_instance( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + host_info: dict, + deploy_instance_mock: Mock, + cpus: int, + gpus: int, + requested_blocks: Optional[int], + expected_blocks: int, + ): + host_info["cpus"] = cpus + host_info["gpu_count"] = gpus + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + total_blocks=requested_blocks, + busy_blocks=0, + ) + await session.commit() + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.IDLE + assert instance.total_blocks == expected_blocks + assert instance.busy_blocks == 0 + deploy_instance_mock.assert_called_once() + + async def test_retries_ssh_instance_if_provisioning_fails( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + deploy_instance_mock: Mock, + ): + deploy_instance_mock.side_effect = SSHProvisioningError("Expected") + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PENDING + assert instance.termination_reason is None + + async def test_terminates_ssh_instance_if_deploy_fails_unexpectedly( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + deploy_instance_mock: Mock, + ): + deploy_instance_mock.side_effect = RuntimeError("Unexpected") + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert instance.termination_reason_message == "Unexpected error when adding SSH instance" + + async def test_terminates_ssh_instance_if_key_is_invalid( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + monkeypatch.setattr( + instances_ssh_deploy, + "ssh_keys_to_pkeys", + Mock(side_effect=ValueError("Bad key")), + ) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert instance.termination_reason_message == "Unsupported private SSH key type" + + async def test_terminates_ssh_instance_if_internal_ip_cannot_be_resolved_from_network( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + host_info: dict, + deploy_instance_mock: Mock, + ): + host_info["addresses"] = ["192.168.100.100/24"] + project = await create_project(session=session) + job_provisioning_data = get_job_provisioning_data( + dockerized=True, + backend=BackendType.REMOTE, + internal_ip=None, + ) + job_provisioning_data.instance_network = "10.0.0.0/24" + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + job_provisioning_data=job_provisioning_data, + ) + await session.commit() + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert ( + instance.termination_reason_message + == "Failed to locate internal IP address on the given network" + ) + + async def test_terminates_ssh_instance_if_internal_ip_is_not_in_host_interfaces( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + host_info: dict, + deploy_instance_mock: Mock, + ): + host_info["addresses"] = ["192.168.100.100/24"] + project = await create_project(session=session) + job_provisioning_data = get_job_provisioning_data( + dockerized=True, + backend=BackendType.REMOTE, + internal_ip="10.0.0.20", + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + job_provisioning_data=job_provisioning_data, + ) + await session.commit() + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert ( + instance.termination_reason_message + == "Specified internal IP not found among instance interfaces" + ) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_termination.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_termination.py new file mode 100644 index 0000000000..b9da58fc11 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_termination.py @@ -0,0 +1,219 @@ +import datetime as dt +from contextlib import contextmanager +from typing import Optional +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from freezegun import freeze_time +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import BackendError, NotYetTerminated +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.server.background.pipeline_tasks.instances import InstanceWorker +from dstack._internal.server.background.pipeline_tasks.instances import ( + termination as instances_termination, +) +from dstack._internal.server.testing.common import create_instance, create_project +from tests._internal.server.background.pipeline_tasks.test_instances.helpers import ( + instance_to_pipeline_item, + lock_instance, + process_instance, +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestTermination: + @staticmethod + @contextmanager + def mock_terminate_in_backend(error: Optional[Exception] = None): + backend = Mock() + backend.TYPE = BackendType.VERDA + terminate_instance = backend.compute.return_value.terminate_instance + if error is not None: + terminate_instance.side_effect = error + with patch.object( + instances_termination.backends_services, + "get_project_backend_by_type", + AsyncMock(return_value=backend), + ): + yield terminate_instance + + async def test_terminate( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + instance.last_job_processed_at = dt.datetime.now(dt.timezone.utc) + dt.timedelta( + minutes=-19 + ) + await session.commit() + + with self.mock_terminate_in_backend() as mock: + await process_instance(session, worker, instance) + mock.assert_called_once() + + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.IDLE_TIMEOUT + assert instance.deleted is True + assert instance.deleted_at is not None + assert instance.finished_at is not None + + async def test_terminates_terminating_deleted_instance( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + lock_instance(instance) + await session.commit() + item = instance_to_pipeline_item(instance) + instance.deleted = True + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + instance.last_job_processed_at = instance.deleted_at = dt.datetime.now( + dt.timezone.utc + ) + dt.timedelta(minutes=-19) + await session.commit() + + with self.mock_terminate_in_backend() as mock: + await worker.process(item) + mock.assert_called_once() + + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATED + assert instance.deleted is True + assert instance.deleted_at is not None + assert instance.finished_at is not None + + @pytest.mark.parametrize( + "error", [BackendError("err"), RuntimeError("err"), NotYetTerminated("")] + ) + async def test_terminate_retry( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + error: Exception, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc) + instance.last_job_processed_at = initial_time + instance.last_processed_at = initial_time - dt.timedelta(minutes=1) + await session.commit() + + with ( + freeze_time(initial_time + dt.timedelta(minutes=1)), + self.mock_terminate_in_backend(error=error) as mock, + ): + await process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATING + + with ( + freeze_time(initial_time + dt.timedelta(minutes=2)), + self.mock_terminate_in_backend(error=None) as mock, + ): + await process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + + async def test_terminate_not_retries_if_too_early( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc) + instance.last_job_processed_at = initial_time + instance.last_processed_at = initial_time - dt.timedelta(minutes=1) + await session.commit() + + with ( + freeze_time(initial_time + dt.timedelta(minutes=1)), + self.mock_terminate_in_backend(error=BackendError("err")) as mock, + ): + await process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATING + + instance.last_processed_at = initial_time + await session.commit() + + with ( + freeze_time(initial_time + dt.timedelta(minutes=1, seconds=11)), + self.mock_terminate_in_backend(error=None) as mock, + ): + await process_instance(session, worker, instance) + mock.assert_not_called() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATING + + async def test_terminate_on_termination_deadline( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc) + instance.last_job_processed_at = initial_time + instance.last_processed_at = initial_time - dt.timedelta(minutes=1) + await session.commit() + + with ( + freeze_time(initial_time + dt.timedelta(minutes=1)), + self.mock_terminate_in_backend(error=BackendError("err")) as mock, + ): + await process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATING + + with ( + freeze_time(initial_time + dt.timedelta(minutes=15, seconds=55)), + self.mock_terminate_in_backend(error=None) as mock, + ): + await process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py index c23d5e604d..90c8e75194 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py @@ -1,5 +1,6 @@ +import asyncio import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from unittest.mock import Mock, patch import pytest @@ -7,7 +8,11 @@ from dstack._internal.core.errors import PlacementGroupInUseError from dstack._internal.server.background.pipeline_tasks.base import PipelineItem -from dstack._internal.server.background.pipeline_tasks.placement_groups import PlacementGroupWorker +from dstack._internal.server.background.pipeline_tasks.placement_groups import ( + PlacementGroupFetcher, + PlacementGroupPipeline, + PlacementGroupWorker, +) from dstack._internal.server.models import PlacementGroupModel from dstack._internal.server.testing.common import ( ComputeMockSpec, @@ -15,6 +20,7 @@ create_placement_group, create_project, ) +from dstack._internal.utils.common import get_current_datetime @pytest.fixture @@ -22,6 +28,17 @@ def worker() -> PlacementGroupWorker: return PlacementGroupWorker(queue=Mock(), heartbeater=Mock()) +@pytest.fixture +def fetcher() -> PlacementGroupFetcher: + return PlacementGroupFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=15), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + def _placement_group_to_pipeline_item(placement_group: PlacementGroupModel) -> PipelineItem: assert placement_group.lock_token is not None assert placement_group.lock_expires_at is not None @@ -34,9 +51,133 @@ def _placement_group_to_pipeline_item(placement_group: PlacementGroupModel) -> P ) +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestPlacementGroupFetcher: + async def test_fetch_selects_eligible_placement_groups_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: PlacementGroupFetcher + ): + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + eligible = await create_placement_group( + session=session, + project=project, + fleet=fleet, + fleet_deleted=True, + ) + eligible.last_processed_at = stale - timedelta(seconds=2) + + fleet_not_deleted = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="fleet-not-deleted", + fleet_deleted=False, + ) + fleet_not_deleted.last_processed_at = stale - timedelta(seconds=1) + + deleted = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="deleted", + fleet_deleted=True, + deleted=True, + ) + deleted.last_processed_at = stale + + recent = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="recent", + fleet_deleted=True, + ) + recent.last_processed_at = now + + locked = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="locked", + fleet_deleted=True, + ) + locked.last_processed_at = stale + timedelta(seconds=1) + locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert [item.id for item in items] == [eligible.id] + + for placement_group in [eligible, fleet_not_deleted, deleted, recent, locked]: + await session.refresh(placement_group) + + assert eligible.lock_owner == PlacementGroupPipeline.__name__ + assert eligible.lock_expires_at is not None + assert eligible.lock_token is not None + + assert fleet_not_deleted.lock_owner is None + assert deleted.lock_owner is None + assert recent.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_returns_oldest_placement_groups_first_up_to_limit( + self, test_db, session: AsyncSession, fetcher: PlacementGroupFetcher + ): + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + now = get_current_datetime() + + oldest = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="oldest", + fleet_deleted=True, + ) + oldest.last_processed_at = now - timedelta(minutes=3) + + middle = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="middle", + fleet_deleted=True, + ) + middle.last_processed_at = now - timedelta(minutes=2) + + newest = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="newest", + fleet_deleted=True, + ) + newest.last_processed_at = now - timedelta(minutes=1) + await session.commit() + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == PlacementGroupPipeline.__name__ + assert middle.lock_owner == PlacementGroupPipeline.__name__ + assert newest.lock_owner is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestPlacementGroupWorker: - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_deletes_placement_group( self, test_db, session: AsyncSession, worker: PlacementGroupWorker ): @@ -64,8 +205,6 @@ async def test_deletes_placement_group( await session.refresh(placement_group) assert placement_group.deleted - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_retries_placement_group_deletion_if_still_in_use( self, test_db, session: AsyncSession, worker: PlacementGroupWorker ): diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py b/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py index 4d22c59b97..63dfaaa45a 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py @@ -1,5 +1,6 @@ +import asyncio import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from unittest.mock import Mock, patch import pytest @@ -9,6 +10,8 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.volumes import VolumeProvisioningData, VolumeStatus from dstack._internal.server.background.pipeline_tasks.volumes import ( + VolumeFetcher, + VolumePipeline, VolumePipelineItem, VolumeWorker, ) @@ -22,6 +25,7 @@ get_volume_provisioning_data, list_events, ) +from dstack._internal.utils.common import get_current_datetime @pytest.fixture @@ -29,6 +33,17 @@ def worker() -> VolumeWorker: return VolumeWorker(queue=Mock(), heartbeater=Mock()) +@pytest.fixture +def fetcher() -> VolumeFetcher: + return VolumeFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=15), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + def _volume_to_pipeline_item(volume_model: VolumeModel) -> VolumePipelineItem: assert volume_model.lock_token is not None assert volume_model.lock_expires_at is not None @@ -43,6 +58,145 @@ def _volume_to_pipeline_item(volume_model: VolumeModel) -> VolumePipelineItem: ) +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestVolumeFetcher: + async def test_fetch_selects_eligible_volumes_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: VolumeFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + submitted = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=stale - timedelta(minutes=1), + last_processed_at=stale - timedelta(seconds=2), + ) + to_be_deleted = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + created_at=stale - timedelta(minutes=1), + last_processed_at=stale - timedelta(seconds=1), + ) + to_be_deleted.to_be_deleted = True + + just_created = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=now, + last_processed_at=now, + ) + + deleted = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=stale - timedelta(minutes=1), + last_processed_at=stale, + deleted_at=stale, + ) + recent = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=now - timedelta(minutes=2), + last_processed_at=now, + ) + locked = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=stale - timedelta(minutes=1), + last_processed_at=stale + timedelta(seconds=1), + ) + locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert {item.id for item in items} == { + submitted.id, + to_be_deleted.id, + just_created.id, + } + assert {(item.id, item.status, item.to_be_deleted) for item in items} == { + (submitted.id, VolumeStatus.SUBMITTED, False), + (to_be_deleted.id, VolumeStatus.ACTIVE, True), + (just_created.id, VolumeStatus.SUBMITTED, False), + } + + for volume in [submitted, to_be_deleted, just_created, deleted, recent, locked]: + await session.refresh(volume) + + fetched_volumes = [submitted, to_be_deleted, just_created] + assert all(volume.lock_owner == VolumePipeline.__name__ for volume in fetched_volumes) + assert all(volume.lock_expires_at is not None for volume in fetched_volumes) + assert all(volume.lock_token is not None for volume in fetched_volumes) + assert len({volume.lock_token for volume in fetched_volumes}) == 1 + + assert deleted.lock_owner is None + assert recent.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_returns_oldest_volumes_first_up_to_limit( + self, test_db, session: AsyncSession, fetcher: VolumeFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + now = get_current_datetime() + + oldest = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=now - timedelta(minutes=4), + last_processed_at=now - timedelta(minutes=3), + ) + middle = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=now - timedelta(minutes=3), + last_processed_at=now - timedelta(minutes=2), + ) + newest = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=now - timedelta(minutes=2), + last_processed_at=now - timedelta(minutes=1), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == VolumePipeline.__name__ + assert middle.lock_owner == VolumePipeline.__name__ + assert newest.lock_owner is None + + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestVolumeWorkerSubmitted: diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_instance_healthchecks.py b/src/tests/_internal/server/background/scheduled_tasks/test_instance_healthchecks.py new file mode 100644 index 0000000000..06ea5ab5ac --- /dev/null +++ b/src/tests/_internal/server/background/scheduled_tasks/test_instance_healthchecks.py @@ -0,0 +1,49 @@ +from datetime import timedelta + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.server.background.scheduled_tasks.instance_healthchecks import ( + delete_instance_healthchecks, +) +from dstack._internal.server.models import InstanceHealthCheckModel, InstanceStatus +from dstack._internal.server.testing.common import ( + create_instance, + create_instance_health_check, + create_project, +) +from dstack._internal.utils.common import get_current_datetime + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("test_db", "image_config_mock") +class TestDeleteInstanceHealthChecks: + async def test_deletes_instance_health_checks( + self, monkeypatch: pytest.MonkeyPatch, session: AsyncSession + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.IDLE + ) + # 30 minutes + monkeypatch.setattr( + "dstack._internal.server.settings.SERVER_INSTANCE_HEALTH_TTL_SECONDS", 1800 + ) + now = get_current_datetime() + # old check + await create_instance_health_check( + session=session, instance=instance, collected_at=now - timedelta(minutes=40) + ) + # recent check + check = await create_instance_health_check( + session=session, instance=instance, collected_at=now - timedelta(minutes=20) + ) + + await delete_instance_healthchecks() + + res = await session.execute(select(InstanceHealthCheckModel)) + all_checks = res.scalars().all() + assert len(all_checks) == 1 + assert all_checks[0] == check diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_instances.py b/src/tests/_internal/server/background/scheduled_tasks/test_instances.py index 1b9789953e..88e4acc949 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_instances.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_instances.py @@ -19,6 +19,7 @@ NoCapacityError, NotYetTerminated, ProvisioningError, + SSHProvisioningError, ) from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.fleets import InstanceGroupPlacement @@ -40,7 +41,6 @@ JobStatus, ) from dstack._internal.server.background.scheduled_tasks.instances import ( - delete_instance_health_checks, process_instances, ) from dstack._internal.server.models import ( @@ -65,7 +65,6 @@ ComputeMockSpec, create_fleet, create_instance, - create_instance_health_check, create_job, create_project, create_repo, @@ -1206,38 +1205,141 @@ async def test_adds_ssh_instance( assert instance.total_blocks == expected_blocks assert instance.busy_blocks == 0 + async def test_retries_ssh_instance_if_provisioning_fails( + self, + session: AsyncSession, + deploy_instance_mock: Mock, + ): + deploy_instance_mock.side_effect = SSHProvisioningError("Expected") + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() -@pytest.mark.asyncio -@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) -@pytest.mark.usefixtures("test_db", "image_config_mock") -class TestDeleteInstanceHealthChecks: - async def test_deletes_instance_health_checks( - self, monkeypatch: pytest.MonkeyPatch, session: AsyncSession + await process_instances() + + await session.refresh(instance) + assert instance.status == InstanceStatus.PENDING + assert instance.termination_reason is None + + async def test_terminates_ssh_instance_if_deploy_fails_unexpectedly( + self, + session: AsyncSession, + deploy_instance_mock: Mock, ): + deploy_instance_mock.side_effect = RuntimeError("Unexpected") project = await create_project(session=session) instance = await create_instance( - session=session, project=project, status=InstanceStatus.IDLE + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), ) - # 30 minutes + await session.commit() + + await process_instances() + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert instance.termination_reason_message == "Unexpected error when adding SSH instance" + + async def test_terminates_ssh_instance_if_key_is_invalid( + self, + session: AsyncSession, + monkeypatch: pytest.MonkeyPatch, + ): monkeypatch.setattr( - "dstack._internal.server.settings.SERVER_INSTANCE_HEALTH_TTL_SECONDS", 1800 + "dstack._internal.server.background.scheduled_tasks.instances._ssh_keys_to_pkeys", + Mock(side_effect=ValueError("Bad key")), + ) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await process_instances() + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert instance.termination_reason_message == "Unsupported private SSH key type" + + async def test_terminates_ssh_instance_if_internal_ip_cannot_be_resolved_from_network( + self, + session: AsyncSession, + host_info: dict, + ): + host_info["addresses"] = ["192.168.100.100/24"] + project = await create_project(session=session) + job_provisioning_data = get_job_provisioning_data( + dockerized=True, + backend=BackendType.REMOTE, + internal_ip=None, + ) + job_provisioning_data.instance_network = "10.0.0.0/24" + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + job_provisioning_data=job_provisioning_data, + ) + await session.commit() + + await process_instances() + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert ( + instance.termination_reason_message + == "Failed to locate internal IP address on the given network" ) - now = get_current_datetime() - # old check - await create_instance_health_check( - session=session, instance=instance, collected_at=now - dt.timedelta(minutes=40) + + async def test_terminates_ssh_instance_if_internal_ip_is_not_in_host_interfaces( + self, + session: AsyncSession, + host_info: dict, + ): + host_info["addresses"] = ["192.168.100.100/24"] + project = await create_project(session=session) + job_provisioning_data = get_job_provisioning_data( + dockerized=True, + backend=BackendType.REMOTE, + internal_ip="10.0.0.20", ) - # recent check - check = await create_instance_health_check( - session=session, instance=instance, collected_at=now - dt.timedelta(minutes=20) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + job_provisioning_data=job_provisioning_data, ) + await session.commit() - await delete_instance_health_checks() + await process_instances() - res = await session.execute(select(InstanceHealthCheckModel)) - all_checks = res.scalars().all() - assert len(all_checks) == 1 - assert all_checks[0] == check + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert ( + instance.termination_reason_message + == "Specified internal IP not found among instance interfaces" + ) @pytest.mark.asyncio diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 12108eed31..d14e74e80d 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -1659,6 +1659,84 @@ async def test_terminates_fleet_instances( assert instance2.status != InstanceStatus.TERMINATING assert fleet.status != FleetStatus.TERMINATING + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_ignores_lock_on_non_selected_instances( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + fleet = await create_fleet(session=session, project=project) + instance1 = await create_instance( + session=session, + project=project, + instance_num=1, + ) + instance2 = await create_instance( + session=session, + project=project, + instance_num=2, + ) + fleet.instances.append(instance1) + fleet.instances.append(instance2) + instance2.lock_expires_at = datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc) + await session.commit() + + response = await client.post( + f"/api/project/{project.name}/fleets/delete_instances", + headers=get_auth_headers(user.token), + json={"name": fleet.name, "instance_nums": [1]}, + ) + assert response.status_code == 200 + await session.refresh(fleet) + await session.refresh(instance1) + await session.refresh(instance2) + assert instance1.status == InstanceStatus.TERMINATING + assert instance2.status != InstanceStatus.TERMINATING + assert fleet.status != FleetStatus.TERMINATING + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_400_when_selected_instance_locked( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + fleet = await create_fleet(session=session, project=project) + instance1 = await create_instance( + session=session, + project=project, + instance_num=1, + ) + instance2 = await create_instance( + session=session, + project=project, + instance_num=2, + ) + fleet.instances.append(instance1) + fleet.instances.append(instance2) + instance1.lock_expires_at = datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc) + await session.commit() + + response = await client.post( + f"/api/project/{project.name}/fleets/delete_instances", + headers=get_auth_headers(user.token), + json={"name": fleet.name, "instance_nums": [1]}, + ) + assert response.status_code == 400 + await session.refresh(fleet) + await session.refresh(instance1) + await session.refresh(instance2) + assert instance1.status != InstanceStatus.TERMINATING + assert instance2.status != InstanceStatus.TERMINATING + assert fleet.status != FleetStatus.TERMINATING + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_returns_400_when_deleting_busy_instances( diff --git a/src/tests/_internal/server/services/test_instances.py b/src/tests/_internal/server/services/test_instances.py index 4883e309cc..0eef682003 100644 --- a/src/tests/_internal/server/services/test_instances.py +++ b/src/tests/_internal/server/services/test_instances.py @@ -1,4 +1,5 @@ import uuid +from unittest.mock import Mock, call import pytest from sqlalchemy.ext.asyncio import AsyncSession @@ -14,10 +15,16 @@ Resources, ) from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.runs import JobStatus from dstack._internal.server.models import InstanceModel +from dstack._internal.server.schemas.runner import TaskListItem, TaskListResponse, TaskStatus +from dstack._internal.server.services.runner.client import ShimClient from dstack._internal.server.testing.common import ( create_instance, + create_job, create_project, + create_repo, + create_run, create_user, get_volume, get_volume_configuration, @@ -155,6 +162,117 @@ async def test_returns_volume_instances(self, test_db, session: AsyncSession): assert res == [runpod_instance2] +@pytest.mark.asyncio +@pytest.mark.usefixtures("image_config_mock") +@pytest.mark.usefixtures("turn_off_keep_shim_tasks_setting") +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestRemoveDanglingTasks: + @pytest.fixture + def turn_off_keep_shim_tasks_setting(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("dstack._internal.server.settings.SERVER_KEEP_SHIM_TASKS", False) + + async def test_terminates_and_removes_dangling_tasks( + self, test_db, session: AsyncSession + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + instance=instance, + ) + dangling_task_id_1 = "fe138b77-d0b1-49d3-8c9f-2dfe78ece727" + dangling_task_id_2 = "8b016a75-41de-44f1-91ff-c9b63d2caa1d" + shim_client_mock = Mock(spec_set=ShimClient) + shim_client_mock.is_api_v2_supported.return_value = True + shim_client_mock.list_tasks.return_value = TaskListResponse( + tasks=[ + TaskListItem(id=str(job.id), status=TaskStatus.RUNNING), + TaskListItem(id=dangling_task_id_1, status=TaskStatus.RUNNING), + TaskListItem(id=dangling_task_id_2, status=TaskStatus.TERMINATED), + ] + ) + await session.refresh(instance, attribute_names=["jobs"]) + + instances_services.remove_dangling_tasks_from_instance(shim_client_mock, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.BUSY + + shim_client_mock.terminate_task.assert_called_once_with( + task_id=dangling_task_id_1, + reason=None, + message=None, + timeout=0, + ) + assert shim_client_mock.remove_task.call_count == 2 + shim_client_mock.remove_task.assert_has_calls( + [call(task_id=dangling_task_id_1), call(task_id=dangling_task_id_2)] + ) + + async def test_terminates_and_removes_dangling_tasks_legacy_shim( + self, test_db, session: AsyncSession + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + instance=instance, + ) + dangling_task_id_1 = "fe138b77-d0b1-49d3-8c9f-2dfe78ece727" + dangling_task_id_2 = "8b016a75-41de-44f1-91ff-c9b63d2caa1d" + shim_client_mock = Mock(spec_set=ShimClient) + shim_client_mock.is_api_v2_supported.return_value = True + shim_client_mock.list_tasks.return_value = TaskListResponse( + ids=[str(job.id), dangling_task_id_1, dangling_task_id_2] + ) + await session.refresh(instance, attribute_names=["jobs"]) + + instances_services.remove_dangling_tasks_from_instance(shim_client_mock, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.BUSY + + assert shim_client_mock.terminate_task.call_count == 2 + shim_client_mock.terminate_task.assert_has_calls( + [ + call(task_id=dangling_task_id_1, reason=None, message=None, timeout=0), + call(task_id=dangling_task_id_2, reason=None, message=None, timeout=0), + ] + ) + assert shim_client_mock.remove_task.call_count == 2 + shim_client_mock.remove_task.assert_has_calls( + [call(task_id=dangling_task_id_1), call(task_id=dangling_task_id_2)] + ) + + class TestInstanceModelToInstance: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)