diff --git a/README.md b/README.md index 1a0fd19..7b942c6 100644 --- a/README.md +++ b/README.md @@ -66,8 +66,13 @@ azd env set POSTGRES_ENTRA_ADMIN_OBJECT_ID "" azd env set POSTGRES_ENTRA_ADMIN_PRINCIPAL_NAME "" azd env set POSTGRES_ENTRA_ADMIN_PRINCIPAL_TYPE "User" azd env set DUCKLAKE_DATA_PATH "az://lakehouse/data/" +azd env set LAKEHOUSE_SECRET_KEY "$(openssl rand -base64 32)" ``` +`openssl rand -base64 32` generates 32 random bytes and stores their base64-encoded form, usually a 44-character string ending in `=`. Use that encoded string directly as `LAKEHOUSE_SECRET_KEY`; the server treats it as a UTF-8 HMAC/JWT signing key and does not base64-decode it. + +Azure deployments require `LAKEHOUSE_SECRET_KEY` before `azd up` so the deployed Container App keeps a stable token-signing key across restarts and revisions. + Find the Entra values if you need them: ```bash @@ -93,6 +98,7 @@ azd env set POSTGRES_ENTRA_ADMIN_OBJECT_ID "" azd env set POSTGRES_ENTRA_ADMIN_PRINCIPAL_NAME "" azd env set POSTGRES_ENTRA_ADMIN_PRINCIPAL_TYPE "User" azd env set DUCKLAKE_DATA_PATH "az://lakehouse/data/" +azd env set LAKEHOUSE_SECRET_KEY "$(openssl rand -base64 32)" azd up ``` @@ -134,6 +140,25 @@ mvn -q -Dexec.mainClass=lakehouse.AzureDemo test-compile exec:java The `MAVEN_OPTS` flag is required for Apache Arrow on Java 17+. +### Run the live backend tests + +The live backend pytest is opt-in because it queries the deployed Azure Container App and reads the `lakehouse-password` secret from Key Vault: + +```bash +LAKEHOUSE_LIVE_BACKEND=1 uv run pytest -q tests/test_live_azure_backend.py +``` + +That default path uses PyArrow to perform the Basic-token handshake, then gives the returned Bearer token to ADBC for the query. It verifies the deployed endpoint, TLS, Key Vault password, and Bearer auth path. + +There is also a separate opt-in check for ADBC's direct Basic-auth path: + +```bash +LAKEHOUSE_LIVE_BACKEND=1 LAKEHOUSE_LIVE_BACKEND_ADBC_BASIC=1 \ + uv run pytest -q tests/test_live_azure_backend.py +``` + +The ADBC Basic check is marked `xfail` because that is the known client path currently failing against the deployed Container App. A result such as `1 passed, 1 xfailed` means the supported bearer smoke test passed and the tracked ADBC Basic issue reproduced as expected. If that changes to `1 passed, 1 xpassed`, the ADBC Basic path has started working and the `xfail` marker should be removed. + If you want one copy/paste block for the demo itself: ```bash @@ -378,7 +403,7 @@ A few settings are environment-only (`.env` also works). | Database | `--database` | `LAKEHOUSE_DATABASE` | `:memory:` | CLI + Env | DuckDB database path | | Username | `--username` | `LAKEHOUSE_USERNAME` | `lakehouse` | CLI + Env | Auth username | | Password | `--password` | `LAKEHOUSE_PASSWORD` | *(empty)* | CLI + Env | Auth password (empty disables auth) | -| Secret Key | `--secret-key` | `LAKEHOUSE_SECRET_KEY` | *(auto-generated)* | CLI + Env | HMAC / JWT signing key | +| Secret Key | `--secret-key` | `LAKEHOUSE_SECRET_KEY` | *(auto-generated locally; required for Azure deploy)* | CLI + Env | HMAC / JWT signing key | | Health Port | `--health-check-port` | `LAKEHOUSE_HEALTH_CHECK_PORT` | `8081` | CLI + Env | gRPC health service port | | Health Enabled | `--health-check-enabled` | `LAKEHOUSE_HEALTH_CHECK_ENABLED` | `true` | CLI + Env | Enable health check server | | Log Level | `--log-level` | `LAKEHOUSE_LOG_LEVEL` | `INFO` | CLI + Env | Python log level | @@ -495,7 +520,7 @@ Lakehouse implements all standard Flight SQL RPCs: - **Azure Container Apps** — runs the Lakehouse Docker image - **User-assigned managed identity** — attached to the Container App, with `Storage Blob Data Contributor` RBAC - **PostgreSQL Entra admin principal** — granted `Storage Blob Data Contributor` RBAC for local DuckLake validation -- **Azure Key Vault** — stores the Lakehouse password +- **Azure Key Vault** — stores the Lakehouse password and stable HMAC/JWT signing key A `postprovision` hook runs automatically to configure PostgreSQL Entra auth grants for the managed identity. diff --git a/infra/main.bicep b/infra/main.bicep index c2ecc28..c4291be 100644 --- a/infra/main.bicep +++ b/infra/main.bicep @@ -76,6 +76,11 @@ param ducklakeDataPath string = 'az://lakehouse/data/' @description('Auto-generated Flight SQL password. Random on each fresh provision.') param lakehousePassword string = newGuid() +@secure() +@minLength(1) +@description('Stable non-empty UTF-8 HMAC/JWT signing key for the Flight SQL server.') +param lakehouseSecretKey string + var uniqueSuffix = toLower(uniqueString(subscription().id, resourceGroup().id, environmentName)) var storageAccountName = 'st${uniqueSuffix}' var postgresServerName = toLower('psql-${environmentName}-${substring(uniqueSuffix, 0, 6)}') @@ -126,6 +131,7 @@ module keyvault './modules/keyvault.bicep' = { containerAppPrincipalId: containerAppIdentity.properties.principalId deployerPrincipalId: postgresEntraAdminObjectId lakehousePassword: lakehousePassword + lakehouseSecretKey: lakehouseSecretKey } } @@ -156,6 +162,7 @@ module containerApp './modules/container-app.bicep' = { ducklakeDataPath: ducklakeDataPath acrLoginServer: acr.outputs.acrLoginServer lakehousePasswordSecretUri: keyvault.outputs.lakehousePasswordSecretUri + lakehouseSecretKeySecretUri: keyvault.outputs.lakehouseSecretKeySecretUri } } diff --git a/infra/main.json b/infra/main.json index 865a40c..8c8b178 100644 --- a/infra/main.json +++ b/infra/main.json @@ -5,7 +5,7 @@ "_generator": { "name": "bicep", "version": "0.40.2.10011", - "templateHash": "7370216803261364591" + "templateHash": "8435078362355386948" } }, "parameters": { @@ -152,6 +152,13 @@ "metadata": { "description": "Auto-generated Flight SQL password. Random on each fresh provision." } + }, + "lakehouseSecretKey": { + "type": "securestring", + "minLength": 1, + "metadata": { + "description": "Stable non-empty UTF-8 HMAC/JWT signing key for the Flight SQL server." + } } }, "variables": { @@ -511,6 +518,9 @@ }, "lakehousePassword": { "value": "[parameters('lakehousePassword')]" + }, + "lakehouseSecretKey": { + "value": "[parameters('lakehouseSecretKey')]" } }, "template": { @@ -520,7 +530,7 @@ "_generator": { "name": "bicep", "version": "0.40.2.10011", - "templateHash": "15964644140239373440" + "templateHash": "1804434252725837723" } }, "parameters": { @@ -554,6 +564,12 @@ "metadata": { "description": "Flight SQL password to store in the vault." } + }, + "lakehouseSecretKey": { + "type": "securestring", + "metadata": { + "description": "Flight SQL HMAC/JWT signing key to store in the vault." + } } }, "resources": [ @@ -622,6 +638,21 @@ "dependsOn": [ "[resourceId('Microsoft.KeyVault/vaults', parameters('keyVaultName'))]" ] + }, + { + "type": "Microsoft.KeyVault/vaults/secrets", + "apiVersion": "2023-07-01", + "name": "[format('{0}/{1}', parameters('keyVaultName'), 'lakehouse-secret-key')]", + "properties": { + "value": "[parameters('lakehouseSecretKey')]", + "contentType": "text/plain", + "attributes": { + "enabled": true + } + }, + "dependsOn": [ + "[resourceId('Microsoft.KeyVault/vaults', parameters('keyVaultName'))]" + ] } ], "outputs": { @@ -636,6 +667,10 @@ "lakehousePasswordSecretUri": { "type": "string", "value": "[reference(resourceId('Microsoft.KeyVault/vaults/secrets', parameters('keyVaultName'), 'lakehouse-password'), '2023-07-01').secretUri]" + }, + "lakehouseSecretKeySecretUri": { + "type": "string", + "value": "[reference(resourceId('Microsoft.KeyVault/vaults/secrets', parameters('keyVaultName'), 'lakehouse-secret-key'), '2023-07-01').secretUri]" } } } @@ -809,6 +844,9 @@ }, "lakehousePasswordSecretUri": { "value": "[reference(resourceId('Microsoft.Resources/deployments', 'keyvault'), '2025-04-01').outputs.lakehousePasswordSecretUri.value]" + }, + "lakehouseSecretKeySecretUri": { + "value": "[reference(resourceId('Microsoft.Resources/deployments', 'keyvault'), '2025-04-01').outputs.lakehouseSecretKeySecretUri.value]" } }, "template": { @@ -818,7 +856,7 @@ "_generator": { "name": "bicep", "version": "0.40.2.10011", - "templateHash": "6413779264902107673" + "templateHash": "7408886526487755529" } }, "parameters": { @@ -869,6 +907,12 @@ "metadata": { "description": "Key Vault secret URI for the Flight SQL password." } + }, + "lakehouseSecretKeySecretUri": { + "type": "string", + "metadata": { + "description": "Key Vault secret URI for the Flight SQL HMAC/JWT signing key." + } } }, "resources": [ @@ -930,6 +974,11 @@ "name": "lakehouse-password", "keyVaultUrl": "[parameters('lakehousePasswordSecretUri')]", "identity": "[parameters('containerAppIdentityId')]" + }, + { + "name": "lakehouse-secret-key", + "keyVaultUrl": "[parameters('lakehouseSecretKeySecretUri')]", + "identity": "[parameters('containerAppIdentityId')]" } ], "registries": [ @@ -981,6 +1030,10 @@ { "name": "LAKEHOUSE_PASSWORD", "secretRef": "lakehouse-password" + }, + { + "name": "LAKEHOUSE_SECRET_KEY", + "secretRef": "lakehouse-secret-key" } ], "probes": [ diff --git a/infra/main.parameters.json b/infra/main.parameters.json index 6cee846..1a03fec 100644 --- a/infra/main.parameters.json +++ b/infra/main.parameters.json @@ -31,6 +31,9 @@ }, "ducklakeDataPath": { "value": "${DUCKLAKE_DATA_PATH}" + }, + "lakehouseSecretKey": { + "value": "${LAKEHOUSE_SECRET_KEY}" } } } diff --git a/infra/modules/container-app.bicep b/infra/modules/container-app.bicep index b31cc8d..c56f100 100644 --- a/infra/modules/container-app.bicep +++ b/infra/modules/container-app.bicep @@ -16,6 +16,9 @@ param acrLoginServer string @description('Key Vault secret URI for the Flight SQL password.') param lakehousePasswordSecretUri string +@description('Key Vault secret URI for the Flight SQL HMAC/JWT signing key.') +param lakehouseSecretKeySecretUri string + resource logAnalyticsWorkspace 'Microsoft.OperationalInsights/workspaces@2022-10-01' = { name: 'law-${containerAppsEnvironmentName}' location: location @@ -68,6 +71,11 @@ resource containerApp 'Microsoft.App/containerApps@2024-03-01' = { keyVaultUrl: lakehousePasswordSecretUri identity: containerAppIdentityId } + { + name: 'lakehouse-secret-key' + keyVaultUrl: lakehouseSecretKeySecretUri + identity: containerAppIdentityId + } ] registries: [ { @@ -119,6 +127,10 @@ resource containerApp 'Microsoft.App/containerApps@2024-03-01' = { name: 'LAKEHOUSE_PASSWORD' secretRef: 'lakehouse-password' } + { + name: 'LAKEHOUSE_SECRET_KEY' + secretRef: 'lakehouse-secret-key' + } ] probes: [ { diff --git a/infra/modules/keyvault.bicep b/infra/modules/keyvault.bicep index 4ba13ae..fff6030 100644 --- a/infra/modules/keyvault.bicep +++ b/infra/modules/keyvault.bicep @@ -14,6 +14,10 @@ param deployerPrincipalId string = '' @description('Flight SQL password to store in the vault.') param lakehousePassword string +@secure() +@description('Flight SQL HMAC/JWT signing key to store in the vault.') +param lakehouseSecretKey string + resource keyVault 'Microsoft.KeyVault/vaults@2023-07-01' = { name: keyVaultName location: location @@ -76,9 +80,26 @@ resource lakehousePasswordSecret 'Microsoft.KeyVault/vaults/secrets@2023-07-01' } } +// ── Secret: lakehouse-secret-key ──────────────────────────────────────── +resource lakehouseSecretKeySecret 'Microsoft.KeyVault/vaults/secrets@2023-07-01' = { + parent: keyVault + name: 'lakehouse-secret-key' + properties: { + value: lakehouseSecretKey + contentType: 'text/plain' + attributes: { + enabled: true + } + } +} + output keyVaultName string = keyVault.name output keyVaultUri string = keyVault.properties.vaultUri // URI only (no secret value) — safe to expose for Container App Key Vault references. #disable-next-line outputs-should-not-contain-secrets output lakehousePasswordSecretUri string = lakehousePasswordSecret.properties.secretUri + +// URI only (no secret value) — safe to expose for Container App Key Vault references. +#disable-next-line outputs-should-not-contain-secrets +output lakehouseSecretKeySecretUri string = lakehouseSecretKeySecret.properties.secretUri diff --git a/scripts/allow-current-ip-postgres.sh b/scripts/allow-current-ip-postgres.sh new file mode 100755 index 0000000..aec8513 --- /dev/null +++ b/scripts/allow-current-ip-postgres.sh @@ -0,0 +1,87 @@ +#!/usr/bin/env bash +set -euo pipefail + +script_dir="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +repo_root="$(cd -- "${script_dir}/.." && pwd)" + +if ! command -v az >/dev/null 2>&1; then + echo "Azure CLI (az) is not installed." + exit 1 +fi + +if ! az account show >/dev/null 2>&1; then + echo "Azure CLI (az) is not authenticated; run az login first." + exit 1 +fi + +if command -v azd >/dev/null 2>&1; then + # Load only the non-secret azd values needed for this firewall rule. + while IFS= read -r line; do + [[ -z "${line}" ]] && continue + line="${line#export }" + [[ "${line}" != *=* ]] && continue + key="${line%%=*}" + value="${line#*=}" + value="${value#\"}" + value="${value%\"}" + case "${key}" in + AZURE_RESOURCE_GROUP) + [[ -z "${AZURE_RESOURCE_GROUP:-}" ]] && AZURE_RESOURCE_GROUP="${value}" + ;; + POSTGRES_SERVER_NAME) + [[ -z "${POSTGRES_SERVER_NAME:-}" ]] && POSTGRES_SERVER_NAME="${value}" + ;; + esac + done < <(cd "${repo_root}" && azd env get-values) +fi + +required_vars=( + AZURE_RESOURCE_GROUP + POSTGRES_SERVER_NAME +) + +for var_name in "${required_vars[@]}"; do + if [[ -z "${!var_name:-}" ]]; then + echo "Missing ${var_name}; set it or run this from an azd environment." + exit 1 + fi +done + +rule_name="${FIREWALL_RULE_NAME:-AllowCurrentClientIpForTests}" +current_ip="${CURRENT_IP:-}" +if [[ -z "${current_ip}" ]]; then + current_ip="$(curl -fsS https://api.ipify.org)" +fi + +if [[ -z "${current_ip}" ]]; then + echo "Could not detect current public IP; set CURRENT_IP to override." + exit 1 +fi + +echo "Allowing current IP ${current_ip} on PostgreSQL server ${POSTGRES_SERVER_NAME}." + +if az postgres flexible-server firewall-rule show \ + --resource-group "${AZURE_RESOURCE_GROUP}" \ + --name "${POSTGRES_SERVER_NAME}" \ + --rule-name "${rule_name}" \ + --only-show-errors >/dev/null 2>&1; then + az postgres flexible-server firewall-rule update \ + --resource-group "${AZURE_RESOURCE_GROUP}" \ + --name "${POSTGRES_SERVER_NAME}" \ + --rule-name "${rule_name}" \ + --start-ip-address "${current_ip}" \ + --end-ip-address "${current_ip}" \ + --only-show-errors \ + --output none +else + az postgres flexible-server firewall-rule create \ + --resource-group "${AZURE_RESOURCE_GROUP}" \ + --name "${POSTGRES_SERVER_NAME}" \ + --rule-name "${rule_name}" \ + --start-ip-address "${current_ip}" \ + --end-ip-address "${current_ip}" \ + --only-show-errors \ + --output none +fi + +echo "Firewall rule ${rule_name} now allows ${current_ip}." diff --git a/src/lakehouse/__main__.py b/src/lakehouse/__main__.py index 2b74b8b..f674765 100644 --- a/src/lakehouse/__main__.py +++ b/src/lakehouse/__main__.py @@ -29,6 +29,7 @@ BasicAuthServerMiddlewareFactory, BearerAuthServerMiddlewareFactory, NoOpAuthHandler, + RequiredAuthServerMiddlewareFactory, ) from lakehouse.config import ServerConfig from lakehouse.health import BackgroundHealthPoller, HealthServer @@ -75,6 +76,7 @@ def build_server(config: ServerConfig) -> DuckDBFlightSqlServer: secret_key=config.secret_key, issuer=config.jwt_issuer, ) + middleware["required-auth"] = RequiredAuthServerMiddlewareFactory() logger.info("Authentication enabled (username=%s)", config.username) else: logger.warning("No password configured — authentication is DISABLED") diff --git a/src/lakehouse/_azd_env.py b/src/lakehouse/_azd_env.py new file mode 100644 index 0000000..fc95344 --- /dev/null +++ b/src/lakehouse/_azd_env.py @@ -0,0 +1,312 @@ +"""Local Azure Developer CLI environment discovery helpers. + +The azd environment directory is intentionally ignored by git because it +contains local deployment state and may include secrets. This module reads +that state only to derive the non-secret values needed by local e2e tests. +""" + +from __future__ import annotations + +import json +import os +import shlex +import shutil +import subprocess +from collections.abc import Callable, Mapping, MutableMapping, Sequence +from dataclasses import dataclass +from pathlib import Path + +__all__ = [ + "CONTAINER_APP_AZD_ENV_KEYS", + "DUCKLAKE_REQUIRED_ENV", + "AzdValues", + "EnvResolution", + "apply_env_resolution", + "load_azd_values", + "parse_azd_env_lines", + "postgres_firewall_hint", + "resolve_container_app_env", + "resolve_ducklake_env", +] + +DUCKLAKE_REQUIRED_ENV: tuple[str, ...] = ( + "DUCKLAKE_PG_HOST", + "DUCKLAKE_PG_DATABASE", + "DUCKLAKE_PG_USER", + "DUCKLAKE_AZURE_STORAGE_ACCOUNT", + "DUCKLAKE_DATA_PATH", +) + +DUCKLAKE_AZD_ENV_MAP: Mapping[str, str] = { + "POSTGRES_FQDN": "DUCKLAKE_PG_HOST", + "POSTGRES_DATABASE_NAME": "DUCKLAKE_PG_DATABASE", + "POSTGRES_ENTRA_ADMIN_PRINCIPAL_NAME": "DUCKLAKE_PG_USER", + "STORAGE_ACCOUNT_NAME": "DUCKLAKE_AZURE_STORAGE_ACCOUNT", + "DUCKLAKE_DATA_PATH": "DUCKLAKE_DATA_PATH", +} + +CONTAINER_APP_AZD_ENV_KEYS: tuple[str, ...] = ( + "AZURE_RESOURCE_GROUP", + "CONTAINER_APP_NAME", + "KEY_VAULT_NAME", +) + +_CommandRunner = Callable[[Sequence[str], Path], subprocess.CompletedProcess[str]] + + +@dataclass(frozen=True) +class EnvResolution: + """Resolved environment values plus missing-key metadata.""" + + values: dict[str, str] + missing: tuple[str, ...] + source: str + azd_environment: str | None = None + + @property + def ready(self) -> bool: + """Whether all required values were discovered.""" + return not self.missing + + def skip_reason(self, prefix: str = "Required environment values missing") -> str: + """Return a secret-safe pytest skip reason.""" + if self.ready: + return "" + missing = ", ".join(self.missing) + detail = ( + "set explicit environment variables or run `azd up` to populate " + "the local .azure environment" + ) + if self.azd_environment: + detail = f"{detail} ({self.azd_environment})" + return f"{prefix}: {missing}; {detail}." + + +@dataclass(frozen=True) +class AzdValues: + """Raw azd values plus their source metadata.""" + + values: dict[str, str] + source: str + environment: str | None + + +def parse_azd_env_lines(lines: str) -> dict[str, str]: + """Parse azd KEY=VALUE lines without expanding shell syntax.""" + values: dict[str, str] = {} + for raw_line in lines.splitlines(): + line = raw_line.strip() + if not line or line.startswith("#"): + continue + if line.startswith("export "): + line = line[len("export ") :].lstrip() + if "=" not in line: + continue + key, raw_value = line.split("=", 1) + key = key.strip() + if not key: + continue + values[key] = _parse_env_value(raw_value.strip()) + return values + + +def load_azd_values( + repo_root: Path | None = None, + *, + use_azd_cli: bool = True, + command_runner: _CommandRunner | None = None, +) -> AzdValues: + """Load local azd values from the CLI, falling back to .azure files.""" + root = find_repo_root(repo_root) + runner = command_runner or _run_command + default_environment = _read_default_environment(root) + + if use_azd_cli: + cli_values = _load_cli_values(root, runner) + if cli_values.values: + environment = cli_values.values.get("AZURE_ENV_NAME") or default_environment + return AzdValues(cli_values.values, cli_values.source, environment) + + return _load_file_values(root, default_environment) + + +def resolve_ducklake_env( + repo_root: Path | None = None, + environ: Mapping[str, str] | None = None, + *, + use_azd_cli: bool = True, + command_runner: _CommandRunner | None = None, +) -> EnvResolution: + """Resolve DuckLake e2e variables from explicit env and azd state.""" + env = os.environ if environ is None else environ + azd = load_azd_values( + repo_root, + use_azd_cli=use_azd_cli, + command_runner=command_runner, + ) + azd_values = _map_azd_to_ducklake(azd.values) + values = _resolve_values(DUCKLAKE_REQUIRED_ENV, env, azd_values) + missing = tuple(name for name in DUCKLAKE_REQUIRED_ENV if not values.get(name)) + source = _source_label(DUCKLAKE_REQUIRED_ENV, env, azd_values, azd.source) + return EnvResolution(values, missing, source, azd.environment) + + +def resolve_container_app_env( + repo_root: Path | None = None, + environ: Mapping[str, str] | None = None, + *, + use_azd_cli: bool = True, + command_runner: _CommandRunner | None = None, +) -> EnvResolution: + """Resolve non-secret azd values needed to find the deployed Container App.""" + env = os.environ if environ is None else environ + azd = load_azd_values( + repo_root, + use_azd_cli=use_azd_cli, + command_runner=command_runner, + ) + values = _resolve_values(CONTAINER_APP_AZD_ENV_KEYS, env, azd.values) + missing = tuple(name for name in CONTAINER_APP_AZD_ENV_KEYS if not values.get(name)) + source = _source_label(CONTAINER_APP_AZD_ENV_KEYS, env, azd.values, azd.source) + return EnvResolution(values, missing, source, azd.environment) + + +def apply_env_resolution( + resolution: EnvResolution, + environ: MutableMapping[str, str] | None = None, +) -> None: + """Apply discovered values, treating blank explicit env vars as unset.""" + env = os.environ if environ is None else environ + for name, value in resolution.values.items(): + if not env.get(name): + env[name] = value + + +def postgres_firewall_hint() -> str: + """Return a secret-safe hint for Azure PostgreSQL network failures.""" + return ( + "If Azure PostgreSQL is unreachable, allow this machine's current public IP " + "on the PostgreSQL Flexible Server firewall. Prefer running " + "`./scripts/allow-current-ip-postgres.sh`, or create an equivalent temporary " + "rule with `az postgres flexible-server firewall-rule create`." + ) + + +def find_repo_root(start: Path | None = None) -> Path: + """Find the nearest ancestor that looks like this azd project root.""" + current = (start or Path.cwd()).resolve() + if current.is_file(): + current = current.parent + for candidate in (current, *current.parents): + if (candidate / "azure.yaml").is_file(): + return candidate + return current + + +def _parse_env_value(raw_value: str) -> str: + if raw_value == "": + return "" + try: + parts = shlex.split(raw_value, posix=True) + except ValueError: + return raw_value.strip().strip("\"'") + if not parts: + return "" + return " ".join(parts) + + +def _run_command(command: Sequence[str], cwd: Path) -> subprocess.CompletedProcess[str]: + return subprocess.run( + command, + cwd=cwd, + capture_output=True, + check=False, + text=True, + ) + + +def _load_cli_values(root: Path, runner: _CommandRunner) -> AzdValues: + if runner is _run_command and shutil.which("azd") is None: + return AzdValues({}, "missing", None) + try: + result = runner(("azd", "env", "get-values"), root) + except OSError: + return AzdValues({}, "missing", None) + if result.returncode != 0: + return AzdValues({}, "missing", None) + values = parse_azd_env_lines(result.stdout) + return AzdValues(values, "azd-cli", values.get("AZURE_ENV_NAME")) + + +def _load_file_values(root: Path, default_environment: str | None) -> AzdValues: + for environment in _candidate_environments(root, default_environment): + env_file = root / ".azure" / environment / ".env" + if env_file.is_file(): + return AzdValues( + parse_azd_env_lines(env_file.read_text(encoding="utf-8")), + "azd-file", + environment, + ) + return AzdValues({}, "missing", default_environment) + + +def _read_default_environment(root: Path) -> str | None: + config_file = root / ".azure" / "config.json" + if not config_file.is_file(): + return None + try: + raw_config = json.loads(config_file.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + return None + environment = raw_config.get("defaultEnvironment") + if isinstance(environment, str) and environment: + return environment + return None + + +def _candidate_environments(root: Path, default_environment: str | None) -> tuple[str, ...]: + names: list[str] = [] + if default_environment: + names.append(default_environment) + azure_dir = root / ".azure" + if azure_dir.is_dir(): + for child in sorted(azure_dir.iterdir()): + if child.is_dir() and (child / ".env").is_file() and child.name not in names: + names.append(child.name) + return tuple(names) + + +def _map_azd_to_ducklake(azd_values: Mapping[str, str]) -> dict[str, str]: + return { + ducklake_key: azd_values[azd_key] + for azd_key, ducklake_key in DUCKLAKE_AZD_ENV_MAP.items() + if azd_values.get(azd_key) + } + + +def _resolve_values( + required_keys: Sequence[str], + environ: Mapping[str, str], + azd_values: Mapping[str, str], +) -> dict[str, str]: + values: dict[str, str] = {} + for name in required_keys: + if environ.get(name): + values[name] = environ[name] + elif azd_values.get(name): + values[name] = azd_values[name] + return values + + +def _source_label( + required_keys: Sequence[str], + environ: Mapping[str, str], + azd_values: Mapping[str, str], + azd_source: str, +) -> str: + sources: list[str] = [] + if any(environ.get(name) for name in required_keys): + sources.append("environment") + if azd_source != "missing" and any(azd_values.get(name) for name in required_keys): + sources.append(azd_source) + return "+".join(sources) if sources else "missing" diff --git a/src/lakehouse/auth.py b/src/lakehouse/auth.py index 3d6bb90..ac1698c 100644 --- a/src/lakehouse/auth.py +++ b/src/lakehouse/auth.py @@ -7,6 +7,8 @@ response headers. * **BearerAuthServerMiddlewareFactory** — validates incoming ``Bearer`` JWT tokens. +* **RequiredAuthServerMiddlewareFactory** — rejects calls with no supported + authorization header when auth is enabled. * **AccessLogMiddlewareFactory** — logs every RPC call with method name and elapsed time. """ @@ -286,6 +288,36 @@ def start_call( return BearerAuthServerMiddleware(payload=payload) +# ═══════════════════════════════════════════════════════════════════════════ +# RequiredAuthServerMiddlewareFactory +# ═══════════════════════════════════════════════════════════════════════════ + + +class RequiredAuthServerMiddlewareFactory(flight.ServerMiddlewareFactory): + """Reject calls that do not present a supported Authorization header.""" + + def start_call( + self, + info: flight.CallInfo, + headers: dict[str, list[str]], + ) -> None: + """Require either Basic or Bearer auth on every Flight RPC. + + Args: + info: RPC call information. + headers: Incoming request headers. + + Raises: + flight.FlightUnauthenticatedError: If the request has no supported + Authorization header. + """ + auth_header = _get_header(headers, "authorization") + if auth_header is None: + raise flight.FlightUnauthenticatedError("Authorization header is required") + if not auth_header.startswith(("Basic ", "Bearer ")): + raise flight.FlightUnauthenticatedError("Unsupported authorization scheme") + + # ═══════════════════════════════════════════════════════════════════════════ # AccessLogMiddlewareFactory # ═══════════════════════════════════════════════════════════════════════════ diff --git a/tests/jdbc/run_local_jdbc_tests.py b/tests/jdbc/run_local_jdbc_tests.py index 935f3a2..8b05e6e 100644 --- a/tests/jdbc/run_local_jdbc_tests.py +++ b/tests/jdbc/run_local_jdbc_tests.py @@ -10,13 +10,7 @@ import duckdb -_REQUIRED_ENV = ( - "DUCKLAKE_PG_HOST", - "DUCKLAKE_PG_DATABASE", - "DUCKLAKE_PG_USER", - "DUCKLAKE_AZURE_STORAGE_ACCOUNT", - "DUCKLAKE_DATA_PATH", -) +from lakehouse._azd_env import apply_env_resolution, postgres_firewall_hint, resolve_ducklake_env def _free_port() -> int: @@ -26,10 +20,11 @@ def _free_port() -> int: def _require_environment() -> None: - missing = [name for name in _REQUIRED_ENV if not os.environ.get(name)] - if missing: - joined = ", ".join(missing) - raise SystemExit(f"Missing DuckLake environment variables: {joined}") + resolution = resolve_ducklake_env() + apply_env_resolution(resolution) + + if resolution.missing: + raise SystemExit(resolution.skip_reason("Missing DuckLake environment variables")) if shutil.which("mvn") is None: raise SystemExit("Maven (mvn) not found on PATH") @@ -69,10 +64,10 @@ def main() -> int: except duckdb.Error as exc: message = ( "Failed to bootstrap the local DuckLake catalog. " - "Verify the DUCKLAKE_* settings and that this machine can reach the configured " - "PostgreSQL server." + "Verify the DUCKLAKE_* settings or azd environment outputs. " + f"{postgres_firewall_hint()}" ) - raise SystemExit(f"{message}\n{exc}") from exc + raise SystemExit(f"{message}\nError type: {type(exc).__name__}") from None thread = threading.Thread(target=server.serve, daemon=True) thread.start() diff --git a/tests/test_auth.py b/tests/test_auth.py index 22eb8c4..06f7bbc 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -18,6 +18,7 @@ BearerAuthServerMiddleware, BearerAuthServerMiddlewareFactory, NoOpAuthHandler, + RequiredAuthServerMiddlewareFactory, _get_header, _parse_basic_header, ) @@ -63,6 +64,12 @@ def access_factory(): return AccessLogMiddlewareFactory() +@pytest.fixture +def required_auth_factory(): + """RequiredAuthServerMiddlewareFactory instance.""" + return RequiredAuthServerMiddlewareFactory() + + @pytest.fixture def call_info(): """A minimal CallInfo for testing.""" @@ -364,6 +371,39 @@ def test_payload_accessible(self): assert mw.payload["role"] == "admin" +# ═══════════════════════════════════════════════════════════════════════════ +# RequiredAuthServerMiddlewareFactory +# ═══════════════════════════════════════════════════════════════════════════ +class TestRequiredAuthFactoryStartCall: + """Tests for RequiredAuthServerMiddlewareFactory.start_call().""" + + def test_missing_auth_header_rejected(self, required_auth_factory, call_info): + """No authorization header is rejected when auth is required.""" + with pytest.raises( + flight.FlightUnauthenticatedError, + match="Authorization header is required", + ): + required_auth_factory.start_call(call_info, {}) + + def test_basic_auth_header_allowed(self, required_auth_factory, call_info): + """Basic auth is allowed for BasicAuthServerMiddlewareFactory to validate.""" + headers = _make_basic_header(USERNAME, PASSWORD) + assert required_auth_factory.start_call(call_info, headers) is None + + def test_bearer_auth_header_allowed(self, required_auth_factory, call_info): + """Bearer auth is allowed for BearerAuthServerMiddlewareFactory to validate.""" + headers = _make_bearer_header("jwt") + assert required_auth_factory.start_call(call_info, headers) is None + + def test_unsupported_auth_scheme_rejected(self, required_auth_factory, call_info): + """Unsupported authorization schemes are rejected.""" + with pytest.raises( + flight.FlightUnauthenticatedError, + match="Unsupported authorization scheme", + ): + required_auth_factory.start_call(call_info, {"authorization": ["Token abc"]}) + + # ═══════════════════════════════════════════════════════════════════════════ # AccessLogMiddlewareFactory # ═══════════════════════════════════════════════════════════════════════════ diff --git a/tests/test_azd_env.py b/tests/test_azd_env.py new file mode 100644 index 0000000..5ed99a7 --- /dev/null +++ b/tests/test_azd_env.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import subprocess + +from lakehouse._azd_env import ( + DUCKLAKE_REQUIRED_ENV, + apply_env_resolution, + load_azd_values, + parse_azd_env_lines, + postgres_firewall_hint, + resolve_container_app_env, + resolve_ducklake_env, +) + + +def _write_azd_env(tmp_path, body: str, env_name: str = "lakehouse-dev") -> None: + azure_dir = tmp_path / ".azure" + env_dir = azure_dir / env_name + env_dir.mkdir(parents=True) + (azure_dir / "config.json").write_text( + f'{{"version": 1, "defaultEnvironment": "{env_name}"}}', + encoding="utf-8", + ) + (env_dir / ".env").write_text(body, encoding="utf-8") + (tmp_path / "azure.yaml").write_text("name: lakehouse\n", encoding="utf-8") + + +def _azd_ducklake_body() -> str: + return """ +AZURE_ENV_NAME="lakehouse-dev" +POSTGRES_FQDN="pg.example.postgres.database.azure.com" +POSTGRES_DATABASE_NAME="ducklake" +POSTGRES_ENTRA_ADMIN_PRINCIPAL_NAME="user#EXT#@example.onmicrosoft.com" +STORAGE_ACCOUNT_NAME="stlakehouse" +DUCKLAKE_DATA_PATH="az://lakehouse/data/" +AZURE_RESOURCE_GROUP="rg-lakehouse" +CONTAINER_APP_NAME="ca-lakehouse" +KEY_VAULT_NAME="kv-lakehouse" +POSTGRES_ADMIN_PASSWORD="super-secret-password" +""" + + +def test_parse_azd_env_lines_handles_quotes_and_hashes(): + values = parse_azd_env_lines( + """ +export AZURE_ENV_NAME="lakehouse-dev" +POSTGRES_ENTRA_ADMIN_PRINCIPAL_NAME='user#EXT#@example.onmicrosoft.com' +DUCKLAKE_DATA_PATH=az://lakehouse/data/ +""" + ) + + assert values["AZURE_ENV_NAME"] == "lakehouse-dev" + assert values["POSTGRES_ENTRA_ADMIN_PRINCIPAL_NAME"] == ("user#EXT#@example.onmicrosoft.com") + assert values["DUCKLAKE_DATA_PATH"] == "az://lakehouse/data/" + + +def test_resolve_ducklake_env_maps_azd_outputs_from_default_environment(tmp_path): + _write_azd_env(tmp_path, _azd_ducklake_body()) + + resolution = resolve_ducklake_env(tmp_path, environ={}, use_azd_cli=False) + + assert resolution.ready + assert resolution.missing == () + assert resolution.source == "azd-file" + assert resolution.azd_environment == "lakehouse-dev" + assert resolution.values == { + "DUCKLAKE_PG_HOST": "pg.example.postgres.database.azure.com", + "DUCKLAKE_PG_DATABASE": "ducklake", + "DUCKLAKE_PG_USER": "user#EXT#@example.onmicrosoft.com", + "DUCKLAKE_AZURE_STORAGE_ACCOUNT": "stlakehouse", + "DUCKLAKE_DATA_PATH": "az://lakehouse/data/", + } + + +def test_resolve_ducklake_env_prefers_explicit_environment_values(tmp_path): + _write_azd_env(tmp_path, _azd_ducklake_body()) + + resolution = resolve_ducklake_env( + tmp_path, + environ={"DUCKLAKE_PG_HOST": "manual.postgres.database.azure.com"}, + use_azd_cli=False, + ) + + assert resolution.ready + assert resolution.source == "environment+azd-file" + assert resolution.values["DUCKLAKE_PG_HOST"] == "manual.postgres.database.azure.com" + assert resolution.values["DUCKLAKE_PG_DATABASE"] == "ducklake" + + +def test_apply_env_resolution_treats_blank_environment_values_as_unset(tmp_path): + _write_azd_env(tmp_path, _azd_ducklake_body()) + environ = {"DUCKLAKE_PG_HOST": ""} + + resolution = resolve_ducklake_env(tmp_path, environ=environ, use_azd_cli=False) + apply_env_resolution(resolution, environ) + + assert resolution.ready + assert environ["DUCKLAKE_PG_HOST"] == "pg.example.postgres.database.azure.com" + assert environ["DUCKLAKE_PG_DATABASE"] == "ducklake" + + +def test_resolve_ducklake_env_reports_missing_config_without_crashing(tmp_path): + (tmp_path / "azure.yaml").write_text("name: lakehouse\n", encoding="utf-8") + + resolution = resolve_ducklake_env(tmp_path, environ={}, use_azd_cli=False) + + assert not resolution.ready + assert resolution.missing == DUCKLAKE_REQUIRED_ENV + assert "DuckLake env vars missing" in resolution.skip_reason("DuckLake env vars missing") + + +def test_resolve_ducklake_env_does_not_leak_secret_values(tmp_path): + _write_azd_env(tmp_path, _azd_ducklake_body()) + + resolution = resolve_ducklake_env(tmp_path, environ={}, use_azd_cli=False) + + assert "super-secret-password" not in repr(resolution) + assert "super-secret-password" not in resolution.skip_reason("DuckLake env vars missing") + + +def test_azd_cli_values_are_preferred_before_file_values(tmp_path): + _write_azd_env(tmp_path, _azd_ducklake_body()) + + def runner(command, cwd): + return subprocess.CompletedProcess( + command, + 0, + stdout=""" +AZURE_ENV_NAME="lakehouse-cli" +POSTGRES_FQDN="pg.cli.postgres.database.azure.com" +POSTGRES_DATABASE_NAME="ducklake_cli" +POSTGRES_ENTRA_ADMIN_PRINCIPAL_NAME="cli-user@example.com" +STORAGE_ACCOUNT_NAME="stcli" +DUCKLAKE_DATA_PATH="az://cli/data/" +""", + stderr="", + ) + + values = load_azd_values(tmp_path, command_runner=runner) + resolution = resolve_ducklake_env(tmp_path, environ={}, command_runner=runner) + + assert values.source == "azd-cli" + assert values.environment == "lakehouse-cli" + assert resolution.source == "azd-cli" + assert resolution.values["DUCKLAKE_PG_HOST"] == "pg.cli.postgres.database.azure.com" + + +def test_resolve_container_app_env_maps_non_secret_azd_outputs(tmp_path): + _write_azd_env(tmp_path, _azd_ducklake_body()) + + resolution = resolve_container_app_env(tmp_path, environ={}, use_azd_cli=False) + + assert resolution.ready + assert resolution.values == { + "AZURE_RESOURCE_GROUP": "rg-lakehouse", + "CONTAINER_APP_NAME": "ca-lakehouse", + "KEY_VAULT_NAME": "kv-lakehouse", + } + + +def test_postgres_firewall_hint_points_to_ip_allowlist_without_values(): + hint = postgres_firewall_hint() + + assert "./scripts/allow-current-ip-postgres.sh" in hint + assert "az postgres flexible-server firewall-rule create" in hint + assert "pg.example.postgres.database.azure.com" not in hint + assert "super-secret-password" not in hint diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 3eff1ed..b7e20cd 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -49,6 +49,7 @@ AccessLogMiddlewareFactory, BasicAuthServerMiddlewareFactory, BearerAuthServerMiddlewareFactory, + RequiredAuthServerMiddlewareFactory, ) from lakehouse.health import BackgroundHealthPoller, HealthServer from lakehouse.security import hash_password @@ -153,6 +154,7 @@ def auth_server(): "bearer-auth": BearerAuthServerMiddlewareFactory( secret_key=_TEST_SECRET, ), + "required-auth": RequiredAuthServerMiddlewareFactory(), } srv = DuckDBFlightSqlServer( @@ -363,7 +365,7 @@ def test_wrong_password_rejected(self, auth_server): """Connection with wrong password is rejected.""" _srv, port = auth_server bad_token = base64.b64encode(f"{_TEST_USERNAME}:wrong-password".encode()).decode() - with pytest.raises(Exception): # noqa: B017 + with pytest.raises(Exception, match=r"UNAUTHENTICATED|Invalid credentials"): conn = flightsql.connect( f"grpc://127.0.0.1:{port}", db_kwargs={ @@ -377,6 +379,19 @@ def test_wrong_password_rejected(self, auth_server): cursor.close() conn.close() + def test_missing_auth_rejected(self, auth_server): + """Connection without an auth header is rejected.""" + _srv, port = auth_server + with pytest.raises( + Exception, + match=r"UNAUTHENTICATED|Authorization header is required", + ): + conn = flightsql.connect(f"grpc://127.0.0.1:{port}") + cursor = conn.execute("SELECT 1") + cursor.fetchall() + cursor.close() + conn.close() + # ═══════════════════════════════════════════════════════════════════════════ # 7.5 — DDL / DML via Flight SQL diff --git a/tests/test_e2e_ducklake.py b/tests/test_e2e_ducklake.py index 08ca1cf..0204339 100644 --- a/tests/test_e2e_ducklake.py +++ b/tests/test_e2e_ducklake.py @@ -2,8 +2,9 @@ These tests start a Flight SQL server wired to a real DuckLake catalog (PostgreSQL on Azure + Parquet on Azure Blob Storage) and exercise -DDL/DML through ADBC. They are **skipped** unless the required -environment variables are set. +DDL/DML through ADBC. They use explicit ``DUCKLAKE_*`` environment +variables first, then fall back to local azd deployment outputs from +``.azure//.env``. Required environment variables ------------------------------ @@ -43,28 +44,25 @@ import time import adbc_driver_flightsql.dbapi as flightsql +import duckdb import pyarrow as pa import pytest +from lakehouse._azd_env import apply_env_resolution, postgres_firewall_hint, resolve_ducklake_env from lakehouse.server import DuckDBFlightSqlServer # ─────────────────────────────────────────────────────────────────────────── # Environment & skip logic # ─────────────────────────────────────────────────────────────────────────── -_REQUIRED_ENV = ( - "DUCKLAKE_PG_HOST", - "DUCKLAKE_PG_DATABASE", - "DUCKLAKE_PG_USER", - "DUCKLAKE_AZURE_STORAGE_ACCOUNT", - "DUCKLAKE_DATA_PATH", -) +_resolution = resolve_ducklake_env() +apply_env_resolution(_resolution) -_missing = [v for v in _REQUIRED_ENV if not os.environ.get(v)] +_missing = _resolution.missing pytestmark = pytest.mark.skipif( bool(_missing), - reason=f"DuckLake env vars missing: {', '.join(_missing)}", + reason=_resolution.skip_reason("DuckLake env vars missing"), ) @@ -142,7 +140,20 @@ def ducklake_server(): # DuckLake init: extensions, secrets, ATTACH, USE token_mgr = PostgresTokenManager(srv._db, config) token = token_mgr.get_initial_token() - initialize_ducklake(srv._db, config, token=token) + ducklake_error_type: str | None = None + try: + initialize_ducklake(srv._db, config, token=token) + except duckdb.Error as exc: + ducklake_error_type = type(exc).__name__ + + if ducklake_error_type is not None: + token_mgr.stop() + srv.shutdown() + message = ( + "Failed to bootstrap the real Azure DuckLake catalog " + f"({ducklake_error_type}). {postgres_firewall_hint()}" + ) + pytest.fail(message, pytrace=False) _start_server(srv) yield srv, port diff --git a/tests/test_jdbc.py b/tests/test_jdbc.py index 4f141b4..d76b96c 100644 --- a/tests/test_jdbc.py +++ b/tests/test_jdbc.py @@ -5,7 +5,7 @@ ``-Dflight.url=grpc://127.0.0.1:``. Skipped when: -* DuckLake env vars are missing (same gate as ``test_e2e_ducklake.py``). +* DuckLake env vars/azd outputs are missing (same gate as ``test_e2e_ducklake.py``). * Maven (``mvn``) is not on ``$PATH``. """ @@ -18,23 +18,25 @@ import threading import time +import duckdb import pytest +from lakehouse._azd_env import apply_env_resolution, postgres_firewall_hint, resolve_ducklake_env + # ─────────────────────────────────────────────────────────────────────────── # Skip conditions # ─────────────────────────────────────────────────────────────────────────── -_REQUIRED_ENV = ( - "DUCKLAKE_PG_HOST", - "DUCKLAKE_PG_DATABASE", - "DUCKLAKE_PG_USER", - "DUCKLAKE_AZURE_STORAGE_ACCOUNT", - "DUCKLAKE_DATA_PATH", -) -_missing = [v for v in _REQUIRED_ENV if not os.environ.get(v)] +_resolution = resolve_ducklake_env() +apply_env_resolution(_resolution) + +_missing = _resolution.missing pytestmark = [ - pytest.mark.skipif(bool(_missing), reason=f"DuckLake env vars missing: {', '.join(_missing)}"), + pytest.mark.skipif( + bool(_missing), + reason=_resolution.skip_reason("DuckLake env vars missing"), + ), pytest.mark.skipif(shutil.which("mvn") is None, reason="Maven (mvn) not found"), ] @@ -71,7 +73,20 @@ def ducklake_server(): # type: ignore[no-redef] srv = DuckDBFlightSqlServer(location=location, db_path=":memory:", ducklake_alias=alias) token_mgr = PostgresTokenManager(srv._db, config) token = token_mgr.get_initial_token() - initialize_ducklake(srv._db, config, token=token) + ducklake_error_type: str | None = None + try: + initialize_ducklake(srv._db, config, token=token) + except duckdb.Error as exc: + ducklake_error_type = type(exc).__name__ + + if ducklake_error_type is not None: + token_mgr.stop() + srv.shutdown() + message = ( + "Failed to bootstrap the real Azure DuckLake catalog " + f"({ducklake_error_type}). {postgres_firewall_hint()}" + ) + pytest.fail(message, pytrace=False) t = threading.Thread(target=srv.serve, daemon=True) t.start() diff --git a/tests/test_live_azure_backend.py b/tests/test_live_azure_backend.py new file mode 100644 index 0000000..895ed94 --- /dev/null +++ b/tests/test_live_azure_backend.py @@ -0,0 +1,182 @@ +"""Opt-in tests against the deployed Azure Container App backend. + +The default live check authenticates with PyArrow's Basic-token handshake and +then runs the query through ADBC with the returned Bearer token. That proves +the deployed backend, Key Vault password, TLS endpoint, and Bearer auth path are +working. + +The separate ADBC Basic-auth check is gated by ``LAKEHOUSE_LIVE_BACKEND_ADBC_BASIC`` +and is marked ``xfail`` because ADBC's Basic-to-Bearer exchange is the currently +tracked client path that fails against the deployed Container App. A result like +``1 passed, 1 xfailed`` means the supported bearer smoke test passed and the +known ADBC Basic issue reproduced as expected. +""" + +from __future__ import annotations + +import contextlib +import os +import re +import shutil +import subprocess + +import pytest + +from lakehouse._azd_env import resolve_container_app_env + +_LIVE_BACKEND_FLAG = "LAKEHOUSE_LIVE_BACKEND" +_LIVE_BACKEND_ADBC_BASIC_FLAG = "LAKEHOUSE_LIVE_BACKEND_ADBC_BASIC" + +pytestmark = pytest.mark.skipif( + os.environ.get(_LIVE_BACKEND_FLAG) != "1", + reason=f"set {_LIVE_BACKEND_FLAG}=1 to query the deployed Azure Container App", +) + + +def _run_az(args: list[str], purpose: str) -> str: + if shutil.which("az") is None: + pytest.skip("Azure CLI (az) not found") + result = subprocess.run( + ["az", *args], + capture_output=True, + check=False, + text=True, + timeout=60, + ) + if result.returncode != 0: + raise RuntimeError(f"Azure CLI failed while {purpose} (exit {result.returncode})") + value = result.stdout.strip() + if not value: + raise RuntimeError(f"Azure CLI returned no value while {purpose}") + return value + + +def _discover_endpoint(values: dict[str, str]) -> str: + fqdn = _run_az( + [ + "containerapp", + "show", + "-g", + values["AZURE_RESOURCE_GROUP"], + "-n", + values["CONTAINER_APP_NAME"], + "--query", + "properties.configuration.ingress.fqdn", + "-o", + "tsv", + ], + "discovering the Container App endpoint", + ) + return f"grpc+tls://{fqdn}:443" + + +def _read_password(values: dict[str, str]) -> str: + return _run_az( + [ + "keyvault", + "secret", + "show", + "--vault-name", + values["KEY_VAULT_NAME"], + "--name", + "lakehouse-password", + "--query", + "value", + "-o", + "tsv", + ], + "reading the lakehouse-password Key Vault secret", + ) + + +def _redact_auth_material(message: str, *secrets: str) -> str: + message = re.sub(r"Basic\s+[A-Za-z0-9+/=_-]+", "Basic ", message) + message = re.sub(r"Bearer\s+[A-Za-z0-9._~+/=-]+", "Bearer ", message) + for secret in secrets: + if secret: + message = message.replace(secret, "") + return message + + +def _bootstrap_bearer_header(endpoint: str, username: str, password: str) -> str: + import pyarrow.flight as flight + + client = flight.connect(endpoint) + header_name, header_value = client.authenticate_basic_token(username, password) + if isinstance(header_name, bytes): + header_name = header_name.decode() + if isinstance(header_value, bytes): + header_value = header_value.decode() + if header_name.lower() != "authorization" or not header_value.startswith("Bearer "): + raise RuntimeError("Basic auth did not return a Bearer authorization header") + return header_value + + +def _connect_with_pyarrow_bootstrapped_bearer(endpoint: str, password: str): + import adbc_driver_flightsql.dbapi as flightsql + from adbc_driver_flightsql import DatabaseOptions + + bearer_header = _bootstrap_bearer_header(endpoint, "lakehouse", password) + return flightsql.connect( + endpoint, + db_kwargs={DatabaseOptions.AUTHORIZATION_HEADER.value: bearer_header}, + ) + + +def _connect_with_adbc_basic(endpoint: str, password: str): + import base64 + + import adbc_driver_flightsql.dbapi as flightsql + from adbc_driver_flightsql import DatabaseOptions + + token = base64.b64encode(f"lakehouse:{password}".encode()).decode() + return flightsql.connect( + endpoint, + db_kwargs={DatabaseOptions.AUTHORIZATION_HEADER.value: f"Basic {token}"}, + ) + + +def _run_live_query(connect): + resolution = resolve_container_app_env() + if not resolution.ready: + pytest.skip(resolution.skip_reason("Azure Container App azd outputs missing")) + + conn = None + cursor = None + failure: str | None = None + password = "" + try: + endpoint = _discover_endpoint(resolution.values) + password = _read_password(resolution.values) + conn = connect(endpoint, password) + cursor = conn.execute("SELECT 1 AS value") + assert cursor.fetchall() == [(1,)] + except Exception as exc: + detail = _redact_auth_material(str(exc), password)[:500] + failure = f"{type(exc).__name__}: {detail}" + finally: + if cursor is not None: + with contextlib.suppress(Exception): + cursor.close() + if conn is not None: + with contextlib.suppress(Exception): + conn.close() + + if failure is not None: + pytest.fail(f"Live Azure backend query failed ({failure})", pytrace=False) + + +def test_deployed_container_app_accepts_pyarrow_bootstrapped_bearer_query(): + _run_live_query(_connect_with_pyarrow_bootstrapped_bearer) + + +@pytest.mark.skipif( + os.environ.get(_LIVE_BACKEND_ADBC_BASIC_FLAG) != "1", + reason=f"set {_LIVE_BACKEND_ADBC_BASIC_FLAG}=1 to exercise ADBC Basic auth", +) +@pytest.mark.xfail( + reason="ADBC Basic-to-Bearer handshake currently fails against the deployed Container App", + strict=False, +) +def test_deployed_container_app_accepts_adbc_basic_query(): + _run_live_query(_connect_with_adbc_basic) diff --git a/tests/test_main.py b/tests/test_main.py index 585be13..f912da8 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -3,8 +3,13 @@ from __future__ import annotations import re +import socket +import threading +import time from unittest.mock import patch +import adbc_driver_flightsql.dbapi as flightsql +import pytest from typer.testing import CliRunner from lakehouse.__main__ import app, build_server @@ -18,6 +23,12 @@ def _strip_ansi(text: str) -> str: return re.sub(r"\x1b\[[0-9;]*m", "", text) +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + # ═══════════════════════════════════════════════════════════════════════════ # build_server # ═══════════════════════════════════════════════════════════════════════════ @@ -49,6 +60,38 @@ def test_auth_with_password(self): server = build_server(config) server.shutdown() + def test_auth_with_password_rejects_missing_credentials(self): + """build_server wires the production auth middleware stack.""" + port = _free_port() + config = ServerConfig( + host="127.0.0.1", + port=port, + database=":memory:", + password="test-password", + secret_key="test-key-at-least-32-bytes-long-x", + ) + server = build_server(config) + thread = threading.Thread(target=server.serve, daemon=True) + thread.start() + time.sleep(0.5) + + conn = None + cursor = None + try: + with pytest.raises( + Exception, + match=r"UNAUTHENTICATED|Authorization header is required", + ): + conn = flightsql.connect(f"grpc://127.0.0.1:{port}") + cursor = conn.execute("SELECT 1") + cursor.fetchall() + finally: + if cursor is not None: + cursor.close() + if conn is not None: + conn.close() + server.shutdown() + def test_custom_database(self, tmp_path): """build_server can create a DuckDB file-backed database.""" db_path = tmp_path / "test.duckdb"