From 6615ff6e46b09314f81fc2427c3326a86a6f9846 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 2 Mar 2026 12:02:45 +0500 Subject: [PATCH 01/51] Move delete_instance_health_checks to a separate module --- .../background/scheduled_tasks/__init__.py | 14 +++--- .../scheduled_tasks/instance_healthchecks.py | 20 ++++++++ .../background/scheduled_tasks/instances.py | 13 +---- .../_internal/server/services/fleets.py | 6 +-- .../server/services/ssh_fleets/__init__.py | 0 .../test_instance_healthchecks.py | 49 +++++++++++++++++++ .../scheduled_tasks/test_instances.py | 35 ------------- 7 files changed, 80 insertions(+), 57 deletions(-) create mode 100644 src/dstack/_internal/server/background/scheduled_tasks/instance_healthchecks.py create mode 100644 src/dstack/_internal/server/services/ssh_fleets/__init__.py create mode 100644 src/tests/_internal/server/background/scheduled_tasks/test_instance_healthchecks.py diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py index 9c7cd6ac1a..6b7f6f3389 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -14,8 +14,10 @@ from dstack._internal.server.background.scheduled_tasks.idle_volumes import ( process_idle_volumes, ) +from dstack._internal.server.background.scheduled_tasks.instance_healthchecks import ( + delete_instance_healthchecks, +) from dstack._internal.server.background.scheduled_tasks.instances import ( - delete_instance_health_checks, process_instances, ) from dstack._internal.server.background.scheduled_tasks.metrics import ( @@ -93,16 +95,16 @@ def start_scheduled_tasks() -> AsyncIOScheduler: _scheduler.add_job(collect_metrics, IntervalTrigger(seconds=10), max_instances=1) _scheduler.add_job(delete_metrics, IntervalTrigger(minutes=5), max_instances=1) _scheduler.add_job(delete_events, IntervalTrigger(minutes=7), max_instances=1) + _scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15)) + _scheduler.add_job( + process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 + ) + _scheduler.add_job(delete_instance_healthchecks, IntervalTrigger(minutes=5), max_instances=1) if settings.ENABLE_PROMETHEUS_METRICS: _scheduler.add_job( collect_prometheus_metrics, IntervalTrigger(seconds=10), max_instances=1 ) _scheduler.add_job(delete_prometheus_metrics, IntervalTrigger(minutes=5), max_instances=1) - _scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15)) - _scheduler.add_job( - process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 - ) - _scheduler.add_job(delete_instance_health_checks, IntervalTrigger(minutes=5), max_instances=1) if not FeatureFlags.PIPELINE_PROCESSING_ENABLED: _scheduler.add_job( process_fleets, diff --git a/src/dstack/_internal/server/background/scheduled_tasks/instance_healthchecks.py b/src/dstack/_internal/server/background/scheduled_tasks/instance_healthchecks.py new file mode 100644 index 0000000000..41e83c71aa --- /dev/null +++ b/src/dstack/_internal/server/background/scheduled_tasks/instance_healthchecks.py @@ -0,0 +1,20 @@ +from datetime import timedelta + +from sqlalchemy import delete + +from dstack._internal.server import settings +from dstack._internal.server.db import get_session_ctx +from dstack._internal.server.models import InstanceHealthCheckModel +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime + + +@sentry_utils.instrument_scheduled_task +async def delete_instance_healthchecks(): + now = get_current_datetime() + cutoff = now - timedelta(seconds=settings.SERVER_INSTANCE_HEALTH_TTL_SECONDS) + async with get_session_ctx() as session: + await session.execute( + delete(InstanceHealthCheckModel).where(InstanceHealthCheckModel.collected_at < cutoff) + ) + await session.commit() diff --git a/src/dstack/_internal/server/background/scheduled_tasks/instances.py b/src/dstack/_internal/server/background/scheduled_tasks/instances.py index e5ecba5278..a14217c726 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/instances.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/instances.py @@ -9,7 +9,7 @@ from paramiko.pkey import PKey from paramiko.ssh_exception import PasswordRequiredException from pydantic import ValidationError -from sqlalchemy import and_, delete, func, not_, select +from sqlalchemy import and_, func, not_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -152,17 +152,6 @@ async def process_instances(batch_size: int = 1): await asyncio.gather(*tasks) -@sentry_utils.instrument_scheduled_task -async def delete_instance_health_checks(): - now = get_current_datetime() - cutoff = now - timedelta(seconds=server_settings.SERVER_INSTANCE_HEALTH_TTL_SECONDS) - async with get_session_ctx() as session: - await session.execute( - delete(InstanceHealthCheckModel).where(InstanceHealthCheckModel.collected_at < cutoff) - ) - await session.commit() - - @sentry_utils.instrument_scheduled_task async def _process_next_instance(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index c0ec21aeaa..00ddb0e2ed 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -51,7 +51,7 @@ from dstack._internal.core.models.users import GlobalRole from dstack._internal.core.services import validate_dstack_resource_name from dstack._internal.core.services.diff import ModelDiff, copy_model, diff_models -from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite +from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite, sqlite_commit from dstack._internal.server.models import ( FleetModel, InstanceModel, @@ -686,9 +686,7 @@ async def delete_fleets( .order_by(InstanceModel.id) ) instances_ids = list(res.scalars().unique().all()) - if is_db_sqlite(): - # Start new transaction to see committed changes after lock - await session.commit() + await sqlite_commit() async with ( get_locker(get_db().dialect_name).lock_ctx(FleetModel.__tablename__, fleets_ids), get_locker(get_db().dialect_name).lock_ctx(InstanceModel.__tablename__, instances_ids), diff --git a/src/dstack/_internal/server/services/ssh_fleets/__init__.py b/src/dstack/_internal/server/services/ssh_fleets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_instance_healthchecks.py b/src/tests/_internal/server/background/scheduled_tasks/test_instance_healthchecks.py new file mode 100644 index 0000000000..06ea5ab5ac --- /dev/null +++ b/src/tests/_internal/server/background/scheduled_tasks/test_instance_healthchecks.py @@ -0,0 +1,49 @@ +from datetime import timedelta + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.server.background.scheduled_tasks.instance_healthchecks import ( + delete_instance_healthchecks, +) +from dstack._internal.server.models import InstanceHealthCheckModel, InstanceStatus +from dstack._internal.server.testing.common import ( + create_instance, + create_instance_health_check, + create_project, +) +from dstack._internal.utils.common import get_current_datetime + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("test_db", "image_config_mock") +class TestDeleteInstanceHealthChecks: + async def test_deletes_instance_health_checks( + self, monkeypatch: pytest.MonkeyPatch, session: AsyncSession + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.IDLE + ) + # 30 minutes + monkeypatch.setattr( + "dstack._internal.server.settings.SERVER_INSTANCE_HEALTH_TTL_SECONDS", 1800 + ) + now = get_current_datetime() + # old check + await create_instance_health_check( + session=session, instance=instance, collected_at=now - timedelta(minutes=40) + ) + # recent check + check = await create_instance_health_check( + session=session, instance=instance, collected_at=now - timedelta(minutes=20) + ) + + await delete_instance_healthchecks() + + res = await session.execute(select(InstanceHealthCheckModel)) + all_checks = res.scalars().all() + assert len(all_checks) == 1 + assert all_checks[0] == check diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_instances.py b/src/tests/_internal/server/background/scheduled_tasks/test_instances.py index 1b9789953e..c9bcc9e7b0 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_instances.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_instances.py @@ -40,7 +40,6 @@ JobStatus, ) from dstack._internal.server.background.scheduled_tasks.instances import ( - delete_instance_health_checks, process_instances, ) from dstack._internal.server.models import ( @@ -65,7 +64,6 @@ ComputeMockSpec, create_fleet, create_instance, - create_instance_health_check, create_job, create_project, create_repo, @@ -1207,39 +1205,6 @@ async def test_adds_ssh_instance( assert instance.busy_blocks == 0 -@pytest.mark.asyncio -@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) -@pytest.mark.usefixtures("test_db", "image_config_mock") -class TestDeleteInstanceHealthChecks: - async def test_deletes_instance_health_checks( - self, monkeypatch: pytest.MonkeyPatch, session: AsyncSession - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, project=project, status=InstanceStatus.IDLE - ) - # 30 minutes - monkeypatch.setattr( - "dstack._internal.server.settings.SERVER_INSTANCE_HEALTH_TTL_SECONDS", 1800 - ) - now = get_current_datetime() - # old check - await create_instance_health_check( - session=session, instance=instance, collected_at=now - dt.timedelta(minutes=40) - ) - # recent check - check = await create_instance_health_check( - session=session, instance=instance, collected_at=now - dt.timedelta(minutes=20) - ) - - await delete_instance_health_checks() - - res = await session.execute(select(InstanceHealthCheckModel)) - all_checks = res.scalars().all() - assert len(all_checks) == 1 - assert all_checks[0] == check - - @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @pytest.mark.usefixtures("test_db", "instance", "ssh_tunnel_mock", "shim_client_mock") From 85f8866bf2aa2763e9b4afc53fa61b68b15e017c Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 2 Mar 2026 12:22:30 +0500 Subject: [PATCH 02/51] Move utils/provisioning.py to ssh_fleets/provisioning.py --- .../_internal/server/background/scheduled_tasks/instances.py | 4 ++-- .../server/{utils => services/ssh_fleets}/provisioning.py | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename src/dstack/_internal/server/{utils => services/ssh_fleets}/provisioning.py (100%) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/instances.py b/src/dstack/_internal/server/background/scheduled_tasks/instances.py index a14217c726..636d05d2f2 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/instances.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/instances.py @@ -108,8 +108,7 @@ ) from dstack._internal.server.services.runner import client as runner_client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel -from dstack._internal.server.utils import sentry_utils -from dstack._internal.server.utils.provisioning import ( +from dstack._internal.server.services.ssh_fleets.provisioning import ( detect_cpu_arch, get_host_info, get_paramiko_connection, @@ -121,6 +120,7 @@ run_shim_as_systemd_service, upload_envs, ) +from dstack._internal.server.utils import sentry_utils from dstack._internal.utils.common import ( get_current_datetime, get_or_error, diff --git a/src/dstack/_internal/server/utils/provisioning.py b/src/dstack/_internal/server/services/ssh_fleets/provisioning.py similarity index 100% rename from src/dstack/_internal/server/utils/provisioning.py rename to src/dstack/_internal/server/services/ssh_fleets/provisioning.py From 0c74816370de5e2815c50b9e49f39f02a2cbfe4a Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 2 Mar 2026 12:32:34 +0500 Subject: [PATCH 03/51] Use SSHProvisioningError for ssh instances errors --- src/dstack/_internal/core/errors.py | 4 ++ .../background/scheduled_tasks/instances.py | 13 ++--- .../services/ssh_fleets/provisioning.py | 58 +++++++++---------- .../scheduled_tasks/test_instances.py | 23 ++++++++ 4 files changed, 62 insertions(+), 36 deletions(-) diff --git a/src/dstack/_internal/core/errors.py b/src/dstack/_internal/core/errors.py index 0bfd5f6f33..0d4262fe9b 100644 --- a/src/dstack/_internal/core/errors.py +++ b/src/dstack/_internal/core/errors.py @@ -136,6 +136,10 @@ class ConfigurationError(DstackError): pass +class SSHProvisioningError(DstackError): + pass + + class SSHError(DstackError): pass diff --git a/src/dstack/_internal/server/background/scheduled_tasks/instances.py b/src/dstack/_internal/server/background/scheduled_tasks/instances.py index 636d05d2f2..6b671d18af 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/instances.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/instances.py @@ -33,12 +33,11 @@ BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT, ) from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT - -# FIXME: ProvisioningError is a subclass of ComputeError and should not be used outside of Compute from dstack._internal.core.errors import ( BackendError, NotYetTerminated, ProvisioningError, + SSHProvisioningError, ) from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.fleets import InstanceGroupPlacement @@ -340,10 +339,10 @@ async def _add_remote(session: AsyncSession, instance: InstanceModel) -> None: result = await asyncio.wait_for(future, timeout=deploy_timeout) health, host_info, arch = result except (asyncio.TimeoutError, TimeoutError) as e: - raise ProvisioningError(f"Deploy timeout: {e}") from e + raise SSHProvisioningError(f"Deploy timeout: {e}") from e except Exception as e: - raise ProvisioningError(f"Deploy instance raised an error: {e}") from e - except ProvisioningError as e: + raise SSHProvisioningError(f"Deploy instance raised an error: {e}") from e + except SSHProvisioningError as e: logger.warning( "Provisioning instance %s could not be completed because of the error: %s", instance.name, @@ -462,7 +461,7 @@ def _deploy_instance( try: fleet_configuration_envs = remote_details.env.as_dict() except ValueError as e: - raise ProvisioningError(f"Invalid Env: {e}") from e + raise SSHProvisioningError(f"Invalid Env: {e}") from e shim_envs.update(fleet_configuration_envs) dstack_working_dir = get_dstack_working_dir() dstack_shim_binary_path = get_dstack_shim_binary_path() @@ -490,7 +489,7 @@ def _deploy_instance( try: healthcheck = HealthcheckResponse.__response__.parse_raw(healthcheck_out) except ValueError as e: - raise ProvisioningError(f"Cannot parse HealthcheckResponse: {e}") from e + raise SSHProvisioningError(f"Cannot parse HealthcheckResponse: {e}") from e instance_check = runner_client.healthcheck_response_to_instance_check(healthcheck) return instance_check, host_info, arch diff --git a/src/dstack/_internal/server/services/ssh_fleets/provisioning.py b/src/dstack/_internal/server/services/ssh_fleets/provisioning.py index fcbe3bf086..3a7c21e6dd 100644 --- a/src/dstack/_internal/server/services/ssh_fleets/provisioning.py +++ b/src/dstack/_internal/server/services/ssh_fleets/provisioning.py @@ -14,9 +14,7 @@ normalize_arch, ) from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT - -# FIXME: ProvisioningError is a subclass of ComputeError and should not be used outside of Compute -from dstack._internal.core.errors import ProvisioningError +from dstack._internal.core.errors import SSHProvisioningError from dstack._internal.core.models.instances import ( Disk, Gpu, @@ -46,15 +44,15 @@ def detect_cpu_arch(client: paramiko.SSHClient) -> GoArchType: try: _, stdout, stderr = client.exec_command(cmd, timeout=20) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"detect_cpu_arch: {e}") from e + raise SSHProvisioningError(f"detect_cpu_arch: {e}") from e out = stdout.read().strip().decode() err = stderr.read().strip().decode() if err: - raise ProvisioningError(f"detect_cpu_arch: {cmd} failed, stdout: {out}, stderr: {err}") + raise SSHProvisioningError(f"detect_cpu_arch: {cmd} failed, stdout: {out}, stderr: {err}") try: return normalize_arch(out) except ValueError as e: - raise ProvisioningError(f"detect_cpu_arch: failed to normalize arch: {e}") from e + raise SSHProvisioningError(f"detect_cpu_arch: failed to normalize arch: {e}") from e def sftp_upload(client: paramiko.SSHClient, path: str, body: str) -> None: @@ -66,7 +64,7 @@ def sftp_upload(client: paramiko.SSHClient, path: str, body: str) -> None: sftp.putfo(io.BytesIO(body.encode()), path) sftp.close() except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"sft_upload failed: {e}") from e + raise SSHProvisioningError(f"sft_upload failed: {e}") from e def upload_envs(client: paramiko.SSHClient, working_dir: str, envs: Dict[str, str]) -> None: @@ -80,11 +78,11 @@ def upload_envs(client: paramiko.SSHClient, working_dir: str, envs: Dict[str, st out = stdout.read().strip().decode() err = stderr.read().strip().decode() if out or err: - raise ProvisioningError( + raise SSHProvisioningError( f"The command 'upload_envs' didn't work. stdout: {out}, stderr: {err}" ) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"upload_envs failed: {e}") from e + raise SSHProvisioningError(f"upload_envs failed: {e}") from e def run_pre_start_commands( @@ -98,11 +96,11 @@ def run_pre_start_commands( out = stdout.read().strip().decode() err = stderr.read().strip().decode() if out or err: - raise ProvisioningError( + raise SSHProvisioningError( f"The command 'authorized_keys' didn't work. stdout: {out}, stderr: {err}" ) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"upload authorized_keys failed: {e}") from e + raise SSHProvisioningError(f"upload authorized_keys failed: {e}") from e script = " && ".join(shim_pre_start_commands) try: @@ -110,11 +108,11 @@ def run_pre_start_commands( out = stdout.read().strip().decode() err = stderr.read().strip().decode() if out or err: - raise ProvisioningError( + raise SSHProvisioningError( f"The command 'run_pre_start_commands' didn't work. stdout: {out}, stderr: {err}" ) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"run_pre-start_commands failed: {e}") from e + raise SSHProvisioningError(f"run_pre-start_commands failed: {e}") from e def run_shim_as_systemd_service( @@ -158,11 +156,11 @@ def run_shim_as_systemd_service( out = stdout.read().strip().decode() err = stderr.read().strip().decode() if out or err: - raise ProvisioningError( + raise SSHProvisioningError( f"The command 'run_shim_as_systemd_service' didn't work. stdout: {out}, stderr: {err}" ) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"run_shim_as_systemd failed: {e}") from e + raise SSHProvisioningError(f"run_shim_as_systemd failed: {e}") from e def check_dstack_shim_service(client: paramiko.SSHClient): @@ -170,12 +168,12 @@ def check_dstack_shim_service(client: paramiko.SSHClient): _, stdout, _ = client.exec_command("sudo systemctl status dstack-shim.service", timeout=10) status = stdout.read() except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"Checking dstack-shim.service status failed: {e}") from e + raise SSHProvisioningError(f"Checking dstack-shim.service status failed: {e}") from e for raw_line in status.splitlines(): line = raw_line.decode() if line.strip().startswith("Active: failed"): - raise ProvisioningError(f"The dstack-shim service doesn't start: {line.strip()}") + raise SSHProvisioningError(f"The dstack-shim service doesn't start: {line.strip()}") def remove_host_info_if_exists(client: paramiko.SSHClient, working_dir: str) -> None: @@ -188,7 +186,7 @@ def remove_host_info_if_exists(client: paramiko.SSHClient, working_dir: str) -> if err: logger.debug(f"{HOST_INFO_FILE} hasn't been removed: %s", err) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"remove_host_info_if_exists failed: {e}") + raise SSHProvisioningError(f"remove_host_info_if_exists failed: {e}") def remove_dstack_runner_if_exists(client: paramiko.SSHClient, path: str) -> None: @@ -198,7 +196,7 @@ def remove_dstack_runner_if_exists(client: paramiko.SSHClient, path: str) -> Non if err: logger.debug(f"{path} hasn't been removed: %s", err) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"remove_dstack_runner_if_exists failed: {e}") + raise SSHProvisioningError(f"remove_dstack_runner_if_exists failed: {e}") def get_host_info(client: paramiko.SSHClient, working_dir: str) -> Dict[str, Any]: @@ -224,11 +222,11 @@ def get_host_info(client: paramiko.SSHClient, working_dir: str) -> Dict[str, Any return host_info except ValueError: # JSON parse error check_dstack_shim_service(client) - raise ProvisioningError("Cannot parse host_info") + raise SSHProvisioningError("Cannot parse host_info") time.sleep(iter_delay) else: check_dstack_shim_service(client) - raise ProvisioningError("Cannot get host_info") + raise SSHProvisioningError("Cannot get host_info") def get_shim_healthcheck(client: paramiko.SSHClient) -> str: @@ -240,7 +238,7 @@ def get_shim_healthcheck(client: paramiko.SSHClient) -> str: return healthcheck logger.debug("healthcheck is empty. retry") time.sleep(iter_delay) - raise ProvisioningError("Cannot get HealthcheckResponse") + raise SSHProvisioningError("Cannot get HealthcheckResponse") def _get_shim_healthcheck(client: paramiko.SSHClient) -> Optional[str]: @@ -251,9 +249,11 @@ def _get_shim_healthcheck(client: paramiko.SSHClient) -> Optional[str]: out = stdout.read().strip().decode() err = stderr.read().strip().decode() except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"get_shim_healthcheck failed: {e}") from e + raise SSHProvisioningError(f"get_shim_healthcheck failed: {e}") from e if err: - raise ProvisioningError(f"get_shim_healthcheck didn't work. stdout: {out}, stderr: {err}") + raise SSHProvisioningError( + f"get_shim_healthcheck didn't work. stdout: {out}, stderr: {err}" + ) if not out: return None return out @@ -306,7 +306,7 @@ def get_paramiko_connection( ) -> Generator[paramiko.SSHClient, None, None]: if proxy is not None: if proxy_pkeys is None: - raise ProvisioningError("Missing proxy private keys") + raise SSHProvisioningError("Missing proxy private keys") proxy_ctx = get_paramiko_connection( proxy.username, proxy.hostname, proxy.port, proxy_pkeys ) @@ -321,7 +321,7 @@ def get_paramiko_connection( try: proxy_channel = transport.open_channel("direct-tcpip", (host, port), ("", 0)) except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"Proxy channel failed: {e}") from e + raise SSHProvisioningError(f"Proxy channel failed: {e}") from e client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) for pkey in pkeys: logger.debug("Try to connect to %s with key %s", conn_url, pkey.fingerprint) @@ -333,7 +333,7 @@ def get_paramiko_connection( f'Authentication failed to connect to "{conn_url}" and {pkey.fingerprint}' ) keys_fp = ", ".join(f"{pk.fingerprint!r}" for pk in pkeys) - raise ProvisioningError( + raise SSHProvisioningError( f"SSH connection to the {conn_url} with keys [{keys_fp}] was unsuccessful" ) @@ -347,7 +347,7 @@ def _paramiko_connect( channel: Optional[paramiko.Channel] = None, ) -> bool: """ - Returns `True` if connected, `False` if auth failed, and raises `ProvisioningError` + Returns `True` if connected, `False` if auth failed, and raises `SSHProvisioningError` on other errors. """ try: @@ -365,4 +365,4 @@ def _paramiko_connect( except paramiko.AuthenticationException: return False except (paramiko.SSHException, OSError) as e: - raise ProvisioningError(f"Connect failed: {e}") from e + raise SSHProvisioningError(f"Connect failed: {e}") from e diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_instances.py b/src/tests/_internal/server/background/scheduled_tasks/test_instances.py index c9bcc9e7b0..448184f5a7 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_instances.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_instances.py @@ -19,6 +19,7 @@ NoCapacityError, NotYetTerminated, ProvisioningError, + SSHProvisioningError, ) from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.fleets import InstanceGroupPlacement @@ -1204,6 +1205,28 @@ async def test_adds_ssh_instance( assert instance.total_blocks == expected_blocks assert instance.busy_blocks == 0 + async def test_retries_ssh_instance_if_provisioning_fails( + self, + session: AsyncSession, + deploy_instance_mock: Mock, + ): + deploy_instance_mock.side_effect = SSHProvisioningError("Expected") + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await process_instances() + + await session.refresh(instance) + assert instance.status == InstanceStatus.PENDING + assert instance.termination_reason is None + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) From 81ac9a2ce01f4c90a73953e954b6280f3900d6ab Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 2 Mar 2026 12:57:27 +0500 Subject: [PATCH 04/51] Fix _add_remote() nested try-excepts --- .../background/scheduled_tasks/instances.py | 68 +++++++++++-------- .../scheduled_tasks/test_instances.py | 49 +++++++++++++ 2 files changed, 87 insertions(+), 30 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/instances.py b/src/dstack/_internal/server/background/scheduled_tasks/instances.py index 6b671d18af..39dc76aa27 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/instances.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/instances.py @@ -312,44 +312,52 @@ async def _add_remote(session: AsyncSession, instance: InstanceModel) -> None: switch_instance_status(session, instance, InstanceStatus.TERMINATED) return + remote_details = get_instance_remote_connection_info(instance) + assert remote_details is not None + try: - remote_details = get_instance_remote_connection_info(instance) - assert remote_details is not None - # Prepare connection key - try: - pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys) - if remote_details.ssh_proxy_keys is not None: - ssh_proxy_pkeys = _ssh_keys_to_pkeys(remote_details.ssh_proxy_keys) - else: - ssh_proxy_pkeys = None - except (ValueError, PasswordRequiredException): - instance.termination_reason = InstanceTerminationReason.ERROR - instance.termination_reason_message = "Unsupported private SSH key type" - switch_instance_status(session, instance, InstanceStatus.TERMINATED) - return + pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys) + if remote_details.ssh_proxy_keys is not None: + ssh_proxy_pkeys = _ssh_keys_to_pkeys(remote_details.ssh_proxy_keys) + else: + ssh_proxy_pkeys = None + except (ValueError, PasswordRequiredException): + instance.termination_reason = InstanceTerminationReason.ERROR + instance.termination_reason_message = "Unsupported private SSH key type" + switch_instance_status(session, instance, InstanceStatus.TERMINATED) + return - authorized_keys = [pk.public.strip() for pk in remote_details.ssh_keys] - authorized_keys.append(instance.project.ssh_public_key.strip()) + authorized_keys = [pk.public.strip() for pk in remote_details.ssh_keys] + authorized_keys.append(instance.project.ssh_public_key.strip()) - try: - future = run_async( - _deploy_instance, remote_details, pkeys, ssh_proxy_pkeys, authorized_keys - ) - deploy_timeout = 20 * 60 # 20 minutes - result = await asyncio.wait_for(future, timeout=deploy_timeout) - health, host_info, arch = result - except (asyncio.TimeoutError, TimeoutError) as e: - raise SSHProvisioningError(f"Deploy timeout: {e}") from e - except Exception as e: - raise SSHProvisioningError(f"Deploy instance raised an error: {e}") from e + try: + future = run_async( + _deploy_instance, remote_details, pkeys, ssh_proxy_pkeys, authorized_keys + ) + deploy_timeout = 20 * 60 # 20 minutes + health, host_info, arch = await asyncio.wait_for(future, timeout=deploy_timeout) + except (asyncio.TimeoutError, TimeoutError) as e: + logger.warning( + "%s: deploy timeout when adding SSH instance: %s", + fmt(instance), + repr(e), + ) + # Stays in PENDING, may retry later + return except SSHProvisioningError as e: logger.warning( - "Provisioning instance %s could not be completed because of the error: %s", - instance.name, - e, + "%s: provisioning error when adding SSH instance: %s", + fmt(instance), + repr(e), ) # Stays in PENDING, may retry later return + except Exception: + logger.exception("%s: unexpected error when adding SSH instance", fmt(instance)) + instance.termination_reason = InstanceTerminationReason.ERROR + instance.termination_reason_message = "Unexpected error when adding SSH instance" + switch_instance_status(session, instance, InstanceStatus.TERMINATED) + return instance_type = host_info_to_instance_type(host_info, arch) instance_network = None diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_instances.py b/src/tests/_internal/server/background/scheduled_tasks/test_instances.py index 448184f5a7..835b9842c3 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_instances.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_instances.py @@ -1227,6 +1227,55 @@ async def test_retries_ssh_instance_if_provisioning_fails( assert instance.status == InstanceStatus.PENDING assert instance.termination_reason is None + async def test_terminates_ssh_instance_if_deploy_fails_unexpectedly( + self, + session: AsyncSession, + deploy_instance_mock: Mock, + ): + deploy_instance_mock.side_effect = RuntimeError("Unexpected") + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await process_instances() + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert instance.termination_reason_message == "Unexpected error when adding SSH instance" + + async def test_terminates_ssh_instance_if_key_is_invalid( + self, + session: AsyncSession, + monkeypatch: pytest.MonkeyPatch, + ): + monkeypatch.setattr( + "dstack._internal.server.background.scheduled_tasks.instances._ssh_keys_to_pkeys", + Mock(side_effect=ValueError("Bad key")), + ) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await process_instances() + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert instance.termination_reason_message == "Unsupported private SSH key type" + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) From 18921ed39729d1a0fa616fd8d2fdc249908bfc61 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 2 Mar 2026 13:24:05 +0500 Subject: [PATCH 05/51] Refactor _resolve_ssh_instance_network --- AGENTS.md | 1 + .../background/scheduled_tasks/instances.py | 64 ++++++++++-------- .../scheduled_tasks/test_instances.py | 65 +++++++++++++++++++ 3 files changed, 104 insertions(+), 26 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index eb348b291b..336b97bb5b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -18,6 +18,7 @@ - Python targets 3.9+ with 4-space indentation and max line length of 99 (see `ruff.toml`; `E501` is ignored but keep lines readable). - Imports are sorted via Ruff’s isort settings (`dstack` treated as first-party). - Keep primary/public functions before local helper functions in a module section. +- Keep private classes, exceptions, and similar implementation-specific types close to the private functions that use them unless they are shared more broadly in the module. - Prefer pydantic-style models in `core/models`. - Tests use `test_*.py` modules and `test_*` functions; fixtures live near usage. diff --git a/src/dstack/_internal/server/background/scheduled_tasks/instances.py b/src/dstack/_internal/server/background/scheduled_tasks/instances.py index 39dc76aa27..a8b0e1125e 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/instances.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/instances.py @@ -360,36 +360,13 @@ async def _add_remote(session: AsyncSession, instance: InstanceModel) -> None: return instance_type = host_info_to_instance_type(host_info, arch) - instance_network = None - internal_ip = None try: - default_jpd = JobProvisioningData.__response__.parse_raw(instance.job_provisioning_data) - instance_network = default_jpd.instance_network - internal_ip = default_jpd.internal_ip - except ValidationError: - pass - - host_network_addresses = host_info.get("addresses", []) - if internal_ip is None: - internal_ip = get_ip_from_network( - network=instance_network, - addresses=host_network_addresses, - ) - if instance_network is not None and internal_ip is None: + instance_network, internal_ip = _resolve_ssh_instance_network(instance, host_info) + except _SSHInstanceNetworkResolutionError as e: instance.termination_reason = InstanceTerminationReason.ERROR - instance.termination_reason_message = ( - "Failed to locate internal IP address on the given network" - ) + instance.termination_reason_message = str(e) switch_instance_status(session, instance, InstanceStatus.TERMINATED) return - if internal_ip is not None: - if not is_ip_among_addresses(ip_address=internal_ip, addresses=host_network_addresses): - instance.termination_reason = InstanceTerminationReason.ERROR - instance.termination_reason_message = ( - "Specified internal IP not found among instance interfaces" - ) - switch_instance_status(session, instance, InstanceStatus.TERMINATED) - return divisible, blocks = is_divisible_into_blocks( cpu_count=instance_type.resources.cpus, @@ -440,6 +417,41 @@ async def _add_remote(session: AsyncSession, instance: InstanceModel) -> None: instance.started_at = get_current_datetime() +class _SSHInstanceNetworkResolutionError(Exception): + pass + + +def _resolve_ssh_instance_network( + instance: InstanceModel, host_info: dict[str, Any] +) -> tuple[Optional[str], Optional[str]]: + instance_network = None + internal_ip = None + try: + default_jpd = JobProvisioningData.__response__.parse_raw(instance.job_provisioning_data) + instance_network = default_jpd.instance_network + internal_ip = default_jpd.internal_ip + except ValidationError: + pass + + host_network_addresses = host_info.get("addresses", []) + if internal_ip is None: + internal_ip = get_ip_from_network( + network=instance_network, + addresses=host_network_addresses, + ) + if instance_network is not None and internal_ip is None: + raise _SSHInstanceNetworkResolutionError( + "Failed to locate internal IP address on the given network" + ) + if internal_ip is not None and not is_ip_among_addresses( + ip_address=internal_ip, addresses=host_network_addresses + ): + raise _SSHInstanceNetworkResolutionError( + "Specified internal IP not found among instance interfaces" + ) + return instance_network, internal_ip + + def _deploy_instance( remote_details: RemoteConnectionInfo, pkeys: list[PKey], diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_instances.py b/src/tests/_internal/server/background/scheduled_tasks/test_instances.py index 835b9842c3..88e4acc949 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_instances.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_instances.py @@ -1276,6 +1276,71 @@ async def test_terminates_ssh_instance_if_key_is_invalid( assert instance.termination_reason == InstanceTerminationReason.ERROR assert instance.termination_reason_message == "Unsupported private SSH key type" + async def test_terminates_ssh_instance_if_internal_ip_cannot_be_resolved_from_network( + self, + session: AsyncSession, + host_info: dict, + ): + host_info["addresses"] = ["192.168.100.100/24"] + project = await create_project(session=session) + job_provisioning_data = get_job_provisioning_data( + dockerized=True, + backend=BackendType.REMOTE, + internal_ip=None, + ) + job_provisioning_data.instance_network = "10.0.0.0/24" + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + job_provisioning_data=job_provisioning_data, + ) + await session.commit() + + await process_instances() + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert ( + instance.termination_reason_message + == "Failed to locate internal IP address on the given network" + ) + + async def test_terminates_ssh_instance_if_internal_ip_is_not_in_host_interfaces( + self, + session: AsyncSession, + host_info: dict, + ): + host_info["addresses"] = ["192.168.100.100/24"] + project = await create_project(session=session) + job_provisioning_data = get_job_provisioning_data( + dockerized=True, + backend=BackendType.REMOTE, + internal_ip="10.0.0.20", + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + job_provisioning_data=job_provisioning_data, + ) + await session.commit() + + await process_instances() + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert ( + instance.termination_reason_message + == "Specified internal IP not found among instance interfaces" + ) + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) From 20d1e08ce3160c3c9eb35d9251049d42f350d73c Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 2 Mar 2026 14:30:56 +0500 Subject: [PATCH 06/51] Refactor _process_instance() into thin dispatcher --- .../background/scheduled_tasks/instances.py | 123 ++++++++++-------- 1 file changed, 71 insertions(+), 52 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/instances.py b/src/dstack/_internal/server/background/scheduled_tasks/instances.py index a8b0e1125e..3fc3174565 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/instances.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/instances.py @@ -199,63 +199,81 @@ async def _process_next_instance(): async def _process_instance(session: AsyncSession, instance: InstanceModel): logger.debug("%s: processing instance, status: %s", fmt(instance), instance.status.upper()) - # Refetch to load related attributes. - # Load related attributes only for statuses that always need them. - if instance.status in ( - InstanceStatus.PENDING, - InstanceStatus.TERMINATING, - ): - res = await session.execute( - select(InstanceModel) - .where(InstanceModel.id == instance.id) - .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) - .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) - .options( - joinedload(InstanceModel.fleet).joinedload( - FleetModel.instances.and_(InstanceModel.deleted == False) - ), - ) - .execution_options(populate_existing=True) - ) - instance = res.unique().scalar_one() + if instance.status == InstanceStatus.PENDING: + await _process_pending_instance(session, instance) + elif instance.status == InstanceStatus.PROVISIONING: + await _process_provisioning_instance(session, instance) elif instance.status == InstanceStatus.IDLE: - res = await session.execute( - select(InstanceModel) - .where(InstanceModel.id == instance.id) - .options(joinedload(InstanceModel.project)) - .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) - .options( - joinedload(InstanceModel.fleet).joinedload( - FleetModel.instances.and_(InstanceModel.deleted == False) - ), + await _process_idle_instance(session, instance) + elif instance.status == InstanceStatus.BUSY: + await _process_busy_instance(session, instance) + elif instance.status == InstanceStatus.TERMINATING: + await _process_terminating_instance(session, instance) + + instance.last_processed_at = get_current_datetime() + await session.commit() + + +async def _process_pending_instance(session: AsyncSession, instance: InstanceModel): + instance = await _refetch_instance_for_pending_or_terminating(session, instance.id) + if is_ssh_instance(instance): + await _add_remote(session, instance) + else: + await _create_instance(session=session, instance=instance) + + +async def _process_provisioning_instance(session: AsyncSession, instance: InstanceModel): + await _check_instance(session, instance) + + +async def _process_idle_instance(session: AsyncSession, instance: InstanceModel): + instance = await _refetch_instance_for_idle(session, instance.id) + idle_duration_expired = _check_and_mark_terminating_if_idle_duration_expired(session, instance) + if not idle_duration_expired: + await _check_instance(session, instance) + + +async def _process_busy_instance(session: AsyncSession, instance: InstanceModel): + await _check_instance(session, instance) + + +async def _process_terminating_instance(session: AsyncSession, instance: InstanceModel): + instance = await _refetch_instance_for_pending_or_terminating(session, instance.id) + await _terminate(session, instance) + + +async def _refetch_instance_for_pending_or_terminating( + session: AsyncSession, instance_id +) -> InstanceModel: + res = await session.execute( + select(InstanceModel) + .where(InstanceModel.id == instance_id) + .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) + .options( + joinedload(InstanceModel.fleet).joinedload( + FleetModel.instances.and_(InstanceModel.deleted == False) ) - .execution_options(populate_existing=True) ) - instance = res.unique().scalar_one() + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one() - if instance.status == InstanceStatus.PENDING: - if is_ssh_instance(instance): - await _add_remote(session, instance) - else: - await _create_instance( - session=session, - instance=instance, + +async def _refetch_instance_for_idle(session: AsyncSession, instance_id) -> InstanceModel: + res = await session.execute( + select(InstanceModel) + .where(InstanceModel.id == instance_id) + .options(joinedload(InstanceModel.project)) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) + .options( + joinedload(InstanceModel.fleet).joinedload( + FleetModel.instances.and_(InstanceModel.deleted == False) ) - elif instance.status in ( - InstanceStatus.PROVISIONING, - InstanceStatus.IDLE, - InstanceStatus.BUSY, - ): - idle_duration_expired = _check_and_mark_terminating_if_idle_duration_expired( - session, instance ) - if not idle_duration_expired: - await _check_instance(session, instance) - elif instance.status == InstanceStatus.TERMINATING: - await _terminate(session, instance) - - instance.last_processed_at = get_current_datetime() - await session.commit() + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one() def _check_and_mark_terminating_if_idle_duration_expired( @@ -1142,7 +1160,8 @@ def _get_termination_deadline(instance: InstanceModel) -> datetime.datetime: def _need_to_wait_fleet_provisioning( - instance: InstanceModel, master_instance: InstanceModel + instance: InstanceModel, + master_instance: InstanceModel, ) -> bool: # Cluster cloud instances should wait for the first fleet instance to be provisioned # so that they are provisioned in the same backend/region From c16a502ecd6e6482e0279bcb9013545076287eb8 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 2 Mar 2026 14:42:31 +0500 Subject: [PATCH 07/51] Refactor instance check code --- .../background/scheduled_tasks/instances.py | 119 ++++++++++++------ 1 file changed, 79 insertions(+), 40 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/instances.py b/src/dstack/_internal/server/background/scheduled_tasks/instances.py index 3fc3174565..b3de9cb305 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/instances.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/instances.py @@ -733,50 +733,22 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non switch_instance_status(session, instance, InstanceStatus.BUSY) return - ssh_private_keys = get_instance_ssh_private_keys(instance) - - health_check_cutoff = get_current_datetime() - timedelta( - seconds=server_settings.SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS - ) - res = await session.execute( - select(func.count(1)).where( - InstanceHealthCheckModel.instance_id == instance.id, - InstanceHealthCheckModel.collected_at > health_check_cutoff, - ) + check_instance_health = await _should_check_instance_health(session, instance) + instance_check = await _run_instance_check( + instance=instance, + job_provisioning_data=job_provisioning_data, + check_instance_health=check_instance_health, ) - check_instance_health = res.scalar_one() == 0 - - # May return False if fails to establish ssh connection - instance_check = await run_async( - _check_instance_inner, - ssh_private_keys, - job_provisioning_data, - None, + health_status = _get_health_status_for_instance_check( instance=instance, + instance_check=instance_check, check_instance_health=check_instance_health, ) - if instance_check is False: - instance_check = InstanceCheck(reachable=False, message="SSH or tunnel error") - - if instance_check.reachable and check_instance_health: - health_status = instance_check.get_health_status() - else: - # Keep previous health status - health_status = instance.health - - loglevel = logging.DEBUG - if not instance_check.reachable and instance.status.is_available(): - loglevel = logging.WARNING - elif check_instance_health and not health_status.is_healthy(): - loglevel = logging.WARNING - logger.log( - loglevel, - "Instance %s check: reachable=%s health_status=%s message=%r", - instance.name, - instance_check.reachable, - health_status.name, - instance_check.message, - extra={"instance_name": instance.name, "health_status": health_status}, + _log_instance_check_result( + instance=instance, + instance_check=instance_check, + health_status=health_status, + check_instance_health=check_instance_health, ) if instance_check.has_health_checks(): @@ -823,6 +795,73 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non switch_instance_status(session, instance, InstanceStatus.TERMINATING) +async def _should_check_instance_health(session: AsyncSession, instance: InstanceModel) -> bool: + health_check_cutoff = get_current_datetime() - timedelta( + seconds=server_settings.SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS + ) + result = await session.execute( + select(func.count(1)).where( + InstanceHealthCheckModel.instance_id == instance.id, + InstanceHealthCheckModel.collected_at > health_check_cutoff, + ) + ) + return result.scalar_one() == 0 + + +async def _run_instance_check( + instance: InstanceModel, + job_provisioning_data: JobProvisioningData, + check_instance_health: bool, +) -> InstanceCheck: + ssh_private_keys = get_instance_ssh_private_keys(instance) + + # May return False if fails to establish ssh connection + instance_check = await run_async( + _check_instance_inner, + ssh_private_keys, + job_provisioning_data, + None, + instance=instance, + check_instance_health=check_instance_health, + ) + if instance_check is False: + return InstanceCheck(reachable=False, message="SSH or tunnel error") + return instance_check + + +def _get_health_status_for_instance_check( + instance: InstanceModel, + instance_check: InstanceCheck, + check_instance_health: bool, +) -> HealthStatus: + if instance_check.reachable and check_instance_health: + return instance_check.get_health_status() + # Keep previous health status + return instance.health + + +def _log_instance_check_result( + instance: InstanceModel, + instance_check: InstanceCheck, + health_status: HealthStatus, + check_instance_health: bool, +) -> None: + loglevel = logging.DEBUG + if not instance_check.reachable and instance.status.is_available(): + loglevel = logging.WARNING + elif check_instance_health and not health_status.is_healthy(): + loglevel = logging.WARNING + logger.log( + loglevel, + "Instance %s check: reachable=%s health_status=%s message=%r", + instance.name, + instance_check.reachable, + health_status.name, + instance_check.message, + extra={"instance_name": instance.name, "health_status": health_status}, + ) + + async def _wait_for_instance_provisioning_data( session: AsyncSession, project: ProjectModel, From 66f97222c0c3fe78eb8a7421165639c1e8b7d1e7 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 2 Mar 2026 15:43:46 +0500 Subject: [PATCH 08/51] Add fetchers tests --- .../pipeline_tasks/test_compute_groups.py | 116 ++++++++++- .../background/pipeline_tasks/test_fleets.py | 120 ++++++++++- .../pipeline_tasks/test_gateways.py | 178 +++++++++++++++- .../pipeline_tasks/test_instances.py | 197 ++++++++++++++++++ .../pipeline_tasks/test_placement_groups.py | 145 ++++++++++++- .../background/pipeline_tasks/test_volumes.py | 156 +++++++++++++- 6 files changed, 905 insertions(+), 7 deletions(-) create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_instances.py diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py index 6d24669f7c..cfc1c48d33 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py @@ -1,5 +1,6 @@ +import asyncio import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from unittest.mock import Mock, patch import pytest @@ -9,7 +10,11 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.compute_groups import ComputeGroupStatus from dstack._internal.server.background.pipeline_tasks.base import PipelineItem -from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupWorker +from dstack._internal.server.background.pipeline_tasks.compute_groups import ( + ComputeGroupFetcher, + ComputeGroupPipeline, + ComputeGroupWorker, +) from dstack._internal.server.models import ComputeGroupModel from dstack._internal.server.testing.common import ( ComputeMockSpec, @@ -17,6 +22,7 @@ create_fleet, create_project, ) +from dstack._internal.utils.common import get_current_datetime @pytest.fixture @@ -24,6 +30,17 @@ def worker() -> ComputeGroupWorker: return ComputeGroupWorker(queue=Mock(), heartbeater=Mock()) +@pytest.fixture +def fetcher() -> ComputeGroupFetcher: + return ComputeGroupFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=15), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + def _compute_group_to_pipeline_item(compute_group: ComputeGroupModel) -> PipelineItem: assert compute_group.lock_token is not None assert compute_group.lock_expires_at is not None @@ -36,6 +53,101 @@ def _compute_group_to_pipeline_item(compute_group: ComputeGroupModel) -> Pipelin ) +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestComputeGroupFetcher: + async def test_fetch_selects_eligible_compute_groups_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: ComputeGroupFetcher + ): + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + eligible = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=stale - timedelta(seconds=2), + ) + finished = await create_compute_group( + session=session, + project=project, + fleet=fleet, + status=ComputeGroupStatus.TERMINATED, + last_processed_at=stale - timedelta(seconds=1), + ) + recent = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=now, + ) + locked = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=stale, + ) + locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert [item.id for item in items] == [eligible.id] + + for compute_group in [eligible, finished, recent, locked]: + await session.refresh(compute_group) + + assert eligible.lock_owner == ComputeGroupPipeline.__name__ + assert eligible.lock_expires_at is not None + assert eligible.lock_token is not None + + assert finished.lock_owner is None + assert recent.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_returns_oldest_compute_groups_first_up_to_limit( + self, test_db, session: AsyncSession, fetcher: ComputeGroupFetcher + ): + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + now = get_current_datetime() + + oldest = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=now - timedelta(minutes=3), + ) + middle = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=now - timedelta(minutes=2), + ) + newest = await create_compute_group( + session=session, + project=project, + fleet=fleet, + last_processed_at=now - timedelta(minutes=1), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == ComputeGroupPipeline.__name__ + assert middle.lock_owner == ComputeGroupPipeline.__name__ + assert newest.lock_owner is None + + class TestComputeGroupWorker: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index 746ddf2ea4..90441a30ef 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -1,5 +1,6 @@ +import asyncio import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from unittest.mock import Mock import pytest @@ -12,6 +13,8 @@ from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.server.background.pipeline_tasks.base import PipelineItem from dstack._internal.server.background.pipeline_tasks.fleets import ( + FleetFetcher, + FleetPipeline, FleetWorker, ) from dstack._internal.server.models import FleetModel, InstanceModel @@ -26,6 +29,7 @@ create_user, get_fleet_spec, ) +from dstack._internal.utils.common import get_current_datetime @pytest.fixture @@ -33,6 +37,17 @@ def worker() -> FleetWorker: return FleetWorker(queue=Mock(), heartbeater=Mock()) +@pytest.fixture +def fetcher() -> FleetFetcher: + return FleetFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=60), + lock_timeout=timedelta(seconds=20), + heartbeater=Mock(), + ) + + def _fleet_to_pipeline_item(fleet: FleetModel) -> PipelineItem: assert fleet.lock_token is not None assert fleet.lock_expires_at is not None @@ -45,6 +60,109 @@ def _fleet_to_pipeline_item(fleet: FleetModel) -> PipelineItem: ) +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestFleetFetcher: + async def test_fetch_selects_eligible_fleets_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: FleetFetcher + ): + project = await create_project(session) + now = get_current_datetime() + + stale = await create_fleet( + session=session, + project=project, + last_processed_at=now - timedelta(minutes=3), + ) + just_created = await create_fleet( + session=session, + project=project, + created_at=now, + last_processed_at=now, + name="just-created", + ) + deleted = await create_fleet( + session=session, + project=project, + deleted=True, + name="deleted", + last_processed_at=now - timedelta(minutes=2), + ) + recent = await create_fleet( + session=session, + project=project, + created_at=now - timedelta(minutes=2), + last_processed_at=now, + name="recent", + ) + locked = await create_fleet( + session=session, + project=project, + name="locked", + last_processed_at=now - timedelta(minutes=1, seconds=1), + ) + locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert {item.id for item in items} == {stale.id, just_created.id} + + for fleet in [stale, just_created, deleted, recent, locked]: + await session.refresh(fleet) + + assert stale.lock_owner == FleetPipeline.__name__ + assert just_created.lock_owner == FleetPipeline.__name__ + assert stale.lock_expires_at is not None + assert just_created.lock_expires_at is not None + assert stale.lock_token is not None + assert just_created.lock_token is not None + assert len({stale.lock_token, just_created.lock_token}) == 1 + + assert deleted.lock_owner is None + assert recent.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_returns_oldest_fleets_first_up_to_limit( + self, test_db, session: AsyncSession, fetcher: FleetFetcher + ): + project = await create_project(session) + now = get_current_datetime() + + oldest = await create_fleet( + session=session, + project=project, + name="oldest", + last_processed_at=now - timedelta(minutes=4), + ) + middle = await create_fleet( + session=session, + project=project, + name="middle", + last_processed_at=now - timedelta(minutes=3), + ) + newest = await create_fleet( + session=session, + project=project, + name="newest", + last_processed_at=now - timedelta(minutes=2), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == FleetPipeline.__name__ + assert middle.lock_owner == FleetPipeline.__name__ + assert newest.lock_owner is None + + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestFleetWorker: diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py index 59cbd370e9..a1d7b360f9 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py @@ -1,5 +1,6 @@ +import asyncio import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from unittest.mock import MagicMock, Mock, patch import pytest @@ -10,6 +11,8 @@ from dstack._internal.core.errors import BackendError from dstack._internal.core.models.gateways import GatewayProvisioningData, GatewayStatus from dstack._internal.server.background.pipeline_tasks.gateways import ( + GatewayFetcher, + GatewayPipeline, GatewayPipelineItem, GatewayWorker, ) @@ -23,6 +26,7 @@ create_project, list_events, ) +from dstack._internal.utils.common import get_current_datetime @pytest.fixture @@ -30,6 +34,17 @@ def worker() -> GatewayWorker: return GatewayWorker(queue=Mock(), heartbeater=Mock()) +@pytest.fixture +def fetcher() -> GatewayFetcher: + return GatewayFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=15), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + def _gateway_to_pipeline_item(gateway_model: GatewayModel) -> GatewayPipelineItem: assert gateway_model.lock_token is not None assert gateway_model.lock_expires_at is not None @@ -44,6 +59,167 @@ def _gateway_to_pipeline_item(gateway_model: GatewayModel) -> GatewayPipelineIte ) +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestGatewayFetcher: + async def test_fetch_selects_eligible_gateways_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: GatewayFetcher + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + submitted = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="submitted", + status=GatewayStatus.SUBMITTED, + last_processed_at=stale - timedelta(seconds=3), + ) + provisioning = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="provisioning", + status=GatewayStatus.PROVISIONING, + last_processed_at=stale - timedelta(seconds=2), + ) + to_be_deleted = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="to-be-deleted", + status=GatewayStatus.RUNNING, + last_processed_at=stale - timedelta(seconds=1), + ) + to_be_deleted.to_be_deleted = True + + just_created = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="just-created", + status=GatewayStatus.SUBMITTED, + last_processed_at=now, + ) + just_created.created_at = now + just_created.last_processed_at = now + + ineligible_status = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="ineligible-status", + status=GatewayStatus.RUNNING, + last_processed_at=stale, + ) + recent = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="recent", + status=GatewayStatus.SUBMITTED, + last_processed_at=now, + ) + recent.created_at = now - timedelta(minutes=2) + recent.last_processed_at = now + + locked = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="locked", + status=GatewayStatus.SUBMITTED, + last_processed_at=stale + timedelta(seconds=1), + ) + locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert {item.id for item in items} == { + submitted.id, + provisioning.id, + to_be_deleted.id, + just_created.id, + } + assert {(item.id, item.status, item.to_be_deleted) for item in items} == { + (submitted.id, GatewayStatus.SUBMITTED, False), + (provisioning.id, GatewayStatus.PROVISIONING, False), + (to_be_deleted.id, GatewayStatus.RUNNING, True), + (just_created.id, GatewayStatus.SUBMITTED, False), + } + + for gateway in [ + submitted, + provisioning, + to_be_deleted, + just_created, + ineligible_status, + recent, + locked, + ]: + await session.refresh(gateway) + + fetched_gateways = [submitted, provisioning, to_be_deleted, just_created] + assert all(gateway.lock_owner == GatewayPipeline.__name__ for gateway in fetched_gateways) + assert all(gateway.lock_expires_at is not None for gateway in fetched_gateways) + assert all(gateway.lock_token is not None for gateway in fetched_gateways) + assert len({gateway.lock_token for gateway in fetched_gateways}) == 1 + + assert ineligible_status.lock_owner is None + assert recent.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_returns_oldest_gateways_first_up_to_limit( + self, test_db, session: AsyncSession, fetcher: GatewayFetcher + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + now = get_current_datetime() + + oldest = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="oldest", + status=GatewayStatus.SUBMITTED, + last_processed_at=now - timedelta(minutes=3), + ) + middle = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="middle", + status=GatewayStatus.PROVISIONING, + last_processed_at=now - timedelta(minutes=2), + ) + newest = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + name="newest", + status=GatewayStatus.SUBMITTED, + last_processed_at=now - timedelta(minutes=1), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == GatewayPipeline.__name__ + assert middle.lock_owner == GatewayPipeline.__name__ + assert newest.lock_owner is None + + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestGatewayWorkerSubmitted: diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances.py new file mode 100644 index 0000000000..a6e07fe152 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances.py @@ -0,0 +1,197 @@ +import asyncio +import uuid +from datetime import timedelta +from unittest.mock import Mock + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.server.background.pipeline_tasks.instances import ( + InstanceFetcher, + InstancePipeline, +) +from dstack._internal.server.testing.common import ( + create_compute_group, + create_fleet, + create_instance, + create_project, +) +from dstack._internal.utils.common import get_current_datetime + + +@pytest.fixture +def fetcher() -> InstanceFetcher: + return InstanceFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=10), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestInstanceFetcher: + async def test_fetch_selects_eligible_instances_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: InstanceFetcher + ): + project = await create_project(session=session) + fleet = await create_fleet(session=session, project=project) + compute_group = await create_compute_group(session=session, project=project, fleet=fleet) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + pending = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + last_processed_at=stale - timedelta(seconds=5), + ) + provisioning = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + name="provisioning", + last_processed_at=stale - timedelta(seconds=4), + ) + busy = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + name="busy", + last_processed_at=stale - timedelta(seconds=3), + ) + idle = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + name="idle", + last_processed_at=stale - timedelta(seconds=2), + ) + terminating = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + name="terminating", + last_processed_at=stale - timedelta(seconds=1), + ) + + deleted = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + name="deleted", + last_processed_at=stale, + ) + deleted.deleted = True + + recent = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + name="recent", + last_processed_at=now, + ) + + terminating_compute_group = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + name="terminating-compute-group", + last_processed_at=stale + timedelta(seconds=1), + ) + terminating_compute_group.compute_group = compute_group + + locked = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + name="locked", + last_processed_at=stale + timedelta(seconds=2), + ) + locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert {item.id for item in items} == { + pending.id, + provisioning.id, + busy.id, + idle.id, + terminating.id, + } + assert {item.status for item in items} == { + InstanceStatus.PENDING, + InstanceStatus.PROVISIONING, + InstanceStatus.BUSY, + InstanceStatus.IDLE, + InstanceStatus.TERMINATING, + } + + for instance in [ + pending, + provisioning, + busy, + idle, + terminating, + deleted, + recent, + terminating_compute_group, + locked, + ]: + await session.refresh(instance) + + expected_lock_owner = InstancePipeline.__name__ + fetched_instances = [pending, provisioning, busy, idle, terminating] + assert all(instance.lock_owner == expected_lock_owner for instance in fetched_instances) + assert all(instance.lock_expires_at is not None for instance in fetched_instances) + assert all(instance.lock_token is not None for instance in fetched_instances) + assert len({instance.lock_token for instance in fetched_instances}) == 1 + + assert deleted.lock_owner is None + assert recent.lock_owner is None + assert terminating_compute_group.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_respects_order_and_limit( + self, test_db, session: AsyncSession, fetcher: InstanceFetcher + ): + project = await create_project(session=session) + now = get_current_datetime() + + oldest = await create_instance( + session=session, + project=project, + name="oldest", + last_processed_at=now - timedelta(minutes=3), + ) + middle = await create_instance( + session=session, + project=project, + name="middle", + last_processed_at=now - timedelta(minutes=2), + ) + newest = await create_instance( + session=session, + project=project, + name="newest", + last_processed_at=now - timedelta(minutes=1), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == InstancePipeline.__name__ + assert middle.lock_owner == InstancePipeline.__name__ + assert newest.lock_owner is None diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py index c23d5e604d..6fa0ea7682 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py @@ -1,5 +1,6 @@ +import asyncio import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from unittest.mock import Mock, patch import pytest @@ -7,7 +8,11 @@ from dstack._internal.core.errors import PlacementGroupInUseError from dstack._internal.server.background.pipeline_tasks.base import PipelineItem -from dstack._internal.server.background.pipeline_tasks.placement_groups import PlacementGroupWorker +from dstack._internal.server.background.pipeline_tasks.placement_groups import ( + PlacementGroupFetcher, + PlacementGroupPipeline, + PlacementGroupWorker, +) from dstack._internal.server.models import PlacementGroupModel from dstack._internal.server.testing.common import ( ComputeMockSpec, @@ -15,6 +20,7 @@ create_placement_group, create_project, ) +from dstack._internal.utils.common import get_current_datetime @pytest.fixture @@ -22,6 +28,17 @@ def worker() -> PlacementGroupWorker: return PlacementGroupWorker(queue=Mock(), heartbeater=Mock()) +@pytest.fixture +def fetcher() -> PlacementGroupFetcher: + return PlacementGroupFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=15), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + def _placement_group_to_pipeline_item(placement_group: PlacementGroupModel) -> PipelineItem: assert placement_group.lock_token is not None assert placement_group.lock_expires_at is not None @@ -34,6 +51,130 @@ def _placement_group_to_pipeline_item(placement_group: PlacementGroupModel) -> P ) +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestPlacementGroupFetcher: + async def test_fetch_selects_eligible_placement_groups_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: PlacementGroupFetcher + ): + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + eligible = await create_placement_group( + session=session, + project=project, + fleet=fleet, + fleet_deleted=True, + ) + eligible.last_processed_at = stale - timedelta(seconds=2) + + fleet_not_deleted = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="fleet-not-deleted", + fleet_deleted=False, + ) + fleet_not_deleted.last_processed_at = stale - timedelta(seconds=1) + + deleted = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="deleted", + fleet_deleted=True, + deleted=True, + ) + deleted.last_processed_at = stale + + recent = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="recent", + fleet_deleted=True, + ) + recent.last_processed_at = now + + locked = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="locked", + fleet_deleted=True, + ) + locked.last_processed_at = stale + timedelta(seconds=1) + locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert [item.id for item in items] == [eligible.id] + + for placement_group in [eligible, fleet_not_deleted, deleted, recent, locked]: + await session.refresh(placement_group) + + assert eligible.lock_owner == PlacementGroupPipeline.__name__ + assert eligible.lock_expires_at is not None + assert eligible.lock_token is not None + + assert fleet_not_deleted.lock_owner is None + assert deleted.lock_owner is None + assert recent.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_returns_oldest_placement_groups_first_up_to_limit( + self, test_db, session: AsyncSession, fetcher: PlacementGroupFetcher + ): + project = await create_project(session) + fleet = await create_fleet(session=session, project=project) + now = get_current_datetime() + + oldest = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="oldest", + fleet_deleted=True, + ) + oldest.last_processed_at = now - timedelta(minutes=3) + + middle = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="middle", + fleet_deleted=True, + ) + middle.last_processed_at = now - timedelta(minutes=2) + + newest = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="newest", + fleet_deleted=True, + ) + newest.last_processed_at = now - timedelta(minutes=1) + await session.commit() + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == PlacementGroupPipeline.__name__ + assert middle.lock_owner == PlacementGroupPipeline.__name__ + assert newest.lock_owner is None + + class TestPlacementGroupWorker: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py b/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py index 4d22c59b97..63dfaaa45a 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py @@ -1,5 +1,6 @@ +import asyncio import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from unittest.mock import Mock, patch import pytest @@ -9,6 +10,8 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.volumes import VolumeProvisioningData, VolumeStatus from dstack._internal.server.background.pipeline_tasks.volumes import ( + VolumeFetcher, + VolumePipeline, VolumePipelineItem, VolumeWorker, ) @@ -22,6 +25,7 @@ get_volume_provisioning_data, list_events, ) +from dstack._internal.utils.common import get_current_datetime @pytest.fixture @@ -29,6 +33,17 @@ def worker() -> VolumeWorker: return VolumeWorker(queue=Mock(), heartbeater=Mock()) +@pytest.fixture +def fetcher() -> VolumeFetcher: + return VolumeFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=15), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + def _volume_to_pipeline_item(volume_model: VolumeModel) -> VolumePipelineItem: assert volume_model.lock_token is not None assert volume_model.lock_expires_at is not None @@ -43,6 +58,145 @@ def _volume_to_pipeline_item(volume_model: VolumeModel) -> VolumePipelineItem: ) +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestVolumeFetcher: + async def test_fetch_selects_eligible_volumes_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: VolumeFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + submitted = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=stale - timedelta(minutes=1), + last_processed_at=stale - timedelta(seconds=2), + ) + to_be_deleted = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + created_at=stale - timedelta(minutes=1), + last_processed_at=stale - timedelta(seconds=1), + ) + to_be_deleted.to_be_deleted = True + + just_created = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=now, + last_processed_at=now, + ) + + deleted = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=stale - timedelta(minutes=1), + last_processed_at=stale, + deleted_at=stale, + ) + recent = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=now - timedelta(minutes=2), + last_processed_at=now, + ) + locked = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=stale - timedelta(minutes=1), + last_processed_at=stale + timedelta(seconds=1), + ) + locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert {item.id for item in items} == { + submitted.id, + to_be_deleted.id, + just_created.id, + } + assert {(item.id, item.status, item.to_be_deleted) for item in items} == { + (submitted.id, VolumeStatus.SUBMITTED, False), + (to_be_deleted.id, VolumeStatus.ACTIVE, True), + (just_created.id, VolumeStatus.SUBMITTED, False), + } + + for volume in [submitted, to_be_deleted, just_created, deleted, recent, locked]: + await session.refresh(volume) + + fetched_volumes = [submitted, to_be_deleted, just_created] + assert all(volume.lock_owner == VolumePipeline.__name__ for volume in fetched_volumes) + assert all(volume.lock_expires_at is not None for volume in fetched_volumes) + assert all(volume.lock_token is not None for volume in fetched_volumes) + assert len({volume.lock_token for volume in fetched_volumes}) == 1 + + assert deleted.lock_owner is None + assert recent.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_returns_oldest_volumes_first_up_to_limit( + self, test_db, session: AsyncSession, fetcher: VolumeFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + now = get_current_datetime() + + oldest = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=now - timedelta(minutes=4), + last_processed_at=now - timedelta(minutes=3), + ) + middle = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=now - timedelta(minutes=3), + last_processed_at=now - timedelta(minutes=2), + ) + newest = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + created_at=now - timedelta(minutes=2), + last_processed_at=now - timedelta(minutes=1), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == VolumePipeline.__name__ + assert middle.lock_owner == VolumePipeline.__name__ + assert newest.lock_owner is None + + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestVolumeWorkerSubmitted: From f781cc3b497ce9080c0db26efdb00342f7788293 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 10:39:09 +0500 Subject: [PATCH 09/51] Add TestInstanceWorker --- .../pipeline_tasks/test_instances.py | 1976 ++++++++++++++++- 1 file changed, 1959 insertions(+), 17 deletions(-) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances.py index a6e07fe152..8a8c3b87ce 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_instances.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances.py @@ -1,36 +1,156 @@ import asyncio +import datetime as dt +import logging import uuid -from datetime import timedelta -from unittest.mock import Mock +from collections import defaultdict +from contextlib import contextmanager +from typing import Optional +from unittest.mock import AsyncMock, Mock, call, patch +import gpuhunt import pytest +import pytest_asyncio +from freezegun import freeze_time +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.backends.base.compute import GoArchType +from dstack._internal.core.errors import ( + BackendError, + NoCapacityError, + NotYetTerminated, + ProvisioningError, + SSHProvisioningError, +) +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.fleets import InstanceGroupPlacement +from dstack._internal.core.models.health import HealthStatus +from dstack._internal.core.models.instances import ( + Gpu, + InstanceAvailability, + InstanceOffer, + InstanceOfferWithAvailability, + InstanceStatus, + InstanceTerminationReason, + InstanceType, + Resources, +) +from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData +from dstack._internal.core.models.profiles import TerminationPolicy +from dstack._internal.core.models.runs import JobProvisioningData, JobStatus +from dstack._internal.server.background.pipeline_tasks import instances as instances_pipeline from dstack._internal.server.background.pipeline_tasks.instances import ( InstanceFetcher, InstancePipeline, + InstancePipelineItem, + InstanceWorker, +) +from dstack._internal.server.models import ( + InstanceHealthCheckModel, + InstanceModel, + PlacementGroupModel, +) +from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse, DCGMHealthResult +from dstack._internal.server.schemas.instances import InstanceCheck +from dstack._internal.server.schemas.runner import ( + ComponentInfo, + ComponentName, + ComponentStatus, + HealthcheckResponse, + InstanceHealthResponse, + TaskListItem, + TaskListResponse, + TaskStatus, ) +from dstack._internal.server.services.runner.client import ComponentList, ShimClient from dstack._internal.server.testing.common import ( + ComputeMockSpec, create_compute_group, create_fleet, create_instance, + create_job, create_project, + create_repo, + create_run, + create_user, + get_fleet_configuration, + get_fleet_spec, + get_instance_offer_with_availability, + get_job_provisioning_data, + get_placement_group_provisioning_data, + get_remote_connection_info, + list_events, ) from dstack._internal.utils.common import get_current_datetime +pytestmark = pytest.mark.usefixtures("image_config_mock") +LOCK_EXPIRES_AT = dt.datetime(2025, 1, 2, 3, 4, tzinfo=dt.timezone.utc) + @pytest.fixture def fetcher() -> InstanceFetcher: return InstanceFetcher( queue=asyncio.Queue(), queue_desired_minsize=1, - min_processing_interval=timedelta(seconds=10), - lock_timeout=timedelta(seconds=30), + min_processing_interval=dt.timedelta(seconds=10), + lock_timeout=dt.timedelta(seconds=30), heartbeater=Mock(), ) +@pytest.fixture +def worker() -> InstanceWorker: + return InstanceWorker(queue=asyncio.Queue(), heartbeater=Mock()) + + +@pytest.fixture +def host_info() -> dict: + return { + "gpu_vendor": "nvidia", + "gpu_name": "T4", + "gpu_memory": 16384, + "gpu_count": 1, + "addresses": ["192.168.100.100/24"], + "disk_size": 260976517120, + "cpus": 32, + "memory": 33544130560, + } + + +@pytest.fixture +def deploy_instance_mock(monkeypatch: pytest.MonkeyPatch, host_info: dict) -> Mock: + mock = Mock(return_value=(InstanceCheck(reachable=True), host_info, GoArchType.AMD64)) + monkeypatch.setattr(instances_pipeline, "_deploy_instance", mock) + return mock + + +def _instance_to_pipeline_item(instance_model: InstanceModel) -> InstancePipelineItem: + assert instance_model.lock_token is not None + assert instance_model.lock_expires_at is not None + return InstancePipelineItem( + __tablename__=instance_model.__tablename__, + id=instance_model.id, + lock_token=instance_model.lock_token, + lock_expires_at=instance_model.lock_expires_at, + prev_lock_expired=False, + status=instance_model.status, + ) + + +def _lock_instance(instance_model: InstanceModel) -> None: + instance_model.lock_token = uuid.uuid4() + instance_model.lock_expires_at = LOCK_EXPIRES_AT + instance_model.lock_owner = InstancePipeline.__name__ + + +async def _process_instance( + session: AsyncSession, worker: InstanceWorker, instance_model: InstanceModel +) -> None: + _lock_instance(instance_model) + await session.commit() + await worker.process(_instance_to_pipeline_item(instance_model)) + + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestInstanceFetcher: @@ -41,41 +161,41 @@ async def test_fetch_selects_eligible_instances_and_sets_lock_fields( fleet = await create_fleet(session=session, project=project) compute_group = await create_compute_group(session=session, project=project, fleet=fleet) now = get_current_datetime() - stale = now - timedelta(minutes=1) + stale = now - dt.timedelta(minutes=1) pending = await create_instance( session=session, project=project, status=InstanceStatus.PENDING, - last_processed_at=stale - timedelta(seconds=5), + last_processed_at=stale - dt.timedelta(seconds=5), ) provisioning = await create_instance( session=session, project=project, status=InstanceStatus.PROVISIONING, name="provisioning", - last_processed_at=stale - timedelta(seconds=4), + last_processed_at=stale - dt.timedelta(seconds=4), ) busy = await create_instance( session=session, project=project, status=InstanceStatus.BUSY, name="busy", - last_processed_at=stale - timedelta(seconds=3), + last_processed_at=stale - dt.timedelta(seconds=3), ) idle = await create_instance( session=session, project=project, status=InstanceStatus.IDLE, name="idle", - last_processed_at=stale - timedelta(seconds=2), + last_processed_at=stale - dt.timedelta(seconds=2), ) terminating = await create_instance( session=session, project=project, status=InstanceStatus.TERMINATING, name="terminating", - last_processed_at=stale - timedelta(seconds=1), + last_processed_at=stale - dt.timedelta(seconds=1), ) deleted = await create_instance( @@ -100,7 +220,7 @@ async def test_fetch_selects_eligible_instances_and_sets_lock_fields( project=project, status=InstanceStatus.TERMINATING, name="terminating-compute-group", - last_processed_at=stale + timedelta(seconds=1), + last_processed_at=stale + dt.timedelta(seconds=1), ) terminating_compute_group.compute_group = compute_group @@ -109,9 +229,9 @@ async def test_fetch_selects_eligible_instances_and_sets_lock_fields( project=project, status=InstanceStatus.IDLE, name="locked", - last_processed_at=stale + timedelta(seconds=2), + last_processed_at=stale + dt.timedelta(seconds=2), ) - locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_expires_at = now + dt.timedelta(minutes=1) locked.lock_token = uuid.uuid4() locked.lock_owner = "OtherPipeline" @@ -169,19 +289,19 @@ async def test_fetch_respects_order_and_limit( session=session, project=project, name="oldest", - last_processed_at=now - timedelta(minutes=3), + last_processed_at=now - dt.timedelta(minutes=3), ) middle = await create_instance( session=session, project=project, name="middle", - last_processed_at=now - timedelta(minutes=2), + last_processed_at=now - dt.timedelta(minutes=2), ) newest = await create_instance( session=session, project=project, name="newest", - last_processed_at=now - timedelta(minutes=1), + last_processed_at=now - dt.timedelta(minutes=1), ) items = await fetcher.fetch(limit=2) @@ -195,3 +315,1825 @@ async def test_fetch_respects_order_and_limit( assert oldest.lock_owner == InstancePipeline.__name__ assert middle.lock_owner == InstancePipeline.__name__ assert newest.lock_owner is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestInstanceWorker: + @staticmethod + @contextmanager + def mock_terminate_in_backend(error: Optional[Exception] = None): + backend = Mock() + backend.TYPE = BackendType.VERDA + terminate_instance = backend.compute.return_value.terminate_instance + if error is not None: + terminate_instance.side_effect = error + with patch.object( + instances_pipeline.backends_services, + "get_project_backend_by_type", + AsyncMock(return_value=backend), + ): + yield terminate_instance + + async def test_process_skips_when_lock_token_changes( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + ) + + _lock_instance(instance) + await session.commit() + item = _instance_to_pipeline_item(instance) + new_lock_token = uuid.uuid4() + instance.lock_token = new_lock_token + await session.commit() + + await worker.process(item) + await session.refresh(instance) + + assert instance.lock_token == new_lock_token + assert instance.lock_owner == InstancePipeline.__name__ + + async def test_process_unlocks_and_updates_last_processed_at_after_check( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + ) + before_processed_at = instance.last_processed_at + + monkeypatch.setattr( + instances_pipeline, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await _process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.lock_expires_at is None + assert instance.lock_token is None + assert instance.lock_owner is None + assert instance.last_processed_at > before_processed_at + + async def test_check_shim_transitions_provisioning_on_ready( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + ) + instance.termination_deadline = get_current_datetime() + dt.timedelta(days=1) + await session.commit() + + monkeypatch.setattr( + instances_pipeline, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await _process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.termination_deadline is None + + async def test_check_shim_transitions_provisioning_on_terminating( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + ) + instance.started_at = get_current_datetime() + dt.timedelta(minutes=-20) + await session.commit() + + monkeypatch.setattr( + instances_pipeline, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False, message="Shim problem")), + ) + await _process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATING + assert instance.termination_deadline is not None + + async def test_check_shim_transitions_provisioning_on_busy( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + ) + instance.termination_deadline = get_current_datetime().replace( + tzinfo=dt.timezone.utc + ) + dt.timedelta(days=1) + job = await create_job( + session=session, + run=run, + status=JobStatus.SUBMITTED, + instance=instance, + ) + await session.commit() + + monkeypatch.setattr( + instances_pipeline, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await _process_instance(session, worker, instance) + + await session.refresh(instance) + await session.refresh(job) + + assert instance.status == InstanceStatus.BUSY + assert instance.termination_deadline is None + assert job.instance == instance + + async def test_check_shim_start_termination_deadline( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=False, + ) + + monkeypatch.setattr( + instances_pipeline, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False, message="SSH connection fail")), + ) + await _process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is True + assert instance.termination_deadline is not None + assert instance.termination_deadline.replace( + tzinfo=dt.timezone.utc + ) > get_current_datetime() + dt.timedelta(minutes=19) + + async def test_check_shim_does_not_start_termination_deadline_with_ssh_instance( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=False, + remote_connection_info=get_remote_connection_info(), + ) + + monkeypatch.setattr( + instances_pipeline, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False, message="SSH connection fail")), + ) + await _process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is True + assert instance.termination_deadline is None + + async def test_check_shim_stop_termination_deadline( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + ) + instance.termination_deadline = get_current_datetime() + dt.timedelta(minutes=19) + await session.commit() + + monkeypatch.setattr( + instances_pipeline, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await _process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.termination_deadline is None + + async def test_check_shim_terminate_instance_by_deadline( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + ) + termination_deadline_time = get_current_datetime() + dt.timedelta(minutes=-19) + instance.termination_deadline = termination_deadline_time + await session.commit() + + monkeypatch.setattr( + instances_pipeline, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False, message="Not ok")), + ) + await _process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATING + assert instance.termination_deadline == termination_deadline_time + assert instance.termination_reason == InstanceTerminationReason.UNREACHABLE + + @pytest.mark.parametrize( + ["termination_policy", "has_job"], + [ + pytest.param(TerminationPolicy.DESTROY_AFTER_IDLE, False, id="destroy-no-job"), + pytest.param(TerminationPolicy.DESTROY_AFTER_IDLE, True, id="destroy-with-job"), + pytest.param(TerminationPolicy.DONT_DESTROY, False, id="dont-destroy-no-job"), + pytest.param(TerminationPolicy.DONT_DESTROY, True, id="dont-destroy-with-job"), + ], + ) + async def test_check_shim_process_unreachable_state( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + termination_policy: TerminationPolicy, + has_job: bool, + ): + project = await create_project(session=session) + if has_job: + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job = await create_job( + session=session, + run=run, + status=JobStatus.SUBMITTED, + ) + else: + job = None + instance = await create_instance( + session=session, + project=project, + created_at=get_current_datetime(), + termination_policy=termination_policy, + status=InstanceStatus.IDLE, + unreachable=True, + job=job, + ) + + monkeypatch.setattr( + instances_pipeline, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await _process_instance(session, worker, instance) + + await session.refresh(instance) + events = await list_events(session) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is False + assert len(events) == 1 + assert events[0].message == "Instance became reachable" + + @pytest.mark.parametrize("health_status", [HealthStatus.HEALTHY, HealthStatus.FAILURE]) + async def test_check_shim_switch_to_unreachable_state( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + health_status: HealthStatus, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=False, + health_status=health_status, + ) + + monkeypatch.setattr( + instances_pipeline, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False)), + ) + await _process_instance(session, worker, instance) + + await session.refresh(instance) + events = await list_events(session) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is True + assert instance.health == health_status + assert len(events) == 1 + assert events[0].message == "Instance became unreachable" + + async def test_check_shim_check_instance_health( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=False, + health_status=HealthStatus.HEALTHY, + ) + health_response = InstanceHealthResponse( + dcgm=DCGMHealthResponse( + overall_health=DCGMHealthResult.DCGM_HEALTH_RESULT_WARN, + incidents=[], + ) + ) + + monkeypatch.setattr( + instances_pipeline, + "_check_instance_inner", + Mock( + return_value=InstanceCheck( + reachable=True, + health_response=health_response, + ) + ), + ) + await _process_instance(session, worker, instance) + + await session.refresh(instance) + events = await list_events(session) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is False + assert instance.health == HealthStatus.WARNING + assert len(events) == 1 + assert events[0].message == "Instance health changed HEALTHY -> WARNING" + + res = await session.execute(select(InstanceHealthCheckModel)) + health_check = res.scalars().one() + assert health_check.status == HealthStatus.WARNING + assert health_check.response == health_response.json() + + async def test_terminate_by_idle_timeout( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + ) + instance.termination_idle_time = 300 + instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE + instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19) + await session.commit() + + await _process_instance(session, worker, instance) + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATING + assert instance.termination_reason == InstanceTerminationReason.IDLE_TIMEOUT + + async def test_pending_ssh_instance_terminates_on_provision_timeout( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime() - dt.timedelta(days=100), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await _process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.PROVISIONING_TIMEOUT + + async def test_terminate( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19) + await session.commit() + + with self.mock_terminate_in_backend() as mock: + await _process_instance(session, worker, instance) + mock.assert_called_once() + + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.IDLE_TIMEOUT + assert instance.deleted is True + assert instance.deleted_at is not None + assert instance.finished_at is not None + + async def test_terminates_terminating_deleted_instance( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + _lock_instance(instance) + await session.commit() + item = _instance_to_pipeline_item(instance) + instance.deleted = True + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + instance.last_job_processed_at = instance.deleted_at = ( + get_current_datetime() + dt.timedelta(minutes=-19) + ) + await session.commit() + + with self.mock_terminate_in_backend() as mock: + await worker.process(item) + mock.assert_called_once() + + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATED + assert instance.deleted is True + assert instance.deleted_at is not None + assert instance.finished_at is not None + + @pytest.mark.parametrize( + "error", [BackendError("err"), RuntimeError("err"), NotYetTerminated("")] + ) + async def test_terminate_retry( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + error: Exception, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc) + instance.last_job_processed_at = initial_time + instance.last_processed_at = initial_time - dt.timedelta(minutes=1) + await session.commit() + + with ( + freeze_time(initial_time + dt.timedelta(minutes=1)), + self.mock_terminate_in_backend(error=error) as mock, + ): + await _process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATING + + with ( + freeze_time(initial_time + dt.timedelta(minutes=2)), + self.mock_terminate_in_backend(error=None) as mock, + ): + await _process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + + async def test_terminate_not_retries_if_too_early( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc) + instance.last_job_processed_at = initial_time + instance.last_processed_at = initial_time - dt.timedelta(minutes=1) + await session.commit() + + with ( + freeze_time(initial_time + dt.timedelta(minutes=1)), + self.mock_terminate_in_backend(error=BackendError("err")) as mock, + ): + await _process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATING + + instance.last_processed_at = initial_time + await session.commit() + + with ( + freeze_time(initial_time + dt.timedelta(minutes=1, seconds=11)), + self.mock_terminate_in_backend(error=None) as mock, + ): + await _process_instance(session, worker, instance) + mock.assert_not_called() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATING + + async def test_terminate_on_termination_deadline( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc) + instance.last_job_processed_at = initial_time + instance.last_processed_at = initial_time - dt.timedelta(minutes=1) + await session.commit() + + with ( + freeze_time(initial_time + dt.timedelta(minutes=1)), + self.mock_terminate_in_backend(error=BackendError("err")) as mock, + ): + await _process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATING + + with ( + freeze_time(initial_time + dt.timedelta(minutes=15, seconds=55)), + self.mock_terminate_in_backend(error=None) as mock, + ): + await _process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + + @pytest.mark.parametrize( + ["cpus", "gpus", "requested_blocks", "expected_blocks"], + [ + pytest.param(32, 8, 1, 1, id="gpu-instance-no-blocks"), + pytest.param(32, 8, 2, 2, id="gpu-instance-four-gpu-per-block"), + pytest.param(32, 8, 4, 4, id="gpu-instance-two-gpus-per-block"), + pytest.param(32, 8, None, 8, id="gpu-instance-auto-max-gpu"), + pytest.param(4, 8, None, 4, id="gpu-instance-auto-max-cpu"), + pytest.param(8, 8, None, 8, id="gpu-instance-auto-max-cpu-and-gpu"), + pytest.param(32, 0, 1, 1, id="cpu-instance-no-blocks"), + pytest.param(32, 0, 2, 2, id="cpu-instance-four-cpu-per-block"), + pytest.param(32, 0, 4, 4, id="cpu-instance-two-cpus-per-block"), + pytest.param(32, 0, None, 32, id="cpu-instance-auto-max-cpu"), + ], + ) + async def test_creates_instance( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + cpus: int, + gpus: int, + requested_blocks: Optional[int], + expected_blocks: int, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + total_blocks=requested_blocks, + busy_blocks=0, + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + backend_mock = Mock() + m.return_value = [backend_mock] + backend_mock.TYPE = BackendType.AWS + gpu = Gpu(name="T4", memory_mib=16384, vendor=gpuhunt.AcceleratorVendor.NVIDIA) + offer = InstanceOfferWithAvailability( + backend=BackendType.AWS, + instance=InstanceType( + name="instance", + resources=Resources( + cpus=cpus, + memory_mib=131072, + spot=False, + gpus=[gpu] * gpus, + ), + ), + region="us", + price=1.0, + availability=InstanceAvailability.AVAILABLE, + total_blocks=expected_blocks, + ) + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [offer] + backend_mock.compute.return_value.create_instance.return_value = JobProvisioningData( + backend=offer.backend, + instance_type=offer.instance, + instance_id="instance_id", + hostname="1.1.1.1", + internal_ip=None, + region=offer.region, + price=offer.price, + username="ubuntu", + ssh_port=22, + ssh_proxy=None, + dockerized=True, + backend_data=None, + ) + + await _process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + assert instance.total_blocks == expected_blocks + assert instance.busy_blocks == 0 + + @pytest.mark.parametrize("err", [RuntimeError("Unexpected"), ProvisioningError("Expected")]) + async def test_tries_second_offer_if_first_fails( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + err: Exception, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + ) + aws_mock = Mock() + aws_mock.TYPE = BackendType.AWS + offer = get_instance_offer_with_availability(backend=BackendType.AWS, price=1.0) + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + aws_mock.compute.return_value.get_offers.return_value = [offer] + aws_mock.compute.return_value.create_instance.side_effect = err + gcp_mock = Mock() + gcp_mock.TYPE = BackendType.GCP + offer = get_instance_offer_with_availability(backend=BackendType.GCP, price=2.0) + gcp_mock.compute.return_value = Mock(spec=ComputeMockSpec) + gcp_mock.compute.return_value.get_offers.return_value = [offer] + gcp_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=offer.backend, + region=offer.region, + price=offer.price, + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [aws_mock, gcp_mock] + await _process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + aws_mock.compute.return_value.create_instance.assert_called_once() + assert instance.backend == BackendType.GCP + + @pytest.mark.parametrize("err", [RuntimeError("Unexpected"), ProvisioningError("Expected")]) + async def test_fails_if_all_offers_fail( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + err: Exception, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + ) + aws_mock = Mock() + aws_mock.TYPE = BackendType.AWS + offer = get_instance_offer_with_availability(backend=BackendType.AWS, price=1.0) + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + aws_mock.compute.return_value.get_offers.return_value = [offer] + aws_mock.compute.return_value.create_instance.side_effect = err + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [aws_mock] + await _process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.NO_OFFERS + + async def test_fails_if_no_offers( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [] + await _process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.NO_OFFERS + + @pytest.mark.parametrize( + ("placement", "expected_termination_reasons"), + [ + pytest.param( + InstanceGroupPlacement.CLUSTER, + { + InstanceTerminationReason.NO_OFFERS: 1, + InstanceTerminationReason.MASTER_FAILED: 3, + }, + id="cluster", + ), + pytest.param( + None, + {InstanceTerminationReason.NO_OFFERS: 4}, + id="non-cluster", + ), + ], + ) + async def test_terminates_cluster_instances_if_master_not_created( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + placement: Optional[InstanceGroupPlacement], + expected_termination_reasons: dict[str, int], + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec(conf=get_fleet_configuration(placement=placement, nodes=4)), + ) + instances = [ + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=index, + created_at=get_current_datetime() + dt.timedelta(seconds=index), + ) + for index in range(4) + ] + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [] + for instance in sorted(instances, key=lambda i: (i.instance_num, i.created_at)): + await _process_instance(session, worker, instance) + + termination_reasons = defaultdict(int) + for instance in instances: + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + termination_reasons[instance.termination_reason] += 1 + assert termination_reasons == expected_termination_reasons + + @pytest.mark.parametrize( + ("placement", "should_create"), + [ + pytest.param(InstanceGroupPlacement.CLUSTER, True, id="placement-cluster"), + pytest.param(None, False, id="no-placement"), + ], + ) + async def test_create_placement_group_if_placement_cluster( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + placement: Optional[InstanceGroupPlacement], + should_create: bool, + ) -> None: + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec(conf=get_fleet_configuration(placement=placement, nodes=1)), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + ) + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability() + ] + backend_mock.compute.return_value.create_instance.return_value = ( + get_job_provisioning_data() + ) + backend_mock.compute.return_value.create_placement_group.return_value = ( + get_placement_group_provisioning_data() + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await _process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() + if should_create: + assert backend_mock.compute.return_value.create_placement_group.call_count == 1 + assert len(placement_groups) == 1 + else: + assert backend_mock.compute.return_value.create_placement_group.call_count == 0 + assert len(placement_groups) == 0 + + @pytest.mark.parametrize("can_reuse", [True, False]) + async def test_reuses_placement_group_between_offers_if_the_group_is_suitable( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + can_reuse: bool, + ) -> None: + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration(placement=InstanceGroupPlacement.CLUSTER, nodes=1) + ), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + ) + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(instance_type="bad-offer-1"), + get_instance_offer_with_availability(instance_type="bad-offer-2"), + get_instance_offer_with_availability(instance_type="good-offer"), + ] + + def create_instance_method( + instance_offer: InstanceOfferWithAvailability, *args, **kwargs + ) -> JobProvisioningData: + if instance_offer.instance.name == "good-offer": + return get_job_provisioning_data() + raise NoCapacityError() + + backend_mock.compute.return_value.create_instance = create_instance_method + backend_mock.compute.return_value.create_placement_group.return_value = ( + get_placement_group_provisioning_data() + ) + backend_mock.compute.return_value.is_suitable_placement_group.return_value = can_reuse + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await _process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() + if can_reuse: + assert backend_mock.compute.return_value.create_placement_group.call_count == 1 + assert len(placement_groups) == 1 + else: + assert backend_mock.compute.return_value.create_placement_group.call_count == 3 + assert len(placement_groups) == 3 + to_be_deleted_count = sum(pg.fleet_deleted for pg in placement_groups) + assert to_be_deleted_count == 2 + + @pytest.mark.parametrize("err", [NoCapacityError(), RuntimeError()]) + async def test_handles_create_placement_group_errors( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + err: Exception, + ) -> None: + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration(placement=InstanceGroupPlacement.CLUSTER, nodes=1) + ), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + ) + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(instance_type="bad-offer"), + get_instance_offer_with_availability(instance_type="good-offer"), + ] + backend_mock.compute.return_value.create_instance.return_value = ( + get_job_provisioning_data() + ) + + def create_placement_group_method( + placement_group: PlacementGroup, master_instance_offer: InstanceOffer + ) -> PlacementGroupProvisioningData: + if master_instance_offer.instance.name == "good-offer": + return get_placement_group_provisioning_data() + raise err + + backend_mock.compute.return_value.create_placement_group = create_placement_group_method + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await _process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + assert instance.offer + assert "good-offer" in instance.offer + assert "bad-offer" not in instance.offer + placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() + assert len(placement_groups) == 1 + + @pytest.mark.parametrize( + ["cpus", "gpus", "requested_blocks", "expected_blocks"], + [ + pytest.param(32, 8, 1, 1, id="gpu-instance-no-blocks"), + pytest.param(32, 8, 2, 2, id="gpu-instance-four-gpu-per-block"), + pytest.param(32, 8, 4, 4, id="gpu-instance-two-gpus-per-block"), + pytest.param(32, 8, None, 8, id="gpu-instance-auto-max-gpu"), + pytest.param(4, 8, None, 4, id="gpu-instance-auto-max-cpu"), + pytest.param(8, 8, None, 8, id="gpu-instance-auto-max-cpu-and-gpu"), + pytest.param(32, 0, 1, 1, id="cpu-instance-no-blocks"), + pytest.param(32, 0, 2, 2, id="cpu-instance-four-cpu-per-block"), + pytest.param(32, 0, 4, 4, id="cpu-instance-two-cpus-per-block"), + pytest.param(32, 0, None, 32, id="cpu-instance-auto-max-cpu"), + ], + ) + async def test_adds_ssh_instance( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + host_info: dict, + deploy_instance_mock: Mock, + cpus: int, + gpus: int, + requested_blocks: Optional[int], + expected_blocks: int, + ): + host_info["cpus"] = cpus + host_info["gpu_count"] = gpus + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + total_blocks=requested_blocks, + busy_blocks=0, + ) + await session.commit() + + await _process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.IDLE + assert instance.total_blocks == expected_blocks + assert instance.busy_blocks == 0 + deploy_instance_mock.assert_called_once() + + async def test_retries_ssh_instance_if_provisioning_fails( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + deploy_instance_mock: Mock, + ): + deploy_instance_mock.side_effect = SSHProvisioningError("Expected") + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await _process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PENDING + assert instance.termination_reason is None + + async def test_terminates_ssh_instance_if_deploy_fails_unexpectedly( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + deploy_instance_mock: Mock, + ): + deploy_instance_mock.side_effect = RuntimeError("Unexpected") + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await _process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert instance.termination_reason_message == "Unexpected error when adding SSH instance" + + async def test_terminates_ssh_instance_if_key_is_invalid( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + monkeypatch.setattr( + instances_pipeline, + "_ssh_keys_to_pkeys", + Mock(side_effect=ValueError("Bad key")), + ) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await _process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert instance.termination_reason_message == "Unsupported private SSH key type" + + async def test_terminates_ssh_instance_if_internal_ip_cannot_be_resolved_from_network( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + host_info: dict, + deploy_instance_mock: Mock, + ): + host_info["addresses"] = ["192.168.100.100/24"] + project = await create_project(session=session) + job_provisioning_data = get_job_provisioning_data( + dockerized=True, + backend=BackendType.REMOTE, + internal_ip=None, + ) + job_provisioning_data.instance_network = "10.0.0.0/24" + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + job_provisioning_data=job_provisioning_data, + ) + await session.commit() + + await _process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert ( + instance.termination_reason_message + == "Failed to locate internal IP address on the given network" + ) + + async def test_terminates_ssh_instance_if_internal_ip_is_not_in_host_interfaces( + self, + test_db, + session: AsyncSession, + fetcher: InstanceFetcher, + worker: InstanceWorker, + host_info: dict, + deploy_instance_mock: Mock, + ): + host_info["addresses"] = ["192.168.100.100/24"] + project = await create_project(session=session) + job_provisioning_data = get_job_provisioning_data( + dockerized=True, + backend=BackendType.REMOTE, + internal_ip="10.0.0.20", + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + job_provisioning_data=job_provisioning_data, + ) + await session.commit() + + await _process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert ( + instance.termination_reason_message + == "Specified internal IP not found among instance interfaces" + ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("turn_off_keep_shim_tasks_setting") +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestRemoveDanglingTasks: + @pytest.fixture + def turn_off_keep_shim_tasks_setting(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("dstack._internal.server.settings.SERVER_KEEP_SHIM_TASKS", False) + + async def test_terminates_and_removes_dangling_tasks(self, test_db, session: AsyncSession): + user = await create_user(session=session) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + instance=instance, + ) + dangling_task_id_1 = "fe138b77-d0b1-49d3-8c9f-2dfe78ece727" + dangling_task_id_2 = "8b016a75-41de-44f1-91ff-c9b63d2caa1d" + shim_client_mock = Mock(spec_set=ShimClient) + shim_client_mock.is_api_v2_supported.return_value = True + shim_client_mock.list_tasks.return_value = TaskListResponse( + tasks=[ + TaskListItem(id=str(job.id), status=TaskStatus.RUNNING), + TaskListItem(id=dangling_task_id_1, status=TaskStatus.RUNNING), + TaskListItem(id=dangling_task_id_2, status=TaskStatus.TERMINATED), + ] + ) + await session.refresh(instance, attribute_names=["jobs"]) + + instances_pipeline.remove_dangling_tasks_from_instance(shim_client_mock, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.BUSY + + shim_client_mock.terminate_task.assert_called_once_with( + task_id=dangling_task_id_1, + reason=None, + message=None, + timeout=0, + ) + assert shim_client_mock.remove_task.call_count == 2 + shim_client_mock.remove_task.assert_has_calls( + [call(task_id=dangling_task_id_1), call(task_id=dangling_task_id_2)] + ) + + async def test_terminates_and_removes_dangling_tasks_legacy_shim( + self, test_db, session: AsyncSession + ): + user = await create_user(session=session) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + instance=instance, + ) + dangling_task_id_1 = "fe138b77-d0b1-49d3-8c9f-2dfe78ece727" + dangling_task_id_2 = "8b016a75-41de-44f1-91ff-c9b63d2caa1d" + shim_client_mock = Mock(spec_set=ShimClient) + shim_client_mock.is_api_v2_supported.return_value = True + shim_client_mock.list_tasks.return_value = TaskListResponse( + ids=[str(job.id), dangling_task_id_1, dangling_task_id_2] + ) + await session.refresh(instance, attribute_names=["jobs"]) + + instances_pipeline.remove_dangling_tasks_from_instance(shim_client_mock, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.BUSY + + assert shim_client_mock.terminate_task.call_count == 2 + shim_client_mock.terminate_task.assert_has_calls( + [ + call(task_id=dangling_task_id_1, reason=None, message=None, timeout=0), + call(task_id=dangling_task_id_2, reason=None, message=None, timeout=0), + ] + ) + assert shim_client_mock.remove_task.call_count == 2 + shim_client_mock.remove_task.assert_has_calls( + [call(task_id=dangling_task_id_1), call(task_id=dangling_task_id_2)] + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class BaseTestMaybeInstallComponents: + EXPECTED_VERSION = "0.20.1" + + @pytest_asyncio.fixture + async def instance(self, session: AsyncSession) -> InstanceModel: + project = await create_project(session=session) + return await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + + @pytest.fixture + def component_list(self) -> ComponentList: + return ComponentList() + + @pytest.fixture + def debug_task_log(self, caplog: pytest.LogCaptureFixture) -> pytest.LogCaptureFixture: + caplog.set_level(level=logging.DEBUG, logger=instances_pipeline.__name__) + return caplog + + @pytest.fixture + def shim_client_mock( + self, + monkeypatch: pytest.MonkeyPatch, + component_list: ComponentList, + ) -> Mock: + mock = Mock(spec_set=ShimClient) + mock.healthcheck.return_value = HealthcheckResponse( + service="dstack-shim", + version=self.EXPECTED_VERSION, + ) + mock.get_instance_health.return_value = InstanceHealthResponse() + mock.get_components.return_value = component_list + mock.list_tasks.return_value = TaskListResponse(tasks=[]) + mock.is_safe_to_restart.return_value = False + monkeypatch.setattr( + "dstack._internal.server.services.runner.client.ShimClient", + Mock(return_value=mock), + ) + return mock + + +@pytest.mark.usefixtures("get_dstack_runner_version_mock") +class TestMaybeInstallRunner(BaseTestMaybeInstallComponents): + @pytest.fixture + def component_list(self) -> ComponentList: + components = ComponentList() + components.add( + ComponentInfo( + name=ComponentName.RUNNER, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + return components + + @pytest.fixture + def get_dstack_runner_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=self.EXPECTED_VERSION) + monkeypatch.setattr(instances_pipeline, "get_dstack_runner_version", mock) + return mock + + @pytest.fixture + def get_dstack_runner_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value="https://example.com/runner") + monkeypatch.setattr(instances_pipeline, "get_dstack_runner_download_url", mock) + return mock + + async def test_cannot_determine_expected_version( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_runner_version_mock: Mock, + ): + get_dstack_runner_version_mock.return_value = None + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + assert "Cannot determine the expected runner version" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_not_called() + + async def test_expected_version_already_installed( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + ): + shim_client_mock.get_components.return_value.runner.version = self.EXPECTED_VERSION + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + assert "expected runner version already installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_not_called() + + @pytest.mark.parametrize("status", [ComponentStatus.NOT_INSTALLED, ComponentStatus.ERROR]) + async def test_install_not_installed_or_error( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_runner_download_url_mock: Mock, + status: ComponentStatus, + ): + shim_client_mock.get_components.return_value.runner.version = "" + shim_client_mock.get_components.return_value.runner.status = status + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + assert f"installing runner (no version) -> {self.EXPECTED_VERSION}" in debug_task_log.text + get_dstack_runner_download_url_mock.assert_called_once_with( + arch=None, + version=self.EXPECTED_VERSION, + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_called_once_with( + get_dstack_runner_download_url_mock.return_value + ) + + @pytest.mark.parametrize("installed_version", ["0.19.40", "0.21.0", "dev"]) + async def test_install_installed( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_runner_download_url_mock: Mock, + installed_version: str, + ): + shim_client_mock.get_components.return_value.runner.version = installed_version + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + assert ( + f"installing runner {installed_version} -> {self.EXPECTED_VERSION}" + in debug_task_log.text + ) + get_dstack_runner_download_url_mock.assert_called_once_with( + arch=None, + version=self.EXPECTED_VERSION, + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_called_once_with( + get_dstack_runner_download_url_mock.return_value + ) + + async def test_already_installing( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + ): + shim_client_mock.get_components.return_value.runner.version = "dev" + shim_client_mock.get_components.return_value.runner.status = ComponentStatus.INSTALLING + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + assert "runner is already being installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_not_called() + + +@pytest.mark.usefixtures("get_dstack_shim_version_mock") +class TestMaybeInstallShim(BaseTestMaybeInstallComponents): + @pytest.fixture + def component_list(self) -> ComponentList: + components = ComponentList() + components.add( + ComponentInfo( + name=ComponentName.SHIM, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + return components + + @pytest.fixture + def get_dstack_shim_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=self.EXPECTED_VERSION) + monkeypatch.setattr(instances_pipeline, "get_dstack_shim_version", mock) + return mock + + @pytest.fixture + def get_dstack_shim_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value="https://example.com/shim") + monkeypatch.setattr(instances_pipeline, "get_dstack_shim_download_url", mock) + return mock + + async def test_cannot_determine_expected_version( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_shim_version_mock: Mock, + ): + get_dstack_shim_version_mock.return_value = None + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + assert "Cannot determine the expected shim version" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_not_called() + + async def test_expected_version_already_installed( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + ): + shim_client_mock.get_components.return_value.shim.version = self.EXPECTED_VERSION + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + assert "expected shim version already installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_not_called() + + @pytest.mark.parametrize("status", [ComponentStatus.NOT_INSTALLED, ComponentStatus.ERROR]) + async def test_install_not_installed_or_error( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_shim_download_url_mock: Mock, + status: ComponentStatus, + ): + shim_client_mock.get_components.return_value.shim.version = "" + shim_client_mock.get_components.return_value.shim.status = status + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + assert f"installing shim (no version) -> {self.EXPECTED_VERSION}" in debug_task_log.text + get_dstack_shim_download_url_mock.assert_called_once_with( + arch=None, + version=self.EXPECTED_VERSION, + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_called_once_with( + get_dstack_shim_download_url_mock.return_value + ) + + @pytest.mark.parametrize("installed_version", ["0.19.40", "0.21.0", "dev"]) + async def test_install_installed( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_shim_download_url_mock: Mock, + installed_version: str, + ): + shim_client_mock.get_components.return_value.shim.version = installed_version + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + assert ( + f"installing shim {installed_version} -> {self.EXPECTED_VERSION}" + in debug_task_log.text + ) + get_dstack_shim_download_url_mock.assert_called_once_with( + arch=None, + version=self.EXPECTED_VERSION, + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_called_once_with( + get_dstack_shim_download_url_mock.return_value + ) + + async def test_already_installing( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + ): + shim_client_mock.get_components.return_value.shim.version = "dev" + shim_client_mock.get_components.return_value.shim.status = ComponentStatus.INSTALLING + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + assert "shim is already being installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_not_called() + + +@pytest.mark.usefixtures("maybe_install_runner_mock", "maybe_install_shim_mock") +class TestMaybeRestartShim(BaseTestMaybeInstallComponents): + @pytest.fixture + def component_list(self) -> ComponentList: + components = ComponentList() + components.add( + ComponentInfo( + name=ComponentName.RUNNER, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + components.add( + ComponentInfo( + name=ComponentName.SHIM, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + return components + + @pytest.fixture + def maybe_install_runner_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=False) + monkeypatch.setattr(instances_pipeline, "_maybe_install_runner", mock) + return mock + + @pytest.fixture + def maybe_install_shim_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=False) + monkeypatch.setattr(instances_pipeline, "_maybe_install_shim", mock) + return mock + + async def test_up_to_date(self, test_db, instance: InstanceModel, shim_client_mock: Mock): + shim_client_mock.get_version_string.return_value = self.EXPECTED_VERSION + shim_client_mock.is_safe_to_restart.return_value = True + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_no_shim_component_info( + self, test_db, instance: InstanceModel, shim_client_mock: Mock + ): + shim_client_mock.get_components.return_value = ComponentList() + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_shutdown_requested( + self, test_db, instance: InstanceModel, shim_client_mock: Mock + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_called_once_with(force=False) + + async def test_outdated_but_task_wont_survive_restart( + self, test_db, instance: InstanceModel, shim_client_mock: Mock + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = False + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_runner_installation_in_progress( + self, + test_db, + instance: InstanceModel, + shim_client_mock: Mock, + component_list: ComponentList, + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + runner_info = component_list.runner + assert runner_info is not None + runner_info.status = ComponentStatus.INSTALLING + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_shim_installation_in_progress( + self, + test_db, + instance: InstanceModel, + shim_client_mock: Mock, + component_list: ComponentList, + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + shim_info = component_list.shim + assert shim_info is not None + shim_info.status = ComponentStatus.INSTALLING + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_runner_installation_requested( + self, + test_db, + instance: InstanceModel, + shim_client_mock: Mock, + maybe_install_runner_mock: Mock, + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + maybe_install_runner_mock.return_value = True + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_shim_installation_requested( + self, + test_db, + instance: InstanceModel, + shim_client_mock: Mock, + maybe_install_shim_mock: Mock, + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + maybe_install_shim_mock.return_value = True + + instances_pipeline._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() From a11567b662a56b6d55774f3d4defbf1cb33e008b Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 11:00:29 +0500 Subject: [PATCH 10/51] Run pyright for pipeline tests --- pyproject.toml | 1 + .../_internal/server/services/fleets.py | 2 +- .../background/pipeline_tasks/test_fleets.py | 7 +++--- .../pipeline_tasks/test_instances.py | 24 +++++++++++++++---- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3336fc5423..d36f23d41c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ include = [ "src/dstack/_internal/core/backends/runpod", "src/dstack/_internal/cli/services/configurators", "src/dstack/_internal/cli/commands", + "src/tests/_internal/server/background/pipeline_tasks", ] ignore = [ "src/dstack/_internal/server/migrations/versions", diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 00ddb0e2ed..380052d78b 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -686,7 +686,7 @@ async def delete_fleets( .order_by(InstanceModel.id) ) instances_ids = list(res.scalars().unique().all()) - await sqlite_commit() + await sqlite_commit(session) async with ( get_locker(get_db().dialect_name).lock_ctx(FleetModel.__tablename__, fleets_ids), get_locker(get_db().dialect_name).lock_ctx(InstanceModel.__tablename__, instances_ids), diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index 90441a30ef..71807b8042 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -510,7 +510,6 @@ async def test_consolidation_attempt_resets_when_no_changes( ) assert len(instances) == 1 assert fleet.consolidation_attempt == 0 - assert ( - fleet.last_consolidated_at is not None - and fleet.last_consolidated_at > previous_last_consolidated_at - ) + last_consolidated_at = fleet.last_consolidated_at + assert last_consolidated_at + assert last_consolidated_at > previous_last_consolidated_at diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances.py index 8a8c3b87ce..ae2882e78d 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_instances.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances.py @@ -23,7 +23,7 @@ SSHProvisioningError, ) from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.fleets import InstanceGroupPlacement +from dstack._internal.core.models.fleets import FleetNodesSpec, InstanceGroupPlacement from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import ( Gpu, @@ -1183,7 +1183,11 @@ async def test_terminates_cluster_instances_if_master_not_created( fleet = await create_fleet( session, project, - spec=get_fleet_spec(conf=get_fleet_configuration(placement=placement, nodes=4)), + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=placement, nodes=FleetNodesSpec(min=4, target=4, max=4) + ) + ), ) instances = [ await create_instance( @@ -1230,7 +1234,11 @@ async def test_create_placement_group_if_placement_cluster( fleet = await create_fleet( session, project, - spec=get_fleet_spec(conf=get_fleet_configuration(placement=placement, nodes=1)), + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=placement, nodes=FleetNodesSpec(min=1, target=1, max=1) + ) + ), ) instance = await create_instance( session=session, @@ -1280,7 +1288,10 @@ async def test_reuses_placement_group_between_offers_if_the_group_is_suitable( session, project, spec=get_fleet_spec( - conf=get_fleet_configuration(placement=InstanceGroupPlacement.CLUSTER, nodes=1) + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=1, target=1, max=1), + ) ), ) instance = await create_instance( @@ -1342,7 +1353,10 @@ async def test_handles_create_placement_group_errors( session, project, spec=get_fleet_spec( - conf=get_fleet_configuration(placement=InstanceGroupPlacement.CLUSTER, nodes=1) + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=1, target=1, max=1), + ) ), ) instance = await create_instance( From d4d3147ad60a3a95096278405e5a986a259859cb Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 11:01:09 +0500 Subject: [PATCH 11/51] WIP: InstanceWorker --- .../background/pipeline_tasks/instances.py | 1917 +++++++++++++++++ src/dstack/_internal/server/models.py | 2 +- 2 files changed, 1918 insertions(+), 1 deletion(-) create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/instances.py diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances.py b/src/dstack/_internal/server/background/pipeline_tasks/instances.py new file mode 100644 index 0000000000..23f00eff16 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances.py @@ -0,0 +1,1917 @@ +import asyncio +import datetime +import logging +import uuid +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Any, Dict, Optional, Sequence, TypedDict, Union, cast + +import gpuhunt +import requests +from paramiko.pkey import PKey +from paramiko.ssh_exception import PasswordRequiredException +from pydantic import ValidationError +from sqlalchemy import and_, func, not_, or_, select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload, load_only + +from dstack._internal import settings +from dstack._internal.core.backends.base.compute import ( + ComputeWithCreateInstanceSupport, + ComputeWithPlacementGroupSupport, + GoArchType, + generate_unique_placement_group_name, + get_dstack_runner_binary_path, + get_dstack_runner_download_url, + get_dstack_runner_version, + get_dstack_shim_binary_path, + get_dstack_shim_download_url, + get_dstack_shim_version, + get_dstack_working_dir, + get_shim_env, + get_shim_pre_start_commands, +) +from dstack._internal.core.backends.features import ( + BACKENDS_WITH_CREATE_INSTANCE_SUPPORT, + BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT, +) +from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT +from dstack._internal.core.errors import ( + BackendError, + NotYetTerminated, + PlacementGroupNotSupportedError, + ProvisioningError, + SSHProvisioningError, +) +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.events import EventTargetType +from dstack._internal.core.models.fleets import InstanceGroupPlacement +from dstack._internal.core.models.health import HealthStatus +from dstack._internal.core.models.instances import ( + InstanceAvailability, + InstanceOfferWithAvailability, + InstanceRuntime, + InstanceStatus, + InstanceTerminationReason, + RemoteConnectionInfo, + SSHKey, +) +from dstack._internal.core.models.placement import ( + PlacementGroup, + PlacementGroupConfiguration, + PlacementStrategy, +) +from dstack._internal.core.models.profiles import TerminationPolicy +from dstack._internal.core.models.runs import JobProvisioningData +from dstack._internal.server import settings as server_settings +from dstack._internal.server.background.pipeline_tasks.base import ( + NOW_PLACEHOLDER, + Fetcher, + Heartbeater, + ItemUpdateMap, + Pipeline, + PipelineItem, + UpdateMapDateTime, + Worker, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, +) +from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ( + FleetModel, + InstanceHealthCheckModel, + InstanceModel, + JobModel, + PlacementGroupModel, + ProjectModel, +) +from dstack._internal.server.schemas.instances import InstanceCheck +from dstack._internal.server.schemas.runner import ( + ComponentInfo, + ComponentStatus, + HealthcheckResponse, + InstanceHealthResponse, +) +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services import events +from dstack._internal.server.services.fleets import ( + fleet_model_to_fleet, + get_create_instance_offers, + is_cloud_cluster, +) +from dstack._internal.server.services.instances import ( + get_instance_configuration, + get_instance_profile, + get_instance_provisioning_data, + get_instance_remote_connection_info, + get_instance_requirements, + get_instance_ssh_private_keys, + is_ssh_instance, + remove_dangling_tasks_from_instance, +) +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.offers import ( + get_instance_offer_with_restricted_az, + is_divisible_into_blocks, +) +from dstack._internal.server.services.placement import ( + placement_group_model_to_placement_group, + schedule_fleet_placement_groups_deletion, +) +from dstack._internal.server.services.runner import client as runner_client +from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel +from dstack._internal.server.services.ssh_fleets.provisioning import ( + detect_cpu_arch, + get_host_info, + get_paramiko_connection, + get_shim_healthcheck, + host_info_to_instance_type, + remove_dstack_runner_if_exists, + remove_host_info_if_exists, + run_pre_start_commands, + run_shim_as_systemd_service, + upload_envs, +) +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime, get_or_error, run_async +from dstack._internal.utils.logging import get_logger +from dstack._internal.utils.network import get_ip_from_network, is_ip_among_addresses +from dstack._internal.utils.ssh import pkey_from_str + +logger = get_logger(__name__) + +TERMINATION_DEADLINE_OFFSET = timedelta(minutes=20) +TERMINATION_RETRY_TIMEOUT = timedelta(seconds=30) +TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15) +PROVISIONING_TIMEOUT_SECONDS = 10 * 60 # 10 minutes in seconds + +_UNSET = object() + + +@dataclass +class InstancePipelineItem(PipelineItem): + status: InstanceStatus + + +class InstancePipeline(Pipeline[InstancePipelineItem]): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=10), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[InstancePipelineItem]( + model_type=InstanceModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = InstanceFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + InstanceWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return InstanceModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater[InstancePipelineItem]: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher[InstancePipelineItem]: + return self.__fetcher + + @property + def _workers(self) -> Sequence["InstanceWorker"]: + return self.__workers + + +class InstanceFetcher(Fetcher[InstancePipelineItem]): + def __init__( + self, + queue: asyncio.Queue[InstancePipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[InstancePipelineItem], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.InstanceFetcher.fetch") + async def fetch(self, limit: int) -> list[InstancePipelineItem]: + instance_lock, _ = get_locker(get_db().dialect_name).get_lockset( + InstanceModel.__tablename__ + ) + async with instance_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.status.in_( + [ + InstanceStatus.PENDING, + InstanceStatus.PROVISIONING, + InstanceStatus.BUSY, + InstanceStatus.IDLE, + InstanceStatus.TERMINATING, + ] + ), + not_( + and_( + InstanceModel.status == InstanceStatus.TERMINATING, + InstanceModel.compute_group_id.is_not(None), + ) + ), + InstanceModel.deleted == False, + InstanceModel.last_processed_at <= now - self._min_processing_interval, + or_( + InstanceModel.lock_expires_at.is_(None), + InstanceModel.lock_expires_at < now, + ), + or_( + InstanceModel.lock_owner.is_(None), + InstanceModel.lock_owner == InstancePipeline.__name__, + ), + ) + .order_by(InstanceModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True, of=InstanceModel) + .options( + load_only( + InstanceModel.id, + InstanceModel.lock_token, + InstanceModel.lock_expires_at, + InstanceModel.status, + ) + ) + ) + instance_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for instance_model in instance_models: + prev_lock_expired = instance_model.lock_expires_at is not None + instance_model.lock_expires_at = lock_expires_at + instance_model.lock_token = lock_token + instance_model.lock_owner = InstancePipeline.__name__ + items.append( + InstancePipelineItem( + __tablename__=InstanceModel.__tablename__, + id=instance_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + status=instance_model.status, + ) + ) + await session.commit() + return items + + +class InstanceWorker(Worker[InstancePipelineItem]): + def __init__( + self, + queue: asyncio.Queue[InstancePipelineItem], + heartbeater: Heartbeater[InstancePipelineItem], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.InstanceWorker.process") + async def process(self, item: InstancePipelineItem): + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_status(session=session, item=item) + if instance_model is None: + _log_lock_token_mismatch(item, action="process") + return + status = instance_model.status + + result: Optional[_ProcessResult] = None + if status == InstanceStatus.PENDING: + result = await _process_pending_item(item) + elif status == InstanceStatus.PROVISIONING: + result = await _process_provisioning_item(item) + elif status == InstanceStatus.IDLE: + result = await _process_idle_item(item) + elif status == InstanceStatus.BUSY: + result = await _process_busy_item(item) + elif status == InstanceStatus.TERMINATING: + result = await _process_terminating_item(item) + + if result is None: + return + + set_processed_update_map_fields(result.instance_update_map) + set_unlock_update_map_fields(result.instance_update_map) + await _apply_process_result(item=item, result=result) + + +class _InstanceUpdateMap(ItemUpdateMap, total=False): + status: InstanceStatus + unreachable: bool + started_at: UpdateMapDateTime + finished_at: UpdateMapDateTime + instance_configuration: str + termination_deadline: Optional[datetime.datetime] + termination_reason: Optional[InstanceTerminationReason] + termination_reason_message: Optional[str] + health: HealthStatus + first_termination_retry_at: UpdateMapDateTime + last_termination_retry_at: UpdateMapDateTime + backend: BackendType + backend_data: Optional[str] + offer: str + region: str + price: float + job_provisioning_data: str + total_blocks: int + busy_blocks: int + deleted: bool + deleted_at: UpdateMapDateTime + + +class _SiblingInstanceUpdateMap(TypedDict, total=False): + id: uuid.UUID + status: InstanceStatus + termination_reason: Optional[InstanceTerminationReason] + termination_reason_message: Optional[str] + + +class _HealthCheckCreate(TypedDict): + instance_id: uuid.UUID + collected_at: datetime.datetime + status: HealthStatus + response: str + + +class _PlacementGroupCreate(TypedDict): + id: uuid.UUID + name: str + project_id: uuid.UUID + fleet_id: uuid.UUID + configuration: str + provisioning_data: str + + +@dataclass +class _DeferredEvent: + message: str + project_id: uuid.UUID + instance_id: uuid.UUID + instance_name: str + + +@dataclass +class _PlacementGroupState: + id: uuid.UUID + placement_group: PlacementGroup + create_payload: Optional[_PlacementGroupCreate] = None + + +@dataclass +class _ProcessResult: + instance_update_map: _InstanceUpdateMap = field(default_factory=_InstanceUpdateMap) + sibling_update_rows: list[_SiblingInstanceUpdateMap] = field(default_factory=list) + deferred_events: list[_DeferredEvent] = field(default_factory=list) + health_check_create: Optional[_HealthCheckCreate] = None + placement_group_creates: list[_PlacementGroupCreate] = field(default_factory=list) + schedule_pg_deletion_fleet_id: Optional[uuid.UUID] = None + schedule_pg_deletion_except_ids: tuple[uuid.UUID, ...] = () + + +async def _process_pending_item(item: InstancePipelineItem) -> Optional[_ProcessResult]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_pending_or_terminating( + session=session, + item=item, + ) + if instance_model is None: + _log_lock_token_mismatch(item, action="process") + return None + if is_ssh_instance(instance_model): + return await _process_add_remote(instance_model) + return await _process_create_instance(instance_model) + + +async def _process_provisioning_item(item: InstancePipelineItem) -> Optional[_ProcessResult]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_check(session=session, item=item) + if instance_model is None: + _log_lock_token_mismatch(item, action="process") + return None + return await _process_instance_check(instance_model) + + +async def _process_idle_item(item: InstancePipelineItem) -> Optional[_ProcessResult]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_idle(session=session, item=item) + if instance_model is None: + _log_lock_token_mismatch(item, action="process") + return None + idle_result = _process_idle_timeout(instance_model) + if idle_result is not None: + return idle_result + return await _process_instance_check(instance_model) + + +async def _process_busy_item(item: InstancePipelineItem) -> Optional[_ProcessResult]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_check(session=session, item=item) + if instance_model is None: + _log_lock_token_mismatch(item, action="process") + return None + return await _process_instance_check(instance_model) + + +async def _process_terminating_item(item: InstancePipelineItem) -> Optional[_ProcessResult]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_pending_or_terminating( + session=session, + item=item, + ) + if instance_model is None: + _log_lock_token_mismatch(item, action="process") + return None + return await _process_terminate(instance_model) + + +async def _refetch_locked_instance_status( + session: AsyncSession, + item: InstancePipelineItem, +) -> Optional[InstanceModel]: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .options(load_only(InstanceModel.id, InstanceModel.status)) + ) + return res.scalar_one_or_none() + + +async def _refetch_locked_instance_for_pending_or_terminating( + session: AsyncSession, + item: InstancePipelineItem, +) -> Optional[InstanceModel]: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) + .options( + joinedload(InstanceModel.fleet).joinedload(FleetModel.project), + ) + .options( + joinedload(InstanceModel.fleet) + .joinedload(FleetModel.instances.and_(InstanceModel.deleted == False)) + .joinedload(InstanceModel.project) + ) + .options( + joinedload(InstanceModel.fleet) + .joinedload(FleetModel.instances.and_(InstanceModel.deleted == False)) + .joinedload(InstanceModel.fleet) + ) + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one_or_none() + + +async def _refetch_locked_instance_for_idle( + session: AsyncSession, + item: InstancePipelineItem, +) -> Optional[InstanceModel]: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .options(joinedload(InstanceModel.project)) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) + .options( + joinedload(InstanceModel.fleet).joinedload( + FleetModel.instances.and_(InstanceModel.deleted == False) + ) + ) + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one_or_none() + + +async def _refetch_locked_instance_for_check( + session: AsyncSession, + item: InstancePipelineItem, +) -> Optional[InstanceModel]: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one_or_none() + + +def _process_idle_timeout(instance_model: InstanceModel) -> Optional[_ProcessResult]: + if not ( + instance_model.status == InstanceStatus.IDLE + and instance_model.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE + and not instance_model.jobs + ): + return None + if instance_model.fleet is not None and not _can_terminate_fleet_instances_on_idle_duration( + instance_model.fleet + ): + logger.debug( + "Skipping instance %s termination on idle duration. Fleet is already at `nodes.min`.", + instance_model.name, + ) + return None + + idle_duration = _get_instance_idle_duration(instance_model) + if idle_duration <= datetime.timedelta(seconds=instance_model.termination_idle_time): + return None + + result = _ProcessResult() + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.IDLE_TIMEOUT, + termination_reason_message=f"Instance idle for {idle_duration.seconds}s", + ) + return result + + +def _can_terminate_fleet_instances_on_idle_duration(fleet_model: FleetModel) -> bool: + fleet = fleet_model_to_fleet(fleet_model) + if fleet.spec.configuration.nodes is None or fleet.spec.autocreated: + return True + active_instances = [ + instance for instance in fleet_model.instances if instance.status.is_active() + ] + return len(active_instances) > fleet.spec.configuration.nodes.min + + +async def _process_add_remote(instance_model: InstanceModel) -> _ProcessResult: + result = _ProcessResult() + logger.info("Adding ssh instance %s...", instance_model.name) + + retry_duration_deadline = instance_model.created_at + timedelta( + seconds=PROVISIONING_TIMEOUT_SECONDS + ) + if retry_duration_deadline < get_current_datetime(): + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, + termination_reason_message=( + f"Failed to add SSH instance in {PROVISIONING_TIMEOUT_SECONDS}s" + ), + ) + return result + + remote_details = get_instance_remote_connection_info(instance_model) + assert remote_details is not None + + try: + pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys) + ssh_proxy_pkeys = None + if remote_details.ssh_proxy_keys is not None: + ssh_proxy_pkeys = _ssh_keys_to_pkeys(remote_details.ssh_proxy_keys) + except (ValueError, PasswordRequiredException): + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Unsupported private SSH key type", + ) + return result + + authorized_keys = [pkey.public.strip() for pkey in remote_details.ssh_keys] + authorized_keys.append(instance_model.project.ssh_public_key.strip()) + + try: + future = run_async( + _deploy_instance, + remote_details, + pkeys, + ssh_proxy_pkeys, + authorized_keys, + ) + deploy_timeout = 20 * 60 + health, host_info, arch = await asyncio.wait_for(future, timeout=deploy_timeout) + except (asyncio.TimeoutError, TimeoutError) as exc: + logger.warning( + "%s: deploy timeout when adding SSH instance: %s", + fmt(instance_model), + repr(exc), + ) + return result + except SSHProvisioningError as exc: + logger.warning( + "%s: provisioning error when adding SSH instance: %s", + fmt(instance_model), + repr(exc), + ) + return result + except Exception: + logger.exception("%s: unexpected error when adding SSH instance", fmt(instance_model)) + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Unexpected error when adding SSH instance", + ) + return result + + instance_type = host_info_to_instance_type(host_info, arch) + try: + instance_network, internal_ip = _resolve_ssh_instance_network(instance_model, host_info) + except _SSHInstanceNetworkResolutionError as exc: + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message=str(exc), + ) + return result + + divisible, blocks = is_divisible_into_blocks( + cpu_count=instance_type.resources.cpus, + gpu_count=len(instance_type.resources.gpus), + blocks="auto" if instance_model.total_blocks is None else instance_model.total_blocks, + ) + if not divisible: + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Cannot split into blocks", + ) + return result + + region = instance_model.region + assert region is not None + job_provisioning_data = JobProvisioningData( + backend=BackendType.REMOTE, + instance_type=instance_type, + instance_id="instance_id", + hostname=remote_details.host, + region=region, + price=0, + internal_ip=internal_ip, + instance_network=instance_network, + username=remote_details.ssh_user, + ssh_port=remote_details.port, + dockerized=True, + backend_data=None, + ssh_proxy=remote_details.ssh_proxy, + ) + instance_offer = InstanceOfferWithAvailability( + backend=BackendType.REMOTE, + instance=instance_type, + region=region, + price=0, + availability=InstanceAvailability.AVAILABLE, + instance_runtime=InstanceRuntime.SHIM, + ) + + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.IDLE if health else InstanceStatus.PROVISIONING, + ) + result.instance_update_map["backend"] = BackendType.REMOTE + result.instance_update_map["price"] = 0 + result.instance_update_map["offer"] = instance_offer.json() + result.instance_update_map["job_provisioning_data"] = job_provisioning_data.json() + result.instance_update_map["started_at"] = NOW_PLACEHOLDER + result.instance_update_map["total_blocks"] = blocks + return result + + +class _SSHInstanceNetworkResolutionError(Exception): + pass + + +def _resolve_ssh_instance_network( + instance_model: InstanceModel, + host_info: dict[str, Any], +) -> tuple[Optional[str], Optional[str]]: + instance_network = None + internal_ip = None + try: + default_job_provisioning_data = JobProvisioningData.__response__.parse_raw( + instance_model.job_provisioning_data + ) + instance_network = default_job_provisioning_data.instance_network + internal_ip = default_job_provisioning_data.internal_ip + except ValidationError: + pass + + host_network_addresses = host_info.get("addresses", []) + if internal_ip is None: + internal_ip = get_ip_from_network( + network=instance_network, + addresses=host_network_addresses, + ) + if instance_network is not None and internal_ip is None: + raise _SSHInstanceNetworkResolutionError( + "Failed to locate internal IP address on the given network" + ) + if internal_ip is not None and not is_ip_among_addresses( + ip_address=internal_ip, + addresses=host_network_addresses, + ): + raise _SSHInstanceNetworkResolutionError( + "Specified internal IP not found among instance interfaces" + ) + return instance_network, internal_ip + + +def _deploy_instance( + remote_details: RemoteConnectionInfo, + pkeys: list[PKey], + ssh_proxy_pkeys: Optional[list[PKey]], + authorized_keys: list[str], +) -> tuple[InstanceCheck, dict[str, Any], GoArchType]: + with get_paramiko_connection( + remote_details.ssh_user, + remote_details.host, + remote_details.port, + pkeys, + remote_details.ssh_proxy, + ssh_proxy_pkeys, + ) as client: + logger.info("Connected to %s %s", remote_details.ssh_user, remote_details.host) + + arch = detect_cpu_arch(client) + logger.info("%s: CPU arch is %s", remote_details.host, arch) + + shim_pre_start_commands = get_shim_pre_start_commands(arch=arch) + run_pre_start_commands(client, shim_pre_start_commands, authorized_keys) + logger.debug("The script for installing dstack has been executed") + + shim_envs = get_shim_env(arch=arch) + try: + fleet_configuration_envs = remote_details.env.as_dict() + except ValueError as exc: + raise SSHProvisioningError(f"Invalid Env: {exc}") from exc + shim_envs.update(fleet_configuration_envs) + dstack_working_dir = get_dstack_working_dir() + dstack_shim_binary_path = get_dstack_shim_binary_path() + dstack_runner_binary_path = get_dstack_runner_binary_path() + upload_envs(client, dstack_working_dir, shim_envs) + logger.debug("The dstack-shim environment variables have been installed") + + remove_host_info_if_exists(client, dstack_working_dir) + remove_dstack_runner_if_exists(client, dstack_runner_binary_path) + + run_shim_as_systemd_service( + client=client, + binary_path=dstack_shim_binary_path, + working_dir=dstack_working_dir, + dev=settings.DSTACK_VERSION is None, + ) + + host_info = get_host_info(client, dstack_working_dir) + logger.debug("Received a host_info %s", host_info) + + healthcheck_out = get_shim_healthcheck(client) + try: + healthcheck = HealthcheckResponse.__response__.parse_raw(healthcheck_out) + except ValueError as exc: + raise SSHProvisioningError(f"Cannot parse HealthcheckResponse: {exc}") from exc + instance_check = runner_client.healthcheck_response_to_instance_check(healthcheck) + return instance_check, host_info, arch + + +async def _process_create_instance(instance_model: InstanceModel) -> _ProcessResult: + result = _ProcessResult() + master_instance_model = _get_fleet_master_instance(instance_model) + if _need_to_wait_fleet_provisioning(instance_model, master_instance_model): + logger.debug( + "%s: waiting for the first instance in the fleet to be provisioned", + fmt(instance_model), + ) + return result + + try: + instance_configuration = get_instance_configuration(instance_model) + profile = get_instance_profile(instance_model) + requirements = get_instance_requirements(instance_model) + except ValidationError as exc: + logger.exception( + "%s: error parsing profile, requirements or instance configuration", + fmt(instance_model), + ) + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message=( + f"Error to parse profile, requirements or instance_configuration: {exc}" + ), + ) + return result + + placement_group_states = await _get_fleet_placement_group_states(instance_model.fleet_id) + placement_group_state = _get_placement_group_state_for_instance( + placement_group_states=placement_group_states, + instance_model=instance_model, + master_instance_model=master_instance_model, + ) + offers = await get_create_instance_offers( + project=instance_model.project, + profile=profile, + requirements=requirements, + fleet_model=instance_model.fleet, + placement_group=( + placement_group_state.placement_group if placement_group_state is not None else None + ), + blocks="auto" if instance_model.total_blocks is None else instance_model.total_blocks, + exclude_not_available=True, + ) + + seen_placement_group_ids = {state.id for state in placement_group_states} + for backend, instance_offer in offers[: server_settings.MAX_OFFERS_TRIED]: + if instance_offer.backend not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT: + continue + compute = backend.compute() + assert isinstance(compute, ComputeWithCreateInstanceSupport) + selected_offer = _get_instance_offer_for_instance( + instance_offer=instance_offer, + instance_model=instance_model, + master_instance_model=master_instance_model, + ) + selected_placement_group_state = placement_group_state + if ( + instance_model.fleet is not None + and is_cloud_cluster(instance_model.fleet) + and instance_model.id == master_instance_model.id + and selected_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT + and isinstance(compute, ComputeWithPlacementGroupSupport) + and ( + compute.are_placement_groups_compatible_with_reservations(selected_offer.backend) + or instance_configuration.reservation is None + ) + ): + selected_placement_group_state = await _find_or_create_suitable_placement_group_state( + instance_model=instance_model, + placement_group_states=placement_group_states, + instance_offer=selected_offer, + compute=compute, + ) + if selected_placement_group_state is None: + continue + if ( + selected_placement_group_state.create_payload is not None + and selected_placement_group_state.id not in seen_placement_group_ids + ): + seen_placement_group_ids.add(selected_placement_group_state.id) + placement_group_states.append(selected_placement_group_state) + result.placement_group_creates.append( + selected_placement_group_state.create_payload + ) + + logger.debug( + "Trying %s in %s/%s for $%0.4f per hour", + selected_offer.instance.name, + selected_offer.backend.value, + selected_offer.region, + selected_offer.price, + ) + try: + job_provisioning_data = await run_async( + compute.create_instance, + selected_offer, + instance_configuration, + selected_placement_group_state.placement_group + if selected_placement_group_state is not None + else None, + ) + except BackendError as exc: + logger.warning( + "%s launch in %s/%s failed: %s", + selected_offer.instance.name, + selected_offer.backend.value, + selected_offer.region, + repr(exc), + extra={"instance_name": instance_model.name}, + ) + continue + except Exception: + logger.exception( + "Got exception when launching %s in %s/%s", + selected_offer.instance.name, + selected_offer.backend.value, + selected_offer.region, + ) + continue + + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.PROVISIONING, + ) + result.instance_update_map["backend"] = backend.TYPE + result.instance_update_map["region"] = selected_offer.region + result.instance_update_map["price"] = selected_offer.price + result.instance_update_map["instance_configuration"] = instance_configuration.json() + result.instance_update_map["job_provisioning_data"] = job_provisioning_data.json() + result.instance_update_map["offer"] = selected_offer.json() + result.instance_update_map["total_blocks"] = selected_offer.total_blocks + result.instance_update_map["started_at"] = NOW_PLACEHOLDER + + if instance_model.fleet_id is not None and instance_model.id == master_instance_model.id: + result.schedule_pg_deletion_fleet_id = instance_model.fleet_id + if selected_placement_group_state is not None: + result.schedule_pg_deletion_except_ids = (selected_placement_group_state.id,) + return result + + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.NO_OFFERS, + termination_reason_message="All offers failed" if offers else "No offers found", + ) + if ( + instance_model.fleet is not None + and instance_model.id == master_instance_model.id + and is_cloud_cluster(instance_model.fleet) + ): + for sibling_instance_model in instance_model.fleet.instances: + if sibling_instance_model.id == instance_model.id: + continue + sibling_update_map = _SiblingInstanceUpdateMap(id=sibling_instance_model.id) + _set_status_update( + update_map=sibling_update_map, + deferred_events=result.deferred_events, + instance_model=sibling_instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.MASTER_FAILED, + ) + if len(sibling_update_map) > 1: + result.sibling_update_rows.append(sibling_update_map) + return result + + +def _get_fleet_master_instance(instance_model: InstanceModel) -> InstanceModel: + if instance_model.fleet is None: + return instance_model + fleet_instances = list(instance_model.fleet.instances) + if all(fleet_instance.id != instance_model.id for fleet_instance in fleet_instances): + fleet_instances.append(instance_model) + return min( + fleet_instances, + key=lambda fleet_instance: (fleet_instance.instance_num, fleet_instance.created_at), + ) + + +async def _get_fleet_placement_group_states( + fleet_id: Optional[uuid.UUID], +) -> list[_PlacementGroupState]: + if fleet_id is None: + return [] + async with get_session_ctx() as session: + res = await session.execute( + select(PlacementGroupModel) + .where( + PlacementGroupModel.fleet_id == fleet_id, + PlacementGroupModel.deleted == False, + PlacementGroupModel.fleet_deleted == False, + ) + .options(joinedload(PlacementGroupModel.project)) + ) + placement_group_models = list(res.unique().scalars().all()) + return [ + _PlacementGroupState( + id=placement_group_model.id, + placement_group=placement_group_model_to_placement_group(placement_group_model), + ) + for placement_group_model in placement_group_models + ] + + +def _get_placement_group_state_for_instance( + placement_group_states: list[_PlacementGroupState], + instance_model: InstanceModel, + master_instance_model: InstanceModel, +) -> Optional[_PlacementGroupState]: + if instance_model.id == master_instance_model.id: + return None + if len(placement_group_states) > 1: + logger.error( + ( + "Expected 0 or 1 placement groups associated with fleet %s, found %s." + " An incorrect placement group might have been selected for instance %s" + ), + instance_model.fleet_id, + len(placement_group_states), + instance_model.name, + ) + if placement_group_states: + return placement_group_states[0] + return None + + +async def _find_or_create_suitable_placement_group_state( + instance_model: InstanceModel, + placement_group_states: list[_PlacementGroupState], + instance_offer: InstanceOfferWithAvailability, + compute: ComputeWithPlacementGroupSupport, +) -> Optional[_PlacementGroupState]: + for placement_group_state in placement_group_states: + if compute.is_suitable_placement_group( + placement_group_state.placement_group, + instance_offer, + ): + return placement_group_state + + assert instance_model.fleet is not None + placement_group_id = uuid.uuid4() + placement_group_name = generate_unique_placement_group_name( + project_name=instance_model.project.name, + fleet_name=instance_model.fleet.name, + ) + placement_group = PlacementGroup( + name=placement_group_name, + project_name=instance_model.project.name, + configuration=PlacementGroupConfiguration( + backend=instance_offer.backend, + region=instance_offer.region, + placement_strategy=PlacementStrategy.CLUSTER, + ), + provisioning_data=None, + ) + logger.debug( + "Creating placement group %s in %s/%s", + placement_group.name, + placement_group.configuration.backend.value, + placement_group.configuration.region, + ) + try: + provisioning_data = await run_async( + compute.create_placement_group, + placement_group, + instance_offer, + ) + except PlacementGroupNotSupportedError: + logger.debug( + "Skipping offer %s because placement group not supported", + instance_offer.instance.name, + ) + return None + except BackendError as exc: + logger.warning( + "Failed to create placement group %s in %s/%s: %r", + placement_group.name, + placement_group.configuration.backend.value, + placement_group.configuration.region, + exc, + ) + return None + except Exception: + logger.exception( + "Got exception when creating placement group %s in %s/%s", + placement_group.name, + placement_group.configuration.backend.value, + placement_group.configuration.region, + ) + return None + + logger.info( + "Created placement group %s in %s/%s", + placement_group.name, + placement_group.configuration.backend.value, + placement_group.configuration.region, + ) + placement_group.provisioning_data = provisioning_data + return _PlacementGroupState( + id=placement_group_id, + placement_group=placement_group, + create_payload=_PlacementGroupCreate( + id=placement_group_id, + name=placement_group.name, + project_id=instance_model.project_id, + fleet_id=get_or_error(instance_model.fleet_id), + configuration=placement_group.configuration.json(), + provisioning_data=provisioning_data.json(), + ), + ) + + +async def _process_instance_check(instance_model: InstanceModel) -> _ProcessResult: + result = _ProcessResult() + if ( + instance_model.status == InstanceStatus.BUSY + and instance_model.jobs + and all(job.status.is_finished() for job in instance_model.jobs) + ): + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.JOB_FINISHED, + ) + logger.warning( + "Detected busy instance %s with finished job. Marked as TERMINATING", + instance_model.name, + extra={ + "instance_name": instance_model.name, + "instance_status": instance_model.status.value, + }, + ) + return result + + job_provisioning_data = get_or_error(get_instance_provisioning_data(instance_model)) + if job_provisioning_data.hostname is None: + return await _process_wait_for_instance_provisioning_data( + instance_model=instance_model, + job_provisioning_data=job_provisioning_data, + ) + + if not job_provisioning_data.dockerized: + if instance_model.status == InstanceStatus.PROVISIONING: + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.BUSY, + ) + return result + + check_instance_health = await _should_check_instance_health(instance_model.id) + instance_check = await _run_instance_check( + instance_model=instance_model, + job_provisioning_data=job_provisioning_data, + check_instance_health=check_instance_health, + ) + health_status = _get_health_status_for_instance_check( + instance_model=instance_model, + instance_check=instance_check, + check_instance_health=check_instance_health, + ) + _log_instance_check_result( + instance_model=instance_model, + instance_check=instance_check, + health_status=health_status, + check_instance_health=check_instance_health, + ) + + if instance_check.has_health_checks(): + assert instance_check.health_response is not None + result.health_check_create = _HealthCheckCreate( + instance_id=instance_model.id, + collected_at=get_current_datetime(), + status=health_status, + response=instance_check.health_response.json(), + ) + + _set_health_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + health=health_status, + ) + _set_unreachable_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + unreachable=not instance_check.reachable, + ) + + if instance_check.reachable: + result.instance_update_map["termination_deadline"] = None + if instance_model.status == InstanceStatus.PROVISIONING: + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.IDLE if not instance_model.jobs else InstanceStatus.BUSY, + ) + return result + + now = get_current_datetime() + if not is_ssh_instance(instance_model) and instance_model.termination_deadline is None: + result.instance_update_map["termination_deadline"] = now + TERMINATION_DEADLINE_OFFSET + + if ( + instance_model.status == InstanceStatus.PROVISIONING + and instance_model.started_at is not None + ): + provisioning_deadline = _get_provisioning_deadline( + instance_model=instance_model, + job_provisioning_data=job_provisioning_data, + ) + if now > provisioning_deadline: + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, + termination_reason_message="Instance did not become reachable in time", + ) + elif instance_model.status.is_available(): + deadline = instance_model.termination_deadline + if deadline is not None and now > deadline: + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.UNREACHABLE, + ) + return result + + +async def _should_check_instance_health(instance_id: uuid.UUID) -> bool: + health_check_cutoff = get_current_datetime() - timedelta( + seconds=server_settings.SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS + ) + async with get_session_ctx() as session: + res = await session.execute( + select(func.count(1)).where( + InstanceHealthCheckModel.instance_id == instance_id, + InstanceHealthCheckModel.collected_at > health_check_cutoff, + ) + ) + return res.scalar_one() == 0 + + +async def _run_instance_check( + instance_model: InstanceModel, + job_provisioning_data: JobProvisioningData, + check_instance_health: bool, +) -> InstanceCheck: + ssh_private_keys = get_instance_ssh_private_keys(instance_model) + instance_check = await run_async( + _check_instance_inner, + ssh_private_keys, + job_provisioning_data, + None, + instance=instance_model, + check_instance_health=check_instance_health, + ) + if instance_check is False: + return InstanceCheck(reachable=False, message="SSH or tunnel error") + return instance_check + + +def _get_health_status_for_instance_check( + instance_model: InstanceModel, + instance_check: InstanceCheck, + check_instance_health: bool, +) -> HealthStatus: + if instance_check.reachable and check_instance_health: + return instance_check.get_health_status() + return instance_model.health + + +def _log_instance_check_result( + instance_model: InstanceModel, + instance_check: InstanceCheck, + health_status: HealthStatus, + check_instance_health: bool, +) -> None: + loglevel = logging.DEBUG + if not instance_check.reachable and instance_model.status.is_available(): + loglevel = logging.WARNING + elif check_instance_health and not health_status.is_healthy(): + loglevel = logging.WARNING + logger.log( + loglevel, + "Instance %s check: reachable=%s health_status=%s message=%r", + instance_model.name, + instance_check.reachable, + health_status.name, + instance_check.message, + extra={"instance_name": instance_model.name, "health_status": health_status}, + ) + + +async def _process_wait_for_instance_provisioning_data( + instance_model: InstanceModel, + job_provisioning_data: JobProvisioningData, +) -> _ProcessResult: + result = _ProcessResult() + logger.debug("Waiting for instance %s to become running", instance_model.name) + provisioning_deadline = _get_provisioning_deadline( + instance_model=instance_model, + job_provisioning_data=job_provisioning_data, + ) + if get_current_datetime() > provisioning_deadline: + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, + termination_reason_message="Backend did not complete provisioning in time", + ) + return result + + backend = await backends_services.get_project_backend_by_type( + project=instance_model.project, + backend_type=job_provisioning_data.backend, + ) + if backend is None: + logger.warning( + "Instance %s failed because instance's backend is not available", + instance_model.name, + ) + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Backend not available", + ) + return result + + try: + await run_async( + backend.compute().update_provisioning_data, + job_provisioning_data, + instance_model.project.ssh_public_key, + instance_model.project.ssh_private_key, + ) + result.instance_update_map["job_provisioning_data"] = job_provisioning_data.json() + except ProvisioningError as exc: + logger.warning( + "Error while waiting for instance %s to become running: %s", + instance_model.name, + repr(exc), + ) + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Error while waiting for instance to become running", + ) + except Exception: + logger.exception( + "Got exception when updating instance %s provisioning data", + instance_model.name, + ) + return result + + +@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) +def _check_instance_inner( + ports: Dict[int, int], + *, + instance: InstanceModel, + check_instance_health: bool = False, +) -> InstanceCheck: + instance_health_response: Optional[InstanceHealthResponse] = None + shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + method = shim_client.healthcheck + try: + healthcheck_response = method(unmask_exceptions=True) + if check_instance_health: + method = shim_client.get_instance_health + instance_health_response = method() + except requests.RequestException as exc: + template = "shim.%s(): request error: %s" + args = (method.__func__.__name__, exc) + logger.debug(template, *args) + return InstanceCheck(reachable=False, message=template % args) + except Exception as exc: + template = "shim.%s(): unexpected exception %s: %s" + args = (method.__func__.__name__, exc.__class__.__name__, exc) + logger.exception(template, *args) + return InstanceCheck(reachable=False, message=template % args) + + try: + remove_dangling_tasks_from_instance(shim_client, instance) + except Exception as exc: + logger.exception("%s: error removing dangling tasks: %s", fmt(instance), exc) + + _maybe_install_components(instance, shim_client) + return runner_client.healthcheck_response_to_instance_check( + healthcheck_response, + instance_health_response, + ) + + +def _maybe_install_components( + instance_model: InstanceModel, + shim_client: runner_client.ShimClient, +) -> None: + try: + components = shim_client.get_components() + except requests.RequestException as exc: + logger.warning( + "Instance %s: shim.get_components(): request error: %s", instance_model.name, exc + ) + return + if components is None: + logger.debug("Instance %s: no components info", instance_model.name) + return + + installed_shim_version: Optional[str] = None + installation_requested = False + + if (runner_info := components.runner) is not None: + installation_requested |= _maybe_install_runner(instance_model, shim_client, runner_info) + else: + logger.debug("Instance %s: no runner info", instance_model.name) + + if (shim_info := components.shim) is not None: + if shim_info.status == ComponentStatus.INSTALLED: + installed_shim_version = shim_info.version + installation_requested |= _maybe_install_shim(instance_model, shim_client, shim_info) + else: + logger.debug("Instance %s: no shim info", instance_model.name) + + running_shim_version = shim_client.get_version_string() + if ( + installed_shim_version is None + or installed_shim_version == running_shim_version + or installation_requested + or any(component.status == ComponentStatus.INSTALLING for component in components) + or not shim_client.is_safe_to_restart() + ): + return + + if shim_client.shutdown(force=False): + logger.debug( + "Instance %s: restarting shim %s -> %s", + instance_model.name, + running_shim_version, + installed_shim_version, + ) + else: + logger.debug("Instance %s: cannot restart shim", instance_model.name) + + +def _maybe_install_runner( + instance_model: InstanceModel, + shim_client: runner_client.ShimClient, + runner_info: ComponentInfo, +) -> bool: + expected_version = get_dstack_runner_version() + if expected_version is None: + logger.debug("Cannot determine the expected runner version") + return False + + installed_version = runner_info.version + logger.debug( + "Instance %s: runner status=%s installed_version=%s", + instance_model.name, + runner_info.status.value, + installed_version or "(no version)", + ) + if runner_info.status == ComponentStatus.INSTALLING: + logger.debug("Instance %s: runner is already being installed", instance_model.name) + return False + if installed_version and installed_version == expected_version: + logger.debug("Instance %s: expected runner version already installed", instance_model.name) + return False + + url = get_dstack_runner_download_url( + arch=_get_instance_cpu_arch(instance_model), + version=expected_version, + ) + logger.debug( + "Instance %s: installing runner %s -> %s from %s", + instance_model.name, + installed_version or "(no version)", + expected_version, + url, + ) + try: + shim_client.install_runner(url) + return True + except requests.RequestException as exc: + logger.warning("Instance %s: shim.install_runner(): %s", instance_model.name, exc) + return False + + +def _maybe_install_shim( + instance_model: InstanceModel, + shim_client: runner_client.ShimClient, + shim_info: ComponentInfo, +) -> bool: + expected_version = get_dstack_shim_version() + if expected_version is None: + logger.debug("Cannot determine the expected shim version") + return False + + installed_version = shim_info.version + logger.debug( + "Instance %s: shim status=%s installed_version=%s running_version=%s", + instance_model.name, + shim_info.status.value, + installed_version or "(no version)", + shim_client.get_version_string(), + ) + if shim_info.status == ComponentStatus.INSTALLING: + logger.debug("Instance %s: shim is already being installed", instance_model.name) + return False + if installed_version and installed_version == expected_version: + logger.debug("Instance %s: expected shim version already installed", instance_model.name) + return False + + url = get_dstack_shim_download_url( + arch=_get_instance_cpu_arch(instance_model), + version=expected_version, + ) + logger.debug( + "Instance %s: installing shim %s -> %s from %s", + instance_model.name, + installed_version or "(no version)", + expected_version, + url, + ) + try: + shim_client.install_shim(url) + return True + except requests.RequestException as exc: + logger.warning("Instance %s: shim.install_shim(): %s", instance_model.name, exc) + return False + + +def _get_instance_cpu_arch(instance_model: InstanceModel) -> Optional[gpuhunt.CPUArchitecture]: + job_provisioning_data = get_instance_provisioning_data(instance_model) + if job_provisioning_data is None: + return None + return job_provisioning_data.instance_type.resources.cpu_arch + + +async def _process_terminate(instance_model: InstanceModel) -> _ProcessResult: + result = _ProcessResult() + now = get_current_datetime() + if ( + instance_model.last_termination_retry_at is not None + and _next_termination_retry_at(instance_model.last_termination_retry_at) > now + ): + return result + + job_provisioning_data = get_instance_provisioning_data(instance_model) + if job_provisioning_data is not None and job_provisioning_data.backend != BackendType.REMOTE: + backend = await backends_services.get_project_backend_by_type( + project=instance_model.project, + backend_type=job_provisioning_data.backend, + ) + if backend is None: + logger.error( + "Failed to terminate instance %s. Backend %s not available.", + instance_model.name, + job_provisioning_data.backend, + ) + else: + logger.debug("Terminating runner instance %s", job_provisioning_data.hostname) + try: + await run_async( + backend.compute().terminate_instance, + job_provisioning_data.instance_id, + job_provisioning_data.region, + job_provisioning_data.backend_data, + ) + except Exception as exc: + first_retry_at = instance_model.first_termination_retry_at + if first_retry_at is None: + first_retry_at = now + result.instance_update_map["first_termination_retry_at"] = NOW_PLACEHOLDER + result.instance_update_map["last_termination_retry_at"] = NOW_PLACEHOLDER + if _next_termination_retry_at(now) < _get_termination_deadline(first_retry_at): + if isinstance(exc, NotYetTerminated): + logger.debug( + "Instance %s termination in progress: %s", + instance_model.name, + exc, + ) + else: + logger.warning( + "Failed to terminate instance %s. Will retry. Error: %r", + instance_model.name, + exc, + exc_info=not isinstance(exc, BackendError), + ) + return result + logger.error( + "Failed all attempts to terminate instance %s." + " Please terminate the instance manually to avoid unexpected charges." + " Error: %r", + instance_model.name, + exc, + exc_info=not isinstance(exc, BackendError), + ) + + result.instance_update_map["deleted"] = True + result.instance_update_map["deleted_at"] = NOW_PLACEHOLDER + result.instance_update_map["finished_at"] = NOW_PLACEHOLDER + _set_status_update( + update_map=result.instance_update_map, + deferred_events=result.deferred_events, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + ) + return result + + +def _next_termination_retry_at(last_termination_retry_at: datetime.datetime) -> datetime.datetime: + return last_termination_retry_at + TERMINATION_RETRY_TIMEOUT + + +def _get_termination_deadline(first_termination_retry_at: datetime.datetime) -> datetime.datetime: + return first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION + + +def _need_to_wait_fleet_provisioning( + instance_model: InstanceModel, + master_instance_model: InstanceModel, +) -> bool: + if instance_model.fleet is None: + return False + if ( + instance_model.id == master_instance_model.id + or master_instance_model.job_provisioning_data is not None + or master_instance_model.status == InstanceStatus.TERMINATED + ): + return False + return is_cloud_cluster(instance_model.fleet) + + +def _get_instance_offer_for_instance( + instance_offer: InstanceOfferWithAvailability, + instance_model: InstanceModel, + master_instance_model: InstanceModel, +) -> InstanceOfferWithAvailability: + if instance_model.fleet is None: + return instance_offer + fleet = fleet_model_to_fleet(instance_model.fleet) + if fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER: + master_job_provisioning_data = get_instance_provisioning_data(master_instance_model) + return get_instance_offer_with_restricted_az( + instance_offer=instance_offer, + master_job_provisioning_data=master_job_provisioning_data, + ) + return instance_offer + + +def _get_instance_idle_duration(instance_model: InstanceModel) -> datetime.timedelta: + last_time = instance_model.created_at + if instance_model.last_job_processed_at is not None: + last_time = instance_model.last_job_processed_at + return get_current_datetime() - last_time + + +def _get_provisioning_deadline( + instance_model: InstanceModel, + job_provisioning_data: JobProvisioningData, +) -> datetime.datetime: + assert instance_model.started_at is not None + timeout_interval = get_provisioning_timeout( + backend_type=job_provisioning_data.get_base_backend(), + instance_type_name=job_provisioning_data.instance_type.name, + ) + return instance_model.started_at + timeout_interval + + +def _ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]: + return [pkey_from_str(ssh_key.private) for ssh_key in ssh_keys if ssh_key.private is not None] + + +def _set_status_update( + update_map: Union[_InstanceUpdateMap, _SiblingInstanceUpdateMap], + deferred_events: list[_DeferredEvent], + instance_model: InstanceModel, + new_status: InstanceStatus, + termination_reason: object = _UNSET, + termination_reason_message: object = _UNSET, +) -> None: + old_status = instance_model.status + if old_status == new_status: + if termination_reason is not _UNSET: + update_map["termination_reason"] = cast( + Optional[InstanceTerminationReason], termination_reason + ) + if termination_reason_message is not _UNSET: + update_map["termination_reason_message"] = cast( + Optional[str], termination_reason_message + ) + return + + effective_termination_reason = instance_model.termination_reason + if termination_reason is not _UNSET: + effective_termination_reason = cast( + Optional[InstanceTerminationReason], termination_reason + ) + update_map["termination_reason"] = effective_termination_reason + + effective_termination_reason_message = instance_model.termination_reason_message + if termination_reason_message is not _UNSET: + effective_termination_reason_message = cast(Optional[str], termination_reason_message) + update_map["termination_reason_message"] = effective_termination_reason_message + + update_map["status"] = new_status + deferred_events.append( + _DeferredEvent( + message=_format_status_change_message( + old_status=old_status, + new_status=new_status, + termination_reason=effective_termination_reason, + termination_reason_message=effective_termination_reason_message, + ), + project_id=instance_model.project_id, + instance_id=instance_model.id, + instance_name=instance_model.name, + ) + ) + + +def _set_health_update( + update_map: _InstanceUpdateMap, + deferred_events: list[_DeferredEvent], + instance_model: InstanceModel, + health: HealthStatus, +) -> None: + if instance_model.health == health: + return + update_map["health"] = health + deferred_events.append( + _DeferredEvent( + message=f"Instance health changed {instance_model.health.upper()} -> {health.upper()}", + project_id=instance_model.project_id, + instance_id=instance_model.id, + instance_name=instance_model.name, + ) + ) + + +def _set_unreachable_update( + update_map: _InstanceUpdateMap, + deferred_events: list[_DeferredEvent], + instance_model: InstanceModel, + unreachable: bool, +) -> None: + if not instance_model.status.is_available() or instance_model.unreachable == unreachable: + return + update_map["unreachable"] = unreachable + deferred_events.append( + _DeferredEvent( + message="Instance became unreachable" if unreachable else "Instance became reachable", + project_id=instance_model.project_id, + instance_id=instance_model.id, + instance_name=instance_model.name, + ) + ) + + +def _format_status_change_message( + old_status: InstanceStatus, + new_status: InstanceStatus, + termination_reason: Optional[InstanceTerminationReason], + termination_reason_message: Optional[str], +) -> str: + message = f"Instance status changed {old_status.upper()} -> {new_status.upper()}" + if new_status == InstanceStatus.TERMINATING or ( + new_status == InstanceStatus.TERMINATED and old_status != InstanceStatus.TERMINATING + ): + if termination_reason is None: + raise ValueError( + f"termination_reason must be set when switching to {new_status.upper()} status" + ) + if ( + termination_reason == InstanceTerminationReason.ERROR + and not termination_reason_message + ): + raise ValueError( + "termination_reason_message must be set when termination_reason is ERROR" + ) + message += f". Termination reason: {termination_reason.upper()}" + if termination_reason_message: + message += f" ({termination_reason_message})" + return message + + +async def _apply_process_result(item: InstancePipelineItem, result: _ProcessResult) -> None: + async with get_session_ctx() as session: + for placement_group_create in result.placement_group_creates: + session.add(PlacementGroupModel(**placement_group_create)) + if result.health_check_create is not None: + session.add(InstanceHealthCheckModel(**result.health_check_create)) + await session.flush() + + now = get_current_datetime() + resolve_now_placeholders(result.instance_update_map, now=now) + resolve_now_placeholders(result.sibling_update_rows, now=now) + res = await session.execute( + update(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .values(**result.instance_update_map) + .returning(InstanceModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + _log_lock_token_mismatch(item, action="update after processing") + await session.rollback() + return + + if result.sibling_update_rows: + await session.execute( + update(InstanceModel).execution_options(synchronize_session=False), + result.sibling_update_rows, + ) + + if result.schedule_pg_deletion_fleet_id is not None: + await schedule_fleet_placement_groups_deletion( + session=session, + fleet_id=result.schedule_pg_deletion_fleet_id, + except_placement_group_ids=result.schedule_pg_deletion_except_ids, + ) + + for deferred_event in result.deferred_events: + events.emit( + session=session, + message=deferred_event.message, + actor=events.SystemActor(), + targets=[ + events.Target( + type=EventTargetType.INSTANCE, + project_id=deferred_event.project_id, + id=deferred_event.instance_id, + name=deferred_event.instance_name, + ) + ], + ) + await session.commit() + + +def _log_lock_token_mismatch(item: InstancePipelineItem, action: str) -> None: + logger.warning( + "Failed to %s %s item %s: lock_token mismatch." + " The item is expected to be processed and updated on another fetch iteration.", + action, + item.__tablename__, + item.id, + ) diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 15c5488da5..843b1d798b 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -619,7 +619,7 @@ class FleetModel(PipelineModelMixin, BaseModel): ) -class InstanceModel(BaseModel): +class InstanceModel(PipelineModelMixin, BaseModel): __tablename__ = "instances" id: Mapped[uuid.UUID] = mapped_column( From bb112371098c5fc8a016577744c4f38335578840 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 11:08:47 +0500 Subject: [PATCH 12/51] Fix volumes pipeline processing active --- .../_internal/server/background/pipeline_tasks/volumes.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py index c7a8f5761a..16a0633af2 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py @@ -204,8 +204,6 @@ async def process(self, item: VolumePipelineItem): await _process_to_be_deleted_item(item) elif item.status == VolumeStatus.SUBMITTED: await _process_submitted_item(item) - elif item.status == VolumeStatus.ACTIVE: - pass async def _process_submitted_item(item: VolumePipelineItem): From 3bf900623ceddeb46fd8ee7b9a41788ca9b92e69 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 11:23:05 +0500 Subject: [PATCH 13/51] Refactor log_lock_token --- .../server/background/pipeline_tasks/base.py | 38 +++++++++++++ .../pipeline_tasks/compute_groups.py | 16 ++---- .../background/pipeline_tasks/fleets.py | 22 +++----- .../background/pipeline_tasks/gateways.py | 54 +++++-------------- .../background/pipeline_tasks/instances.py | 39 +++++--------- .../pipeline_tasks/placement_groups.py | 16 ++---- .../background/pipeline_tasks/volumes.py | 30 +++-------- 7 files changed, 83 insertions(+), 132 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py index aa5af9a4a3..a68ad0ff97 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/base.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -1,4 +1,5 @@ import asyncio +import logging import math import random import uuid @@ -416,3 +417,40 @@ def resolve_now_placeholders(update_values: _ResolveNowInput, now: datetime): for key, value in update_values.items(): if value is NOW_PLACEHOLDER: update_values[key] = now + + +def log_lock_token_mismatch( + logger: logging.Logger, + item: PipelineItem, + action: str = "process", +) -> None: + logger.warning( + "Failed to %s %s item %s: lock_token mismatch." + " The item is expected to be processed and updated on another fetch iteration.", + action, + item.__tablename__, + item.id, + ) + + +def log_lock_token_changed_after_processing( + logger: logging.Logger, + item: PipelineItem, + action: str = "update", + expected_outcome: str = "updated", +) -> None: + logger.warning( + "Failed to %s %s item %s after processing: lock_token changed." + " The item is expected to be processed and %s on another fetch iteration.", + action, + item.__tablename__, + item.id, + expected_outcome, + ) + + +def log_lock_token_changed_on_reset(logger: logging.Logger) -> None: + logger.warning( + "Failed to reset lock: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration." + ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py index 0ee2975eb2..fd4be167cb 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -20,6 +20,8 @@ PipelineItem, UpdateMapDateTime, Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, @@ -194,12 +196,7 @@ async def process(self, item: PipelineItem): ) compute_group_model = res.unique().scalar_one_or_none() if compute_group_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return result = _TerminateResult() @@ -228,12 +225,7 @@ async def process(self, item: PipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) return if not result.instances_update_map: return diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 55ffcd7f94..6d6295de5e 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -20,6 +20,9 @@ PipelineItem, UpdateMapDateTime, Worker, + log_lock_token_changed_after_processing, + log_lock_token_changed_on_reset, + log_lock_token_mismatch, resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, @@ -210,12 +213,7 @@ async def process(self, item: PipelineItem): ) fleet_model = res.unique().scalar_one_or_none() if fleet_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return instance_lock, _ = get_locker(get_db().dialect_name).get_lockset( @@ -262,10 +260,7 @@ async def process(self, item: PipelineItem): ) ) if res.rowcount == 0: # pyright: ignore[reportAttributeAccessIssue] - logger.warning( - "Failed to reset lock: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration." - ) + log_lock_token_changed_on_reset(logger) return # TODO: Lock instance models in the DB @@ -297,12 +292,7 @@ async def process(self, item: PipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) # TODO: Clean up fleet. return diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py index 2d5f0a947b..81ba2ae708 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py @@ -18,6 +18,8 @@ Pipeline, PipelineItem, Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, @@ -219,12 +221,7 @@ async def _process_submitted_item(item: GatewayPipelineItem): ) gateway_model = res.unique().scalar_one_or_none() if gateway_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return result = await _process_submitted_gateway(gateway_model) @@ -251,12 +248,7 @@ async def _process_submitted_item(item: GatewayPipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) # TODO: Clean up gateway_compute_model. return emit_gateway_status_change_event( @@ -345,12 +337,7 @@ async def _process_provisioning_item(item: GatewayPipelineItem): ) gateway_model = res.unique().scalar_one_or_none() if gateway_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return result = await _process_provisioning_gateway(gateway_model) @@ -372,12 +359,7 @@ async def _process_provisioning_item(item: GatewayPipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) return emit_gateway_status_change_event( session=session, @@ -464,12 +446,7 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): ) gateway_model = res.unique().scalar_one_or_none() if gateway_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return result = await _process_to_be_deleted_gateway(gateway_model) @@ -485,11 +462,11 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): ) deleted_ids = list(res.scalars().all()) if len(deleted_ids) == 0: - logger.warning( - "Failed to delete %s item %s after processing: lock_token changed." - " The item is expected to be processed and deleted on another fetch iteration.", - item.__tablename__, - item.id, + log_lock_token_changed_after_processing( + logger, + item, + action="delete", + expected_outcome="deleted", ) return events.emit( @@ -514,12 +491,7 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) return if result.gateway_compute_update_map: diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances.py b/src/dstack/_internal/server/background/pipeline_tasks/instances.py index 23f00eff16..ffc9b92bc3 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances.py @@ -73,6 +73,8 @@ PipelineItem, UpdateMapDateTime, Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, @@ -315,7 +317,7 @@ async def process(self, item: InstancePipelineItem): async with get_session_ctx() as session: instance_model = await _refetch_locked_instance_status(session=session, item=item) if instance_model is None: - _log_lock_token_mismatch(item, action="process") + log_lock_token_mismatch(logger, item) return status = instance_model.status @@ -330,10 +332,9 @@ async def process(self, item: InstancePipelineItem): result = await _process_busy_item(item) elif status == InstanceStatus.TERMINATING: result = await _process_terminating_item(item) - if result is None: + # FIXME: Item won't be unlocked!!! return - set_processed_update_map_fields(result.instance_update_map) set_unlock_update_map_fields(result.instance_update_map) await _apply_process_result(item=item, result=result) @@ -419,7 +420,7 @@ async def _process_pending_item(item: InstancePipelineItem) -> Optional[_Process item=item, ) if instance_model is None: - _log_lock_token_mismatch(item, action="process") + log_lock_token_mismatch(logger, item) return None if is_ssh_instance(instance_model): return await _process_add_remote(instance_model) @@ -430,7 +431,7 @@ async def _process_provisioning_item(item: InstancePipelineItem) -> Optional[_Pr async with get_session_ctx() as session: instance_model = await _refetch_locked_instance_for_check(session=session, item=item) if instance_model is None: - _log_lock_token_mismatch(item, action="process") + log_lock_token_mismatch(logger, item) return None return await _process_instance_check(instance_model) @@ -439,7 +440,7 @@ async def _process_idle_item(item: InstancePipelineItem) -> Optional[_ProcessRes async with get_session_ctx() as session: instance_model = await _refetch_locked_instance_for_idle(session=session, item=item) if instance_model is None: - _log_lock_token_mismatch(item, action="process") + log_lock_token_mismatch(logger, item) return None idle_result = _process_idle_timeout(instance_model) if idle_result is not None: @@ -451,7 +452,7 @@ async def _process_busy_item(item: InstancePipelineItem) -> Optional[_ProcessRes async with get_session_ctx() as session: instance_model = await _refetch_locked_instance_for_check(session=session, item=item) if instance_model is None: - _log_lock_token_mismatch(item, action="process") + log_lock_token_mismatch(logger, item) return None return await _process_instance_check(instance_model) @@ -463,7 +464,7 @@ async def _process_terminating_item(item: InstancePipelineItem) -> Optional[_Pro item=item, ) if instance_model is None: - _log_lock_token_mismatch(item, action="process") + log_lock_token_mismatch(logger, item) return None return await _process_terminate(instance_model) @@ -796,10 +797,10 @@ def _deploy_instance( remote_details.ssh_proxy, ssh_proxy_pkeys, ) as client: - logger.info("Connected to %s %s", remote_details.ssh_user, remote_details.host) + logger.debug("Connected to %s %s", remote_details.ssh_user, remote_details.host) arch = detect_cpu_arch(client) - logger.info("%s: CPU arch is %s", remote_details.host, arch) + logger.debug("%s: CPU arch is %s", remote_details.host, arch) shim_pre_start_commands = get_shim_pre_start_commands(arch=arch) run_pre_start_commands(client, shim_pre_start_commands, authorized_keys) @@ -1138,12 +1139,6 @@ async def _find_or_create_suitable_placement_group_state( ) return None - logger.info( - "Created placement group %s in %s/%s", - placement_group.name, - placement_group.configuration.backend.value, - placement_group.configuration.region, - ) placement_group.provisioning_data = provisioning_data return _PlacementGroupState( id=placement_group_id, @@ -1873,7 +1868,7 @@ async def _apply_process_result(item: InstancePipelineItem, result: _ProcessResu ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - _log_lock_token_mismatch(item, action="update after processing") + log_lock_token_changed_after_processing(logger, item) await session.rollback() return @@ -1905,13 +1900,3 @@ async def _apply_process_result(item: InstancePipelineItem, result: _ProcessResu ], ) await session.commit() - - -def _log_lock_token_mismatch(item: InstancePipelineItem, action: str) -> None: - logger.warning( - "Failed to %s %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - action, - item.__tablename__, - item.id, - ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py index 703cfe1548..552ae00dc8 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -18,6 +18,8 @@ PipelineItem, UpdateMapDateTime, Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, @@ -189,12 +191,7 @@ async def process(self, item: PipelineItem): ) placement_group_model = res.unique().scalar_one_or_none() if placement_group_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return result = await _delete_placement_group(placement_group_model) @@ -217,12 +214,7 @@ async def process(self, item: PipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) class _PlacementGroupUpdateMap(ItemUpdateMap, total=False): diff --git a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py index 16a0633af2..81d94c361b 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py @@ -19,6 +19,8 @@ PipelineItem, UpdateMapDateTime, Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, resolve_now_placeholders, set_processed_update_map_fields, set_unlock_update_map_fields, @@ -225,12 +227,7 @@ async def _process_submitted_item(item: VolumePipelineItem): ) volume_model = res.unique().scalar_one_or_none() if volume_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return result = await _process_submitted_volume(volume_model) @@ -251,12 +248,7 @@ async def _process_submitted_item(item: VolumePipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) # TODO: Clean up volume. return emit_volume_status_change_event( @@ -367,12 +359,7 @@ async def _process_to_be_deleted_item(item: VolumePipelineItem): ) volume_model = res.unique().scalar_one_or_none() if volume_model is None: - logger.warning( - "Failed to process %s item %s: lock_token mismatch." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_mismatch(logger, item) return result = await _process_to_be_deleted_volume(volume_model) @@ -394,12 +381,7 @@ async def _process_to_be_deleted_item(item: VolumePipelineItem): ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - logger.warning( - "Failed to update %s item %s after processing: lock_token changed." - " The item is expected to be processed and updated on another fetch iteration.", - item.__tablename__, - item.id, - ) + log_lock_token_changed_after_processing(logger, item) return events.emit( session, From b508c3482e00b4bbebe33caeff420253d0a88479 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 12:04:08 +0500 Subject: [PATCH 14/51] Build instance events from update map --- .../pipeline_tasks/compute_groups.py | 2 + .../background/pipeline_tasks/instances.py | 233 ++++++++++++------ .../_internal/server/services/instances.py | 24 +- 3 files changed, 176 insertions(+), 83 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py index fd4be167cb..69ce3e7998 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -241,6 +241,8 @@ async def process(self, item: PipelineItem): instance_model=instance_model, old_status=instance_model.status, new_status=InstanceStatus.TERMINATED, + termination_reason=instance_model.termination_reason, + termination_reason_message=instance_model.termination_reason_message, ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances.py b/src/dstack/_internal/server/background/pipeline_tasks/instances.py index ffc9b92bc3..465622ff23 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances.py @@ -104,12 +104,14 @@ is_cloud_cluster, ) from dstack._internal.server.services.instances import ( + emit_instance_status_change_event, get_instance_configuration, get_instance_profile, get_instance_provisioning_data, get_instance_remote_connection_info, get_instance_requirements, get_instance_ssh_private_keys, + get_instance_status_change_message, is_ssh_instance, remove_dangling_tasks_from_instance, ) @@ -388,7 +390,7 @@ class _PlacementGroupCreate(TypedDict): @dataclass -class _DeferredEvent: +class _SiblingDeferredEvent: message: str project_id: uuid.UUID instance_id: uuid.UUID @@ -406,7 +408,7 @@ class _PlacementGroupState: class _ProcessResult: instance_update_map: _InstanceUpdateMap = field(default_factory=_InstanceUpdateMap) sibling_update_rows: list[_SiblingInstanceUpdateMap] = field(default_factory=list) - deferred_events: list[_DeferredEvent] = field(default_factory=list) + sibling_deferred_events: list[_SiblingDeferredEvent] = field(default_factory=list) health_check_create: Optional[_HealthCheckCreate] = None placement_group_creates: list[_PlacementGroupCreate] = field(default_factory=list) schedule_pg_deletion_fleet_id: Optional[uuid.UUID] = None @@ -576,7 +578,6 @@ def _process_idle_timeout(instance_model: InstanceModel) -> Optional[_ProcessRes result = _ProcessResult() _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATING, termination_reason=InstanceTerminationReason.IDLE_TIMEOUT, @@ -605,7 +606,6 @@ async def _process_add_remote(instance_model: InstanceModel) -> _ProcessResult: if retry_duration_deadline < get_current_datetime(): _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATED, termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, @@ -626,7 +626,6 @@ async def _process_add_remote(instance_model: InstanceModel) -> _ProcessResult: except (ValueError, PasswordRequiredException): _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATED, termination_reason=InstanceTerminationReason.ERROR, @@ -665,7 +664,6 @@ async def _process_add_remote(instance_model: InstanceModel) -> _ProcessResult: logger.exception("%s: unexpected error when adding SSH instance", fmt(instance_model)) _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATED, termination_reason=InstanceTerminationReason.ERROR, @@ -679,7 +677,6 @@ async def _process_add_remote(instance_model: InstanceModel) -> _ProcessResult: except _SSHInstanceNetworkResolutionError as exc: _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATED, termination_reason=InstanceTerminationReason.ERROR, @@ -695,7 +692,6 @@ async def _process_add_remote(instance_model: InstanceModel) -> _ProcessResult: if not divisible: _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATED, termination_reason=InstanceTerminationReason.ERROR, @@ -731,7 +727,6 @@ async def _process_add_remote(instance_model: InstanceModel) -> _ProcessResult: _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.IDLE if health else InstanceStatus.PROVISIONING, ) @@ -861,7 +856,6 @@ async def _process_create_instance(instance_model: InstanceModel) -> _ProcessRes ) _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATED, termination_reason=InstanceTerminationReason.ERROR, @@ -967,7 +961,6 @@ async def _process_create_instance(instance_model: InstanceModel) -> _ProcessRes _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.PROVISIONING, ) @@ -988,7 +981,6 @@ async def _process_create_instance(instance_model: InstanceModel) -> _ProcessRes _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATED, termination_reason=InstanceTerminationReason.NO_OFFERS, @@ -1005,13 +997,24 @@ async def _process_create_instance(instance_model: InstanceModel) -> _ProcessRes sibling_update_map = _SiblingInstanceUpdateMap(id=sibling_instance_model.id) _set_status_update( update_map=sibling_update_map, - deferred_events=result.deferred_events, instance_model=sibling_instance_model, new_status=InstanceStatus.TERMINATED, termination_reason=InstanceTerminationReason.MASTER_FAILED, ) if len(sibling_update_map) > 1: result.sibling_update_rows.append(sibling_update_map) + _append_sibling_status_event( + deferred_events=result.sibling_deferred_events, + instance_model=sibling_instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=cast( + Optional[InstanceTerminationReason], + sibling_update_map.get("termination_reason"), + ), + termination_reason_message=cast( + Optional[str], sibling_update_map.get("termination_reason_message") + ), + ) return result @@ -1163,7 +1166,6 @@ async def _process_instance_check(instance_model: InstanceModel) -> _ProcessResu ): _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATING, termination_reason=InstanceTerminationReason.JOB_FINISHED, @@ -1189,7 +1191,6 @@ async def _process_instance_check(instance_model: InstanceModel) -> _ProcessResu if instance_model.status == InstanceStatus.PROVISIONING: _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.BUSY, ) @@ -1224,13 +1225,11 @@ async def _process_instance_check(instance_model: InstanceModel) -> _ProcessResu _set_health_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, health=health_status, ) _set_unreachable_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, unreachable=not instance_check.reachable, ) @@ -1240,7 +1239,6 @@ async def _process_instance_check(instance_model: InstanceModel) -> _ProcessResu if instance_model.status == InstanceStatus.PROVISIONING: _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.IDLE if not instance_model.jobs else InstanceStatus.BUSY, ) @@ -1261,7 +1259,6 @@ async def _process_instance_check(instance_model: InstanceModel) -> _ProcessResu if now > provisioning_deadline: _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATING, termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, @@ -1272,7 +1269,6 @@ async def _process_instance_check(instance_model: InstanceModel) -> _ProcessResu if deadline is not None and now > deadline: _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATING, termination_reason=InstanceTerminationReason.UNREACHABLE, @@ -1358,7 +1354,6 @@ async def _process_wait_for_instance_provisioning_data( if get_current_datetime() > provisioning_deadline: _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATING, termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, @@ -1377,7 +1372,6 @@ async def _process_wait_for_instance_provisioning_data( ) _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATING, termination_reason=InstanceTerminationReason.ERROR, @@ -1401,7 +1395,6 @@ async def _process_wait_for_instance_provisioning_data( ) _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATING, termination_reason=InstanceTerminationReason.ERROR, @@ -1663,7 +1656,6 @@ async def _process_terminate(instance_model: InstanceModel) -> _ProcessResult: result.instance_update_map["finished_at"] = NOW_PLACEHOLDER _set_status_update( update_map=result.instance_update_map, - deferred_events=result.deferred_events, instance_model=instance_model, new_status=InstanceStatus.TERMINATED, ) @@ -1735,7 +1727,6 @@ def _ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]: def _set_status_update( update_map: Union[_InstanceUpdateMap, _SiblingInstanceUpdateMap], - deferred_events: list[_DeferredEvent], instance_model: InstanceModel, new_status: InstanceStatus, termination_reason: object = _UNSET, @@ -1766,52 +1757,45 @@ def _set_status_update( update_map["termination_reason_message"] = effective_termination_reason_message update_map["status"] = new_status - deferred_events.append( - _DeferredEvent( - message=_format_status_change_message( - old_status=old_status, - new_status=new_status, - termination_reason=effective_termination_reason, - termination_reason_message=effective_termination_reason_message, - ), - project_id=instance_model.project_id, - instance_id=instance_model.id, - instance_name=instance_model.name, - ) - ) def _set_health_update( update_map: _InstanceUpdateMap, - deferred_events: list[_DeferredEvent], instance_model: InstanceModel, health: HealthStatus, ) -> None: if instance_model.health == health: return update_map["health"] = health - deferred_events.append( - _DeferredEvent( - message=f"Instance health changed {instance_model.health.upper()} -> {health.upper()}", - project_id=instance_model.project_id, - instance_id=instance_model.id, - instance_name=instance_model.name, - ) - ) def _set_unreachable_update( update_map: _InstanceUpdateMap, - deferred_events: list[_DeferredEvent], instance_model: InstanceModel, unreachable: bool, ) -> None: if not instance_model.status.is_available() or instance_model.unreachable == unreachable: return update_map["unreachable"] = unreachable + + +def _append_sibling_status_event( + deferred_events: list[_SiblingDeferredEvent], + instance_model: InstanceModel, + new_status: InstanceStatus, + termination_reason: Optional[InstanceTerminationReason], + termination_reason_message: Optional[str], +) -> None: + if instance_model.status == new_status: + return deferred_events.append( - _DeferredEvent( - message="Instance became unreachable" if unreachable else "Instance became reachable", + _SiblingDeferredEvent( + message=get_instance_status_change_message( + old_status=instance_model.status, + new_status=new_status, + termination_reason=termination_reason, + termination_reason_message=termination_reason_message, + ), project_id=instance_model.project_id, instance_id=instance_model.id, instance_name=instance_model.name, @@ -1819,35 +1803,105 @@ def _set_unreachable_update( ) -def _format_status_change_message( +def _get_effective_instance_status( + instance_model: InstanceModel, + update_map: _InstanceUpdateMap, +) -> InstanceStatus: + return cast(InstanceStatus, update_map.get("status", instance_model.status)) + + +def _get_effective_instance_termination_reason( + instance_model: InstanceModel, + update_map: _InstanceUpdateMap, +) -> Optional[InstanceTerminationReason]: + return cast( + Optional[InstanceTerminationReason], + update_map.get("termination_reason", instance_model.termination_reason), + ) + + +def _get_effective_instance_termination_reason_message( + instance_model: InstanceModel, + update_map: _InstanceUpdateMap, +) -> Optional[str]: + return cast( + Optional[str], + update_map.get("termination_reason_message", instance_model.termination_reason_message), + ) + + +def _get_effective_instance_health( + instance_model: InstanceModel, + update_map: _InstanceUpdateMap, +) -> HealthStatus: + return cast(HealthStatus, update_map.get("health", instance_model.health)) + + +def _get_effective_instance_unreachable( + instance_model: InstanceModel, + update_map: _InstanceUpdateMap, +) -> bool: + return cast(bool, update_map.get("unreachable", instance_model.unreachable)) + + +def _emit_instance_health_change_event( + session: AsyncSession, + instance_model: InstanceModel, + old_health: HealthStatus, + new_health: HealthStatus, +) -> None: + if old_health == new_health: + return + events.emit( + session=session, + message=f"Instance health changed {old_health.upper()} -> {new_health.upper()}", + actor=events.SystemActor(), + targets=[events.Target.from_model(instance_model)], + ) + + +def _emit_instance_reachability_change_event( + session: AsyncSession, + instance_model: InstanceModel, old_status: InstanceStatus, - new_status: InstanceStatus, - termination_reason: Optional[InstanceTerminationReason], - termination_reason_message: Optional[str], -) -> str: - message = f"Instance status changed {old_status.upper()} -> {new_status.upper()}" - if new_status == InstanceStatus.TERMINATING or ( - new_status == InstanceStatus.TERMINATED and old_status != InstanceStatus.TERMINATING - ): - if termination_reason is None: - raise ValueError( - f"termination_reason must be set when switching to {new_status.upper()} status" - ) - if ( - termination_reason == InstanceTerminationReason.ERROR - and not termination_reason_message - ): - raise ValueError( - "termination_reason_message must be set when termination_reason is ERROR" - ) - message += f". Termination reason: {termination_reason.upper()}" - if termination_reason_message: - message += f" ({termination_reason_message})" - return message + old_unreachable: bool, + new_unreachable: bool, +) -> None: + if not old_status.is_available() or old_unreachable == new_unreachable: + return + events.emit( + session=session, + message="Instance became unreachable" if new_unreachable else "Instance became reachable", + actor=events.SystemActor(), + targets=[events.Target.from_model(instance_model)], + ) async def _apply_process_result(item: InstancePipelineItem, result: _ProcessResult) -> None: async with get_session_ctx() as session: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .options( + load_only( + InstanceModel.id, + InstanceModel.project_id, + InstanceModel.name, + InstanceModel.status, + InstanceModel.termination_reason, + InstanceModel.termination_reason_message, + InstanceModel.health, + InstanceModel.unreachable, + ) + ) + ) + instance_model = res.scalar_one_or_none() + if instance_model is None: + log_lock_token_mismatch(logger, item) + return for placement_group_create in result.placement_group_creates: session.add(PlacementGroupModel(**placement_group_create)) if result.health_check_create is not None: @@ -1859,6 +1913,7 @@ async def _apply_process_result(item: InstancePipelineItem, result: _ProcessResu resolve_now_placeholders(result.sibling_update_rows, now=now) res = await session.execute( update(InstanceModel) + .execution_options(synchronize_session=False) .where( InstanceModel.id == item.id, InstanceModel.lock_token == item.lock_token, @@ -1885,7 +1940,35 @@ async def _apply_process_result(item: InstancePipelineItem, result: _ProcessResu except_placement_group_ids=result.schedule_pg_deletion_except_ids, ) - for deferred_event in result.deferred_events: + emit_instance_status_change_event( + session=session, + instance_model=instance_model, + old_status=instance_model.status, + new_status=_get_effective_instance_status(instance_model, result.instance_update_map), + termination_reason=_get_effective_instance_termination_reason( + instance_model, result.instance_update_map + ), + termination_reason_message=_get_effective_instance_termination_reason_message( + instance_model, result.instance_update_map + ), + ) + _emit_instance_health_change_event( + session=session, + instance_model=instance_model, + old_health=instance_model.health, + new_health=_get_effective_instance_health(instance_model, result.instance_update_map), + ) + _emit_instance_reachability_change_event( + session=session, + instance_model=instance_model, + old_status=instance_model.status, + old_unreachable=instance_model.unreachable, + new_unreachable=_get_effective_instance_unreachable( + instance_model, result.instance_update_map + ), + ) + + for deferred_event in result.sibling_deferred_events: events.emit( session=session, message=deferred_event.message, diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index 046f092c03..051463c57d 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -88,6 +88,8 @@ def switch_instance_status( instance_model=instance_model, old_status=old_status, new_status=new_status, + termination_reason=instance_model.termination_reason, + termination_reason_message=instance_model.termination_reason_message, actor=actor, ) @@ -97,20 +99,26 @@ def emit_instance_status_change_event( instance_model: InstanceModel, old_status: InstanceStatus, new_status: InstanceStatus, + termination_reason: Optional[InstanceTerminationReason], + termination_reason_message: Optional[str], actor: events.AnyActor = events.SystemActor(), ) -> None: if old_status == new_status: return msg = get_instance_status_change_message( - instance_model=instance_model, old_status=old_status, new_status=new_status, + termination_reason=termination_reason, + termination_reason_message=termination_reason_message, ) events.emit(session, msg, actor=actor, targets=[events.Target.from_model(instance_model)]) def get_instance_status_change_message( - instance_model: InstanceModel, old_status: InstanceStatus, new_status: InstanceStatus + old_status: InstanceStatus, + new_status: InstanceStatus, + termination_reason: Optional[InstanceTerminationReason], + termination_reason_message: Optional[str], ) -> str: msg = f"Instance status changed {old_status.upper()} -> {new_status.upper()}" if ( @@ -118,20 +126,20 @@ def get_instance_status_change_message( or new_status == InstanceStatus.TERMINATED and old_status != InstanceStatus.TERMINATING ): - if instance_model.termination_reason is None: + if termination_reason is None: raise ValueError( f"termination_reason must be set when switching to {new_status.upper()} status" ) if ( - instance_model.termination_reason == InstanceTerminationReason.ERROR - and not instance_model.termination_reason_message + termination_reason == InstanceTerminationReason.ERROR + and not termination_reason_message ): raise ValueError( "termination_reason_message must be set when termination_reason is ERROR" ) - msg += f". Termination reason: {instance_model.termination_reason.upper()}" - if instance_model.termination_reason_message: - msg += f" ({instance_model.termination_reason_message})" + msg += f". Termination reason: {termination_reason.upper()}" + if termination_reason_message: + msg += f" ({termination_reason_message})" return msg From 3a0b7ddde0e8abe8dfdcb7eafaa7467fac34fbd8 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 12:27:38 +0500 Subject: [PATCH 15/51] Rename --- .../background/pipeline_tasks/instances.py | 172 +++++++++--------- 1 file changed, 86 insertions(+), 86 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances.py b/src/dstack/_internal/server/background/pipeline_tasks/instances.py index 465622ff23..b3b46a8afe 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances.py @@ -425,8 +425,8 @@ async def _process_pending_item(item: InstancePipelineItem) -> Optional[_Process log_lock_token_mismatch(logger, item) return None if is_ssh_instance(instance_model): - return await _process_add_remote(instance_model) - return await _process_create_instance(instance_model) + return await _add_ssh_instance(instance_model) + return await _create_cloud_instance(instance_model) async def _process_provisioning_item(item: InstancePipelineItem) -> Optional[_ProcessResult]: @@ -435,7 +435,7 @@ async def _process_provisioning_item(item: InstancePipelineItem) -> Optional[_Pr if instance_model is None: log_lock_token_mismatch(logger, item) return None - return await _process_instance_check(instance_model) + return await _check_instance(instance_model) async def _process_idle_item(item: InstancePipelineItem) -> Optional[_ProcessResult]: @@ -447,7 +447,7 @@ async def _process_idle_item(item: InstancePipelineItem) -> Optional[_ProcessRes idle_result = _process_idle_timeout(instance_model) if idle_result is not None: return idle_result - return await _process_instance_check(instance_model) + return await _check_instance(instance_model) async def _process_busy_item(item: InstancePipelineItem) -> Optional[_ProcessResult]: @@ -456,7 +456,7 @@ async def _process_busy_item(item: InstancePipelineItem) -> Optional[_ProcessRes if instance_model is None: log_lock_token_mismatch(logger, item) return None - return await _process_instance_check(instance_model) + return await _check_instance(instance_model) async def _process_terminating_item(item: InstancePipelineItem) -> Optional[_ProcessResult]: @@ -468,7 +468,7 @@ async def _process_terminating_item(item: InstancePipelineItem) -> Optional[_Pro if instance_model is None: log_lock_token_mismatch(logger, item) return None - return await _process_terminate(instance_model) + return await _terminate_instance(instance_model) async def _refetch_locked_instance_status( @@ -596,7 +596,7 @@ def _can_terminate_fleet_instances_on_idle_duration(fleet_model: FleetModel) -> return len(active_instances) > fleet.spec.configuration.nodes.min -async def _process_add_remote(instance_model: InstanceModel) -> _ProcessResult: +async def _add_ssh_instance(instance_model: InstanceModel) -> _ProcessResult: result = _ProcessResult() logger.info("Adding ssh instance %s...", instance_model.name) @@ -638,7 +638,7 @@ async def _process_add_remote(instance_model: InstanceModel) -> _ProcessResult: try: future = run_async( - _deploy_instance, + _deploy_ssh_instance, remote_details, pkeys, ssh_proxy_pkeys, @@ -778,7 +778,7 @@ def _resolve_ssh_instance_network( return instance_network, internal_ip -def _deploy_instance( +def _deploy_ssh_instance( remote_details: RemoteConnectionInfo, pkeys: list[PKey], ssh_proxy_pkeys: Optional[list[PKey]], @@ -835,7 +835,7 @@ def _deploy_instance( return instance_check, host_info, arch -async def _process_create_instance(instance_model: InstanceModel) -> _ProcessResult: +async def _create_cloud_instance(instance_model: InstanceModel) -> _ProcessResult: result = _ProcessResult() master_instance_model = _get_fleet_master_instance(instance_model) if _need_to_wait_fleet_provisioning(instance_model, master_instance_model): @@ -1157,7 +1157,7 @@ async def _find_or_create_suitable_placement_group_state( ) -async def _process_instance_check(instance_model: InstanceModel) -> _ProcessResult: +async def _check_instance(instance_model: InstanceModel) -> _ProcessResult: result = _ProcessResult() if ( instance_model.status == InstanceStatus.BUSY @@ -1591,7 +1591,7 @@ def _get_instance_cpu_arch(instance_model: InstanceModel) -> Optional[gpuhunt.CP return job_provisioning_data.instance_type.resources.cpu_arch -async def _process_terminate(instance_model: InstanceModel) -> _ProcessResult: +async def _terminate_instance(instance_model: InstanceModel) -> _ProcessResult: result = _ProcessResult() now = get_current_datetime() if ( @@ -1803,80 +1803,6 @@ def _append_sibling_status_event( ) -def _get_effective_instance_status( - instance_model: InstanceModel, - update_map: _InstanceUpdateMap, -) -> InstanceStatus: - return cast(InstanceStatus, update_map.get("status", instance_model.status)) - - -def _get_effective_instance_termination_reason( - instance_model: InstanceModel, - update_map: _InstanceUpdateMap, -) -> Optional[InstanceTerminationReason]: - return cast( - Optional[InstanceTerminationReason], - update_map.get("termination_reason", instance_model.termination_reason), - ) - - -def _get_effective_instance_termination_reason_message( - instance_model: InstanceModel, - update_map: _InstanceUpdateMap, -) -> Optional[str]: - return cast( - Optional[str], - update_map.get("termination_reason_message", instance_model.termination_reason_message), - ) - - -def _get_effective_instance_health( - instance_model: InstanceModel, - update_map: _InstanceUpdateMap, -) -> HealthStatus: - return cast(HealthStatus, update_map.get("health", instance_model.health)) - - -def _get_effective_instance_unreachable( - instance_model: InstanceModel, - update_map: _InstanceUpdateMap, -) -> bool: - return cast(bool, update_map.get("unreachable", instance_model.unreachable)) - - -def _emit_instance_health_change_event( - session: AsyncSession, - instance_model: InstanceModel, - old_health: HealthStatus, - new_health: HealthStatus, -) -> None: - if old_health == new_health: - return - events.emit( - session=session, - message=f"Instance health changed {old_health.upper()} -> {new_health.upper()}", - actor=events.SystemActor(), - targets=[events.Target.from_model(instance_model)], - ) - - -def _emit_instance_reachability_change_event( - session: AsyncSession, - instance_model: InstanceModel, - old_status: InstanceStatus, - old_unreachable: bool, - new_unreachable: bool, -) -> None: - if not old_status.is_available() or old_unreachable == new_unreachable: - return - events.emit( - session=session, - message="Instance became unreachable" if new_unreachable else "Instance became reachable", - actor=events.SystemActor(), - targets=[events.Target.from_model(instance_model)], - ) - - async def _apply_process_result(item: InstancePipelineItem, result: _ProcessResult) -> None: async with get_session_ctx() as session: res = await session.execute( @@ -1983,3 +1909,77 @@ async def _apply_process_result(item: InstancePipelineItem, result: _ProcessResu ], ) await session.commit() + + +def _get_effective_instance_status( + instance_model: InstanceModel, + update_map: _InstanceUpdateMap, +) -> InstanceStatus: + return cast(InstanceStatus, update_map.get("status", instance_model.status)) + + +def _get_effective_instance_termination_reason( + instance_model: InstanceModel, + update_map: _InstanceUpdateMap, +) -> Optional[InstanceTerminationReason]: + return cast( + Optional[InstanceTerminationReason], + update_map.get("termination_reason", instance_model.termination_reason), + ) + + +def _get_effective_instance_termination_reason_message( + instance_model: InstanceModel, + update_map: _InstanceUpdateMap, +) -> Optional[str]: + return cast( + Optional[str], + update_map.get("termination_reason_message", instance_model.termination_reason_message), + ) + + +def _get_effective_instance_health( + instance_model: InstanceModel, + update_map: _InstanceUpdateMap, +) -> HealthStatus: + return cast(HealthStatus, update_map.get("health", instance_model.health)) + + +def _get_effective_instance_unreachable( + instance_model: InstanceModel, + update_map: _InstanceUpdateMap, +) -> bool: + return cast(bool, update_map.get("unreachable", instance_model.unreachable)) + + +def _emit_instance_health_change_event( + session: AsyncSession, + instance_model: InstanceModel, + old_health: HealthStatus, + new_health: HealthStatus, +) -> None: + if old_health == new_health: + return + events.emit( + session=session, + message=f"Instance health changed {old_health.upper()} -> {new_health.upper()}", + actor=events.SystemActor(), + targets=[events.Target.from_model(instance_model)], + ) + + +def _emit_instance_reachability_change_event( + session: AsyncSession, + instance_model: InstanceModel, + old_status: InstanceStatus, + old_unreachable: bool, + new_unreachable: bool, +) -> None: + if not old_status.is_available() or old_unreachable == new_unreachable: + return + events.emit( + session=session, + message="Instance became unreachable" if new_unreachable else "Instance became reachable", + actor=events.SystemActor(), + targets=[events.Target.from_model(instance_model)], + ) From 1ae0aae2627f84b9d2f30823eb05cd425087788d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 14:06:14 +0500 Subject: [PATCH 16/51] Refactor instance pipeline into modules --- .../background/pipeline_tasks/instances.py | 1985 --------------- .../pipeline_tasks/instances/__init__.py | 580 +++++ .../pipeline_tasks/instances/check.py | 519 ++++ .../instances/cloud_provisioning.py | 374 +++ .../pipeline_tasks/instances/common.py | 268 ++ .../pipeline_tasks/instances/ssh_deploy.py | 297 +++ .../pipeline_tasks/instances/termination.py | 88 + .../pipeline_tasks/test_instances.py | 2153 ----------------- .../pipeline_tasks/test_instances/__init__.py | 0 .../pipeline_tasks/test_instances/conftest.py | 58 + .../pipeline_tasks/test_instances/helpers.py | 40 + .../test_instances/test_check.py | 863 +++++++ .../test_instances/test_cloud_provisioning.py | 452 ++++ .../test_instances/test_pipeline.py | 253 ++ .../test_instances/test_ssh_deploy.py | 248 ++ .../test_instances/test_termination.py | 219 ++ .../server/services/test_instances.py | 118 + 17 files changed, 4377 insertions(+), 4138 deletions(-) delete mode 100644 src/dstack/_internal/server/background/pipeline_tasks/instances.py create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/instances/check.py create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/instances/common.py create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/instances/ssh_deploy.py create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py delete mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_instances.py create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_instances/__init__.py create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_instances/conftest.py create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_instances/helpers.py create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_instances/test_pipeline.py create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_instances/test_ssh_deploy.py create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_instances/test_termination.py diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances.py b/src/dstack/_internal/server/background/pipeline_tasks/instances.py deleted file mode 100644 index b3b46a8afe..0000000000 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances.py +++ /dev/null @@ -1,1985 +0,0 @@ -import asyncio -import datetime -import logging -import uuid -from dataclasses import dataclass, field -from datetime import timedelta -from typing import Any, Dict, Optional, Sequence, TypedDict, Union, cast - -import gpuhunt -import requests -from paramiko.pkey import PKey -from paramiko.ssh_exception import PasswordRequiredException -from pydantic import ValidationError -from sqlalchemy import and_, func, not_, or_, select, update -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload, load_only - -from dstack._internal import settings -from dstack._internal.core.backends.base.compute import ( - ComputeWithCreateInstanceSupport, - ComputeWithPlacementGroupSupport, - GoArchType, - generate_unique_placement_group_name, - get_dstack_runner_binary_path, - get_dstack_runner_download_url, - get_dstack_runner_version, - get_dstack_shim_binary_path, - get_dstack_shim_download_url, - get_dstack_shim_version, - get_dstack_working_dir, - get_shim_env, - get_shim_pre_start_commands, -) -from dstack._internal.core.backends.features import ( - BACKENDS_WITH_CREATE_INSTANCE_SUPPORT, - BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT, -) -from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT -from dstack._internal.core.errors import ( - BackendError, - NotYetTerminated, - PlacementGroupNotSupportedError, - ProvisioningError, - SSHProvisioningError, -) -from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.events import EventTargetType -from dstack._internal.core.models.fleets import InstanceGroupPlacement -from dstack._internal.core.models.health import HealthStatus -from dstack._internal.core.models.instances import ( - InstanceAvailability, - InstanceOfferWithAvailability, - InstanceRuntime, - InstanceStatus, - InstanceTerminationReason, - RemoteConnectionInfo, - SSHKey, -) -from dstack._internal.core.models.placement import ( - PlacementGroup, - PlacementGroupConfiguration, - PlacementStrategy, -) -from dstack._internal.core.models.profiles import TerminationPolicy -from dstack._internal.core.models.runs import JobProvisioningData -from dstack._internal.server import settings as server_settings -from dstack._internal.server.background.pipeline_tasks.base import ( - NOW_PLACEHOLDER, - Fetcher, - Heartbeater, - ItemUpdateMap, - Pipeline, - PipelineItem, - UpdateMapDateTime, - Worker, - log_lock_token_changed_after_processing, - log_lock_token_mismatch, - resolve_now_placeholders, - set_processed_update_map_fields, - set_unlock_update_map_fields, -) -from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout -from dstack._internal.server.db import get_db, get_session_ctx -from dstack._internal.server.models import ( - FleetModel, - InstanceHealthCheckModel, - InstanceModel, - JobModel, - PlacementGroupModel, - ProjectModel, -) -from dstack._internal.server.schemas.instances import InstanceCheck -from dstack._internal.server.schemas.runner import ( - ComponentInfo, - ComponentStatus, - HealthcheckResponse, - InstanceHealthResponse, -) -from dstack._internal.server.services import backends as backends_services -from dstack._internal.server.services import events -from dstack._internal.server.services.fleets import ( - fleet_model_to_fleet, - get_create_instance_offers, - is_cloud_cluster, -) -from dstack._internal.server.services.instances import ( - emit_instance_status_change_event, - get_instance_configuration, - get_instance_profile, - get_instance_provisioning_data, - get_instance_remote_connection_info, - get_instance_requirements, - get_instance_ssh_private_keys, - get_instance_status_change_message, - is_ssh_instance, - remove_dangling_tasks_from_instance, -) -from dstack._internal.server.services.locking import get_locker -from dstack._internal.server.services.logging import fmt -from dstack._internal.server.services.offers import ( - get_instance_offer_with_restricted_az, - is_divisible_into_blocks, -) -from dstack._internal.server.services.placement import ( - placement_group_model_to_placement_group, - schedule_fleet_placement_groups_deletion, -) -from dstack._internal.server.services.runner import client as runner_client -from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel -from dstack._internal.server.services.ssh_fleets.provisioning import ( - detect_cpu_arch, - get_host_info, - get_paramiko_connection, - get_shim_healthcheck, - host_info_to_instance_type, - remove_dstack_runner_if_exists, - remove_host_info_if_exists, - run_pre_start_commands, - run_shim_as_systemd_service, - upload_envs, -) -from dstack._internal.server.utils import sentry_utils -from dstack._internal.utils.common import get_current_datetime, get_or_error, run_async -from dstack._internal.utils.logging import get_logger -from dstack._internal.utils.network import get_ip_from_network, is_ip_among_addresses -from dstack._internal.utils.ssh import pkey_from_str - -logger = get_logger(__name__) - -TERMINATION_DEADLINE_OFFSET = timedelta(minutes=20) -TERMINATION_RETRY_TIMEOUT = timedelta(seconds=30) -TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15) -PROVISIONING_TIMEOUT_SECONDS = 10 * 60 # 10 minutes in seconds - -_UNSET = object() - - -@dataclass -class InstancePipelineItem(PipelineItem): - status: InstanceStatus - - -class InstancePipeline(Pipeline[InstancePipelineItem]): - def __init__( - self, - workers_num: int = 10, - queue_lower_limit_factor: float = 0.5, - queue_upper_limit_factor: float = 2.0, - min_processing_interval: timedelta = timedelta(seconds=10), - lock_timeout: timedelta = timedelta(seconds=30), - heartbeat_trigger: timedelta = timedelta(seconds=15), - ) -> None: - super().__init__( - workers_num=workers_num, - queue_lower_limit_factor=queue_lower_limit_factor, - queue_upper_limit_factor=queue_upper_limit_factor, - min_processing_interval=min_processing_interval, - lock_timeout=lock_timeout, - heartbeat_trigger=heartbeat_trigger, - ) - self.__heartbeater = Heartbeater[InstancePipelineItem]( - model_type=InstanceModel, - lock_timeout=self._lock_timeout, - heartbeat_trigger=self._heartbeat_trigger, - ) - self.__fetcher = InstanceFetcher( - queue=self._queue, - queue_desired_minsize=self._queue_desired_minsize, - min_processing_interval=self._min_processing_interval, - lock_timeout=self._lock_timeout, - heartbeater=self._heartbeater, - ) - self.__workers = [ - InstanceWorker(queue=self._queue, heartbeater=self._heartbeater) - for _ in range(self._workers_num) - ] - - @property - def hint_fetch_model_name(self) -> str: - return InstanceModel.__name__ - - @property - def _heartbeater(self) -> Heartbeater[InstancePipelineItem]: - return self.__heartbeater - - @property - def _fetcher(self) -> Fetcher[InstancePipelineItem]: - return self.__fetcher - - @property - def _workers(self) -> Sequence["InstanceWorker"]: - return self.__workers - - -class InstanceFetcher(Fetcher[InstancePipelineItem]): - def __init__( - self, - queue: asyncio.Queue[InstancePipelineItem], - queue_desired_minsize: int, - min_processing_interval: timedelta, - lock_timeout: timedelta, - heartbeater: Heartbeater[InstancePipelineItem], - queue_check_delay: float = 1.0, - ) -> None: - super().__init__( - queue=queue, - queue_desired_minsize=queue_desired_minsize, - min_processing_interval=min_processing_interval, - lock_timeout=lock_timeout, - heartbeater=heartbeater, - queue_check_delay=queue_check_delay, - ) - - @sentry_utils.instrument_named_task("pipeline_tasks.InstanceFetcher.fetch") - async def fetch(self, limit: int) -> list[InstancePipelineItem]: - instance_lock, _ = get_locker(get_db().dialect_name).get_lockset( - InstanceModel.__tablename__ - ) - async with instance_lock: - async with get_session_ctx() as session: - now = get_current_datetime() - res = await session.execute( - select(InstanceModel) - .where( - InstanceModel.status.in_( - [ - InstanceStatus.PENDING, - InstanceStatus.PROVISIONING, - InstanceStatus.BUSY, - InstanceStatus.IDLE, - InstanceStatus.TERMINATING, - ] - ), - not_( - and_( - InstanceModel.status == InstanceStatus.TERMINATING, - InstanceModel.compute_group_id.is_not(None), - ) - ), - InstanceModel.deleted == False, - InstanceModel.last_processed_at <= now - self._min_processing_interval, - or_( - InstanceModel.lock_expires_at.is_(None), - InstanceModel.lock_expires_at < now, - ), - or_( - InstanceModel.lock_owner.is_(None), - InstanceModel.lock_owner == InstancePipeline.__name__, - ), - ) - .order_by(InstanceModel.last_processed_at.asc()) - .limit(limit) - .with_for_update(skip_locked=True, key_share=True, of=InstanceModel) - .options( - load_only( - InstanceModel.id, - InstanceModel.lock_token, - InstanceModel.lock_expires_at, - InstanceModel.status, - ) - ) - ) - instance_models = list(res.scalars().all()) - lock_expires_at = get_current_datetime() + self._lock_timeout - lock_token = uuid.uuid4() - items = [] - for instance_model in instance_models: - prev_lock_expired = instance_model.lock_expires_at is not None - instance_model.lock_expires_at = lock_expires_at - instance_model.lock_token = lock_token - instance_model.lock_owner = InstancePipeline.__name__ - items.append( - InstancePipelineItem( - __tablename__=InstanceModel.__tablename__, - id=instance_model.id, - lock_expires_at=lock_expires_at, - lock_token=lock_token, - prev_lock_expired=prev_lock_expired, - status=instance_model.status, - ) - ) - await session.commit() - return items - - -class InstanceWorker(Worker[InstancePipelineItem]): - def __init__( - self, - queue: asyncio.Queue[InstancePipelineItem], - heartbeater: Heartbeater[InstancePipelineItem], - ) -> None: - super().__init__( - queue=queue, - heartbeater=heartbeater, - ) - - @sentry_utils.instrument_named_task("pipeline_tasks.InstanceWorker.process") - async def process(self, item: InstancePipelineItem): - async with get_session_ctx() as session: - instance_model = await _refetch_locked_instance_status(session=session, item=item) - if instance_model is None: - log_lock_token_mismatch(logger, item) - return - status = instance_model.status - - result: Optional[_ProcessResult] = None - if status == InstanceStatus.PENDING: - result = await _process_pending_item(item) - elif status == InstanceStatus.PROVISIONING: - result = await _process_provisioning_item(item) - elif status == InstanceStatus.IDLE: - result = await _process_idle_item(item) - elif status == InstanceStatus.BUSY: - result = await _process_busy_item(item) - elif status == InstanceStatus.TERMINATING: - result = await _process_terminating_item(item) - if result is None: - # FIXME: Item won't be unlocked!!! - return - set_processed_update_map_fields(result.instance_update_map) - set_unlock_update_map_fields(result.instance_update_map) - await _apply_process_result(item=item, result=result) - - -class _InstanceUpdateMap(ItemUpdateMap, total=False): - status: InstanceStatus - unreachable: bool - started_at: UpdateMapDateTime - finished_at: UpdateMapDateTime - instance_configuration: str - termination_deadline: Optional[datetime.datetime] - termination_reason: Optional[InstanceTerminationReason] - termination_reason_message: Optional[str] - health: HealthStatus - first_termination_retry_at: UpdateMapDateTime - last_termination_retry_at: UpdateMapDateTime - backend: BackendType - backend_data: Optional[str] - offer: str - region: str - price: float - job_provisioning_data: str - total_blocks: int - busy_blocks: int - deleted: bool - deleted_at: UpdateMapDateTime - - -class _SiblingInstanceUpdateMap(TypedDict, total=False): - id: uuid.UUID - status: InstanceStatus - termination_reason: Optional[InstanceTerminationReason] - termination_reason_message: Optional[str] - - -class _HealthCheckCreate(TypedDict): - instance_id: uuid.UUID - collected_at: datetime.datetime - status: HealthStatus - response: str - - -class _PlacementGroupCreate(TypedDict): - id: uuid.UUID - name: str - project_id: uuid.UUID - fleet_id: uuid.UUID - configuration: str - provisioning_data: str - - -@dataclass -class _SiblingDeferredEvent: - message: str - project_id: uuid.UUID - instance_id: uuid.UUID - instance_name: str - - -@dataclass -class _PlacementGroupState: - id: uuid.UUID - placement_group: PlacementGroup - create_payload: Optional[_PlacementGroupCreate] = None - - -@dataclass -class _ProcessResult: - instance_update_map: _InstanceUpdateMap = field(default_factory=_InstanceUpdateMap) - sibling_update_rows: list[_SiblingInstanceUpdateMap] = field(default_factory=list) - sibling_deferred_events: list[_SiblingDeferredEvent] = field(default_factory=list) - health_check_create: Optional[_HealthCheckCreate] = None - placement_group_creates: list[_PlacementGroupCreate] = field(default_factory=list) - schedule_pg_deletion_fleet_id: Optional[uuid.UUID] = None - schedule_pg_deletion_except_ids: tuple[uuid.UUID, ...] = () - - -async def _process_pending_item(item: InstancePipelineItem) -> Optional[_ProcessResult]: - async with get_session_ctx() as session: - instance_model = await _refetch_locked_instance_for_pending_or_terminating( - session=session, - item=item, - ) - if instance_model is None: - log_lock_token_mismatch(logger, item) - return None - if is_ssh_instance(instance_model): - return await _add_ssh_instance(instance_model) - return await _create_cloud_instance(instance_model) - - -async def _process_provisioning_item(item: InstancePipelineItem) -> Optional[_ProcessResult]: - async with get_session_ctx() as session: - instance_model = await _refetch_locked_instance_for_check(session=session, item=item) - if instance_model is None: - log_lock_token_mismatch(logger, item) - return None - return await _check_instance(instance_model) - - -async def _process_idle_item(item: InstancePipelineItem) -> Optional[_ProcessResult]: - async with get_session_ctx() as session: - instance_model = await _refetch_locked_instance_for_idle(session=session, item=item) - if instance_model is None: - log_lock_token_mismatch(logger, item) - return None - idle_result = _process_idle_timeout(instance_model) - if idle_result is not None: - return idle_result - return await _check_instance(instance_model) - - -async def _process_busy_item(item: InstancePipelineItem) -> Optional[_ProcessResult]: - async with get_session_ctx() as session: - instance_model = await _refetch_locked_instance_for_check(session=session, item=item) - if instance_model is None: - log_lock_token_mismatch(logger, item) - return None - return await _check_instance(instance_model) - - -async def _process_terminating_item(item: InstancePipelineItem) -> Optional[_ProcessResult]: - async with get_session_ctx() as session: - instance_model = await _refetch_locked_instance_for_pending_or_terminating( - session=session, - item=item, - ) - if instance_model is None: - log_lock_token_mismatch(logger, item) - return None - return await _terminate_instance(instance_model) - - -async def _refetch_locked_instance_status( - session: AsyncSession, - item: InstancePipelineItem, -) -> Optional[InstanceModel]: - res = await session.execute( - select(InstanceModel) - .where( - InstanceModel.id == item.id, - InstanceModel.lock_token == item.lock_token, - ) - .options(load_only(InstanceModel.id, InstanceModel.status)) - ) - return res.scalar_one_or_none() - - -async def _refetch_locked_instance_for_pending_or_terminating( - session: AsyncSession, - item: InstancePipelineItem, -) -> Optional[InstanceModel]: - res = await session.execute( - select(InstanceModel) - .where( - InstanceModel.id == item.id, - InstanceModel.lock_token == item.lock_token, - ) - .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) - .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) - .options( - joinedload(InstanceModel.fleet).joinedload(FleetModel.project), - ) - .options( - joinedload(InstanceModel.fleet) - .joinedload(FleetModel.instances.and_(InstanceModel.deleted == False)) - .joinedload(InstanceModel.project) - ) - .options( - joinedload(InstanceModel.fleet) - .joinedload(FleetModel.instances.and_(InstanceModel.deleted == False)) - .joinedload(InstanceModel.fleet) - ) - .execution_options(populate_existing=True) - ) - return res.unique().scalar_one_or_none() - - -async def _refetch_locked_instance_for_idle( - session: AsyncSession, - item: InstancePipelineItem, -) -> Optional[InstanceModel]: - res = await session.execute( - select(InstanceModel) - .where( - InstanceModel.id == item.id, - InstanceModel.lock_token == item.lock_token, - ) - .options(joinedload(InstanceModel.project)) - .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) - .options( - joinedload(InstanceModel.fleet).joinedload( - FleetModel.instances.and_(InstanceModel.deleted == False) - ) - ) - .execution_options(populate_existing=True) - ) - return res.unique().scalar_one_or_none() - - -async def _refetch_locked_instance_for_check( - session: AsyncSession, - item: InstancePipelineItem, -) -> Optional[InstanceModel]: - res = await session.execute( - select(InstanceModel) - .where( - InstanceModel.id == item.id, - InstanceModel.lock_token == item.lock_token, - ) - .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) - .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) - .execution_options(populate_existing=True) - ) - return res.unique().scalar_one_or_none() - - -def _process_idle_timeout(instance_model: InstanceModel) -> Optional[_ProcessResult]: - if not ( - instance_model.status == InstanceStatus.IDLE - and instance_model.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE - and not instance_model.jobs - ): - return None - if instance_model.fleet is not None and not _can_terminate_fleet_instances_on_idle_duration( - instance_model.fleet - ): - logger.debug( - "Skipping instance %s termination on idle duration. Fleet is already at `nodes.min`.", - instance_model.name, - ) - return None - - idle_duration = _get_instance_idle_duration(instance_model) - if idle_duration <= datetime.timedelta(seconds=instance_model.termination_idle_time): - return None - - result = _ProcessResult() - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATING, - termination_reason=InstanceTerminationReason.IDLE_TIMEOUT, - termination_reason_message=f"Instance idle for {idle_duration.seconds}s", - ) - return result - - -def _can_terminate_fleet_instances_on_idle_duration(fleet_model: FleetModel) -> bool: - fleet = fleet_model_to_fleet(fleet_model) - if fleet.spec.configuration.nodes is None or fleet.spec.autocreated: - return True - active_instances = [ - instance for instance in fleet_model.instances if instance.status.is_active() - ] - return len(active_instances) > fleet.spec.configuration.nodes.min - - -async def _add_ssh_instance(instance_model: InstanceModel) -> _ProcessResult: - result = _ProcessResult() - logger.info("Adding ssh instance %s...", instance_model.name) - - retry_duration_deadline = instance_model.created_at + timedelta( - seconds=PROVISIONING_TIMEOUT_SECONDS - ) - if retry_duration_deadline < get_current_datetime(): - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATED, - termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, - termination_reason_message=( - f"Failed to add SSH instance in {PROVISIONING_TIMEOUT_SECONDS}s" - ), - ) - return result - - remote_details = get_instance_remote_connection_info(instance_model) - assert remote_details is not None - - try: - pkeys = _ssh_keys_to_pkeys(remote_details.ssh_keys) - ssh_proxy_pkeys = None - if remote_details.ssh_proxy_keys is not None: - ssh_proxy_pkeys = _ssh_keys_to_pkeys(remote_details.ssh_proxy_keys) - except (ValueError, PasswordRequiredException): - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATED, - termination_reason=InstanceTerminationReason.ERROR, - termination_reason_message="Unsupported private SSH key type", - ) - return result - - authorized_keys = [pkey.public.strip() for pkey in remote_details.ssh_keys] - authorized_keys.append(instance_model.project.ssh_public_key.strip()) - - try: - future = run_async( - _deploy_ssh_instance, - remote_details, - pkeys, - ssh_proxy_pkeys, - authorized_keys, - ) - deploy_timeout = 20 * 60 - health, host_info, arch = await asyncio.wait_for(future, timeout=deploy_timeout) - except (asyncio.TimeoutError, TimeoutError) as exc: - logger.warning( - "%s: deploy timeout when adding SSH instance: %s", - fmt(instance_model), - repr(exc), - ) - return result - except SSHProvisioningError as exc: - logger.warning( - "%s: provisioning error when adding SSH instance: %s", - fmt(instance_model), - repr(exc), - ) - return result - except Exception: - logger.exception("%s: unexpected error when adding SSH instance", fmt(instance_model)) - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATED, - termination_reason=InstanceTerminationReason.ERROR, - termination_reason_message="Unexpected error when adding SSH instance", - ) - return result - - instance_type = host_info_to_instance_type(host_info, arch) - try: - instance_network, internal_ip = _resolve_ssh_instance_network(instance_model, host_info) - except _SSHInstanceNetworkResolutionError as exc: - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATED, - termination_reason=InstanceTerminationReason.ERROR, - termination_reason_message=str(exc), - ) - return result - - divisible, blocks = is_divisible_into_blocks( - cpu_count=instance_type.resources.cpus, - gpu_count=len(instance_type.resources.gpus), - blocks="auto" if instance_model.total_blocks is None else instance_model.total_blocks, - ) - if not divisible: - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATED, - termination_reason=InstanceTerminationReason.ERROR, - termination_reason_message="Cannot split into blocks", - ) - return result - - region = instance_model.region - assert region is not None - job_provisioning_data = JobProvisioningData( - backend=BackendType.REMOTE, - instance_type=instance_type, - instance_id="instance_id", - hostname=remote_details.host, - region=region, - price=0, - internal_ip=internal_ip, - instance_network=instance_network, - username=remote_details.ssh_user, - ssh_port=remote_details.port, - dockerized=True, - backend_data=None, - ssh_proxy=remote_details.ssh_proxy, - ) - instance_offer = InstanceOfferWithAvailability( - backend=BackendType.REMOTE, - instance=instance_type, - region=region, - price=0, - availability=InstanceAvailability.AVAILABLE, - instance_runtime=InstanceRuntime.SHIM, - ) - - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.IDLE if health else InstanceStatus.PROVISIONING, - ) - result.instance_update_map["backend"] = BackendType.REMOTE - result.instance_update_map["price"] = 0 - result.instance_update_map["offer"] = instance_offer.json() - result.instance_update_map["job_provisioning_data"] = job_provisioning_data.json() - result.instance_update_map["started_at"] = NOW_PLACEHOLDER - result.instance_update_map["total_blocks"] = blocks - return result - - -class _SSHInstanceNetworkResolutionError(Exception): - pass - - -def _resolve_ssh_instance_network( - instance_model: InstanceModel, - host_info: dict[str, Any], -) -> tuple[Optional[str], Optional[str]]: - instance_network = None - internal_ip = None - try: - default_job_provisioning_data = JobProvisioningData.__response__.parse_raw( - instance_model.job_provisioning_data - ) - instance_network = default_job_provisioning_data.instance_network - internal_ip = default_job_provisioning_data.internal_ip - except ValidationError: - pass - - host_network_addresses = host_info.get("addresses", []) - if internal_ip is None: - internal_ip = get_ip_from_network( - network=instance_network, - addresses=host_network_addresses, - ) - if instance_network is not None and internal_ip is None: - raise _SSHInstanceNetworkResolutionError( - "Failed to locate internal IP address on the given network" - ) - if internal_ip is not None and not is_ip_among_addresses( - ip_address=internal_ip, - addresses=host_network_addresses, - ): - raise _SSHInstanceNetworkResolutionError( - "Specified internal IP not found among instance interfaces" - ) - return instance_network, internal_ip - - -def _deploy_ssh_instance( - remote_details: RemoteConnectionInfo, - pkeys: list[PKey], - ssh_proxy_pkeys: Optional[list[PKey]], - authorized_keys: list[str], -) -> tuple[InstanceCheck, dict[str, Any], GoArchType]: - with get_paramiko_connection( - remote_details.ssh_user, - remote_details.host, - remote_details.port, - pkeys, - remote_details.ssh_proxy, - ssh_proxy_pkeys, - ) as client: - logger.debug("Connected to %s %s", remote_details.ssh_user, remote_details.host) - - arch = detect_cpu_arch(client) - logger.debug("%s: CPU arch is %s", remote_details.host, arch) - - shim_pre_start_commands = get_shim_pre_start_commands(arch=arch) - run_pre_start_commands(client, shim_pre_start_commands, authorized_keys) - logger.debug("The script for installing dstack has been executed") - - shim_envs = get_shim_env(arch=arch) - try: - fleet_configuration_envs = remote_details.env.as_dict() - except ValueError as exc: - raise SSHProvisioningError(f"Invalid Env: {exc}") from exc - shim_envs.update(fleet_configuration_envs) - dstack_working_dir = get_dstack_working_dir() - dstack_shim_binary_path = get_dstack_shim_binary_path() - dstack_runner_binary_path = get_dstack_runner_binary_path() - upload_envs(client, dstack_working_dir, shim_envs) - logger.debug("The dstack-shim environment variables have been installed") - - remove_host_info_if_exists(client, dstack_working_dir) - remove_dstack_runner_if_exists(client, dstack_runner_binary_path) - - run_shim_as_systemd_service( - client=client, - binary_path=dstack_shim_binary_path, - working_dir=dstack_working_dir, - dev=settings.DSTACK_VERSION is None, - ) - - host_info = get_host_info(client, dstack_working_dir) - logger.debug("Received a host_info %s", host_info) - - healthcheck_out = get_shim_healthcheck(client) - try: - healthcheck = HealthcheckResponse.__response__.parse_raw(healthcheck_out) - except ValueError as exc: - raise SSHProvisioningError(f"Cannot parse HealthcheckResponse: {exc}") from exc - instance_check = runner_client.healthcheck_response_to_instance_check(healthcheck) - return instance_check, host_info, arch - - -async def _create_cloud_instance(instance_model: InstanceModel) -> _ProcessResult: - result = _ProcessResult() - master_instance_model = _get_fleet_master_instance(instance_model) - if _need_to_wait_fleet_provisioning(instance_model, master_instance_model): - logger.debug( - "%s: waiting for the first instance in the fleet to be provisioned", - fmt(instance_model), - ) - return result - - try: - instance_configuration = get_instance_configuration(instance_model) - profile = get_instance_profile(instance_model) - requirements = get_instance_requirements(instance_model) - except ValidationError as exc: - logger.exception( - "%s: error parsing profile, requirements or instance configuration", - fmt(instance_model), - ) - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATED, - termination_reason=InstanceTerminationReason.ERROR, - termination_reason_message=( - f"Error to parse profile, requirements or instance_configuration: {exc}" - ), - ) - return result - - placement_group_states = await _get_fleet_placement_group_states(instance_model.fleet_id) - placement_group_state = _get_placement_group_state_for_instance( - placement_group_states=placement_group_states, - instance_model=instance_model, - master_instance_model=master_instance_model, - ) - offers = await get_create_instance_offers( - project=instance_model.project, - profile=profile, - requirements=requirements, - fleet_model=instance_model.fleet, - placement_group=( - placement_group_state.placement_group if placement_group_state is not None else None - ), - blocks="auto" if instance_model.total_blocks is None else instance_model.total_blocks, - exclude_not_available=True, - ) - - seen_placement_group_ids = {state.id for state in placement_group_states} - for backend, instance_offer in offers[: server_settings.MAX_OFFERS_TRIED]: - if instance_offer.backend not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT: - continue - compute = backend.compute() - assert isinstance(compute, ComputeWithCreateInstanceSupport) - selected_offer = _get_instance_offer_for_instance( - instance_offer=instance_offer, - instance_model=instance_model, - master_instance_model=master_instance_model, - ) - selected_placement_group_state = placement_group_state - if ( - instance_model.fleet is not None - and is_cloud_cluster(instance_model.fleet) - and instance_model.id == master_instance_model.id - and selected_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT - and isinstance(compute, ComputeWithPlacementGroupSupport) - and ( - compute.are_placement_groups_compatible_with_reservations(selected_offer.backend) - or instance_configuration.reservation is None - ) - ): - selected_placement_group_state = await _find_or_create_suitable_placement_group_state( - instance_model=instance_model, - placement_group_states=placement_group_states, - instance_offer=selected_offer, - compute=compute, - ) - if selected_placement_group_state is None: - continue - if ( - selected_placement_group_state.create_payload is not None - and selected_placement_group_state.id not in seen_placement_group_ids - ): - seen_placement_group_ids.add(selected_placement_group_state.id) - placement_group_states.append(selected_placement_group_state) - result.placement_group_creates.append( - selected_placement_group_state.create_payload - ) - - logger.debug( - "Trying %s in %s/%s for $%0.4f per hour", - selected_offer.instance.name, - selected_offer.backend.value, - selected_offer.region, - selected_offer.price, - ) - try: - job_provisioning_data = await run_async( - compute.create_instance, - selected_offer, - instance_configuration, - selected_placement_group_state.placement_group - if selected_placement_group_state is not None - else None, - ) - except BackendError as exc: - logger.warning( - "%s launch in %s/%s failed: %s", - selected_offer.instance.name, - selected_offer.backend.value, - selected_offer.region, - repr(exc), - extra={"instance_name": instance_model.name}, - ) - continue - except Exception: - logger.exception( - "Got exception when launching %s in %s/%s", - selected_offer.instance.name, - selected_offer.backend.value, - selected_offer.region, - ) - continue - - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.PROVISIONING, - ) - result.instance_update_map["backend"] = backend.TYPE - result.instance_update_map["region"] = selected_offer.region - result.instance_update_map["price"] = selected_offer.price - result.instance_update_map["instance_configuration"] = instance_configuration.json() - result.instance_update_map["job_provisioning_data"] = job_provisioning_data.json() - result.instance_update_map["offer"] = selected_offer.json() - result.instance_update_map["total_blocks"] = selected_offer.total_blocks - result.instance_update_map["started_at"] = NOW_PLACEHOLDER - - if instance_model.fleet_id is not None and instance_model.id == master_instance_model.id: - result.schedule_pg_deletion_fleet_id = instance_model.fleet_id - if selected_placement_group_state is not None: - result.schedule_pg_deletion_except_ids = (selected_placement_group_state.id,) - return result - - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATED, - termination_reason=InstanceTerminationReason.NO_OFFERS, - termination_reason_message="All offers failed" if offers else "No offers found", - ) - if ( - instance_model.fleet is not None - and instance_model.id == master_instance_model.id - and is_cloud_cluster(instance_model.fleet) - ): - for sibling_instance_model in instance_model.fleet.instances: - if sibling_instance_model.id == instance_model.id: - continue - sibling_update_map = _SiblingInstanceUpdateMap(id=sibling_instance_model.id) - _set_status_update( - update_map=sibling_update_map, - instance_model=sibling_instance_model, - new_status=InstanceStatus.TERMINATED, - termination_reason=InstanceTerminationReason.MASTER_FAILED, - ) - if len(sibling_update_map) > 1: - result.sibling_update_rows.append(sibling_update_map) - _append_sibling_status_event( - deferred_events=result.sibling_deferred_events, - instance_model=sibling_instance_model, - new_status=InstanceStatus.TERMINATED, - termination_reason=cast( - Optional[InstanceTerminationReason], - sibling_update_map.get("termination_reason"), - ), - termination_reason_message=cast( - Optional[str], sibling_update_map.get("termination_reason_message") - ), - ) - return result - - -def _get_fleet_master_instance(instance_model: InstanceModel) -> InstanceModel: - if instance_model.fleet is None: - return instance_model - fleet_instances = list(instance_model.fleet.instances) - if all(fleet_instance.id != instance_model.id for fleet_instance in fleet_instances): - fleet_instances.append(instance_model) - return min( - fleet_instances, - key=lambda fleet_instance: (fleet_instance.instance_num, fleet_instance.created_at), - ) - - -async def _get_fleet_placement_group_states( - fleet_id: Optional[uuid.UUID], -) -> list[_PlacementGroupState]: - if fleet_id is None: - return [] - async with get_session_ctx() as session: - res = await session.execute( - select(PlacementGroupModel) - .where( - PlacementGroupModel.fleet_id == fleet_id, - PlacementGroupModel.deleted == False, - PlacementGroupModel.fleet_deleted == False, - ) - .options(joinedload(PlacementGroupModel.project)) - ) - placement_group_models = list(res.unique().scalars().all()) - return [ - _PlacementGroupState( - id=placement_group_model.id, - placement_group=placement_group_model_to_placement_group(placement_group_model), - ) - for placement_group_model in placement_group_models - ] - - -def _get_placement_group_state_for_instance( - placement_group_states: list[_PlacementGroupState], - instance_model: InstanceModel, - master_instance_model: InstanceModel, -) -> Optional[_PlacementGroupState]: - if instance_model.id == master_instance_model.id: - return None - if len(placement_group_states) > 1: - logger.error( - ( - "Expected 0 or 1 placement groups associated with fleet %s, found %s." - " An incorrect placement group might have been selected for instance %s" - ), - instance_model.fleet_id, - len(placement_group_states), - instance_model.name, - ) - if placement_group_states: - return placement_group_states[0] - return None - - -async def _find_or_create_suitable_placement_group_state( - instance_model: InstanceModel, - placement_group_states: list[_PlacementGroupState], - instance_offer: InstanceOfferWithAvailability, - compute: ComputeWithPlacementGroupSupport, -) -> Optional[_PlacementGroupState]: - for placement_group_state in placement_group_states: - if compute.is_suitable_placement_group( - placement_group_state.placement_group, - instance_offer, - ): - return placement_group_state - - assert instance_model.fleet is not None - placement_group_id = uuid.uuid4() - placement_group_name = generate_unique_placement_group_name( - project_name=instance_model.project.name, - fleet_name=instance_model.fleet.name, - ) - placement_group = PlacementGroup( - name=placement_group_name, - project_name=instance_model.project.name, - configuration=PlacementGroupConfiguration( - backend=instance_offer.backend, - region=instance_offer.region, - placement_strategy=PlacementStrategy.CLUSTER, - ), - provisioning_data=None, - ) - logger.debug( - "Creating placement group %s in %s/%s", - placement_group.name, - placement_group.configuration.backend.value, - placement_group.configuration.region, - ) - try: - provisioning_data = await run_async( - compute.create_placement_group, - placement_group, - instance_offer, - ) - except PlacementGroupNotSupportedError: - logger.debug( - "Skipping offer %s because placement group not supported", - instance_offer.instance.name, - ) - return None - except BackendError as exc: - logger.warning( - "Failed to create placement group %s in %s/%s: %r", - placement_group.name, - placement_group.configuration.backend.value, - placement_group.configuration.region, - exc, - ) - return None - except Exception: - logger.exception( - "Got exception when creating placement group %s in %s/%s", - placement_group.name, - placement_group.configuration.backend.value, - placement_group.configuration.region, - ) - return None - - placement_group.provisioning_data = provisioning_data - return _PlacementGroupState( - id=placement_group_id, - placement_group=placement_group, - create_payload=_PlacementGroupCreate( - id=placement_group_id, - name=placement_group.name, - project_id=instance_model.project_id, - fleet_id=get_or_error(instance_model.fleet_id), - configuration=placement_group.configuration.json(), - provisioning_data=provisioning_data.json(), - ), - ) - - -async def _check_instance(instance_model: InstanceModel) -> _ProcessResult: - result = _ProcessResult() - if ( - instance_model.status == InstanceStatus.BUSY - and instance_model.jobs - and all(job.status.is_finished() for job in instance_model.jobs) - ): - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATING, - termination_reason=InstanceTerminationReason.JOB_FINISHED, - ) - logger.warning( - "Detected busy instance %s with finished job. Marked as TERMINATING", - instance_model.name, - extra={ - "instance_name": instance_model.name, - "instance_status": instance_model.status.value, - }, - ) - return result - - job_provisioning_data = get_or_error(get_instance_provisioning_data(instance_model)) - if job_provisioning_data.hostname is None: - return await _process_wait_for_instance_provisioning_data( - instance_model=instance_model, - job_provisioning_data=job_provisioning_data, - ) - - if not job_provisioning_data.dockerized: - if instance_model.status == InstanceStatus.PROVISIONING: - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.BUSY, - ) - return result - - check_instance_health = await _should_check_instance_health(instance_model.id) - instance_check = await _run_instance_check( - instance_model=instance_model, - job_provisioning_data=job_provisioning_data, - check_instance_health=check_instance_health, - ) - health_status = _get_health_status_for_instance_check( - instance_model=instance_model, - instance_check=instance_check, - check_instance_health=check_instance_health, - ) - _log_instance_check_result( - instance_model=instance_model, - instance_check=instance_check, - health_status=health_status, - check_instance_health=check_instance_health, - ) - - if instance_check.has_health_checks(): - assert instance_check.health_response is not None - result.health_check_create = _HealthCheckCreate( - instance_id=instance_model.id, - collected_at=get_current_datetime(), - status=health_status, - response=instance_check.health_response.json(), - ) - - _set_health_update( - update_map=result.instance_update_map, - instance_model=instance_model, - health=health_status, - ) - _set_unreachable_update( - update_map=result.instance_update_map, - instance_model=instance_model, - unreachable=not instance_check.reachable, - ) - - if instance_check.reachable: - result.instance_update_map["termination_deadline"] = None - if instance_model.status == InstanceStatus.PROVISIONING: - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.IDLE if not instance_model.jobs else InstanceStatus.BUSY, - ) - return result - - now = get_current_datetime() - if not is_ssh_instance(instance_model) and instance_model.termination_deadline is None: - result.instance_update_map["termination_deadline"] = now + TERMINATION_DEADLINE_OFFSET - - if ( - instance_model.status == InstanceStatus.PROVISIONING - and instance_model.started_at is not None - ): - provisioning_deadline = _get_provisioning_deadline( - instance_model=instance_model, - job_provisioning_data=job_provisioning_data, - ) - if now > provisioning_deadline: - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATING, - termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, - termination_reason_message="Instance did not become reachable in time", - ) - elif instance_model.status.is_available(): - deadline = instance_model.termination_deadline - if deadline is not None and now > deadline: - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATING, - termination_reason=InstanceTerminationReason.UNREACHABLE, - ) - return result - - -async def _should_check_instance_health(instance_id: uuid.UUID) -> bool: - health_check_cutoff = get_current_datetime() - timedelta( - seconds=server_settings.SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS - ) - async with get_session_ctx() as session: - res = await session.execute( - select(func.count(1)).where( - InstanceHealthCheckModel.instance_id == instance_id, - InstanceHealthCheckModel.collected_at > health_check_cutoff, - ) - ) - return res.scalar_one() == 0 - - -async def _run_instance_check( - instance_model: InstanceModel, - job_provisioning_data: JobProvisioningData, - check_instance_health: bool, -) -> InstanceCheck: - ssh_private_keys = get_instance_ssh_private_keys(instance_model) - instance_check = await run_async( - _check_instance_inner, - ssh_private_keys, - job_provisioning_data, - None, - instance=instance_model, - check_instance_health=check_instance_health, - ) - if instance_check is False: - return InstanceCheck(reachable=False, message="SSH or tunnel error") - return instance_check - - -def _get_health_status_for_instance_check( - instance_model: InstanceModel, - instance_check: InstanceCheck, - check_instance_health: bool, -) -> HealthStatus: - if instance_check.reachable and check_instance_health: - return instance_check.get_health_status() - return instance_model.health - - -def _log_instance_check_result( - instance_model: InstanceModel, - instance_check: InstanceCheck, - health_status: HealthStatus, - check_instance_health: bool, -) -> None: - loglevel = logging.DEBUG - if not instance_check.reachable and instance_model.status.is_available(): - loglevel = logging.WARNING - elif check_instance_health and not health_status.is_healthy(): - loglevel = logging.WARNING - logger.log( - loglevel, - "Instance %s check: reachable=%s health_status=%s message=%r", - instance_model.name, - instance_check.reachable, - health_status.name, - instance_check.message, - extra={"instance_name": instance_model.name, "health_status": health_status}, - ) - - -async def _process_wait_for_instance_provisioning_data( - instance_model: InstanceModel, - job_provisioning_data: JobProvisioningData, -) -> _ProcessResult: - result = _ProcessResult() - logger.debug("Waiting for instance %s to become running", instance_model.name) - provisioning_deadline = _get_provisioning_deadline( - instance_model=instance_model, - job_provisioning_data=job_provisioning_data, - ) - if get_current_datetime() > provisioning_deadline: - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATING, - termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, - termination_reason_message="Backend did not complete provisioning in time", - ) - return result - - backend = await backends_services.get_project_backend_by_type( - project=instance_model.project, - backend_type=job_provisioning_data.backend, - ) - if backend is None: - logger.warning( - "Instance %s failed because instance's backend is not available", - instance_model.name, - ) - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATING, - termination_reason=InstanceTerminationReason.ERROR, - termination_reason_message="Backend not available", - ) - return result - - try: - await run_async( - backend.compute().update_provisioning_data, - job_provisioning_data, - instance_model.project.ssh_public_key, - instance_model.project.ssh_private_key, - ) - result.instance_update_map["job_provisioning_data"] = job_provisioning_data.json() - except ProvisioningError as exc: - logger.warning( - "Error while waiting for instance %s to become running: %s", - instance_model.name, - repr(exc), - ) - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATING, - termination_reason=InstanceTerminationReason.ERROR, - termination_reason_message="Error while waiting for instance to become running", - ) - except Exception: - logger.exception( - "Got exception when updating instance %s provisioning data", - instance_model.name, - ) - return result - - -@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) -def _check_instance_inner( - ports: Dict[int, int], - *, - instance: InstanceModel, - check_instance_health: bool = False, -) -> InstanceCheck: - instance_health_response: Optional[InstanceHealthResponse] = None - shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) - method = shim_client.healthcheck - try: - healthcheck_response = method(unmask_exceptions=True) - if check_instance_health: - method = shim_client.get_instance_health - instance_health_response = method() - except requests.RequestException as exc: - template = "shim.%s(): request error: %s" - args = (method.__func__.__name__, exc) - logger.debug(template, *args) - return InstanceCheck(reachable=False, message=template % args) - except Exception as exc: - template = "shim.%s(): unexpected exception %s: %s" - args = (method.__func__.__name__, exc.__class__.__name__, exc) - logger.exception(template, *args) - return InstanceCheck(reachable=False, message=template % args) - - try: - remove_dangling_tasks_from_instance(shim_client, instance) - except Exception as exc: - logger.exception("%s: error removing dangling tasks: %s", fmt(instance), exc) - - _maybe_install_components(instance, shim_client) - return runner_client.healthcheck_response_to_instance_check( - healthcheck_response, - instance_health_response, - ) - - -def _maybe_install_components( - instance_model: InstanceModel, - shim_client: runner_client.ShimClient, -) -> None: - try: - components = shim_client.get_components() - except requests.RequestException as exc: - logger.warning( - "Instance %s: shim.get_components(): request error: %s", instance_model.name, exc - ) - return - if components is None: - logger.debug("Instance %s: no components info", instance_model.name) - return - - installed_shim_version: Optional[str] = None - installation_requested = False - - if (runner_info := components.runner) is not None: - installation_requested |= _maybe_install_runner(instance_model, shim_client, runner_info) - else: - logger.debug("Instance %s: no runner info", instance_model.name) - - if (shim_info := components.shim) is not None: - if shim_info.status == ComponentStatus.INSTALLED: - installed_shim_version = shim_info.version - installation_requested |= _maybe_install_shim(instance_model, shim_client, shim_info) - else: - logger.debug("Instance %s: no shim info", instance_model.name) - - running_shim_version = shim_client.get_version_string() - if ( - installed_shim_version is None - or installed_shim_version == running_shim_version - or installation_requested - or any(component.status == ComponentStatus.INSTALLING for component in components) - or not shim_client.is_safe_to_restart() - ): - return - - if shim_client.shutdown(force=False): - logger.debug( - "Instance %s: restarting shim %s -> %s", - instance_model.name, - running_shim_version, - installed_shim_version, - ) - else: - logger.debug("Instance %s: cannot restart shim", instance_model.name) - - -def _maybe_install_runner( - instance_model: InstanceModel, - shim_client: runner_client.ShimClient, - runner_info: ComponentInfo, -) -> bool: - expected_version = get_dstack_runner_version() - if expected_version is None: - logger.debug("Cannot determine the expected runner version") - return False - - installed_version = runner_info.version - logger.debug( - "Instance %s: runner status=%s installed_version=%s", - instance_model.name, - runner_info.status.value, - installed_version or "(no version)", - ) - if runner_info.status == ComponentStatus.INSTALLING: - logger.debug("Instance %s: runner is already being installed", instance_model.name) - return False - if installed_version and installed_version == expected_version: - logger.debug("Instance %s: expected runner version already installed", instance_model.name) - return False - - url = get_dstack_runner_download_url( - arch=_get_instance_cpu_arch(instance_model), - version=expected_version, - ) - logger.debug( - "Instance %s: installing runner %s -> %s from %s", - instance_model.name, - installed_version or "(no version)", - expected_version, - url, - ) - try: - shim_client.install_runner(url) - return True - except requests.RequestException as exc: - logger.warning("Instance %s: shim.install_runner(): %s", instance_model.name, exc) - return False - - -def _maybe_install_shim( - instance_model: InstanceModel, - shim_client: runner_client.ShimClient, - shim_info: ComponentInfo, -) -> bool: - expected_version = get_dstack_shim_version() - if expected_version is None: - logger.debug("Cannot determine the expected shim version") - return False - - installed_version = shim_info.version - logger.debug( - "Instance %s: shim status=%s installed_version=%s running_version=%s", - instance_model.name, - shim_info.status.value, - installed_version or "(no version)", - shim_client.get_version_string(), - ) - if shim_info.status == ComponentStatus.INSTALLING: - logger.debug("Instance %s: shim is already being installed", instance_model.name) - return False - if installed_version and installed_version == expected_version: - logger.debug("Instance %s: expected shim version already installed", instance_model.name) - return False - - url = get_dstack_shim_download_url( - arch=_get_instance_cpu_arch(instance_model), - version=expected_version, - ) - logger.debug( - "Instance %s: installing shim %s -> %s from %s", - instance_model.name, - installed_version or "(no version)", - expected_version, - url, - ) - try: - shim_client.install_shim(url) - return True - except requests.RequestException as exc: - logger.warning("Instance %s: shim.install_shim(): %s", instance_model.name, exc) - return False - - -def _get_instance_cpu_arch(instance_model: InstanceModel) -> Optional[gpuhunt.CPUArchitecture]: - job_provisioning_data = get_instance_provisioning_data(instance_model) - if job_provisioning_data is None: - return None - return job_provisioning_data.instance_type.resources.cpu_arch - - -async def _terminate_instance(instance_model: InstanceModel) -> _ProcessResult: - result = _ProcessResult() - now = get_current_datetime() - if ( - instance_model.last_termination_retry_at is not None - and _next_termination_retry_at(instance_model.last_termination_retry_at) > now - ): - return result - - job_provisioning_data = get_instance_provisioning_data(instance_model) - if job_provisioning_data is not None and job_provisioning_data.backend != BackendType.REMOTE: - backend = await backends_services.get_project_backend_by_type( - project=instance_model.project, - backend_type=job_provisioning_data.backend, - ) - if backend is None: - logger.error( - "Failed to terminate instance %s. Backend %s not available.", - instance_model.name, - job_provisioning_data.backend, - ) - else: - logger.debug("Terminating runner instance %s", job_provisioning_data.hostname) - try: - await run_async( - backend.compute().terminate_instance, - job_provisioning_data.instance_id, - job_provisioning_data.region, - job_provisioning_data.backend_data, - ) - except Exception as exc: - first_retry_at = instance_model.first_termination_retry_at - if first_retry_at is None: - first_retry_at = now - result.instance_update_map["first_termination_retry_at"] = NOW_PLACEHOLDER - result.instance_update_map["last_termination_retry_at"] = NOW_PLACEHOLDER - if _next_termination_retry_at(now) < _get_termination_deadline(first_retry_at): - if isinstance(exc, NotYetTerminated): - logger.debug( - "Instance %s termination in progress: %s", - instance_model.name, - exc, - ) - else: - logger.warning( - "Failed to terminate instance %s. Will retry. Error: %r", - instance_model.name, - exc, - exc_info=not isinstance(exc, BackendError), - ) - return result - logger.error( - "Failed all attempts to terminate instance %s." - " Please terminate the instance manually to avoid unexpected charges." - " Error: %r", - instance_model.name, - exc, - exc_info=not isinstance(exc, BackendError), - ) - - result.instance_update_map["deleted"] = True - result.instance_update_map["deleted_at"] = NOW_PLACEHOLDER - result.instance_update_map["finished_at"] = NOW_PLACEHOLDER - _set_status_update( - update_map=result.instance_update_map, - instance_model=instance_model, - new_status=InstanceStatus.TERMINATED, - ) - return result - - -def _next_termination_retry_at(last_termination_retry_at: datetime.datetime) -> datetime.datetime: - return last_termination_retry_at + TERMINATION_RETRY_TIMEOUT - - -def _get_termination_deadline(first_termination_retry_at: datetime.datetime) -> datetime.datetime: - return first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION - - -def _need_to_wait_fleet_provisioning( - instance_model: InstanceModel, - master_instance_model: InstanceModel, -) -> bool: - if instance_model.fleet is None: - return False - if ( - instance_model.id == master_instance_model.id - or master_instance_model.job_provisioning_data is not None - or master_instance_model.status == InstanceStatus.TERMINATED - ): - return False - return is_cloud_cluster(instance_model.fleet) - - -def _get_instance_offer_for_instance( - instance_offer: InstanceOfferWithAvailability, - instance_model: InstanceModel, - master_instance_model: InstanceModel, -) -> InstanceOfferWithAvailability: - if instance_model.fleet is None: - return instance_offer - fleet = fleet_model_to_fleet(instance_model.fleet) - if fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER: - master_job_provisioning_data = get_instance_provisioning_data(master_instance_model) - return get_instance_offer_with_restricted_az( - instance_offer=instance_offer, - master_job_provisioning_data=master_job_provisioning_data, - ) - return instance_offer - - -def _get_instance_idle_duration(instance_model: InstanceModel) -> datetime.timedelta: - last_time = instance_model.created_at - if instance_model.last_job_processed_at is not None: - last_time = instance_model.last_job_processed_at - return get_current_datetime() - last_time - - -def _get_provisioning_deadline( - instance_model: InstanceModel, - job_provisioning_data: JobProvisioningData, -) -> datetime.datetime: - assert instance_model.started_at is not None - timeout_interval = get_provisioning_timeout( - backend_type=job_provisioning_data.get_base_backend(), - instance_type_name=job_provisioning_data.instance_type.name, - ) - return instance_model.started_at + timeout_interval - - -def _ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]: - return [pkey_from_str(ssh_key.private) for ssh_key in ssh_keys if ssh_key.private is not None] - - -def _set_status_update( - update_map: Union[_InstanceUpdateMap, _SiblingInstanceUpdateMap], - instance_model: InstanceModel, - new_status: InstanceStatus, - termination_reason: object = _UNSET, - termination_reason_message: object = _UNSET, -) -> None: - old_status = instance_model.status - if old_status == new_status: - if termination_reason is not _UNSET: - update_map["termination_reason"] = cast( - Optional[InstanceTerminationReason], termination_reason - ) - if termination_reason_message is not _UNSET: - update_map["termination_reason_message"] = cast( - Optional[str], termination_reason_message - ) - return - - effective_termination_reason = instance_model.termination_reason - if termination_reason is not _UNSET: - effective_termination_reason = cast( - Optional[InstanceTerminationReason], termination_reason - ) - update_map["termination_reason"] = effective_termination_reason - - effective_termination_reason_message = instance_model.termination_reason_message - if termination_reason_message is not _UNSET: - effective_termination_reason_message = cast(Optional[str], termination_reason_message) - update_map["termination_reason_message"] = effective_termination_reason_message - - update_map["status"] = new_status - - -def _set_health_update( - update_map: _InstanceUpdateMap, - instance_model: InstanceModel, - health: HealthStatus, -) -> None: - if instance_model.health == health: - return - update_map["health"] = health - - -def _set_unreachable_update( - update_map: _InstanceUpdateMap, - instance_model: InstanceModel, - unreachable: bool, -) -> None: - if not instance_model.status.is_available() or instance_model.unreachable == unreachable: - return - update_map["unreachable"] = unreachable - - -def _append_sibling_status_event( - deferred_events: list[_SiblingDeferredEvent], - instance_model: InstanceModel, - new_status: InstanceStatus, - termination_reason: Optional[InstanceTerminationReason], - termination_reason_message: Optional[str], -) -> None: - if instance_model.status == new_status: - return - deferred_events.append( - _SiblingDeferredEvent( - message=get_instance_status_change_message( - old_status=instance_model.status, - new_status=new_status, - termination_reason=termination_reason, - termination_reason_message=termination_reason_message, - ), - project_id=instance_model.project_id, - instance_id=instance_model.id, - instance_name=instance_model.name, - ) - ) - - -async def _apply_process_result(item: InstancePipelineItem, result: _ProcessResult) -> None: - async with get_session_ctx() as session: - res = await session.execute( - select(InstanceModel) - .where( - InstanceModel.id == item.id, - InstanceModel.lock_token == item.lock_token, - ) - .options( - load_only( - InstanceModel.id, - InstanceModel.project_id, - InstanceModel.name, - InstanceModel.status, - InstanceModel.termination_reason, - InstanceModel.termination_reason_message, - InstanceModel.health, - InstanceModel.unreachable, - ) - ) - ) - instance_model = res.scalar_one_or_none() - if instance_model is None: - log_lock_token_mismatch(logger, item) - return - for placement_group_create in result.placement_group_creates: - session.add(PlacementGroupModel(**placement_group_create)) - if result.health_check_create is not None: - session.add(InstanceHealthCheckModel(**result.health_check_create)) - await session.flush() - - now = get_current_datetime() - resolve_now_placeholders(result.instance_update_map, now=now) - resolve_now_placeholders(result.sibling_update_rows, now=now) - res = await session.execute( - update(InstanceModel) - .execution_options(synchronize_session=False) - .where( - InstanceModel.id == item.id, - InstanceModel.lock_token == item.lock_token, - ) - .values(**result.instance_update_map) - .returning(InstanceModel.id) - ) - updated_ids = list(res.scalars().all()) - if len(updated_ids) == 0: - log_lock_token_changed_after_processing(logger, item) - await session.rollback() - return - - if result.sibling_update_rows: - await session.execute( - update(InstanceModel).execution_options(synchronize_session=False), - result.sibling_update_rows, - ) - - if result.schedule_pg_deletion_fleet_id is not None: - await schedule_fleet_placement_groups_deletion( - session=session, - fleet_id=result.schedule_pg_deletion_fleet_id, - except_placement_group_ids=result.schedule_pg_deletion_except_ids, - ) - - emit_instance_status_change_event( - session=session, - instance_model=instance_model, - old_status=instance_model.status, - new_status=_get_effective_instance_status(instance_model, result.instance_update_map), - termination_reason=_get_effective_instance_termination_reason( - instance_model, result.instance_update_map - ), - termination_reason_message=_get_effective_instance_termination_reason_message( - instance_model, result.instance_update_map - ), - ) - _emit_instance_health_change_event( - session=session, - instance_model=instance_model, - old_health=instance_model.health, - new_health=_get_effective_instance_health(instance_model, result.instance_update_map), - ) - _emit_instance_reachability_change_event( - session=session, - instance_model=instance_model, - old_status=instance_model.status, - old_unreachable=instance_model.unreachable, - new_unreachable=_get_effective_instance_unreachable( - instance_model, result.instance_update_map - ), - ) - - for deferred_event in result.sibling_deferred_events: - events.emit( - session=session, - message=deferred_event.message, - actor=events.SystemActor(), - targets=[ - events.Target( - type=EventTargetType.INSTANCE, - project_id=deferred_event.project_id, - id=deferred_event.instance_id, - name=deferred_event.instance_name, - ) - ], - ) - await session.commit() - - -def _get_effective_instance_status( - instance_model: InstanceModel, - update_map: _InstanceUpdateMap, -) -> InstanceStatus: - return cast(InstanceStatus, update_map.get("status", instance_model.status)) - - -def _get_effective_instance_termination_reason( - instance_model: InstanceModel, - update_map: _InstanceUpdateMap, -) -> Optional[InstanceTerminationReason]: - return cast( - Optional[InstanceTerminationReason], - update_map.get("termination_reason", instance_model.termination_reason), - ) - - -def _get_effective_instance_termination_reason_message( - instance_model: InstanceModel, - update_map: _InstanceUpdateMap, -) -> Optional[str]: - return cast( - Optional[str], - update_map.get("termination_reason_message", instance_model.termination_reason_message), - ) - - -def _get_effective_instance_health( - instance_model: InstanceModel, - update_map: _InstanceUpdateMap, -) -> HealthStatus: - return cast(HealthStatus, update_map.get("health", instance_model.health)) - - -def _get_effective_instance_unreachable( - instance_model: InstanceModel, - update_map: _InstanceUpdateMap, -) -> bool: - return cast(bool, update_map.get("unreachable", instance_model.unreachable)) - - -def _emit_instance_health_change_event( - session: AsyncSession, - instance_model: InstanceModel, - old_health: HealthStatus, - new_health: HealthStatus, -) -> None: - if old_health == new_health: - return - events.emit( - session=session, - message=f"Instance health changed {old_health.upper()} -> {new_health.upper()}", - actor=events.SystemActor(), - targets=[events.Target.from_model(instance_model)], - ) - - -def _emit_instance_reachability_change_event( - session: AsyncSession, - instance_model: InstanceModel, - old_status: InstanceStatus, - old_unreachable: bool, - new_unreachable: bool, -) -> None: - if not old_status.is_available() or old_unreachable == new_unreachable: - return - events.emit( - session=session, - message="Instance became unreachable" if new_unreachable else "Instance became reachable", - actor=events.SystemActor(), - targets=[events.Target.from_model(instance_model)], - ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py new file mode 100644 index 0000000000..d3ca224cea --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -0,0 +1,580 @@ +import asyncio +import uuid +from dataclasses import dataclass +from datetime import timedelta +from typing import Optional, Sequence + +from sqlalchemy import and_, not_, or_, select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload, load_only +from sqlalchemy.orm.attributes import set_committed_value + +from dstack._internal.core.models.health import HealthStatus +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + Pipeline, + PipelineItem, + Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, +) +from dstack._internal.server.background.pipeline_tasks.instances.check import ( + check_instance, + process_idle_timeout, +) +from dstack._internal.server.background.pipeline_tasks.instances.cloud_provisioning import ( + create_cloud_instance, +) +from dstack._internal.server.background.pipeline_tasks.instances.common import ( + InstanceUpdateMap, + ProcessResult, +) +from dstack._internal.server.background.pipeline_tasks.instances.ssh_deploy import ( + add_ssh_instance, +) +from dstack._internal.server.background.pipeline_tasks.instances.termination import ( + terminate_instance, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ( + FleetModel, + InstanceHealthCheckModel, + InstanceModel, + JobModel, + PlacementGroupModel, + ProjectModel, +) +from dstack._internal.server.services import events +from dstack._internal.server.services.instances import ( + emit_instance_status_change_event, + is_ssh_instance, +) +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.placement import ( + schedule_fleet_placement_groups_deletion, +) +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class InstancePipelineItem(PipelineItem): + status: InstanceStatus + + +class InstancePipeline(Pipeline[InstancePipelineItem]): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=10), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[InstancePipelineItem]( + model_type=InstanceModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = InstanceFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + InstanceWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return InstanceModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater[InstancePipelineItem]: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher[InstancePipelineItem]: + return self.__fetcher + + @property + def _workers(self) -> Sequence["InstanceWorker"]: + return self.__workers + + +class InstanceFetcher(Fetcher[InstancePipelineItem]): + def __init__( + self, + queue: asyncio.Queue[InstancePipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[InstancePipelineItem], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.InstanceFetcher.fetch") + async def fetch(self, limit: int) -> list[InstancePipelineItem]: + instance_lock, _ = get_locker(get_db().dialect_name).get_lockset( + InstanceModel.__tablename__ + ) + async with instance_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.status.in_( + [ + InstanceStatus.PENDING, + InstanceStatus.PROVISIONING, + InstanceStatus.BUSY, + InstanceStatus.IDLE, + InstanceStatus.TERMINATING, + ] + ), + not_( + and_( + InstanceModel.status == InstanceStatus.TERMINATING, + InstanceModel.compute_group_id.is_not(None), + ) + ), + InstanceModel.deleted == False, + InstanceModel.last_processed_at <= now - self._min_processing_interval, + or_( + InstanceModel.lock_expires_at.is_(None), + InstanceModel.lock_expires_at < now, + ), + or_( + InstanceModel.lock_owner.is_(None), + InstanceModel.lock_owner == InstancePipeline.__name__, + ), + ) + .order_by(InstanceModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True, of=InstanceModel) + .options( + load_only( + InstanceModel.id, + InstanceModel.lock_token, + InstanceModel.lock_expires_at, + InstanceModel.status, + ) + ) + ) + instance_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for instance_model in instance_models: + prev_lock_expired = instance_model.lock_expires_at is not None + instance_model.lock_expires_at = lock_expires_at + instance_model.lock_token = lock_token + instance_model.lock_owner = InstancePipeline.__name__ + items.append( + InstancePipelineItem( + __tablename__=InstanceModel.__tablename__, + id=instance_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + status=instance_model.status, + ) + ) + await session.commit() + return items + + +class InstanceWorker(Worker[InstancePipelineItem]): + def __init__( + self, + queue: asyncio.Queue[InstancePipelineItem], + heartbeater: Heartbeater[InstancePipelineItem], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.InstanceWorker.process") + async def process(self, item: InstancePipelineItem): + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_status(session=session, item=item) + if instance_model is None: + log_lock_token_mismatch(logger, item) + return + status = instance_model.status + + result: Optional[ProcessResult] = None + if status == InstanceStatus.PENDING: + result = await _process_pending_item(item) + elif status == InstanceStatus.PROVISIONING: + result = await _process_provisioning_item(item) + elif status == InstanceStatus.IDLE: + result = await _process_idle_item(item) + elif status == InstanceStatus.BUSY: + result = await _process_busy_item(item) + elif status == InstanceStatus.TERMINATING: + result = await _process_terminating_item(item) + if result is None: + return + set_processed_update_map_fields(result.instance_update_map) + set_unlock_update_map_fields(result.instance_update_map) + await _apply_process_result(item=item, result=result) + + +async def _process_pending_item(item: InstancePipelineItem) -> Optional[ProcessResult]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_pending_or_terminating( + session=session, + item=item, + ) + if instance_model is None: + log_lock_token_mismatch(logger, item) + return None + if is_ssh_instance(instance_model): + return await add_ssh_instance(instance_model) + return await create_cloud_instance(instance_model) + + +async def _process_provisioning_item(item: InstancePipelineItem) -> Optional[ProcessResult]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_check(session=session, item=item) + if instance_model is None: + log_lock_token_mismatch(logger, item) + return None + return await check_instance(instance_model) + + +async def _process_idle_item(item: InstancePipelineItem) -> Optional[ProcessResult]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_idle(session=session, item=item) + if instance_model is None: + log_lock_token_mismatch(logger, item) + return None + idle_result = process_idle_timeout(instance_model) + if idle_result is not None: + return idle_result + return await check_instance(instance_model) + + +async def _process_busy_item(item: InstancePipelineItem) -> Optional[ProcessResult]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_check(session=session, item=item) + if instance_model is None: + log_lock_token_mismatch(logger, item) + return None + return await check_instance(instance_model) + + +async def _process_terminating_item(item: InstancePipelineItem) -> Optional[ProcessResult]: + async with get_session_ctx() as session: + instance_model = await _refetch_locked_instance_for_pending_or_terminating( + session=session, + item=item, + ) + if instance_model is None: + log_lock_token_mismatch(logger, item) + return None + return await terminate_instance(instance_model) + + +async def _refetch_locked_instance_status( + session: AsyncSession, item: InstancePipelineItem +) -> Optional[InstanceModel]: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .options(load_only(InstanceModel.status)) + ) + return res.scalar_one_or_none() + + +async def _refetch_locked_instance_for_pending_or_terminating( + session: AsyncSession, item: InstancePipelineItem +) -> Optional[InstanceModel]: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) + .options( + joinedload(InstanceModel.fleet).joinedload( + FleetModel.instances.and_(InstanceModel.deleted == False) + ) + ) + ) + instance_model = res.unique().scalar_one_or_none() + if instance_model is not None: + # Pending/terminating processing runs on detached objects and later traverses + # `fleet.project`, sibling `project`, and sibling `fleet`. Populate those attrs from + # already known objects so detached access works without adding extra joins. + _populate_pending_or_terminating_detached_relations(instance_model) + return instance_model + + +def _populate_pending_or_terminating_detached_relations( + instance_model: InstanceModel, +) -> None: + project = instance_model.project + fleet = instance_model.fleet + if fleet is None: + return + set_committed_value(fleet, "project", project) + for sibling_instance_model in fleet.instances: + set_committed_value(sibling_instance_model, "project", project) + set_committed_value(sibling_instance_model, "fleet", fleet) + + +async def _refetch_locked_instance_for_idle( + session: AsyncSession, item: InstancePipelineItem +) -> Optional[InstanceModel]: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .options(joinedload(InstanceModel.project)) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) + .options( + joinedload(InstanceModel.fleet).joinedload( + FleetModel.instances.and_(InstanceModel.deleted == False) + ) + ) + ) + return res.unique().scalar_one_or_none() + + +async def _refetch_locked_instance_for_check( + session: AsyncSession, item: InstancePipelineItem +) -> Optional[InstanceModel]: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) + ) + return res.unique().scalar_one_or_none() + + +def _get_effective_instance_status( + instance_model: InstanceModel, update_map: InstanceUpdateMap +) -> InstanceStatus: + return update_map.get("status", instance_model.status) + + +def _get_effective_instance_termination_reason( + instance_model: InstanceModel, update_map: InstanceUpdateMap +): + return update_map.get("termination_reason", instance_model.termination_reason) + + +def _get_effective_instance_termination_reason_message( + instance_model: InstanceModel, update_map: InstanceUpdateMap +): + return update_map.get("termination_reason_message", instance_model.termination_reason_message) + + +def _get_effective_instance_health( + instance_model: InstanceModel, update_map: InstanceUpdateMap +) -> HealthStatus: + return update_map.get("health", instance_model.health) + + +def _get_effective_instance_unreachable( + instance_model: InstanceModel, update_map: InstanceUpdateMap +) -> bool: + return update_map.get("unreachable", instance_model.unreachable) + + +def _emit_instance_health_change_event( + session: AsyncSession, + instance_model: InstanceModel, + old_health: HealthStatus, + new_health: HealthStatus, +) -> None: + if old_health == new_health: + return + events.emit( + session, + f"Instance health changed {old_health.upper()} -> {new_health.upper()}", + actor=events.SystemActor(), + targets=[events.Target.from_model(instance_model)], + ) + + +def _emit_instance_reachability_change_event( + session: AsyncSession, + instance_model: InstanceModel, + old_status: InstanceStatus, + old_unreachable: bool, + new_unreachable: bool, +) -> None: + if not old_status.is_available() or old_unreachable == new_unreachable: + return + events.emit( + session, + "Instance became unreachable" if new_unreachable else "Instance became reachable", + actor=events.SystemActor(), + targets=[events.Target.from_model(instance_model)], + ) + + +async def _apply_process_result(item: InstancePipelineItem, result: ProcessResult) -> None: + async with get_session_ctx() as session: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .options( + load_only( + InstanceModel.id, + InstanceModel.project_id, + InstanceModel.name, + InstanceModel.status, + InstanceModel.health, + InstanceModel.unreachable, + InstanceModel.termination_reason, + InstanceModel.termination_reason_message, + InstanceModel.lock_token, + ) + ) + ) + instance_model = res.scalar_one_or_none() + if instance_model is None: + log_lock_token_changed_after_processing(logger, item) + return + + if result.health_check_create is not None: + session.add(InstanceHealthCheckModel(**result.health_check_create)) + if result.placement_group_creates: + session.add_all( + PlacementGroupModel(**placement_group_create) + for placement_group_create in result.placement_group_creates + ) + if result.health_check_create is not None or result.placement_group_creates: + await session.flush() + + now = get_current_datetime() + resolve_now_placeholders(result.instance_update_map, now=now) + resolve_now_placeholders(result.sibling_update_rows, now=now) + + res = await session.execute( + update(InstanceModel) + .where( + InstanceModel.id == item.id, + InstanceModel.lock_token == item.lock_token, + ) + .values(**result.instance_update_map) + .execution_options(synchronize_session=False) + .returning(InstanceModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + await session.rollback() + log_lock_token_changed_after_processing(logger, item) + return + + for sibling_update_row in result.sibling_update_rows: + sibling_id = sibling_update_row.get("id") + if sibling_id is None: + continue + sibling_values = { + key: value for key, value in sibling_update_row.items() if key != "id" + } + if sibling_values: + await session.execute( + update(InstanceModel) + .where(InstanceModel.id == sibling_id) + .values(**sibling_values) + ) + + if result.schedule_pg_deletion_fleet_id is not None: + await schedule_fleet_placement_groups_deletion( + session=session, + fleet_id=result.schedule_pg_deletion_fleet_id, + except_placement_group_ids=result.schedule_pg_deletion_except_ids, + ) + + emit_instance_status_change_event( + session=session, + instance_model=instance_model, + old_status=instance_model.status, + new_status=_get_effective_instance_status(instance_model, result.instance_update_map), + termination_reason=_get_effective_instance_termination_reason( + instance_model, result.instance_update_map + ), + termination_reason_message=_get_effective_instance_termination_reason_message( + instance_model, result.instance_update_map + ), + ) + _emit_instance_health_change_event( + session=session, + instance_model=instance_model, + old_health=instance_model.health, + new_health=_get_effective_instance_health(instance_model, result.instance_update_map), + ) + _emit_instance_reachability_change_event( + session=session, + instance_model=instance_model, + old_status=instance_model.status, + old_unreachable=instance_model.unreachable, + new_unreachable=_get_effective_instance_unreachable( + instance_model, result.instance_update_map + ), + ) + + for sibling_deferred_event in result.sibling_deferred_events: + events.emit( + session, + sibling_deferred_event.message, + actor=events.SystemActor(), + targets=[ + events.Target( + type=events.EventTargetType.INSTANCE, + project_id=sibling_deferred_event.project_id, + id=sibling_deferred_event.instance_id, + name=sibling_deferred_event.instance_name, + ) + ], + ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py new file mode 100644 index 0000000000..6cf75e827a --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py @@ -0,0 +1,519 @@ +import logging +from datetime import timedelta +from typing import Dict, Optional + +import gpuhunt +import requests +from sqlalchemy import func, select + +from dstack._internal.core.backends.base.compute import ( + get_dstack_runner_download_url, + get_dstack_runner_version, + get_dstack_shim_download_url, + get_dstack_shim_version, +) +from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT +from dstack._internal.core.errors import ProvisioningError +from dstack._internal.core.models.health import HealthStatus +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.core.models.profiles import TerminationPolicy +from dstack._internal.core.models.runs import JobProvisioningData +from dstack._internal.server import settings as server_settings +from dstack._internal.server.background.pipeline_tasks.instances.common import ( + TERMINATION_DEADLINE_OFFSET, + HealthCheckCreate, + ProcessResult, + can_terminate_fleet_instances_on_idle_duration, + get_instance_idle_duration, + get_provisioning_deadline, + set_health_update, + set_status_update, + set_unreachable_update, +) +from dstack._internal.server.db import get_session_ctx +from dstack._internal.server.models import InstanceHealthCheckModel, InstanceModel +from dstack._internal.server.schemas.instances import InstanceCheck +from dstack._internal.server.schemas.runner import ( + ComponentInfo, + ComponentStatus, + InstanceHealthResponse, +) +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services.instances import ( + get_instance_provisioning_data, + get_instance_ssh_private_keys, + is_ssh_instance, + remove_dangling_tasks_from_instance, +) +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.runner import client as runner_client +from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel +from dstack._internal.utils.common import get_current_datetime, get_or_error, run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +def process_idle_timeout(instance_model: InstanceModel) -> Optional[ProcessResult]: + if not ( + instance_model.status == InstanceStatus.IDLE + and instance_model.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE + and not instance_model.jobs + ): + return None + if instance_model.fleet is not None and not can_terminate_fleet_instances_on_idle_duration( + instance_model.fleet + ): + logger.debug( + "Skipping instance %s termination on idle duration. Fleet is already at `nodes.min`.", + instance_model.name, + ) + return None + + idle_duration = get_instance_idle_duration(instance_model) + if idle_duration <= timedelta(seconds=instance_model.termination_idle_time): + return None + + result = ProcessResult() + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.IDLE_TIMEOUT, + termination_reason_message=f"Instance idle for {idle_duration.seconds}s", + ) + return result + + +async def check_instance(instance_model: InstanceModel) -> ProcessResult: + result = ProcessResult() + if ( + instance_model.status == InstanceStatus.BUSY + and instance_model.jobs + and all(job.status.is_finished() for job in instance_model.jobs) + ): + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.JOB_FINISHED, + ) + logger.warning( + "Detected busy instance %s with finished job. Marked as TERMINATING", + instance_model.name, + extra={ + "instance_name": instance_model.name, + "instance_status": instance_model.status.value, + }, + ) + return result + + job_provisioning_data = get_or_error(get_instance_provisioning_data(instance_model)) + if job_provisioning_data.hostname is None: + return await _process_wait_for_instance_provisioning_data( + instance_model=instance_model, + job_provisioning_data=job_provisioning_data, + ) + + if not job_provisioning_data.dockerized: + if instance_model.status == InstanceStatus.PROVISIONING: + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.BUSY, + ) + return result + + check_instance_health = await _should_check_instance_health(instance_model.id) + instance_check = await _run_instance_check( + instance_model=instance_model, + job_provisioning_data=job_provisioning_data, + check_instance_health=check_instance_health, + ) + health_status = _get_health_status_for_instance_check( + instance_model=instance_model, + instance_check=instance_check, + check_instance_health=check_instance_health, + ) + _log_instance_check_result( + instance_model=instance_model, + instance_check=instance_check, + health_status=health_status, + check_instance_health=check_instance_health, + ) + + if instance_check.has_health_checks(): + assert instance_check.health_response is not None + result.health_check_create = HealthCheckCreate( + instance_id=instance_model.id, + collected_at=get_current_datetime(), + status=health_status, + response=instance_check.health_response.json(), + ) + + set_health_update( + update_map=result.instance_update_map, + instance_model=instance_model, + health=health_status, + ) + set_unreachable_update( + update_map=result.instance_update_map, + instance_model=instance_model, + unreachable=not instance_check.reachable, + ) + + if instance_check.reachable: + result.instance_update_map["termination_deadline"] = None + if instance_model.status == InstanceStatus.PROVISIONING: + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.IDLE if not instance_model.jobs else InstanceStatus.BUSY, + ) + return result + + now = get_current_datetime() + if not is_ssh_instance(instance_model) and instance_model.termination_deadline is None: + result.instance_update_map["termination_deadline"] = now + TERMINATION_DEADLINE_OFFSET + + if ( + instance_model.status == InstanceStatus.PROVISIONING + and instance_model.started_at is not None + ): + provisioning_deadline = get_provisioning_deadline( + instance_model=instance_model, + job_provisioning_data=job_provisioning_data, + ) + if now > provisioning_deadline: + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, + termination_reason_message="Instance did not become reachable in time", + ) + elif instance_model.status.is_available(): + deadline = instance_model.termination_deadline + if deadline is not None and now > deadline: + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.UNREACHABLE, + ) + return result + + +async def _should_check_instance_health(instance_id) -> bool: + health_check_cutoff = get_current_datetime() - timedelta( + seconds=server_settings.SERVER_INSTANCE_HEALTH_MIN_COLLECT_INTERVAL_SECONDS + ) + async with get_session_ctx() as session: + res = await session.execute( + select(func.count(1)).where( + InstanceHealthCheckModel.instance_id == instance_id, + InstanceHealthCheckModel.collected_at > health_check_cutoff, + ) + ) + return res.scalar_one() == 0 + + +async def _run_instance_check( + instance_model: InstanceModel, + job_provisioning_data: JobProvisioningData, + check_instance_health: bool, +) -> InstanceCheck: + ssh_private_keys = get_instance_ssh_private_keys(instance_model) + instance_check = await run_async( + _check_instance_inner, + ssh_private_keys, + job_provisioning_data, + None, + instance=instance_model, + check_instance_health=check_instance_health, + ) + if instance_check is False: + return InstanceCheck(reachable=False, message="SSH or tunnel error") + return instance_check + + +def _get_health_status_for_instance_check( + instance_model: InstanceModel, + instance_check: InstanceCheck, + check_instance_health: bool, +) -> HealthStatus: + if instance_check.reachable and check_instance_health: + return instance_check.get_health_status() + return instance_model.health + + +def _log_instance_check_result( + instance_model: InstanceModel, + instance_check: InstanceCheck, + health_status: HealthStatus, + check_instance_health: bool, +) -> None: + loglevel = logging.DEBUG + if not instance_check.reachable and instance_model.status.is_available(): + loglevel = logging.WARNING + elif check_instance_health and not health_status.is_healthy(): + loglevel = logging.WARNING + logger.log( + loglevel, + "Instance %s check: reachable=%s health_status=%s message=%r", + instance_model.name, + instance_check.reachable, + health_status.name, + instance_check.message, + extra={"instance_name": instance_model.name, "health_status": health_status}, + ) + + +async def _process_wait_for_instance_provisioning_data( + instance_model: InstanceModel, + job_provisioning_data: JobProvisioningData, +) -> ProcessResult: + result = ProcessResult() + logger.debug("Waiting for instance %s to become running", instance_model.name) + provisioning_deadline = get_provisioning_deadline( + instance_model=instance_model, + job_provisioning_data=job_provisioning_data, + ) + if get_current_datetime() > provisioning_deadline: + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, + termination_reason_message="Backend did not complete provisioning in time", + ) + return result + + backend = await backends_services.get_project_backend_by_type( + project=instance_model.project, + backend_type=job_provisioning_data.backend, + ) + if backend is None: + logger.warning( + "Instance %s failed because instance's backend is not available", + instance_model.name, + ) + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Backend not available", + ) + return result + + try: + await run_async( + backend.compute().update_provisioning_data, + job_provisioning_data, + instance_model.project.ssh_public_key, + instance_model.project.ssh_private_key, + ) + result.instance_update_map["job_provisioning_data"] = job_provisioning_data.json() + except ProvisioningError as exc: + logger.warning( + "Error while waiting for instance %s to become running: %s", + instance_model.name, + repr(exc), + ) + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATING, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Error while waiting for instance to become running", + ) + except Exception: + logger.exception( + "Got exception when updating instance %s provisioning data", + instance_model.name, + ) + return result + + +@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) +def _check_instance_inner( + ports: Dict[int, int], + *, + instance: InstanceModel, + check_instance_health: bool = False, +) -> InstanceCheck: + instance_health_response: Optional[InstanceHealthResponse] = None + shim_client = runner_client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + method = shim_client.healthcheck + try: + healthcheck_response = method(unmask_exceptions=True) + if check_instance_health: + method = shim_client.get_instance_health + instance_health_response = method() + except requests.RequestException as exc: + template = "shim.%s(): request error: %s" + args = (method.__func__.__name__, exc) + logger.debug(template, *args) + return InstanceCheck(reachable=False, message=template % args) + except Exception as exc: + template = "shim.%s(): unexpected exception %s: %s" + args = (method.__func__.__name__, exc.__class__.__name__, exc) + logger.exception(template, *args) + return InstanceCheck(reachable=False, message=template % args) + + try: + remove_dangling_tasks_from_instance(shim_client, instance) + except Exception as exc: + logger.exception("%s: error removing dangling tasks: %s", fmt(instance), exc) + + _maybe_install_components(instance, shim_client) + return runner_client.healthcheck_response_to_instance_check( + healthcheck_response, + instance_health_response, + ) + + +def _maybe_install_components( + instance_model: InstanceModel, + shim_client: runner_client.ShimClient, +) -> None: + try: + components = shim_client.get_components() + except requests.RequestException as exc: + logger.warning( + "Instance %s: shim.get_components(): request error: %s", instance_model.name, exc + ) + return + if components is None: + logger.debug("Instance %s: no components info", instance_model.name) + return + + installed_shim_version: Optional[str] = None + installation_requested = False + + if (runner_info := components.runner) is not None: + installation_requested |= _maybe_install_runner(instance_model, shim_client, runner_info) + else: + logger.debug("Instance %s: no runner info", instance_model.name) + + if (shim_info := components.shim) is not None: + if shim_info.status == ComponentStatus.INSTALLED: + installed_shim_version = shim_info.version + installation_requested |= _maybe_install_shim(instance_model, shim_client, shim_info) + else: + logger.debug("Instance %s: no shim info", instance_model.name) + + running_shim_version = shim_client.get_version_string() + if ( + installed_shim_version is None + or installed_shim_version == running_shim_version + or installation_requested + or any(component.status == ComponentStatus.INSTALLING for component in components) + or not shim_client.is_safe_to_restart() + ): + return + + if shim_client.shutdown(force=False): + logger.debug( + "Instance %s: restarting shim %s -> %s", + instance_model.name, + running_shim_version, + installed_shim_version, + ) + else: + logger.debug("Instance %s: cannot restart shim", instance_model.name) + + +def _maybe_install_runner( + instance_model: InstanceModel, + shim_client: runner_client.ShimClient, + runner_info: ComponentInfo, +) -> bool: + expected_version = get_dstack_runner_version() + if expected_version is None: + logger.debug("Cannot determine the expected runner version") + return False + + installed_version = runner_info.version + logger.debug( + "Instance %s: runner status=%s installed_version=%s", + instance_model.name, + runner_info.status.value, + installed_version or "(no version)", + ) + if runner_info.status == ComponentStatus.INSTALLING: + logger.debug("Instance %s: runner is already being installed", instance_model.name) + return False + if installed_version and installed_version == expected_version: + logger.debug("Instance %s: expected runner version already installed", instance_model.name) + return False + + url = get_dstack_runner_download_url( + arch=_get_instance_cpu_arch(instance_model), + version=expected_version, + ) + logger.debug( + "Instance %s: installing runner %s -> %s from %s", + instance_model.name, + installed_version or "(no version)", + expected_version, + url, + ) + try: + shim_client.install_runner(url) + return True + except requests.RequestException as exc: + logger.warning("Instance %s: shim.install_runner(): %s", instance_model.name, exc) + return False + + +def _maybe_install_shim( + instance_model: InstanceModel, + shim_client: runner_client.ShimClient, + shim_info: ComponentInfo, +) -> bool: + expected_version = get_dstack_shim_version() + if expected_version is None: + logger.debug("Cannot determine the expected shim version") + return False + + installed_version = shim_info.version + logger.debug( + "Instance %s: shim status=%s installed_version=%s running_version=%s", + instance_model.name, + shim_info.status.value, + installed_version or "(no version)", + shim_client.get_version_string(), + ) + if shim_info.status == ComponentStatus.INSTALLING: + logger.debug("Instance %s: shim is already being installed", instance_model.name) + return False + if installed_version and installed_version == expected_version: + logger.debug("Instance %s: expected shim version already installed", instance_model.name) + return False + + url = get_dstack_shim_download_url( + arch=_get_instance_cpu_arch(instance_model), + version=expected_version, + ) + logger.debug( + "Instance %s: installing shim %s -> %s from %s", + instance_model.name, + installed_version or "(no version)", + expected_version, + url, + ) + try: + shim_client.install_shim(url) + return True + except requests.RequestException as exc: + logger.warning("Instance %s: shim.install_shim(): %s", instance_model.name, exc) + return False + + +def _get_instance_cpu_arch(instance_model: InstanceModel) -> Optional[gpuhunt.CPUArchitecture]: + job_provisioning_data = get_instance_provisioning_data(instance_model) + if job_provisioning_data is None: + return None + return job_provisioning_data.instance_type.resources.cpu_arch diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py new file mode 100644 index 0000000000..4a5a658022 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py @@ -0,0 +1,374 @@ +import uuid +from dataclasses import dataclass +from typing import Optional, cast + +from pydantic import ValidationError +from sqlalchemy import select +from sqlalchemy.orm import joinedload + +from dstack._internal.core.backends.base.compute import ( + ComputeWithCreateInstanceSupport, + ComputeWithPlacementGroupSupport, + generate_unique_placement_group_name, +) +from dstack._internal.core.backends.features import ( + BACKENDS_WITH_CREATE_INSTANCE_SUPPORT, + BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT, +) +from dstack._internal.core.errors import ( + BackendError, + PlacementGroupNotSupportedError, +) +from dstack._internal.core.models.instances import ( + InstanceOfferWithAvailability, + InstanceStatus, + InstanceTerminationReason, +) +from dstack._internal.core.models.placement import ( + PlacementGroup, + PlacementGroupConfiguration, + PlacementStrategy, +) +from dstack._internal.server import settings as server_settings +from dstack._internal.server.background.pipeline_tasks.base import NOW_PLACEHOLDER +from dstack._internal.server.background.pipeline_tasks.instances.common import ( + PlacementGroupCreate, + ProcessResult, + SiblingInstanceUpdateMap, + append_sibling_status_event, + get_fleet_master_instance, + get_instance_offer_for_instance, + need_to_wait_fleet_provisioning, + set_status_update, +) +from dstack._internal.server.db import get_session_ctx +from dstack._internal.server.models import InstanceModel, PlacementGroupModel +from dstack._internal.server.services.fleets import get_create_instance_offers, is_cloud_cluster +from dstack._internal.server.services.instances import ( + get_instance_configuration, + get_instance_profile, + get_instance_requirements, +) +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.placement import placement_group_model_to_placement_group +from dstack._internal.utils.common import get_or_error, run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class _PlacementGroupState: + id: uuid.UUID + placement_group: PlacementGroup + create_payload: Optional[PlacementGroupCreate] = None + + +async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: + result = ProcessResult() + master_instance_model = get_fleet_master_instance(instance_model) + if need_to_wait_fleet_provisioning(instance_model, master_instance_model): + logger.debug( + "%s: waiting for the first instance in the fleet to be provisioned", + fmt(instance_model), + ) + return result + + try: + instance_configuration = get_instance_configuration(instance_model) + profile = get_instance_profile(instance_model) + requirements = get_instance_requirements(instance_model) + except ValidationError as exc: + logger.exception( + "%s: error parsing profile, requirements or instance configuration", + fmt(instance_model), + ) + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message=( + f"Error to parse profile, requirements or instance_configuration: {exc}" + ), + ) + return result + + placement_group_states = await _get_fleet_placement_group_states(instance_model.fleet_id) + placement_group_state = _get_placement_group_state_for_instance( + placement_group_states=placement_group_states, + instance_model=instance_model, + master_instance_model=master_instance_model, + ) + offers = await get_create_instance_offers( + project=instance_model.project, + profile=profile, + requirements=requirements, + fleet_model=instance_model.fleet, + placement_group=( + placement_group_state.placement_group if placement_group_state is not None else None + ), + blocks="auto" if instance_model.total_blocks is None else instance_model.total_blocks, + exclude_not_available=True, + ) + + seen_placement_group_ids = {state.id for state in placement_group_states} + for backend, instance_offer in offers[: server_settings.MAX_OFFERS_TRIED]: + if instance_offer.backend not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT: + continue + compute = backend.compute() + assert isinstance(compute, ComputeWithCreateInstanceSupport) + selected_offer = get_instance_offer_for_instance( + instance_offer=instance_offer, + instance_model=instance_model, + master_instance_model=master_instance_model, + ) + selected_placement_group_state = placement_group_state + if ( + instance_model.fleet is not None + and is_cloud_cluster(instance_model.fleet) + and instance_model.id == master_instance_model.id + and selected_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT + and isinstance(compute, ComputeWithPlacementGroupSupport) + and ( + compute.are_placement_groups_compatible_with_reservations(selected_offer.backend) + or instance_configuration.reservation is None + ) + ): + selected_placement_group_state = await _find_or_create_suitable_placement_group_state( + instance_model=instance_model, + placement_group_states=placement_group_states, + instance_offer=selected_offer, + compute=compute, + ) + if selected_placement_group_state is None: + continue + if ( + selected_placement_group_state.create_payload is not None + and selected_placement_group_state.id not in seen_placement_group_ids + ): + seen_placement_group_ids.add(selected_placement_group_state.id) + placement_group_states.append(selected_placement_group_state) + result.placement_group_creates.append( + selected_placement_group_state.create_payload + ) + + logger.debug( + "Trying %s in %s/%s for $%0.4f per hour", + selected_offer.instance.name, + selected_offer.backend.value, + selected_offer.region, + selected_offer.price, + ) + try: + job_provisioning_data = await run_async( + compute.create_instance, + selected_offer, + instance_configuration, + selected_placement_group_state.placement_group + if selected_placement_group_state is not None + else None, + ) + except BackendError as exc: + logger.warning( + "%s launch in %s/%s failed: %s", + selected_offer.instance.name, + selected_offer.backend.value, + selected_offer.region, + repr(exc), + extra={"instance_name": instance_model.name}, + ) + continue + except Exception: + logger.exception( + "Got exception when launching %s in %s/%s", + selected_offer.instance.name, + selected_offer.backend.value, + selected_offer.region, + ) + continue + + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.PROVISIONING, + ) + result.instance_update_map["backend"] = backend.TYPE + result.instance_update_map["region"] = selected_offer.region + result.instance_update_map["price"] = selected_offer.price + result.instance_update_map["instance_configuration"] = instance_configuration.json() + result.instance_update_map["job_provisioning_data"] = job_provisioning_data.json() + result.instance_update_map["offer"] = selected_offer.json() + result.instance_update_map["total_blocks"] = selected_offer.total_blocks + result.instance_update_map["started_at"] = NOW_PLACEHOLDER + + if instance_model.fleet_id is not None and instance_model.id == master_instance_model.id: + result.schedule_pg_deletion_fleet_id = instance_model.fleet_id + if selected_placement_group_state is not None: + result.schedule_pg_deletion_except_ids = (selected_placement_group_state.id,) + return result + + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.NO_OFFERS, + termination_reason_message="All offers failed" if offers else "No offers found", + ) + if ( + instance_model.fleet is not None + and instance_model.id == master_instance_model.id + and is_cloud_cluster(instance_model.fleet) + ): + for sibling_instance_model in instance_model.fleet.instances: + if sibling_instance_model.id == instance_model.id: + continue + sibling_update_map = SiblingInstanceUpdateMap(id=sibling_instance_model.id) + set_status_update( + update_map=sibling_update_map, + instance_model=sibling_instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.MASTER_FAILED, + ) + if len(sibling_update_map) > 1: + result.sibling_update_rows.append(sibling_update_map) + append_sibling_status_event( + deferred_events=result.sibling_deferred_events, + instance_model=sibling_instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=cast( + Optional[InstanceTerminationReason], + sibling_update_map.get("termination_reason"), + ), + termination_reason_message=cast( + Optional[str], sibling_update_map.get("termination_reason_message") + ), + ) + return result + + +async def _get_fleet_placement_group_states( + fleet_id: Optional[uuid.UUID], +) -> list[_PlacementGroupState]: + if fleet_id is None: + return [] + async with get_session_ctx() as session: + res = await session.execute( + select(PlacementGroupModel) + .where( + PlacementGroupModel.fleet_id == fleet_id, + PlacementGroupModel.deleted == False, + PlacementGroupModel.fleet_deleted == False, + ) + .options(joinedload(PlacementGroupModel.project)) + ) + placement_group_models = list(res.unique().scalars().all()) + return [ + _PlacementGroupState( + id=placement_group_model.id, + placement_group=placement_group_model_to_placement_group(placement_group_model), + ) + for placement_group_model in placement_group_models + ] + + +def _get_placement_group_state_for_instance( + placement_group_states: list[_PlacementGroupState], + instance_model: InstanceModel, + master_instance_model: InstanceModel, +) -> Optional[_PlacementGroupState]: + if instance_model.id == master_instance_model.id: + return None + if len(placement_group_states) > 1: + logger.error( + ( + "Expected 0 or 1 placement groups associated with fleet %s, found %s." + " An incorrect placement group might have been selected for instance %s" + ), + instance_model.fleet_id, + len(placement_group_states), + instance_model.name, + ) + if placement_group_states: + return placement_group_states[0] + return None + + +async def _find_or_create_suitable_placement_group_state( + instance_model: InstanceModel, + placement_group_states: list[_PlacementGroupState], + instance_offer: InstanceOfferWithAvailability, + compute: ComputeWithPlacementGroupSupport, +) -> Optional[_PlacementGroupState]: + for placement_group_state in placement_group_states: + if compute.is_suitable_placement_group( + placement_group_state.placement_group, + instance_offer, + ): + return placement_group_state + + assert instance_model.fleet is not None + placement_group_id = uuid.uuid4() + placement_group_name = generate_unique_placement_group_name( + project_name=instance_model.project.name, + fleet_name=instance_model.fleet.name, + ) + placement_group = PlacementGroup( + name=placement_group_name, + project_name=instance_model.project.name, + configuration=PlacementGroupConfiguration( + backend=instance_offer.backend, + region=instance_offer.region, + placement_strategy=PlacementStrategy.CLUSTER, + ), + provisioning_data=None, + ) + logger.debug( + "Creating placement group %s in %s/%s", + placement_group.name, + placement_group.configuration.backend.value, + placement_group.configuration.region, + ) + try: + provisioning_data = await run_async( + compute.create_placement_group, + placement_group, + instance_offer, + ) + except PlacementGroupNotSupportedError: + logger.debug( + "Skipping offer %s because placement group not supported", + instance_offer.instance.name, + ) + return None + except BackendError as exc: + logger.warning( + "Failed to create placement group %s in %s/%s: %r", + placement_group.name, + placement_group.configuration.backend.value, + placement_group.configuration.region, + exc, + ) + return None + except Exception: + logger.exception( + "Got exception when creating placement group %s in %s/%s", + placement_group.name, + placement_group.configuration.backend.value, + placement_group.configuration.region, + ) + return None + + placement_group.provisioning_data = provisioning_data + return _PlacementGroupState( + id=placement_group_id, + placement_group=placement_group, + create_payload=PlacementGroupCreate( + id=placement_group_id, + name=placement_group.name, + project_id=instance_model.project_id, + fleet_id=get_or_error(instance_model.fleet_id), + configuration=placement_group.configuration.json(), + provisioning_data=provisioning_data.json(), + ), + ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py new file mode 100644 index 0000000000..9c569658b1 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py @@ -0,0 +1,268 @@ +import datetime +import uuid +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Optional, TypedDict, Union, cast + +from paramiko.pkey import PKey + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.fleets import InstanceGroupPlacement +from dstack._internal.core.models.health import HealthStatus +from dstack._internal.core.models.instances import ( + InstanceOfferWithAvailability, + InstanceStatus, + InstanceTerminationReason, + SSHKey, +) +from dstack._internal.core.models.runs import JobProvisioningData +from dstack._internal.server.background.pipeline_tasks.base import ( + ItemUpdateMap, + UpdateMapDateTime, +) +from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout +from dstack._internal.server.models import FleetModel, InstanceModel +from dstack._internal.server.services.fleets import fleet_model_to_fleet, is_cloud_cluster +from dstack._internal.server.services.instances import ( + get_instance_provisioning_data, + get_instance_status_change_message, +) +from dstack._internal.server.services.offers import get_instance_offer_with_restricted_az +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.ssh import pkey_from_str + +TERMINATION_DEADLINE_OFFSET = timedelta(minutes=20) +TERMINATION_RETRY_TIMEOUT = timedelta(seconds=30) +TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15) +PROVISIONING_TIMEOUT_SECONDS = 10 * 60 # 10 minutes in seconds + +_UNSET = object() + + +class InstanceUpdateMap(ItemUpdateMap, total=False): + status: InstanceStatus + unreachable: bool + started_at: UpdateMapDateTime + finished_at: UpdateMapDateTime + instance_configuration: str + termination_deadline: Optional[datetime.datetime] + termination_reason: Optional[InstanceTerminationReason] + termination_reason_message: Optional[str] + health: HealthStatus + first_termination_retry_at: UpdateMapDateTime + last_termination_retry_at: UpdateMapDateTime + backend: BackendType + backend_data: Optional[str] + offer: str + region: str + price: float + job_provisioning_data: str + total_blocks: int + busy_blocks: int + deleted: bool + deleted_at: UpdateMapDateTime + + +class SiblingInstanceUpdateMap(TypedDict, total=False): + id: uuid.UUID + status: InstanceStatus + termination_reason: Optional[InstanceTerminationReason] + termination_reason_message: Optional[str] + + +class HealthCheckCreate(TypedDict): + instance_id: uuid.UUID + collected_at: datetime.datetime + status: HealthStatus + response: str + + +class PlacementGroupCreate(TypedDict): + id: uuid.UUID + name: str + project_id: uuid.UUID + fleet_id: uuid.UUID + configuration: str + provisioning_data: str + + +@dataclass +class SiblingDeferredEvent: + message: str + project_id: uuid.UUID + instance_id: uuid.UUID + instance_name: str + + +@dataclass +class ProcessResult: + instance_update_map: InstanceUpdateMap = field(default_factory=InstanceUpdateMap) + sibling_update_rows: list[SiblingInstanceUpdateMap] = field(default_factory=list) + sibling_deferred_events: list[SiblingDeferredEvent] = field(default_factory=list) + health_check_create: Optional[HealthCheckCreate] = None + placement_group_creates: list[PlacementGroupCreate] = field(default_factory=list) + schedule_pg_deletion_fleet_id: Optional[uuid.UUID] = None + schedule_pg_deletion_except_ids: tuple[uuid.UUID, ...] = () + + +def can_terminate_fleet_instances_on_idle_duration(fleet_model: FleetModel) -> bool: + fleet = fleet_model_to_fleet(fleet_model) + if fleet.spec.configuration.nodes is None or fleet.spec.autocreated: + return True + active_instances = [ + instance for instance in fleet_model.instances if instance.status.is_active() + ] + return len(active_instances) > fleet.spec.configuration.nodes.min + + +def get_fleet_master_instance(instance_model: InstanceModel) -> InstanceModel: + if instance_model.fleet is None: + return instance_model + fleet_instances = list(instance_model.fleet.instances) + if all(fleet_instance.id != instance_model.id for fleet_instance in fleet_instances): + fleet_instances.append(instance_model) + return min( + fleet_instances, + key=lambda fleet_instance: (fleet_instance.instance_num, fleet_instance.created_at), + ) + + +def need_to_wait_fleet_provisioning( + instance_model: InstanceModel, + master_instance_model: InstanceModel, +) -> bool: + if instance_model.fleet is None: + return False + if ( + instance_model.id == master_instance_model.id + or master_instance_model.job_provisioning_data is not None + or master_instance_model.status == InstanceStatus.TERMINATED + ): + return False + return is_cloud_cluster(instance_model.fleet) + + +def get_instance_offer_for_instance( + instance_offer: InstanceOfferWithAvailability, + instance_model: InstanceModel, + master_instance_model: InstanceModel, +) -> InstanceOfferWithAvailability: + if instance_model.fleet is None: + return instance_offer + fleet = fleet_model_to_fleet(instance_model.fleet) + if fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER: + master_job_provisioning_data = get_instance_provisioning_data(master_instance_model) + return get_instance_offer_with_restricted_az( + instance_offer=instance_offer, + master_job_provisioning_data=master_job_provisioning_data, + ) + return instance_offer + + +def get_instance_idle_duration(instance_model: InstanceModel) -> datetime.timedelta: + last_time = instance_model.created_at + if instance_model.last_job_processed_at is not None: + last_time = instance_model.last_job_processed_at + return get_current_datetime() - last_time + + +def get_provisioning_deadline( + instance_model: InstanceModel, + job_provisioning_data: JobProvisioningData, +) -> datetime.datetime: + assert instance_model.started_at is not None + timeout_interval = get_provisioning_timeout( + backend_type=job_provisioning_data.get_base_backend(), + instance_type_name=job_provisioning_data.instance_type.name, + ) + return instance_model.started_at + timeout_interval + + +def next_termination_retry_at(last_termination_retry_at: datetime.datetime) -> datetime.datetime: + return last_termination_retry_at + TERMINATION_RETRY_TIMEOUT + + +def get_termination_deadline(first_termination_retry_at: datetime.datetime) -> datetime.datetime: + return first_termination_retry_at + TERMINATION_RETRY_MAX_DURATION + + +def ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]: + return [pkey_from_str(ssh_key.private) for ssh_key in ssh_keys if ssh_key.private is not None] + + +def set_status_update( + update_map: Union[InstanceUpdateMap, SiblingInstanceUpdateMap], + instance_model: InstanceModel, + new_status: InstanceStatus, + termination_reason: object = _UNSET, + termination_reason_message: object = _UNSET, +) -> None: + old_status = instance_model.status + if old_status == new_status: + if termination_reason is not _UNSET: + update_map["termination_reason"] = cast( + Optional[InstanceTerminationReason], termination_reason + ) + if termination_reason_message is not _UNSET: + update_map["termination_reason_message"] = cast( + Optional[str], termination_reason_message + ) + return + + effective_termination_reason = instance_model.termination_reason + if termination_reason is not _UNSET: + effective_termination_reason = cast( + Optional[InstanceTerminationReason], termination_reason + ) + update_map["termination_reason"] = effective_termination_reason + + effective_termination_reason_message = instance_model.termination_reason_message + if termination_reason_message is not _UNSET: + effective_termination_reason_message = cast(Optional[str], termination_reason_message) + update_map["termination_reason_message"] = effective_termination_reason_message + + update_map["status"] = new_status + + +def set_health_update( + update_map: InstanceUpdateMap, + instance_model: InstanceModel, + health: HealthStatus, +) -> None: + if instance_model.health == health: + return + update_map["health"] = health + + +def set_unreachable_update( + update_map: InstanceUpdateMap, + instance_model: InstanceModel, + unreachable: bool, +) -> None: + if not instance_model.status.is_available() or instance_model.unreachable == unreachable: + return + update_map["unreachable"] = unreachable + + +def append_sibling_status_event( + deferred_events: list[SiblingDeferredEvent], + instance_model: InstanceModel, + new_status: InstanceStatus, + termination_reason: Optional[InstanceTerminationReason], + termination_reason_message: Optional[str], +) -> None: + if instance_model.status == new_status: + return + deferred_events.append( + SiblingDeferredEvent( + message=get_instance_status_change_message( + old_status=instance_model.status, + new_status=new_status, + termination_reason=termination_reason, + termination_reason_message=termination_reason_message, + ), + project_id=instance_model.project_id, + instance_id=instance_model.id, + instance_name=instance_model.name, + ) + ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/ssh_deploy.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/ssh_deploy.py new file mode 100644 index 0000000000..312f7977ef --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/ssh_deploy.py @@ -0,0 +1,297 @@ +import asyncio +from datetime import timedelta +from typing import Any, Optional + +from paramiko.pkey import PKey +from paramiko.ssh_exception import PasswordRequiredException +from pydantic import ValidationError + +from dstack._internal import settings +from dstack._internal.core.backends.base.compute import ( + GoArchType, + get_dstack_runner_binary_path, + get_dstack_shim_binary_path, + get_dstack_working_dir, + get_shim_env, + get_shim_pre_start_commands, +) +from dstack._internal.core.errors import SSHProvisioningError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import ( + InstanceAvailability, + InstanceOfferWithAvailability, + InstanceRuntime, + InstanceStatus, + InstanceTerminationReason, + RemoteConnectionInfo, +) +from dstack._internal.core.models.runs import JobProvisioningData +from dstack._internal.server.background.pipeline_tasks.base import NOW_PLACEHOLDER +from dstack._internal.server.background.pipeline_tasks.instances.common import ( + PROVISIONING_TIMEOUT_SECONDS, + ProcessResult, + set_status_update, + ssh_keys_to_pkeys, +) +from dstack._internal.server.models import InstanceModel +from dstack._internal.server.schemas.instances import InstanceCheck +from dstack._internal.server.schemas.runner import HealthcheckResponse +from dstack._internal.server.services.instances import get_instance_remote_connection_info +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.offers import is_divisible_into_blocks +from dstack._internal.server.services.runner import client as runner_client +from dstack._internal.server.services.ssh_fleets.provisioning import ( + detect_cpu_arch, + get_host_info, + get_paramiko_connection, + get_shim_healthcheck, + host_info_to_instance_type, + remove_dstack_runner_if_exists, + remove_host_info_if_exists, + run_pre_start_commands, + run_shim_as_systemd_service, + upload_envs, +) +from dstack._internal.utils.common import get_current_datetime, run_async +from dstack._internal.utils.logging import get_logger +from dstack._internal.utils.network import get_ip_from_network, is_ip_among_addresses + +logger = get_logger(__name__) + + +async def add_ssh_instance(instance_model: InstanceModel) -> ProcessResult: + result = ProcessResult() + logger.info("Adding ssh instance %s...", instance_model.name) + + retry_duration_deadline = instance_model.created_at + timedelta( + seconds=PROVISIONING_TIMEOUT_SECONDS + ) + if retry_duration_deadline < get_current_datetime(): + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.PROVISIONING_TIMEOUT, + termination_reason_message=( + f"Failed to add SSH instance in {PROVISIONING_TIMEOUT_SECONDS}s" + ), + ) + return result + + remote_details = get_instance_remote_connection_info(instance_model) + assert remote_details is not None + + try: + pkeys = ssh_keys_to_pkeys(remote_details.ssh_keys) + ssh_proxy_pkeys = None + if remote_details.ssh_proxy_keys is not None: + ssh_proxy_pkeys = ssh_keys_to_pkeys(remote_details.ssh_proxy_keys) + except (ValueError, PasswordRequiredException): + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Unsupported private SSH key type", + ) + return result + + authorized_keys = [pkey.public.strip() for pkey in remote_details.ssh_keys] + authorized_keys.append(instance_model.project.ssh_public_key.strip()) + + try: + future = run_async( + _deploy_instance, + remote_details, + pkeys, + ssh_proxy_pkeys, + authorized_keys, + ) + health, host_info, arch = await asyncio.wait_for(future, timeout=20 * 60) + except (asyncio.TimeoutError, TimeoutError) as exc: + logger.warning( + "%s: deploy timeout when adding SSH instance: %s", + fmt(instance_model), + repr(exc), + ) + return result + except SSHProvisioningError as exc: + logger.warning( + "%s: provisioning error when adding SSH instance: %s", + fmt(instance_model), + repr(exc), + ) + return result + except Exception: + logger.exception("%s: unexpected error when adding SSH instance", fmt(instance_model)) + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Unexpected error when adding SSH instance", + ) + return result + + instance_type = host_info_to_instance_type(host_info, arch) + try: + instance_network, internal_ip = _resolve_ssh_instance_network(instance_model, host_info) + except _SSHInstanceNetworkResolutionError as exc: + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message=str(exc), + ) + return result + + divisible, blocks = is_divisible_into_blocks( + cpu_count=instance_type.resources.cpus, + gpu_count=len(instance_type.resources.gpus), + blocks="auto" if instance_model.total_blocks is None else instance_model.total_blocks, + ) + if not divisible: + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + termination_reason=InstanceTerminationReason.ERROR, + termination_reason_message="Cannot split into blocks", + ) + return result + + region = instance_model.region + assert region is not None + job_provisioning_data = JobProvisioningData( + backend=BackendType.REMOTE, + instance_type=instance_type, + instance_id="instance_id", + hostname=remote_details.host, + region=region, + price=0, + internal_ip=internal_ip, + instance_network=instance_network, + username=remote_details.ssh_user, + ssh_port=remote_details.port, + dockerized=True, + backend_data=None, + ssh_proxy=remote_details.ssh_proxy, + ) + instance_offer = InstanceOfferWithAvailability( + backend=BackendType.REMOTE, + instance=instance_type, + region=region, + price=0, + availability=InstanceAvailability.AVAILABLE, + instance_runtime=InstanceRuntime.SHIM, + ) + + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.IDLE if health else InstanceStatus.PROVISIONING, + ) + result.instance_update_map["backend"] = BackendType.REMOTE + result.instance_update_map["price"] = 0 + result.instance_update_map["offer"] = instance_offer.json() + result.instance_update_map["job_provisioning_data"] = job_provisioning_data.json() + result.instance_update_map["started_at"] = NOW_PLACEHOLDER + result.instance_update_map["total_blocks"] = blocks + return result + + +class _SSHInstanceNetworkResolutionError(Exception): + pass + + +def _resolve_ssh_instance_network( + instance_model: InstanceModel, + host_info: dict[str, Any], +) -> tuple[Optional[str], Optional[str]]: + instance_network = None + internal_ip = None + try: + default_job_provisioning_data = JobProvisioningData.__response__.parse_raw( + instance_model.job_provisioning_data + ) + instance_network = default_job_provisioning_data.instance_network + internal_ip = default_job_provisioning_data.internal_ip + except ValidationError: + pass + + host_network_addresses = host_info.get("addresses", []) + if internal_ip is None: + internal_ip = get_ip_from_network( + network=instance_network, + addresses=host_network_addresses, + ) + if instance_network is not None and internal_ip is None: + raise _SSHInstanceNetworkResolutionError( + "Failed to locate internal IP address on the given network" + ) + if internal_ip is not None and not is_ip_among_addresses( + ip_address=internal_ip, + addresses=host_network_addresses, + ): + raise _SSHInstanceNetworkResolutionError( + "Specified internal IP not found among instance interfaces" + ) + return instance_network, internal_ip + + +def _deploy_instance( + remote_details: RemoteConnectionInfo, + pkeys: list[PKey], + ssh_proxy_pkeys: Optional[list[PKey]], + authorized_keys: list[str], +) -> tuple[InstanceCheck, dict[str, Any], GoArchType]: + with get_paramiko_connection( + remote_details.ssh_user, + remote_details.host, + remote_details.port, + pkeys, + remote_details.ssh_proxy, + ssh_proxy_pkeys, + ) as client: + logger.debug("Connected to %s %s", remote_details.ssh_user, remote_details.host) + + arch = detect_cpu_arch(client) + logger.debug("%s: CPU arch is %s", remote_details.host, arch) + + shim_pre_start_commands = get_shim_pre_start_commands(arch=arch) + run_pre_start_commands(client, shim_pre_start_commands, authorized_keys) + logger.debug("The script for installing dstack has been executed") + + shim_envs = get_shim_env(arch=arch) + try: + fleet_configuration_envs = remote_details.env.as_dict() + except ValueError as exc: + raise SSHProvisioningError(f"Invalid Env: {exc}") from exc + shim_envs.update(fleet_configuration_envs) + dstack_working_dir = get_dstack_working_dir() + dstack_shim_binary_path = get_dstack_shim_binary_path() + dstack_runner_binary_path = get_dstack_runner_binary_path() + upload_envs(client, dstack_working_dir, shim_envs) + logger.debug("The dstack-shim environment variables have been installed") + + remove_host_info_if_exists(client, dstack_working_dir) + remove_dstack_runner_if_exists(client, dstack_runner_binary_path) + + run_shim_as_systemd_service( + client=client, + binary_path=dstack_shim_binary_path, + working_dir=dstack_working_dir, + dev=settings.DSTACK_VERSION is None, + ) + + host_info = get_host_info(client, dstack_working_dir) + logger.debug("Received a host_info %s", host_info) + + healthcheck_out = get_shim_healthcheck(client) + try: + healthcheck = HealthcheckResponse.__response__.parse_raw(healthcheck_out) + except ValueError as exc: + raise SSHProvisioningError(f"Cannot parse HealthcheckResponse: {exc}") from exc + instance_check = runner_client.healthcheck_response_to_instance_check(healthcheck) + return instance_check, host_info, arch diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py new file mode 100644 index 0000000000..eb1f3c8a39 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/termination.py @@ -0,0 +1,88 @@ +from dstack._internal.core.errors import BackendError, NotYetTerminated +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.server.background.pipeline_tasks.base import NOW_PLACEHOLDER +from dstack._internal.server.background.pipeline_tasks.instances.common import ( + ProcessResult, + get_termination_deadline, + next_termination_retry_at, + set_status_update, +) +from dstack._internal.server.models import InstanceModel +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services.instances import get_instance_provisioning_data +from dstack._internal.utils.common import get_current_datetime, run_async +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +async def terminate_instance(instance_model: InstanceModel) -> ProcessResult: + result = ProcessResult() + now = get_current_datetime() + if ( + instance_model.last_termination_retry_at is not None + and next_termination_retry_at(instance_model.last_termination_retry_at) > now + ): + return result + + job_provisioning_data = get_instance_provisioning_data(instance_model) + if job_provisioning_data is not None and job_provisioning_data.backend != BackendType.REMOTE: + backend = await backends_services.get_project_backend_by_type( + project=instance_model.project, + backend_type=job_provisioning_data.backend, + ) + if backend is None: + logger.error( + "Failed to terminate instance %s. Backend %s not available.", + instance_model.name, + job_provisioning_data.backend, + ) + else: + logger.debug("Terminating runner instance %s", job_provisioning_data.hostname) + try: + await run_async( + backend.compute().terminate_instance, + job_provisioning_data.instance_id, + job_provisioning_data.region, + job_provisioning_data.backend_data, + ) + except Exception as exc: + first_retry_at = instance_model.first_termination_retry_at + if first_retry_at is None: + first_retry_at = now + result.instance_update_map["first_termination_retry_at"] = NOW_PLACEHOLDER + result.instance_update_map["last_termination_retry_at"] = NOW_PLACEHOLDER + if next_termination_retry_at(now) < get_termination_deadline(first_retry_at): + if isinstance(exc, NotYetTerminated): + logger.debug( + "Instance %s termination in progress: %s", + instance_model.name, + exc, + ) + else: + logger.warning( + "Failed to terminate instance %s. Will retry. Error: %r", + instance_model.name, + exc, + exc_info=not isinstance(exc, BackendError), + ) + return result + logger.error( + "Failed all attempts to terminate instance %s." + " Please terminate the instance manually to avoid unexpected charges." + " Error: %r", + instance_model.name, + exc, + exc_info=not isinstance(exc, BackendError), + ) + + result.instance_update_map["deleted"] = True + result.instance_update_map["deleted_at"] = NOW_PLACEHOLDER + result.instance_update_map["finished_at"] = NOW_PLACEHOLDER + set_status_update( + update_map=result.instance_update_map, + instance_model=instance_model, + new_status=InstanceStatus.TERMINATED, + ) + return result diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances.py deleted file mode 100644 index ae2882e78d..0000000000 --- a/src/tests/_internal/server/background/pipeline_tasks/test_instances.py +++ /dev/null @@ -1,2153 +0,0 @@ -import asyncio -import datetime as dt -import logging -import uuid -from collections import defaultdict -from contextlib import contextmanager -from typing import Optional -from unittest.mock import AsyncMock, Mock, call, patch - -import gpuhunt -import pytest -import pytest_asyncio -from freezegun import freeze_time -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from dstack._internal.core.backends.base.compute import GoArchType -from dstack._internal.core.errors import ( - BackendError, - NoCapacityError, - NotYetTerminated, - ProvisioningError, - SSHProvisioningError, -) -from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.fleets import FleetNodesSpec, InstanceGroupPlacement -from dstack._internal.core.models.health import HealthStatus -from dstack._internal.core.models.instances import ( - Gpu, - InstanceAvailability, - InstanceOffer, - InstanceOfferWithAvailability, - InstanceStatus, - InstanceTerminationReason, - InstanceType, - Resources, -) -from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData -from dstack._internal.core.models.profiles import TerminationPolicy -from dstack._internal.core.models.runs import JobProvisioningData, JobStatus -from dstack._internal.server.background.pipeline_tasks import instances as instances_pipeline -from dstack._internal.server.background.pipeline_tasks.instances import ( - InstanceFetcher, - InstancePipeline, - InstancePipelineItem, - InstanceWorker, -) -from dstack._internal.server.models import ( - InstanceHealthCheckModel, - InstanceModel, - PlacementGroupModel, -) -from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse, DCGMHealthResult -from dstack._internal.server.schemas.instances import InstanceCheck -from dstack._internal.server.schemas.runner import ( - ComponentInfo, - ComponentName, - ComponentStatus, - HealthcheckResponse, - InstanceHealthResponse, - TaskListItem, - TaskListResponse, - TaskStatus, -) -from dstack._internal.server.services.runner.client import ComponentList, ShimClient -from dstack._internal.server.testing.common import ( - ComputeMockSpec, - create_compute_group, - create_fleet, - create_instance, - create_job, - create_project, - create_repo, - create_run, - create_user, - get_fleet_configuration, - get_fleet_spec, - get_instance_offer_with_availability, - get_job_provisioning_data, - get_placement_group_provisioning_data, - get_remote_connection_info, - list_events, -) -from dstack._internal.utils.common import get_current_datetime - -pytestmark = pytest.mark.usefixtures("image_config_mock") -LOCK_EXPIRES_AT = dt.datetime(2025, 1, 2, 3, 4, tzinfo=dt.timezone.utc) - - -@pytest.fixture -def fetcher() -> InstanceFetcher: - return InstanceFetcher( - queue=asyncio.Queue(), - queue_desired_minsize=1, - min_processing_interval=dt.timedelta(seconds=10), - lock_timeout=dt.timedelta(seconds=30), - heartbeater=Mock(), - ) - - -@pytest.fixture -def worker() -> InstanceWorker: - return InstanceWorker(queue=asyncio.Queue(), heartbeater=Mock()) - - -@pytest.fixture -def host_info() -> dict: - return { - "gpu_vendor": "nvidia", - "gpu_name": "T4", - "gpu_memory": 16384, - "gpu_count": 1, - "addresses": ["192.168.100.100/24"], - "disk_size": 260976517120, - "cpus": 32, - "memory": 33544130560, - } - - -@pytest.fixture -def deploy_instance_mock(monkeypatch: pytest.MonkeyPatch, host_info: dict) -> Mock: - mock = Mock(return_value=(InstanceCheck(reachable=True), host_info, GoArchType.AMD64)) - monkeypatch.setattr(instances_pipeline, "_deploy_instance", mock) - return mock - - -def _instance_to_pipeline_item(instance_model: InstanceModel) -> InstancePipelineItem: - assert instance_model.lock_token is not None - assert instance_model.lock_expires_at is not None - return InstancePipelineItem( - __tablename__=instance_model.__tablename__, - id=instance_model.id, - lock_token=instance_model.lock_token, - lock_expires_at=instance_model.lock_expires_at, - prev_lock_expired=False, - status=instance_model.status, - ) - - -def _lock_instance(instance_model: InstanceModel) -> None: - instance_model.lock_token = uuid.uuid4() - instance_model.lock_expires_at = LOCK_EXPIRES_AT - instance_model.lock_owner = InstancePipeline.__name__ - - -async def _process_instance( - session: AsyncSession, worker: InstanceWorker, instance_model: InstanceModel -) -> None: - _lock_instance(instance_model) - await session.commit() - await worker.process(_instance_to_pipeline_item(instance_model)) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) -class TestInstanceFetcher: - async def test_fetch_selects_eligible_instances_and_sets_lock_fields( - self, test_db, session: AsyncSession, fetcher: InstanceFetcher - ): - project = await create_project(session=session) - fleet = await create_fleet(session=session, project=project) - compute_group = await create_compute_group(session=session, project=project, fleet=fleet) - now = get_current_datetime() - stale = now - dt.timedelta(minutes=1) - - pending = await create_instance( - session=session, - project=project, - status=InstanceStatus.PENDING, - last_processed_at=stale - dt.timedelta(seconds=5), - ) - provisioning = await create_instance( - session=session, - project=project, - status=InstanceStatus.PROVISIONING, - name="provisioning", - last_processed_at=stale - dt.timedelta(seconds=4), - ) - busy = await create_instance( - session=session, - project=project, - status=InstanceStatus.BUSY, - name="busy", - last_processed_at=stale - dt.timedelta(seconds=3), - ) - idle = await create_instance( - session=session, - project=project, - status=InstanceStatus.IDLE, - name="idle", - last_processed_at=stale - dt.timedelta(seconds=2), - ) - terminating = await create_instance( - session=session, - project=project, - status=InstanceStatus.TERMINATING, - name="terminating", - last_processed_at=stale - dt.timedelta(seconds=1), - ) - - deleted = await create_instance( - session=session, - project=project, - status=InstanceStatus.IDLE, - name="deleted", - last_processed_at=stale, - ) - deleted.deleted = True - - recent = await create_instance( - session=session, - project=project, - status=InstanceStatus.IDLE, - name="recent", - last_processed_at=now, - ) - - terminating_compute_group = await create_instance( - session=session, - project=project, - status=InstanceStatus.TERMINATING, - name="terminating-compute-group", - last_processed_at=stale + dt.timedelta(seconds=1), - ) - terminating_compute_group.compute_group = compute_group - - locked = await create_instance( - session=session, - project=project, - status=InstanceStatus.IDLE, - name="locked", - last_processed_at=stale + dt.timedelta(seconds=2), - ) - locked.lock_expires_at = now + dt.timedelta(minutes=1) - locked.lock_token = uuid.uuid4() - locked.lock_owner = "OtherPipeline" - - await session.commit() - - items = await fetcher.fetch(limit=10) - - assert {item.id for item in items} == { - pending.id, - provisioning.id, - busy.id, - idle.id, - terminating.id, - } - assert {item.status for item in items} == { - InstanceStatus.PENDING, - InstanceStatus.PROVISIONING, - InstanceStatus.BUSY, - InstanceStatus.IDLE, - InstanceStatus.TERMINATING, - } - - for instance in [ - pending, - provisioning, - busy, - idle, - terminating, - deleted, - recent, - terminating_compute_group, - locked, - ]: - await session.refresh(instance) - - expected_lock_owner = InstancePipeline.__name__ - fetched_instances = [pending, provisioning, busy, idle, terminating] - assert all(instance.lock_owner == expected_lock_owner for instance in fetched_instances) - assert all(instance.lock_expires_at is not None for instance in fetched_instances) - assert all(instance.lock_token is not None for instance in fetched_instances) - assert len({instance.lock_token for instance in fetched_instances}) == 1 - - assert deleted.lock_owner is None - assert recent.lock_owner is None - assert terminating_compute_group.lock_owner is None - assert locked.lock_owner == "OtherPipeline" - - async def test_fetch_respects_order_and_limit( - self, test_db, session: AsyncSession, fetcher: InstanceFetcher - ): - project = await create_project(session=session) - now = get_current_datetime() - - oldest = await create_instance( - session=session, - project=project, - name="oldest", - last_processed_at=now - dt.timedelta(minutes=3), - ) - middle = await create_instance( - session=session, - project=project, - name="middle", - last_processed_at=now - dt.timedelta(minutes=2), - ) - newest = await create_instance( - session=session, - project=project, - name="newest", - last_processed_at=now - dt.timedelta(minutes=1), - ) - - items = await fetcher.fetch(limit=2) - - assert [item.id for item in items] == [oldest.id, middle.id] - - await session.refresh(oldest) - await session.refresh(middle) - await session.refresh(newest) - - assert oldest.lock_owner == InstancePipeline.__name__ - assert middle.lock_owner == InstancePipeline.__name__ - assert newest.lock_owner is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) -class TestInstanceWorker: - @staticmethod - @contextmanager - def mock_terminate_in_backend(error: Optional[Exception] = None): - backend = Mock() - backend.TYPE = BackendType.VERDA - terminate_instance = backend.compute.return_value.terminate_instance - if error is not None: - terminate_instance.side_effect = error - with patch.object( - instances_pipeline.backends_services, - "get_project_backend_by_type", - AsyncMock(return_value=backend), - ): - yield terminate_instance - - async def test_process_skips_when_lock_token_changes( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.IDLE, - ) - - _lock_instance(instance) - await session.commit() - item = _instance_to_pipeline_item(instance) - new_lock_token = uuid.uuid4() - instance.lock_token = new_lock_token - await session.commit() - - await worker.process(item) - await session.refresh(instance) - - assert instance.lock_token == new_lock_token - assert instance.lock_owner == InstancePipeline.__name__ - - async def test_process_unlocks_and_updates_last_processed_at_after_check( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - monkeypatch: pytest.MonkeyPatch, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PROVISIONING, - ) - before_processed_at = instance.last_processed_at - - monkeypatch.setattr( - instances_pipeline, - "_check_instance_inner", - Mock(return_value=InstanceCheck(reachable=True)), - ) - await _process_instance(session, worker, instance) - - await session.refresh(instance) - - assert instance.status == InstanceStatus.IDLE - assert instance.lock_expires_at is None - assert instance.lock_token is None - assert instance.lock_owner is None - assert instance.last_processed_at > before_processed_at - - async def test_check_shim_transitions_provisioning_on_ready( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - monkeypatch: pytest.MonkeyPatch, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PROVISIONING, - ) - instance.termination_deadline = get_current_datetime() + dt.timedelta(days=1) - await session.commit() - - monkeypatch.setattr( - instances_pipeline, - "_check_instance_inner", - Mock(return_value=InstanceCheck(reachable=True)), - ) - await _process_instance(session, worker, instance) - - await session.refresh(instance) - - assert instance.status == InstanceStatus.IDLE - assert instance.termination_deadline is None - - async def test_check_shim_transitions_provisioning_on_terminating( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - monkeypatch: pytest.MonkeyPatch, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PROVISIONING, - ) - instance.started_at = get_current_datetime() + dt.timedelta(minutes=-20) - await session.commit() - - monkeypatch.setattr( - instances_pipeline, - "_check_instance_inner", - Mock(return_value=InstanceCheck(reachable=False, message="Shim problem")), - ) - await _process_instance(session, worker, instance) - - await session.refresh(instance) - - assert instance.status == InstanceStatus.TERMINATING - assert instance.termination_deadline is not None - - async def test_check_shim_transitions_provisioning_on_busy( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - monkeypatch: pytest.MonkeyPatch, - ): - user = await create_user(session=session) - project = await create_project(session=session, owner=user) - repo = await create_repo(session=session, project_id=project.id) - run = await create_run(session=session, project=project, repo=repo, user=user) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PROVISIONING, - ) - instance.termination_deadline = get_current_datetime().replace( - tzinfo=dt.timezone.utc - ) + dt.timedelta(days=1) - job = await create_job( - session=session, - run=run, - status=JobStatus.SUBMITTED, - instance=instance, - ) - await session.commit() - - monkeypatch.setattr( - instances_pipeline, - "_check_instance_inner", - Mock(return_value=InstanceCheck(reachable=True)), - ) - await _process_instance(session, worker, instance) - - await session.refresh(instance) - await session.refresh(job) - - assert instance.status == InstanceStatus.BUSY - assert instance.termination_deadline is None - assert job.instance == instance - - async def test_check_shim_start_termination_deadline( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - monkeypatch: pytest.MonkeyPatch, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.IDLE, - unreachable=False, - ) - - monkeypatch.setattr( - instances_pipeline, - "_check_instance_inner", - Mock(return_value=InstanceCheck(reachable=False, message="SSH connection fail")), - ) - await _process_instance(session, worker, instance) - - await session.refresh(instance) - - assert instance.status == InstanceStatus.IDLE - assert instance.unreachable is True - assert instance.termination_deadline is not None - assert instance.termination_deadline.replace( - tzinfo=dt.timezone.utc - ) > get_current_datetime() + dt.timedelta(minutes=19) - - async def test_check_shim_does_not_start_termination_deadline_with_ssh_instance( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - monkeypatch: pytest.MonkeyPatch, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.IDLE, - unreachable=False, - remote_connection_info=get_remote_connection_info(), - ) - - monkeypatch.setattr( - instances_pipeline, - "_check_instance_inner", - Mock(return_value=InstanceCheck(reachable=False, message="SSH connection fail")), - ) - await _process_instance(session, worker, instance) - - await session.refresh(instance) - - assert instance.status == InstanceStatus.IDLE - assert instance.unreachable is True - assert instance.termination_deadline is None - - async def test_check_shim_stop_termination_deadline( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - monkeypatch: pytest.MonkeyPatch, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.IDLE, - ) - instance.termination_deadline = get_current_datetime() + dt.timedelta(minutes=19) - await session.commit() - - monkeypatch.setattr( - instances_pipeline, - "_check_instance_inner", - Mock(return_value=InstanceCheck(reachable=True)), - ) - await _process_instance(session, worker, instance) - - await session.refresh(instance) - - assert instance.status == InstanceStatus.IDLE - assert instance.termination_deadline is None - - async def test_check_shim_terminate_instance_by_deadline( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - monkeypatch: pytest.MonkeyPatch, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.IDLE, - ) - termination_deadline_time = get_current_datetime() + dt.timedelta(minutes=-19) - instance.termination_deadline = termination_deadline_time - await session.commit() - - monkeypatch.setattr( - instances_pipeline, - "_check_instance_inner", - Mock(return_value=InstanceCheck(reachable=False, message="Not ok")), - ) - await _process_instance(session, worker, instance) - - await session.refresh(instance) - - assert instance.status == InstanceStatus.TERMINATING - assert instance.termination_deadline == termination_deadline_time - assert instance.termination_reason == InstanceTerminationReason.UNREACHABLE - - @pytest.mark.parametrize( - ["termination_policy", "has_job"], - [ - pytest.param(TerminationPolicy.DESTROY_AFTER_IDLE, False, id="destroy-no-job"), - pytest.param(TerminationPolicy.DESTROY_AFTER_IDLE, True, id="destroy-with-job"), - pytest.param(TerminationPolicy.DONT_DESTROY, False, id="dont-destroy-no-job"), - pytest.param(TerminationPolicy.DONT_DESTROY, True, id="dont-destroy-with-job"), - ], - ) - async def test_check_shim_process_unreachable_state( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - monkeypatch: pytest.MonkeyPatch, - termination_policy: TerminationPolicy, - has_job: bool, - ): - project = await create_project(session=session) - if has_job: - user = await create_user(session=session) - repo = await create_repo(session=session, project_id=project.id) - run = await create_run(session=session, project=project, repo=repo, user=user) - job = await create_job( - session=session, - run=run, - status=JobStatus.SUBMITTED, - ) - else: - job = None - instance = await create_instance( - session=session, - project=project, - created_at=get_current_datetime(), - termination_policy=termination_policy, - status=InstanceStatus.IDLE, - unreachable=True, - job=job, - ) - - monkeypatch.setattr( - instances_pipeline, - "_check_instance_inner", - Mock(return_value=InstanceCheck(reachable=True)), - ) - await _process_instance(session, worker, instance) - - await session.refresh(instance) - events = await list_events(session) - - assert instance.status == InstanceStatus.IDLE - assert instance.unreachable is False - assert len(events) == 1 - assert events[0].message == "Instance became reachable" - - @pytest.mark.parametrize("health_status", [HealthStatus.HEALTHY, HealthStatus.FAILURE]) - async def test_check_shim_switch_to_unreachable_state( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - monkeypatch: pytest.MonkeyPatch, - health_status: HealthStatus, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.IDLE, - unreachable=False, - health_status=health_status, - ) - - monkeypatch.setattr( - instances_pipeline, - "_check_instance_inner", - Mock(return_value=InstanceCheck(reachable=False)), - ) - await _process_instance(session, worker, instance) - - await session.refresh(instance) - events = await list_events(session) - - assert instance.status == InstanceStatus.IDLE - assert instance.unreachable is True - assert instance.health == health_status - assert len(events) == 1 - assert events[0].message == "Instance became unreachable" - - async def test_check_shim_check_instance_health( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - monkeypatch: pytest.MonkeyPatch, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.IDLE, - unreachable=False, - health_status=HealthStatus.HEALTHY, - ) - health_response = InstanceHealthResponse( - dcgm=DCGMHealthResponse( - overall_health=DCGMHealthResult.DCGM_HEALTH_RESULT_WARN, - incidents=[], - ) - ) - - monkeypatch.setattr( - instances_pipeline, - "_check_instance_inner", - Mock( - return_value=InstanceCheck( - reachable=True, - health_response=health_response, - ) - ), - ) - await _process_instance(session, worker, instance) - - await session.refresh(instance) - events = await list_events(session) - - assert instance.status == InstanceStatus.IDLE - assert instance.unreachable is False - assert instance.health == HealthStatus.WARNING - assert len(events) == 1 - assert events[0].message == "Instance health changed HEALTHY -> WARNING" - - res = await session.execute(select(InstanceHealthCheckModel)) - health_check = res.scalars().one() - assert health_check.status == HealthStatus.WARNING - assert health_check.response == health_response.json() - - async def test_terminate_by_idle_timeout( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.IDLE, - ) - instance.termination_idle_time = 300 - instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE - instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19) - await session.commit() - - await _process_instance(session, worker, instance) - await session.refresh(instance) - - assert instance.status == InstanceStatus.TERMINATING - assert instance.termination_reason == InstanceTerminationReason.IDLE_TIMEOUT - - async def test_pending_ssh_instance_terminates_on_provision_timeout( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PENDING, - created_at=get_current_datetime() - dt.timedelta(days=100), - remote_connection_info=get_remote_connection_info(), - ) - await session.commit() - - await _process_instance(session, worker, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATED - assert instance.termination_reason == InstanceTerminationReason.PROVISIONING_TIMEOUT - - async def test_terminate( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.TERMINATING, - ) - instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT - instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19) - await session.commit() - - with self.mock_terminate_in_backend() as mock: - await _process_instance(session, worker, instance) - mock.assert_called_once() - - await session.refresh(instance) - - assert instance.status == InstanceStatus.TERMINATED - assert instance.termination_reason == InstanceTerminationReason.IDLE_TIMEOUT - assert instance.deleted is True - assert instance.deleted_at is not None - assert instance.finished_at is not None - - async def test_terminates_terminating_deleted_instance( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.TERMINATING, - ) - _lock_instance(instance) - await session.commit() - item = _instance_to_pipeline_item(instance) - instance.deleted = True - instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT - instance.last_job_processed_at = instance.deleted_at = ( - get_current_datetime() + dt.timedelta(minutes=-19) - ) - await session.commit() - - with self.mock_terminate_in_backend() as mock: - await worker.process(item) - mock.assert_called_once() - - await session.refresh(instance) - - assert instance.status == InstanceStatus.TERMINATED - assert instance.deleted is True - assert instance.deleted_at is not None - assert instance.finished_at is not None - - @pytest.mark.parametrize( - "error", [BackendError("err"), RuntimeError("err"), NotYetTerminated("")] - ) - async def test_terminate_retry( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - error: Exception, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.TERMINATING, - ) - instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT - initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc) - instance.last_job_processed_at = initial_time - instance.last_processed_at = initial_time - dt.timedelta(minutes=1) - await session.commit() - - with ( - freeze_time(initial_time + dt.timedelta(minutes=1)), - self.mock_terminate_in_backend(error=error) as mock, - ): - await _process_instance(session, worker, instance) - mock.assert_called_once() - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATING - - with ( - freeze_time(initial_time + dt.timedelta(minutes=2)), - self.mock_terminate_in_backend(error=None) as mock, - ): - await _process_instance(session, worker, instance) - mock.assert_called_once() - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATED - - async def test_terminate_not_retries_if_too_early( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.TERMINATING, - ) - instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT - initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc) - instance.last_job_processed_at = initial_time - instance.last_processed_at = initial_time - dt.timedelta(minutes=1) - await session.commit() - - with ( - freeze_time(initial_time + dt.timedelta(minutes=1)), - self.mock_terminate_in_backend(error=BackendError("err")) as mock, - ): - await _process_instance(session, worker, instance) - mock.assert_called_once() - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATING - - instance.last_processed_at = initial_time - await session.commit() - - with ( - freeze_time(initial_time + dt.timedelta(minutes=1, seconds=11)), - self.mock_terminate_in_backend(error=None) as mock, - ): - await _process_instance(session, worker, instance) - mock.assert_not_called() - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATING - - async def test_terminate_on_termination_deadline( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.TERMINATING, - ) - instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT - initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc) - instance.last_job_processed_at = initial_time - instance.last_processed_at = initial_time - dt.timedelta(minutes=1) - await session.commit() - - with ( - freeze_time(initial_time + dt.timedelta(minutes=1)), - self.mock_terminate_in_backend(error=BackendError("err")) as mock, - ): - await _process_instance(session, worker, instance) - mock.assert_called_once() - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATING - - with ( - freeze_time(initial_time + dt.timedelta(minutes=15, seconds=55)), - self.mock_terminate_in_backend(error=None) as mock, - ): - await _process_instance(session, worker, instance) - mock.assert_called_once() - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATED - - @pytest.mark.parametrize( - ["cpus", "gpus", "requested_blocks", "expected_blocks"], - [ - pytest.param(32, 8, 1, 1, id="gpu-instance-no-blocks"), - pytest.param(32, 8, 2, 2, id="gpu-instance-four-gpu-per-block"), - pytest.param(32, 8, 4, 4, id="gpu-instance-two-gpus-per-block"), - pytest.param(32, 8, None, 8, id="gpu-instance-auto-max-gpu"), - pytest.param(4, 8, None, 4, id="gpu-instance-auto-max-cpu"), - pytest.param(8, 8, None, 8, id="gpu-instance-auto-max-cpu-and-gpu"), - pytest.param(32, 0, 1, 1, id="cpu-instance-no-blocks"), - pytest.param(32, 0, 2, 2, id="cpu-instance-four-cpu-per-block"), - pytest.param(32, 0, 4, 4, id="cpu-instance-two-cpus-per-block"), - pytest.param(32, 0, None, 32, id="cpu-instance-auto-max-cpu"), - ], - ) - async def test_creates_instance( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - cpus: int, - gpus: int, - requested_blocks: Optional[int], - expected_blocks: int, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PENDING, - total_blocks=requested_blocks, - busy_blocks=0, - ) - with patch("dstack._internal.server.services.backends.get_project_backends") as m: - backend_mock = Mock() - m.return_value = [backend_mock] - backend_mock.TYPE = BackendType.AWS - gpu = Gpu(name="T4", memory_mib=16384, vendor=gpuhunt.AcceleratorVendor.NVIDIA) - offer = InstanceOfferWithAvailability( - backend=BackendType.AWS, - instance=InstanceType( - name="instance", - resources=Resources( - cpus=cpus, - memory_mib=131072, - spot=False, - gpus=[gpu] * gpus, - ), - ), - region="us", - price=1.0, - availability=InstanceAvailability.AVAILABLE, - total_blocks=expected_blocks, - ) - backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) - backend_mock.compute.return_value.get_offers.return_value = [offer] - backend_mock.compute.return_value.create_instance.return_value = JobProvisioningData( - backend=offer.backend, - instance_type=offer.instance, - instance_id="instance_id", - hostname="1.1.1.1", - internal_ip=None, - region=offer.region, - price=offer.price, - username="ubuntu", - ssh_port=22, - ssh_proxy=None, - dockerized=True, - backend_data=None, - ) - - await _process_instance(session, worker, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.PROVISIONING - assert instance.total_blocks == expected_blocks - assert instance.busy_blocks == 0 - - @pytest.mark.parametrize("err", [RuntimeError("Unexpected"), ProvisioningError("Expected")]) - async def test_tries_second_offer_if_first_fails( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - err: Exception, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PENDING, - ) - aws_mock = Mock() - aws_mock.TYPE = BackendType.AWS - offer = get_instance_offer_with_availability(backend=BackendType.AWS, price=1.0) - aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) - aws_mock.compute.return_value.get_offers.return_value = [offer] - aws_mock.compute.return_value.create_instance.side_effect = err - gcp_mock = Mock() - gcp_mock.TYPE = BackendType.GCP - offer = get_instance_offer_with_availability(backend=BackendType.GCP, price=2.0) - gcp_mock.compute.return_value = Mock(spec=ComputeMockSpec) - gcp_mock.compute.return_value.get_offers.return_value = [offer] - gcp_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( - backend=offer.backend, - region=offer.region, - price=offer.price, - ) - with patch("dstack._internal.server.services.backends.get_project_backends") as m: - m.return_value = [aws_mock, gcp_mock] - await _process_instance(session, worker, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.PROVISIONING - aws_mock.compute.return_value.create_instance.assert_called_once() - assert instance.backend == BackendType.GCP - - @pytest.mark.parametrize("err", [RuntimeError("Unexpected"), ProvisioningError("Expected")]) - async def test_fails_if_all_offers_fail( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - err: Exception, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PENDING, - ) - aws_mock = Mock() - aws_mock.TYPE = BackendType.AWS - offer = get_instance_offer_with_availability(backend=BackendType.AWS, price=1.0) - aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) - aws_mock.compute.return_value.get_offers.return_value = [offer] - aws_mock.compute.return_value.create_instance.side_effect = err - with patch("dstack._internal.server.services.backends.get_project_backends") as m: - m.return_value = [aws_mock] - await _process_instance(session, worker, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATED - assert instance.termination_reason == InstanceTerminationReason.NO_OFFERS - - async def test_fails_if_no_offers( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - ): - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PENDING, - ) - with patch("dstack._internal.server.services.backends.get_project_backends") as m: - m.return_value = [] - await _process_instance(session, worker, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATED - assert instance.termination_reason == InstanceTerminationReason.NO_OFFERS - - @pytest.mark.parametrize( - ("placement", "expected_termination_reasons"), - [ - pytest.param( - InstanceGroupPlacement.CLUSTER, - { - InstanceTerminationReason.NO_OFFERS: 1, - InstanceTerminationReason.MASTER_FAILED: 3, - }, - id="cluster", - ), - pytest.param( - None, - {InstanceTerminationReason.NO_OFFERS: 4}, - id="non-cluster", - ), - ], - ) - async def test_terminates_cluster_instances_if_master_not_created( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - placement: Optional[InstanceGroupPlacement], - expected_termination_reasons: dict[str, int], - ): - project = await create_project(session=session) - fleet = await create_fleet( - session, - project, - spec=get_fleet_spec( - conf=get_fleet_configuration( - placement=placement, nodes=FleetNodesSpec(min=4, target=4, max=4) - ) - ), - ) - instances = [ - await create_instance( - session=session, - project=project, - fleet=fleet, - status=InstanceStatus.PENDING, - offer=None, - job_provisioning_data=None, - instance_num=index, - created_at=get_current_datetime() + dt.timedelta(seconds=index), - ) - for index in range(4) - ] - with patch("dstack._internal.server.services.backends.get_project_backends") as m: - m.return_value = [] - for instance in sorted(instances, key=lambda i: (i.instance_num, i.created_at)): - await _process_instance(session, worker, instance) - - termination_reasons = defaultdict(int) - for instance in instances: - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATED - termination_reasons[instance.termination_reason] += 1 - assert termination_reasons == expected_termination_reasons - - @pytest.mark.parametrize( - ("placement", "should_create"), - [ - pytest.param(InstanceGroupPlacement.CLUSTER, True, id="placement-cluster"), - pytest.param(None, False, id="no-placement"), - ], - ) - async def test_create_placement_group_if_placement_cluster( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - placement: Optional[InstanceGroupPlacement], - should_create: bool, - ) -> None: - project = await create_project(session=session) - fleet = await create_fleet( - session, - project, - spec=get_fleet_spec( - conf=get_fleet_configuration( - placement=placement, nodes=FleetNodesSpec(min=1, target=1, max=1) - ) - ), - ) - instance = await create_instance( - session=session, - project=project, - fleet=fleet, - status=InstanceStatus.PENDING, - offer=None, - job_provisioning_data=None, - ) - backend_mock = Mock() - backend_mock.TYPE = BackendType.AWS - backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) - backend_mock.compute.return_value.get_offers.return_value = [ - get_instance_offer_with_availability() - ] - backend_mock.compute.return_value.create_instance.return_value = ( - get_job_provisioning_data() - ) - backend_mock.compute.return_value.create_placement_group.return_value = ( - get_placement_group_provisioning_data() - ) - with patch("dstack._internal.server.services.backends.get_project_backends") as m: - m.return_value = [backend_mock] - await _process_instance(session, worker, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.PROVISIONING - placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() - if should_create: - assert backend_mock.compute.return_value.create_placement_group.call_count == 1 - assert len(placement_groups) == 1 - else: - assert backend_mock.compute.return_value.create_placement_group.call_count == 0 - assert len(placement_groups) == 0 - - @pytest.mark.parametrize("can_reuse", [True, False]) - async def test_reuses_placement_group_between_offers_if_the_group_is_suitable( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - can_reuse: bool, - ) -> None: - project = await create_project(session=session) - fleet = await create_fleet( - session, - project, - spec=get_fleet_spec( - conf=get_fleet_configuration( - placement=InstanceGroupPlacement.CLUSTER, - nodes=FleetNodesSpec(min=1, target=1, max=1), - ) - ), - ) - instance = await create_instance( - session=session, - project=project, - fleet=fleet, - status=InstanceStatus.PENDING, - offer=None, - job_provisioning_data=None, - ) - backend_mock = Mock() - backend_mock.TYPE = BackendType.AWS - backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) - backend_mock.compute.return_value.get_offers.return_value = [ - get_instance_offer_with_availability(instance_type="bad-offer-1"), - get_instance_offer_with_availability(instance_type="bad-offer-2"), - get_instance_offer_with_availability(instance_type="good-offer"), - ] - - def create_instance_method( - instance_offer: InstanceOfferWithAvailability, *args, **kwargs - ) -> JobProvisioningData: - if instance_offer.instance.name == "good-offer": - return get_job_provisioning_data() - raise NoCapacityError() - - backend_mock.compute.return_value.create_instance = create_instance_method - backend_mock.compute.return_value.create_placement_group.return_value = ( - get_placement_group_provisioning_data() - ) - backend_mock.compute.return_value.is_suitable_placement_group.return_value = can_reuse - with patch("dstack._internal.server.services.backends.get_project_backends") as m: - m.return_value = [backend_mock] - await _process_instance(session, worker, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.PROVISIONING - placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() - if can_reuse: - assert backend_mock.compute.return_value.create_placement_group.call_count == 1 - assert len(placement_groups) == 1 - else: - assert backend_mock.compute.return_value.create_placement_group.call_count == 3 - assert len(placement_groups) == 3 - to_be_deleted_count = sum(pg.fleet_deleted for pg in placement_groups) - assert to_be_deleted_count == 2 - - @pytest.mark.parametrize("err", [NoCapacityError(), RuntimeError()]) - async def test_handles_create_placement_group_errors( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - err: Exception, - ) -> None: - project = await create_project(session=session) - fleet = await create_fleet( - session, - project, - spec=get_fleet_spec( - conf=get_fleet_configuration( - placement=InstanceGroupPlacement.CLUSTER, - nodes=FleetNodesSpec(min=1, target=1, max=1), - ) - ), - ) - instance = await create_instance( - session=session, - project=project, - fleet=fleet, - status=InstanceStatus.PENDING, - offer=None, - job_provisioning_data=None, - ) - backend_mock = Mock() - backend_mock.TYPE = BackendType.AWS - backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) - backend_mock.compute.return_value.get_offers.return_value = [ - get_instance_offer_with_availability(instance_type="bad-offer"), - get_instance_offer_with_availability(instance_type="good-offer"), - ] - backend_mock.compute.return_value.create_instance.return_value = ( - get_job_provisioning_data() - ) - - def create_placement_group_method( - placement_group: PlacementGroup, master_instance_offer: InstanceOffer - ) -> PlacementGroupProvisioningData: - if master_instance_offer.instance.name == "good-offer": - return get_placement_group_provisioning_data() - raise err - - backend_mock.compute.return_value.create_placement_group = create_placement_group_method - with patch("dstack._internal.server.services.backends.get_project_backends") as m: - m.return_value = [backend_mock] - await _process_instance(session, worker, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.PROVISIONING - assert instance.offer - assert "good-offer" in instance.offer - assert "bad-offer" not in instance.offer - placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() - assert len(placement_groups) == 1 - - @pytest.mark.parametrize( - ["cpus", "gpus", "requested_blocks", "expected_blocks"], - [ - pytest.param(32, 8, 1, 1, id="gpu-instance-no-blocks"), - pytest.param(32, 8, 2, 2, id="gpu-instance-four-gpu-per-block"), - pytest.param(32, 8, 4, 4, id="gpu-instance-two-gpus-per-block"), - pytest.param(32, 8, None, 8, id="gpu-instance-auto-max-gpu"), - pytest.param(4, 8, None, 4, id="gpu-instance-auto-max-cpu"), - pytest.param(8, 8, None, 8, id="gpu-instance-auto-max-cpu-and-gpu"), - pytest.param(32, 0, 1, 1, id="cpu-instance-no-blocks"), - pytest.param(32, 0, 2, 2, id="cpu-instance-four-cpu-per-block"), - pytest.param(32, 0, 4, 4, id="cpu-instance-two-cpus-per-block"), - pytest.param(32, 0, None, 32, id="cpu-instance-auto-max-cpu"), - ], - ) - async def test_adds_ssh_instance( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - host_info: dict, - deploy_instance_mock: Mock, - cpus: int, - gpus: int, - requested_blocks: Optional[int], - expected_blocks: int, - ): - host_info["cpus"] = cpus - host_info["gpu_count"] = gpus - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PENDING, - created_at=get_current_datetime(), - remote_connection_info=get_remote_connection_info(), - total_blocks=requested_blocks, - busy_blocks=0, - ) - await session.commit() - - await _process_instance(session, worker, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.IDLE - assert instance.total_blocks == expected_blocks - assert instance.busy_blocks == 0 - deploy_instance_mock.assert_called_once() - - async def test_retries_ssh_instance_if_provisioning_fails( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - deploy_instance_mock: Mock, - ): - deploy_instance_mock.side_effect = SSHProvisioningError("Expected") - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PENDING, - created_at=get_current_datetime(), - remote_connection_info=get_remote_connection_info(), - ) - await session.commit() - - await _process_instance(session, worker, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.PENDING - assert instance.termination_reason is None - - async def test_terminates_ssh_instance_if_deploy_fails_unexpectedly( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - deploy_instance_mock: Mock, - ): - deploy_instance_mock.side_effect = RuntimeError("Unexpected") - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PENDING, - created_at=get_current_datetime(), - remote_connection_info=get_remote_connection_info(), - ) - await session.commit() - - await _process_instance(session, worker, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATED - assert instance.termination_reason == InstanceTerminationReason.ERROR - assert instance.termination_reason_message == "Unexpected error when adding SSH instance" - - async def test_terminates_ssh_instance_if_key_is_invalid( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - monkeypatch: pytest.MonkeyPatch, - ): - monkeypatch.setattr( - instances_pipeline, - "_ssh_keys_to_pkeys", - Mock(side_effect=ValueError("Bad key")), - ) - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PENDING, - created_at=get_current_datetime(), - remote_connection_info=get_remote_connection_info(), - ) - await session.commit() - - await _process_instance(session, worker, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATED - assert instance.termination_reason == InstanceTerminationReason.ERROR - assert instance.termination_reason_message == "Unsupported private SSH key type" - - async def test_terminates_ssh_instance_if_internal_ip_cannot_be_resolved_from_network( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - host_info: dict, - deploy_instance_mock: Mock, - ): - host_info["addresses"] = ["192.168.100.100/24"] - project = await create_project(session=session) - job_provisioning_data = get_job_provisioning_data( - dockerized=True, - backend=BackendType.REMOTE, - internal_ip=None, - ) - job_provisioning_data.instance_network = "10.0.0.0/24" - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PENDING, - created_at=get_current_datetime(), - remote_connection_info=get_remote_connection_info(), - job_provisioning_data=job_provisioning_data, - ) - await session.commit() - - await _process_instance(session, worker, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATED - assert instance.termination_reason == InstanceTerminationReason.ERROR - assert ( - instance.termination_reason_message - == "Failed to locate internal IP address on the given network" - ) - - async def test_terminates_ssh_instance_if_internal_ip_is_not_in_host_interfaces( - self, - test_db, - session: AsyncSession, - fetcher: InstanceFetcher, - worker: InstanceWorker, - host_info: dict, - deploy_instance_mock: Mock, - ): - host_info["addresses"] = ["192.168.100.100/24"] - project = await create_project(session=session) - job_provisioning_data = get_job_provisioning_data( - dockerized=True, - backend=BackendType.REMOTE, - internal_ip="10.0.0.20", - ) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.PENDING, - created_at=get_current_datetime(), - remote_connection_info=get_remote_connection_info(), - job_provisioning_data=job_provisioning_data, - ) - await session.commit() - - await _process_instance(session, worker, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATED - assert instance.termination_reason == InstanceTerminationReason.ERROR - assert ( - instance.termination_reason_message - == "Specified internal IP not found among instance interfaces" - ) - - -@pytest.mark.asyncio -@pytest.mark.usefixtures("turn_off_keep_shim_tasks_setting") -@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) -class TestRemoveDanglingTasks: - @pytest.fixture - def turn_off_keep_shim_tasks_setting(self, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr("dstack._internal.server.settings.SERVER_KEEP_SHIM_TASKS", False) - - async def test_terminates_and_removes_dangling_tasks(self, test_db, session: AsyncSession): - user = await create_user(session=session) - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.BUSY, - ) - repo = await create_repo(session=session, project_id=project.id) - run = await create_run( - session=session, - project=project, - repo=repo, - user=user, - ) - job = await create_job( - session=session, - run=run, - status=JobStatus.RUNNING, - instance=instance, - ) - dangling_task_id_1 = "fe138b77-d0b1-49d3-8c9f-2dfe78ece727" - dangling_task_id_2 = "8b016a75-41de-44f1-91ff-c9b63d2caa1d" - shim_client_mock = Mock(spec_set=ShimClient) - shim_client_mock.is_api_v2_supported.return_value = True - shim_client_mock.list_tasks.return_value = TaskListResponse( - tasks=[ - TaskListItem(id=str(job.id), status=TaskStatus.RUNNING), - TaskListItem(id=dangling_task_id_1, status=TaskStatus.RUNNING), - TaskListItem(id=dangling_task_id_2, status=TaskStatus.TERMINATED), - ] - ) - await session.refresh(instance, attribute_names=["jobs"]) - - instances_pipeline.remove_dangling_tasks_from_instance(shim_client_mock, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.BUSY - - shim_client_mock.terminate_task.assert_called_once_with( - task_id=dangling_task_id_1, - reason=None, - message=None, - timeout=0, - ) - assert shim_client_mock.remove_task.call_count == 2 - shim_client_mock.remove_task.assert_has_calls( - [call(task_id=dangling_task_id_1), call(task_id=dangling_task_id_2)] - ) - - async def test_terminates_and_removes_dangling_tasks_legacy_shim( - self, test_db, session: AsyncSession - ): - user = await create_user(session=session) - project = await create_project(session=session) - instance = await create_instance( - session=session, - project=project, - status=InstanceStatus.BUSY, - ) - repo = await create_repo(session=session, project_id=project.id) - run = await create_run( - session=session, - project=project, - repo=repo, - user=user, - ) - job = await create_job( - session=session, - run=run, - status=JobStatus.RUNNING, - instance=instance, - ) - dangling_task_id_1 = "fe138b77-d0b1-49d3-8c9f-2dfe78ece727" - dangling_task_id_2 = "8b016a75-41de-44f1-91ff-c9b63d2caa1d" - shim_client_mock = Mock(spec_set=ShimClient) - shim_client_mock.is_api_v2_supported.return_value = True - shim_client_mock.list_tasks.return_value = TaskListResponse( - ids=[str(job.id), dangling_task_id_1, dangling_task_id_2] - ) - await session.refresh(instance, attribute_names=["jobs"]) - - instances_pipeline.remove_dangling_tasks_from_instance(shim_client_mock, instance) - - await session.refresh(instance) - assert instance.status == InstanceStatus.BUSY - - assert shim_client_mock.terminate_task.call_count == 2 - shim_client_mock.terminate_task.assert_has_calls( - [ - call(task_id=dangling_task_id_1, reason=None, message=None, timeout=0), - call(task_id=dangling_task_id_2, reason=None, message=None, timeout=0), - ] - ) - assert shim_client_mock.remove_task.call_count == 2 - shim_client_mock.remove_task.assert_has_calls( - [call(task_id=dangling_task_id_1), call(task_id=dangling_task_id_2)] - ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) -class BaseTestMaybeInstallComponents: - EXPECTED_VERSION = "0.20.1" - - @pytest_asyncio.fixture - async def instance(self, session: AsyncSession) -> InstanceModel: - project = await create_project(session=session) - return await create_instance( - session=session, - project=project, - status=InstanceStatus.BUSY, - ) - - @pytest.fixture - def component_list(self) -> ComponentList: - return ComponentList() - - @pytest.fixture - def debug_task_log(self, caplog: pytest.LogCaptureFixture) -> pytest.LogCaptureFixture: - caplog.set_level(level=logging.DEBUG, logger=instances_pipeline.__name__) - return caplog - - @pytest.fixture - def shim_client_mock( - self, - monkeypatch: pytest.MonkeyPatch, - component_list: ComponentList, - ) -> Mock: - mock = Mock(spec_set=ShimClient) - mock.healthcheck.return_value = HealthcheckResponse( - service="dstack-shim", - version=self.EXPECTED_VERSION, - ) - mock.get_instance_health.return_value = InstanceHealthResponse() - mock.get_components.return_value = component_list - mock.list_tasks.return_value = TaskListResponse(tasks=[]) - mock.is_safe_to_restart.return_value = False - monkeypatch.setattr( - "dstack._internal.server.services.runner.client.ShimClient", - Mock(return_value=mock), - ) - return mock - - -@pytest.mark.usefixtures("get_dstack_runner_version_mock") -class TestMaybeInstallRunner(BaseTestMaybeInstallComponents): - @pytest.fixture - def component_list(self) -> ComponentList: - components = ComponentList() - components.add( - ComponentInfo( - name=ComponentName.RUNNER, - version=self.EXPECTED_VERSION, - status=ComponentStatus.INSTALLED, - ), - ) - return components - - @pytest.fixture - def get_dstack_runner_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: - mock = Mock(return_value=self.EXPECTED_VERSION) - monkeypatch.setattr(instances_pipeline, "get_dstack_runner_version", mock) - return mock - - @pytest.fixture - def get_dstack_runner_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: - mock = Mock(return_value="https://example.com/runner") - monkeypatch.setattr(instances_pipeline, "get_dstack_runner_download_url", mock) - return mock - - async def test_cannot_determine_expected_version( - self, - test_db, - instance: InstanceModel, - debug_task_log: pytest.LogCaptureFixture, - shim_client_mock: Mock, - get_dstack_runner_version_mock: Mock, - ): - get_dstack_runner_version_mock.return_value = None - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - assert "Cannot determine the expected runner version" in debug_task_log.text - shim_client_mock.get_components.assert_called_once() - shim_client_mock.install_runner.assert_not_called() - - async def test_expected_version_already_installed( - self, - test_db, - instance: InstanceModel, - debug_task_log: pytest.LogCaptureFixture, - shim_client_mock: Mock, - ): - shim_client_mock.get_components.return_value.runner.version = self.EXPECTED_VERSION - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - assert "expected runner version already installed" in debug_task_log.text - shim_client_mock.get_components.assert_called_once() - shim_client_mock.install_runner.assert_not_called() - - @pytest.mark.parametrize("status", [ComponentStatus.NOT_INSTALLED, ComponentStatus.ERROR]) - async def test_install_not_installed_or_error( - self, - test_db, - instance: InstanceModel, - debug_task_log: pytest.LogCaptureFixture, - shim_client_mock: Mock, - get_dstack_runner_download_url_mock: Mock, - status: ComponentStatus, - ): - shim_client_mock.get_components.return_value.runner.version = "" - shim_client_mock.get_components.return_value.runner.status = status - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - assert f"installing runner (no version) -> {self.EXPECTED_VERSION}" in debug_task_log.text - get_dstack_runner_download_url_mock.assert_called_once_with( - arch=None, - version=self.EXPECTED_VERSION, - ) - shim_client_mock.get_components.assert_called_once() - shim_client_mock.install_runner.assert_called_once_with( - get_dstack_runner_download_url_mock.return_value - ) - - @pytest.mark.parametrize("installed_version", ["0.19.40", "0.21.0", "dev"]) - async def test_install_installed( - self, - test_db, - instance: InstanceModel, - debug_task_log: pytest.LogCaptureFixture, - shim_client_mock: Mock, - get_dstack_runner_download_url_mock: Mock, - installed_version: str, - ): - shim_client_mock.get_components.return_value.runner.version = installed_version - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - assert ( - f"installing runner {installed_version} -> {self.EXPECTED_VERSION}" - in debug_task_log.text - ) - get_dstack_runner_download_url_mock.assert_called_once_with( - arch=None, - version=self.EXPECTED_VERSION, - ) - shim_client_mock.get_components.assert_called_once() - shim_client_mock.install_runner.assert_called_once_with( - get_dstack_runner_download_url_mock.return_value - ) - - async def test_already_installing( - self, - test_db, - instance: InstanceModel, - debug_task_log: pytest.LogCaptureFixture, - shim_client_mock: Mock, - ): - shim_client_mock.get_components.return_value.runner.version = "dev" - shim_client_mock.get_components.return_value.runner.status = ComponentStatus.INSTALLING - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - assert "runner is already being installed" in debug_task_log.text - shim_client_mock.get_components.assert_called_once() - shim_client_mock.install_runner.assert_not_called() - - -@pytest.mark.usefixtures("get_dstack_shim_version_mock") -class TestMaybeInstallShim(BaseTestMaybeInstallComponents): - @pytest.fixture - def component_list(self) -> ComponentList: - components = ComponentList() - components.add( - ComponentInfo( - name=ComponentName.SHIM, - version=self.EXPECTED_VERSION, - status=ComponentStatus.INSTALLED, - ), - ) - return components - - @pytest.fixture - def get_dstack_shim_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: - mock = Mock(return_value=self.EXPECTED_VERSION) - monkeypatch.setattr(instances_pipeline, "get_dstack_shim_version", mock) - return mock - - @pytest.fixture - def get_dstack_shim_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: - mock = Mock(return_value="https://example.com/shim") - monkeypatch.setattr(instances_pipeline, "get_dstack_shim_download_url", mock) - return mock - - async def test_cannot_determine_expected_version( - self, - test_db, - instance: InstanceModel, - debug_task_log: pytest.LogCaptureFixture, - shim_client_mock: Mock, - get_dstack_shim_version_mock: Mock, - ): - get_dstack_shim_version_mock.return_value = None - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - assert "Cannot determine the expected shim version" in debug_task_log.text - shim_client_mock.get_components.assert_called_once() - shim_client_mock.install_shim.assert_not_called() - - async def test_expected_version_already_installed( - self, - test_db, - instance: InstanceModel, - debug_task_log: pytest.LogCaptureFixture, - shim_client_mock: Mock, - ): - shim_client_mock.get_components.return_value.shim.version = self.EXPECTED_VERSION - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - assert "expected shim version already installed" in debug_task_log.text - shim_client_mock.get_components.assert_called_once() - shim_client_mock.install_shim.assert_not_called() - - @pytest.mark.parametrize("status", [ComponentStatus.NOT_INSTALLED, ComponentStatus.ERROR]) - async def test_install_not_installed_or_error( - self, - test_db, - instance: InstanceModel, - debug_task_log: pytest.LogCaptureFixture, - shim_client_mock: Mock, - get_dstack_shim_download_url_mock: Mock, - status: ComponentStatus, - ): - shim_client_mock.get_components.return_value.shim.version = "" - shim_client_mock.get_components.return_value.shim.status = status - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - assert f"installing shim (no version) -> {self.EXPECTED_VERSION}" in debug_task_log.text - get_dstack_shim_download_url_mock.assert_called_once_with( - arch=None, - version=self.EXPECTED_VERSION, - ) - shim_client_mock.get_components.assert_called_once() - shim_client_mock.install_shim.assert_called_once_with( - get_dstack_shim_download_url_mock.return_value - ) - - @pytest.mark.parametrize("installed_version", ["0.19.40", "0.21.0", "dev"]) - async def test_install_installed( - self, - test_db, - instance: InstanceModel, - debug_task_log: pytest.LogCaptureFixture, - shim_client_mock: Mock, - get_dstack_shim_download_url_mock: Mock, - installed_version: str, - ): - shim_client_mock.get_components.return_value.shim.version = installed_version - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - assert ( - f"installing shim {installed_version} -> {self.EXPECTED_VERSION}" - in debug_task_log.text - ) - get_dstack_shim_download_url_mock.assert_called_once_with( - arch=None, - version=self.EXPECTED_VERSION, - ) - shim_client_mock.get_components.assert_called_once() - shim_client_mock.install_shim.assert_called_once_with( - get_dstack_shim_download_url_mock.return_value - ) - - async def test_already_installing( - self, - test_db, - instance: InstanceModel, - debug_task_log: pytest.LogCaptureFixture, - shim_client_mock: Mock, - ): - shim_client_mock.get_components.return_value.shim.version = "dev" - shim_client_mock.get_components.return_value.shim.status = ComponentStatus.INSTALLING - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - assert "shim is already being installed" in debug_task_log.text - shim_client_mock.get_components.assert_called_once() - shim_client_mock.install_shim.assert_not_called() - - -@pytest.mark.usefixtures("maybe_install_runner_mock", "maybe_install_shim_mock") -class TestMaybeRestartShim(BaseTestMaybeInstallComponents): - @pytest.fixture - def component_list(self) -> ComponentList: - components = ComponentList() - components.add( - ComponentInfo( - name=ComponentName.RUNNER, - version=self.EXPECTED_VERSION, - status=ComponentStatus.INSTALLED, - ), - ) - components.add( - ComponentInfo( - name=ComponentName.SHIM, - version=self.EXPECTED_VERSION, - status=ComponentStatus.INSTALLED, - ), - ) - return components - - @pytest.fixture - def maybe_install_runner_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: - mock = Mock(return_value=False) - monkeypatch.setattr(instances_pipeline, "_maybe_install_runner", mock) - return mock - - @pytest.fixture - def maybe_install_shim_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: - mock = Mock(return_value=False) - monkeypatch.setattr(instances_pipeline, "_maybe_install_shim", mock) - return mock - - async def test_up_to_date(self, test_db, instance: InstanceModel, shim_client_mock: Mock): - shim_client_mock.get_version_string.return_value = self.EXPECTED_VERSION - shim_client_mock.is_safe_to_restart.return_value = True - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - shim_client_mock.get_components.assert_called_once() - shim_client_mock.shutdown.assert_not_called() - - async def test_no_shim_component_info( - self, test_db, instance: InstanceModel, shim_client_mock: Mock - ): - shim_client_mock.get_components.return_value = ComponentList() - shim_client_mock.get_version_string.return_value = "outdated" - shim_client_mock.is_safe_to_restart.return_value = True - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - shim_client_mock.get_components.assert_called_once() - shim_client_mock.shutdown.assert_not_called() - - async def test_outdated_shutdown_requested( - self, test_db, instance: InstanceModel, shim_client_mock: Mock - ): - shim_client_mock.get_version_string.return_value = "outdated" - shim_client_mock.is_safe_to_restart.return_value = True - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - shim_client_mock.get_components.assert_called_once() - shim_client_mock.shutdown.assert_called_once_with(force=False) - - async def test_outdated_but_task_wont_survive_restart( - self, test_db, instance: InstanceModel, shim_client_mock: Mock - ): - shim_client_mock.get_version_string.return_value = "outdated" - shim_client_mock.is_safe_to_restart.return_value = False - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - shim_client_mock.get_components.assert_called_once() - shim_client_mock.shutdown.assert_not_called() - - async def test_outdated_but_runner_installation_in_progress( - self, - test_db, - instance: InstanceModel, - shim_client_mock: Mock, - component_list: ComponentList, - ): - shim_client_mock.get_version_string.return_value = "outdated" - shim_client_mock.is_safe_to_restart.return_value = True - runner_info = component_list.runner - assert runner_info is not None - runner_info.status = ComponentStatus.INSTALLING - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - shim_client_mock.get_components.assert_called_once() - shim_client_mock.shutdown.assert_not_called() - - async def test_outdated_but_shim_installation_in_progress( - self, - test_db, - instance: InstanceModel, - shim_client_mock: Mock, - component_list: ComponentList, - ): - shim_client_mock.get_version_string.return_value = "outdated" - shim_client_mock.is_safe_to_restart.return_value = True - shim_info = component_list.shim - assert shim_info is not None - shim_info.status = ComponentStatus.INSTALLING - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - shim_client_mock.get_components.assert_called_once() - shim_client_mock.shutdown.assert_not_called() - - async def test_outdated_but_runner_installation_requested( - self, - test_db, - instance: InstanceModel, - shim_client_mock: Mock, - maybe_install_runner_mock: Mock, - ): - shim_client_mock.get_version_string.return_value = "outdated" - shim_client_mock.is_safe_to_restart.return_value = True - maybe_install_runner_mock.return_value = True - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - shim_client_mock.get_components.assert_called_once() - shim_client_mock.shutdown.assert_not_called() - - async def test_outdated_but_shim_installation_requested( - self, - test_db, - instance: InstanceModel, - shim_client_mock: Mock, - maybe_install_shim_mock: Mock, - ): - shim_client_mock.get_version_string.return_value = "outdated" - shim_client_mock.is_safe_to_restart.return_value = True - maybe_install_shim_mock.return_value = True - - instances_pipeline._maybe_install_components(instance, shim_client_mock) - - shim_client_mock.get_components.assert_called_once() - shim_client_mock.shutdown.assert_not_called() diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/__init__.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/conftest.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/conftest.py new file mode 100644 index 0000000000..f7600e0ba6 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/conftest.py @@ -0,0 +1,58 @@ +import asyncio +import datetime as dt +from unittest.mock import Mock + +import pytest + +from dstack._internal.core.backends.base.compute import GoArchType +from dstack._internal.server.background.pipeline_tasks.instances import ( + InstanceFetcher, + InstanceWorker, +) +from dstack._internal.server.background.pipeline_tasks.instances import ( + ssh_deploy as instances_ssh_deploy, +) +from dstack._internal.server.schemas.instances import InstanceCheck + + +@pytest.fixture +def fetcher() -> InstanceFetcher: + return InstanceFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=dt.timedelta(seconds=10), + lock_timeout=dt.timedelta(seconds=30), + heartbeater=Mock(), + ) + + +@pytest.fixture +def worker() -> InstanceWorker: + return InstanceWorker(queue=asyncio.Queue(), heartbeater=Mock()) + + +@pytest.fixture +def host_info() -> dict: + return { + "gpu_vendor": "nvidia", + "gpu_name": "T4", + "gpu_memory": 16384, + "gpu_count": 1, + "addresses": ["192.168.100.100/24"], + "disk_size": 260976517120, + "cpus": 32, + "memory": 33544130560, + } + + +@pytest.fixture +def deploy_instance_mock(monkeypatch: pytest.MonkeyPatch, host_info: dict) -> Mock: + mock = Mock( + return_value=( + InstanceCheck(reachable=True), + host_info, + GoArchType.AMD64, + ) + ) + monkeypatch.setattr(instances_ssh_deploy, "_deploy_instance", mock) + return mock diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/helpers.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/helpers.py new file mode 100644 index 0000000000..81eb0fde5c --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/helpers.py @@ -0,0 +1,40 @@ +import datetime as dt +import uuid + +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.server.background.pipeline_tasks.instances import ( + InstancePipeline, + InstancePipelineItem, + InstanceWorker, +) +from dstack._internal.server.models import InstanceModel + +LOCK_EXPIRES_AT = dt.datetime(2025, 1, 2, 3, 4, tzinfo=dt.timezone.utc) + + +def instance_to_pipeline_item(instance_model: InstanceModel) -> InstancePipelineItem: + assert instance_model.lock_token is not None + assert instance_model.lock_expires_at is not None + return InstancePipelineItem( + __tablename__=instance_model.__tablename__, + id=instance_model.id, + lock_token=instance_model.lock_token, + lock_expires_at=instance_model.lock_expires_at, + prev_lock_expired=False, + status=instance_model.status, + ) + + +def lock_instance(instance_model: InstanceModel) -> None: + instance_model.lock_token = uuid.uuid4() + instance_model.lock_expires_at = LOCK_EXPIRES_AT + instance_model.lock_owner = InstancePipeline.__name__ + + +async def process_instance( + session: AsyncSession, worker: InstanceWorker, instance_model: InstanceModel +) -> None: + lock_instance(instance_model) + await session.commit() + await worker.process(instance_to_pipeline_item(instance_model)) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py new file mode 100644 index 0000000000..5cfc8f887a --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py @@ -0,0 +1,863 @@ +import datetime as dt +import logging +from unittest.mock import Mock + +import pytest +import pytest_asyncio +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.health import HealthStatus +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.core.models.profiles import TerminationPolicy +from dstack._internal.core.models.runs import JobStatus +from dstack._internal.server.background.pipeline_tasks.instances import InstanceWorker +from dstack._internal.server.background.pipeline_tasks.instances import check as instances_check +from dstack._internal.server.models import InstanceHealthCheckModel, InstanceModel +from dstack._internal.server.schemas.health.dcgm import DCGMHealthResponse, DCGMHealthResult +from dstack._internal.server.schemas.instances import InstanceCheck +from dstack._internal.server.schemas.runner import ( + ComponentInfo, + ComponentName, + ComponentStatus, + HealthcheckResponse, + InstanceHealthResponse, + TaskListResponse, +) +from dstack._internal.server.services.runner.client import ComponentList, ShimClient +from dstack._internal.server.testing.common import ( + create_instance, + create_job, + create_project, + create_repo, + create_run, + create_user, + get_remote_connection_info, + list_events, +) +from dstack._internal.utils.common import get_current_datetime +from tests._internal.server.background.pipeline_tasks.test_instances.helpers import ( + process_instance, +) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("image_config_mock") +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestInstanceCheck: + async def test_check_shim_transitions_provisioning_on_ready( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + ) + instance.termination_deadline = get_current_datetime() + dt.timedelta(days=1) + await session.commit() + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.termination_deadline is None + + async def test_check_shim_transitions_provisioning_on_terminating( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + ) + instance.started_at = get_current_datetime() + dt.timedelta(minutes=-20) + await session.commit() + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False, message="Shim problem")), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATING + assert instance.termination_deadline is not None + + async def test_check_shim_transitions_provisioning_on_busy( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + ) + instance.termination_deadline = get_current_datetime().replace( + tzinfo=dt.timezone.utc + ) + dt.timedelta(days=1) + job = await create_job( + session=session, + run=run, + status=JobStatus.SUBMITTED, + instance=instance, + ) + await session.commit() + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + await session.refresh(job) + + assert instance.status == InstanceStatus.BUSY + assert instance.termination_deadline is None + assert job.instance == instance + + async def test_check_shim_start_termination_deadline( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=False, + ) + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False, message="SSH connection fail")), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is True + assert instance.termination_deadline is not None + assert instance.termination_deadline.replace( + tzinfo=dt.timezone.utc + ) > get_current_datetime() + dt.timedelta(minutes=19) + + async def test_check_shim_does_not_start_termination_deadline_with_ssh_instance( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=False, + remote_connection_info=get_remote_connection_info(), + ) + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False, message="SSH connection fail")), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is True + assert instance.termination_deadline is None + + async def test_check_shim_stop_termination_deadline( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + ) + instance.termination_deadline = get_current_datetime() + dt.timedelta(minutes=19) + await session.commit() + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.termination_deadline is None + + async def test_check_shim_terminate_instance_by_deadline( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + ) + termination_deadline_time = get_current_datetime() + dt.timedelta(minutes=-19) + instance.termination_deadline = termination_deadline_time + await session.commit() + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False, message="Not ok")), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATING + assert instance.termination_deadline == termination_deadline_time + assert instance.termination_reason == InstanceTerminationReason.UNREACHABLE + + @pytest.mark.parametrize( + ["termination_policy", "has_job"], + [ + pytest.param(TerminationPolicy.DESTROY_AFTER_IDLE, False, id="destroy-no-job"), + pytest.param(TerminationPolicy.DESTROY_AFTER_IDLE, True, id="destroy-with-job"), + pytest.param(TerminationPolicy.DONT_DESTROY, False, id="dont-destroy-no-job"), + pytest.param(TerminationPolicy.DONT_DESTROY, True, id="dont-destroy-with-job"), + ], + ) + async def test_check_shim_process_unreachable_state( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + termination_policy: TerminationPolicy, + has_job: bool, + ): + project = await create_project(session=session) + if has_job: + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job = await create_job( + session=session, + run=run, + status=JobStatus.SUBMITTED, + ) + else: + job = None + instance = await create_instance( + session=session, + project=project, + created_at=get_current_datetime(), + termination_policy=termination_policy, + status=InstanceStatus.IDLE, + unreachable=True, + job=job, + ) + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + events = await list_events(session) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is False + assert len(events) == 1 + assert events[0].message == "Instance became reachable" + + @pytest.mark.parametrize("health_status", [HealthStatus.HEALTHY, HealthStatus.FAILURE]) + async def test_check_shim_switch_to_unreachable_state( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + health_status: HealthStatus, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=False, + health_status=health_status, + ) + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=False)), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + events = await list_events(session) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is True + assert instance.health == health_status + assert len(events) == 1 + assert events[0].message == "Instance became unreachable" + + async def test_check_shim_check_instance_health( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + unreachable=False, + health_status=HealthStatus.HEALTHY, + ) + health_response = InstanceHealthResponse( + dcgm=DCGMHealthResponse( + overall_health=DCGMHealthResult.DCGM_HEALTH_RESULT_WARN, + incidents=[], + ) + ) + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock( + return_value=InstanceCheck( + reachable=True, + health_response=health_response, + ) + ), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + events = await list_events(session) + + assert instance.status == InstanceStatus.IDLE + assert instance.unreachable is False + assert instance.health == HealthStatus.WARNING + assert len(events) == 1 + assert events[0].message == "Instance health changed HEALTHY -> WARNING" + + res = await session.execute(select(InstanceHealthCheckModel)) + health_check = res.scalars().one() + assert health_check.status == HealthStatus.WARNING + assert health_check.response == health_response.json() + + async def test_terminate_by_idle_timeout( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + ) + instance.termination_idle_time = 300 + instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE + instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19) + await session.commit() + + await process_instance(session, worker, instance) + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATING + assert instance.termination_reason == InstanceTerminationReason.IDLE_TIMEOUT + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class BaseTestMaybeInstallComponents: + EXPECTED_VERSION = "0.20.1" + + @pytest_asyncio.fixture + async def instance(self, session: AsyncSession) -> InstanceModel: + project = await create_project(session=session) + return await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + + @pytest.fixture + def component_list(self) -> ComponentList: + return ComponentList() + + @pytest.fixture + def debug_task_log(self, caplog: pytest.LogCaptureFixture) -> pytest.LogCaptureFixture: + caplog.set_level(level=logging.DEBUG, logger=instances_check.__name__) + return caplog + + @pytest.fixture + def shim_client_mock( + self, + monkeypatch: pytest.MonkeyPatch, + component_list: ComponentList, + ) -> Mock: + mock = Mock(spec_set=ShimClient) + mock.healthcheck.return_value = HealthcheckResponse( + service="dstack-shim", + version=self.EXPECTED_VERSION, + ) + mock.get_instance_health.return_value = InstanceHealthResponse() + mock.get_components.return_value = component_list + mock.list_tasks.return_value = TaskListResponse(tasks=[]) + mock.is_safe_to_restart.return_value = False + monkeypatch.setattr( + "dstack._internal.server.services.runner.client.ShimClient", + Mock(return_value=mock), + ) + return mock + + +@pytest.mark.usefixtures("get_dstack_runner_version_mock") +class TestMaybeInstallRunner(BaseTestMaybeInstallComponents): + @pytest.fixture + def component_list(self) -> ComponentList: + components = ComponentList() + components.add( + ComponentInfo( + name=ComponentName.RUNNER, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + return components + + @pytest.fixture + def get_dstack_runner_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=self.EXPECTED_VERSION) + monkeypatch.setattr(instances_check, "get_dstack_runner_version", mock) + return mock + + @pytest.fixture + def get_dstack_runner_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value="https://example.com/runner") + monkeypatch.setattr(instances_check, "get_dstack_runner_download_url", mock) + return mock + + async def test_cannot_determine_expected_version( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_runner_version_mock: Mock, + ): + get_dstack_runner_version_mock.return_value = None + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert "Cannot determine the expected runner version" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_not_called() + + async def test_expected_version_already_installed( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + ): + shim_client_mock.get_components.return_value.runner.version = self.EXPECTED_VERSION + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert "expected runner version already installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_not_called() + + @pytest.mark.parametrize("status", [ComponentStatus.NOT_INSTALLED, ComponentStatus.ERROR]) + async def test_install_not_installed_or_error( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_runner_download_url_mock: Mock, + status: ComponentStatus, + ): + shim_client_mock.get_components.return_value.runner.version = "" + shim_client_mock.get_components.return_value.runner.status = status + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert f"installing runner (no version) -> {self.EXPECTED_VERSION}" in debug_task_log.text + get_dstack_runner_download_url_mock.assert_called_once_with( + arch=None, + version=self.EXPECTED_VERSION, + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_called_once_with( + get_dstack_runner_download_url_mock.return_value + ) + + @pytest.mark.parametrize("installed_version", ["0.19.40", "0.21.0", "dev"]) + async def test_install_installed( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_runner_download_url_mock: Mock, + installed_version: str, + ): + shim_client_mock.get_components.return_value.runner.version = installed_version + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert ( + f"installing runner {installed_version} -> {self.EXPECTED_VERSION}" + in debug_task_log.text + ) + get_dstack_runner_download_url_mock.assert_called_once_with( + arch=None, + version=self.EXPECTED_VERSION, + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_called_once_with( + get_dstack_runner_download_url_mock.return_value + ) + + async def test_already_installing( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + ): + shim_client_mock.get_components.return_value.runner.version = "dev" + shim_client_mock.get_components.return_value.runner.status = ComponentStatus.INSTALLING + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert "runner is already being installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_runner.assert_not_called() + + +@pytest.mark.usefixtures("get_dstack_shim_version_mock") +class TestMaybeInstallShim(BaseTestMaybeInstallComponents): + @pytest.fixture + def component_list(self) -> ComponentList: + components = ComponentList() + components.add( + ComponentInfo( + name=ComponentName.SHIM, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + return components + + @pytest.fixture + def get_dstack_shim_version_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=self.EXPECTED_VERSION) + monkeypatch.setattr(instances_check, "get_dstack_shim_version", mock) + return mock + + @pytest.fixture + def get_dstack_shim_download_url_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value="https://example.com/shim") + monkeypatch.setattr(instances_check, "get_dstack_shim_download_url", mock) + return mock + + async def test_cannot_determine_expected_version( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_shim_version_mock: Mock, + ): + get_dstack_shim_version_mock.return_value = None + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert "Cannot determine the expected shim version" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_not_called() + + async def test_expected_version_already_installed( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + ): + shim_client_mock.get_components.return_value.shim.version = self.EXPECTED_VERSION + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert "expected shim version already installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_not_called() + + @pytest.mark.parametrize("status", [ComponentStatus.NOT_INSTALLED, ComponentStatus.ERROR]) + async def test_install_not_installed_or_error( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_shim_download_url_mock: Mock, + status: ComponentStatus, + ): + shim_client_mock.get_components.return_value.shim.version = "" + shim_client_mock.get_components.return_value.shim.status = status + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert f"installing shim (no version) -> {self.EXPECTED_VERSION}" in debug_task_log.text + get_dstack_shim_download_url_mock.assert_called_once_with( + arch=None, + version=self.EXPECTED_VERSION, + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_called_once_with( + get_dstack_shim_download_url_mock.return_value + ) + + @pytest.mark.parametrize("installed_version", ["0.19.40", "0.21.0", "dev"]) + async def test_install_installed( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + get_dstack_shim_download_url_mock: Mock, + installed_version: str, + ): + shim_client_mock.get_components.return_value.shim.version = installed_version + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert ( + f"installing shim {installed_version} -> {self.EXPECTED_VERSION}" + in debug_task_log.text + ) + get_dstack_shim_download_url_mock.assert_called_once_with( + arch=None, + version=self.EXPECTED_VERSION, + ) + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_called_once_with( + get_dstack_shim_download_url_mock.return_value + ) + + async def test_already_installing( + self, + test_db, + instance: InstanceModel, + debug_task_log: pytest.LogCaptureFixture, + shim_client_mock: Mock, + ): + shim_client_mock.get_components.return_value.shim.version = "dev" + shim_client_mock.get_components.return_value.shim.status = ComponentStatus.INSTALLING + + instances_check._maybe_install_components(instance, shim_client_mock) + + assert "shim is already being installed" in debug_task_log.text + shim_client_mock.get_components.assert_called_once() + shim_client_mock.install_shim.assert_not_called() + + +@pytest.mark.usefixtures("maybe_install_runner_mock", "maybe_install_shim_mock") +class TestMaybeRestartShim(BaseTestMaybeInstallComponents): + @pytest.fixture + def component_list(self) -> ComponentList: + components = ComponentList() + components.add( + ComponentInfo( + name=ComponentName.RUNNER, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + components.add( + ComponentInfo( + name=ComponentName.SHIM, + version=self.EXPECTED_VERSION, + status=ComponentStatus.INSTALLED, + ), + ) + return components + + @pytest.fixture + def maybe_install_runner_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=False) + monkeypatch.setattr(instances_check, "_maybe_install_runner", mock) + return mock + + @pytest.fixture + def maybe_install_shim_mock(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(return_value=False) + monkeypatch.setattr(instances_check, "_maybe_install_shim", mock) + return mock + + async def test_up_to_date(self, test_db, instance: InstanceModel, shim_client_mock: Mock): + shim_client_mock.get_version_string.return_value = self.EXPECTED_VERSION + shim_client_mock.is_safe_to_restart.return_value = True + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_no_shim_component_info( + self, test_db, instance: InstanceModel, shim_client_mock: Mock + ): + shim_client_mock.get_components.return_value = ComponentList() + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_shutdown_requested( + self, test_db, instance: InstanceModel, shim_client_mock: Mock + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_called_once_with(force=False) + + async def test_outdated_but_task_wont_survive_restart( + self, test_db, instance: InstanceModel, shim_client_mock: Mock + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = False + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_runner_installation_in_progress( + self, + test_db, + instance: InstanceModel, + shim_client_mock: Mock, + component_list: ComponentList, + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + runner_info = component_list.runner + assert runner_info is not None + runner_info.status = ComponentStatus.INSTALLING + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_shim_installation_in_progress( + self, + test_db, + instance: InstanceModel, + shim_client_mock: Mock, + component_list: ComponentList, + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + shim_info = component_list.shim + assert shim_info is not None + shim_info.status = ComponentStatus.INSTALLING + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_runner_installation_requested( + self, + test_db, + instance: InstanceModel, + shim_client_mock: Mock, + maybe_install_runner_mock: Mock, + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + maybe_install_runner_mock.return_value = True + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() + + async def test_outdated_but_shim_installation_requested( + self, + test_db, + instance: InstanceModel, + shim_client_mock: Mock, + maybe_install_shim_mock: Mock, + ): + shim_client_mock.get_version_string.return_value = "outdated" + shim_client_mock.is_safe_to_restart.return_value = True + maybe_install_shim_mock.return_value = True + + instances_check._maybe_install_components(instance, shim_client_mock) + + shim_client_mock.get_components.assert_called_once() + shim_client_mock.shutdown.assert_not_called() diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py new file mode 100644 index 0000000000..ff76033e0a --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py @@ -0,0 +1,452 @@ +import datetime as dt +from collections import defaultdict +from typing import Optional +from unittest.mock import Mock, patch + +import gpuhunt +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import NoCapacityError, ProvisioningError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.fleets import FleetNodesSpec, InstanceGroupPlacement +from dstack._internal.core.models.instances import ( + Gpu, + InstanceAvailability, + InstanceOffer, + InstanceOfferWithAvailability, + InstanceStatus, + InstanceTerminationReason, + InstanceType, + Resources, +) +from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData +from dstack._internal.core.models.runs import JobProvisioningData +from dstack._internal.server.background.pipeline_tasks.instances import InstanceWorker +from dstack._internal.server.models import PlacementGroupModel +from dstack._internal.server.testing.common import ( + ComputeMockSpec, + create_fleet, + create_instance, + create_project, + get_fleet_configuration, + get_fleet_spec, + get_instance_offer_with_availability, + get_job_provisioning_data, + get_placement_group_provisioning_data, +) +from tests._internal.server.background.pipeline_tasks.test_instances.helpers import ( + process_instance, +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestCloudProvisioning: + @pytest.mark.parametrize( + ["cpus", "gpus", "requested_blocks", "expected_blocks"], + [ + pytest.param(32, 8, 1, 1, id="gpu-instance-no-blocks"), + pytest.param(32, 8, 2, 2, id="gpu-instance-four-gpu-per-block"), + pytest.param(32, 8, 4, 4, id="gpu-instance-two-gpus-per-block"), + pytest.param(32, 8, None, 8, id="gpu-instance-auto-max-gpu"), + pytest.param(4, 8, None, 4, id="gpu-instance-auto-max-cpu"), + pytest.param(8, 8, None, 8, id="gpu-instance-auto-max-cpu-and-gpu"), + pytest.param(32, 0, 1, 1, id="cpu-instance-no-blocks"), + pytest.param(32, 0, 2, 2, id="cpu-instance-four-cpu-per-block"), + pytest.param(32, 0, 4, 4, id="cpu-instance-two-cpus-per-block"), + pytest.param(32, 0, None, 32, id="cpu-instance-auto-max-cpu"), + ], + ) + async def test_creates_instance( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + cpus: int, + gpus: int, + requested_blocks: Optional[int], + expected_blocks: int, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + total_blocks=requested_blocks, + busy_blocks=0, + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + backend_mock = Mock() + m.return_value = [backend_mock] + backend_mock.TYPE = BackendType.AWS + gpu = Gpu(name="T4", memory_mib=16384, vendor=gpuhunt.AcceleratorVendor.NVIDIA) + offer = InstanceOfferWithAvailability( + backend=BackendType.AWS, + instance=InstanceType( + name="instance", + resources=Resources( + cpus=cpus, + memory_mib=131072, + spot=False, + gpus=[gpu] * gpus, + ), + ), + region="us", + price=1.0, + availability=InstanceAvailability.AVAILABLE, + total_blocks=expected_blocks, + ) + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [offer] + backend_mock.compute.return_value.create_instance.return_value = JobProvisioningData( + backend=offer.backend, + instance_type=offer.instance, + instance_id="instance_id", + hostname="1.1.1.1", + internal_ip=None, + region=offer.region, + price=offer.price, + username="ubuntu", + ssh_port=22, + ssh_proxy=None, + dockerized=True, + backend_data=None, + ) + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + assert instance.total_blocks == expected_blocks + assert instance.busy_blocks == 0 + + @pytest.mark.parametrize("err", [RuntimeError("Unexpected"), ProvisioningError("Expected")]) + async def test_tries_second_offer_if_first_fails( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + err: Exception, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + ) + aws_mock = Mock() + aws_mock.TYPE = BackendType.AWS + offer = get_instance_offer_with_availability(backend=BackendType.AWS, price=1.0) + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + aws_mock.compute.return_value.get_offers.return_value = [offer] + aws_mock.compute.return_value.create_instance.side_effect = err + gcp_mock = Mock() + gcp_mock.TYPE = BackendType.GCP + offer = get_instance_offer_with_availability(backend=BackendType.GCP, price=2.0) + gcp_mock.compute.return_value = Mock(spec=ComputeMockSpec) + gcp_mock.compute.return_value.get_offers.return_value = [offer] + gcp_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=offer.backend, + region=offer.region, + price=offer.price, + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [aws_mock, gcp_mock] + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + aws_mock.compute.return_value.create_instance.assert_called_once() + assert instance.backend == BackendType.GCP + + @pytest.mark.parametrize("err", [RuntimeError("Unexpected"), ProvisioningError("Expected")]) + async def test_fails_if_all_offers_fail( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + err: Exception, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + ) + aws_mock = Mock() + aws_mock.TYPE = BackendType.AWS + offer = get_instance_offer_with_availability(backend=BackendType.AWS, price=1.0) + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + aws_mock.compute.return_value.get_offers.return_value = [offer] + aws_mock.compute.return_value.create_instance.side_effect = err + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [aws_mock] + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.NO_OFFERS + + async def test_fails_if_no_offers( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [] + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.NO_OFFERS + + @pytest.mark.parametrize( + ("placement", "expected_termination_reasons"), + [ + pytest.param( + InstanceGroupPlacement.CLUSTER, + { + InstanceTerminationReason.NO_OFFERS: 1, + InstanceTerminationReason.MASTER_FAILED: 3, + }, + id="cluster", + ), + pytest.param( + None, + {InstanceTerminationReason.NO_OFFERS: 4}, + id="non-cluster", + ), + ], + ) + async def test_terminates_cluster_instances_if_master_not_created( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + placement: Optional[InstanceGroupPlacement], + expected_termination_reasons: dict[str, int], + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=placement, nodes=FleetNodesSpec(min=4, target=4, max=4) + ) + ), + ) + instances = [ + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=index, + created_at=dt.datetime.now(dt.timezone.utc) + dt.timedelta(seconds=index), + ) + for index in range(4) + ] + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [] + for instance in sorted(instances, key=lambda i: (i.instance_num, i.created_at)): + await process_instance(session, worker, instance) + + termination_reasons = defaultdict(int) + for instance in instances: + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + termination_reasons[instance.termination_reason] += 1 + assert termination_reasons == expected_termination_reasons + + @pytest.mark.parametrize( + ("placement", "should_create"), + [ + pytest.param(InstanceGroupPlacement.CLUSTER, True, id="placement-cluster"), + pytest.param(None, False, id="no-placement"), + ], + ) + async def test_create_placement_group_if_placement_cluster( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + placement: Optional[InstanceGroupPlacement], + should_create: bool, + ) -> None: + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=placement, nodes=FleetNodesSpec(min=1, target=1, max=1) + ) + ), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + ) + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability() + ] + backend_mock.compute.return_value.create_instance.return_value = ( + get_job_provisioning_data() + ) + backend_mock.compute.return_value.create_placement_group.return_value = ( + get_placement_group_provisioning_data() + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() + if should_create: + assert backend_mock.compute.return_value.create_placement_group.call_count == 1 + assert len(placement_groups) == 1 + else: + assert backend_mock.compute.return_value.create_placement_group.call_count == 0 + assert len(placement_groups) == 0 + + @pytest.mark.parametrize("can_reuse", [True, False]) + async def test_reuses_placement_group_between_offers_if_the_group_is_suitable( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + can_reuse: bool, + ) -> None: + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=1, target=1, max=1), + ) + ), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + ) + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(instance_type="bad-offer-1"), + get_instance_offer_with_availability(instance_type="bad-offer-2"), + get_instance_offer_with_availability(instance_type="good-offer"), + ] + + def create_instance_method( + instance_offer: InstanceOfferWithAvailability, *args, **kwargs + ) -> JobProvisioningData: + if instance_offer.instance.name == "good-offer": + return get_job_provisioning_data() + raise NoCapacityError() + + backend_mock.compute.return_value.create_instance = create_instance_method + backend_mock.compute.return_value.create_placement_group.return_value = ( + get_placement_group_provisioning_data() + ) + backend_mock.compute.return_value.is_suitable_placement_group.return_value = can_reuse + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() + if can_reuse: + assert backend_mock.compute.return_value.create_placement_group.call_count == 1 + assert len(placement_groups) == 1 + else: + assert backend_mock.compute.return_value.create_placement_group.call_count == 3 + assert len(placement_groups) == 3 + to_be_deleted_count = sum(pg.fleet_deleted for pg in placement_groups) + assert to_be_deleted_count == 2 + + @pytest.mark.parametrize("err", [NoCapacityError(), RuntimeError()]) + async def test_handles_create_placement_group_errors( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + err: Exception, + ) -> None: + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=1, target=1, max=1), + ) + ), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + ) + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(instance_type="bad-offer"), + get_instance_offer_with_availability(instance_type="good-offer"), + ] + backend_mock.compute.return_value.create_instance.return_value = ( + get_job_provisioning_data() + ) + + def create_placement_group_method( + placement_group: PlacementGroup, master_instance_offer: InstanceOffer + ) -> PlacementGroupProvisioningData: + if master_instance_offer.instance.name == "good-offer": + return get_placement_group_provisioning_data() + raise err + + backend_mock.compute.return_value.create_placement_group = create_placement_group_method + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PROVISIONING + assert instance.offer + assert "good-offer" in instance.offer + assert "bad-offer" not in instance.offer + placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() + assert len(placement_groups) == 1 diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_pipeline.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_pipeline.py new file mode 100644 index 0000000000..012c7fdb38 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_pipeline.py @@ -0,0 +1,253 @@ +import datetime as dt +import uuid +from unittest.mock import Mock + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.server.background.pipeline_tasks.instances import ( + InstanceFetcher, + InstancePipeline, + InstanceWorker, +) +from dstack._internal.server.background.pipeline_tasks.instances import check as instances_check +from dstack._internal.server.schemas.instances import InstanceCheck +from dstack._internal.server.testing.common import ( + create_compute_group, + create_fleet, + create_instance, + create_project, +) +from dstack._internal.utils.common import get_current_datetime +from tests._internal.server.background.pipeline_tasks.test_instances.helpers import ( + instance_to_pipeline_item, + lock_instance, + process_instance, +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestInstanceFetcher: + async def test_fetch_selects_eligible_instances_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: InstanceFetcher + ): + project = await create_project(session=session) + fleet = await create_fleet(session=session, project=project) + compute_group = await create_compute_group(session=session, project=project, fleet=fleet) + now = get_current_datetime() + stale = now - dt.timedelta(minutes=1) + + pending = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + last_processed_at=stale - dt.timedelta(seconds=5), + ) + provisioning = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + name="provisioning", + last_processed_at=stale - dt.timedelta(seconds=4), + ) + busy = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + name="busy", + last_processed_at=stale - dt.timedelta(seconds=3), + ) + idle = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + name="idle", + last_processed_at=stale - dt.timedelta(seconds=2), + ) + terminating = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + name="terminating", + last_processed_at=stale - dt.timedelta(seconds=1), + ) + + deleted = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + name="deleted", + last_processed_at=stale, + ) + deleted.deleted = True + + recent = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + name="recent", + last_processed_at=now, + ) + + terminating_compute_group = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + name="terminating-compute-group", + last_processed_at=stale + dt.timedelta(seconds=1), + ) + terminating_compute_group.compute_group = compute_group + + locked = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + name="locked", + last_processed_at=stale + dt.timedelta(seconds=2), + ) + locked.lock_expires_at = now + dt.timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert {item.id for item in items} == { + pending.id, + provisioning.id, + busy.id, + idle.id, + terminating.id, + } + assert {item.status for item in items} == { + InstanceStatus.PENDING, + InstanceStatus.PROVISIONING, + InstanceStatus.BUSY, + InstanceStatus.IDLE, + InstanceStatus.TERMINATING, + } + + for instance in [ + pending, + provisioning, + busy, + idle, + terminating, + deleted, + recent, + terminating_compute_group, + locked, + ]: + await session.refresh(instance) + + expected_lock_owner = InstancePipeline.__name__ + fetched_instances = [pending, provisioning, busy, idle, terminating] + assert all(instance.lock_owner == expected_lock_owner for instance in fetched_instances) + assert all(instance.lock_expires_at is not None for instance in fetched_instances) + assert all(instance.lock_token is not None for instance in fetched_instances) + assert len({instance.lock_token for instance in fetched_instances}) == 1 + + assert deleted.lock_owner is None + assert recent.lock_owner is None + assert terminating_compute_group.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_respects_order_and_limit( + self, test_db, session: AsyncSession, fetcher: InstanceFetcher + ): + project = await create_project(session=session) + now = get_current_datetime() + + oldest = await create_instance( + session=session, + project=project, + name="oldest", + last_processed_at=now - dt.timedelta(minutes=3), + ) + middle = await create_instance( + session=session, + project=project, + name="middle", + last_processed_at=now - dt.timedelta(minutes=2), + ) + newest = await create_instance( + session=session, + project=project, + name="newest", + last_processed_at=now - dt.timedelta(minutes=1), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == InstancePipeline.__name__ + assert middle.lock_owner == InstancePipeline.__name__ + assert newest.lock_owner is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestInstanceWorker: + async def test_process_skips_when_lock_token_changes( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.IDLE, + ) + + lock_instance(instance) + await session.commit() + item = instance_to_pipeline_item(instance) + new_lock_token = uuid.uuid4() + instance.lock_token = new_lock_token + await session.commit() + + await worker.process(item) + await session.refresh(instance) + + assert instance.lock_token == new_lock_token + assert instance.lock_owner == InstancePipeline.__name__ + + async def test_process_unlocks_and_updates_last_processed_at_after_check( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PROVISIONING, + ) + before_processed_at = instance.last_processed_at + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + await process_instance(session, worker, instance) + + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.lock_expires_at is None + assert instance.lock_token is None + assert instance.lock_owner is None + assert instance.last_processed_at > before_processed_at diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_ssh_deploy.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_ssh_deploy.py new file mode 100644 index 0000000000..c103458ed4 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_ssh_deploy.py @@ -0,0 +1,248 @@ +import datetime as dt +from typing import Optional +from unittest.mock import Mock + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import SSHProvisioningError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.server.background.pipeline_tasks.instances import InstanceWorker +from dstack._internal.server.background.pipeline_tasks.instances import ( + ssh_deploy as instances_ssh_deploy, +) +from dstack._internal.server.testing.common import ( + create_instance, + create_project, + get_job_provisioning_data, + get_remote_connection_info, +) +from dstack._internal.utils.common import get_current_datetime +from tests._internal.server.background.pipeline_tasks.test_instances.helpers import ( + process_instance, +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestSSHDeploy: + async def test_pending_ssh_instance_terminates_on_provision_timeout( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime() - dt.timedelta(days=100), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.PROVISIONING_TIMEOUT + + @pytest.mark.parametrize( + ["cpus", "gpus", "requested_blocks", "expected_blocks"], + [ + pytest.param(32, 8, 1, 1, id="gpu-instance-no-blocks"), + pytest.param(32, 8, 2, 2, id="gpu-instance-four-gpu-per-block"), + pytest.param(32, 8, 4, 4, id="gpu-instance-two-gpus-per-block"), + pytest.param(32, 8, None, 8, id="gpu-instance-auto-max-gpu"), + pytest.param(4, 8, None, 4, id="gpu-instance-auto-max-cpu"), + pytest.param(8, 8, None, 8, id="gpu-instance-auto-max-cpu-and-gpu"), + pytest.param(32, 0, 1, 1, id="cpu-instance-no-blocks"), + pytest.param(32, 0, 2, 2, id="cpu-instance-four-cpu-per-block"), + pytest.param(32, 0, 4, 4, id="cpu-instance-two-cpus-per-block"), + pytest.param(32, 0, None, 32, id="cpu-instance-auto-max-cpu"), + ], + ) + async def test_adds_ssh_instance( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + host_info: dict, + deploy_instance_mock: Mock, + cpus: int, + gpus: int, + requested_blocks: Optional[int], + expected_blocks: int, + ): + host_info["cpus"] = cpus + host_info["gpu_count"] = gpus + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + total_blocks=requested_blocks, + busy_blocks=0, + ) + await session.commit() + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.IDLE + assert instance.total_blocks == expected_blocks + assert instance.busy_blocks == 0 + deploy_instance_mock.assert_called_once() + + async def test_retries_ssh_instance_if_provisioning_fails( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + deploy_instance_mock: Mock, + ): + deploy_instance_mock.side_effect = SSHProvisioningError("Expected") + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PENDING + assert instance.termination_reason is None + + async def test_terminates_ssh_instance_if_deploy_fails_unexpectedly( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + deploy_instance_mock: Mock, + ): + deploy_instance_mock.side_effect = RuntimeError("Unexpected") + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert instance.termination_reason_message == "Unexpected error when adding SSH instance" + + async def test_terminates_ssh_instance_if_key_is_invalid( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + monkeypatch.setattr( + instances_ssh_deploy, + "ssh_keys_to_pkeys", + Mock(side_effect=ValueError("Bad key")), + ) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + ) + await session.commit() + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert instance.termination_reason_message == "Unsupported private SSH key type" + + async def test_terminates_ssh_instance_if_internal_ip_cannot_be_resolved_from_network( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + host_info: dict, + deploy_instance_mock: Mock, + ): + host_info["addresses"] = ["192.168.100.100/24"] + project = await create_project(session=session) + job_provisioning_data = get_job_provisioning_data( + dockerized=True, + backend=BackendType.REMOTE, + internal_ip=None, + ) + job_provisioning_data.instance_network = "10.0.0.0/24" + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + job_provisioning_data=job_provisioning_data, + ) + await session.commit() + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert ( + instance.termination_reason_message + == "Failed to locate internal IP address on the given network" + ) + + async def test_terminates_ssh_instance_if_internal_ip_is_not_in_host_interfaces( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + host_info: dict, + deploy_instance_mock: Mock, + ): + host_info["addresses"] = ["192.168.100.100/24"] + project = await create_project(session=session) + job_provisioning_data = get_job_provisioning_data( + dockerized=True, + backend=BackendType.REMOTE, + internal_ip="10.0.0.20", + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.PENDING, + created_at=get_current_datetime(), + remote_connection_info=get_remote_connection_info(), + job_provisioning_data=job_provisioning_data, + ) + await session.commit() + + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.ERROR + assert ( + instance.termination_reason_message + == "Specified internal IP not found among instance interfaces" + ) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_termination.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_termination.py new file mode 100644 index 0000000000..b9da58fc11 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_termination.py @@ -0,0 +1,219 @@ +import datetime as dt +from contextlib import contextmanager +from typing import Optional +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from freezegun import freeze_time +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import BackendError, NotYetTerminated +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.server.background.pipeline_tasks.instances import InstanceWorker +from dstack._internal.server.background.pipeline_tasks.instances import ( + termination as instances_termination, +) +from dstack._internal.server.testing.common import create_instance, create_project +from tests._internal.server.background.pipeline_tasks.test_instances.helpers import ( + instance_to_pipeline_item, + lock_instance, + process_instance, +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestTermination: + @staticmethod + @contextmanager + def mock_terminate_in_backend(error: Optional[Exception] = None): + backend = Mock() + backend.TYPE = BackendType.VERDA + terminate_instance = backend.compute.return_value.terminate_instance + if error is not None: + terminate_instance.side_effect = error + with patch.object( + instances_termination.backends_services, + "get_project_backend_by_type", + AsyncMock(return_value=backend), + ): + yield terminate_instance + + async def test_terminate( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + instance.last_job_processed_at = dt.datetime.now(dt.timezone.utc) + dt.timedelta( + minutes=-19 + ) + await session.commit() + + with self.mock_terminate_in_backend() as mock: + await process_instance(session, worker, instance) + mock.assert_called_once() + + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATED + assert instance.termination_reason == InstanceTerminationReason.IDLE_TIMEOUT + assert instance.deleted is True + assert instance.deleted_at is not None + assert instance.finished_at is not None + + async def test_terminates_terminating_deleted_instance( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + lock_instance(instance) + await session.commit() + item = instance_to_pipeline_item(instance) + instance.deleted = True + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + instance.last_job_processed_at = instance.deleted_at = dt.datetime.now( + dt.timezone.utc + ) + dt.timedelta(minutes=-19) + await session.commit() + + with self.mock_terminate_in_backend() as mock: + await worker.process(item) + mock.assert_called_once() + + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATED + assert instance.deleted is True + assert instance.deleted_at is not None + assert instance.finished_at is not None + + @pytest.mark.parametrize( + "error", [BackendError("err"), RuntimeError("err"), NotYetTerminated("")] + ) + async def test_terminate_retry( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + error: Exception, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc) + instance.last_job_processed_at = initial_time + instance.last_processed_at = initial_time - dt.timedelta(minutes=1) + await session.commit() + + with ( + freeze_time(initial_time + dt.timedelta(minutes=1)), + self.mock_terminate_in_backend(error=error) as mock, + ): + await process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATING + + with ( + freeze_time(initial_time + dt.timedelta(minutes=2)), + self.mock_terminate_in_backend(error=None) as mock, + ): + await process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED + + async def test_terminate_not_retries_if_too_early( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc) + instance.last_job_processed_at = initial_time + instance.last_processed_at = initial_time - dt.timedelta(minutes=1) + await session.commit() + + with ( + freeze_time(initial_time + dt.timedelta(minutes=1)), + self.mock_terminate_in_backend(error=BackendError("err")) as mock, + ): + await process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATING + + instance.last_processed_at = initial_time + await session.commit() + + with ( + freeze_time(initial_time + dt.timedelta(minutes=1, seconds=11)), + self.mock_terminate_in_backend(error=None) as mock, + ): + await process_instance(session, worker, instance) + mock.assert_not_called() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATING + + async def test_terminate_on_termination_deadline( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.TERMINATING, + ) + instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT + initial_time = dt.datetime(2025, 1, 1, tzinfo=dt.timezone.utc) + instance.last_job_processed_at = initial_time + instance.last_processed_at = initial_time - dt.timedelta(minutes=1) + await session.commit() + + with ( + freeze_time(initial_time + dt.timedelta(minutes=1)), + self.mock_terminate_in_backend(error=BackendError("err")) as mock, + ): + await process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATING + + with ( + freeze_time(initial_time + dt.timedelta(minutes=15, seconds=55)), + self.mock_terminate_in_backend(error=None) as mock, + ): + await process_instance(session, worker, instance) + mock.assert_called_once() + await session.refresh(instance) + assert instance.status == InstanceStatus.TERMINATED diff --git a/src/tests/_internal/server/services/test_instances.py b/src/tests/_internal/server/services/test_instances.py index 4883e309cc..0eef682003 100644 --- a/src/tests/_internal/server/services/test_instances.py +++ b/src/tests/_internal/server/services/test_instances.py @@ -1,4 +1,5 @@ import uuid +from unittest.mock import Mock, call import pytest from sqlalchemy.ext.asyncio import AsyncSession @@ -14,10 +15,16 @@ Resources, ) from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.runs import JobStatus from dstack._internal.server.models import InstanceModel +from dstack._internal.server.schemas.runner import TaskListItem, TaskListResponse, TaskStatus +from dstack._internal.server.services.runner.client import ShimClient from dstack._internal.server.testing.common import ( create_instance, + create_job, create_project, + create_repo, + create_run, create_user, get_volume, get_volume_configuration, @@ -155,6 +162,117 @@ async def test_returns_volume_instances(self, test_db, session: AsyncSession): assert res == [runpod_instance2] +@pytest.mark.asyncio +@pytest.mark.usefixtures("image_config_mock") +@pytest.mark.usefixtures("turn_off_keep_shim_tasks_setting") +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestRemoveDanglingTasks: + @pytest.fixture + def turn_off_keep_shim_tasks_setting(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("dstack._internal.server.settings.SERVER_KEEP_SHIM_TASKS", False) + + async def test_terminates_and_removes_dangling_tasks( + self, test_db, session: AsyncSession + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + instance=instance, + ) + dangling_task_id_1 = "fe138b77-d0b1-49d3-8c9f-2dfe78ece727" + dangling_task_id_2 = "8b016a75-41de-44f1-91ff-c9b63d2caa1d" + shim_client_mock = Mock(spec_set=ShimClient) + shim_client_mock.is_api_v2_supported.return_value = True + shim_client_mock.list_tasks.return_value = TaskListResponse( + tasks=[ + TaskListItem(id=str(job.id), status=TaskStatus.RUNNING), + TaskListItem(id=dangling_task_id_1, status=TaskStatus.RUNNING), + TaskListItem(id=dangling_task_id_2, status=TaskStatus.TERMINATED), + ] + ) + await session.refresh(instance, attribute_names=["jobs"]) + + instances_services.remove_dangling_tasks_from_instance(shim_client_mock, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.BUSY + + shim_client_mock.terminate_task.assert_called_once_with( + task_id=dangling_task_id_1, + reason=None, + message=None, + timeout=0, + ) + assert shim_client_mock.remove_task.call_count == 2 + shim_client_mock.remove_task.assert_has_calls( + [call(task_id=dangling_task_id_1), call(task_id=dangling_task_id_2)] + ) + + async def test_terminates_and_removes_dangling_tasks_legacy_shim( + self, test_db, session: AsyncSession + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + instance=instance, + ) + dangling_task_id_1 = "fe138b77-d0b1-49d3-8c9f-2dfe78ece727" + dangling_task_id_2 = "8b016a75-41de-44f1-91ff-c9b63d2caa1d" + shim_client_mock = Mock(spec_set=ShimClient) + shim_client_mock.is_api_v2_supported.return_value = True + shim_client_mock.list_tasks.return_value = TaskListResponse( + ids=[str(job.id), dangling_task_id_1, dangling_task_id_2] + ) + await session.refresh(instance, attribute_names=["jobs"]) + + instances_services.remove_dangling_tasks_from_instance(shim_client_mock, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.BUSY + + assert shim_client_mock.terminate_task.call_count == 2 + shim_client_mock.terminate_task.assert_has_calls( + [ + call(task_id=dangling_task_id_1, reason=None, message=None, timeout=0), + call(task_id=dangling_task_id_2, reason=None, message=None, timeout=0), + ] + ) + assert shim_client_mock.remove_task.call_count == 2 + shim_client_mock.remove_task.assert_has_calls( + [call(task_id=dangling_task_id_1), call(task_id=dangling_task_id_2)] + ) + + class TestInstanceModelToInstance: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) From 197c5f68b6135525b1706ee769c974e57b8ba391 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 14:09:03 +0500 Subject: [PATCH 17/51] Inline _get_effective_ helpers --- .../pipeline_tasks/instances/__init__.py | 114 +++++++----------- 1 file changed, 42 insertions(+), 72 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py index d3ca224cea..96389a516c 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -31,7 +31,6 @@ create_cloud_instance, ) from dstack._internal.server.background.pipeline_tasks.instances.common import ( - InstanceUpdateMap, ProcessResult, ) from dstack._internal.server.background.pipeline_tasks.instances.ssh_deploy import ( @@ -395,69 +394,6 @@ async def _refetch_locked_instance_for_check( return res.unique().scalar_one_or_none() -def _get_effective_instance_status( - instance_model: InstanceModel, update_map: InstanceUpdateMap -) -> InstanceStatus: - return update_map.get("status", instance_model.status) - - -def _get_effective_instance_termination_reason( - instance_model: InstanceModel, update_map: InstanceUpdateMap -): - return update_map.get("termination_reason", instance_model.termination_reason) - - -def _get_effective_instance_termination_reason_message( - instance_model: InstanceModel, update_map: InstanceUpdateMap -): - return update_map.get("termination_reason_message", instance_model.termination_reason_message) - - -def _get_effective_instance_health( - instance_model: InstanceModel, update_map: InstanceUpdateMap -) -> HealthStatus: - return update_map.get("health", instance_model.health) - - -def _get_effective_instance_unreachable( - instance_model: InstanceModel, update_map: InstanceUpdateMap -) -> bool: - return update_map.get("unreachable", instance_model.unreachable) - - -def _emit_instance_health_change_event( - session: AsyncSession, - instance_model: InstanceModel, - old_health: HealthStatus, - new_health: HealthStatus, -) -> None: - if old_health == new_health: - return - events.emit( - session, - f"Instance health changed {old_health.upper()} -> {new_health.upper()}", - actor=events.SystemActor(), - targets=[events.Target.from_model(instance_model)], - ) - - -def _emit_instance_reachability_change_event( - session: AsyncSession, - instance_model: InstanceModel, - old_status: InstanceStatus, - old_unreachable: bool, - new_unreachable: bool, -) -> None: - if not old_status.is_available() or old_unreachable == new_unreachable: - return - events.emit( - session, - "Instance became unreachable" if new_unreachable else "Instance became reachable", - actor=events.SystemActor(), - targets=[events.Target.from_model(instance_model)], - ) - - async def _apply_process_result(item: InstancePipelineItem, result: ProcessResult) -> None: async with get_session_ctx() as session: res = await session.execute( @@ -540,27 +476,28 @@ async def _apply_process_result(item: InstancePipelineItem, result: ProcessResul session=session, instance_model=instance_model, old_status=instance_model.status, - new_status=_get_effective_instance_status(instance_model, result.instance_update_map), - termination_reason=_get_effective_instance_termination_reason( - instance_model, result.instance_update_map + new_status=result.instance_update_map.get("status", instance_model.status), + termination_reason=result.instance_update_map.get( + "termination_reason", instance_model.termination_reason ), - termination_reason_message=_get_effective_instance_termination_reason_message( - instance_model, result.instance_update_map + termination_reason_message=result.instance_update_map.get( + "termination_reason_message", + instance_model.termination_reason_message, ), ) _emit_instance_health_change_event( session=session, instance_model=instance_model, old_health=instance_model.health, - new_health=_get_effective_instance_health(instance_model, result.instance_update_map), + new_health=result.instance_update_map.get("health", instance_model.health), ) _emit_instance_reachability_change_event( session=session, instance_model=instance_model, old_status=instance_model.status, old_unreachable=instance_model.unreachable, - new_unreachable=_get_effective_instance_unreachable( - instance_model, result.instance_update_map + new_unreachable=result.instance_update_map.get( + "unreachable", instance_model.unreachable ), ) @@ -578,3 +515,36 @@ async def _apply_process_result(item: InstancePipelineItem, result: ProcessResul ) ], ) + + +def _emit_instance_health_change_event( + session: AsyncSession, + instance_model: InstanceModel, + old_health: HealthStatus, + new_health: HealthStatus, +) -> None: + if old_health == new_health: + return + events.emit( + session, + f"Instance health changed {old_health.upper()} -> {new_health.upper()}", + actor=events.SystemActor(), + targets=[events.Target.from_model(instance_model)], + ) + + +def _emit_instance_reachability_change_event( + session: AsyncSession, + instance_model: InstanceModel, + old_status: InstanceStatus, + old_unreachable: bool, + new_unreachable: bool, +) -> None: + if not old_status.is_available() or old_unreachable == new_unreachable: + return + events.emit( + session, + "Instance became unreachable" if new_unreachable else "Instance became reachable", + actor=events.SystemActor(), + targets=[events.Target.from_model(instance_model)], + ) From 0b3acbb4b983a028aab42c6bf73206966be5d361 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 14:48:10 +0500 Subject: [PATCH 18/51] Process new instance immediately --- .../server/background/pipeline_tasks/instances/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py index 96389a516c..c5a5b4b639 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -167,7 +167,10 @@ async def fetch(self, limit: int) -> list[InstancePipelineItem]: ) ), InstanceModel.deleted == False, - InstanceModel.last_processed_at <= now - self._min_processing_interval, + or_( + InstanceModel.last_processed_at <= now - self._min_processing_interval, + InstanceModel.last_processed_at == InstanceModel.created_at, + ), or_( InstanceModel.lock_expires_at.is_(None), InstanceModel.lock_expires_at < now, From 07c484b4af44efcc8aa96eba8fdca1dd29868566 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 14:55:53 +0500 Subject: [PATCH 19/51] Do not refetch status --- .../pipeline_tasks/instances/__init__.py | 31 +++---------------- .../test_instances/test_cloud_provisioning.py | 1 + 2 files changed, 6 insertions(+), 26 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py index c5a5b4b639..392bf97fc7 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -228,23 +228,16 @@ def __init__( @sentry_utils.instrument_named_task("pipeline_tasks.InstanceWorker.process") async def process(self, item: InstancePipelineItem): - async with get_session_ctx() as session: - instance_model = await _refetch_locked_instance_status(session=session, item=item) - if instance_model is None: - log_lock_token_mismatch(logger, item) - return - status = instance_model.status - result: Optional[ProcessResult] = None - if status == InstanceStatus.PENDING: + if item.status == InstanceStatus.PENDING: result = await _process_pending_item(item) - elif status == InstanceStatus.PROVISIONING: + elif item.status == InstanceStatus.PROVISIONING: result = await _process_provisioning_item(item) - elif status == InstanceStatus.IDLE: + elif item.status == InstanceStatus.IDLE: result = await _process_idle_item(item) - elif status == InstanceStatus.BUSY: + elif item.status == InstanceStatus.BUSY: result = await _process_busy_item(item) - elif status == InstanceStatus.TERMINATING: + elif item.status == InstanceStatus.TERMINATING: result = await _process_terminating_item(item) if result is None: return @@ -309,20 +302,6 @@ async def _process_terminating_item(item: InstancePipelineItem) -> Optional[Proc return await terminate_instance(instance_model) -async def _refetch_locked_instance_status( - session: AsyncSession, item: InstancePipelineItem -) -> Optional[InstanceModel]: - res = await session.execute( - select(InstanceModel) - .where( - InstanceModel.id == item.id, - InstanceModel.lock_token == item.lock_token, - ) - .options(load_only(InstanceModel.status)) - ) - return res.scalar_one_or_none() - - async def _refetch_locked_instance_for_pending_or_terminating( session: AsyncSession, item: InstancePipelineItem ) -> Optional[InstanceModel]: diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py index ff76033e0a..e34a758456 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py @@ -261,6 +261,7 @@ async def test_terminates_cluster_instances_if_master_not_created( with patch("dstack._internal.server.services.backends.get_project_backends") as m: m.return_value = [] for instance in sorted(instances, key=lambda i: (i.instance_num, i.created_at)): + await session.refresh(instance) await process_instance(session, worker, instance) termination_reasons = defaultdict(int) From 5f8980c22192a1bdb6c97fe76106a74f5c403ee7 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 14:58:50 +0500 Subject: [PATCH 20/51] Fix sibling_update_rows --- .../pipeline_tasks/instances/__init__.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py index 392bf97fc7..27fbf6e7d3 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -433,19 +433,11 @@ async def _apply_process_result(item: InstancePipelineItem, result: ProcessResul log_lock_token_changed_after_processing(logger, item) return - for sibling_update_row in result.sibling_update_rows: - sibling_id = sibling_update_row.get("id") - if sibling_id is None: - continue - sibling_values = { - key: value for key, value in sibling_update_row.items() if key != "id" - } - if sibling_values: - await session.execute( - update(InstanceModel) - .where(InstanceModel.id == sibling_id) - .values(**sibling_values) - ) + if result.sibling_update_rows: + await session.execute( + update(InstanceModel).execution_options(synchronize_session=False), + result.sibling_update_rows, + ) if result.schedule_pg_deletion_fleet_id is not None: await schedule_fleet_placement_groups_deletion( From bd9342b506e03e325361b04b6e0c67e557df737d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 15:05:16 +0500 Subject: [PATCH 21/51] Drop redundant synchronize_session=False --- src/dstack/_internal/server/background/pipeline_tasks/fleets.py | 2 +- .../server/background/pipeline_tasks/instances/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 6d6295de5e..a18e817a4a 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -304,7 +304,7 @@ async def process(self, item: PipelineItem): ) if instance_update_rows: await session.execute( - update(InstanceModel).execution_options(synchronize_session=False), + update(InstanceModel), instance_update_rows, ) if result.new_instances_count > 0: diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py index 27fbf6e7d3..546226180b 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -435,7 +435,7 @@ async def _apply_process_result(item: InstancePipelineItem, result: ProcessResul if result.sibling_update_rows: await session.execute( - update(InstanceModel).execution_options(synchronize_session=False), + update(InstanceModel), result.sibling_update_rows, ) From 28fa00022854f2f6da87d5400667dfe5a9134805 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 15:12:52 +0500 Subject: [PATCH 22/51] Add ProcessContext --- .../pipeline_tasks/instances/__init__.py | 94 +++++++++---------- 1 file changed, 44 insertions(+), 50 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py index 546226180b..1a0dcf9637 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -228,25 +228,35 @@ def __init__( @sentry_utils.instrument_named_task("pipeline_tasks.InstanceWorker.process") async def process(self, item: InstancePipelineItem): - result: Optional[ProcessResult] = None + process_context: Optional[_ProcessContext] = None if item.status == InstanceStatus.PENDING: - result = await _process_pending_item(item) + process_context = await _process_pending_item(item) elif item.status == InstanceStatus.PROVISIONING: - result = await _process_provisioning_item(item) + process_context = await _process_provisioning_item(item) elif item.status == InstanceStatus.IDLE: - result = await _process_idle_item(item) + process_context = await _process_idle_item(item) elif item.status == InstanceStatus.BUSY: - result = await _process_busy_item(item) + process_context = await _process_busy_item(item) elif item.status == InstanceStatus.TERMINATING: - result = await _process_terminating_item(item) - if result is None: + process_context = await _process_terminating_item(item) + if process_context is None: return - set_processed_update_map_fields(result.instance_update_map) - set_unlock_update_map_fields(result.instance_update_map) - await _apply_process_result(item=item, result=result) + set_processed_update_map_fields(process_context.result.instance_update_map) + set_unlock_update_map_fields(process_context.result.instance_update_map) + await _apply_process_result( + item=item, + instance_model=process_context.instance_model, + result=process_context.result, + ) + + +@dataclass +class _ProcessContext: + instance_model: InstanceModel + result: ProcessResult -async def _process_pending_item(item: InstancePipelineItem) -> Optional[ProcessResult]: +async def _process_pending_item(item: InstancePipelineItem) -> Optional[_ProcessContext]: async with get_session_ctx() as session: instance_model = await _refetch_locked_instance_for_pending_or_terminating( session=session, @@ -256,20 +266,23 @@ async def _process_pending_item(item: InstancePipelineItem) -> Optional[ProcessR log_lock_token_mismatch(logger, item) return None if is_ssh_instance(instance_model): - return await add_ssh_instance(instance_model) - return await create_cloud_instance(instance_model) + result = await add_ssh_instance(instance_model) + else: + result = await create_cloud_instance(instance_model) + return _ProcessContext(instance_model=instance_model, result=result) -async def _process_provisioning_item(item: InstancePipelineItem) -> Optional[ProcessResult]: +async def _process_provisioning_item(item: InstancePipelineItem) -> Optional[_ProcessContext]: async with get_session_ctx() as session: instance_model = await _refetch_locked_instance_for_check(session=session, item=item) if instance_model is None: log_lock_token_mismatch(logger, item) return None - return await check_instance(instance_model) + result = await check_instance(instance_model) + return _ProcessContext(instance_model=instance_model, result=result) -async def _process_idle_item(item: InstancePipelineItem) -> Optional[ProcessResult]: +async def _process_idle_item(item: InstancePipelineItem) -> Optional[_ProcessContext]: async with get_session_ctx() as session: instance_model = await _refetch_locked_instance_for_idle(session=session, item=item) if instance_model is None: @@ -277,20 +290,22 @@ async def _process_idle_item(item: InstancePipelineItem) -> Optional[ProcessResu return None idle_result = process_idle_timeout(instance_model) if idle_result is not None: - return idle_result - return await check_instance(instance_model) + return _ProcessContext(instance_model=instance_model, result=idle_result) + result = await check_instance(instance_model) + return _ProcessContext(instance_model=instance_model, result=result) -async def _process_busy_item(item: InstancePipelineItem) -> Optional[ProcessResult]: +async def _process_busy_item(item: InstancePipelineItem) -> Optional[_ProcessContext]: async with get_session_ctx() as session: instance_model = await _refetch_locked_instance_for_check(session=session, item=item) if instance_model is None: log_lock_token_mismatch(logger, item) return None - return await check_instance(instance_model) + result = await check_instance(instance_model) + return _ProcessContext(instance_model=instance_model, result=result) -async def _process_terminating_item(item: InstancePipelineItem) -> Optional[ProcessResult]: +async def _process_terminating_item(item: InstancePipelineItem) -> Optional[_ProcessContext]: async with get_session_ctx() as session: instance_model = await _refetch_locked_instance_for_pending_or_terminating( session=session, @@ -299,7 +314,8 @@ async def _process_terminating_item(item: InstancePipelineItem) -> Optional[Proc if instance_model is None: log_lock_token_mismatch(logger, item) return None - return await terminate_instance(instance_model) + result = await terminate_instance(instance_model) + return _ProcessContext(instance_model=instance_model, result=result) async def _refetch_locked_instance_for_pending_or_terminating( @@ -376,33 +392,12 @@ async def _refetch_locked_instance_for_check( return res.unique().scalar_one_or_none() -async def _apply_process_result(item: InstancePipelineItem, result: ProcessResult) -> None: +async def _apply_process_result( + item: InstancePipelineItem, + instance_model: InstanceModel, + result: ProcessResult, +) -> None: async with get_session_ctx() as session: - res = await session.execute( - select(InstanceModel) - .where( - InstanceModel.id == item.id, - InstanceModel.lock_token == item.lock_token, - ) - .options( - load_only( - InstanceModel.id, - InstanceModel.project_id, - InstanceModel.name, - InstanceModel.status, - InstanceModel.health, - InstanceModel.unreachable, - InstanceModel.termination_reason, - InstanceModel.termination_reason_message, - InstanceModel.lock_token, - ) - ) - ) - instance_model = res.scalar_one_or_none() - if instance_model is None: - log_lock_token_changed_after_processing(logger, item) - return - if result.health_check_create is not None: session.add(InstanceHealthCheckModel(**result.health_check_create)) if result.placement_group_creates: @@ -424,13 +419,12 @@ async def _apply_process_result(item: InstancePipelineItem, result: ProcessResul InstanceModel.lock_token == item.lock_token, ) .values(**result.instance_update_map) - .execution_options(synchronize_session=False) .returning(InstanceModel.id) ) updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: - await session.rollback() log_lock_token_changed_after_processing(logger, item) + await session.rollback() return if result.sibling_update_rows: From 899c18c0184feedb4b70a50d902caf21567e739c Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 15:25:50 +0500 Subject: [PATCH 23/51] Simplify placement groups code --- .../pipeline_tasks/instances/__init__.py | 16 ++++++++-------- .../instances/cloud_provisioning.py | 13 +++++-------- .../pipeline_tasks/instances/common.py | 15 +++------------ 3 files changed, 16 insertions(+), 28 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py index 1a0dcf9637..cba8f68df1 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -45,7 +45,6 @@ InstanceHealthCheckModel, InstanceModel, JobModel, - PlacementGroupModel, ProjectModel, ) from dstack._internal.server.services import events @@ -400,12 +399,9 @@ async def _apply_process_result( async with get_session_ctx() as session: if result.health_check_create is not None: session.add(InstanceHealthCheckModel(**result.health_check_create)) - if result.placement_group_creates: - session.add_all( - PlacementGroupModel(**placement_group_create) - for placement_group_create in result.placement_group_creates - ) - if result.health_check_create is not None or result.placement_group_creates: + if result.new_placement_group_models: + session.add_all(result.new_placement_group_models) + if result.health_check_create is not None or result.new_placement_group_models: await session.flush() now = get_current_datetime() @@ -437,7 +433,11 @@ async def _apply_process_result( await schedule_fleet_placement_groups_deletion( session=session, fleet_id=result.schedule_pg_deletion_fleet_id, - except_placement_group_ids=result.schedule_pg_deletion_except_ids, + except_placement_group_ids=( + () + if result.schedule_pg_deletion_except_id is None + else (result.schedule_pg_deletion_except_id,) + ), ) emit_instance_status_change_event( diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py index 4a5a658022..36bc2acc29 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py @@ -32,7 +32,6 @@ from dstack._internal.server import settings as server_settings from dstack._internal.server.background.pipeline_tasks.base import NOW_PLACEHOLDER from dstack._internal.server.background.pipeline_tasks.instances.common import ( - PlacementGroupCreate, ProcessResult, SiblingInstanceUpdateMap, append_sibling_status_event, @@ -61,7 +60,7 @@ class _PlacementGroupState: id: uuid.UUID placement_group: PlacementGroup - create_payload: Optional[PlacementGroupCreate] = None + new_model: Optional[PlacementGroupModel] = None async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: @@ -144,14 +143,12 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: if selected_placement_group_state is None: continue if ( - selected_placement_group_state.create_payload is not None + selected_placement_group_state.new_model is not None and selected_placement_group_state.id not in seen_placement_group_ids ): seen_placement_group_ids.add(selected_placement_group_state.id) placement_group_states.append(selected_placement_group_state) - result.placement_group_creates.append( - selected_placement_group_state.create_payload - ) + result.new_placement_group_models.append(selected_placement_group_state.new_model) logger.debug( "Trying %s in %s/%s for $%0.4f per hour", @@ -205,7 +202,7 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: if instance_model.fleet_id is not None and instance_model.id == master_instance_model.id: result.schedule_pg_deletion_fleet_id = instance_model.fleet_id if selected_placement_group_state is not None: - result.schedule_pg_deletion_except_ids = (selected_placement_group_state.id,) + result.schedule_pg_deletion_except_id = selected_placement_group_state.id return result set_status_update( @@ -363,7 +360,7 @@ async def _find_or_create_suitable_placement_group_state( return _PlacementGroupState( id=placement_group_id, placement_group=placement_group, - create_payload=PlacementGroupCreate( + new_model=PlacementGroupModel( id=placement_group_id, name=placement_group.name, project_id=instance_model.project_id, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py index 9c569658b1..a9475cb9f6 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py @@ -21,7 +21,7 @@ UpdateMapDateTime, ) from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout -from dstack._internal.server.models import FleetModel, InstanceModel +from dstack._internal.server.models import FleetModel, InstanceModel, PlacementGroupModel from dstack._internal.server.services.fleets import fleet_model_to_fleet, is_cloud_cluster from dstack._internal.server.services.instances import ( get_instance_provisioning_data, @@ -77,15 +77,6 @@ class HealthCheckCreate(TypedDict): response: str -class PlacementGroupCreate(TypedDict): - id: uuid.UUID - name: str - project_id: uuid.UUID - fleet_id: uuid.UUID - configuration: str - provisioning_data: str - - @dataclass class SiblingDeferredEvent: message: str @@ -100,9 +91,9 @@ class ProcessResult: sibling_update_rows: list[SiblingInstanceUpdateMap] = field(default_factory=list) sibling_deferred_events: list[SiblingDeferredEvent] = field(default_factory=list) health_check_create: Optional[HealthCheckCreate] = None - placement_group_creates: list[PlacementGroupCreate] = field(default_factory=list) + new_placement_group_models: list[PlacementGroupModel] = field(default_factory=list) schedule_pg_deletion_fleet_id: Optional[uuid.UUID] = None - schedule_pg_deletion_except_ids: tuple[uuid.UUID, ...] = () + schedule_pg_deletion_except_id: Optional[uuid.UUID] = None def can_terminate_fleet_instances_on_idle_duration(fleet_model: FleetModel) -> bool: From 316fd8aa0b02951eae8753327a89509a6c78552c Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 15:43:19 +0500 Subject: [PATCH 24/51] Drop _PlacementGroupState --- .../instances/cloud_provisioning.py | 195 ++++++------------ 1 file changed, 65 insertions(+), 130 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py index 36bc2acc29..a119d1eb66 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py @@ -1,10 +1,7 @@ import uuid -from dataclasses import dataclass from typing import Optional, cast from pydantic import ValidationError -from sqlalchemy import select -from sqlalchemy.orm import joinedload from dstack._internal.core.backends.base.compute import ( ComputeWithCreateInstanceSupport, @@ -24,11 +21,7 @@ InstanceStatus, InstanceTerminationReason, ) -from dstack._internal.core.models.placement import ( - PlacementGroup, - PlacementGroupConfiguration, - PlacementStrategy, -) +from dstack._internal.core.models.placement import PlacementGroupConfiguration, PlacementStrategy from dstack._internal.server import settings as server_settings from dstack._internal.server.background.pipeline_tasks.base import NOW_PLACEHOLDER from dstack._internal.server.background.pipeline_tasks.instances.common import ( @@ -49,20 +42,18 @@ get_instance_requirements, ) from dstack._internal.server.services.logging import fmt -from dstack._internal.server.services.placement import placement_group_model_to_placement_group +from dstack._internal.server.services.placement import ( + get_fleet_placement_group_models, + get_placement_group_model_for_instance, + placement_group_model_to_placement_group, + placement_group_model_to_placement_group_optional, +) from dstack._internal.utils.common import get_or_error, run_async from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) -@dataclass -class _PlacementGroupState: - id: uuid.UUID - placement_group: PlacementGroup - new_model: Optional[PlacementGroupModel] = None - - async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: result = ProcessResult() master_instance_model = get_fleet_master_instance(instance_model) @@ -93,9 +84,15 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: ) return result - placement_group_states = await _get_fleet_placement_group_states(instance_model.fleet_id) - placement_group_state = _get_placement_group_state_for_instance( - placement_group_states=placement_group_states, + # The placement group is determined when provisioning the master instance + # and used for all other instances in the fleet. + async with get_session_ctx() as session: + placement_group_models = await get_fleet_placement_group_models( + session=session, + fleet_id=instance_model.fleet_id, + ) + placement_group_model = get_placement_group_model_for_instance( + placement_group_models=placement_group_models, instance_model=instance_model, master_instance_model=master_instance_model, ) @@ -104,74 +101,67 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: profile=profile, requirements=requirements, fleet_model=instance_model.fleet, - placement_group=( - placement_group_state.placement_group if placement_group_state is not None else None - ), + placement_group=placement_group_model_to_placement_group_optional(placement_group_model), blocks="auto" if instance_model.total_blocks is None else instance_model.total_blocks, exclude_not_available=True, ) - seen_placement_group_ids = {state.id for state in placement_group_states} for backend, instance_offer in offers[: server_settings.MAX_OFFERS_TRIED]: if instance_offer.backend not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT: continue compute = backend.compute() assert isinstance(compute, ComputeWithCreateInstanceSupport) - selected_offer = get_instance_offer_for_instance( + instance_offer = get_instance_offer_for_instance( instance_offer=instance_offer, instance_model=instance_model, master_instance_model=master_instance_model, ) - selected_placement_group_state = placement_group_state if ( instance_model.fleet is not None and is_cloud_cluster(instance_model.fleet) and instance_model.id == master_instance_model.id - and selected_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT + and instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT and isinstance(compute, ComputeWithPlacementGroupSupport) and ( - compute.are_placement_groups_compatible_with_reservations(selected_offer.backend) + compute.are_placement_groups_compatible_with_reservations(instance_offer.backend) or instance_configuration.reservation is None ) ): - selected_placement_group_state = await _find_or_create_suitable_placement_group_state( + ( + placement_group_model, + created_placement_group_model, + ) = await _find_or_create_suitable_placement_group_model( instance_model=instance_model, - placement_group_states=placement_group_states, - instance_offer=selected_offer, + placement_group_models=placement_group_models, + instance_offer=instance_offer, compute=compute, ) - if selected_placement_group_state is None: + if placement_group_model is None: continue - if ( - selected_placement_group_state.new_model is not None - and selected_placement_group_state.id not in seen_placement_group_ids - ): - seen_placement_group_ids.add(selected_placement_group_state.id) - placement_group_states.append(selected_placement_group_state) - result.new_placement_group_models.append(selected_placement_group_state.new_model) + if created_placement_group_model: + placement_group_models.append(placement_group_model) + result.new_placement_group_models.append(placement_group_model) logger.debug( "Trying %s in %s/%s for $%0.4f per hour", - selected_offer.instance.name, - selected_offer.backend.value, - selected_offer.region, - selected_offer.price, + instance_offer.instance.name, + instance_offer.backend.value, + instance_offer.region, + instance_offer.price, ) try: job_provisioning_data = await run_async( compute.create_instance, - selected_offer, + instance_offer, instance_configuration, - selected_placement_group_state.placement_group - if selected_placement_group_state is not None - else None, + placement_group_model_to_placement_group_optional(placement_group_model), ) except BackendError as exc: logger.warning( "%s launch in %s/%s failed: %s", - selected_offer.instance.name, - selected_offer.backend.value, - selected_offer.region, + instance_offer.instance.name, + instance_offer.backend.value, + instance_offer.region, repr(exc), extra={"instance_name": instance_model.name}, ) @@ -179,9 +169,9 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: except Exception: logger.exception( "Got exception when launching %s in %s/%s", - selected_offer.instance.name, - selected_offer.backend.value, - selected_offer.region, + instance_offer.instance.name, + instance_offer.backend.value, + instance_offer.region, ) continue @@ -191,18 +181,18 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: new_status=InstanceStatus.PROVISIONING, ) result.instance_update_map["backend"] = backend.TYPE - result.instance_update_map["region"] = selected_offer.region - result.instance_update_map["price"] = selected_offer.price + result.instance_update_map["region"] = instance_offer.region + result.instance_update_map["price"] = instance_offer.price result.instance_update_map["instance_configuration"] = instance_configuration.json() result.instance_update_map["job_provisioning_data"] = job_provisioning_data.json() - result.instance_update_map["offer"] = selected_offer.json() - result.instance_update_map["total_blocks"] = selected_offer.total_blocks + result.instance_update_map["offer"] = instance_offer.json() + result.instance_update_map["total_blocks"] = instance_offer.total_blocks result.instance_update_map["started_at"] = NOW_PLACEHOLDER if instance_model.fleet_id is not None and instance_model.id == master_instance_model.id: result.schedule_pg_deletion_fleet_id = instance_model.fleet_id - if selected_placement_group_state is not None: - result.schedule_pg_deletion_except_id = selected_placement_group_state.id + if placement_group_model is not None: + result.schedule_pg_deletion_except_id = placement_group_model.id return result set_status_update( @@ -244,65 +234,18 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: return result -async def _get_fleet_placement_group_states( - fleet_id: Optional[uuid.UUID], -) -> list[_PlacementGroupState]: - if fleet_id is None: - return [] - async with get_session_ctx() as session: - res = await session.execute( - select(PlacementGroupModel) - .where( - PlacementGroupModel.fleet_id == fleet_id, - PlacementGroupModel.deleted == False, - PlacementGroupModel.fleet_deleted == False, - ) - .options(joinedload(PlacementGroupModel.project)) - ) - placement_group_models = list(res.unique().scalars().all()) - return [ - _PlacementGroupState( - id=placement_group_model.id, - placement_group=placement_group_model_to_placement_group(placement_group_model), - ) - for placement_group_model in placement_group_models - ] - - -def _get_placement_group_state_for_instance( - placement_group_states: list[_PlacementGroupState], - instance_model: InstanceModel, - master_instance_model: InstanceModel, -) -> Optional[_PlacementGroupState]: - if instance_model.id == master_instance_model.id: - return None - if len(placement_group_states) > 1: - logger.error( - ( - "Expected 0 or 1 placement groups associated with fleet %s, found %s." - " An incorrect placement group might have been selected for instance %s" - ), - instance_model.fleet_id, - len(placement_group_states), - instance_model.name, - ) - if placement_group_states: - return placement_group_states[0] - return None - - -async def _find_or_create_suitable_placement_group_state( +async def _find_or_create_suitable_placement_group_model( instance_model: InstanceModel, - placement_group_states: list[_PlacementGroupState], + placement_group_models: list[PlacementGroupModel], instance_offer: InstanceOfferWithAvailability, compute: ComputeWithPlacementGroupSupport, -) -> Optional[_PlacementGroupState]: - for placement_group_state in placement_group_states: +) -> tuple[Optional[PlacementGroupModel], bool]: + for placement_group_model in placement_group_models: if compute.is_suitable_placement_group( - placement_group_state.placement_group, + placement_group_model_to_placement_group(placement_group_model), instance_offer, ): - return placement_group_state + return placement_group_model, False assert instance_model.fleet is not None placement_group_id = uuid.uuid4() @@ -310,16 +253,18 @@ async def _find_or_create_suitable_placement_group_state( project_name=instance_model.project.name, fleet_name=instance_model.fleet.name, ) - placement_group = PlacementGroup( + placement_group_model = PlacementGroupModel( + id=placement_group_id, name=placement_group_name, - project_name=instance_model.project.name, + project=instance_model.project, + fleet=get_or_error(instance_model.fleet), configuration=PlacementGroupConfiguration( backend=instance_offer.backend, region=instance_offer.region, placement_strategy=PlacementStrategy.CLUSTER, - ), - provisioning_data=None, + ).json(), ) + placement_group = placement_group_model_to_placement_group(placement_group_model) logger.debug( "Creating placement group %s in %s/%s", placement_group.name, @@ -337,7 +282,7 @@ async def _find_or_create_suitable_placement_group_state( "Skipping offer %s because placement group not supported", instance_offer.instance.name, ) - return None + return None, False except BackendError as exc: logger.warning( "Failed to create placement group %s in %s/%s: %r", @@ -346,7 +291,7 @@ async def _find_or_create_suitable_placement_group_state( placement_group.configuration.region, exc, ) - return None + return None, False except Exception: logger.exception( "Got exception when creating placement group %s in %s/%s", @@ -354,18 +299,8 @@ async def _find_or_create_suitable_placement_group_state( placement_group.configuration.backend.value, placement_group.configuration.region, ) - return None + return None, False placement_group.provisioning_data = provisioning_data - return _PlacementGroupState( - id=placement_group_id, - placement_group=placement_group, - new_model=PlacementGroupModel( - id=placement_group_id, - name=placement_group.name, - project_id=instance_model.project_id, - fleet_id=get_or_error(instance_model.fleet_id), - configuration=placement_group.configuration.json(), - provisioning_data=provisioning_data.json(), - ), - ) + placement_group_model.provisioning_data = provisioning_data.json() + return placement_group_model, True From 2d0090b481837edc3efff747c2a1659dfd58dd51 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 15:54:20 +0500 Subject: [PATCH 25/51] Restore comments --- .../pipeline_tasks/instances/check.py | 24 +++++++++++++++++++ .../instances/cloud_provisioning.py | 4 ++++ .../pipeline_tasks/instances/common.py | 2 ++ .../pipeline_tasks/instances/ssh_deploy.py | 5 ++++ 4 files changed, 35 insertions(+) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py index 6cf75e827a..7ddddb31c2 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py @@ -61,6 +61,11 @@ def process_idle_timeout(instance_model: InstanceModel) -> Optional[ProcessResul and not instance_model.jobs ): return None + # Do not terminate instances on idle duration if fleet is already at `nodes.min`. + # This is an optimization to avoid terminate-create loop. + # There may be race conditions since we don't take the fleet lock. + # That's ok: in the worst case we go below `nodes.min`, but + # the fleet consolidation logic will provision new nodes. if instance_model.fleet is not None and not can_terminate_fleet_instances_on_idle_duration( instance_model.fleet ): @@ -92,6 +97,8 @@ async def check_instance(instance_model: InstanceModel) -> ProcessResult: and instance_model.jobs and all(job.status.is_finished() for job in instance_model.jobs) ): + # A busy instance could have no active jobs due to this bug: + # https://github.com/dstackai/dstack/issues/2068 set_status_update( update_map=result.instance_update_map, instance_model=instance_model, @@ -143,6 +150,7 @@ async def check_instance(instance_model: InstanceModel) -> ProcessResult: ) if instance_check.has_health_checks(): + # ensured by has_health_checks() assert instance_check.health_response is not None result.health_check_create = HealthCheckCreate( instance_id=instance_model.id, @@ -232,6 +240,7 @@ async def _run_instance_check( instance=instance_model, check_instance_health=check_instance_health, ) + # May return False if fails to establish ssh connection. if instance_check is False: return InstanceCheck(reachable=False, message="SSH or tunnel error") return instance_check @@ -244,6 +253,7 @@ def _get_health_status_for_instance_check( ) -> HealthStatus: if instance_check.reachable and check_instance_health: return instance_check.get_health_status() + # Keep previous health status. return instance_model.health @@ -367,6 +377,7 @@ def _check_instance_inner( except Exception as exc: logger.exception("%s: error removing dangling tasks: %s", fmt(instance), exc) + # There should be no shim API calls after this function call since it can request shim restart. _maybe_install_components(instance, shim_client) return runner_client.healthcheck_response_to_instance_check( healthcheck_response, @@ -404,6 +415,11 @@ def _maybe_install_components( else: logger.debug("Instance %s: no shim info", instance_model.name) + # old shim without `dstack-shim` component and `/api/shutdown` support + # or the same version is already running + # or we just requested installation of at least one component + # or at least one component is already being installed + # or at least one shim task won't survive restart running_shim_version = shim_client.get_version_string() if ( installed_shim_version is None @@ -430,6 +446,10 @@ def _maybe_install_runner( shim_client: runner_client.ShimClient, runner_info: ComponentInfo, ) -> bool: + # For developers: + # * To install the latest dev build for the current branch from the CI, + # set DSTACK_USE_LATEST_FROM_BRANCH=1. + # * To provide your own build, set DSTACK_RUNNER_VERSION_URL and DSTACK_RUNNER_DOWNLOAD_URL. expected_version = get_dstack_runner_version() if expected_version is None: logger.debug("Cannot determine the expected runner version") @@ -473,6 +493,10 @@ def _maybe_install_shim( shim_client: runner_client.ShimClient, shim_info: ComponentInfo, ) -> bool: + # For developers: + # * To install the latest dev build for the current branch from the CI, + # set DSTACK_USE_LATEST_FROM_BRANCH=1. + # * To provide your own build, set DSTACK_SHIM_VERSION_URL and DSTACK_SHIM_DOWNLOAD_URL. expected_version = get_dstack_shim_version() if expected_version is None: logger.debug("Cannot determine the expected shim version") diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py index a119d1eb66..9e58a6fc2e 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py @@ -106,6 +106,7 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: exclude_not_available=True, ) + # Limit number of offers tried to prevent long-running processing in case all offers fail. for backend, instance_offer in offers[: server_settings.MAX_OFFERS_TRIED]: if instance_offer.backend not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT: continue @@ -190,6 +191,7 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: result.instance_update_map["started_at"] = NOW_PLACEHOLDER if instance_model.fleet_id is not None and instance_model.id == master_instance_model.id: + # Clean up placement groups that did not end up being used. result.schedule_pg_deletion_fleet_id = instance_model.fleet_id if placement_group_model is not None: result.schedule_pg_deletion_except_id = placement_group_model.id @@ -207,6 +209,8 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: and instance_model.id == master_instance_model.id and is_cloud_cluster(instance_model.fleet) ): + # Do not attempt to deploy other instances, as they won't determine the correct cluster + # backend, region, and placement group without a successfully deployed master instance. for sibling_instance_model in instance_model.fleet.instances: if sibling_instance_model.id == instance_model.id: continue diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py index a9475cb9f6..123fd29049 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py @@ -122,6 +122,8 @@ def need_to_wait_fleet_provisioning( instance_model: InstanceModel, master_instance_model: InstanceModel, ) -> bool: + # Cluster cloud instances should wait for the first fleet instance to be provisioned + # so that they are provisioned in the same backend/region. if instance_model.fleet is None: return False if ( diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/ssh_deploy.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/ssh_deploy.py index 312f7977ef..b4e3e1122a 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/ssh_deploy.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/ssh_deploy.py @@ -259,10 +259,12 @@ def _deploy_instance( arch = detect_cpu_arch(client) logger.debug("%s: CPU arch is %s", remote_details.host, arch) + # Execute pre start commands shim_pre_start_commands = get_shim_pre_start_commands(arch=arch) run_pre_start_commands(client, shim_pre_start_commands, authorized_keys) logger.debug("The script for installing dstack has been executed") + # Upload envs shim_envs = get_shim_env(arch=arch) try: fleet_configuration_envs = remote_details.env.as_dict() @@ -275,9 +277,11 @@ def _deploy_instance( upload_envs(client, dstack_working_dir, shim_envs) logger.debug("The dstack-shim environment variables have been installed") + # Ensure we have fresh versions of host info.json and dstack-runner remove_host_info_if_exists(client, dstack_working_dir) remove_dstack_runner_if_exists(client, dstack_runner_binary_path) + # Run dstack-shim as a systemd service run_shim_as_systemd_service( client=client, binary_path=dstack_shim_binary_path, @@ -285,6 +289,7 @@ def _deploy_instance( dev=settings.DSTACK_VERSION is None, ) + # Get host info host_info = get_host_info(client, dstack_working_dir) logger.debug("Received a host_info %s", host_info) From 24c74cca9f71894ced644144540821e4637e1ae8 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 15:58:56 +0500 Subject: [PATCH 26/51] Fix result.sibling_update_rows append --- .../instances/cloud_provisioning.py | 5 ++--- .../pipeline_tasks/instances/common.py | 21 +++++++++++++------ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py index 9e58a6fc2e..7ac27b4c59 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py @@ -215,13 +215,12 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: if sibling_instance_model.id == instance_model.id: continue sibling_update_map = SiblingInstanceUpdateMap(id=sibling_instance_model.id) - set_status_update( + if set_status_update( update_map=sibling_update_map, instance_model=sibling_instance_model, new_status=InstanceStatus.TERMINATED, termination_reason=InstanceTerminationReason.MASTER_FAILED, - ) - if len(sibling_update_map) > 1: + ): result.sibling_update_rows.append(sibling_update_map) append_sibling_status_event( deferred_events=result.sibling_deferred_events, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py index 123fd29049..f2fcd1f366 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py @@ -189,18 +189,21 @@ def set_status_update( new_status: InstanceStatus, termination_reason: object = _UNSET, termination_reason_message: object = _UNSET, -) -> None: +) -> bool: old_status = instance_model.status + changed = False if old_status == new_status: if termination_reason is not _UNSET: update_map["termination_reason"] = cast( Optional[InstanceTerminationReason], termination_reason ) + changed = True if termination_reason_message is not _UNSET: update_map["termination_reason_message"] = cast( Optional[str], termination_reason_message ) - return + changed = True + return changed effective_termination_reason = instance_model.termination_reason if termination_reason is not _UNSET: @@ -208,33 +211,39 @@ def set_status_update( Optional[InstanceTerminationReason], termination_reason ) update_map["termination_reason"] = effective_termination_reason + changed = True effective_termination_reason_message = instance_model.termination_reason_message if termination_reason_message is not _UNSET: effective_termination_reason_message = cast(Optional[str], termination_reason_message) update_map["termination_reason_message"] = effective_termination_reason_message + changed = True update_map["status"] = new_status + changed = True + return changed def set_health_update( update_map: InstanceUpdateMap, instance_model: InstanceModel, health: HealthStatus, -) -> None: +) -> bool: if instance_model.health == health: - return + return False update_map["health"] = health + return True def set_unreachable_update( update_map: InstanceUpdateMap, instance_model: InstanceModel, unreachable: bool, -) -> None: +) -> bool: if not instance_model.status.is_available() or instance_model.unreachable == unreachable: - return + return False update_map["unreachable"] = unreachable + return True def append_sibling_status_event( From 8bd898de8aff99cbf5a083f228fa95d537176638 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 3 Mar 2026 16:08:20 +0500 Subject: [PATCH 27/51] Fix unset typing --- .../pipeline_tasks/instances/check.py | 4 +-- .../instances/cloud_provisioning.py | 11 +++---- .../pipeline_tasks/instances/common.py | 32 +++++++------------ src/dstack/_internal/utils/common.py | 14 +++++++- 4 files changed, 31 insertions(+), 30 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py index 7ddddb31c2..8ca135ab69 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py @@ -1,6 +1,6 @@ import logging from datetime import timedelta -from typing import Dict, Optional +from typing import Optional import gpuhunt import requests @@ -348,7 +348,7 @@ async def _process_wait_for_instance_provisioning_data( @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) def _check_instance_inner( - ports: Dict[int, int], + ports: dict[int, int], *, instance: InstanceModel, check_instance_health: bool = False, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py index 7ac27b4c59..cb2e9f9400 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py @@ -1,5 +1,5 @@ import uuid -from typing import Optional, cast +from typing import Optional from pydantic import ValidationError @@ -226,12 +226,9 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: deferred_events=result.sibling_deferred_events, instance_model=sibling_instance_model, new_status=InstanceStatus.TERMINATED, - termination_reason=cast( - Optional[InstanceTerminationReason], - sibling_update_map.get("termination_reason"), - ), - termination_reason_message=cast( - Optional[str], sibling_update_map.get("termination_reason_message") + termination_reason=sibling_update_map.get("termination_reason"), + termination_reason_message=sibling_update_map.get( + "termination_reason_message" ), ) return result diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py index f2fcd1f366..65915c30e2 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py @@ -2,7 +2,7 @@ import uuid from dataclasses import dataclass, field from datetime import timedelta -from typing import Optional, TypedDict, Union, cast +from typing import Optional, TypedDict, Union from paramiko.pkey import PKey @@ -28,7 +28,7 @@ get_instance_status_change_message, ) from dstack._internal.server.services.offers import get_instance_offer_with_restricted_az -from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.common import UNSET, Unset, get_current_datetime from dstack._internal.utils.ssh import pkey_from_str TERMINATION_DEADLINE_OFFSET = timedelta(minutes=20) @@ -36,8 +36,6 @@ TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15) PROVISIONING_TIMEOUT_SECONDS = 10 * 60 # 10 minutes in seconds -_UNSET = object() - class InstanceUpdateMap(ItemUpdateMap, total=False): status: InstanceStatus @@ -187,35 +185,29 @@ def set_status_update( update_map: Union[InstanceUpdateMap, SiblingInstanceUpdateMap], instance_model: InstanceModel, new_status: InstanceStatus, - termination_reason: object = _UNSET, - termination_reason_message: object = _UNSET, + termination_reason: Union[Optional[InstanceTerminationReason], Unset] = UNSET, + termination_reason_message: Union[Optional[str], Unset] = UNSET, ) -> bool: old_status = instance_model.status changed = False if old_status == new_status: - if termination_reason is not _UNSET: - update_map["termination_reason"] = cast( - Optional[InstanceTerminationReason], termination_reason - ) + if not isinstance(termination_reason, Unset): + update_map["termination_reason"] = termination_reason changed = True - if termination_reason_message is not _UNSET: - update_map["termination_reason_message"] = cast( - Optional[str], termination_reason_message - ) + if not isinstance(termination_reason_message, Unset): + update_map["termination_reason_message"] = termination_reason_message changed = True return changed effective_termination_reason = instance_model.termination_reason - if termination_reason is not _UNSET: - effective_termination_reason = cast( - Optional[InstanceTerminationReason], termination_reason - ) + if not isinstance(termination_reason, Unset): + effective_termination_reason = termination_reason update_map["termination_reason"] = effective_termination_reason changed = True effective_termination_reason_message = instance_model.termination_reason_message - if termination_reason_message is not _UNSET: - effective_termination_reason_message = cast(Optional[str], termination_reason_message) + if not isinstance(termination_reason_message, Unset): + effective_termination_reason_message = termination_reason_message update_map["termination_reason_message"] = effective_termination_reason_message changed = True diff --git a/src/dstack/_internal/utils/common.py b/src/dstack/_internal/utils/common.py index ba139c6bfc..3653efc9ed 100644 --- a/src/dstack/_internal/utils/common.py +++ b/src/dstack/_internal/utils/common.py @@ -7,13 +7,25 @@ from datetime import datetime, timedelta, timezone from functools import partial from pathlib import Path -from typing import Any, Iterable, List, Optional, TypeVar +from typing import Any, Final, Iterable, List, Optional, TypeVar from urllib.parse import urlparse from typing_extensions import ParamSpec from dstack._internal.core.models.common import Duration + +class Unset: + pass + + +UNSET: Final = Unset() +""" +Use `UNSET` as kwargs default value to distinguish between +specified and non-specified `Optional` values. +""" + + P = ParamSpec("P") R = TypeVar("R") From 07e83e42d9b60446c4f620bed9455045cf3221ab Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 4 Mar 2026 10:48:03 +0500 Subject: [PATCH 28/51] Add migration --- ...0aa4_add_instancemodel_pipeline_columns.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 src/dstack/_internal/server/migrations/versions/2026/03_04_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_04_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py b/src/dstack/_internal/server/migrations/versions/2026/03_04_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py new file mode 100644 index 0000000000..cc82d95de4 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_04_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py @@ -0,0 +1,47 @@ +"""Add InstanceModel pipeline columns + +Revision ID: 8e8647f20aa4 +Revises: 46150101edec +Create Date: 2026-03-04 05:47:39.307013+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "8e8647f20aa4" +down_revision = "46150101edec" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + + # ### end Alembic commands ### From 1b99cfcc038f7587ad1cec6ac5f1be90de400df2 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 4 Mar 2026 10:54:30 +0500 Subject: [PATCH 29/51] Lock instances in fleet pipeline --- .../background/pipeline_tasks/fleets.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index a18e817a4a..af29f909f2 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -225,15 +225,14 @@ async def process(self, item: PipelineItem): .where( InstanceModel.fleet_id == item.id, InstanceModel.deleted == False, - # TODO: Lock instance models in the DB - # or_( - # InstanceModel.lock_expires_at.is_(None), - # InstanceModel.lock_expires_at < get_current_datetime(), - # ), - # or_( - # InstanceModel.lock_owner.is_(None), - # InstanceModel.lock_owner == FleetPipeline.__name__, - # ), + or_( + InstanceModel.lock_expires_at.is_(None), + InstanceModel.lock_expires_at < get_current_datetime(), + ), + or_( + InstanceModel.lock_owner.is_(None), + InstanceModel.lock_owner == FleetPipeline.__name__, + ), ) .with_for_update(skip_locked=True, key_share=True) ) @@ -243,7 +242,6 @@ async def process(self, item: PipelineItem): "Failed to lock fleet %s instances. The fleet will be processed later.", item.id, ) - now = get_current_datetime() # Keep `lock_owner` so that `InstancePipeline` sees that the fleet is being locked # but unset `lock_expires_at` to process the item again ASAP (after `min_processing_interval`). # Unset `lock_token` so that heartbeater can no longer update the item. @@ -256,19 +254,20 @@ async def process(self, item: PipelineItem): .values( lock_expires_at=None, lock_token=None, - last_processed_at=now, + last_processed_at=get_current_datetime(), ) + .returning(FleetModel.id) ) - if res.rowcount == 0: # pyright: ignore[reportAttributeAccessIssue] + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: log_lock_token_changed_on_reset(logger) return - # TODO: Lock instance models in the DB - # for instance_model in locked_instance_models: - # instance_model.lock_expires_at = item.lock_expires_at - # instance_model.lock_token = item.lock_token - # instance_model.lock_owner = FleetPipeline.__name__ - # await session.commit() + for instance_model in locked_instance_models: + instance_model.lock_expires_at = item.lock_expires_at + instance_model.lock_token = item.lock_token + instance_model.lock_owner = FleetPipeline.__name__ + await session.commit() result = await _process_fleet(fleet_model) fleet_update_map = _FleetUpdateMap() From 57255760d110e009f5fdd8a944e3c2e16413a364 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 4 Mar 2026 11:30:13 +0500 Subject: [PATCH 30/51] Optimize instance lock in fleet pipeline --- .../background/pipeline_tasks/fleets.py | 187 +++++++++------ .../background/pipeline_tasks/test_fleets.py | 219 ++++++++++++++++++ 2 files changed, 338 insertions(+), 68 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index af29f909f2..5691c589a5 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -2,7 +2,7 @@ import uuid from dataclasses import dataclass, field from datetime import timedelta -from typing import Sequence, TypedDict +from typing import Optional, Sequence, TypedDict from sqlalchemy import or_, select, update from sqlalchemy.ext.asyncio.session import AsyncSession @@ -216,60 +216,22 @@ async def process(self, item: PipelineItem): log_lock_token_mismatch(logger, item) return - instance_lock, _ = get_locker(get_db().dialect_name).get_lockset( - InstanceModel.__tablename__ - ) - async with instance_lock: - res = await session.execute( - select(InstanceModel) - .where( - InstanceModel.fleet_id == item.id, - InstanceModel.deleted == False, - or_( - InstanceModel.lock_expires_at.is_(None), - InstanceModel.lock_expires_at < get_current_datetime(), - ), - or_( - InstanceModel.lock_owner.is_(None), - InstanceModel.lock_owner == FleetPipeline.__name__, - ), - ) - .with_for_update(skip_locked=True, key_share=True) + # Lock instance only if consolidation is needed. + consolidation_fleet_spec = _get_fleet_spec_if_ready_for_consolidation(fleet_model) + consolidation_instances = None + if consolidation_fleet_spec is not None: + consolidation_instances = await _lock_fleet_instances_for_consolidation( + session=session, + item=item, ) - locked_instance_models = res.scalars().all() - if len(fleet_model.instances) != len(locked_instance_models): - logger.debug( - "Failed to lock fleet %s instances. The fleet will be processed later.", - item.id, - ) - # Keep `lock_owner` so that `InstancePipeline` sees that the fleet is being locked - # but unset `lock_expires_at` to process the item again ASAP (after `min_processing_interval`). - # Unset `lock_token` so that heartbeater can no longer update the item. - res = await session.execute( - update(FleetModel) - .where( - FleetModel.id == item.id, - FleetModel.lock_token == item.lock_token, - ) - .values( - lock_expires_at=None, - lock_token=None, - last_processed_at=get_current_datetime(), - ) - .returning(FleetModel.id) - ) - updated_ids = list(res.scalars().all()) - if len(updated_ids) == 0: - log_lock_token_changed_on_reset(logger) + if consolidation_instances is None: return - for instance_model in locked_instance_models: - instance_model.lock_expires_at = item.lock_expires_at - instance_model.lock_token = item.lock_token - instance_model.lock_owner = FleetPipeline.__name__ - await session.commit() - - result = await _process_fleet(fleet_model) + result = await _process_fleet( + fleet_model, + consolidation_fleet_spec=consolidation_fleet_spec, + consolidation_instances=consolidation_instances, + ) fleet_update_map = _FleetUpdateMap() fleet_update_map.update(result.fleet_update_map) set_processed_update_map_fields(fleet_update_map) @@ -340,6 +302,86 @@ class _InstanceUpdateMap(TypedDict, total=False): id: uuid.UUID +def _get_fleet_spec_if_ready_for_consolidation(fleet_model: FleetModel) -> Optional[FleetSpec]: + if fleet_model.status == FleetStatus.TERMINATING: + return None + consolidation_fleet_spec = get_fleet_spec(fleet_model) + if ( + consolidation_fleet_spec.configuration.nodes is None + or consolidation_fleet_spec.autocreated + ): + return None + if not _is_fleet_ready_for_consolidation(fleet_model): + return None + return consolidation_fleet_spec + + +async def _lock_fleet_instances_for_consolidation( + session: AsyncSession, + item: PipelineItem, +) -> Optional[list[InstanceModel]]: + instance_lock, _ = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__) + async with instance_lock: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.fleet_id == item.id, + InstanceModel.deleted == False, + or_( + InstanceModel.lock_expires_at.is_(None), + InstanceModel.lock_expires_at < get_current_datetime(), + ), + or_( + InstanceModel.lock_owner.is_(None), + InstanceModel.lock_owner == FleetPipeline.__name__, + ), + ) + .with_for_update(skip_locked=True, key_share=True) + ) + locked_instance_models = list(res.scalars().all()) + locked_instance_ids = {instance_model.id for instance_model in locked_instance_models} + + res = await session.execute( + select(InstanceModel.id).where( + InstanceModel.fleet_id == item.id, + InstanceModel.deleted == False, + ) + ) + current_instance_ids = set(res.scalars().all()) + if current_instance_ids != locked_instance_ids: + logger.debug( + "Failed to lock fleet %s instances. The fleet will be processed later.", + item.id, + ) + # Keep `lock_owner` so that `InstancePipeline` sees that the fleet is being locked + # but unset `lock_expires_at` to process the item again ASAP (after `min_processing_interval`). + # Unset `lock_token` so that heartbeater can no longer update the item. + res = await session.execute( + update(FleetModel) + .where( + FleetModel.id == item.id, + FleetModel.lock_token == item.lock_token, + ) + .values( + lock_expires_at=None, + lock_token=None, + last_processed_at=get_current_datetime(), + ) + .returning(FleetModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + log_lock_token_changed_on_reset(logger) + return None + + for instance_model in locked_instance_models: + instance_model.lock_expires_at = item.lock_expires_at + instance_model.lock_token = item.lock_token + instance_model.lock_owner = FleetPipeline.__name__ + await session.commit() + return locked_instance_models + + @dataclass class _ProcessResult: fleet_update_map: _FleetUpdateMap = field(default_factory=_FleetUpdateMap) @@ -358,8 +400,18 @@ def has_changes(self) -> bool: return len(self.instance_id_to_update_map) > 0 or self.new_instances_count > 0 -async def _process_fleet(fleet_model: FleetModel) -> _ProcessResult: - result = _consolidate_fleet_state_with_spec(fleet_model) +async def _process_fleet( + fleet_model: FleetModel, + consolidation_fleet_spec: Optional[FleetSpec] = None, + consolidation_instances: Optional[Sequence[InstanceModel]] = None, +) -> _ProcessResult: + result = _ProcessResult() + if consolidation_fleet_spec is not None: + result = _consolidate_fleet_state_with_spec( + fleet_model, + consolidation_fleet_spec=consolidation_fleet_spec, + consolidation_instances=consolidation_instances, + ) if result.new_instances_count > 0: # Avoid deleting fleets that are about to provision new instances. return result @@ -371,17 +423,16 @@ async def _process_fleet(fleet_model: FleetModel) -> _ProcessResult: return result -def _consolidate_fleet_state_with_spec(fleet_model: FleetModel) -> _ProcessResult: +def _consolidate_fleet_state_with_spec( + fleet_model: FleetModel, + consolidation_fleet_spec: FleetSpec, + consolidation_instances: Optional[Sequence[InstanceModel]] = None, +) -> _ProcessResult: result = _ProcessResult() - if fleet_model.status == FleetStatus.TERMINATING: - return result - fleet_spec = get_fleet_spec(fleet_model) - if fleet_spec.configuration.nodes is None or fleet_spec.autocreated: - # Only explicitly created cloud fleets are consolidated. - return result - if not _is_fleet_ready_for_consolidation(fleet_model): - return result - maintain_nodes_result = _maintain_fleet_nodes_in_min_max_range(fleet_model, fleet_spec) + maintain_nodes_result = _maintain_fleet_nodes_in_min_max_range( + instances=consolidation_instances or fleet_model.instances, + fleet_spec=consolidation_fleet_spec, + ) if maintain_nodes_result.has_changes: result.instance_id_to_update_map = maintain_nodes_result.instance_id_to_update_map result.new_instances_count = maintain_nodes_result.new_instances_count @@ -420,7 +471,7 @@ def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta: def _maintain_fleet_nodes_in_min_max_range( - fleet_model: FleetModel, + instances: Sequence[InstanceModel], fleet_spec: FleetSpec, ) -> _MaintainNodesResult: """ @@ -428,7 +479,7 @@ def _maintain_fleet_nodes_in_min_max_range( """ assert fleet_spec.configuration.nodes is not None result = _MaintainNodesResult() - for instance in fleet_model.instances: + for instance in instances: # Delete terminated but not deleted instances since # they are going to be replaced with new pending instances. if instance.status == InstanceStatus.TERMINATED and not instance.deleted: @@ -438,7 +489,7 @@ def _maintain_fleet_nodes_in_min_max_range( "deleted_at": NOW_PLACEHOLDER, } active_instances = [ - i for i in fleet_model.instances if i.status != InstanceStatus.TERMINATED and not i.deleted + i for i in instances if i.status != InstanceStatus.TERMINATED and not i.deleted ] active_instances_num = len(active_instances) if active_instances_num < fleet_spec.configuration.nodes.min: @@ -456,7 +507,7 @@ def _maintain_fleet_nodes_in_min_max_range( # or if nodes.max is updated. result.changes_required = True nodes_redundant = active_instances_num - fleet_spec.configuration.nodes.max - for instance in fleet_model.instances: + for instance in instances: if nodes_redundant == 0: break if instance.status == InstanceStatus.IDLE: diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index 71807b8042..3b433ea047 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -6,6 +6,7 @@ import pytest from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from dstack._internal.core.models.fleets import FleetNodesSpec, FleetStatus from dstack._internal.core.models.instances import InstanceStatus @@ -16,7 +17,10 @@ FleetFetcher, FleetPipeline, FleetWorker, + _get_fleet_spec_if_ready_for_consolidation, + _lock_fleet_instances_for_consolidation, ) +from dstack._internal.server.db import get_session_ctx from dstack._internal.server.models import FleetModel, InstanceModel from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( @@ -28,6 +32,7 @@ create_run, create_user, get_fleet_spec, + get_ssh_fleet_configuration, ) from dstack._internal.utils.common import get_current_datetime @@ -166,6 +171,220 @@ async def test_fetch_returns_oldest_fleets_first_up_to_limit( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestFleetWorker: + async def test_ready_for_consolidation_helper_returns_none_for_ssh_fleet( + self, test_db, session: AsyncSession + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec(conf=get_ssh_fleet_configuration()), + ) + + assert _get_fleet_spec_if_ready_for_consolidation(fleet) is None + + async def test_ready_for_consolidation_helper_returns_none_when_retry_delay_is_active( + self, test_db, session: AsyncSession + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec(), + ) + fleet.consolidation_attempt = 1 + fleet.last_consolidated_at = datetime.now(timezone.utc) + await session.commit() + + assert _get_fleet_spec_if_ready_for_consolidation(fleet) is None + + async def test_ready_for_consolidation_helper_returns_consolidation_fleet_spec_for_eligible_cloud_fleet( + self, test_db, session: AsyncSession + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec(), + last_processed_at=datetime(2023, 1, 2, 3, 0, tzinfo=timezone.utc), + ) + + consolidation_fleet_spec = _get_fleet_spec_if_ready_for_consolidation(fleet) + + assert consolidation_fleet_spec is not None + assert consolidation_fleet_spec.configuration.nodes is not None + + async def test_skips_instance_locking_for_ssh_fleet( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec(conf=get_ssh_fleet_configuration()), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + original_last_processed_at = fleet.last_processed_at + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + instance.lock_token = uuid.uuid4() + instance.lock_expires_at = datetime(2025, 1, 2, 3, 5, tzinfo=timezone.utc) + instance.lock_owner = "OtherPipeline" + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(instance) + assert not fleet.deleted + assert fleet.lock_owner is None + assert fleet.lock_token is None + assert fleet.lock_expires_at is None + assert fleet.last_processed_at > original_last_processed_at + assert instance.lock_owner == "OtherPipeline" + + async def test_skips_instance_locking_when_fleet_is_not_ready_for_consolidation( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + original_last_processed_at = fleet.last_processed_at + original_last_consolidated_at = datetime.now(timezone.utc) + fleet.consolidation_attempt = 1 + fleet.last_consolidated_at = original_last_consolidated_at + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + instance.lock_token = uuid.uuid4() + instance.lock_expires_at = datetime(2025, 1, 2, 3, 5, tzinfo=timezone.utc) + instance.lock_owner = "OtherPipeline" + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(instance) + assert not fleet.deleted + assert fleet.consolidation_attempt == 1 + assert fleet.last_consolidated_at == original_last_consolidated_at + assert fleet.lock_owner is None + assert fleet.lock_token is None + assert fleet.lock_expires_at is None + assert fleet.last_processed_at > original_last_processed_at + assert instance.lock_owner == "OtherPipeline" + + async def test_resets_fleet_lock_when_not_all_instances_can_be_locked( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=0, + ) + locked_elsewhere = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=1, + ) + original_last_processed_at = fleet.last_processed_at + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + fleet.lock_owner = FleetPipeline.__name__ + locked_elsewhere.lock_token = uuid.uuid4() + locked_elsewhere.lock_expires_at = datetime(2025, 1, 2, 3, 5, tzinfo=timezone.utc) + locked_elsewhere.lock_owner = "OtherPipeline" + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(locked_elsewhere) + assert fleet.lock_owner == FleetPipeline.__name__ + assert fleet.lock_token is None + assert fleet.lock_expires_at is None + assert fleet.last_processed_at > original_last_processed_at + assert locked_elsewhere.lock_owner == "OtherPipeline" + + async def test_lock_helper_uses_fresh_current_instances_instead_of_stale_relationship( + self, test_db, session: AsyncSession + ): + project = await create_project(session) + spec = get_fleet_spec() + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=0, + ) + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + fleet.lock_owner = FleetPipeline.__name__ + await session.commit() + + res = await session.execute( + select(FleetModel) + .where(FleetModel.id == fleet.id) + .options(selectinload(FleetModel.instances.and_(InstanceModel.deleted == False))) + ) + stale_fleet_model = res.unique().scalar_one() + assert len(stale_fleet_model.instances) == 1 + + async with get_session_ctx() as other_session: + project_model = await other_session.get(type(project), project.id) + fleet_model = await other_session.get(FleetModel, fleet.id) + assert project_model is not None + assert fleet_model is not None + await create_instance( + session=other_session, + project=project_model, + fleet=fleet_model, + status=InstanceStatus.IDLE, + instance_num=1, + ) + + assert len(stale_fleet_model.instances) == 1 + + locked_instances = await _lock_fleet_instances_for_consolidation( + session=session, + item=_fleet_to_pipeline_item(fleet), + ) + + assert locked_instances is not None + assert len(locked_instances) == 2 + assert {instance.instance_num for instance in locked_instances} == {0, 1} + async def test_deletes_empty_autocreated_fleet( self, test_db, session: AsyncSession, worker: FleetWorker ): From 8955d68ff0f95561c5992a22fd3f32bc8f4761a1 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 4 Mar 2026 11:38:53 +0500 Subject: [PATCH 31/51] Respect instance lock in delete_fleets --- src/dstack/_internal/server/services/fleets.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 380052d78b..b7d4f68713 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -507,6 +507,8 @@ async def apply_plan( ): # Refetch after lock # TODO: Lock instances with FOR UPDATE? + # We do not respect InstanceModel.lock_* fields here because FleetPipeline does not update SSH instances. + # TODO: Respect InstanceModel.lock_* fields if FleetPipeline and apply update the same instances. res = await session.execute( select(FleetModel) .where( @@ -730,6 +732,7 @@ async def delete_fleets( .where( InstanceModel.id.in_(instances_ids), InstanceModel.deleted == False, + InstanceModel.lock_expires_at.is_(None), ) .order_by(InstanceModel.id) # take locks in order .with_for_update(key_share=True, of=InstanceModel) From 91bbced92670a44a83430ea36a04bf2cbaa3a32e Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 4 Mar 2026 11:42:12 +0500 Subject: [PATCH 32/51] Skip locked instances in process_next_terminating_job --- .../server/background/scheduled_tasks/terminating_jobs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py index 3749076c1a..27163b53d9 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py @@ -66,6 +66,7 @@ async def _process_next_terminating_job(): .where( InstanceModel.id == job_model.used_instance_id, InstanceModel.id.not_in(instance_lockset), + InstanceModel.lock_expires_at.is_(None), ) .with_for_update(skip_locked=True, key_share=True) ) From 168b9bdbde3165aee0ce83f164cb05685e45a19a Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 4 Mar 2026 12:44:40 +0500 Subject: [PATCH 33/51] Respect instance lock in submitted_jobs --- src/dstack/_internal/server/services/runs/plan.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/server/services/runs/plan.py b/src/dstack/_internal/server/services/runs/plan.py index 5e3b6e5a02..f1614561ad 100644 --- a/src/dstack/_internal/server/services/runs/plan.py +++ b/src/dstack/_internal/server/services/runs/plan.py @@ -239,9 +239,12 @@ async def select_run_candidate_fleet_models_with_filters( .execution_options(populate_existing=True) ) if lock_instances: - stmt = stmt.order_by(InstanceModel.id).with_for_update( # take locks in order - key_share=True, of=InstanceModel - ) + # Skip locked instances since waiting for all the instances to unlock may take indefinite time. + # TODO: Switch to optimistic locking – implement select-lock-reselect loop. + stmt = stmt.where(InstanceModel.lock_expires_at.is_(None)) + stmt = stmt.order_by( + InstanceModel.id # take locks in order + ).with_for_update(skip_locked=True, key_share=True, of=InstanceModel) res = await session.execute(stmt) fleet_models_with_instances = list(res.unique().scalars().all()) fleet_models_with_instances_ids = [f.id for f in fleet_models_with_instances] From 9cf36a45236aaede8640197f586384e8b1602a19 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 4 Mar 2026 12:52:28 +0500 Subject: [PATCH 34/51] Add ix_instances_pipeline_fetch_q_index --- ...add_ix_instances_pipeline_fetch_q_index.py | 49 +++++++++++++++++++ src/dstack/_internal/server/models.py | 9 ++++ 2 files changed, 58 insertions(+) create mode 100644 src/dstack/_internal/server/migrations/versions/2026/03_04_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_04_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py b/src/dstack/_internal/server/migrations/versions/2026/03_04_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py new file mode 100644 index 0000000000..5139370104 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_04_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py @@ -0,0 +1,49 @@ +"""Add ix_instances_pipeline_fetch_q index + +Revision ID: 297c68450cc8 +Revises: 8e8647f20aa4 +Create Date: 2026-03-04 07:51:02.855596+00:00 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "297c68450cc8" +down_revision = "8e8647f20aa4" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_instances_pipeline_fetch_q_index", + table_name="instances", + if_exists=True, + postgresql_concurrently=True, + ) + op.create_index( + "ix_instances_pipeline_fetch_q_index", + "instances", + [sa.literal_column("last_processed_at ASC")], + unique=False, + sqlite_where=sa.text("deleted = 0"), + postgresql_where=sa.text("deleted IS FALSE"), + postgresql_concurrently=True, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_instances_pipeline_fetch_q_index", + table_name="instances", + if_exists=True, + postgresql_concurrently=True, + ) + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 843b1d798b..285338cb77 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -727,6 +727,15 @@ class InstanceModel(PipelineModelMixin, BaseModel): cascade="save-update, merge, delete-orphan, delete", ) + __table_args__ = ( + Index( + "ix_instances_pipeline_fetch_q", + last_processed_at.asc(), + postgresql_where=deleted == false(), + sqlite_where=deleted == false(), + ), + ) + class InstanceHealthCheckModel(BaseModel): __tablename__ = "instance_health_checks" From 994165217038246fa8a71b3c91b384a337c57d34 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 4 Mar 2026 13:06:52 +0500 Subject: [PATCH 35/51] Wire instance pipeline --- .../server/background/pipeline_tasks/__init__.py | 2 ++ .../background/pipeline_tasks/instances/__init__.py | 2 +- .../server/background/scheduled_tasks/__init__.py | 12 ++++++------ 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index 6b3762419f..556e13daaf 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -4,6 +4,7 @@ from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupPipeline from dstack._internal.server.background.pipeline_tasks.fleets import FleetPipeline from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline +from dstack._internal.server.background.pipeline_tasks.instances import InstancePipeline from dstack._internal.server.background.pipeline_tasks.placement_groups import ( PlacementGroupPipeline, ) @@ -19,6 +20,7 @@ def __init__(self) -> None: ComputeGroupPipeline(), FleetPipeline(), GatewayPipeline(), + InstancePipeline(), PlacementGroupPipeline(), VolumePipeline(), ] diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py index cba8f68df1..113cbafb3c 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -71,7 +71,7 @@ class InstancePipelineItem(PipelineItem): class InstancePipeline(Pipeline[InstancePipelineItem]): def __init__( self, - workers_num: int = 10, + workers_num: int = 20, queue_lower_limit_factor: float = 0.5, queue_upper_limit_factor: float = 2.0, min_processing_interval: timedelta = timedelta(seconds=10), diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py index 6b7f6f3389..2994fca37c 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -146,13 +146,13 @@ def start_scheduled_tasks() -> AsyncIOScheduler: kwargs={"batch_size": 5}, max_instances=2 if replica == 0 else 1, ) - _scheduler.add_job( - process_instances, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) if not FeatureFlags.PIPELINE_PROCESSING_ENABLED: + _scheduler.add_job( + process_instances, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) _scheduler.add_job( process_compute_groups, IntervalTrigger(seconds=15, jitter=2), From 8e3a019f16f5ddd4734130d52847af523c1ce140 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 4 Mar 2026 15:55:49 +0500 Subject: [PATCH 36/51] Set current_master_instance --- .../background/pipeline_tasks/fleets.py | 90 +++- .../pipeline_tasks/instances/__init__.py | 22 - .../instances/cloud_provisioning.py | 194 ++++--- .../pipeline_tasks/instances/common.py | 98 +--- .../background/scheduled_tasks/instances.py | 1 + ...4d986_add_fleet_current_master_instance.py | 45 ++ src/dstack/_internal/server/models.py | 14 +- .../_internal/server/services/fleets.py | 15 +- .../background/pipeline_tasks/test_fleets.py | 229 +++++++- .../test_instances/test_cloud_provisioning.py | 505 ++++++++++++++++-- 10 files changed, 982 insertions(+), 231 deletions(-) create mode 100644 src/dstack/_internal/server/migrations/versions/2026/03_04_1015_9cb8e4e4d986_add_fleet_current_master_instance.py diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 5691c589a5..df0d5b809a 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -8,7 +8,11 @@ from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.orm import joinedload, load_only, selectinload -from dstack._internal.core.models.fleets import FleetSpec, FleetStatus +from dstack._internal.core.models.fleets import ( + FleetSpec, + FleetStatus, + InstanceGroupPlacement, +) from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason from dstack._internal.core.models.runs import RunStatus from dstack._internal.server.background.pipeline_tasks.base import ( @@ -274,6 +278,14 @@ async def process(self, item: PipelineItem): fleet_model=fleet_model, new_instances_count=result.new_instances_count, ) + await session.flush() + # FleetPipeline is the sole owner of cluster master election. + # Sync it after instance updates and creation so the pointer reflects + # the final post-consolidation fleet state that will be committed. + await _sync_current_master_instance( + session=session, + fleet_model_id=fleet_model.id, + ) emit_fleet_status_change_event( session=session, fleet_model=fleet_model, @@ -596,3 +608,79 @@ async def _create_missing_fleet_instances( new_instances_count, fleet_model.name, ) + + +async def _sync_current_master_instance( + session: AsyncSession, + fleet_model_id: uuid.UUID, +) -> None: + fleet_model = await session.get(FleetModel, fleet_model_id) + if fleet_model is None: + return + + new_current_master_instance_id = None + fleet_spec = get_fleet_spec(fleet_model) + is_cluster = ( + fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER + and fleet_spec.configuration.ssh_config is None + ) + if not fleet_model.deleted and is_cluster: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.fleet_id == fleet_model_id, + InstanceModel.deleted == False, + ) + .order_by(InstanceModel.instance_num, InstanceModel.created_at) + .options( + load_only( + InstanceModel.id, + InstanceModel.status, + InstanceModel.job_provisioning_data, + ) + ) + ) + current_instance_models = list(res.scalars().all()) + new_current_master_instance_id = _select_current_master_instance_id( + current_master_instance_id=fleet_model.current_master_instance_id, + instance_models=current_instance_models, + ) + + if fleet_model.current_master_instance_id == new_current_master_instance_id: + return + + await session.execute( + update(FleetModel) + .where(FleetModel.id == fleet_model_id) + .values(current_master_instance_id=new_current_master_instance_id) + ) + + +def _select_current_master_instance_id( + current_master_instance_id: Optional[uuid.UUID], + instance_models: Sequence[InstanceModel], +) -> Optional[uuid.UUID]: + # Keep the current master stable while it is still alive so InstancePipeline + # does not see fleet-wide election churn between provisioning attempts. + if current_master_instance_id is not None: + for instance_model in instance_models: + if ( + instance_model.id == current_master_instance_id + and instance_model.status != InstanceStatus.TERMINATED + ): + return instance_model.id + + # If the old master is gone, prefer a surviving provisioned instance since it + # already defines backend/region/AZ for the current cluster generation. + for instance_model in instance_models: + if ( + instance_model.status != InstanceStatus.TERMINATED + and instance_model.job_provisioning_data is not None + ): + return instance_model.id + + for instance_model in instance_models: + if instance_model.status != InstanceStatus.TERMINATED: + return instance_model.id + + return None diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py index 113cbafb3c..9515c92f7f 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -406,7 +406,6 @@ async def _apply_process_result( now = get_current_datetime() resolve_now_placeholders(result.instance_update_map, now=now) - resolve_now_placeholders(result.sibling_update_rows, now=now) res = await session.execute( update(InstanceModel) @@ -423,12 +422,6 @@ async def _apply_process_result( await session.rollback() return - if result.sibling_update_rows: - await session.execute( - update(InstanceModel), - result.sibling_update_rows, - ) - if result.schedule_pg_deletion_fleet_id is not None: await schedule_fleet_placement_groups_deletion( session=session, @@ -469,21 +462,6 @@ async def _apply_process_result( ), ) - for sibling_deferred_event in result.sibling_deferred_events: - events.emit( - session, - sibling_deferred_event.message, - actor=events.SystemActor(), - targets=[ - events.Target( - type=events.EventTargetType.INSTANCE, - project_id=sibling_deferred_event.project_id, - id=sibling_deferred_event.instance_id, - name=sibling_deferred_event.instance_name, - ) - ], - ) - def _emit_instance_health_change_event( session: AsyncSession, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py index cb2e9f9400..17d59747ce 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py @@ -2,6 +2,10 @@ from typing import Optional from pydantic import ValidationError +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import load_only +from sqlalchemy.orm.attributes import set_committed_value from dstack._internal.core.backends.base.compute import ( ComputeWithCreateInstanceSupport, @@ -26,25 +30,21 @@ from dstack._internal.server.background.pipeline_tasks.base import NOW_PLACEHOLDER from dstack._internal.server.background.pipeline_tasks.instances.common import ( ProcessResult, - SiblingInstanceUpdateMap, - append_sibling_status_event, - get_fleet_master_instance, - get_instance_offer_for_instance, - need_to_wait_fleet_provisioning, set_status_update, ) from dstack._internal.server.db import get_session_ctx -from dstack._internal.server.models import InstanceModel, PlacementGroupModel +from dstack._internal.server.models import FleetModel, InstanceModel, PlacementGroupModel from dstack._internal.server.services.fleets import get_create_instance_offers, is_cloud_cluster from dstack._internal.server.services.instances import ( get_instance_configuration, get_instance_profile, + get_instance_provisioning_data, get_instance_requirements, ) from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.offers import get_instance_offer_with_restricted_az from dstack._internal.server.services.placement import ( get_fleet_placement_group_models, - get_placement_group_model_for_instance, placement_group_model_to_placement_group, placement_group_model_to_placement_group_optional, ) @@ -56,13 +56,6 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: result = ProcessResult() - master_instance_model = get_fleet_master_instance(instance_model) - if need_to_wait_fleet_provisioning(instance_model, master_instance_model): - logger.debug( - "%s: waiting for the first instance in the fleet to be provisioned", - fmt(instance_model), - ) - return result try: instance_configuration = get_instance_configuration(instance_model) @@ -84,18 +77,63 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: ) return result - # The placement group is determined when provisioning the master instance - # and used for all other instances in the fleet. - async with get_session_ctx() as session: - placement_group_models = await get_fleet_placement_group_models( - session=session, - fleet_id=instance_model.fleet_id, - ) - placement_group_model = get_placement_group_model_for_instance( - placement_group_models=placement_group_models, - instance_model=instance_model, - master_instance_model=master_instance_model, - ) + current_master_instance_model = None + master_job_provisioning_data = None + placement_group_models: list[PlacementGroupModel] = [] + placement_group_model = None + if instance_model.fleet is not None and is_cloud_cluster(instance_model.fleet): + assert instance_model.fleet_id is not None + async with get_session_ctx() as session: + current_master_instance_model = await _load_current_master_instance( + session=session, + fleet_id=instance_model.fleet_id, + ) + placement_group_models = await get_fleet_placement_group_models( + session=session, + fleet_id=instance_model.fleet_id, + ) + if current_master_instance_model is None: + # FleetPipeline elects the current master. Until it does, instance + # workers must wait instead of trying to coordinate bootstrap. + logger.debug( + "%s: waiting for fleet pipeline to elect current cluster master", + fmt(instance_model), + ) + return result + if current_master_instance_model.id != instance_model.id: + if ( + current_master_instance_model.deleted + or current_master_instance_model.status == InstanceStatus.TERMINATED + ): + # Master failover is also owned by FleetPipeline. InstancePipeline + # only terminates the current instance and waits for the next fleet tick. + logger.debug( + "%s: waiting for fleet pipeline to replace current master %s", + fmt(instance_model), + current_master_instance_model.id, + ) + return result + master_job_provisioning_data = get_instance_provisioning_data( + current_master_instance_model + ) + if master_job_provisioning_data is None: + logger.debug( + "%s: waiting for current master %s to determine cluster placement", + fmt(instance_model), + current_master_instance_model.id, + ) + return result + # Non-master instances only reuse the placement group chosen by the + # current master. They never create a new placement group themselves. + placement_group_model = _get_current_master_placement_group_model( + placement_group_models=placement_group_models, + fleet_id=instance_model.fleet_id, + ) + if placement_group_model is not None: + _populate_current_master_placement_group_relations( + placement_group_model=placement_group_model, + instance_model=instance_model, + ) offers = await get_create_instance_offers( project=instance_model.project, profile=profile, @@ -104,6 +142,7 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: placement_group=placement_group_model_to_placement_group_optional(placement_group_model), blocks="auto" if instance_model.total_blocks is None else instance_model.total_blocks, exclude_not_available=True, + master_job_provisioning_data=master_job_provisioning_data, ) # Limit number of offers tried to prevent long-running processing in case all offers fail. @@ -112,15 +151,18 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: continue compute = backend.compute() assert isinstance(compute, ComputeWithCreateInstanceSupport) - instance_offer = get_instance_offer_for_instance( - instance_offer=instance_offer, - instance_model=instance_model, - master_instance_model=master_instance_model, - ) + if master_job_provisioning_data is not None: + # Shared offer lookup already restricts backend and region from the master. + # Availability zone still has to be narrowed per offer. + instance_offer = get_instance_offer_with_restricted_az( + instance_offer=instance_offer, + master_job_provisioning_data=master_job_provisioning_data, + ) if ( instance_model.fleet is not None and is_cloud_cluster(instance_model.fleet) - and instance_model.id == master_instance_model.id + and current_master_instance_model is not None + and current_master_instance_model.id == instance_model.id and instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT and isinstance(compute, ComputeWithPlacementGroupSupport) and ( @@ -190,7 +232,11 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: result.instance_update_map["total_blocks"] = instance_offer.total_blocks result.instance_update_map["started_at"] = NOW_PLACEHOLDER - if instance_model.fleet_id is not None and instance_model.id == master_instance_model.id: + if ( + instance_model.fleet_id is not None + and current_master_instance_model is not None + and current_master_instance_model.id == instance_model.id + ): # Clean up placement groups that did not end up being used. result.schedule_pg_deletion_fleet_id = instance_model.fleet_id if placement_group_model is not None: @@ -204,36 +250,64 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: termination_reason=InstanceTerminationReason.NO_OFFERS, termination_reason_message="All offers failed" if offers else "No offers found", ) - if ( - instance_model.fleet is not None - and instance_model.id == master_instance_model.id - and is_cloud_cluster(instance_model.fleet) - ): - # Do not attempt to deploy other instances, as they won't determine the correct cluster - # backend, region, and placement group without a successfully deployed master instance. - for sibling_instance_model in instance_model.fleet.instances: - if sibling_instance_model.id == instance_model.id: - continue - sibling_update_map = SiblingInstanceUpdateMap(id=sibling_instance_model.id) - if set_status_update( - update_map=sibling_update_map, - instance_model=sibling_instance_model, - new_status=InstanceStatus.TERMINATED, - termination_reason=InstanceTerminationReason.MASTER_FAILED, - ): - result.sibling_update_rows.append(sibling_update_map) - append_sibling_status_event( - deferred_events=result.sibling_deferred_events, - instance_model=sibling_instance_model, - new_status=InstanceStatus.TERMINATED, - termination_reason=sibling_update_map.get("termination_reason"), - termination_reason_message=sibling_update_map.get( - "termination_reason_message" - ), - ) return result +async def _load_current_master_instance( + session: AsyncSession, + fleet_id: uuid.UUID, +) -> Optional[InstanceModel]: + res = await session.execute( + select(FleetModel.current_master_instance_id).where(FleetModel.id == fleet_id) + ) + current_master_instance_id = res.scalar_one_or_none() + if current_master_instance_id is None: + return None + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == current_master_instance_id, + ) + .options( + load_only( + InstanceModel.id, + InstanceModel.deleted, + InstanceModel.status, + InstanceModel.job_provisioning_data, + ) + ) + ) + return res.scalar_one_or_none() + + +def _get_current_master_placement_group_model( + placement_group_models: list[PlacementGroupModel], + fleet_id: uuid.UUID, +) -> Optional[PlacementGroupModel]: + if not placement_group_models: + return None + if len(placement_group_models) > 1: + logger.error( + "Expected 0 or 1 placement groups associated with fleet master %s, found %s." + " Using the first placement group for this provisioning attempt.", + fleet_id, + len(placement_group_models), + ) + return placement_group_models[0] + + +def _populate_current_master_placement_group_relations( + placement_group_model: PlacementGroupModel, + instance_model: InstanceModel, +) -> None: + # Placement groups are loaded in a separate session from the instance worker. + # Reattach the already-known project/fleet objects so later detached access + # can still build a PlacementGroup value object without lazy loading. + set_committed_value(placement_group_model, "project", instance_model.project) + if instance_model.fleet is not None: + set_committed_value(placement_group_model, "fleet", instance_model.fleet) + + async def _find_or_create_suitable_placement_group_model( instance_model: InstanceModel, placement_group_models: list[PlacementGroupModel], diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py index 65915c30e2..06a833007a 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py @@ -7,10 +7,8 @@ from paramiko.pkey import PKey from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.fleets import InstanceGroupPlacement from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import ( - InstanceOfferWithAvailability, InstanceStatus, InstanceTerminationReason, SSHKey, @@ -22,12 +20,7 @@ ) from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout from dstack._internal.server.models import FleetModel, InstanceModel, PlacementGroupModel -from dstack._internal.server.services.fleets import fleet_model_to_fleet, is_cloud_cluster -from dstack._internal.server.services.instances import ( - get_instance_provisioning_data, - get_instance_status_change_message, -) -from dstack._internal.server.services.offers import get_instance_offer_with_restricted_az +from dstack._internal.server.services.fleets import fleet_model_to_fleet from dstack._internal.utils.common import UNSET, Unset, get_current_datetime from dstack._internal.utils.ssh import pkey_from_str @@ -61,13 +54,6 @@ class InstanceUpdateMap(ItemUpdateMap, total=False): deleted_at: UpdateMapDateTime -class SiblingInstanceUpdateMap(TypedDict, total=False): - id: uuid.UUID - status: InstanceStatus - termination_reason: Optional[InstanceTerminationReason] - termination_reason_message: Optional[str] - - class HealthCheckCreate(TypedDict): instance_id: uuid.UUID collected_at: datetime.datetime @@ -75,19 +61,9 @@ class HealthCheckCreate(TypedDict): response: str -@dataclass -class SiblingDeferredEvent: - message: str - project_id: uuid.UUID - instance_id: uuid.UUID - instance_name: str - - @dataclass class ProcessResult: instance_update_map: InstanceUpdateMap = field(default_factory=InstanceUpdateMap) - sibling_update_rows: list[SiblingInstanceUpdateMap] = field(default_factory=list) - sibling_deferred_events: list[SiblingDeferredEvent] = field(default_factory=list) health_check_create: Optional[HealthCheckCreate] = None new_placement_group_models: list[PlacementGroupModel] = field(default_factory=list) schedule_pg_deletion_fleet_id: Optional[uuid.UUID] = None @@ -104,52 +80,6 @@ def can_terminate_fleet_instances_on_idle_duration(fleet_model: FleetModel) -> b return len(active_instances) > fleet.spec.configuration.nodes.min -def get_fleet_master_instance(instance_model: InstanceModel) -> InstanceModel: - if instance_model.fleet is None: - return instance_model - fleet_instances = list(instance_model.fleet.instances) - if all(fleet_instance.id != instance_model.id for fleet_instance in fleet_instances): - fleet_instances.append(instance_model) - return min( - fleet_instances, - key=lambda fleet_instance: (fleet_instance.instance_num, fleet_instance.created_at), - ) - - -def need_to_wait_fleet_provisioning( - instance_model: InstanceModel, - master_instance_model: InstanceModel, -) -> bool: - # Cluster cloud instances should wait for the first fleet instance to be provisioned - # so that they are provisioned in the same backend/region. - if instance_model.fleet is None: - return False - if ( - instance_model.id == master_instance_model.id - or master_instance_model.job_provisioning_data is not None - or master_instance_model.status == InstanceStatus.TERMINATED - ): - return False - return is_cloud_cluster(instance_model.fleet) - - -def get_instance_offer_for_instance( - instance_offer: InstanceOfferWithAvailability, - instance_model: InstanceModel, - master_instance_model: InstanceModel, -) -> InstanceOfferWithAvailability: - if instance_model.fleet is None: - return instance_offer - fleet = fleet_model_to_fleet(instance_model.fleet) - if fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER: - master_job_provisioning_data = get_instance_provisioning_data(master_instance_model) - return get_instance_offer_with_restricted_az( - instance_offer=instance_offer, - master_job_provisioning_data=master_job_provisioning_data, - ) - return instance_offer - - def get_instance_idle_duration(instance_model: InstanceModel) -> datetime.timedelta: last_time = instance_model.created_at if instance_model.last_job_processed_at is not None: @@ -182,7 +112,7 @@ def ssh_keys_to_pkeys(ssh_keys: list[SSHKey]) -> list[PKey]: def set_status_update( - update_map: Union[InstanceUpdateMap, SiblingInstanceUpdateMap], + update_map: InstanceUpdateMap, instance_model: InstanceModel, new_status: InstanceStatus, termination_reason: Union[Optional[InstanceTerminationReason], Unset] = UNSET, @@ -236,27 +166,3 @@ def set_unreachable_update( return False update_map["unreachable"] = unreachable return True - - -def append_sibling_status_event( - deferred_events: list[SiblingDeferredEvent], - instance_model: InstanceModel, - new_status: InstanceStatus, - termination_reason: Optional[InstanceTerminationReason], - termination_reason_message: Optional[str], -) -> None: - if instance_model.status == new_status: - return - deferred_events.append( - SiblingDeferredEvent( - message=get_instance_status_change_message( - old_status=instance_model.status, - new_status=new_status, - termination_reason=termination_reason, - termination_reason_message=termination_reason_message, - ), - project_id=instance_model.project_id, - instance_id=instance_model.id, - instance_name=instance_model.name, - ) - ) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/instances.py b/src/dstack/_internal/server/background/scheduled_tasks/instances.py index b3de9cb305..1857e0ad09 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/instances.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/instances.py @@ -672,6 +672,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No if instance.fleet and instance.id == master_instance.id and is_cloud_cluster(instance.fleet): # Do not attempt to deploy other instances, as they won't determine the correct cluster # backend, region, and placement group without a successfully deployed master instance + # FIXME: Race condition with siblings processed concurrently. for sibling_instance in instance.fleet.instances: if sibling_instance.id == instance.id: continue diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_04_1015_9cb8e4e4d986_add_fleet_current_master_instance.py b/src/dstack/_internal/server/migrations/versions/2026/03_04_1015_9cb8e4e4d986_add_fleet_current_master_instance.py new file mode 100644 index 0000000000..db8653a65c --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_04_1015_9cb8e4e4d986_add_fleet_current_master_instance.py @@ -0,0 +1,45 @@ +"""Add FleetModel current master instance + +Revision ID: 9cb8e4e4d986 +Revises: 297c68450cc8 +Create Date: 2026-03-04 10:15:00.000000+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +# revision identifiers, used by Alembic. +revision = "9cb8e4e4d986" +down_revision = "297c68450cc8" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("fleets", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "current_master_instance_id", + sqlalchemy_utils.types.uuid.UUIDType(binary=False), + nullable=True, + ) + ) + batch_op.create_index( + batch_op.f("ix_fleets_current_master_instance_id"), + ["current_master_instance_id"], + unique=False, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("fleets", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_fleets_current_master_instance_id")) + batch_op.drop_column("current_master_instance_id") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 285338cb77..3d85169ad0 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -602,7 +602,14 @@ class FleetModel(PipelineModelMixin, BaseModel): runs: Mapped[List["RunModel"]] = relationship(back_populates="fleet") jobs: Mapped[List["JobModel"]] = relationship(back_populates="fleet") - instances: Mapped[List["InstanceModel"]] = relationship(back_populates="fleet") + instances: Mapped[List["InstanceModel"]] = relationship( + back_populates="fleet", + foreign_keys="InstanceModel.fleet_id", + ) + + current_master_instance_id: Mapped[Optional[uuid.UUID]] = mapped_column( + UUIDType(binary=False), index=True + ) # `consolidation_attempt` counts how many times in a row fleet needed consolidation. # Allows increasing delays between attempts. @@ -647,7 +654,10 @@ class InstanceModel(PipelineModelMixin, BaseModel): pool: Mapped[Optional["PoolModel"]] = relationship(back_populates="instances") fleet_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("fleets.id"), index=True) - fleet: Mapped[Optional["FleetModel"]] = relationship(back_populates="instances") + fleet: Mapped[Optional["FleetModel"]] = relationship( + back_populates="instances", + foreign_keys=[fleet_id], + ) compute_group_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("compute_groups.id")) compute_group: Mapped[Optional["ComputeGroupModel"]] = relationship(back_populates="instances") diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index b7d4f68713..23e25c284d 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -426,19 +426,22 @@ async def get_create_instance_offers( fleet_model: Optional[FleetModel] = None, blocks: Union[int, Literal["auto"]] = 1, exclude_not_available: bool = False, + master_job_provisioning_data: Optional[JobProvisioningData] = None, ) -> List[Tuple[Backend, InstanceOfferWithAvailability]]: multinode = False - master_job_provisioning_data = None if fleet_spec is not None: multinode = fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER if fleet_model is not None: fleet = fleet_model_to_fleet(fleet_model) multinode = fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER - for instance in fleet_model.instances: - jpd = instances_services.get_instance_provisioning_data(instance) - if jpd is not None: - master_job_provisioning_data = jpd - break + # The caller may override the current cluster master explicitly instead + # of inferring placement restrictions from the loaded fleet instances. + if master_job_provisioning_data is None: + for instance in fleet_model.instances: + jpd = instances_services.get_instance_provisioning_data(instance) + if jpd is not None: + master_job_provisioning_data = jpd + break offers = await offers_services.get_offers_by_requirements( project=project, diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index 3b433ea047..913abb8a30 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -8,7 +8,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from dstack._internal.core.models.fleets import FleetNodesSpec, FleetStatus +from dstack._internal.core.models.fleets import ( + FleetNodesSpec, + FleetStatus, + InstanceGroupPlacement, +) from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.runs import RunStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole @@ -31,7 +35,9 @@ create_repo, create_run, create_user, + get_fleet_configuration, get_fleet_spec, + get_job_provisioning_data, get_ssh_fleet_configuration, ) from dstack._internal.utils.common import get_current_datetime @@ -65,6 +71,12 @@ def _fleet_to_pipeline_item(fleet: FleetModel) -> PipelineItem: ) +async def _lock_fleet_for_processing(session: AsyncSession, fleet: FleetModel) -> None: + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestFleetFetcher: @@ -385,6 +397,221 @@ async def test_lock_helper_uses_fresh_current_instances_instead_of_stale_relatio assert len(locked_instances) == 2 assert {instance.instance_num for instance in locked_instances} == {0, 1} + async def test_syncs_initial_current_master_for_cluster_fleet( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + first_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=1, + ) + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert fleet.current_master_instance_id == first_instance.id + + async def test_keeps_current_master_when_it_is_still_active( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + current_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PROVISIONING, + job_provisioning_data=get_job_provisioning_data(), + instance_num=1, + ) + fleet.current_master_instance_id = current_master.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert fleet.current_master_instance_id == current_master.id + + async def test_promotes_provisioned_survivor_when_current_master_terminated( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=1, target=1, max=2), + ) + ), + ) + terminated_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + provisioned_survivor = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + job_provisioning_data=get_job_provisioning_data(), + instance_num=1, + ) + fleet.current_master_instance_id = terminated_master.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(terminated_master) + assert terminated_master.deleted + assert fleet.current_master_instance_id == provisioned_survivor.id + + async def test_promotes_next_bootstrap_candidate_when_current_master_terminated( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=1, target=1, max=2), + ) + ), + ) + terminated_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + next_candidate = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=1, + ) + fleet.current_master_instance_id = terminated_master.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert fleet.current_master_instance_id == next_candidate.id + + async def test_clears_current_master_for_non_cluster_fleet( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec(), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + fleet.current_master_instance_id = instance.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert fleet.current_master_instance_id is None + + async def test_syncs_current_master_after_creating_missing_instances( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + instances = ( + ( + await session.execute( + select(InstanceModel) + .where(InstanceModel.fleet_id == fleet.id, InstanceModel.deleted == False) + .order_by(InstanceModel.instance_num, InstanceModel.created_at) + ) + ) + .scalars() + .all() + ) + assert len(instances) == 2 + assert fleet.current_master_instance_id == instances[0].id + async def test_deletes_empty_autocreated_fleet( self, test_db, session: AsyncSession, worker: FleetWorker ): diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py index e34a758456..afcb75336b 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_cloud_provisioning.py @@ -1,5 +1,3 @@ -import datetime as dt -from collections import defaultdict from typing import Optional from unittest.mock import Mock, patch @@ -29,6 +27,7 @@ ComputeMockSpec, create_fleet, create_instance, + create_placement_group, create_project, get_fleet_configuration, get_fleet_spec, @@ -37,10 +36,17 @@ get_placement_group_provisioning_data, ) from tests._internal.server.background.pipeline_tasks.test_instances.helpers import ( + instance_to_pipeline_item, + lock_instance, process_instance, ) +async def _set_current_master_instance(session: AsyncSession, fleet, instance) -> None: + fleet.current_master_instance_id = None if instance is None else instance.id + await session.commit() + + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestCloudProvisioning: @@ -209,31 +215,99 @@ async def test_fails_if_no_offers( assert instance.status == InstanceStatus.TERMINATED assert instance.termination_reason == InstanceTerminationReason.NO_OFFERS - @pytest.mark.parametrize( - ("placement", "expected_termination_reasons"), - [ - pytest.param( - InstanceGroupPlacement.CLUSTER, - { - InstanceTerminationReason.NO_OFFERS: 1, - InstanceTerminationReason.MASTER_FAILED: 3, - }, - id="cluster", + async def test_waits_when_fleet_has_no_current_master( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) ), - pytest.param( - None, - {InstanceTerminationReason.NO_OFFERS: 4}, - id="non-cluster", + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=0, + ) + + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, instance) + + await session.refresh(instance) + assert instance.status == InstanceStatus.PENDING + assert backend_mock.compute.return_value.create_instance.call_count == 0 + + async def test_waits_for_current_master_to_determine_cluster_placement( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) ), - ], - ) - async def test_terminates_cluster_instances_if_master_not_created( + ) + master_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=0, + ) + sibling_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=1, + ) + await _set_current_master_instance(session, fleet, master_instance) + + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, sibling_instance) + + await session.refresh(master_instance) + await session.refresh(sibling_instance) + assert master_instance.status == InstanceStatus.PENDING + assert sibling_instance.status == InstanceStatus.PENDING + assert backend_mock.compute.return_value.create_instance.call_count == 0 + + async def test_failed_master_does_not_provision_stale_sibling_until_fleet_reassigns_it( self, test_db, session: AsyncSession, worker: InstanceWorker, - placement: Optional[InstanceGroupPlacement], - expected_termination_reasons: dict[str, int], ): project = await create_project(session=session) fleet = await create_fleet( @@ -241,35 +315,376 @@ async def test_terminates_cluster_instances_if_master_not_created( project, spec=get_fleet_spec( conf=get_fleet_configuration( - placement=placement, nodes=FleetNodesSpec(min=4, target=4, max=4) + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), ) ), ) - instances = [ - await create_instance( - session=session, - project=project, - fleet=fleet, - status=InstanceStatus.PENDING, - offer=None, - job_provisioning_data=None, - instance_num=index, - created_at=dt.datetime.now(dt.timezone.utc) + dt.timedelta(seconds=index), + master_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=0, + ) + sibling_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=1, + ) + await _set_current_master_instance(session, fleet, master_instance) + + lock_instance(master_instance) + lock_instance(sibling_instance) + await session.commit() + master_item = instance_to_pipeline_item(master_instance) + sibling_item = instance_to_pipeline_item(sibling_instance) + + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [] + await worker.process(master_item) + + await session.refresh(master_instance) + await session.refresh(sibling_instance) + assert master_instance.status == InstanceStatus.TERMINATED + assert master_instance.termination_reason == InstanceTerminationReason.NO_OFFERS + assert sibling_instance.status == InstanceStatus.PENDING + + gcp_mock = Mock() + gcp_mock.TYPE = BackendType.GCP + gcp_mock.compute.return_value = Mock(spec=ComputeMockSpec) + gcp_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.GCP, region="us-central1") + ] + gcp_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=BackendType.GCP, + region="us-central1", + ) + aws_mock = Mock() + aws_mock.TYPE = BackendType.AWS + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + aws_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, region="us-east-1") + ] + aws_mock.compute.return_value.create_placement_group.return_value = ( + get_placement_group_provisioning_data() + ) + aws_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ) + + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [gcp_mock, aws_mock] + await worker.process(sibling_item) + + await session.refresh(sibling_instance) + assert sibling_instance.status == InstanceStatus.PENDING + assert gcp_mock.compute.return_value.get_offers.call_count == 0 + assert gcp_mock.compute.return_value.create_instance.call_count == 0 + assert aws_mock.compute.return_value.create_instance.call_count == 0 + + await _set_current_master_instance(session, fleet, sibling_instance) + promoted_backend_mock = Mock() + promoted_backend_mock.TYPE = BackendType.AWS + promoted_backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + promoted_backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, region="us-east-1") + ] + promoted_backend_mock.compute.return_value.create_placement_group.return_value = ( + get_placement_group_provisioning_data() + ) + promoted_backend_mock.compute.return_value.create_instance.return_value = ( + get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", ) - for index in range(4) + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [promoted_backend_mock] + await process_instance(session, worker, sibling_instance) + + await session.refresh(sibling_instance) + assert sibling_instance.status == InstanceStatus.PROVISIONING + assert sibling_instance.backend == BackendType.AWS + assert sibling_instance.region == "us-east-1" + assert promoted_backend_mock.compute.return_value.create_instance.call_count == 1 + + async def test_follows_current_master_backend_and_region_constraints( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + master_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + job_provisioning_data=get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ), + instance_num=0, + ) + sibling_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=1, + ) + await _set_current_master_instance(session, fleet, master_instance) + + gcp_mock = Mock() + gcp_mock.TYPE = BackendType.GCP + gcp_mock.compute.return_value = Mock(spec=ComputeMockSpec) + gcp_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.GCP, region="us-central1") ] + gcp_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=BackendType.GCP, + region="us-central1", + ) + aws_mock = Mock() + aws_mock.TYPE = BackendType.AWS + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + aws_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, region="us-east-1") + ] + aws_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ) with patch("dstack._internal.server.services.backends.get_project_backends") as m: - m.return_value = [] - for instance in sorted(instances, key=lambda i: (i.instance_num, i.created_at)): - await session.refresh(instance) - await process_instance(session, worker, instance) + m.return_value = [gcp_mock, aws_mock] + await process_instance(session, worker, sibling_instance) + + await session.refresh(sibling_instance) + assert sibling_instance.status == InstanceStatus.PROVISIONING + assert sibling_instance.backend == BackendType.AWS + assert sibling_instance.region == "us-east-1" + assert gcp_mock.compute.return_value.get_offers.call_count == 0 + assert gcp_mock.compute.return_value.create_instance.call_count == 0 + assert aws_mock.compute.return_value.create_instance.call_count == 1 + + async def test_non_master_does_not_create_new_placement_group_without_master_pg( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + master_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + job_provisioning_data=get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ), + instance_num=0, + ) + sibling_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=1, + ) + await _set_current_master_instance(session, fleet, master_instance) + + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, region="us-east-1") + ] + backend_mock.compute.return_value.is_suitable_placement_group.return_value = True + backend_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, sibling_instance) + + await session.refresh(sibling_instance) + assert sibling_instance.status == InstanceStatus.PROVISIONING + assert backend_mock.compute.return_value.create_placement_group.call_count == 0 + placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() + assert len(placement_groups) == 0 + + async def test_non_master_reuses_existing_current_master_placement_group( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=3, target=3, max=3), + ) + ), + ) + master_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + job_provisioning_data=get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ), + instance_num=0, + ) + current_master_pg = await create_placement_group( + session=session, + project=project, + fleet=fleet, + ) + sibling_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=1, + ) + await _set_current_master_instance(session, fleet, master_instance) + + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, region="us-east-1") + ] + backend_mock.compute.return_value.is_suitable_placement_group.return_value = True + backend_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, sibling_instance) + + await session.refresh(sibling_instance) + assert sibling_instance.status == InstanceStatus.PROVISIONING + assert backend_mock.compute.return_value.create_placement_group.call_count == 0 + create_call = backend_mock.compute.return_value.create_instance.call_args + assert create_call is not None + assert create_call.args[2] is not None + assert create_call.args[2].name == current_master_pg.name + placement_groups = (await session.execute(select(PlacementGroupModel))).scalars().all() + assert len(placement_groups) == 1 + + async def test_allows_parallel_processing_after_master_is_provisioned( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session, + project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=3, target=3, max=3), + ) + ), + ) + master_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + job_provisioning_data=get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ), + instance_num=0, + ) + later_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=2, + ) + earlier_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + offer=None, + job_provisioning_data=None, + instance_num=1, + ) + await _set_current_master_instance(session, fleet, master_instance) + + backend_mock = Mock() + backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.get_offers.return_value = [ + get_instance_offer_with_availability(backend=BackendType.AWS, region="us-east-1") + ] + backend_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( + backend=BackendType.AWS, + region="us-east-1", + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + m.return_value = [backend_mock] + await process_instance(session, worker, later_instance) + assert backend_mock.compute.return_value.create_instance.call_count == 1 + await process_instance(session, worker, earlier_instance) - termination_reasons = defaultdict(int) - for instance in instances: - await session.refresh(instance) - assert instance.status == InstanceStatus.TERMINATED - termination_reasons[instance.termination_reason] += 1 - assert termination_reasons == expected_termination_reasons + await session.refresh(later_instance) + await session.refresh(earlier_instance) + assert later_instance.status == InstanceStatus.PROVISIONING + assert earlier_instance.status == InstanceStatus.PROVISIONING + assert backend_mock.compute.return_value.create_instance.call_count == 2 @pytest.mark.parametrize( ("placement", "should_create"), @@ -304,6 +719,8 @@ async def test_create_placement_group_if_placement_cluster( offer=None, job_provisioning_data=None, ) + if placement == InstanceGroupPlacement.CLUSTER: + await _set_current_master_instance(session, fleet, instance) backend_mock = Mock() backend_mock.TYPE = BackendType.AWS backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) @@ -357,6 +774,7 @@ async def test_reuses_placement_group_between_offers_if_the_group_is_suitable( offer=None, job_provisioning_data=None, ) + await _set_current_master_instance(session, fleet, instance) backend_mock = Mock() backend_mock.TYPE = BackendType.AWS backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) @@ -421,6 +839,7 @@ async def test_handles_create_placement_group_errors( offer=None, job_provisioning_data=None, ) + await _set_current_master_instance(session, fleet, instance) backend_mock = Mock() backend_mock.TYPE = BackendType.AWS backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) From 30ca68f465080ea1f7f7154693d7b8ca55e31575 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 5 Mar 2026 10:54:55 +0500 Subject: [PATCH 37/51] Refactor current_master_instance --- .../background/pipeline_tasks/fleets.py | 159 +++++++++-------- .../instances/cloud_provisioning.py | 163 +++++++++++------- .../scheduled_tasks/submitted_jobs.py | 4 + ...4d986_add_fleet_current_master_instance.py | 8 - ...dd_ix_fleets_current_master_instance_id.py | 42 +++++ .../_internal/server/services/fleets.py | 5 + .../_internal/server/services/instances.py | 4 +- .../background/pipeline_tasks/test_fleets.py | 161 +++++++---------- 8 files changed, 294 insertions(+), 252 deletions(-) create mode 100644 src/dstack/_internal/server/migrations/versions/2026/03_05_0545_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index df0d5b809a..412ec353f9 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -11,7 +11,6 @@ from dstack._internal.core.models.fleets import ( FleetSpec, FleetStatus, - InstanceGroupPlacement, ) from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason from dstack._internal.core.models.runs import RunStatus @@ -45,6 +44,7 @@ emit_fleet_status_change_event, get_fleet_spec, get_next_instance_num, + is_cloud_cluster, is_fleet_empty, is_fleet_in_use, ) @@ -272,20 +272,12 @@ async def process(self, item: PipelineItem): update(InstanceModel), instance_update_rows, ) - if result.new_instances_count > 0: + if len(result.new_instance_creates) > 0: await _create_missing_fleet_instances( session=session, fleet_model=fleet_model, - new_instances_count=result.new_instances_count, + new_instance_creates=result.new_instance_creates, ) - await session.flush() - # FleetPipeline is the sole owner of cluster master election. - # Sync it after instance updates and creation so the pointer reflects - # the final post-consolidation fleet state that will be committed. - await _sync_current_master_instance( - session=session, - fleet_model_id=fleet_model.id, - ) emit_fleet_status_change_event( session=session, fleet_model=fleet_model, @@ -302,6 +294,7 @@ class _FleetUpdateMap(ItemUpdateMap, total=False): deleted_at: UpdateMapDateTime consolidation_attempt: int last_consolidated_at: UpdateMapDateTime + current_master_instance_id: Optional[uuid.UUID] class _InstanceUpdateMap(TypedDict, total=False): @@ -398,18 +391,23 @@ async def _lock_fleet_instances_for_consolidation( class _ProcessResult: fleet_update_map: _FleetUpdateMap = field(default_factory=_FleetUpdateMap) instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) - new_instances_count: int = 0 + new_instance_creates: list["_NewInstanceCreate"] = field(default_factory=list) + + +class _NewInstanceCreate(TypedDict): + id: uuid.UUID + instance_num: int @dataclass class _MaintainNodesResult: instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) - new_instances_count: int = 0 + new_instance_creates: list[_NewInstanceCreate] = field(default_factory=list) changes_required: bool = False @property def has_changes(self) -> bool: - return len(self.instance_id_to_update_map) > 0 or self.new_instances_count > 0 + return len(self.instance_id_to_update_map) > 0 or len(self.new_instance_creates) > 0 async def _process_fleet( @@ -418,36 +416,40 @@ async def _process_fleet( consolidation_instances: Optional[Sequence[InstanceModel]] = None, ) -> _ProcessResult: result = _ProcessResult() + effective_instances = list(consolidation_instances or fleet_model.instances) if consolidation_fleet_spec is not None: result = _consolidate_fleet_state_with_spec( fleet_model, consolidation_fleet_spec=consolidation_fleet_spec, - consolidation_instances=consolidation_instances, + consolidation_instances=effective_instances, ) - if result.new_instances_count > 0: - # Avoid deleting fleets that are about to provision new instances. - return result - delete = _should_delete_fleet(fleet_model) - if delete: + if len(result.new_instance_creates) == 0 and _should_delete_fleet(fleet_model): result.fleet_update_map["status"] = FleetStatus.TERMINATED result.fleet_update_map["deleted"] = True result.fleet_update_map["deleted_at"] = NOW_PLACEHOLDER + _set_current_master_instance_id_update( + fleet_model=fleet_model, + fleet_update_map=result.fleet_update_map, + instance_models=effective_instances, + instance_id_to_update_map=result.instance_id_to_update_map, + new_instance_creates=result.new_instance_creates, + ) return result def _consolidate_fleet_state_with_spec( fleet_model: FleetModel, consolidation_fleet_spec: FleetSpec, - consolidation_instances: Optional[Sequence[InstanceModel]] = None, + consolidation_instances: Sequence[InstanceModel], ) -> _ProcessResult: result = _ProcessResult() maintain_nodes_result = _maintain_fleet_nodes_in_min_max_range( - instances=consolidation_instances or fleet_model.instances, + instances=consolidation_instances, fleet_spec=consolidation_fleet_spec, ) if maintain_nodes_result.has_changes: result.instance_id_to_update_map = maintain_nodes_result.instance_id_to_update_map - result.new_instances_count = maintain_nodes_result.new_instances_count + result.new_instance_creates = maintain_nodes_result.new_instance_creates if maintain_nodes_result.changes_required: result.fleet_update_map["consolidation_attempt"] = fleet_model.consolidation_attempt + 1 else: @@ -507,7 +509,13 @@ def _maintain_fleet_nodes_in_min_max_range( if active_instances_num < fleet_spec.configuration.nodes.min: result.changes_required = True nodes_missing = fleet_spec.configuration.nodes.min - active_instances_num - result.new_instances_count = nodes_missing + taken_instance_nums = {instance.instance_num for instance in active_instances} + for _ in range(nodes_missing): + instance_num = get_next_instance_num(taken_instance_nums) + taken_instance_nums.add(instance_num) + result.new_instance_creates.append( + _NewInstanceCreate(id=uuid.uuid4(), instance_num=instance_num) + ) return result if ( fleet_spec.configuration.nodes.max is None @@ -572,28 +580,20 @@ def _build_instance_update_rows( async def _create_missing_fleet_instances( session: AsyncSession, fleet_model: FleetModel, - new_instances_count: int, + new_instance_creates: Sequence[_NewInstanceCreate], ): fleet_spec = get_fleet_spec(fleet_model) - res = await session.execute( - select(InstanceModel.instance_num).where( - InstanceModel.fleet_id == fleet_model.id, - InstanceModel.deleted == False, - ) - ) - taken_instance_nums = set(res.scalars().all()) - for _ in range(new_instances_count): - instance_num = get_next_instance_num(taken_instance_nums) + for new_instance_create in new_instance_creates: instance_model = create_fleet_instance_model( session=session, project=fleet_model.project, # TODO: Store fleet.user and pass it instead of the project owner. username=fleet_model.project.owner.name, spec=fleet_spec, - instance_num=instance_num, + instance_num=new_instance_create["instance_num"], + instance_id=new_instance_create["id"], ) instance_model.fleet_id = fleet_model.id - taken_instance_nums.add(instance_num) events.emit( session=session, message=( @@ -605,82 +605,77 @@ async def _create_missing_fleet_instances( ) logger.info( "Added %d instances to fleet %s", - new_instances_count, + len(new_instance_creates), fleet_model.name, ) -async def _sync_current_master_instance( - session: AsyncSession, - fleet_model_id: uuid.UUID, +def _set_current_master_instance_id_update( + fleet_model: FleetModel, + fleet_update_map: _FleetUpdateMap, + instance_models: Sequence[InstanceModel], + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], + new_instance_creates: Sequence[_NewInstanceCreate], ) -> None: - fleet_model = await session.get(FleetModel, fleet_model_id) - if fleet_model is None: + if not is_cloud_cluster(fleet_model): + fleet_update_map["current_master_instance_id"] = None return - - new_current_master_instance_id = None - fleet_spec = get_fleet_spec(fleet_model) - is_cluster = ( - fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER - and fleet_spec.configuration.ssh_config is None + surviving_instance_models = _get_surviving_instance_models_after_updates( + instance_models=instance_models, + instance_id_to_update_map=instance_id_to_update_map, ) - if not fleet_model.deleted and is_cluster: - res = await session.execute( - select(InstanceModel) - .where( - InstanceModel.fleet_id == fleet_model_id, - InstanceModel.deleted == False, - ) - .order_by(InstanceModel.instance_num, InstanceModel.created_at) - .options( - load_only( - InstanceModel.id, - InstanceModel.status, - InstanceModel.job_provisioning_data, - ) - ) - ) - current_instance_models = list(res.scalars().all()) - new_current_master_instance_id = _select_current_master_instance_id( - current_master_instance_id=fleet_model.current_master_instance_id, - instance_models=current_instance_models, - ) + current_master_instance_id = _select_current_master_instance_id( + current_master_instance_id=fleet_model.current_master_instance_id, + surviving_instance_models=surviving_instance_models, + new_instance_creates=new_instance_creates, + ) + fleet_update_map["current_master_instance_id"] = current_master_instance_id - if fleet_model.current_master_instance_id == new_current_master_instance_id: - return - await session.execute( - update(FleetModel) - .where(FleetModel.id == fleet_model_id) - .values(current_master_instance_id=new_current_master_instance_id) - ) +def _get_surviving_instance_models_after_updates( + instance_models: Sequence[InstanceModel], + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], +) -> list[InstanceModel]: + surviving_instance_models = [] + for instance_model in sorted(instance_models, key=lambda i: (i.instance_num, i.created_at)): + instance_update_map = instance_id_to_update_map.get(instance_model.id) + if instance_update_map is not None and instance_update_map.get("deleted"): + continue + surviving_instance_models.append(instance_model) + return surviving_instance_models def _select_current_master_instance_id( current_master_instance_id: Optional[uuid.UUID], - instance_models: Sequence[InstanceModel], + surviving_instance_models: Sequence[InstanceModel], + new_instance_creates: Sequence[_NewInstanceCreate], ) -> Optional[uuid.UUID]: # Keep the current master stable while it is still alive so InstancePipeline # does not see fleet-wide election churn between provisioning attempts. if current_master_instance_id is not None: - for instance_model in instance_models: + for instance_model in surviving_instance_models: if ( instance_model.id == current_master_instance_id and instance_model.status != InstanceStatus.TERMINATED ): return instance_model.id - # If the old master is gone, prefer a surviving provisioned instance since it - # already defines backend/region/AZ for the current cluster generation. - for instance_model in instance_models: + # If the old master is gone, prefer a surviving provisioned instance so we + # keep following an already-established cluster placement decision. + for instance_model in surviving_instance_models: if ( instance_model.status != InstanceStatus.TERMINATED and instance_model.job_provisioning_data is not None ): return instance_model.id - for instance_model in instance_models: + # Prefer existing surviving instances over freshly planned replacements to + # avoid election churn during min-nodes backfill. + for instance_model in surviving_instance_models: if instance_model.status != InstanceStatus.TERMINATED: return instance_model.id + for new_instance_create in sorted(new_instance_creates, key=lambda i: i["instance_num"]): + return new_instance_create["id"] + return None diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py index 17d59747ce..eac6f3ff30 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py @@ -1,4 +1,5 @@ import uuid +from dataclasses import dataclass from typing import Optional from pydantic import ValidationError @@ -26,6 +27,7 @@ InstanceTerminationReason, ) from dstack._internal.core.models.placement import PlacementGroupConfiguration, PlacementStrategy +from dstack._internal.core.models.runs import JobProvisioningData from dstack._internal.server import settings as server_settings from dstack._internal.server.background.pipeline_tasks.base import NOW_PLACEHOLDER from dstack._internal.server.background.pipeline_tasks.instances.common import ( @@ -54,6 +56,13 @@ logger = get_logger(__name__) +@dataclass +class _ClusterMasterContext: + current_master_instance_model: InstanceModel + is_current_instance_master: bool + master_job_provisioning_data: Optional[JobProvisioningData] + + async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: result = ProcessResult() @@ -77,63 +86,21 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: ) return result - current_master_instance_model = None - master_job_provisioning_data = None + cluster_context = None placement_group_models: list[PlacementGroupModel] = [] placement_group_model = None + master_job_provisioning_data = None if instance_model.fleet is not None and is_cloud_cluster(instance_model.fleet): - assert instance_model.fleet_id is not None - async with get_session_ctx() as session: - current_master_instance_model = await _load_current_master_instance( - session=session, - fleet_id=instance_model.fleet_id, - ) - placement_group_models = await get_fleet_placement_group_models( - session=session, - fleet_id=instance_model.fleet_id, - ) - if current_master_instance_model is None: - # FleetPipeline elects the current master. Until it does, instance - # workers must wait instead of trying to coordinate bootstrap. - logger.debug( - "%s: waiting for fleet pipeline to elect current cluster master", - fmt(instance_model), - ) + cluster_context = await _get_cluster_master_context(instance_model) + if cluster_context is None: + # Waiting for the master return result - if current_master_instance_model.id != instance_model.id: - if ( - current_master_instance_model.deleted - or current_master_instance_model.status == InstanceStatus.TERMINATED - ): - # Master failover is also owned by FleetPipeline. InstancePipeline - # only terminates the current instance and waits for the next fleet tick. - logger.debug( - "%s: waiting for fleet pipeline to replace current master %s", - fmt(instance_model), - current_master_instance_model.id, - ) - return result - master_job_provisioning_data = get_instance_provisioning_data( - current_master_instance_model - ) - if master_job_provisioning_data is None: - logger.debug( - "%s: waiting for current master %s to determine cluster placement", - fmt(instance_model), - current_master_instance_model.id, - ) - return result - # Non-master instances only reuse the placement group chosen by the - # current master. They never create a new placement group themselves. - placement_group_model = _get_current_master_placement_group_model( - placement_group_models=placement_group_models, - fleet_id=instance_model.fleet_id, - ) - if placement_group_model is not None: - _populate_current_master_placement_group_relations( - placement_group_model=placement_group_model, - instance_model=instance_model, - ) + placement_group_models, placement_group_model = await _get_cluster_placement_context( + instance_model=instance_model, + cluster_context=cluster_context, + ) + master_job_provisioning_data = cluster_context.master_job_provisioning_data + offers = await get_create_instance_offers( project=instance_model.project, profile=profile, @@ -152,17 +119,15 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: compute = backend.compute() assert isinstance(compute, ComputeWithCreateInstanceSupport) if master_job_provisioning_data is not None: - # Shared offer lookup already restricts backend and region from the master. + # `get_create_instance_offers()` already restricts backend and region from the master. # Availability zone still has to be narrowed per offer. instance_offer = get_instance_offer_with_restricted_az( instance_offer=instance_offer, master_job_provisioning_data=master_job_provisioning_data, ) if ( - instance_model.fleet is not None - and is_cloud_cluster(instance_model.fleet) - and current_master_instance_model is not None - and current_master_instance_model.id == instance_model.id + cluster_context is not None + and cluster_context.is_current_instance_master and instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT and isinstance(compute, ComputeWithPlacementGroupSupport) and ( @@ -234,8 +199,8 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: if ( instance_model.fleet_id is not None - and current_master_instance_model is not None - and current_master_instance_model.id == instance_model.id + and cluster_context is not None + and cluster_context.is_current_instance_master ): # Clean up placement groups that did not end up being used. result.schedule_pg_deletion_fleet_id = instance_model.fleet_id @@ -253,6 +218,84 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: return result +async def _get_cluster_master_context( + instance_model: InstanceModel, +) -> Optional[_ClusterMasterContext]: + assert instance_model.fleet is not None and is_cloud_cluster(instance_model.fleet) + assert instance_model.fleet_id is not None + async with get_session_ctx() as session: + current_master_instance_model = await _load_current_master_instance( + session=session, + fleet_id=instance_model.fleet_id, + ) + if current_master_instance_model is None: + # FleetPipeline elects the current master. Until it does, instance + # workers must wait instead of trying to coordinate bootstrap. + logger.debug( + "%s: waiting for fleet pipeline to elect current cluster master", + fmt(instance_model), + ) + return None + + is_current_instance_master = current_master_instance_model.id == instance_model.id + master_job_provisioning_data = None + if not is_current_instance_master: + if ( + current_master_instance_model.deleted + or current_master_instance_model.status == InstanceStatus.TERMINATED + ): + # Master failover is also owned by FleetPipeline. InstancePipeline + # only terminates the current instance and waits for the next fleet tick. + logger.debug( + "%s: waiting for fleet pipeline to replace current master %s", + fmt(instance_model), + current_master_instance_model.id, + ) + return None + master_job_provisioning_data = get_instance_provisioning_data( + current_master_instance_model + ) + if master_job_provisioning_data is None: + logger.debug( + "%s: waiting for current master %s to determine cluster placement", + fmt(instance_model), + current_master_instance_model.id, + ) + return None + return _ClusterMasterContext( + current_master_instance_model=current_master_instance_model, + is_current_instance_master=is_current_instance_master, + master_job_provisioning_data=master_job_provisioning_data, + ) + + +async def _get_cluster_placement_context( + instance_model: InstanceModel, + cluster_context: _ClusterMasterContext, +) -> tuple[list[PlacementGroupModel], Optional[PlacementGroupModel]]: + assert instance_model.fleet is not None and is_cloud_cluster(instance_model.fleet) + assert instance_model.fleet_id is not None + async with get_session_ctx() as session: + placement_group_models = await get_fleet_placement_group_models( + session=session, + fleet_id=instance_model.fleet_id, + ) + placement_group_model = None + if not cluster_context.is_current_instance_master: + # Non-master instances only reuse the placement group chosen by the + # current master. They never create a new placement group themselves. + placement_group_model = _get_current_master_placement_group_model( + placement_group_models=placement_group_models, + fleet_id=instance_model.fleet_id, + ) + if placement_group_model is not None: + _populate_current_master_placement_group_relations( + placement_group_model=placement_group_model, + instance_model=instance_model, + ) + return placement_group_models, placement_group_model + + async def _load_current_master_instance( session: AsyncSession, fleet_id: uuid.UUID, diff --git a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py index 151f07deeb..729ded205c 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py @@ -574,6 +574,10 @@ async def _fetch_fleet_with_master_instance_provisioning_data( fleet_model: Optional[FleetModel], job: Job, ) -> Optional[JobProvisioningData]: + # TODO: When submitted-jobs provisioning moves to pipelines, stop inferring the + # cluster master from loaded fleet instances here. Resolve the current master via + # FleetModel.current_master_instance_id so jobs follow the same master election + # as FleetPipeline/InstancePipeline. master_instance_provisioning_data = None if is_master_job(job) and fleet_model is not None: fleet = fleet_model_to_fleet(fleet_model) diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_04_1015_9cb8e4e4d986_add_fleet_current_master_instance.py b/src/dstack/_internal/server/migrations/versions/2026/03_04_1015_9cb8e4e4d986_add_fleet_current_master_instance.py index db8653a65c..519b58c899 100644 --- a/src/dstack/_internal/server/migrations/versions/2026/03_04_1015_9cb8e4e4d986_add_fleet_current_master_instance.py +++ b/src/dstack/_internal/server/migrations/versions/2026/03_04_1015_9cb8e4e4d986_add_fleet_current_master_instance.py @@ -27,19 +27,11 @@ def upgrade() -> None: nullable=True, ) ) - batch_op.create_index( - batch_op.f("ix_fleets_current_master_instance_id"), - ["current_master_instance_id"], - unique=False, - ) - # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### with op.batch_alter_table("fleets", schema=None) as batch_op: - batch_op.drop_index(batch_op.f("ix_fleets_current_master_instance_id")) batch_op.drop_column("current_master_instance_id") - # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_05_0545_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py b/src/dstack/_internal/server/migrations/versions/2026/03_05_0545_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py new file mode 100644 index 0000000000..29f703c1f2 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_05_0545_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py @@ -0,0 +1,42 @@ +"""Add ix_fleets_current_master_instance_id index + +Revision ID: c7b0a8e57294 +Revises: 9cb8e4e4d986 +Create Date: 2026-03-05 05:45:00.000000+00:00 + +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "c7b0a8e57294" +down_revision = "9cb8e4e4d986" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.get_context().autocommit_block(): + op.drop_index( + "ix_fleets_current_master_instance_id", + table_name="fleets", + if_exists=True, + postgresql_concurrently=True, + ) + op.create_index( + "ix_fleets_current_master_instance_id", + "fleets", + ["current_master_instance_id"], + unique=False, + postgresql_concurrently=True, + ) + + +def downgrade() -> None: + with op.get_context().autocommit_block(): + op.drop_index( + "ix_fleets_current_master_instance_id", + table_name="fleets", + if_exists=True, + postgresql_concurrently=True, + ) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 23e25c284d..ea72f0f609 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -587,6 +587,7 @@ def create_fleet_instance_model( username: str, spec: FleetSpec, instance_num: int, + instance_id: Optional[uuid.UUID] = None, ) -> InstanceModel: profile = spec.merged_profile requirements = get_fleet_requirements(spec) @@ -598,6 +599,7 @@ def create_fleet_instance_model( requirements=requirements, instance_name=f"{spec.configuration.name}-{instance_num}", instance_num=instance_num, + instance_id=instance_id, reservation=spec.merged_profile.reservation, blocks=spec.configuration.blocks, tags=spec.configuration.tags, @@ -870,6 +872,9 @@ def get_fleet_master_instance_provisioning_data( ) -> Optional[JobProvisioningData]: master_instance_provisioning_data = None if fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER: + # TODO: This legacy helper infers the cluster master from fleet instances. + # Pipeline-based provisioning should use FleetModel.current_master_instance_id + # instead of relying on instance ordering in the loaded relationship. # Offers for master jobs must be in the same cluster as existing instances. fleet_instance_models = [im for im in fleet_model.instances if not im.deleted] if len(fleet_instance_models) > 0: diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index 051463c57d..45381ce01a 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -642,11 +642,13 @@ def create_instance_model( reservation: Optional[str], blocks: Union[Literal["auto"], int], tags: Optional[Dict[str, str]], + instance_id: Optional[uuid.UUID] = None, ) -> InstanceModel: termination_policy, termination_idle_time = get_termination( profile, DEFAULT_FLEET_TERMINATION_IDLE_TIME ) - instance_id = uuid.uuid4() + if instance_id is None: + instance_id = uuid.uuid4() project_ssh_key = SSHKey( public=project.ssh_public_key.strip(), private=project.ssh_private_key.strip(), diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index 913abb8a30..40eada4f07 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -6,7 +6,6 @@ import pytest from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload from dstack._internal.core.models.fleets import ( FleetNodesSpec, @@ -21,10 +20,7 @@ FleetFetcher, FleetPipeline, FleetWorker, - _get_fleet_spec_if_ready_for_consolidation, - _lock_fleet_instances_for_consolidation, ) -from dstack._internal.server.db import get_session_ctx from dstack._internal.server.models import FleetModel, InstanceModel from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( @@ -183,49 +179,6 @@ async def test_fetch_returns_oldest_fleets_first_up_to_limit( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestFleetWorker: - async def test_ready_for_consolidation_helper_returns_none_for_ssh_fleet( - self, test_db, session: AsyncSession - ): - project = await create_project(session) - fleet = await create_fleet( - session=session, - project=project, - spec=get_fleet_spec(conf=get_ssh_fleet_configuration()), - ) - - assert _get_fleet_spec_if_ready_for_consolidation(fleet) is None - - async def test_ready_for_consolidation_helper_returns_none_when_retry_delay_is_active( - self, test_db, session: AsyncSession - ): - project = await create_project(session) - fleet = await create_fleet( - session=session, - project=project, - spec=get_fleet_spec(), - ) - fleet.consolidation_attempt = 1 - fleet.last_consolidated_at = datetime.now(timezone.utc) - await session.commit() - - assert _get_fleet_spec_if_ready_for_consolidation(fleet) is None - - async def test_ready_for_consolidation_helper_returns_consolidation_fleet_spec_for_eligible_cloud_fleet( - self, test_db, session: AsyncSession - ): - project = await create_project(session) - fleet = await create_fleet( - session=session, - project=project, - spec=get_fleet_spec(), - last_processed_at=datetime(2023, 1, 2, 3, 0, tzinfo=timezone.utc), - ) - - consolidation_fleet_spec = _get_fleet_spec_if_ready_for_consolidation(fleet) - - assert consolidation_fleet_spec is not None - assert consolidation_fleet_spec.configuration.nodes is not None - async def test_skips_instance_locking_for_ssh_fleet( self, test_db, session: AsyncSession, worker: FleetWorker ): @@ -343,60 +296,6 @@ async def test_resets_fleet_lock_when_not_all_instances_can_be_locked( assert fleet.last_processed_at > original_last_processed_at assert locked_elsewhere.lock_owner == "OtherPipeline" - async def test_lock_helper_uses_fresh_current_instances_instead_of_stale_relationship( - self, test_db, session: AsyncSession - ): - project = await create_project(session) - spec = get_fleet_spec() - fleet = await create_fleet( - session=session, - project=project, - spec=spec, - ) - await create_instance( - session=session, - project=project, - fleet=fleet, - status=InstanceStatus.IDLE, - instance_num=0, - ) - fleet.lock_token = uuid.uuid4() - fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) - fleet.lock_owner = FleetPipeline.__name__ - await session.commit() - - res = await session.execute( - select(FleetModel) - .where(FleetModel.id == fleet.id) - .options(selectinload(FleetModel.instances.and_(InstanceModel.deleted == False))) - ) - stale_fleet_model = res.unique().scalar_one() - assert len(stale_fleet_model.instances) == 1 - - async with get_session_ctx() as other_session: - project_model = await other_session.get(type(project), project.id) - fleet_model = await other_session.get(FleetModel, fleet.id) - assert project_model is not None - assert fleet_model is not None - await create_instance( - session=other_session, - project=project_model, - fleet=fleet_model, - status=InstanceStatus.IDLE, - instance_num=1, - ) - - assert len(stale_fleet_model.instances) == 1 - - locked_instances = await _lock_fleet_instances_for_consolidation( - session=session, - item=_fleet_to_pipeline_item(fleet), - ) - - assert locked_instances is not None - assert len(locked_instances) == 2 - assert {instance.instance_num for instance in locked_instances} == {0, 1} - async def test_syncs_initial_current_master_for_cluster_fleet( self, test_db, session: AsyncSession, worker: FleetWorker ): @@ -612,6 +511,66 @@ async def test_syncs_current_master_after_creating_missing_instances( assert len(instances) == 2 assert fleet.current_master_instance_id == instances[0].id + async def test_prefers_surviving_instance_over_new_replacement_for_master_election( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=2, target=2, max=2), + ) + ), + ) + terminated_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + surviving_instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=1, + ) + fleet.current_master_instance_id = terminated_master.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(terminated_master) + await session.refresh(surviving_instance) + non_deleted_instances = ( + ( + await session.execute( + select(InstanceModel) + .where(InstanceModel.fleet_id == fleet.id, InstanceModel.deleted == False) + .order_by(InstanceModel.instance_num, InstanceModel.created_at) + ) + ) + .scalars() + .all() + ) + + assert terminated_master.deleted + assert fleet.current_master_instance_id == surviving_instance.id + assert len(non_deleted_instances) == 2 + assert any( + instance.id != surviving_instance.id and instance.instance_num == 0 + for instance in non_deleted_instances + ) + async def test_deletes_empty_autocreated_fleet( self, test_db, session: AsyncSession, worker: FleetWorker ): From 8367521d7f1e3ea9e3a28b02a18b6004d1bfbd5f Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 5 Mar 2026 11:42:51 +0500 Subject: [PATCH 38/51] Terminate instances with MASTER_FAILED if the master dies with NO_OFFERS --- .../background/pipeline_tasks/fleets.py | 112 +++++++++++- .../background/pipeline_tasks/test_fleets.py | 163 +++++++++++++++++- 2 files changed, 267 insertions(+), 8 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 412ec353f9..b3ed80740a 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -11,6 +11,7 @@ from dstack._internal.core.models.fleets import ( FleetSpec, FleetStatus, + InstanceGroupPlacement, ) from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason from dstack._internal.core.models.runs import RunStatus @@ -44,7 +45,6 @@ emit_fleet_status_change_event, get_fleet_spec, get_next_instance_num, - is_cloud_cluster, is_fleet_empty, is_fleet_in_use, ) @@ -427,7 +427,12 @@ async def _process_fleet( result.fleet_update_map["status"] = FleetStatus.TERMINATED result.fleet_update_map["deleted"] = True result.fleet_update_map["deleted_at"] = NOW_PLACEHOLDER - _set_current_master_instance_id_update( + _set_fail_instances_on_master_bootstrap_failure( + fleet_model=fleet_model, + instance_models=effective_instances, + instance_id_to_update_map=result.instance_id_to_update_map, + ) + _set_current_master_instance_id( fleet_model=fleet_model, fleet_update_map=result.fleet_update_map, instance_models=effective_instances, @@ -610,14 +615,73 @@ async def _create_missing_fleet_instances( ) -def _set_current_master_instance_id_update( +def _set_fail_instances_on_master_bootstrap_failure( + fleet_model: FleetModel, + instance_models: Sequence[InstanceModel], + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], +) -> None: + """ + Terminates instances with MASTER_FAILED if the master dies with NO_OFFERS in a cluster with node.min == 0. + This is needed to avoid master re-election loop and fail fast. + """ + fleet_spec = get_fleet_spec(fleet_model) + if ( + not _is_cloud_cluster_fleet_spec(fleet_spec) + or fleet_spec.configuration.nodes is None + or fleet_spec.configuration.nodes.min != 0 + or fleet_model.current_master_instance_id is None + ): + return + + current_master_instance_model = None + for instance_model in instance_models: + if instance_model.id == fleet_model.current_master_instance_id: + current_master_instance_model = instance_model + break + if current_master_instance_model is None: + return + + if ( + current_master_instance_model.status != InstanceStatus.TERMINATED + or current_master_instance_model.termination_reason != InstanceTerminationReason.NO_OFFERS + ): + return + + surviving_instance_models = _get_surviving_instance_models_after_updates( + instance_models=instance_models, + instance_id_to_update_map=instance_id_to_update_map, + ) + if any( + instance_model.status not in InstanceStatus.finished_statuses() + for instance_model in surviving_instance_models + ): + return + + for instance_model in surviving_instance_models: + if ( + instance_model.id == current_master_instance_model.id + or instance_model.status in InstanceStatus.finished_statuses() + ): + continue + update_map = instance_id_to_update_map.setdefault(instance_model.id, _InstanceUpdateMap()) + update_map["status"] = InstanceStatus.TERMINATED + update_map["termination_reason"] = InstanceTerminationReason.MASTER_FAILED + + +def _set_current_master_instance_id( fleet_model: FleetModel, fleet_update_map: _FleetUpdateMap, instance_models: Sequence[InstanceModel], instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], new_instance_creates: Sequence[_NewInstanceCreate], ) -> None: - if not is_cloud_cluster(fleet_model): + """ + Sets `current_master_instance_id` for `fleet_model`. + Master instance can be changed if the previous master is gone. + If there are no active instances, newly selected master may change backend/region/az/placement. + """ + fleet_spec = get_fleet_spec(fleet_model) + if not _is_cloud_cluster_fleet_spec(fleet_spec): fleet_update_map["current_master_instance_id"] = None return surviving_instance_models = _get_surviving_instance_models_after_updates( @@ -627,6 +691,7 @@ def _set_current_master_instance_id_update( current_master_instance_id = _select_current_master_instance_id( current_master_instance_id=fleet_model.current_master_instance_id, surviving_instance_models=surviving_instance_models, + instance_id_to_update_map=instance_id_to_update_map, new_instance_creates=new_instance_creates, ) fleet_update_map["current_master_instance_id"] = current_master_instance_id @@ -648,6 +713,7 @@ def _get_surviving_instance_models_after_updates( def _select_current_master_instance_id( current_master_instance_id: Optional[uuid.UUID], surviving_instance_models: Sequence[InstanceModel], + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], new_instance_creates: Sequence[_NewInstanceCreate], ) -> Optional[uuid.UUID]: # Keep the current master stable while it is still alive so InstancePipeline @@ -656,7 +722,11 @@ def _select_current_master_instance_id( for instance_model in surviving_instance_models: if ( instance_model.id == current_master_instance_id - and instance_model.status != InstanceStatus.TERMINATED + and _get_effective_instance_status( + instance_model, + instance_id_to_update_map=instance_id_to_update_map, + ) + not in InstanceStatus.finished_statuses() ): return instance_model.id @@ -664,7 +734,11 @@ def _select_current_master_instance_id( # keep following an already-established cluster placement decision. for instance_model in surviving_instance_models: if ( - instance_model.status != InstanceStatus.TERMINATED + _get_effective_instance_status( + instance_model, + instance_id_to_update_map=instance_id_to_update_map, + ) + not in InstanceStatus.finished_statuses() and instance_model.job_provisioning_data is not None ): return instance_model.id @@ -672,10 +746,34 @@ def _select_current_master_instance_id( # Prefer existing surviving instances over freshly planned replacements to # avoid election churn during min-nodes backfill. for instance_model in surviving_instance_models: - if instance_model.status != InstanceStatus.TERMINATED: + if ( + _get_effective_instance_status( + instance_model, + instance_id_to_update_map=instance_id_to_update_map, + ) + not in InstanceStatus.finished_statuses() + ): return instance_model.id for new_instance_create in sorted(new_instance_creates, key=lambda i: i["instance_num"]): return new_instance_create["id"] return None + + +def _get_effective_instance_status( + instance_model: InstanceModel, + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], +) -> InstanceStatus: + update_map = instance_id_to_update_map.get(instance_model.id) + if update_map is None: + return instance_model.status + return update_map.get("status", instance_model.status) + + +def _is_cloud_cluster_fleet_spec(fleet_spec: FleetSpec) -> bool: + configuration = fleet_spec.configuration + return ( + configuration.placement == InstanceGroupPlacement.CLUSTER + and configuration.ssh_config is None + ) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index 40eada4f07..ed11dd4c15 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -12,7 +12,7 @@ FleetStatus, InstanceGroupPlacement, ) -from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason from dstack._internal.core.models.runs import RunStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.server.background.pipeline_tasks.base import PipelineItem @@ -455,6 +455,55 @@ async def test_promotes_next_bootstrap_candidate_when_current_master_terminated( await session.refresh(fleet) assert fleet.current_master_instance_id == next_candidate.id + async def test_does_not_elect_terminating_bootstrap_candidate_as_master( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=1, target=1, max=3), + ) + ), + ) + terminated_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATING, + job_provisioning_data=None, + offer=None, + instance_num=1, + ) + pending_candidate = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=2, + ) + fleet.current_master_instance_id = terminated_master.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert fleet.current_master_instance_id == pending_candidate.id + async def test_clears_current_master_for_non_cluster_fleet( self, test_db, session: AsyncSession, worker: FleetWorker ): @@ -571,6 +620,118 @@ async def test_prefers_surviving_instance_over_new_replacement_for_master_electi for instance in non_deleted_instances ) + async def test_min_zero_failed_master_terminates_unprovisioned_siblings( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=0, target=3, max=3), + ) + ), + ) + failed_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + failed_master.termination_reason = InstanceTerminationReason.NO_OFFERS + sibling1 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=1, + ) + sibling2 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=2, + ) + fleet.current_master_instance_id = failed_master.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(failed_master) + await session.refresh(sibling1) + await session.refresh(sibling2) + assert failed_master.deleted + assert sibling1.status == InstanceStatus.TERMINATED + assert sibling2.status == InstanceStatus.TERMINATED + assert sibling1.termination_reason == InstanceTerminationReason.MASTER_FAILED + assert sibling2.termination_reason == InstanceTerminationReason.MASTER_FAILED + assert fleet.current_master_instance_id is None + + async def test_min_zero_failed_master_preserves_provisioned_survivor( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + conf=get_fleet_configuration( + placement=InstanceGroupPlacement.CLUSTER, + nodes=FleetNodesSpec(min=0, target=2, max=2), + ) + ), + ) + failed_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + job_provisioning_data=None, + offer=None, + instance_num=0, + ) + failed_master.termination_reason = InstanceTerminationReason.NO_OFFERS + provisioned_survivor = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + job_provisioning_data=get_job_provisioning_data(), + instance_num=1, + ) + pending_sibling = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.PENDING, + job_provisioning_data=None, + offer=None, + instance_num=2, + ) + fleet.current_master_instance_id = failed_master.id + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(provisioned_survivor) + await session.refresh(pending_sibling) + assert provisioned_survivor.status == InstanceStatus.IDLE + assert pending_sibling.status == InstanceStatus.PENDING + assert pending_sibling.termination_reason is None + assert fleet.current_master_instance_id == provisioned_survivor.id + async def test_deletes_empty_autocreated_fleet( self, test_db, session: AsyncSession, worker: FleetWorker ): From edb225b37ce4c1c2f0d67cd1bc3adabb8fb4b218 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 5 Mar 2026 12:21:14 +0500 Subject: [PATCH 39/51] Fix instance unlock in fleet pipeline --- .../background/pipeline_tasks/fleets.py | 47 +++++++++++-- .../pipeline_tasks/test_compute_groups.py | 6 +- .../background/pipeline_tasks/test_fleets.py | 67 ++++++++++++++++++- .../test_instances/test_check.py | 6 +- .../pipeline_tasks/test_placement_groups.py | 6 +- 5 files changed, 118 insertions(+), 14 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index b3ed80740a..2a63e21bd5 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -221,6 +221,7 @@ async def process(self, item: PipelineItem): return # Lock instance only if consolidation is needed. + locked_instance_ids: set[uuid.UUID] = set() consolidation_fleet_spec = _get_fleet_spec_if_ready_for_consolidation(fleet_model) consolidation_instances = None if consolidation_fleet_spec is not None: @@ -230,6 +231,7 @@ async def process(self, item: PipelineItem): ) if consolidation_instances is None: return + locked_instance_ids = {instance.id for instance in consolidation_instances} result = await _process_fleet( fleet_model, @@ -240,7 +242,10 @@ async def process(self, item: PipelineItem): fleet_update_map.update(result.fleet_update_map) set_processed_update_map_fields(fleet_update_map) set_unlock_update_map_fields(fleet_update_map) - instance_update_rows = _build_instance_update_rows(result.instance_id_to_update_map) + instance_update_rows = _build_instance_update_rows( + result.instance_id_to_update_map, + unlock_instance_ids=locked_instance_ids, + ) async with get_session_ctx() as session: now = get_current_datetime() @@ -258,6 +263,12 @@ async def process(self, item: PipelineItem): updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: log_lock_token_changed_after_processing(logger, item) + if locked_instance_ids: + await _unlock_fleet_locked_instances( + session=session, + item=item, + locked_instance_ids=locked_instance_ids, + ) # TODO: Clean up fleet. return @@ -297,7 +308,7 @@ class _FleetUpdateMap(ItemUpdateMap, total=False): current_master_instance_id: Optional[uuid.UUID] -class _InstanceUpdateMap(TypedDict, total=False): +class _InstanceUpdateMap(ItemUpdateMap, total=False): status: InstanceStatus termination_reason: InstanceTerminationReason termination_reason_message: str @@ -571,17 +582,42 @@ def _should_delete_fleet(fleet_model: FleetModel) -> bool: def _build_instance_update_rows( instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], + unlock_instance_ids: set[uuid.UUID], ) -> list[_InstanceUpdateMap]: instance_update_rows = [] - for instance_id, instance_update_map in instance_id_to_update_map.items(): + for instance_id in sorted(instance_id_to_update_map.keys() | unlock_instance_ids): + instance_update_map = instance_id_to_update_map.get(instance_id) update_row = _InstanceUpdateMap() - update_row.update(instance_update_map) + if instance_update_map is not None: + update_row.update(instance_update_map) + if instance_id in unlock_instance_ids: + set_unlock_update_map_fields(update_row) update_row["id"] = instance_id set_processed_update_map_fields(update_row) instance_update_rows.append(update_row) return instance_update_rows +async def _unlock_fleet_locked_instances( + session: AsyncSession, + item: PipelineItem, + locked_instance_ids: set[uuid.UUID], +) -> None: + await session.execute( + update(InstanceModel) + .where( + InstanceModel.id.in_(locked_instance_ids), + InstanceModel.lock_token == item.lock_token, + InstanceModel.lock_owner == FleetPipeline.__name__, + ) + .values( + lock_expires_at=None, + lock_token=None, + lock_owner=None, + ) + ) + + async def _create_missing_fleet_instances( session: AsyncSession, fleet_model: FleetModel, @@ -653,8 +689,11 @@ def _set_fail_instances_on_master_bootstrap_failure( ) if any( instance_model.status not in InstanceStatus.finished_statuses() + and instance_model.job_provisioning_data is not None for instance_model in surviving_instance_models ): + # It should not be possible to provision non-master instances ahead of master + # but we still safe-guard against the case when there can be other instances provisioned. return for instance_model in surviving_instance_models: diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py index cfc1c48d33..776240fc47 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_compute_groups.py @@ -148,9 +148,9 @@ async def test_fetch_returns_oldest_compute_groups_first_up_to_limit( assert newest.lock_owner is None +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestComputeGroupWorker: - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_terminates_compute_group( self, test_db, session: AsyncSession, worker: ComputeGroupWorker ): @@ -176,8 +176,6 @@ async def test_terminates_compute_group( assert compute_group.status == ComputeGroupStatus.TERMINATED assert compute_group.deleted - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_retries_compute_group_termination( self, test_db, session: AsyncSession, worker: ComputeGroupWorker ): diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index ed11dd4c15..d2b53226e3 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -1,7 +1,7 @@ import asyncio import uuid from datetime import datetime, timedelta, timezone -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock, patch import pytest from sqlalchemy import select @@ -15,6 +15,7 @@ from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason from dstack._internal.core.models.runs import RunStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole +from dstack._internal.server.background.pipeline_tasks import fleets as fleets_pipeline from dstack._internal.server.background.pipeline_tasks.base import PipelineItem from dstack._internal.server.background.pipeline_tasks.fleets import ( FleetFetcher, @@ -296,6 +297,70 @@ async def test_resets_fleet_lock_when_not_all_instances_can_be_locked( assert fleet.last_processed_at > original_last_processed_at assert locked_elsewhere.lock_owner == "OtherPipeline" + async def test_unlocks_instances_after_consolidation( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=0, + ) + await _lock_fleet_for_processing(session, fleet) + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(instance) + assert instance.lock_owner is None + assert instance.lock_token is None + assert instance.lock_expires_at is None + + async def test_unlocks_instances_when_fleet_lock_token_changes_after_processing( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=0, + ) + await _lock_fleet_for_processing(session, fleet) + + async def mock_process_fleet(*args, **kwargs): + fleet_model = args[0] + fleet_model.lock_token = uuid.uuid4() + return fleets_pipeline._ProcessResult() + + with patch.object( + fleets_pipeline, + "_process_fleet", + AsyncMock(side_effect=mock_process_fleet), + ): + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(instance) + assert instance.lock_owner is None + assert instance.lock_token is None + assert instance.lock_expires_at is None + async def test_syncs_initial_current_master_for_cluster_fleet( self, test_db, session: AsyncSession, worker: FleetWorker ): diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py index 5cfc8f887a..f19986b51d 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py @@ -44,7 +44,7 @@ @pytest.mark.asyncio @pytest.mark.usefixtures("image_config_mock") @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) -class TestInstanceCheck: +class TestCheckInstance: async def test_check_shim_transitions_provisioning_on_ready( self, test_db, @@ -397,6 +397,10 @@ async def test_check_shim_check_instance_health( assert health_check.status == HealthStatus.WARNING assert health_check.response == health_response.json() + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestProcessIdleTimeout: async def test_terminate_by_idle_timeout( self, test_db, diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py index 6fa0ea7682..90c8e75194 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py @@ -175,9 +175,9 @@ async def test_fetch_returns_oldest_placement_groups_first_up_to_limit( assert newest.lock_owner is None +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestPlacementGroupWorker: - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_deletes_placement_group( self, test_db, session: AsyncSession, worker: PlacementGroupWorker ): @@ -205,8 +205,6 @@ async def test_deletes_placement_group( await session.refresh(placement_group) assert placement_group.deleted - @pytest.mark.asyncio - @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_retries_placement_group_deletion_if_still_in_use( self, test_db, session: AsyncSession, worker: PlacementGroupWorker ): From e2152d8f7e7ead4798b31fecefa56f74da062a65 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 5 Mar 2026 12:29:31 +0500 Subject: [PATCH 40/51] Remove extra fleet_model_to_fleet --- .../server/background/pipeline_tasks/instances/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py index 06a833007a..729842176e 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py @@ -20,7 +20,7 @@ ) from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout from dstack._internal.server.models import FleetModel, InstanceModel, PlacementGroupModel -from dstack._internal.server.services.fleets import fleet_model_to_fleet +from dstack._internal.server.services.fleets import get_fleet_spec from dstack._internal.utils.common import UNSET, Unset, get_current_datetime from dstack._internal.utils.ssh import pkey_from_str @@ -71,13 +71,13 @@ class ProcessResult: def can_terminate_fleet_instances_on_idle_duration(fleet_model: FleetModel) -> bool: - fleet = fleet_model_to_fleet(fleet_model) - if fleet.spec.configuration.nodes is None or fleet.spec.autocreated: + fleet_spec = get_fleet_spec(fleet_model) + if fleet_spec.configuration.nodes is None or fleet_spec.autocreated: return True active_instances = [ instance for instance in fleet_model.instances if instance.status.is_active() ] - return len(active_instances) > fleet.spec.configuration.nodes.min + return len(active_instances) > fleet_spec.configuration.nodes.min def get_instance_idle_duration(instance_model: InstanceModel) -> datetime.timedelta: From 6e2676343dcffdcc5179c723110e2b732d771e52 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 5 Mar 2026 13:27:02 +0500 Subject: [PATCH 41/51] Wire pipeline_hinter --- .../background/pipeline_tasks/instances/check.py | 4 ---- src/dstack/_internal/server/routers/fleets.py | 5 +++++ src/dstack/_internal/server/services/fleets.py | 13 ++++++++++++- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py index 8ca135ab69..30afd97742 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py @@ -69,10 +69,6 @@ def process_idle_timeout(instance_model: InstanceModel) -> Optional[ProcessResul if instance_model.fleet is not None and not can_terminate_fleet_instances_on_idle_duration( instance_model.fleet ): - logger.debug( - "Skipping instance %s termination on idle duration. Fleet is already at `nodes.min`.", - instance_model.name, - ) return None idle_duration = get_instance_idle_duration(instance_model) diff --git a/src/dstack/_internal/server/routers/fleets.py b/src/dstack/_internal/server/routers/fleets.py index a436d1123a..192403ae3f 100644 --- a/src/dstack/_internal/server/routers/fleets.py +++ b/src/dstack/_internal/server/routers/fleets.py @@ -20,6 +20,7 @@ ListFleetsRequest, ) from dstack._internal.server.security.permissions import Authenticated, ProjectMember +from dstack._internal.server.services.pipelines import PipelineHinterProtocol, get_pipeline_hinter from dstack._internal.server.utils.routers import ( CustomORJSONResponse, get_base_api_additional_responses, @@ -127,6 +128,7 @@ async def apply_plan( body: ApplyFleetPlanRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), + pipeline_hinter: PipelineHinterProtocol = Depends(get_pipeline_hinter), ): """ Creates a new fleet or updates an existing fleet. @@ -141,6 +143,7 @@ async def apply_plan( project=project, plan=body.plan, force=body.force, + pipeline_hinter=pipeline_hinter, ) ) @@ -150,6 +153,7 @@ async def create_fleet( body: CreateFleetRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), + pipeline_hinter: PipelineHinterProtocol = Depends(get_pipeline_hinter), ): """ Creates a fleet given a fleet configuration. @@ -161,6 +165,7 @@ async def create_fleet( project=project, user=user, spec=body.spec, + pipeline_hinter=pipeline_hinter, ) ) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index ea72f0f609..59052ef432 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -73,6 +73,7 @@ get_locker, string_to_lock_id, ) +from dstack._internal.server.services.pipelines import PipelineHinterProtocol from dstack._internal.server.services.plugins import apply_plugin_policies from dstack._internal.server.services.projects import ( get_member, @@ -467,6 +468,7 @@ async def apply_plan( project: ProjectModel, plan: ApplyFleetPlanInput, force: bool, + pipeline_hinter: PipelineHinterProtocol, ) -> Fleet: spec = await apply_plugin_policies( user=user.name, @@ -487,6 +489,7 @@ async def apply_plan( project=project, user=user, spec=spec, + pipeline_hinter=pipeline_hinter, ) fleet_model = await get_project_fleet_model_by_name( @@ -500,6 +503,7 @@ async def apply_plan( project=project, user=user, spec=spec, + pipeline_hinter=pipeline_hinter, ) instances_ids = sorted(i.id for i in fleet_model.instances if not i.deleted) @@ -557,6 +561,7 @@ async def apply_plan( project=project, user=user, spec=spec, + pipeline_hinter=pipeline_hinter, ) @@ -565,6 +570,7 @@ async def create_fleet( project: ProjectModel, user: UserModel, spec: FleetSpec, + pipeline_hinter: PipelineHinterProtocol, ) -> Fleet: spec = await apply_plugin_policies( user=user.name, @@ -578,7 +584,9 @@ async def create_fleet( if spec.configuration.ssh_config is not None: _check_can_manage_ssh_fleets(user=user, project=project) - return await _create_fleet(session=session, project=project, user=user, spec=spec) + return await _create_fleet( + session=session, project=project, user=user, spec=spec, pipeline_hinter=pipeline_hinter + ) def create_fleet_instance_model( @@ -910,6 +918,7 @@ async def _create_fleet( project: ProjectModel, user: UserModel, spec: FleetSpec, + pipeline_hinter: PipelineHinterProtocol, ) -> Fleet: lock_namespace = f"fleet_names_{project.name}" if is_db_sqlite(): @@ -990,6 +999,8 @@ async def _create_fleet( targets=[events.Target.from_model(instance_model)], ) fleet_model.instances.append(instance_model) + pipeline_hinter.hint_fetch(FleetModel.__name__) + pipeline_hinter.hint_fetch(InstanceModel.__name__) await session.commit() return fleet_model_to_fleet(fleet_model) From baa79e314019b44fc79423013d3ce0fcca05a37e Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 5 Mar 2026 14:43:54 +0500 Subject: [PATCH 42/51] Remove extra fleet_model_to_fleet --- .../pipeline_tasks/instances/cloud_provisioning.py | 5 +---- src/dstack/_internal/server/services/fleets.py | 6 +++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py index eac6f3ff30..43c5261d0d 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py @@ -229,8 +229,6 @@ async def _get_cluster_master_context( fleet_id=instance_model.fleet_id, ) if current_master_instance_model is None: - # FleetPipeline elects the current master. Until it does, instance - # workers must wait instead of trying to coordinate bootstrap. logger.debug( "%s: waiting for fleet pipeline to elect current cluster master", fmt(instance_model), @@ -244,8 +242,6 @@ async def _get_cluster_master_context( current_master_instance_model.deleted or current_master_instance_model.status == InstanceStatus.TERMINATED ): - # Master failover is also owned by FleetPipeline. InstancePipeline - # only terminates the current instance and waits for the next fleet tick. logger.debug( "%s: waiting for fleet pipeline to replace current master %s", fmt(instance_model), @@ -262,6 +258,7 @@ async def _get_cluster_master_context( current_master_instance_model.id, ) return None + return _ClusterMasterContext( current_master_instance_model=current_master_instance_model, is_current_instance_master=is_current_instance_master, diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 59052ef432..e56049d25c 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -842,10 +842,10 @@ def is_fleet_empty(fleet_model: FleetModel) -> bool: def is_cloud_cluster(fleet_model: FleetModel) -> bool: - fleet = fleet_model_to_fleet(fleet_model) + fleet_spec = get_fleet_spec(fleet_model) return ( - fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER - and fleet.spec.configuration.ssh_config is None + fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER + and fleet_spec.configuration.ssh_config is None ) From 1bcbca0962ed722f3fb45fc4cca80af947a50ded Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 5 Mar 2026 14:47:58 +0500 Subject: [PATCH 43/51] Fix index name --- ..._297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_04_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py b/src/dstack/_internal/server/migrations/versions/2026/03_04_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py index 5139370104..29c88f14e6 100644 --- a/src/dstack/_internal/server/migrations/versions/2026/03_04_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py +++ b/src/dstack/_internal/server/migrations/versions/2026/03_04_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py @@ -20,13 +20,13 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### with op.get_context().autocommit_block(): op.drop_index( - "ix_instances_pipeline_fetch_q_index", + "ix_instances_pipeline_fetch_q", table_name="instances", if_exists=True, postgresql_concurrently=True, ) op.create_index( - "ix_instances_pipeline_fetch_q_index", + "ix_instances_pipeline_fetch_q", "instances", [sa.literal_column("last_processed_at ASC")], unique=False, @@ -41,7 +41,7 @@ def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### with op.get_context().autocommit_block(): op.drop_index( - "ix_instances_pipeline_fetch_q_index", + "ix_instances_pipeline_fetch_q", table_name="instances", if_exists=True, postgresql_concurrently=True, From 0767d6c9de26789fd245e4a85e6724b8715716d7 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 5 Mar 2026 15:00:37 +0500 Subject: [PATCH 44/51] Rebase migrations --- ...0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py} | 6 +++--- ...297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py} | 2 +- ..._1015_9cb8e4e4d986_add_fleet_current_master_instance.py} | 2 +- ...7b0a8e57294_add_ix_fleets_current_master_instance_id.py} | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) rename src/dstack/_internal/server/migrations/versions/2026/{03_04_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py => 03_05_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py} (92%) rename src/dstack/_internal/server/migrations/versions/2026/{03_04_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py => 03_05_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py} (96%) rename src/dstack/_internal/server/migrations/versions/2026/{03_04_1015_9cb8e4e4d986_add_fleet_current_master_instance.py => 03_05_1015_9cb8e4e4d986_add_fleet_current_master_instance.py} (95%) rename src/dstack/_internal/server/migrations/versions/2026/{03_05_0545_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py => 03_05_1045_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py} (95%) diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_04_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py b/src/dstack/_internal/server/migrations/versions/2026/03_05_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py similarity index 92% rename from src/dstack/_internal/server/migrations/versions/2026/03_04_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py rename to src/dstack/_internal/server/migrations/versions/2026/03_05_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py index cc82d95de4..f1c2b1217a 100644 --- a/src/dstack/_internal/server/migrations/versions/2026/03_04_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py +++ b/src/dstack/_internal/server/migrations/versions/2026/03_05_0547_8e8647f20aa4_add_instancemodel_pipeline_columns.py @@ -1,8 +1,8 @@ """Add InstanceModel pipeline columns Revision ID: 8e8647f20aa4 -Revises: 46150101edec -Create Date: 2026-03-04 05:47:39.307013+00:00 +Revises: 5e8c7a9202bc +Create Date: 2026-03-05 05:47:39.307013+00:00 """ @@ -14,7 +14,7 @@ # revision identifiers, used by Alembic. revision = "8e8647f20aa4" -down_revision = "46150101edec" +down_revision = "5e8c7a9202bc" branch_labels = None depends_on = None diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_04_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py b/src/dstack/_internal/server/migrations/versions/2026/03_05_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py similarity index 96% rename from src/dstack/_internal/server/migrations/versions/2026/03_04_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py rename to src/dstack/_internal/server/migrations/versions/2026/03_05_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py index 29c88f14e6..e629de0950 100644 --- a/src/dstack/_internal/server/migrations/versions/2026/03_04_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py +++ b/src/dstack/_internal/server/migrations/versions/2026/03_05_0751_297c68450cc8_add_ix_instances_pipeline_fetch_q_index.py @@ -2,7 +2,7 @@ Revision ID: 297c68450cc8 Revises: 8e8647f20aa4 -Create Date: 2026-03-04 07:51:02.855596+00:00 +Create Date: 2026-03-05 07:51:02.855596+00:00 """ diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_04_1015_9cb8e4e4d986_add_fleet_current_master_instance.py b/src/dstack/_internal/server/migrations/versions/2026/03_05_1015_9cb8e4e4d986_add_fleet_current_master_instance.py similarity index 95% rename from src/dstack/_internal/server/migrations/versions/2026/03_04_1015_9cb8e4e4d986_add_fleet_current_master_instance.py rename to src/dstack/_internal/server/migrations/versions/2026/03_05_1015_9cb8e4e4d986_add_fleet_current_master_instance.py index 519b58c899..2049236267 100644 --- a/src/dstack/_internal/server/migrations/versions/2026/03_04_1015_9cb8e4e4d986_add_fleet_current_master_instance.py +++ b/src/dstack/_internal/server/migrations/versions/2026/03_05_1015_9cb8e4e4d986_add_fleet_current_master_instance.py @@ -2,7 +2,7 @@ Revision ID: 9cb8e4e4d986 Revises: 297c68450cc8 -Create Date: 2026-03-04 10:15:00.000000+00:00 +Create Date: 2026-03-05 10:15:00.000000+00:00 """ diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_05_0545_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py b/src/dstack/_internal/server/migrations/versions/2026/03_05_1045_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py similarity index 95% rename from src/dstack/_internal/server/migrations/versions/2026/03_05_0545_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py rename to src/dstack/_internal/server/migrations/versions/2026/03_05_1045_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py index 29f703c1f2..e1cb938750 100644 --- a/src/dstack/_internal/server/migrations/versions/2026/03_05_0545_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py +++ b/src/dstack/_internal/server/migrations/versions/2026/03_05_1045_c7b0a8e57294_add_ix_fleets_current_master_instance_id.py @@ -2,7 +2,7 @@ Revision ID: c7b0a8e57294 Revises: 9cb8e4e4d986 -Create Date: 2026-03-05 05:45:00.000000+00:00 +Create Date: 2026-03-05 10:45:00.000000+00:00 """ From e14fcd4a73de5a54bdc032ca23dd6e56914b5059 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 5 Mar 2026 16:18:47 +0500 Subject: [PATCH 45/51] Fix redundant fleet.instances loads in instance pipeline --- .../server/background/pipeline_tasks/base.py | 9 ++- .../pipeline_tasks/instances/__init__.py | 52 ++++-------- .../pipeline_tasks/instances/check.py | 43 ++++++++-- .../instances/cloud_provisioning.py | 1 + .../pipeline_tasks/instances/common.py | 19 +++-- .../_internal/server/services/fleets.py | 10 ++- .../test_instances/test_check.py | 79 +++++++++++++++++++ 7 files changed, 162 insertions(+), 51 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py index a68ad0ff97..76073b7893 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/base.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -2,6 +2,7 @@ import logging import math import random +import time import uuid from abc import ABC, abstractmethod from collections.abc import Iterable, Sequence @@ -332,6 +333,7 @@ async def start(self): self._running = True while self._running: item = await self._queue.get() + start_time = time.time() logger.debug("Processing %s item %s", item.__tablename__, item.id) try: await self.process(item) @@ -339,7 +341,12 @@ async def start(self): logger.exception("Unexpected exception when processing item") finally: await self._heartbeater.untrack(item) - logger.debug("Processed %s item %s", item.__tablename__, item.id) + logger.debug( + "Processed %s item %s in %.3f", + item.__tablename__, + item.id, + time.time() - start_time, + ) def stop(self): self._running = False diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py index 9515c92f7f..8e73403d79 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -7,7 +7,6 @@ from sqlalchemy import and_, not_, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload, load_only -from sqlalchemy.orm.attributes import set_committed_value from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import InstanceStatus @@ -41,7 +40,6 @@ ) from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( - FleetModel, InstanceHealthCheckModel, InstanceModel, JobModel, @@ -287,9 +285,12 @@ async def _process_idle_item(item: InstancePipelineItem) -> Optional[_ProcessCon if instance_model is None: log_lock_token_mismatch(logger, item) return None - idle_result = process_idle_timeout(instance_model) - if idle_result is not None: - return _ProcessContext(instance_model=instance_model, result=idle_result) + idle_result = await process_idle_timeout( + session=session, + instance_model=instance_model, + ) + if idle_result is not None: + return _ProcessContext(instance_model=instance_model, result=idle_result) result = await check_instance(instance_model) return _ProcessContext(instance_model=instance_model, result=result) @@ -328,32 +329,9 @@ async def _refetch_locked_instance_for_pending_or_terminating( ) .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) - .options( - joinedload(InstanceModel.fleet).joinedload( - FleetModel.instances.and_(InstanceModel.deleted == False) - ) - ) + .options(joinedload(InstanceModel.fleet)) ) - instance_model = res.unique().scalar_one_or_none() - if instance_model is not None: - # Pending/terminating processing runs on detached objects and later traverses - # `fleet.project`, sibling `project`, and sibling `fleet`. Populate those attrs from - # already known objects so detached access works without adding extra joins. - _populate_pending_or_terminating_detached_relations(instance_model) - return instance_model - - -def _populate_pending_or_terminating_detached_relations( - instance_model: InstanceModel, -) -> None: - project = instance_model.project - fleet = instance_model.fleet - if fleet is None: - return - set_committed_value(fleet, "project", project) - for sibling_instance_model in fleet.instances: - set_committed_value(sibling_instance_model, "project", project) - set_committed_value(sibling_instance_model, "fleet", fleet) + return res.unique().scalar_one_or_none() async def _refetch_locked_instance_for_idle( @@ -367,11 +345,7 @@ async def _refetch_locked_instance_for_idle( ) .options(joinedload(InstanceModel.project)) .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) - .options( - joinedload(InstanceModel.fleet).joinedload( - FleetModel.instances.and_(InstanceModel.deleted == False) - ) - ) + .options(joinedload(InstanceModel.fleet)) ) return res.unique().scalar_one_or_none() @@ -385,7 +359,13 @@ async def _refetch_locked_instance_for_check( InstanceModel.id == item.id, InstanceModel.lock_token == item.lock_token, ) - .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) + .options( + joinedload(InstanceModel.project).load_only( + ProjectModel.id, + ProjectModel.ssh_public_key, + ProjectModel.ssh_private_key, + ) + ) .options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status)) ) return res.unique().scalar_one_or_none() diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py index 30afd97742..a57ad229a0 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py @@ -1,11 +1,15 @@ import logging +import uuid from datetime import timedelta from typing import Optional import gpuhunt import requests from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload +from dstack._internal.core.backends.base.backend import Backend from dstack._internal.core.backends.base.compute import ( get_dstack_runner_download_url, get_dstack_runner_version, @@ -14,6 +18,7 @@ ) from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT from dstack._internal.core.errors import ProvisioningError +from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason from dstack._internal.core.models.profiles import TerminationPolicy @@ -31,7 +36,7 @@ set_unreachable_update, ) from dstack._internal.server.db import get_session_ctx -from dstack._internal.server.models import InstanceHealthCheckModel, InstanceModel +from dstack._internal.server.models import InstanceHealthCheckModel, InstanceModel, ProjectModel from dstack._internal.server.schemas.instances import InstanceCheck from dstack._internal.server.schemas.runner import ( ComponentInfo, @@ -54,7 +59,10 @@ logger = get_logger(__name__) -def process_idle_timeout(instance_model: InstanceModel) -> Optional[ProcessResult]: +async def process_idle_timeout( + session: AsyncSession, + instance_model: InstanceModel, +) -> Optional[ProcessResult]: if not ( instance_model.status == InstanceStatus.IDLE and instance_model.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE @@ -66,8 +74,12 @@ def process_idle_timeout(instance_model: InstanceModel) -> Optional[ProcessResul # There may be race conditions since we don't take the fleet lock. # That's ok: in the worst case we go below `nodes.min`, but # the fleet consolidation logic will provision new nodes. - if instance_model.fleet is not None and not can_terminate_fleet_instances_on_idle_duration( - instance_model.fleet + if ( + instance_model.fleet is not None + and not await can_terminate_fleet_instances_on_idle_duration( + session=session, + fleet_model=instance_model.fleet, + ) ): return None @@ -295,8 +307,8 @@ async def _process_wait_for_instance_provisioning_data( ) return result - backend = await backends_services.get_project_backend_by_type( - project=instance_model.project, + backend = await _get_backend_for_provisioning_wait( + project_id=instance_model.project_id, backend_type=job_provisioning_data.backend, ) if backend is None: @@ -342,6 +354,25 @@ async def _process_wait_for_instance_provisioning_data( return result +async def _get_backend_for_provisioning_wait( + project_id: uuid.UUID, + backend_type: BackendType, +) -> Optional[Backend]: + async with get_session_ctx() as session: + res = await session.execute( + select(ProjectModel) + .where(ProjectModel.id == project_id) + .options(joinedload(ProjectModel.backends)) + ) + project_model = res.unique().scalar_one_or_none() + if project_model is None: + return None + return await backends_services.get_project_backend_by_type( + project=project_model, + backend_type=backend_type, + ) + + @runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) def _check_instance_inner( ports: dict[int, int], diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py index 43c5261d0d..4d2cbd8696 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/cloud_provisioning.py @@ -110,6 +110,7 @@ async def create_cloud_instance(instance_model: InstanceModel) -> ProcessResult: blocks="auto" if instance_model.total_blocks is None else instance_model.total_blocks, exclude_not_available=True, master_job_provisioning_data=master_job_provisioning_data, + infer_master_job_provisioning_data_from_fleet_instances=False, ) # Limit number of offers tried to prevent long-running processing in case all offers fail. diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py index 729842176e..34e80311fd 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/common.py @@ -5,6 +5,8 @@ from typing import Optional, TypedDict, Union from paramiko.pkey import PKey +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.health import HealthStatus @@ -70,14 +72,21 @@ class ProcessResult: schedule_pg_deletion_except_id: Optional[uuid.UUID] = None -def can_terminate_fleet_instances_on_idle_duration(fleet_model: FleetModel) -> bool: +async def can_terminate_fleet_instances_on_idle_duration( + session: AsyncSession, + fleet_model: FleetModel, +) -> bool: fleet_spec = get_fleet_spec(fleet_model) if fleet_spec.configuration.nodes is None or fleet_spec.autocreated: return True - active_instances = [ - instance for instance in fleet_model.instances if instance.status.is_active() - ] - return len(active_instances) > fleet_spec.configuration.nodes.min + res = await session.execute( + select(func.count(1)).where( + InstanceModel.fleet_id == fleet_model.id, + InstanceModel.deleted == False, + InstanceModel.status.not_in(InstanceStatus.finished_statuses()), + ) + ) + return res.scalar_one() > fleet_spec.configuration.nodes.min def get_instance_idle_duration(instance_model: InstanceModel) -> datetime.timedelta: diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 74a12ff162..2db1dfab2e 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -467,16 +467,20 @@ async def get_create_instance_offers( blocks: Union[int, Literal["auto"]] = 1, exclude_not_available: bool = False, master_job_provisioning_data: Optional[JobProvisioningData] = None, + infer_master_job_provisioning_data_from_fleet_instances: bool = True, ) -> List[Tuple[Backend, InstanceOfferWithAvailability]]: multinode = False if fleet_spec is not None: multinode = fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER if fleet_model is not None: - fleet = fleet_model_to_fleet(fleet_model) - multinode = fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER + fleet_spec_from_model = get_fleet_spec(fleet_model) + multinode = fleet_spec_from_model.configuration.placement == InstanceGroupPlacement.CLUSTER # The caller may override the current cluster master explicitly instead # of inferring placement restrictions from the loaded fleet instances. - if master_job_provisioning_data is None: + if ( + master_job_provisioning_data is None + and infer_master_job_provisioning_data_from_fleet_instances + ): for instance in fleet_model.instances: jpd = instances_services.get_instance_provisioning_data(instance) if jpd is not None: diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py index f19986b51d..56772f461c 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py @@ -7,6 +7,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from dstack._internal.core.models.fleets import FleetNodesSpec from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason from dstack._internal.core.models.profiles import TerminationPolicy @@ -26,12 +27,15 @@ ) from dstack._internal.server.services.runner.client import ComponentList, ShimClient from dstack._internal.server.testing.common import ( + create_fleet, create_instance, create_job, create_project, create_repo, create_run, create_user, + get_fleet_configuration, + get_fleet_spec, get_remote_connection_info, list_events, ) @@ -401,6 +405,81 @@ async def test_check_shim_check_instance_health( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestProcessIdleTimeout: + async def test_does_not_terminate_by_idle_timeout_when_fleet_at_min_nodes( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + monkeypatch: pytest.MonkeyPatch, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + get_fleet_configuration(nodes=FleetNodesSpec(min=1, target=1, max=1)) + ), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + instance.termination_idle_time = 300 + instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE + instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19) + await session.commit() + + monkeypatch.setattr( + instances_check, + "_check_instance_inner", + Mock(return_value=InstanceCheck(reachable=True)), + ) + + await process_instance(session, worker, instance) + await session.refresh(instance) + + assert instance.status == InstanceStatus.IDLE + assert instance.termination_reason is None + + async def test_terminates_by_idle_timeout_when_fleet_above_min_nodes( + self, + test_db, + session: AsyncSession, + worker: InstanceWorker, + ): + project = await create_project(session=session) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + get_fleet_configuration(nodes=FleetNodesSpec(min=1, target=2, max=2)) + ), + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + instance.termination_idle_time = 300 + instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE + instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19) + await session.commit() + + await process_instance(session, worker, instance) + await session.refresh(instance) + + assert instance.status == InstanceStatus.TERMINATING + assert instance.termination_reason == InstanceTerminationReason.IDLE_TIMEOUT + async def test_terminate_by_idle_timeout( self, test_db, From 32a24184c44448298b2cb416a1a6b846c1c13107 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 5 Mar 2026 16:39:43 +0500 Subject: [PATCH 46/51] Do not lock all instances in delete_fleets --- .../pipeline_tasks/instances/check.py | 2 - .../_internal/server/services/fleets.py | 9 ++- .../_internal/server/routers/test_fleets.py | 78 +++++++++++++++++++ 3 files changed, 84 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py index a57ad229a0..d23d536cd1 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/check.py @@ -479,7 +479,6 @@ def _maybe_install_runner( # * To provide your own build, set DSTACK_RUNNER_VERSION_URL and DSTACK_RUNNER_DOWNLOAD_URL. expected_version = get_dstack_runner_version() if expected_version is None: - logger.debug("Cannot determine the expected runner version") return False installed_version = runner_info.version @@ -526,7 +525,6 @@ def _maybe_install_shim( # * To provide your own build, set DSTACK_SHIM_VERSION_URL and DSTACK_SHIM_DOWNLOAD_URL. expected_version = get_dstack_shim_version() if expected_version is None: - logger.debug("Cannot determine the expected shim version") return False installed_version = shim_info.version diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 2db1dfab2e..3e3c35fce0 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -735,7 +735,7 @@ async def delete_fleets( .order_by(FleetModel.id) ) fleets_ids = list(res.scalars().unique().all()) - res = await session.execute( + stmt = ( select(InstanceModel.id) .where( InstanceModel.fleet_id.in_(fleets_ids), @@ -743,6 +743,9 @@ async def delete_fleets( ) .order_by(InstanceModel.id) ) + if instance_nums is not None: + stmt = stmt.where(InstanceModel.instance_num.in_(instance_nums)) + res = await session.execute(stmt) instances_ids = list(res.scalars().unique().all()) await sqlite_commit(session) async with ( @@ -803,8 +806,8 @@ async def delete_fleets( ) raise ServerClientError(msg) for fleet_model in fleet_models: - fleet = fleet_model_to_fleet(fleet_model) - if fleet.spec.configuration.ssh_config is not None: + fleet_spec = get_fleet_spec(fleet_model) + if fleet_spec.configuration.ssh_config is not None: _check_can_manage_ssh_fleets(user=user, project=project) if instance_nums is None: logger.info("Deleting fleets: %s", [f.name for f in fleet_models]) diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 12108eed31..d14e74e80d 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -1659,6 +1659,84 @@ async def test_terminates_fleet_instances( assert instance2.status != InstanceStatus.TERMINATING assert fleet.status != FleetStatus.TERMINATING + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_ignores_lock_on_non_selected_instances( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + fleet = await create_fleet(session=session, project=project) + instance1 = await create_instance( + session=session, + project=project, + instance_num=1, + ) + instance2 = await create_instance( + session=session, + project=project, + instance_num=2, + ) + fleet.instances.append(instance1) + fleet.instances.append(instance2) + instance2.lock_expires_at = datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc) + await session.commit() + + response = await client.post( + f"/api/project/{project.name}/fleets/delete_instances", + headers=get_auth_headers(user.token), + json={"name": fleet.name, "instance_nums": [1]}, + ) + assert response.status_code == 200 + await session.refresh(fleet) + await session.refresh(instance1) + await session.refresh(instance2) + assert instance1.status == InstanceStatus.TERMINATING + assert instance2.status != InstanceStatus.TERMINATING + assert fleet.status != FleetStatus.TERMINATING + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_400_when_selected_instance_locked( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + fleet = await create_fleet(session=session, project=project) + instance1 = await create_instance( + session=session, + project=project, + instance_num=1, + ) + instance2 = await create_instance( + session=session, + project=project, + instance_num=2, + ) + fleet.instances.append(instance1) + fleet.instances.append(instance2) + instance1.lock_expires_at = datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc) + await session.commit() + + response = await client.post( + f"/api/project/{project.name}/fleets/delete_instances", + headers=get_auth_headers(user.token), + json={"name": fleet.name, "instance_nums": [1]}, + ) + assert response.status_code == 400 + await session.refresh(fleet) + await session.refresh(instance1) + await session.refresh(instance2) + assert instance1.status != InstanceStatus.TERMINATING + assert instance2.status != InstanceStatus.TERMINATING + assert fleet.status != FleetStatus.TERMINATING + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_returns_400_when_deleting_busy_instances( From 631c086feae026302f3938b139fdfbab0b364ca3 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 5 Mar 2026 16:53:14 +0500 Subject: [PATCH 47/51] Add FIXME --- .../server/background/pipeline_tasks/instances/__init__.py | 2 +- src/dstack/_internal/server/services/fleets.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py index 8e73403d79..b5289e05e9 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -72,7 +72,7 @@ def __init__( workers_num: int = 20, queue_lower_limit_factor: float = 0.5, queue_upper_limit_factor: float = 2.0, - min_processing_interval: timedelta = timedelta(seconds=10), + min_processing_interval: timedelta = timedelta(seconds=15), lock_timeout: timedelta = timedelta(seconds=30), heartbeat_trigger: timedelta = timedelta(seconds=15), ) -> None: diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 3e3c35fce0..1e4ccb9fc8 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -799,6 +799,8 @@ async def delete_fleets( ) instance_models_ids = list(res.scalars().unique().all()) if len(instance_models_ids) != len(instances_ids): + # FIXME: In case of many instances, it can always fail. + # Try locking and waiting for all instances here until requests are queued for processing. msg = ( "Failed to delete fleets: fleet instances are being processed currently. Try again later." if instance_nums is None From 08ca18f397c35d8356838586ceed95532855c747 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 5 Mar 2026 17:37:48 +0500 Subject: [PATCH 48/51] Fix tests --- .../background/pipeline_tasks/test_instances/test_check.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py index 56772f461c..b555556881 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_instances/test_check.py @@ -587,7 +587,6 @@ async def test_cannot_determine_expected_version( instances_check._maybe_install_components(instance, shim_client_mock) - assert "Cannot determine the expected runner version" in debug_task_log.text shim_client_mock.get_components.assert_called_once() shim_client_mock.install_runner.assert_not_called() @@ -713,7 +712,6 @@ async def test_cannot_determine_expected_version( instances_check._maybe_install_components(instance, shim_client_mock) - assert "Cannot determine the expected shim version" in debug_task_log.text shim_client_mock.get_components.assert_called_once() shim_client_mock.install_shim.assert_not_called() From c6f18eb26ab9da297b5005b946b05143ee49e04a Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 6 Mar 2026 11:04:32 +0500 Subject: [PATCH 49/51] Retry instance lock in delete_fleets --- .../_internal/server/services/fleets.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 1e4ccb9fc8..3d19cfed30 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -1,3 +1,4 @@ +import asyncio import uuid from collections.abc import Callable from datetime import datetime @@ -786,21 +787,27 @@ async def delete_fleets( else "Failed to delete fleet instances: fleets are being processed currently. Try again later." ) raise ServerClientError(msg) - res = await session.execute( - select(InstanceModel.id) - .where( - InstanceModel.id.in_(instances_ids), - InstanceModel.deleted == False, - InstanceModel.lock_expires_at.is_(None), + # Try locking instances in a retry loop. + # This is a hack to be able to delete fleets with many instances. + # Won't be necessary after requests are queued. + instances_left_to_lock = set(instances_ids) + for i in range(10): + res = await session.execute( + select(InstanceModel.id) + .where( + InstanceModel.id.in_(instances_left_to_lock), + InstanceModel.deleted == False, + InstanceModel.lock_expires_at.is_(None), + ) + .order_by(InstanceModel.id) # take locks in order + .with_for_update(key_share=True, of=InstanceModel) + .execution_options(populate_existing=True) ) - .order_by(InstanceModel.id) # take locks in order - .with_for_update(key_share=True, of=InstanceModel) - .execution_options(populate_existing=True) - ) - instance_models_ids = list(res.scalars().unique().all()) - if len(instance_models_ids) != len(instances_ids): - # FIXME: In case of many instances, it can always fail. - # Try locking and waiting for all instances here until requests are queued for processing. + instances_left_to_lock.difference_update(res.scalars().unique().all()) + if len(instances_left_to_lock) == 0: + break + await asyncio.sleep(0.5) + if len(instances_left_to_lock) > 0: msg = ( "Failed to delete fleets: fleet instances are being processed currently. Try again later." if instance_nums is None From f9a933838f451bc3838294881550cdd8f05a87cc Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 6 Mar 2026 11:27:54 +0500 Subject: [PATCH 50/51] Retry lock in all delete endpoints --- .../_internal/server/services/fleets.py | 61 ++++++++++--------- .../server/services/gateways/__init__.py | 35 ++++++----- .../_internal/server/services/volumes.py | 38 +++++++----- 3 files changed, 74 insertions(+), 60 deletions(-) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 3d19cfed30..600cba41a4 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -753,43 +753,46 @@ async def delete_fleets( get_locker(get_db().dialect_name).lock_ctx(FleetModel.__tablename__, fleets_ids), get_locker(get_db().dialect_name).lock_ctx(InstanceModel.__tablename__, instances_ids), ): - # Refetch after lock. - # TODO: Do not lock fleet when deleting only instances. - res = await session.execute( - select(FleetModel) - .where( - FleetModel.project_id == project.id, - FleetModel.id.in_(fleets_ids), - FleetModel.deleted == False, - FleetModel.lock_expires_at.is_(None), - ) - .options( - selectinload(FleetModel.instances.and_(InstanceModel.id.in_(instances_ids))) - .selectinload(InstanceModel.jobs) - .load_only(JobModel.id) - ) - .options( - selectinload( - FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) - ).load_only(RunModel.status) + # Retry locking fleets to increase lock acquisition chances. + # This hack is needed until requests are queued. + fleet_models = [] + for i in range(10): + res = await session.execute( + select(FleetModel) + .where( + FleetModel.project_id == project.id, + FleetModel.id.in_(fleets_ids), + FleetModel.deleted == False, + FleetModel.lock_expires_at.is_(None), + ) + .options( + selectinload(FleetModel.instances.and_(InstanceModel.id.in_(instances_ids))) + .selectinload(InstanceModel.jobs) + .load_only(JobModel.id) + ) + .options( + selectinload( + FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) + ).load_only(RunModel.status) + ) + .order_by(FleetModel.id) # take locks in order + .with_for_update(key_share=True, of=FleetModel) + .execution_options(populate_existing=True) ) - .execution_options(populate_existing=True) - .order_by(FleetModel.id) # take locks in order - .with_for_update(key_share=True, of=FleetModel) - ) - fleet_models = res.scalars().unique().all() + fleet_models = res.scalars().unique().all() + if len(fleet_models) == len(fleets_ids): + break + await asyncio.sleep(0.5) if len(fleet_models) != len(fleets_ids): - # TODO: Make the endpoint fully async so we don't need to lock and error: - # put the request in queue and process in the background. + # TODO: Make the endpoint fully async so we don't need to lock and error. msg = ( "Failed to delete fleets: fleets are being processed currently. Try again later." if instance_nums is None else "Failed to delete fleet instances: fleets are being processed currently. Try again later." ) raise ServerClientError(msg) - # Try locking instances in a retry loop. - # This is a hack to be able to delete fleets with many instances. - # Won't be necessary after requests are queued. + # Retry locking instances to increase lock acquisition chances. + # This hack is needed until requests are queued. instances_left_to_lock = set(instances_ids) for i in range(10): res = await session.execute( diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index ddc3d64c44..b4dfef083f 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -341,23 +341,28 @@ async def _delete_gateways_pipeline( async with get_locker(get_db().dialect_name).lock_ctx( GatewayModel.__tablename__, gateways_ids ): - # Refetch after lock - res = await session.execute( - select(GatewayModel) - .where( - GatewayModel.id.in_(gateways_ids), - GatewayModel.project_id == project.id, - GatewayModel.lock_expires_at.is_(None), + # Retry locking gateways to increase lock acquisition chances. + # This hack is needed until requests are queued. + gateway_models = [] + for i in range(10): + res = await session.execute( + select(GatewayModel) + .where( + GatewayModel.id.in_(gateways_ids), + GatewayModel.project_id == project.id, + GatewayModel.lock_expires_at.is_(None), + ) + .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) + .order_by(GatewayModel.id) # take locks in order + .with_for_update(key_share=True, of=GatewayModel) + .execution_options(populate_existing=True) ) - .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) - .order_by(GatewayModel.id) # take locks in order - .with_for_update(key_share=True, nowait=True, of=GatewayModel) - .execution_options(populate_existing=True) - ) - gateway_models = res.scalars().all() + gateway_models = res.scalars().all() + if len(gateway_models) == len(gateways_ids): + break + await asyncio.sleep(0.5) if len(gateway_models) != len(gateways_ids): - # TODO: Make the endpoint fully async so we don't need to lock and error: - # put the request in queue and process in the background. + # TODO: Make the endpoint fully async so we don't need to lock and error. raise ServerClientError( "Failed to delete gateways: gateways are being processed currently. Try again later." ) diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index 1c846c724f..ac2f88a5d1 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -1,3 +1,4 @@ +import asyncio import uuid from datetime import datetime, timedelta from typing import List, Optional @@ -353,24 +354,29 @@ async def _delete_volumes_pipeline( await session.commit() logger.info("Deleting volumes: %s", [v.name for v in volume_models]) async with get_locker(get_db().dialect_name).lock_ctx(VolumeModel.__tablename__, volumes_ids): - # Refetch after lock - res = await session.execute( - select(VolumeModel) - .where( - VolumeModel.project_id == project.id, - VolumeModel.id.in_(volumes_ids), - VolumeModel.deleted == False, - VolumeModel.lock_expires_at.is_(None), + # Retry locking volumes to increase lock acquisition chances. + # This hack is needed until requests are queued. + volume_models = [] + for i in range(10): + res = await session.execute( + select(VolumeModel) + .where( + VolumeModel.project_id == project.id, + VolumeModel.id.in_(volumes_ids), + VolumeModel.deleted == False, + VolumeModel.lock_expires_at.is_(None), + ) + .options(selectinload(VolumeModel.attachments)) + .order_by(VolumeModel.id) # take locks in order + .with_for_update(key_share=True, of=VolumeModel) + .execution_options(populate_existing=True) ) - .options(selectinload(VolumeModel.attachments)) - .execution_options(populate_existing=True) - .order_by(VolumeModel.id) # take locks in order - .with_for_update(key_share=True, of=VolumeModel) - ) - volume_models = res.scalars().unique().all() + volume_models = res.scalars().unique().all() + if len(volume_models) == len(volumes_ids): + break + await asyncio.sleep(0.5) if len(volume_models) != len(volumes_ids): - # TODO: Make the endpoint fully async so we don't need to lock and error: - # put the request in queue and process in the background. + # TODO: Make the endpoint fully async so we don't need to lock and error. raise ServerClientError( "Failed to delete volumes: volumes are being processed currently. Try again later." ) From caa0abd39422c3c764269a26e35d9fda29fede3a Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 6 Mar 2026 11:43:36 +0500 Subject: [PATCH 51/51] Fix created_at and last_processed_at init values --- src/dstack/_internal/server/services/fleets.py | 15 ++++++++++++--- src/dstack/_internal/server/services/instances.py | 4 +++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 600cba41a4..183e81b208 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -86,7 +86,12 @@ ) from dstack._internal.server.services.resources import set_resources_defaults from dstack._internal.utils import random_names -from dstack._internal.utils.common import EntityID, EntityName, EntityNameOrID +from dstack._internal.utils.common import ( + EntityID, + EntityName, + EntityNameOrID, + get_current_datetime, +) from dstack._internal.utils.logging import get_logger from dstack._internal.utils.ssh import pkey_from_str @@ -999,6 +1004,7 @@ async def _create_fleet( else: spec.configuration.name = await generate_fleet_name(session=session, project=project) + now = get_current_datetime() fleet_model = FleetModel( id=uuid.uuid4(), name=spec.configuration.name, @@ -1006,6 +1012,8 @@ async def _create_fleet( status=FleetStatus.ACTIVE, spec=spec.json(), instances=[], + created_at=now, + last_processed_at=now, ) session.add(fleet_model) events.emit( @@ -1057,9 +1065,10 @@ async def _create_fleet( targets=[events.Target.from_model(instance_model)], ) fleet_model.instances.append(instance_model) - pipeline_hinter.hint_fetch(FleetModel.__name__) - pipeline_hinter.hint_fetch(InstanceModel.__name__) await session.commit() + if spec.configuration.ssh_config is None: + pipeline_hinter.hint_fetch(FleetModel.__name__) + pipeline_hinter.hint_fetch(InstanceModel.__name__) return fleet_model_to_fleet(fleet_model) diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index 44cbcc4520..e07bce938b 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -679,12 +679,14 @@ def create_instance_model( reservation=reservation, tags=tags, ) + now = common_utils.get_current_datetime() instance = InstanceModel( id=instance_id, name=instance_name, instance_num=instance_num, project=project, - created_at=common_utils.get_current_datetime(), + created_at=now, + last_processed_at=now, status=InstanceStatus.PENDING, unreachable=False, profile=profile.json(),