diff --git a/.opencode/skills/data-parity/SKILL.md b/.opencode/skills/data-parity/SKILL.md index 2bb7fa5df..6bdd054d2 100644 --- a/.opencode/skills/data-parity/SKILL.md +++ b/.opencode/skills/data-parity/SKILL.md @@ -71,6 +71,19 @@ WHERE table_schema = 'mydb' AND table_name = 'orders' ORDER BY ordinal_position ``` +```sql +-- SQL Server / Fabric +SELECT c.name AS column_name, tp.name AS data_type, c.is_nullable, + dc.definition AS column_default +FROM sys.columns c +INNER JOIN sys.types tp ON c.user_type_id = tp.user_type_id +INNER JOIN sys.objects o ON c.object_id = o.object_id +INNER JOIN sys.schemas s ON o.schema_id = s.schema_id +LEFT JOIN sys.default_constraints dc ON c.default_object_id = dc.object_id +WHERE s.name = 'dbo' AND o.name = 'orders' +ORDER BY c.column_id +``` + ```sql -- ClickHouse DESCRIBE TABLE source_db.events @@ -409,3 +422,56 @@ Even when tables match perfectly, state what was checked: **Silently excluding auto-timestamp columns without asking the user** → Always present detected auto-timestamp columns (Step 4) and get explicit confirmation. In migration scenarios, `created_at` should be *identical* — excluding it silently hides real bugs. + +--- + +## SQL Server and Microsoft Fabric + +### Minimum Version Requirements + +| Component | Minimum Version | Why | +|---|---|---| +| **SQL Server** | 2022 (16.x) | `DATETRUNC()` used for date partitioning; `LEAST()`/`GREATEST()` used by Rust engine | +| **Azure SQL Database** | Any current version | Always has `DATETRUNC()` and `LEAST()` | +| **Microsoft Fabric** | Any current version | T-SQL surface includes all required functions | +| **mssql** (npm) | 12.0.0 | `ConnectionPool` isolation for concurrent connections, tedious 19 | +| **@azure/identity** (npm) | 4.0.0 | Required only for Azure AD authentication; tedious imports it internally | + +> **Note:** Date partitioning (`partition_column` + `partition_granularity`) uses `DATETRUNC()` which is **not available on SQL Server 2019 or earlier**. Basic diff operations (joindiff, hashdiff, profile) work on older versions. If you need partitioned diffs on SQL Server < 2022, use numeric or categorical partitioning instead. + +### Supported Configurations + +| Warehouse Type | Authentication | Notes | +|---|---|---| +| `sqlserver` / `mssql` | User/password or Azure AD | On-prem or Azure SQL. SQL Server 2022+ required for date partitioning. | +| `fabric` | Azure AD only | Microsoft Fabric SQL endpoint. Always uses TLS encryption. | + +### Connecting to Microsoft Fabric + +Fabric uses the same TDS protocol as SQL Server — no separate driver needed. Configuration: + +```yaml +type: "fabric" +host: "-.datawarehouse.fabric.microsoft.com" +database: "" +authentication: "azure-active-directory-default" # recommended +``` + +Auth shorthands (mapped to full tedious type names): +- `CLI` or `default` → `azure-active-directory-default` +- `password` → `azure-active-directory-password` +- `service-principal` → `azure-active-directory-service-principal-secret` +- `msi` or `managed-identity` → `azure-active-directory-msi-vm` + +Full Azure AD authentication types: +- `azure-active-directory-default` — auto-discovers credentials via `DefaultAzureCredential` (recommended; works with `az login`) +- `azure-active-directory-password` — username/password with `azure_client_id` and `azure_tenant_id` +- `azure-active-directory-access-token` — pre-obtained token (does **not** auto-refresh) +- `azure-active-directory-service-principal-secret` — service principal with `azure_client_id`, `azure_client_secret`, `azure_tenant_id` +- `azure-active-directory-msi-vm` / `azure-active-directory-msi-app-service` — managed identity + +### Algorithm Behavior + +- **Same-warehouse** MSSQL or Fabric → `joindiff` (single FULL OUTER JOIN, most efficient) +- **Cross-warehouse** MSSQL/Fabric ↔ other database → `hashdiff` (automatic when using `auto`) +- The Rust engine maps `sqlserver`/`mssql` to `tsql` dialect and `fabric` to `fabric` dialect — both generate valid T-SQL syntax with bracket quoting (`[schema].[table]`). diff --git a/bun.lock b/bun.lock index 1b06053a5..25e43809d 100644 --- a/bun.lock +++ b/bun.lock @@ -48,7 +48,7 @@ "@google-cloud/bigquery": "^8.0.0", "duckdb": "^1.0.0", "mongodb": "^6.0.0", - "mssql": "^11.0.0", + "mssql": "^12.0.0", "mysql2": "^3.0.0", "oracledb": "^6.0.0", "pg": "^8.0.0", @@ -1034,7 +1034,7 @@ "@techteamer/ocsp": ["@techteamer/ocsp@1.0.1", "", { "dependencies": { "asn1.js": "^5.4.1", "asn1.js-rfc2560": "^5.0.1", "asn1.js-rfc5280": "^3.0.0", "async": "^3.2.4", "simple-lru-cache": "^0.0.2" } }, "sha512-q4pW5wAC6Pc3JI8UePwE37CkLQ5gDGZMgjSX4MEEm4D4Di59auDQ8UNIDzC4gRnPNmmcwjpPxozq8p5pjiOmOw=="], - "@tediousjs/connection-string": ["@tediousjs/connection-string@0.5.0", "", {}, "sha512-7qSgZbincDDDFyRweCIEvZULFAw5iz/DeunhvuxpL31nfntX3P4Yd4HkHBRg9H8CdqY1e5WFN1PZIz/REL9MVQ=="], + "@tediousjs/connection-string": ["@tediousjs/connection-string@0.6.0", "", {}, "sha512-GxlsW354Vi6QqbUgdPyQVcQjI7cZBdGV5vOYVYuCVDTylx2wl3WHR2HlhcxxHTrMigbelpXsdcZso+66uxPfow=="], "@tokenizer/token": ["@tokenizer/token@0.3.0", "", {}, "sha512-OvjF+z51L3ov0OyAU0duzsYuvO01PH7x4t6DJx+guahgTnBHkhJdG7soQeTSFLWN3efnHyibZ4Z8l2EuWwJN3A=="], @@ -1902,7 +1902,7 @@ "msgpackr-extract": ["msgpackr-extract@3.0.3", "", { "dependencies": { "node-gyp-build-optional-packages": "5.2.2" }, "optionalDependencies": { "@msgpackr-extract/msgpackr-extract-darwin-arm64": "3.0.3", "@msgpackr-extract/msgpackr-extract-darwin-x64": "3.0.3", "@msgpackr-extract/msgpackr-extract-linux-arm": "3.0.3", "@msgpackr-extract/msgpackr-extract-linux-arm64": "3.0.3", "@msgpackr-extract/msgpackr-extract-linux-x64": "3.0.3", "@msgpackr-extract/msgpackr-extract-win32-x64": "3.0.3" }, "bin": { "download-msgpackr-prebuilds": "bin/download-prebuilds.js" } }, "sha512-P0efT1C9jIdVRefqjzOQ9Xml57zpOXnIuS+csaB4MdZbTdmGDLo8XhzBG1N7aO11gKDDkJvBLULeFTo46wwreA=="], - "mssql": ["mssql@11.0.1", "", { "dependencies": { "@tediousjs/connection-string": "^0.5.0", "commander": "^11.0.0", "debug": "^4.3.3", "rfdc": "^1.3.0", "tarn": "^3.0.2", "tedious": "^18.2.1" }, "bin": { "mssql": "bin/mssql" } }, "sha512-KlGNsugoT90enKlR8/G36H0kTxPthDhmtNUCwEHvgRza5Cjpjoj+P2X6eMpFUDN7pFrJZsKadL4x990G8RBE1w=="], + "mssql": ["mssql@12.2.1", "", { "dependencies": { "@tediousjs/connection-string": "^0.6.0", "commander": "^11.0.0", "debug": "^4.3.3", "tarn": "^3.0.2", "tedious": "^19.0.0" }, "bin": { "mssql": "bin/mssql" } }, "sha512-TU89g82WatOVcinw3etO/crKbd67ugC3Wm6TJDklHjp7211brVENWIs++UoPC2H+TWvyi0OSlzMou8GY15onOA=="], "multicast-dns": ["multicast-dns@7.2.5", "", { "dependencies": { "dns-packet": "^5.2.2", "thunky": "^1.0.2" }, "bin": { "multicast-dns": "cli.js" } }, "sha512-2eznPJP8z2BFLX50tf0LuODrpINqP1RVIm/CObbTcBRITQgmC/TjcREF1NeTBzIcR5XO/ukWo+YHOjBbFwIupg=="], @@ -2336,7 +2336,7 @@ "tarn": ["tarn@3.0.2", "", {}, "sha512-51LAVKUSZSVfI05vjPESNc5vwqqZpbXCsU+/+wxlOrUjk2SnFTt97v9ZgQrD4YmxYW1Px6w2KjaDitCfkvgxMQ=="], - "tedious": ["tedious@18.6.2", "", { "dependencies": { "@azure/core-auth": "^1.7.2", "@azure/identity": "^4.2.1", "@azure/keyvault-keys": "^4.4.0", "@js-joda/core": "^5.6.1", "@types/node": ">=18", "bl": "^6.0.11", "iconv-lite": "^0.6.3", "js-md4": "^0.3.2", "native-duplexpair": "^1.0.0", "sprintf-js": "^1.1.3" } }, "sha512-g7jC56o3MzLkE3lHkaFe2ZdOVFBahq5bsB60/M4NYUbocw/MCrS89IOEQUFr+ba6pb8ZHczZ/VqCyYeYq0xBAg=="], + "tedious": ["tedious@19.2.1", "", { "dependencies": { "@azure/core-auth": "^1.7.2", "@azure/identity": "^4.2.1", "@azure/keyvault-keys": "^4.4.0", "@js-joda/core": "^5.6.5", "@types/node": ">=18", "bl": "^6.1.4", "iconv-lite": "^0.7.0", "js-md4": "^0.3.2", "native-duplexpair": "^1.0.0", "sprintf-js": "^1.1.3" } }, "sha512-pk1Q16Yl62iocuQB+RWbg6rFUFkIyzqOFQ6NfysCltRvQqKwfurgj8v/f2X+CKvDhSL4IJ0cCOfCHDg9PWEEYA=="], "teeny-request": ["teeny-request@10.1.0", "", { "dependencies": { "http-proxy-agent": "^5.0.0", "https-proxy-agent": "^5.0.0", "node-fetch": "^3.3.2", "stream-events": "^1.0.5" } }, "sha512-3ZnLvgWF29jikg1sAQ1g0o+lr5JX6sVgYvfUJazn7ZjJroDBUTWp44/+cFVX0bULjv4vci+rBD+oGVAkWqhUbw=="], @@ -2988,6 +2988,8 @@ "@smithy/util-waiter/@smithy/types": ["@smithy/types@4.13.1", "", { "dependencies": { "tslib": "^2.6.2" } }, "sha512-787F3yzE2UiJIQ+wYW1CVg2odHjmaWLGksnKQHUrK/lYZSEcy1msuLVvxaR/sI2/aDe9U+TBuLsXnr3vod1g0g=="], + "@types/mssql/tedious": ["tedious@18.6.2", "", { "dependencies": { "@azure/core-auth": "^1.7.2", "@azure/identity": "^4.2.1", "@azure/keyvault-keys": "^4.4.0", "@js-joda/core": "^5.6.1", "@types/node": ">=18", "bl": "^6.0.11", "iconv-lite": "^0.6.3", "js-md4": "^0.3.2", "native-duplexpair": "^1.0.0", "sprintf-js": "^1.1.3" } }, "sha512-g7jC56o3MzLkE3lHkaFe2ZdOVFBahq5bsB60/M4NYUbocw/MCrS89IOEQUFr+ba6pb8ZHczZ/VqCyYeYq0xBAg=="], + "@types/request/form-data": ["form-data@2.5.5", "", { "dependencies": { "asynckit": "^0.4.0", "combined-stream": "^1.0.8", "es-set-tostringtag": "^2.1.0", "hasown": "^2.0.2", "mime-types": "^2.1.35", "safe-buffer": "^5.2.1" } }, "sha512-jqdObeR2rxZZbPSGL+3VckHMYtu+f9//KXBsVny6JSX/pa38Fy+bGjuG8eW/H6USNQWhLi8Num++cU2yOCNz4A=="], "accepts/negotiator": ["negotiator@1.0.0", "", {}, "sha512-8Ofs/AUQh8MaEcrlq5xOX0CQ9ypTF5dl78mjlMNfOK08fzpgTHQRQPBxcPlEtIw0yRpws+Zo/3r+5WRby7u3Gg=="], @@ -3040,6 +3042,8 @@ "cross-spawn/which": ["which@2.0.2", "", { "dependencies": { "isexe": "^2.0.0" }, "bin": { "node-which": "./bin/node-which" } }, "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA=="], + "drizzle-orm/mssql": ["mssql@11.0.1", "", { "dependencies": { "@tediousjs/connection-string": "^0.5.0", "commander": "^11.0.0", "debug": "^4.3.3", "rfdc": "^1.3.0", "tarn": "^3.0.2", "tedious": "^18.2.1" }, "bin": { "mssql": "bin/mssql" } }, "sha512-KlGNsugoT90enKlR8/G36H0kTxPthDhmtNUCwEHvgRza5Cjpjoj+P2X6eMpFUDN7pFrJZsKadL4x990G8RBE1w=="], + "effect/@standard-schema/spec": ["@standard-schema/spec@1.1.0", "", {}, "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w=="], "effect/yaml": ["yaml@2.8.2", "", { "bin": { "yaml": "bin.mjs" } }, "sha512-mplynKqc1C2hTVYxd0PU2xQAc22TI1vShAYGksCCfxbn/dFwnHTNi1bvYsBTkhdUNtGIf5xNOg938rrSSYvS9A=="], @@ -3164,8 +3168,6 @@ "tar-stream/bl": ["bl@4.1.0", "", { "dependencies": { "buffer": "^5.5.0", "inherits": "^2.0.4", "readable-stream": "^3.4.0" } }, "sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w=="], - "tedious/iconv-lite": ["iconv-lite@0.6.3", "", { "dependencies": { "safer-buffer": ">= 2.1.2 < 3.0.0" } }, "sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw=="], - "teeny-request/http-proxy-agent": ["http-proxy-agent@5.0.0", "", { "dependencies": { "@tootallnate/once": "2", "agent-base": "6", "debug": "4" } }, "sha512-n2hY8YdoRE1i7r6M0w9DIw5GgZN0G25P8zLCRQ8rjXtTU3vsNFBI/vWK/UIeE6g5MUUz6avwAPXmL6Fy9D/90w=="], "teeny-request/https-proxy-agent": ["https-proxy-agent@5.0.1", "", { "dependencies": { "agent-base": "6", "debug": "4" } }, "sha512-dFcAjpTQFgoLMzC2VwU+C/CbS7uRL0lWmxDITmqm7C+7F0Odmj6s9l6alZc6AELXhrnggM2CeWSXHGOdX2YtwA=="], @@ -3518,6 +3520,8 @@ "@smithy/util-stream/@smithy/node-http-handler/@smithy/querystring-builder": ["@smithy/querystring-builder@4.2.8", "", { "dependencies": { "@smithy/types": "^4.12.0", "@smithy/util-uri-escape": "^4.2.0", "tslib": "^2.6.2" } }, "sha512-Xr83r31+DrE8CP3MqPgMJl+pQlLLmOfiEUnoyAlGzzJIrEsbKsPy1hqH0qySaQm4oWrCBlUqRt+idEgunKB+iw=="], + "@types/mssql/tedious/iconv-lite": ["iconv-lite@0.6.3", "", { "dependencies": { "safer-buffer": ">= 2.1.2 < 3.0.0" } }, "sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw=="], + "@types/request/form-data/mime-types": ["mime-types@2.1.35", "", { "dependencies": { "mime-db": "1.52.0" } }, "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw=="], "ai-gateway-provider/@ai-sdk/amazon-bedrock/@ai-sdk/anthropic": ["@ai-sdk/anthropic@2.0.62", "", { "dependencies": { "@ai-sdk/provider": "2.0.1", "@ai-sdk/provider-utils": "3.0.21" }, "peerDependencies": { "zod": "^3.25.76 || ^4.1.8" } }, "sha512-I3RhaOEMnWlWnrvjNBOYvUb19Dwf2nw01IruZrVJRDi688886e11wnd5DxrBZLd2V29Gizo3vpOPnnExsA+wTA=="], @@ -3556,6 +3560,12 @@ "cross-spawn/which/isexe": ["isexe@2.0.0", "", {}, "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw=="], + "drizzle-orm/mssql/@tediousjs/connection-string": ["@tediousjs/connection-string@0.5.0", "", {}, "sha512-7qSgZbincDDDFyRweCIEvZULFAw5iz/DeunhvuxpL31nfntX3P4Yd4HkHBRg9H8CdqY1e5WFN1PZIz/REL9MVQ=="], + + "drizzle-orm/mssql/commander": ["commander@11.1.0", "", {}, "sha512-yPVavfyCcRhmorC7rWlkHn15b4wDVgVmBA7kV4QVBsF7kv/9TKJAbAXVTxvTnwP8HHKjRCJDClKbciiYS7p0DQ=="], + + "drizzle-orm/mssql/tedious": ["tedious@18.6.2", "", { "dependencies": { "@azure/core-auth": "^1.7.2", "@azure/identity": "^4.2.1", "@azure/keyvault-keys": "^4.4.0", "@js-joda/core": "^5.6.1", "@types/node": ">=18", "bl": "^6.0.11", "iconv-lite": "^0.6.3", "js-md4": "^0.3.2", "native-duplexpair": "^1.0.0", "sprintf-js": "^1.1.3" } }, "sha512-g7jC56o3MzLkE3lHkaFe2ZdOVFBahq5bsB60/M4NYUbocw/MCrS89IOEQUFr+ba6pb8ZHczZ/VqCyYeYq0xBAg=="], + "form-data/mime-types/mime-db": ["mime-db@1.52.0", "", {}, "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg=="], "fs-minipass/minipass/yallist": ["yallist@4.0.0", "", {}, "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A=="], @@ -3778,6 +3788,8 @@ "cross-fetch/node-fetch/whatwg-url/webidl-conversions": ["webidl-conversions@3.0.1", "", {}, "sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ=="], + "drizzle-orm/mssql/tedious/iconv-lite": ["iconv-lite@0.6.3", "", { "dependencies": { "safer-buffer": ">= 2.1.2 < 3.0.0" } }, "sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw=="], + "gaxios/rimraf/glob/jackspeak": ["jackspeak@3.4.3", "", { "dependencies": { "@isaacs/cliui": "^8.0.2" }, "optionalDependencies": { "@pkgjs/parseargs": "^0.11.0" } }, "sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw=="], "gaxios/rimraf/glob/minimatch": ["minimatch@9.0.5", "", { "dependencies": { "brace-expansion": "^2.0.1" } }, "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow=="], diff --git a/packages/drivers/package.json b/packages/drivers/package.json index 98a0112cf..361c1dd96 100644 --- a/packages/drivers/package.json +++ b/packages/drivers/package.json @@ -17,7 +17,7 @@ "@google-cloud/bigquery": "^8.0.0", "@databricks/sql": "^1.0.0", "mysql2": "^3.0.0", - "mssql": "^11.0.0", + "mssql": "^12.0.0", "oracledb": "^6.0.0", "duckdb": "^1.0.0", "mongodb": "^6.0.0", diff --git a/packages/drivers/src/normalize.ts b/packages/drivers/src/normalize.ts index 5afc20cee..2d3c36127 100644 --- a/packages/drivers/src/normalize.ts +++ b/packages/drivers/src/normalize.ts @@ -65,6 +65,12 @@ const SQLSERVER_ALIASES: AliasMap = { ...COMMON_ALIASES, host: ["server", "serverName", "server_name"], trust_server_certificate: ["trustServerCertificate"], + authentication: ["authenticationType", "auth_type", "authentication_type"], + azure_tenant_id: ["tenantId", "tenant_id", "azureTenantId"], + azure_client_id: ["clientId", "client_id", "azureClientId"], + azure_client_secret: ["clientSecret", "client_secret", "azureClientSecret"], + access_token: ["token", "accessToken"], + azure_resource_url: ["azureResourceUrl", "resourceUrl", "resource_url"], } const ORACLE_ALIASES: AliasMap = { @@ -104,6 +110,7 @@ const DRIVER_ALIASES: Record = { mariadb: MYSQL_ALIASES, sqlserver: SQLSERVER_ALIASES, mssql: SQLSERVER_ALIASES, + fabric: SQLSERVER_ALIASES, oracle: ORACLE_ALIASES, mongodb: MONGODB_ALIASES, mongo: MONGODB_ALIASES, diff --git a/packages/drivers/src/sqlserver.ts b/packages/drivers/src/sqlserver.ts index 3ea1e390f..2a38cfe82 100644 --- a/packages/drivers/src/sqlserver.ts +++ b/packages/drivers/src/sqlserver.ts @@ -4,12 +4,82 @@ import type { ConnectionConfig, Connector, ConnectorResult, ExecuteOptions, SchemaColumn } from "./types" +// --------------------------------------------------------------------------- +// Azure AD helpers — cache + resource URL resolution +// --------------------------------------------------------------------------- + +// Module-scoped token cache, keyed by `${resource}|${clientId ?? ""}`. +// Tokens are reused across `connect()` calls in the same process and refreshed +// a few minutes before expiry. Fixes the issue where every new connection +// fetched a fresh token (wasteful, risks throttling) and long-lived diffs +// failed silently when the embedded token hit its ~1h TTL. +const tokenCache = new Map() +const TOKEN_REFRESH_MARGIN_MS = 5 * 60 * 1000 // refresh 5 minutes before expiry +const TOKEN_FALLBACK_TTL_MS = 50 * 60 * 1000 // used when JWT has no exp claim + +/** + * Parse the `exp` claim from a JWT access token (milliseconds since epoch). + * Returns undefined if the token isn't a JWT or has no exp claim. + */ +function parseTokenExpiry(token: string): number | undefined { + try { + const parts = token.split(".") + if (parts.length !== 3) return undefined + const payload = parts[1] + // base64url → base64 + padding + const padded = payload.replace(/-/g, "+").replace(/_/g, "/") + + "=".repeat((4 - (payload.length % 4)) % 4) + const decoded = Buffer.from(padded, "base64").toString("utf-8") + const claims = JSON.parse(decoded) + return typeof claims.exp === "number" ? claims.exp * 1000 : undefined + } catch { + return undefined + } +} + +/** + * Resolve the Azure resource URL for token acquisition. + * + * Preference order: + * 1. Explicit `config.azure_resource_url`. + * 2. Inferred from host suffix (Azure Gov / China). + * 3. Default Azure commercial cloud. + * + * The returned URL is guaranteed to end with `/` — callers append `.default` + * to build the OAuth scope (e.g. `https://database.windows.net/.default`). + * Without the trailing slash, an explicit user value like + * `"https://custom-host"` would produce an invalid scope + * `"https://custom-host.default"` and `credential.getToken` would fail. + */ +function resolveAzureResourceUrl(config: ConnectionConfig): string { + const explicit = config.azure_resource_url as string | undefined + const raw = explicit ?? (() => { + const host = (config.host as string | undefined) ?? "" + if (host.includes(".usgovcloudapi.net") || host.includes(".datawarehouse.fabric.microsoft.us")) { + return "https://database.usgovcloudapi.net/" + } + if (host.includes(".chinacloudapi.cn")) { + return "https://database.chinacloudapi.cn/" + } + return "https://database.windows.net/" + })() + return raw.endsWith("/") ? raw : `${raw}/` +} + +/** Visible for testing: reset the module-scoped token cache. */ +export function _resetTokenCacheForTests(): void { + tokenCache.clear() +} + export async function connect(config: ConnectionConfig): Promise { let mssql: any + let MssqlConnectionPool: any try { // @ts-expect-error — mssql has no type declarations; installed as optional peerDependency - mssql = await import("mssql") - mssql = mssql.default || mssql + const mod = await import("mssql") + mssql = mod.default || mod + // ConnectionPool is a named export, not on .default + MssqlConnectionPool = mod.ConnectionPool ?? mssql.ConnectionPool } catch { throw new Error( "SQL Server driver not installed. Run: npm install mssql", @@ -24,8 +94,6 @@ export async function connect(config: ConnectionConfig): Promise { server: config.host ?? "127.0.0.1", port: config.port ?? 1433, database: config.database, - user: config.user, - password: config.password, options: { encrypt: config.encrypt ?? false, trustServerCertificate: config.trust_server_certificate ?? true, @@ -39,7 +107,206 @@ export async function connect(config: ConnectionConfig): Promise { }, } - pool = await mssql.connect(mssqlConfig) + // Normalize shorthand auth values to tedious-compatible types + const AUTH_SHORTHANDS: Record = { + cli: "azure-active-directory-default", + default: "azure-active-directory-default", + password: "azure-active-directory-password", + "service-principal": "azure-active-directory-service-principal-secret", + serviceprincipal: "azure-active-directory-service-principal-secret", + "managed-identity": "azure-active-directory-msi-vm", + msi: "azure-active-directory-msi-vm", + } + // `config.authentication` is typed as unknown upstream — accept only + // strings here. A caller passing a non-string (object, null, pre-built + // auth block) shouldn't crash with "toLowerCase is not a function"; + // treat as "no shorthand requested" and leave authType undefined. + const rawAuth = config.authentication + const authType = + typeof rawAuth === "string" + ? (AUTH_SHORTHANDS[rawAuth.toLowerCase()] ?? rawAuth) + : undefined + + if (authType?.startsWith("azure-active-directory")) { + ;(mssqlConfig.options as any).encrypt = true + + // Resolve a raw Azure AD access token. + // Used by both `azure-active-directory-default` and by + // `azure-active-directory-access-token` when no token was provided. + // + // We acquire the token ourselves rather than letting tedious do it because: + // 1. Bun can resolve @azure/identity to the browser bundle (inside + // tedious or even our own import), where DefaultAzureCredential + // is a non-functional stub that throws. + // 2. Passing a credential object via type:"token-credential" hits a + // CJS/ESM isTokenCredential boundary mismatch in Bun. + // + // Strategy: try @azure/identity first (works when module resolution + // is correct), fall back to shelling out to `az account get-access-token` + // (works everywhere Azure CLI is installed). + // + // Tokens are cached module-scope keyed by (resource, client_id) and + // refreshed 5 minutes before expiry — reuses tokens across connections + // and prevents silent failures when embedded tokens hit their TTL. + const resourceUrl = resolveAzureResourceUrl(config) + const clientId = (config.azure_client_id as string | undefined) ?? "" + const cacheKey = `${resourceUrl}|${clientId}` + + const acquireAzureToken = async (): Promise => { + const cached = tokenCache.get(cacheKey) + if (cached && cached.expiresAt - Date.now() > TOKEN_REFRESH_MARGIN_MS) { + return cached.token + } + + let token: string | undefined + let expiresAt: number | undefined + let azureIdentityError: unknown = null + let azCliStderr = "" + + try { + const azureIdentity = await import("@azure/identity") + const credential = new azureIdentity.DefaultAzureCredential( + config.azure_client_id + ? { managedIdentityClientId: config.azure_client_id as string } + : undefined, + ) + const tokenResponse = await credential.getToken(`${resourceUrl}.default`) + if (tokenResponse?.token) { + token = tokenResponse.token + // @azure/identity provides expiresOnTimestamp (ms). Prefer it; fall + // back to parsing the JWT exp claim so both paths share the cache. + expiresAt = tokenResponse.expiresOnTimestamp ?? parseTokenExpiry(token) + } + } catch (err) { + azureIdentityError = err + // @azure/identity unavailable or browser bundle — fall through to CLI + } + + if (!token) { + try { + // Use async `execFile` (NOT `exec`/`execSync`): + // - `execFile` skips the shell entirely and passes args as an + // array, preventing shell interpolation when `resourceUrl` + // comes from user config (e.g. `"https://x/;other-cmd"`). + // - async keeps the connect path non-blocking during the + // multi-second `az` round trip. + const childProcess = await import("node:child_process") + const { promisify } = await import("node:util") + const execFileAsync = promisify(childProcess.execFile) + const { stdout } = await execFileAsync( + "az", + [ + "account", "get-access-token", + "--resource", resourceUrl, + "--query", "accessToken", + "-o", "tsv", + ], + { encoding: "utf-8", timeout: 15000 }, + ) + const out = String(stdout).trim() + if (out) { + token = out + expiresAt = parseTokenExpiry(out) + } + } catch (err: any) { + // Capture stderr so the final error message can hint at the root cause + // (e.g. "Please run 'az login'", "subscription not found"). + azCliStderr = String(err?.stderr ?? err?.message ?? "").slice(0, 200).trim() + } + } + + if (!token) { + const hints: string[] = [] + if (azureIdentityError) hints.push(`@azure/identity: ${String(azureIdentityError).slice(0, 120)}`) + if (azCliStderr) hints.push(`az CLI: ${azCliStderr}`) + const detail = hints.length > 0 ? ` (${hints.join("; ")})` : "" + throw new Error( + `Azure AD token acquisition failed${detail}. Either install @azure/identity (npm install @azure/identity) ` + + "or log in with Azure CLI (az login).", + ) + } + + tokenCache.set(cacheKey, { + token, + expiresAt: expiresAt ?? Date.now() + TOKEN_FALLBACK_TTL_MS, + }) + return token + } + + if (authType === "azure-active-directory-default") { + mssqlConfig.authentication = { + type: "azure-active-directory-access-token", + options: { token: await acquireAzureToken() }, + } + } else if (authType === "azure-active-directory-password") { + mssqlConfig.authentication = { + type: "azure-active-directory-password", + options: { + userName: config.user, + password: config.password, + clientId: config.azure_client_id, + tenantId: config.azure_tenant_id, + }, + } + } else if (authType === "azure-active-directory-access-token") { + // If the caller supplied a token, use it; otherwise acquire one + // automatically (DefaultAzureCredential → az CLI). + const suppliedToken = (config.token ?? config.access_token) as string | undefined + mssqlConfig.authentication = { + type: "azure-active-directory-access-token", + options: { token: suppliedToken ?? (await acquireAzureToken()) }, + } + } else if ( + authType === "azure-active-directory-msi-vm" || + authType === "azure-active-directory-msi-app-service" + ) { + mssqlConfig.authentication = { + type: authType, + options: { + ...(config.azure_client_id ? { clientId: config.azure_client_id } : {}), + }, + } + } else if (authType === "azure-active-directory-service-principal-secret") { + mssqlConfig.authentication = { + type: "azure-active-directory-service-principal-secret", + options: { + clientId: config.azure_client_id, + clientSecret: config.azure_client_secret, + tenantId: config.azure_tenant_id, + }, + } + } else { + // Any other `azure-active-directory-*` subtype (typo or future + // tedious addition). Fail fast — otherwise we'd silently connect + // with no `authentication` block and tedious would surface an + // opaque error far from the root cause. + throw new Error( + `Unsupported Azure AD authentication subtype: "${authType}". ` + + "Supported subtypes: azure-active-directory-default, " + + "azure-active-directory-password, azure-active-directory-access-token, " + + "azure-active-directory-msi-vm, azure-active-directory-msi-app-service, " + + "azure-active-directory-service-principal-secret.", + ) + } + } else { + // Standard SQL Server user/password + mssqlConfig.user = config.user + mssqlConfig.password = config.password + } + + // Use an explicit ConnectionPool (not the global mssql.connect()) so + // multiple simultaneous connections to different servers are isolated. + // `mssql@^12` guarantees ConnectionPool as a named export — if it's + // missing, the installed driver version is too old. Fail fast rather + // than silently use the global shared pool (which reintroduces the + // cross-database interference bug this branch was added to fix). + if (!MssqlConnectionPool) { + throw new Error( + "mssql.ConnectionPool is not available — the installed `mssql` package is too old. Upgrade to mssql@^12.", + ) + } + pool = new MssqlConnectionPool(mssqlConfig) + await pool.connect() }, async execute(sql: string, limit?: number, _binds?: any[], options?: ExecuteOptions): Promise { @@ -62,22 +329,56 @@ export async function connect(config: ConnectionConfig): Promise { } const result = await pool.request().query(query) - const rows = result.recordset ?? [] - const columns = - rows.length > 0 - ? Object.keys(rows[0]).filter((k) => !k.startsWith("_")) - : (result.recordset?.columns - ? Object.keys(result.recordset.columns) - : []) - const truncated = effectiveLimit > 0 && rows.length > effectiveLimit - const limitedRows = truncated ? rows.slice(0, effectiveLimit) : rows + const recordset = result.recordset ?? [] + const truncated = effectiveLimit > 0 && recordset.length > effectiveLimit + const limitedRecordset = truncated ? recordset.slice(0, effectiveLimit) : recordset + + // mssql merges unnamed columns (e.g. SELECT COUNT(*), SUM(...)) into a + // single array under the empty-string key: row[""] = [val1, val2, ...]. + // When a query mixes named and unnamed columns (e.g. + // SELECT name, COUNT(*), SUM(x) → { name: "alice", "": [42, 100] }), + // we must preserve the known header for `name` and synthesize col_N only + // for the unnamed positions. Build columns and rows in a single pass so + // they stay aligned regardless of how many unnamed values the row + // contains. + let columns: string[] = [] + let columnsBuilt = false + const flatten = (row: any): any[] => { + const vals: any[] = [] + let unnamedCounter = 0 + const entries = Object.entries(row) + for (const [k, v] of entries) { + if (k === "" && Array.isArray(v)) { + for (const inner of v) { + if (!columnsBuilt) columns.push(`col_${unnamedCounter}`) + unnamedCounter++ + vals.push(inner) + } + } else if (k === "") { + // Empty-string key with non-array value — rare edge case, give it + // a synthetic name rather than producing a column named "". + if (!columnsBuilt) columns.push(`col_${unnamedCounter}`) + unnamedCounter++ + vals.push(v) + } else { + if (!columnsBuilt) columns.push(k) + vals.push(v) + } + } + columnsBuilt = true + return vals + } + + const rows = limitedRecordset.map(flatten) + if (!columnsBuilt) { + // No rows — fall back to driver-reported column metadata. + columns = result.recordset?.columns ? Object.keys(result.recordset.columns) : [] + } return { columns, - rows: limitedRows.map((row: any) => - columns.map((col) => row[col]), - ), - row_count: limitedRows.length, + rows, + row_count: rows.length, truncated, } }, diff --git a/packages/drivers/test/sqlserver-unit.test.ts b/packages/drivers/test/sqlserver-unit.test.ts new file mode 100644 index 000000000..776f46a48 --- /dev/null +++ b/packages/drivers/test/sqlserver-unit.test.ts @@ -0,0 +1,800 @@ +/** + * Unit tests for SQL Server driver logic: + * - TOP injection (vs LIMIT) + * - Truncation detection + * - Azure AD authentication (7 flows) + * - Schema introspection queries + * - Connection lifecycle + * - Result format mapping + */ +import { describe, test, expect, mock, beforeEach } from "bun:test" + +// --- Mock mssql --- + +let mockQueryCalls: string[] = [] +let mockQueryResult: any = { recordset: [] } +let mockConnectCalls: any[] = [] +let mockCloseCalls = 0 +let mockInputs: Array<{ name: string; value: any }> = [] + +function resetMocks() { + mockQueryCalls = [] + mockQueryResult = { recordset: [] } + mockConnectCalls = [] + mockCloseCalls = 0 + mockInputs = [] +} + +function createMockRequest() { + const req: any = { + input(name: string, value: any) { + mockInputs.push({ name, value }) + return req + }, + async query(sql: string) { + mockQueryCalls.push(sql) + return mockQueryResult + }, + } + return req +} + +function createMockPool(config: any) { + mockConnectCalls.push(config) + return { + connect: async () => {}, + request: () => createMockRequest(), + close: async () => { + mockCloseCalls++ + }, + } +} + +mock.module("mssql", () => ({ + default: { + connect: async (config: any) => createMockPool(config), + }, + ConnectionPool: class { + _pool: any + constructor(config: any) { + this._pool = createMockPool(config) + } + async connect() { return this._pool.connect() } + request() { return this._pool.request() } + async close() { return this._pool.close() } + }, +})) + +// Exposed to individual tests so they can assert scope / force failures. +const azureIdentityState = { + lastScope: "" as string, + tokenOverride: null as null | { token: string; expiresOnTimestamp?: number }, + throwOnGetToken: false as boolean, +} +mock.module("@azure/identity", () => ({ + DefaultAzureCredential: class { + _opts: any + constructor(opts?: any) { this._opts = opts } + async getToken(scope: string) { + azureIdentityState.lastScope = scope + if (azureIdentityState.throwOnGetToken) throw new Error("mock identity failure") + if (azureIdentityState.tokenOverride) return azureIdentityState.tokenOverride + return { token: "mock-azure-token-12345", expiresOnTimestamp: Date.now() + 3600000 } + } + }, +})) + +// Exposed to tests to stub the `az` CLI fallback. +const cliState = { + lastCmd: "" as string, + output: "mock-cli-token-fallback\n" as string, + throwError: null as null | { stderr?: string; message?: string }, +} +const realChildProcess = await import("node:child_process") +const realUtil = await import("node:util") + +// Helper: build a mock with callback + util.promisify.custom support so +// `promisify(child_process.exec)` or `promisify(child_process.execFile)` +// yields { stdout, stderr } exactly like the real implementation. +function makeChildProcessMock(captureCmd: (args: string) => void) { + const stub: any = (arg0: any, arg1: any, arg2: any, arg3: any) => { + // Accept both exec(cmd, opts?, cb?) and execFile(file, args?, opts?, cb?) + const cb = [arg0, arg1, arg2, arg3].find((x) => typeof x === "function") + // Pick the best "command" representation for test assertions: + // - exec: first arg is the full command string + // - execFile: first arg is the program, second arg is the args array + if (Array.isArray(arg1)) { + captureCmd(`${arg0} ${arg1.join(" ")}`) + } else { + captureCmd(String(arg0)) + } + if (cliState.throwError) { + const e: any = new Error(cliState.throwError.message ?? "az failed") + e.stderr = cliState.throwError.stderr + if (cb) cb(e, "", cliState.throwError.stderr ?? "") + return { on() {}, stdout: null, stderr: null } + } + if (cb) cb(null, cliState.output, "") + return { on() {}, stdout: null, stderr: null } + } + stub[realUtil.promisify.custom] = (arg0: any, arg1: any) => { + if (Array.isArray(arg1)) { + captureCmd(`${arg0} ${arg1.join(" ")}`) + } else { + captureCmd(String(arg0)) + } + if (cliState.throwError) { + const e: any = new Error(cliState.throwError.message ?? "az failed") + e.stderr = cliState.throwError.stderr + return Promise.reject(e) + } + return Promise.resolve({ stdout: cliState.output, stderr: "" }) + } + return stub +} + +const execStub = makeChildProcessMock((c) => { cliState.lastCmd = c }) +const execFileStub = makeChildProcessMock((c) => { cliState.lastCmd = c }) + +mock.module("node:child_process", () => ({ + ...realChildProcess, + execSync: (cmd: string) => { + cliState.lastCmd = cmd + if (cliState.throwError) { + const e: any = new Error(cliState.throwError.message ?? "az failed") + e.stderr = cliState.throwError.stderr + throw e + } + return cliState.output + }, + exec: execStub, + execFile: execFileStub, +})) + +// Import after mocking +const { connect, _resetTokenCacheForTests } = await import("../src/sqlserver") + +describe("SQL Server driver unit tests", () => { + let connector: Awaited> + + beforeEach(async () => { + resetMocks() + connector = await connect({ host: "localhost", port: 1433, database: "testdb", user: "sa", password: "pass" }) + await connector.connect() + }) + + // --- TOP injection --- + + describe("TOP injection", () => { + test("injects TOP for SELECT without one", async () => { + mockQueryResult = { recordset: [{ id: 1, name: "a" }] } + await connector.execute("SELECT * FROM t") + expect(mockQueryCalls[0]).toContain("TOP 1001") + }) + + test("does NOT double-TOP when TOP already present", async () => { + mockQueryResult = { recordset: [{ id: 1 }] } + await connector.execute("SELECT TOP 5 * FROM t") + expect(mockQueryCalls[0]).toBe("SELECT TOP 5 * FROM t") + }) + + test("does NOT inject TOP when LIMIT present", async () => { + mockQueryResult = { recordset: [] } + await connector.execute("SELECT * FROM t LIMIT 10") + expect(mockQueryCalls[0]).toBe("SELECT * FROM t LIMIT 10") + }) + + test("noLimit bypasses TOP injection", async () => { + mockQueryResult = { recordset: [] } + await connector.execute("SELECT * FROM t", undefined, undefined, { noLimit: true }) + expect(mockQueryCalls[0]).toBe("SELECT * FROM t") + }) + + test("uses custom limit value", async () => { + mockQueryResult = { recordset: [] } + await connector.execute("SELECT * FROM t", 50) + expect(mockQueryCalls[0]).toContain("TOP 51") + }) + + test("default limit is 1000", async () => { + mockQueryResult = { recordset: [] } + await connector.execute("SELECT * FROM t") + expect(mockQueryCalls[0]).toContain("TOP 1001") + }) + }) + + // --- Truncation --- + + describe("truncation detection", () => { + test("detects truncation when rows exceed limit", async () => { + const rows = Array.from({ length: 11 }, (_, i) => ({ id: i })) + mockQueryResult = { recordset: rows } + const result = await connector.execute("SELECT * FROM t", 10) + expect(result.truncated).toBe(true) + expect(result.rows.length).toBe(10) + }) + + test("no truncation when rows at or below limit", async () => { + mockQueryResult = { recordset: [{ id: 1 }, { id: 2 }] } + const result = await connector.execute("SELECT * FROM t", 10) + expect(result.truncated).toBe(false) + }) + + test("empty result returns correctly", async () => { + // mssql exposes column metadata as `recordset.columns` (a property ON + // the recordset array), not as a sibling key — mirror the real shape. + const recordset: any[] = [] + ;(recordset as any).columns = {} + mockQueryResult = { recordset } + const result = await connector.execute("SELECT * FROM t") + expect(result.rows).toEqual([]) + expect(result.truncated).toBe(false) + }) + }) + + // --- Azure AD authentication --- + + describe("Azure AD authentication", () => { + test("standard auth uses user/password directly", async () => { + resetMocks() + const c = await connect({ host: "localhost", database: "db", user: "sa", password: "pass" }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.user).toBe("sa") + expect(cfg.password).toBe("pass") + expect(cfg.authentication).toBeUndefined() + }) + + test("azure-active-directory-password builds correct auth object", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + user: "user@domain.com", + password: "secret", + authentication: "azure-active-directory-password", + azure_client_id: "client-123", + azure_tenant_id: "tenant-456", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication).toEqual({ + type: "azure-active-directory-password", + options: { + userName: "user@domain.com", + password: "secret", + clientId: "client-123", + tenantId: "tenant-456", + }, + }) + expect(cfg.user).toBeUndefined() + expect(cfg.password).toBeUndefined() + }) + + test("azure-active-directory-access-token passes supplied token unchanged", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-access-token", + access_token: "eyJhbGciOi...", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication).toEqual({ + type: "azure-active-directory-access-token", + options: { token: "eyJhbGciOi..." }, + }) + }) + + test("azure-active-directory-access-token with no token auto-acquires one", async () => { + // Regression: prior to this, omitting `token`/`access_token` resulted in + // `options.token: undefined`, which tedious rejects with + // "config.authentication.options.token must be of type string". + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-access-token", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication.type).toBe("azure-active-directory-access-token") + expect(cfg.authentication.options.token).toBe("mock-azure-token-12345") + }) + + test("azure-active-directory-service-principal-secret builds SP auth", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-service-principal-secret", + azure_client_id: "sp-client", + azure_client_secret: "sp-secret", + azure_tenant_id: "sp-tenant", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication).toEqual({ + type: "azure-active-directory-service-principal-secret", + options: { + clientId: "sp-client", + clientSecret: "sp-secret", + tenantId: "sp-tenant", + }, + }) + }) + + test("azure-active-directory-msi-vm builds MSI auth with optional clientId", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-msi-vm", + azure_client_id: "msi-client", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication).toEqual({ + type: "azure-active-directory-msi-vm", + options: { clientId: "msi-client" }, + }) + }) + + test("azure-active-directory-msi-app-service works without clientId", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-msi-app-service", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication).toEqual({ + type: "azure-active-directory-msi-app-service", + options: {}, + }) + }) + + test("azure-active-directory-default acquires token and passes as access-token", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-default", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication.type).toBe("azure-active-directory-access-token") + expect(cfg.authentication.options.token).toBe("mock-azure-token-12345") + }) + + test("azure-active-directory-default with client_id passes managedIdentityClientId to credential", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-default", + azure_client_id: "mi-client-id", + }) + await c.connect() + const cfg = mockConnectCalls[0] + // Token is still passed as access-token regardless of client_id + expect(cfg.authentication.type).toBe("azure-active-directory-access-token") + expect(cfg.authentication.options.token).toBe("mock-azure-token-12345") + }) + + test("encryption forced for all Azure AD connections", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "azure-active-directory-password", + user: "u", + password: "p", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.options.encrypt).toBe(true) + }) + + test("standard auth does not force encryption", async () => { + resetMocks() + const c = await connect({ host: "localhost", database: "db", user: "sa", password: "pass" }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.options.encrypt).toBe(false) + }) + + test("'CLI' shorthand acquires token via DefaultAzureCredential", async () => { + resetMocks() + const c = await connect({ + host: "myserver.datawarehouse.fabric.microsoft.com", + database: "migration", + authentication: "CLI", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication.type).toBe("azure-active-directory-access-token") + expect(cfg.authentication.options.token).toBe("mock-azure-token-12345") + expect(cfg.options.encrypt).toBe(true) + }) + + test("'service-principal' shorthand maps correctly", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "service-principal", + azure_client_id: "cid", + azure_client_secret: "csec", + azure_tenant_id: "tid", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication.type).toBe("azure-active-directory-service-principal-secret") + expect(cfg.authentication.options.clientId).toBe("cid") + }) + + test("'msi' shorthand maps to azure-active-directory-msi-vm", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", + database: "db", + authentication: "msi", + }) + await c.connect() + const cfg = mockConnectCalls[0] + expect(cfg.authentication.type).toBe("azure-active-directory-msi-vm") + }) + }) + + // --- Schema introspection --- + + describe("schema introspection", () => { + test("listSchemas queries sys.schemas", async () => { + mockQueryResult = { recordset: [{ name: "dbo" }, { name: "sales" }] } + const schemas = await connector.listSchemas() + expect(mockQueryCalls[0]).toContain("sys.schemas") + expect(schemas).toEqual(["dbo", "sales"]) + }) + + test("listTables queries sys.tables and sys.views", async () => { + mockQueryResult = { + recordset: [ + { name: "orders", type: "U " }, + { name: "order_summary", type: "V" }, + ], + } + const tables = await connector.listTables("dbo") + expect(mockQueryCalls[0]).toContain("UNION ALL") + expect(mockQueryCalls[0]).toContain("sys.tables") + expect(mockQueryCalls[0]).toContain("sys.views") + expect(tables).toEqual([ + { name: "orders", type: "table" }, + { name: "order_summary", type: "view" }, + ]) + }) + + test("describeTable queries sys.columns", async () => { + mockQueryResult = { + recordset: [ + { column_name: "id", data_type: "int", is_nullable: 0 }, + { column_name: "name", data_type: "nvarchar", is_nullable: 1 }, + ], + } + const cols = await connector.describeTable("dbo", "users") + expect(mockQueryCalls[0]).toContain("sys.columns") + expect(cols).toEqual([ + { name: "id", data_type: "int", nullable: false }, + { name: "name", data_type: "nvarchar", nullable: true }, + ]) + }) + }) + + // --- Connection lifecycle --- + + describe("connection lifecycle", () => { + test("close is idempotent", async () => { + await connector.close() + await connector.close() + expect(mockCloseCalls).toBe(1) + }) + }) + + // --- Result format --- + + describe("result format", () => { + test("maps recordset to column-ordered arrays", async () => { + mockQueryResult = { + recordset: [ + { id: 1, name: "alice", age: 30 }, + { id: 2, name: "bob", age: 25 }, + ], + } + const result = await connector.execute("SELECT id, name, age FROM t") + expect(result.columns).toEqual(["id", "name", "age"]) + expect(result.rows).toEqual([ + [1, "alice", 30], + [2, "bob", 25], + ]) + }) + + test("preserves underscore-prefixed columns", async () => { + mockQueryResult = { + recordset: [{ id: 1, _p: "Delivered", name: "x" }], + } + const result = await connector.execute("SELECT * FROM t") + expect(result.columns).toEqual(["id", "_p", "name"]) + }) + }) + + // --- Unnamed column flattening --- + + describe("unnamed column flattening", () => { + test("flattens unnamed columns merged under empty-string key", async () => { + // mssql merges SELECT COUNT(*), SUM(amount) into row[""] = [42, 1000] + mockQueryResult = { + recordset: [{ "": [42, 1000] }], + } + const result = await connector.execute("SELECT COUNT(*), SUM(amount) FROM t") + expect(result.rows).toEqual([[42, 1000]]) + expect(result.columns).toEqual(["col_0", "col_1"]) + }) + + test("preserves legitimate array values from named columns", async () => { + // A named column containing an array (e.g. from JSON aggregation) + // should NOT be spread — only the empty-string key gets flattened + mockQueryResult = { + recordset: [{ id: 1, tags: ["a", "b", "c"] }], + } + const result = await connector.execute("SELECT * FROM t") + expect(result.columns).toEqual(["id", "tags"]) + expect(result.rows).toEqual([[1, ["a", "b", "c"]]]) + }) + + test("handles mix of named and unnamed columns", async () => { + mockQueryResult = { + recordset: [{ name: "alice", "": [42] }], + } + const result = await connector.execute("SELECT * FROM t") + // Named header preserved; single unnamed aggregate synthesized. + expect(result.columns).toEqual(["name", "col_0"]) + expect(result.rows).toEqual([["alice", 42]]) + }) + + test("mixed named + MULTIPLE unnamed aggregates keep named header", async () => { + // SELECT name, COUNT(*), SUM(x) FROM t → { name: "alice", "": [42, 100] }. + // Regression: previous implementation fell back to col_0..col_N for all + // columns, erasing the known `name` header. + mockQueryResult = { + recordset: [{ name: "alice", "": [42, 100] }], + } + const result = await connector.execute("SELECT name, COUNT(*), SUM(x) FROM t") + expect(result.columns).toEqual(["name", "col_0", "col_1"]) + expect(result.rows).toEqual([["alice", 42, 100]]) + }) + + test("single unnamed column gets synthetic name (no blank header)", async () => { + // SELECT COUNT(*) FROM t → { "": [5] } + mockQueryResult = { + recordset: [{ "": [5] }], + } + const result = await connector.execute("SELECT COUNT(*) FROM t") + expect(result.columns).toEqual(["col_0"]) + expect(result.columns).not.toContain("") + expect(result.rows).toEqual([[5]]) + }) + }) + + // --- Azure token caching (Fix #2) --- + + describe("Azure token cache", () => { + beforeEach(() => { + _resetTokenCacheForTests() + azureIdentityState.throwOnGetToken = false + azureIdentityState.tokenOverride = null + cliState.throwError = null + cliState.output = "mock-cli-token-fallback\n" + }) + + test("second connect with same (resource, clientId) reuses cached token", async () => { + let getTokenCalls = 0 + azureIdentityState.tokenOverride = { token: "cached-token-A", expiresOnTimestamp: Date.now() + 3600_000 } + // Hook getToken counter + const origCredential = (await import("@azure/identity")).DefaultAzureCredential + const origGetToken = origCredential.prototype.getToken + origCredential.prototype.getToken = async function (scope: string) { + getTokenCalls++ + return origGetToken.call(this, scope) + } + try { + resetMocks() + const c1 = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + }) + await c1.connect() + const c2 = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + }) + await c2.connect() + expect(getTokenCalls).toBe(1) + // Both pool configs embed the same cached token + expect(mockConnectCalls[0].authentication.options.token).toBe("cached-token-A") + expect(mockConnectCalls[1].authentication.options.token).toBe("cached-token-A") + } finally { + origCredential.prototype.getToken = origGetToken + } + }) + + test("near-expiry token triggers refresh", async () => { + // First token expires in 1 minute (well under the 5-minute refresh margin) + azureIdentityState.tokenOverride = { token: "about-to-expire", expiresOnTimestamp: Date.now() + 60_000 } + resetMocks() + const c1 = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + }) + await c1.connect() + // Now change the mock to issue a new token on refresh + azureIdentityState.tokenOverride = { token: "fresh-token", expiresOnTimestamp: Date.now() + 3600_000 } + const c2 = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + }) + await c2.connect() + expect(mockConnectCalls[0].authentication.options.token).toBe("about-to-expire") + expect(mockConnectCalls[1].authentication.options.token).toBe("fresh-token") + }) + + test("different clientIds cache separately", async () => { + // Prove cache keying by counting distinct getToken invocations: with + // separate clientIds we expect 2 calls (one per key); with a shared + // clientId we expect 1 on the second connect. + let getTokenCalls = 0 + azureIdentityState.tokenOverride = { token: "shared-token", expiresOnTimestamp: Date.now() + 3600_000 } + const origCredential = (await import("@azure/identity")).DefaultAzureCredential + const origGetToken = origCredential.prototype.getToken + origCredential.prototype.getToken = async function (scope: string) { + getTokenCalls++ + return origGetToken.call(this, scope) + } + try { + resetMocks() + const a = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + azure_client_id: "client-1", + }) + await a.connect() + const b = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + azure_client_id: "client-2", + }) + await b.connect() + // Two distinct client IDs → two distinct cache entries → two getToken + // calls. If the cache were keyed only on resource URL this would be 1. + expect(getTokenCalls).toBe(2) + expect(mockConnectCalls[0].authentication.options.token).toBe("shared-token") + expect(mockConnectCalls[1].authentication.options.token).toBe("shared-token") + + // Reconnect with client-1 again — should hit the cache, no new getToken + const c = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + azure_client_id: "client-1", + }) + await c.connect() + expect(getTokenCalls).toBe(2) + } finally { + origCredential.prototype.getToken = origGetToken + } + }) + }) + + // --- Configurable / inferred Azure resource URL (Fix #5) --- + + describe("Azure resource URL resolution", () => { + beforeEach(() => { + _resetTokenCacheForTests() + azureIdentityState.throwOnGetToken = false + azureIdentityState.tokenOverride = null + cliState.throwError = null + }) + + test("commercial cloud: default to database.windows.net", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + }) + await c.connect() + expect(azureIdentityState.lastScope).toBe("https://database.windows.net/.default") + }) + + test("Azure Government host infers usgovcloudapi.net", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.usgovcloudapi.net", database: "d", + authentication: "azure-active-directory-default", + }) + await c.connect() + expect(azureIdentityState.lastScope).toBe("https://database.usgovcloudapi.net/.default") + }) + + test("Azure China host infers chinacloudapi.cn", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.chinacloudapi.cn", database: "d", + authentication: "azure-active-directory-default", + }) + await c.connect() + expect(azureIdentityState.lastScope).toBe("https://database.chinacloudapi.cn/.default") + }) + + test("explicit azure_resource_url wins over host inference", async () => { + resetMocks() + const c = await connect({ + host: "myserver.database.windows.net", // commercial host + database: "d", + authentication: "azure-active-directory-default", + azure_resource_url: "https://custom.sovereign.example/", + }) + await c.connect() + expect(azureIdentityState.lastScope).toBe("https://custom.sovereign.example/.default") + }) + + test("azure_resource_url without trailing slash is normalized", async () => { + // Regression: without the slash, `${resourceUrl}.default` produced an + // invalid scope like "https://custom-host.default", and `getToken` + // would reject it. + resetMocks() + const c = await connect({ + host: "x.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + azure_resource_url: "https://custom-host", + }) + await c.connect() + expect(azureIdentityState.lastScope).toBe("https://custom-host/.default") + }) + + test("az CLI fallback uses the same resource URL", async () => { + // Disable @azure/identity so we hit the az CLI fallback + azureIdentityState.throwOnGetToken = true + cliState.output = "eyJ.eyJ.sig\n" // looks like JWT; parseTokenExpiry returns undefined → fallback TTL + resetMocks() + const c = await connect({ + host: "myserver.database.usgovcloudapi.net", database: "d", + authentication: "azure-active-directory-default", + }) + await c.connect() + expect(cliState.lastCmd).toContain("--resource https://database.usgovcloudapi.net/") + }) + }) + + // --- Error surfacing when auth fails (Fix #5 bonus, Minor #10 addressed) --- + + describe("Azure auth error surfacing", () => { + beforeEach(() => { + _resetTokenCacheForTests() + azureIdentityState.throwOnGetToken = false + azureIdentityState.tokenOverride = null + cliState.throwError = null + }) + + test("both @azure/identity and az CLI fail → error includes both hints", async () => { + azureIdentityState.throwOnGetToken = true + cliState.throwError = { stderr: "Please run 'az login' to set up an account.", message: "failed" } + resetMocks() + const c = await connect({ + host: "h.database.windows.net", database: "d", + authentication: "azure-active-directory-default", + }) + await expect(c.connect()).rejects.toThrow(/Azure AD token acquisition failed/) + await expect(c.connect()).rejects.toThrow(/az CLI:.*az login/) + }) + }) +}) diff --git a/packages/opencode/src/altimate/native/connections/data-diff.ts b/packages/opencode/src/altimate/native/connections/data-diff.ts index 294c43745..a4fc16671 100644 --- a/packages/opencode/src/altimate/native/connections/data-diff.ts +++ b/packages/opencode/src/altimate/native/connections/data-diff.ts @@ -10,6 +10,50 @@ import type { DataDiffParams, DataDiffResult, PartitionDiffResult } from "../types" import * as Registry from "./registry" +// --------------------------------------------------------------------------- +// Dialect mapping — bridge warehouse config types to Rust SqlDialect serde names +// --------------------------------------------------------------------------- + +/** Map warehouse config types to Rust SqlDialect serde names. */ +const WAREHOUSE_TO_DIALECT: Record = { + sqlserver: "tsql", + mssql: "tsql", + fabric: "fabric", + postgresql: "postgres", + mariadb: "mysql", +} + +/** Convert a warehouse config type to the Rust-compatible SqlDialect name. */ +export function warehouseTypeToDialect(warehouseType: string): string { + return WAREHOUSE_TO_DIALECT[warehouseType.toLowerCase()] ?? warehouseType.toLowerCase() +} + +// --------------------------------------------------------------------------- +// Dialect-aware identifier quoting +// --------------------------------------------------------------------------- + +/** + * Quote a SQL identifier using the correct delimiter for the dialect. + * Used both for partition column/value quoting and for plain-table-name + * wrapping inside CTEs (via `resolveTableSources`). + */ +function quoteIdentForDialect(identifier: string, dialect: string): string { + switch (dialect) { + case "mysql": + case "mariadb": + case "clickhouse": + return `\`${identifier.replace(/`/g, "``")}\`` + case "tsql": + case "fabric": + case "sqlserver": + case "mssql": + return `[${identifier.replace(/\]/g, "]]")}]` + default: + // ANSI SQL: Postgres, Snowflake, BigQuery, DuckDB, Oracle, Redshift, etc. + return `"${identifier.replace(/"/g, '""')}"` + } +} + // --------------------------------------------------------------------------- // Query-source detection // --------------------------------------------------------------------------- @@ -18,49 +62,82 @@ const SQL_KEYWORDS = /^\s*(SELECT|WITH|VALUES)\b/i /** * Detect whether a string is an arbitrary SQL query (vs a plain table name). - * Plain table names may contain dots (schema.table, db.schema.table) but not spaces. + * + * A SQL query starts with a keyword AND contains whitespace (e.g., "SELECT * FROM ..."). + * A plain table name — even one named "select" or "with" — is a single token without + * internal whitespace (possibly dot-separated like schema.table or db.schema.table). + * + * The \b in SQL_KEYWORDS already prevents matching "with_metadata" or "select_results", + * but the whitespace check additionally handles bare keyword table names like "select". */ function isQuery(input: string): boolean { - return SQL_KEYWORDS.test(input) + const trimmed = input.trim() + return SQL_KEYWORDS.test(trimmed) && /\s/.test(trimmed) } /** * If either source or target is an arbitrary query, wrap them in CTEs so the * DataParity engine can treat them as tables named `__diff_source` / `__diff_target`. * - * Returns `{ table1Name, table2Name, ctePrefix | null }`. + * Returns both a combined prefix (used for same-warehouse tasks where a JOIN + * might reference both CTEs) and side-specific prefixes (used for cross-warehouse + * tasks where each warehouse only has access to its own base tables). * - * When a CTE prefix is returned, it must be prepended to every SQL task emitted - * by the engine before execution. + * **Why side-specific prefixes matter:** T-SQL / Fabric parse-bind every CTE body + * at parse time, even unreferenced ones. Sending a combined `WITH __diff_source + * AS (... FROM mssql_only_table), __diff_target AS (... FROM fabric_only_table)` + * to MSSQL fails because MSSQL can't resolve the Fabric-only table referenced in + * the unused `__diff_target` CTE. + * + * Callers must prepend the appropriate prefix to every SQL task emitted by the + * engine before execution. */ export function resolveTableSources( source: string, target: string, -): { table1Name: string; table2Name: string; ctePrefix: string | null } { + sourceDialect?: string, + targetDialect?: string, +): { + table1Name: string + table2Name: string + ctePrefix: string | null + sourceCtePrefix: string | null + targetCtePrefix: string | null +} { const source_is_query = isQuery(source) const target_is_query = isQuery(target) if (!source_is_query && !target_is_query) { // Both are plain table names — pass through unchanged - return { table1Name: source, table2Name: target, ctePrefix: null } + return { + table1Name: source, + table2Name: target, + ctePrefix: null, + sourceCtePrefix: null, + targetCtePrefix: null, + } } - // At least one is a query — wrap both in CTEs - // Quote identifier parts so table names with special chars don't inject SQL. - // Use double-quote escaping (ANSI SQL standard, works in Postgres/Snowflake/DuckDB/etc.) - const quoteIdent = (name: string) => - name - .split(".") - .map((p) => `"${p.replace(/"/g, '""')}"`) - .join(".") - const srcExpr = source_is_query ? source : `SELECT * FROM ${quoteIdent(source)}` - const tgtExpr = target_is_query ? target : `SELECT * FROM ${quoteIdent(target)}` + // At least one is a query — wrap both in CTEs. Quote plain-table names with + // the *side's own* dialect so T-SQL / Fabric get `[schema].[table]` and + // ANSI dialects get `"schema"."table"` — avoids `QUOTED_IDENTIFIER OFF` + // surprises on MSSQL/Fabric. Fallback to ANSI when dialect is unspecified. + const quoteTableRef = (name: string, dialect: string | undefined): string => { + const d = dialect ?? "generic" + return name.split(".").map((p) => quoteIdentForDialect(p, d)).join(".") + } + const srcExpr = source_is_query ? source : `SELECT * FROM ${quoteTableRef(source, sourceDialect)}` + const tgtExpr = target_is_query ? target : `SELECT * FROM ${quoteTableRef(target, targetDialect)}` + const sourceCtePrefix = `WITH __diff_source AS (\n${srcExpr}\n)` + const targetCtePrefix = `WITH __diff_target AS (\n${tgtExpr}\n)` const ctePrefix = `WITH __diff_source AS (\n${srcExpr}\n), __diff_target AS (\n${tgtExpr}\n)` return { table1Name: "__diff_source", table2Name: "__diff_target", ctePrefix, + sourceCtePrefix, + targetCtePrefix, } } @@ -403,28 +480,12 @@ const MAX_STEPS = 200 // Partition support // --------------------------------------------------------------------------- -/** - * Quote a SQL identifier using the correct delimiter for the dialect. - */ -function quoteIdentForDialect(identifier: string, dialect: string): string { - switch (dialect) { - case "mysql": - case "mariadb": - case "clickhouse": - return `\`${identifier.replace(/`/g, "``")}\`` - case "tsql": - case "fabric": - return `[${identifier.replace(/\]/g, "]]")}]` - default: - // ANSI SQL: Postgres, Snowflake, BigQuery, DuckDB, Oracle, Redshift, etc. - return `"${identifier.replace(/"/g, '""')}"` - } -} - /** * Build a DATE_TRUNC expression appropriate for the warehouse dialect. + * + * Exported for targeted unit testing; not part of the stable public API. */ -function dateTruncExpr(granularity: string, column: string, dialect: string): string { +export function dateTruncExpr(granularity: string, column: string, dialect: string): string { const g = granularity.toLowerCase() switch (dialect) { case "bigquery": @@ -449,6 +510,12 @@ function dateTruncExpr(granularity: string, column: string, dialect: string): st } return `TRUNC(${column}, '${oracleFmt[g] ?? g.toUpperCase()}')` } + case "sqlserver": + case "mssql": + case "tsql": + case "fabric": + // SQL Server 2022+ / Fabric: DATETRUNC expects unquoted datepart keyword + return `DATETRUNC(${g.toUpperCase()}, ${column})` default: // Postgres, Snowflake, Redshift, DuckDB, etc. return `DATE_TRUNC('${g}', ${column})` @@ -500,8 +567,10 @@ function buildPartitionDiscoverySQL( /** * Build a WHERE clause that scopes to a single partition. + * + * Exported for targeted unit testing; not part of the stable public API. */ -function buildPartitionWhereClause( +export function buildPartitionWhereClause( partitionColumn: string, partitionValue: string, granularity: string | undefined, @@ -526,18 +595,47 @@ function buildPartitionWhereClause( // date mode const expr = dateTruncExpr(granularity!, quotedCol, dialect) + // Normalize to ISO `yyyy-mm-dd` ONLY for T-SQL / Fabric, which use + // `CONVERT(DATE, '…', 23)` (strict ISO-8601 parser). The mssql driver + // returns date columns as JS Date objects that coerce to strings like + // "Mon Jan 01 2024 00:00:00 GMT+0000 (UTC)" — that format must be parsed + // to ISO before CONVERT will accept it. + // + // For other dialects, pass the value through unchanged. MySQL/MariaDB + // produce non-ISO `DATE_FORMAT` outputs (e.g. `YYYY-%u` for ISO week, + // which is `YYYY-42` not `YYYY-MM-DD`), and forcing ISO conversion would + // corrupt them — the WHERE would never match. Postgres / BigQuery / + // ClickHouse accept whatever their own `DATE_TRUNC`/`toStartOf*` + // emits verbatim on the round trip. + const needsIso = dialect === "tsql" || dialect === "fabric" || + dialect === "sqlserver" || dialect === "mssql" + const normalized = needsIso + ? (() => { + const trimmed = partitionValue.trim() + if (/^\d{4}-\d{2}-\d{2}(\s|T|$)/.test(trimmed)) return trimmed.slice(0, 10) + const d = new Date(trimmed) + return Number.isNaN(d.getTime()) ? trimmed : d.toISOString().slice(0, 10) + })() + : partitionValue + const escaped = normalized.replace(/'/g, "''") // Cast the literal appropriately per dialect switch (dialect) { case "bigquery": - return `${expr} = '${partitionValue}'` + return `${expr} = '${escaped}'` case "clickhouse": - return `${expr} = toDate('${partitionValue}')` + return `${expr} = toDate('${escaped}')` case "mysql": case "mariadb": - return `${expr} = '${partitionValue}'` + return `${expr} = '${escaped}'` + case "sqlserver": + case "mssql": + case "tsql": + case "fabric": + // Style 23 = ISO-8601 (yyyy-mm-dd), locale-safe + return `${expr} = CONVERT(DATE, '${escaped}', 23)` default: - return `${expr} = '${partitionValue}'` + return `${expr} = '${escaped}'` } } @@ -623,15 +721,20 @@ async function runPartitionedDiff(params: DataDiffParams): Promise { if (warehouse) { const cfg = Registry.getConfig(warehouse) - return cfg?.type ?? "generic" + return warehouseTypeToDialect(cfg?.type ?? "generic") } const warehouses = Registry.list().warehouses - return warehouses[0]?.type ?? "generic" + return warehouseTypeToDialect(warehouses[0]?.type ?? "generic") } const sourceDialect = resolveDialect(params.source_warehouse) const targetDialect = resolveDialect(params.target_warehouse ?? params.source_warehouse) - const { table1Name, table2Name } = resolveTableSources(params.source, params.target) + const { table1Name, table2Name } = resolveTableSources( + params.source, + params.target, + sourceDialect, + targetDialect, + ) // Discover partition values from BOTH source and target to catch target-only partitions. // Without this, rows that exist only in target partitions are silently missed. @@ -680,24 +783,75 @@ async function runPartitionedDiff(params: DataDiffParams): Promise + name.split(".").map((p) => quoteIdentForDialect(p, dialect)).join(".") + const sourceTableRef = quoteTableRefForDialect(params.source, sourceDialect) + const targetTableRef = quoteTableRefForDialect(params.target, targetDialect) + for (const pVal of partitionValues) { - const partWhere = buildPartitionWhereClause( + // Build per-side partition WHERE clauses. The dialects can differ + // (cross-warehouse diff) — the engine applies `where_clause` to both + // sides identically, so we can't use it to carry dialect-specific syntax. + // Bake each side's WHERE into its own subquery-wrapped SQL source instead. + const sourcePartWhere = buildPartitionWhereClause( params.partition_column!, pVal, params.partition_granularity, params.partition_bucket_size, sourceDialect, ) - const fullWhere = params.where_clause ? `(${params.where_clause}) AND (${partWhere})` : partWhere + const targetPartWhere = buildPartitionWhereClause( + params.partition_column!, + pVal, + params.partition_granularity, + params.partition_bucket_size, + targetDialect, + ) + + // Wrap each side's table as a SELECT subquery filtered to this partition. + // The recursive runDataDiff below will detect these as SQL queries and + // route them through the CTE-injection path, which is already side-aware. + const sourceSql = `SELECT * FROM ${sourceTableRef} WHERE ${sourcePartWhere}` + const targetSql = `SELECT * FROM ${targetTableRef} WHERE ${targetPartWhere}` const result = await runDataDiff({ ...params, - where_clause: fullWhere, + source: sourceSql, + target: targetSql, + // Preserve the user's shared where_clause — it's dialect-neutral. + where_clause: params.where_clause, + // Pass auto-discovered extras explicitly — `runDataDiff`'s own + // discovery path would skip these wrapped SELECT subqueries and + // regress to key-only comparison. + extra_columns: resolvedExtraColumns, partition_column: undefined, // prevent recursion }) @@ -718,6 +872,7 @@ async function runPartitionedDiff(params: DataDiffParams): Promise 0 ? { excluded_audit_columns: partitionExcludedAudit } : {}), } } @@ -727,6 +882,50 @@ export async function runDataDiff(params: DataDiffParams): Promise { + if (warehouse) return warehouse + const warehouses = Registry.list().warehouses + return warehouses[0]?.name + } + + // Resolve dialect from warehouse config + const resolveDialect = (warehouse: string | undefined): string => { + if (warehouse) { + const cfg = Registry.getConfig(warehouse) + return warehouseTypeToDialect(cfg?.type ?? "generic") + } + const warehouses = Registry.list().warehouses + return warehouseTypeToDialect(warehouses[0]?.type ?? "generic") + } + + const resolvedSource = resolveWarehouseName(params.source_warehouse) + const resolvedTarget = resolveWarehouseName(params.target_warehouse ?? params.source_warehouse) + + const dialect1 = resolveDialect(params.source_warehouse) + const dialect2 = resolveDialect(params.target_warehouse ?? params.source_warehouse) + + // Input-validation guards — run BEFORE the NAPI import so they produce the + // right error even in environments where `@altimateai/altimate-core` isn't + // built locally. + // + // JoinDiff cannot work across warehouses: it emits one FULL OUTER JOIN task + // referencing both CTE aliases, but side-aware injection only defines one + // side per task — the other alias would be unresolved. Guard early so users + // get a clear error instead of an obscure SQL parse failure. + const crossWarehousePre = resolvedSource !== resolvedTarget + if (params.algorithm === "joindiff" && crossWarehousePre) { + return { + success: false, + steps: 0, + error: + "joindiff requires both tables in the same warehouse; use hashdiff or auto for cross-warehouse comparisons.", + } + } + // Dynamically import NAPI module (not available in test environments without the binary) let DataParitySession: new (specJson: string) => { start(): string @@ -745,11 +944,20 @@ export async function runDataDiff(params: DataDiffParams): Promise { @@ -762,19 +970,6 @@ export async function runDataDiff(params: DataDiffParams): Promise { - if (warehouse) { - const cfg = Registry.getConfig(warehouse) - return cfg?.type ?? "generic" - } - const warehouses = Registry.list().warehouses - return warehouses[0]?.type ?? "generic" - } - - const dialect1 = resolveDialect(params.source_warehouse) - const dialect2 = resolveDialect(params.target_warehouse ?? params.source_warehouse) - // Auto-discover extra_columns when not explicitly provided. // The Rust engine only compares columns listed in extra_columns — if the list is // empty, it compares key existence only and reports all matched rows as "identical" @@ -873,8 +1068,24 @@ export async function runDataDiff(params: DataDiffParams): Promise { const warehouse = warehouseFor(task.table_side) - // Inject CTE definitions if we're in query-comparison mode - const sql = ctePrefix ? injectCte(task.sql, ctePrefix) : task.sql + // Inject CTE definitions if we're in query-comparison mode. In + // cross-warehouse mode each task only gets the CTE for its own side — + // the other side's base tables aren't bindable on this warehouse. + let prefix: string | null = null + if (ctePrefix) { + if (crossWarehouse) { + prefix = task.table_side === "Table2" ? targetCtePrefix : sourceCtePrefix + } else { + if (task.table_side === "Table1") { + prefix = sourceCtePrefix + } else if (task.table_side === "Table2") { + prefix = targetCtePrefix + } else { + prefix = ctePrefix + } + } + } + const sql = prefix ? injectCte(task.sql, prefix) : task.sql try { const rows = await executeQuery(sql, warehouse) return { id: task.id, rows, error: null } diff --git a/packages/opencode/src/altimate/native/connections/registry.ts b/packages/opencode/src/altimate/native/connections/registry.ts index 617d6685d..cc871682c 100644 --- a/packages/opencode/src/altimate/native/connections/registry.ts +++ b/packages/opencode/src/altimate/native/connections/registry.ts @@ -122,6 +122,7 @@ const DRIVER_MAP: Record = { mariadb: "@altimateai/drivers/mysql", sqlserver: "@altimateai/drivers/sqlserver", mssql: "@altimateai/drivers/sqlserver", + fabric: "@altimateai/drivers/sqlserver", databricks: "@altimateai/drivers/databricks", duckdb: "@altimateai/drivers/duckdb", oracle: "@altimateai/drivers/oracle", @@ -165,6 +166,7 @@ async function createConnector(name: string, config: ConnectionConfig): Promise< "mariadb", "sqlserver", "mssql", + "fabric", "oracle", "snowflake", "clickhouse", diff --git a/packages/opencode/src/altimate/tools/data-diff.ts b/packages/opencode/src/altimate/tools/data-diff.ts index bf9948748..163cbb8bf 100644 --- a/packages/opencode/src/altimate/tools/data-diff.ts +++ b/packages/opencode/src/altimate/tools/data-diff.ts @@ -203,7 +203,11 @@ function formatOutcome(outcome: any, source: string, target: string): string { lines.push(` Sample differences (first ${Math.min(diffRows.length, 5)}):`) for (const d of diffRows.slice(0, 5)) { const label = d.sign === "-" ? "source only" : "target only" - lines.push(` [${label}] ${d.values?.join(" | ")}`) + // `d.values?.join(" | ") ?? "(no values)"` misses the common case where + // `values` is an empty array — `[].join(" | ")` returns "" (not null), + // so the coalesce never triggers. Gate on length explicitly. + const values = d.values?.length ? d.values.join(" | ") : "(no values)" + lines.push(` [${label}] ${values}`) } } diff --git a/packages/opencode/test/altimate/connections.test.ts b/packages/opencode/test/altimate/connections.test.ts index f741a8cf1..5c9680297 100644 --- a/packages/opencode/test/altimate/connections.test.ts +++ b/packages/opencode/test/altimate/connections.test.ts @@ -81,6 +81,23 @@ describe("ConnectionRegistry", () => { await expect(Registry.get("mydb")).rejects.toThrow("Supported:") }) + test("fabric type is recognized in DRIVER_MAP and routes to sqlserver driver", () => { + Registry.setConfigs({ + fabricdb: { + type: "fabric", + host: "myserver.datawarehouse.fabric.microsoft.com", + database: "migration", + authentication: "default", + }, + }) + const config = Registry.getConfig("fabricdb") + expect(config).toBeDefined() + expect(config?.type).toBe("fabric") + const result = Registry.list() + expect(result.warehouses).toHaveLength(1) + expect(result.warehouses[0].type).toBe("fabric") + }) + test("getConfig returns config for known connection", () => { Registry.setConfigs({ mydb: { type: "postgres", host: "localhost" }, diff --git a/packages/opencode/test/altimate/data-diff-cross-dialect.test.ts b/packages/opencode/test/altimate/data-diff-cross-dialect.test.ts new file mode 100644 index 000000000..bf349e323 --- /dev/null +++ b/packages/opencode/test/altimate/data-diff-cross-dialect.test.ts @@ -0,0 +1,232 @@ +/** + * Tests for cross-dialect partitioned diff and joindiff cross-warehouse guard. + * + * These cover the two CRITICAL/MAJOR bugs fixed in the review follow-up: + * 1. Partitioned WHERE was built with sourceDialect only and applied to both + * warehouses; cross-dialect diffs blew up the target with foreign syntax. + * 2. Explicit `algorithm: "joindiff"` with different warehouses silently + * produced SQL referencing an undefined CTE alias. + * + * ## Why unit tests, not integration + * + * An earlier version of this file integration-tested the fixes by driving + * `runDataDiff` end-to-end with mocked Registry + mocked `@altimateai/altimate-core`. + * That approach leaked `mock.module()` state across test files in Bun (bun:test + * runs the whole suite in one process), breaking `connections.test.ts` and + * `telemetry-safety.test.ts`. Additionally, when other test files imported the + * real `@altimateai/altimate-core` first, Bun cached it and our NAPI mock was + * bypassed — the npm-published `0.2.6` lacks `DataParitySession`, so our + * integration test would fail with "altimate-core NAPI module unavailable" + * regardless of our mock. + * + * The fix is in pure-function SQL builders (`dateTruncExpr`, + * `buildPartitionWhereClause`). Testing them directly is both more targeted + * (zero coupling to NAPI availability / Registry state) and more reliable in a + * single-process test runner. + */ +import { describe, test, expect, beforeAll, afterAll, beforeEach } from "bun:test" + +import { + buildPartitionWhereClause, + dateTruncExpr, + runDataDiff, +} from "../../src/altimate/native/connections/data-diff" +import * as Registry from "../../src/altimate/native/connections/registry" + +describe("dateTruncExpr — dialect-native output", () => { + test("tsql uses DATETRUNC with unquoted datepart keyword", () => { + expect(dateTruncExpr("month", "[order_date]", "tsql")).toBe("DATETRUNC(MONTH, [order_date])") + }) + + test("fabric matches tsql", () => { + expect(dateTruncExpr("day", "[d]", "fabric")).toBe("DATETRUNC(DAY, [d])") + }) + + test("postgres uses DATE_TRUNC with lowercase string literal", () => { + expect(dateTruncExpr("month", `"order_date"`, "postgres")).toBe(`DATE_TRUNC('month', "order_date")`) + }) + + test("bigquery uses DATE_TRUNC with uppercase unit keyword", () => { + expect(dateTruncExpr("month", "order_date", "bigquery")).toBe("DATE_TRUNC(order_date, MONTH)") + }) + + test("mysql uses DATE_FORMAT with format strings", () => { + expect(dateTruncExpr("month", "`d`", "mysql")).toContain("DATE_FORMAT(`d`") + expect(dateTruncExpr("month", "`d`", "mysql")).toContain("%Y-%m-01") + }) +}) + +describe("buildPartitionWhereClause — cross-dialect correctness (the CRITICAL fix)", () => { + const col = "order_date" + const value = "2026-04-01" + + test("tsql: DATETRUNC + CONVERT(DATE, ..., 23) ISO-8601 style", () => { + const sql = buildPartitionWhereClause(col, value, "month", undefined, "tsql") + expect(sql).toContain("DATETRUNC(MONTH, [order_date])") + expect(sql).toContain("CONVERT(DATE, '2026-04-01', 23)") + // Must not leak generic single-quoted literal in tsql + expect(sql).not.toMatch(/=\s*'2026-04-01'\s*$/) + }) + + test("fabric: same as tsql", () => { + const sql = buildPartitionWhereClause(col, value, "month", undefined, "fabric") + expect(sql).toContain("DATETRUNC(MONTH, [order_date])") + expect(sql).toContain("CONVERT(DATE, '2026-04-01', 23)") + }) + + test("postgres: DATE_TRUNC + bare date literal", () => { + const sql = buildPartitionWhereClause(col, value, "month", undefined, "postgres") + expect(sql).toContain(`DATE_TRUNC('month', "order_date")`) + expect(sql).toContain(`'2026-04-01'`) + // Must not produce T-SQL syntax + expect(sql).not.toMatch(/DATETRUNC\(/i) + expect(sql).not.toMatch(/CONVERT\(DATE/i) + }) + + test("clickhouse: toStartOfMonth + toDate() cast", () => { + const sql = buildPartitionWhereClause(col, value, "month", undefined, "clickhouse") + expect(sql).toContain("toStartOfMonth(`order_date`)") + expect(sql).toContain("toDate('2026-04-01')") + }) + + test("bigquery: DATE_TRUNC uppercase + bare literal", () => { + const sql = buildPartitionWhereClause(col, value, "month", undefined, "bigquery") + // `quoteIdentForDialect` falls through to ANSI double-quotes for bigquery + expect(sql).toContain(`DATE_TRUNC("order_date", MONTH)`) + expect(sql).toContain(`'2026-04-01'`) + }) + + // The regression this guards against: before the fix, the orchestrator built + // ONE partition WHERE using `sourceDialect` and passed it to both sides. + // A cross-dialect MSSQL → Postgres diff would send `DATETRUNC`/`CONVERT` to + // Postgres and blow up. With per-side WHERE generation, the two outputs are + // independent — asserted directly here. + test("cross-dialect sanity: MSSQL and Postgres outputs are independent and incompatible", () => { + const mssqlWhere = buildPartitionWhereClause(col, value, "month", undefined, "tsql") + const pgWhere = buildPartitionWhereClause(col, value, "month", undefined, "postgres") + expect(mssqlWhere).not.toEqual(pgWhere) + // MSSQL WHERE would break when sent to Postgres and vice versa — the test + // proves each dialect yields only its own syntax. + expect(mssqlWhere).toMatch(/DATETRUNC/i) + expect(pgWhere).not.toMatch(/DATETRUNC/i) + expect(pgWhere).toMatch(/DATE_TRUNC/i) + expect(mssqlWhere).not.toMatch(/DATE_TRUNC/i) + }) + + test("numeric mode produces bucket range, ignores dialect", () => { + const sql = buildPartitionWhereClause("amount", "100000", undefined, 1000, "tsql") + expect(sql).toContain("[amount] >= 100000") + expect(sql).toContain("[amount] < 101000") + }) + + test("categorical mode quotes the value with single-quote escaping", () => { + const sql = buildPartitionWhereClause("status", "it's active", undefined, undefined, "postgres") + expect(sql).toContain(`"status" = 'it''s active'`) + }) + + test("tsql date literal normalizes timestamp inputs to ISO yyyy-mm-dd", () => { + // Regression: mssql returns Date-like strings (e.g. "Mon Apr 01 2024 …") + // that must be normalized before CONVERT(DATE, …, 23) can parse them. + const sql = buildPartitionWhereClause(col, "Mon Apr 01 2024 00:00:00 GMT+0000", "month", undefined, "tsql") + expect(sql).toContain("CONVERT(DATE, '2024-04-01', 23)") + }) + + test("mysql week-format values pass through unchanged (regression: no ISO rewrite)", () => { + // Regression guard: MySQL `DATE_FORMAT(%Y-%u)` emits e.g. "2024-42" for + // week 42 — that's not a parseable JS Date. An earlier revision tried + // to normalize it to ISO `yyyy-mm-dd`, which either produced NaN or a + // wildly wrong date (Dec 2024 in 0042 AD). Must be passed through. + const sql = buildPartitionWhereClause("ts", "2024-42", "week", undefined, "mysql") + expect(sql).toContain("= '2024-42'") + expect(sql).not.toContain("0042") + expect(sql).not.toContain("NaN") + }) + + test("mysql DATE_FORMAT month output flows through verbatim", () => { + const sql = buildPartitionWhereClause("ts", "2024-04-01", "month", undefined, "mariadb") + expect(sql).toContain("DATE_FORMAT(`ts`, '%Y-%m-01')") + expect(sql).toContain("= '2024-04-01'") + }) +}) + +// The joindiff guard runs BEFORE `runDataDiff`'s NAPI import, so we can drive +// it end-to-end without any mock. This verifies the actual wiring, not just +// the pure-function output — complementary to the unit tests above. +describe("joindiff + cross-warehouse guard", () => { + beforeAll(() => { + process.env.ALTIMATE_TELEMETRY_DISABLED = "true" + }) + afterAll(() => { + delete process.env.ALTIMATE_TELEMETRY_DISABLED + Registry.reset() + }) + beforeEach(() => { + Registry.reset() + }) + + test("explicit joindiff with different warehouses (mixed dialect) returns early error", async () => { + Registry.setConfigs({ + msrc: { type: "sqlserver", host: "mssql-host", database: "src" }, + ptgt: { type: "postgres", host: "pg-host", database: "tgt" }, + }) + const result = await runDataDiff({ + source: "dbo.orders", + target: "public.orders", + key_columns: ["id"], + source_warehouse: "msrc", + target_warehouse: "ptgt", + algorithm: "joindiff", + }) + expect(result.success).toBe(false) + expect(result.error).toMatch(/joindiff requires both tables in the same warehouse/i) + // Guard must fire before any NAPI/driver work, so steps stays at 0. + expect(result.steps).toBe(0) + }) + + test("explicit joindiff with different warehouses (SAME dialect) still errors", async () => { + // Regression guard: if `crossWarehouse` were computed from dialect + // equality (as an earlier revision did) instead of resolved warehouse + // identity, this case would slip through and route a JOIN query to a + // warehouse that doesn't have the other side's tables. Two MSSQL + // servers share a dialect but are independent physical databases. + Registry.setConfigs({ + mssql_a: { type: "sqlserver", host: "server-a", database: "src" }, + mssql_b: { type: "sqlserver", host: "server-b", database: "tgt" }, + }) + const result = await runDataDiff({ + source: "dbo.orders", + target: "dbo.orders", + key_columns: ["id"], + source_warehouse: "mssql_a", + target_warehouse: "mssql_b", + algorithm: "joindiff", + }) + expect(result.success).toBe(false) + expect(result.error).toMatch(/joindiff requires both tables in the same warehouse/i) + expect(result.steps).toBe(0) + }) + + test("same-name warehouse on both sides does NOT trigger the guard", async () => { + // Guard compares resolved warehouse identity, not dialect — same name → + // guard stays quiet. We can't drive the whole diff without NAPI, but we + // can confirm the guard error is NOT the one returned (the call will + // instead fail with the NAPI-unavailable error in test envs that lack the + // built binary, which is fine). + Registry.setConfigs({ + shared: { type: "sqlserver", host: "shared-host", database: "d" }, + }) + const result = await runDataDiff({ + source: "dbo.orders", + target: "dbo.orders_v2", + key_columns: ["id"], + source_warehouse: "shared", + target_warehouse: "shared", + algorithm: "joindiff", + }) + // May succeed or fail depending on NAPI availability — the assertion here + // is only that the joindiff guard did not reject this same-warehouse case. + if (!result.success) { + expect(result.error).not.toMatch(/joindiff requires both tables in the same warehouse/i) + } + }) +}) diff --git a/packages/opencode/test/altimate/data-diff-cte.test.ts b/packages/opencode/test/altimate/data-diff-cte.test.ts new file mode 100644 index 000000000..aea08f27c --- /dev/null +++ b/packages/opencode/test/altimate/data-diff-cte.test.ts @@ -0,0 +1,161 @@ +/** + * Tests for CTE wrapping and injection in SQL-query mode. + * + * The tricky case is cross-warehouse comparison where source and target are both + * SQL queries referencing tables that only exist on their own side. The combined + * CTE prefix cannot be sent to both warehouses because T-SQL / Fabric parse-bind + * every CTE body even when unreferenced — the "other side" CTE would fail to + * resolve its base table. + */ +import { describe, test, expect } from "bun:test" + +import { resolveTableSources, injectCte } from "../../src/altimate/native/connections/data-diff" + +describe("resolveTableSources", () => { + test("plain table names pass through without wrapping", () => { + const r = resolveTableSources("orders", "orders_v2") + expect(r.table1Name).toBe("orders") + expect(r.table2Name).toBe("orders_v2") + expect(r.ctePrefix).toBeNull() + expect(r.sourceCtePrefix).toBeNull() + expect(r.targetCtePrefix).toBeNull() + }) + + test("schema-qualified plain names pass through", () => { + const r = resolveTableSources("gold.dim_customer", "TRANSFORMED.DimCustomer") + expect(r.table1Name).toBe("gold.dim_customer") + expect(r.table2Name).toBe("TRANSFORMED.DimCustomer") + expect(r.ctePrefix).toBeNull() + }) + + test("both queries are wrapped in CTEs with aliases", () => { + const r = resolveTableSources( + "SELECT id, val FROM [TRANSFORMED].[DimCustomer]", + "SELECT id, val FROM [gold].[dim_customer]", + ) + expect(r.table1Name).toBe("__diff_source") + expect(r.table2Name).toBe("__diff_target") + expect(r.ctePrefix).toContain("__diff_source AS (") + expect(r.ctePrefix).toContain("__diff_target AS (") + expect(r.ctePrefix).toContain("[TRANSFORMED].[DimCustomer]") + expect(r.ctePrefix).toContain("[gold].[dim_customer]") + }) + + test("side-specific prefixes contain only the relevant CTE", () => { + const r = resolveTableSources( + "SELECT id FROM [TRANSFORMED].[DimCustomer]", + "SELECT id FROM [gold].[dim_customer]", + ) + // Source prefix has source table only — must not leak target table ref + expect(r.sourceCtePrefix).toContain("__diff_source AS (") + expect(r.sourceCtePrefix).toContain("[TRANSFORMED].[DimCustomer]") + expect(r.sourceCtePrefix).not.toContain("__diff_target") + expect(r.sourceCtePrefix).not.toContain("[gold].[dim_customer]") + + // Target prefix has target table only — must not leak source table ref + expect(r.targetCtePrefix).toContain("__diff_target AS (") + expect(r.targetCtePrefix).toContain("[gold].[dim_customer]") + expect(r.targetCtePrefix).not.toContain("__diff_source") + expect(r.targetCtePrefix).not.toContain("[TRANSFORMED].[DimCustomer]") + }) + + test("mixed: plain source + query target still wraps both sides", () => { + const r = resolveTableSources( + "orders", + "SELECT * FROM other.orders WHERE region = 'EU'", + ) + expect(r.table1Name).toBe("__diff_source") + expect(r.table2Name).toBe("__diff_target") + // Plain table wrapped with ANSI double-quoted identifiers + expect(r.sourceCtePrefix).toContain('SELECT * FROM "orders"') + expect(r.targetCtePrefix).toContain("other.orders") + }) + + test("dialect-aware quoting: tsql uses square brackets", () => { + // Fix #4: plain table names wrapped inside CTEs must use the side's + // native quoting. `"schema"."table"` fails on MSSQL with QUOTED_IDENTIFIER OFF. + const r = resolveTableSources( + "dbo.orders", + "SELECT * FROM base", + "tsql", + "postgres", + ) + expect(r.sourceCtePrefix).toContain("[dbo].[orders]") + expect(r.sourceCtePrefix).not.toContain('"dbo"."orders"') + }) + + test("dialect-aware quoting: fabric uses square brackets; mysql uses backticks", () => { + // Pair the plain-table side with a SQL-query counterpart to force CTE wrapping. + const fabric = resolveTableSources( + "gold.dim_customer", + "SELECT * FROM other", + "fabric", + "fabric", + ) + expect(fabric.sourceCtePrefix).toContain("[gold].[dim_customer]") + + const mysql = resolveTableSources( + "SELECT 1 AS id", + "db.orders", + "mysql", + "mysql", + ) + expect(mysql.targetCtePrefix).toContain("`db`.`orders`") + }) + + test("query detection requires both keyword AND whitespace", () => { + // A table literally named "select" should NOT be treated as a query + const r = resolveTableSources("select", "with") + expect(r.table1Name).toBe("select") + expect(r.table2Name).toBe("with") + expect(r.ctePrefix).toBeNull() + }) +}) + +describe("injectCte", () => { + test("prepends CTE prefix to a plain SELECT", () => { + const prefix = "WITH __diff_source AS (\nSELECT 1 AS id\n)" + const sql = "SELECT COUNT(*) FROM __diff_source" + const out = injectCte(sql, prefix) + expect(out.startsWith(prefix)).toBe(true) + expect(out).toContain("SELECT COUNT(*) FROM __diff_source") + }) + + test("merges with an engine-emitted WITH clause", () => { + const prefix = "WITH __diff_source AS (\nSELECT * FROM base\n)" + const engineSql = "WITH engine_cte AS (SELECT id FROM __diff_source) SELECT * FROM engine_cte" + const out = injectCte(engineSql, prefix) + // Must start with a single WITH, with our CTE first, then engine's + expect(out.match(/^WITH /)).not.toBeNull() + expect((out.match(/\bWITH\b/g) ?? []).length).toBe(1) + expect(out.indexOf("__diff_source AS")).toBeLessThan(out.indexOf("engine_cte AS")) + }) + + test("side-specific injection: source prefix does not leak target refs", () => { + // Simulates cross-warehouse fp1_1 task going to MSSQL. It must not see any + // reference to the Fabric-only target table, since MSSQL parse-binds every + // CTE body. + const r = resolveTableSources( + "SELECT id FROM [TRANSFORMED].[DimCustomer]", + "SELECT id FROM [gold].[dim_customer]", + ) + const engineFp1Sql = + "SELECT COUNT(*), SUM(CAST(...HASHBYTES('MD5', CONCAT(CAST([id] AS NVARCHAR(MAX))))...)) FROM [__diff_source]" + const sqlForMssql = injectCte(engineFp1Sql, r.sourceCtePrefix!) + expect(sqlForMssql).toContain("[TRANSFORMED].[DimCustomer]") + expect(sqlForMssql).not.toContain("[gold].[dim_customer]") + expect(sqlForMssql).not.toContain("__diff_target") + }) + + test("side-specific injection: target prefix does not leak source refs", () => { + const r = resolveTableSources( + "SELECT id FROM [TRANSFORMED].[DimCustomer]", + "SELECT id FROM [gold].[dim_customer]", + ) + const engineFp2Sql = "SELECT COUNT(*) FROM [__diff_target]" + const sqlForFabric = injectCte(engineFp2Sql, r.targetCtePrefix!) + expect(sqlForFabric).toContain("[gold].[dim_customer]") + expect(sqlForFabric).not.toContain("[TRANSFORMED].[DimCustomer]") + expect(sqlForFabric).not.toContain("__diff_source") + }) +}) diff --git a/packages/opencode/test/altimate/data-diff-dialect.test.ts b/packages/opencode/test/altimate/data-diff-dialect.test.ts new file mode 100644 index 000000000..083c64d57 --- /dev/null +++ b/packages/opencode/test/altimate/data-diff-dialect.test.ts @@ -0,0 +1,55 @@ +/** + * Tests for warehouse-type-to-dialect mapping in the data-diff orchestrator. + * + * The Rust engine's SqlDialect serde deserialization only accepts exact lowercase + * variant names (e.g., "tsql", not "sqlserver"). This mapping bridges the gap + * between warehouse config types and Rust dialect names. + */ +import { describe, test, expect } from "bun:test" + +import { warehouseTypeToDialect } from "../../src/altimate/native/connections/data-diff" + +describe("warehouseTypeToDialect", () => { + // --- Remapped types --- + + test("maps sqlserver to tsql", () => { + expect(warehouseTypeToDialect("sqlserver")).toBe("tsql") + }) + + test("maps mssql to tsql", () => { + expect(warehouseTypeToDialect("mssql")).toBe("tsql") + }) + + test("maps fabric to fabric", () => { + expect(warehouseTypeToDialect("fabric")).toBe("fabric") + }) + + test("maps postgresql to postgres", () => { + expect(warehouseTypeToDialect("postgresql")).toBe("postgres") + }) + + test("maps mariadb to mysql", () => { + expect(warehouseTypeToDialect("mariadb")).toBe("mysql") + }) + + // --- Passthrough types (already match Rust names) --- + + test("passes through postgres unchanged", () => { + expect(warehouseTypeToDialect("postgres")).toBe("postgres") + }) + + test("passes through snowflake unchanged", () => { + expect(warehouseTypeToDialect("snowflake")).toBe("snowflake") + }) + + test("passes through generic unchanged", () => { + expect(warehouseTypeToDialect("generic")).toBe("generic") + }) + + // --- Case insensitivity --- + + test("handles uppercase input", () => { + expect(warehouseTypeToDialect("SQLSERVER")).toBe("tsql") + expect(warehouseTypeToDialect("PostgreSQL")).toBe("postgres") + }) +}) diff --git a/packages/opencode/test/altimate/driver-normalize.test.ts b/packages/opencode/test/altimate/driver-normalize.test.ts index 95f348289..43b31c4e8 100644 --- a/packages/opencode/test/altimate/driver-normalize.test.ts +++ b/packages/opencode/test/altimate/driver-normalize.test.ts @@ -463,6 +463,19 @@ describe("normalizeConfig — SQL Server", () => { expect(result.host).toBe("myserver") expect(result.user).toBe("sa") }) + + test("fabric type uses SQLSERVER_ALIASES", () => { + const result = normalizeConfig({ + type: "fabric", + server: "myserver.datawarehouse.fabric.microsoft.com", + trustServerCertificate: false, + authentication: "default", + }) + expect(result.host).toBe("myserver.datawarehouse.fabric.microsoft.com") + expect(result.server).toBeUndefined() + expect(result.trust_server_certificate).toBe(false) + expect(result.trustServerCertificate).toBeUndefined() + }) }) // ---------------------------------------------------------------------------