diff --git a/docs/docs/concepts/backends.md b/docs/docs/concepts/backends.md
index 0213b669d4..a385b322a5 100644
--- a/docs/docs/concepts/backends.md
+++ b/docs/docs/concepts/backends.md
@@ -631,6 +631,29 @@ gcloud projects list --format="json(projectId)"
Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets.
Additionally, [Cloud NAT](https://cloud.google.com/nat/docs/overview) must be configured to provide access to external resources for provisioned instances.
+??? info "TPU"
+ By default, `dstack` does not include TPU offers.
+ To enable TPU provisioning, set `tpu` to `true` in the backend settings.
+
+
+
+ ```yaml
+ projects:
+ - name: main
+ backends:
+ - type: gcp
+ project_id: gcp-project-id
+ creds:
+ type: default
+
+ tpu: true
+ ```
+
+
+
+ Make sure the required TPU permissions and the `serviceAccountUser` role are granted
+ (see "Required permissions" above).
+
### Lambda
Log into your [Lambda Cloud](https://lambdalabs.com/service/gpu-cloud) account, click API keys in the sidebar, and then click the `Generate API key`
diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py
index cd5ecb829f..322f02f3b2 100644
--- a/src/dstack/_internal/core/backends/gcp/compute.py
+++ b/src/dstack/_internal/core/backends/gcp/compute.py
@@ -135,7 +135,7 @@ def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability
regions = get_or_error(self.config.regions)
offers = get_catalog_offers(
backend=BackendType.GCP,
- extra_filter=_supported_instances_and_zones(regions),
+ extra_filter=_supported_instances_and_zones(regions, tpu=self.config.allow_tpu),
)
quotas: Dict[str, Dict[str, float]] = defaultdict(dict)
for region in self.regions_client.list(project=self.config.project_id):
@@ -989,14 +989,17 @@ def _find_reservation(self, configured_name: str) -> dict[str, compute_v1.Reserv
def _supported_instances_and_zones(
regions: List[str],
+ tpu: bool = False,
) -> Optional[Callable[[InstanceOffer], bool]]:
def _filter(offer: InstanceOffer) -> bool:
# strip zone
if offer.region[:-2] not in regions:
return False
- # remove multi-host TPUs for initial release
- if _is_tpu(offer.instance.name) and not _is_single_host_tpu(offer.instance.name):
- return False
+ if _is_tpu(offer.instance.name):
+ if not tpu:
+ return False
+ if not _is_single_host_tpu(offer.instance.name):
+ return False
for family in [
"m4-",
"c4-",
diff --git a/src/dstack/_internal/core/backends/gcp/models.py b/src/dstack/_internal/core/backends/gcp/models.py
index 4d06144ee8..1fc6bbd995 100644
--- a/src/dstack/_internal/core/backends/gcp/models.py
+++ b/src/dstack/_internal/core/backends/gcp/models.py
@@ -5,6 +5,8 @@
from dstack._internal.core.backends.base.models import fill_data
from dstack._internal.core.models.common import CoreModel
+GCP_TPU_DEFAULT = False
+
class GCPServiceAccountCreds(CoreModel):
type: Annotated[Literal["service_account"], Field(description="The type of credentials")] = (
@@ -89,6 +91,15 @@ class GCPBackendConfig(CoreModel):
description="The tags (labels) that will be assigned to resources created by `dstack`"
),
] = None
+ tpu: Annotated[
+ Optional[bool],
+ Field(
+ description=(
+ "Whether TPU offers can be used for provisioning."
+ f" Defaults to `{str(GCP_TPU_DEFAULT).lower()}`"
+ )
+ ),
+ ] = None
preview_features: Annotated[
Optional[List[Literal["g4"]]],
Field(
@@ -143,6 +154,12 @@ class GCPStoredConfig(GCPBackendConfig):
class GCPConfig(GCPStoredConfig):
creds: AnyGCPCreds
+ @property
+ def allow_tpu(self) -> bool:
+ if self.tpu is not None:
+ return self.tpu
+ return GCP_TPU_DEFAULT
+
@property
def allocate_public_ips(self) -> bool:
if self.public_ips is not None:
diff --git a/src/tests/_internal/core/backends/gcp/test_compute.py b/src/tests/_internal/core/backends/gcp/test_compute.py
new file mode 100644
index 0000000000..89d29c5292
--- /dev/null
+++ b/src/tests/_internal/core/backends/gcp/test_compute.py
@@ -0,0 +1,83 @@
+from dstack._internal.core.backends.gcp.compute import _supported_instances_and_zones
+from dstack._internal.core.backends.gcp.models import GCPConfig, GCPDefaultCreds
+from dstack._internal.core.models.backends.base import BackendType
+from dstack._internal.core.models.instances import (
+ Gpu,
+ InstanceOffer,
+ InstanceType,
+ Resources,
+)
+
+
+def _make_offer(instance_name: str, region: str = "us-central1-a", gpus=None) -> InstanceOffer:
+ if gpus is None:
+ gpus = []
+ return InstanceOffer(
+ backend=BackendType.GCP,
+ instance=InstanceType(
+ name=instance_name,
+ resources=Resources(
+ cpus=8,
+ memory_mib=32768,
+ gpus=gpus,
+ spot=False,
+ ),
+ ),
+ region=region,
+ price=1.0,
+ )
+
+
+class TestSupportedInstancesAndZones:
+ def test_filters_tpu_when_disabled(self):
+ f = _supported_instances_and_zones(["us-central1"], tpu=False)
+ offer = _make_offer(
+ "v5litepod-8",
+ region="us-central1-b",
+ gpus=[Gpu(name="v5litepod", memory_mib=16384)],
+ )
+ assert f(offer) is False
+
+ def test_allows_single_host_tpu_when_enabled(self):
+ f = _supported_instances_and_zones(["us-central1"], tpu=True)
+ offer = _make_offer(
+ "v5litepod-8",
+ region="us-central1-b",
+ gpus=[Gpu(name="v5litepod", memory_mib=16384)],
+ )
+ assert f(offer) is True
+
+ def test_filters_multi_host_tpu_when_enabled(self):
+ f = _supported_instances_and_zones(["us-central1"], tpu=True)
+ offer = _make_offer(
+ "v5litepod-16",
+ region="us-central1-b",
+ gpus=[Gpu(name="v5litepod", memory_mib=16384)],
+ )
+ assert f(offer) is False
+
+ def test_allows_gpu_instances_regardless_of_tpu_flag(self):
+ f = _supported_instances_and_zones(["us-central1"], tpu=False)
+ offer = _make_offer(
+ "a2-highgpu-1g",
+ region="us-central1-b",
+ gpus=[Gpu(name="A100", memory_mib=40960)],
+ )
+ assert f(offer) is True
+
+
+class TestGCPConfigAllowTpu:
+ def _make_config(self, tpu=None) -> GCPConfig:
+ return GCPConfig(
+ project_id="test-project",
+ creds=GCPDefaultCreds(),
+ tpu=tpu,
+ )
+
+ def test_default(self):
+ config = self._make_config(tpu=None)
+ assert config.allow_tpu is False
+
+ def test_explicit_true(self):
+ config = self._make_config(tpu=True)
+ assert config.allow_tpu is True