diff --git a/connection.go b/connection.go index 484fae11..5b8a91c2 100644 --- a/connection.go +++ b/connection.go @@ -60,7 +60,7 @@ func (c *conn) Close() error { // Record DELETE_SESSION regardless of error (matches JDBC), then flush and release if c.telemetry != nil { - c.telemetry.RecordOperation(ctx, c.id, telemetry.OperationTypeDeleteSession, time.Since(closeStart).Milliseconds(), err) + c.telemetry.RecordOperation(ctx, c.id, "", telemetry.OperationTypeDeleteSession, time.Since(closeStart).Milliseconds(), err) _ = c.telemetry.Close(ctx) telemetry.ReleaseForConnection(c.cfg.Host) } @@ -130,15 +130,20 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name executeStart := time.Now() exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args) log, ctx = client.LoggerAndContext(ctx, exStmtResp) - stagingErr := c.execStagingOperation(exStmtResp, ctx) - // Telemetry: track statement execution + // Telemetry: set up metric context BEFORE staging operation so that the + // staging op's telemetryUpdate callback can attach tags to the metric context. var statementID string var closeOpErr error // Track CloseOperation errors for telemetry if c.telemetry != nil && exStmtResp != nil && exStmtResp.OperationHandle != nil && exStmtResp.OperationHandle.OperationId != nil { statementID = client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID) - // Use BeforeExecuteWithTime to set the correct start time (before execution) ctx = c.telemetry.BeforeExecuteWithTime(ctx, c.id, statementID, executeStart) + c.telemetry.AddTag(ctx, telemetry.TagOperationType, telemetry.OperationTypeExecuteStatement) + } + + stagingErr := c.execStagingOperation(exStmtResp, ctx) + + if c.telemetry != nil && statementID != "" { defer func() { finalErr := err if stagingErr != nil { @@ -163,7 +168,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name OperationHandle: exStmtResp.OperationHandle, }) if c.telemetry != nil { - c.telemetry.RecordOperation(ctx, c.id, telemetry.OperationTypeCloseStatement, time.Since(closeOpStart).Milliseconds(), err1) + c.telemetry.RecordOperation(ctx, c.id, statementID, telemetry.OperationTypeCloseStatement, time.Since(closeOpStart).Milliseconds(), err1) } if err1 != nil { log.Err(err1).Msg("databricks: failed to close operation after executing statement") @@ -179,7 +184,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name if stagingErr != nil { log.Err(stagingErr).Msgf("databricks: failed to execute query: query %s", query) - return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp) + return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, stagingErr, opStatusResp) } res := result{AffectedRows: opStatusResp.GetNumModifiedRows()} @@ -187,6 +192,47 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name return &res, nil } +// chunkTimingAccumulator aggregates per-chunk fetch latencies for telemetry. +// It tracks the initial, slowest, and cumulative latencies, plus the number +// of CloudFetch file downloads. All fields should be accessed under the +// serialization provided by database/sql's closemu (see QueryContext). +type chunkTimingAccumulator struct { + initialMs int64 + slowestMs int64 + sumMs int64 + initialSet bool + // cloudFetchFileCount counts individual S3 files downloaded via CloudFetch. + // Used to set chunk_total_present correctly for both bulk and paginated CloudFetch: + // - paginated CF (1 link/FetchResults): file count == page count == correct total + // - bulk CF (all links in DirectResults): file count == actual S3 downloads + // For inline ArrowBatch results this stays 0 and chunk_total_present falls back to chunkCount. + cloudFetchFileCount int +} + +// record accumulates a single chunk or download latency. Returns true if +// the latency was positive and tags should be updated; false otherwise. +func (a *chunkTimingAccumulator) record(latencyMs int64) bool { + if latencyMs <= 0 { + return false + } + if !a.initialSet { + a.initialMs = latencyMs + a.initialSet = true + } + if latencyMs > a.slowestMs { + a.slowestMs = latencyMs + } + a.sumMs += latencyMs + return true +} + +// applyTags writes the current timing state to the telemetry context. +func (a *chunkTimingAccumulator) applyTags(ctx context.Context, interceptor *telemetry.Interceptor) { + interceptor.AddTag(ctx, telemetry.TagChunkInitialLatencyMs, a.initialMs) + interceptor.AddTag(ctx, telemetry.TagChunkSlowestLatencyMs, a.slowestMs) + interceptor.AddTag(ctx, telemetry.TagChunkSumLatencyMs, a.sumMs) +} + // QueryContext executes a query that may return rows, such as a // SELECT. // @@ -206,32 +252,116 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam log, ctx = client.LoggerAndContext(ctx, exStmtResp) defer log.Duration(msg, start) - // Telemetry: track statement execution + // Telemetry: set up metric context for the statement. + // BeforeExecuteWithTime anchors startTime to before runQuery() ran. var statementID string if c.telemetry != nil && exStmtResp != nil && exStmtResp.OperationHandle != nil && exStmtResp.OperationHandle.OperationId != nil { statementID = client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID) - // Use BeforeExecuteWithTime to set the correct start time (before execution) ctx = c.telemetry.BeforeExecuteWithTime(ctx, c.id, statementID, executeStart) - defer func() { - c.telemetry.AfterExecute(ctx, err) - c.telemetry.CompleteStatement(ctx, statementID, err != nil) - }() + c.telemetry.AddTag(ctx, telemetry.TagOperationType, telemetry.OperationTypeExecuteStatement) } if err != nil { + // Error path: finalize and emit the EXECUTE_STATEMENT metric immediately — + // there are no rows to iterate so the metric is complete right now. + if c.telemetry != nil && statementID != "" { + c.telemetry.AfterExecute(ctx, err) + c.telemetry.CompleteStatement(ctx, statementID, true) + } log.Err(err).Msg("databricks: failed to run query") // To log query we need to redact credentials return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp) } - // Telemetry callback for tracking row fetching metrics - telemetryUpdate := func(chunkCount int, bytesDownloaded int64) { - if c.telemetry != nil { - c.telemetry.AddTag(ctx, "chunk_count", chunkCount) - c.telemetry.AddTag(ctx, "bytes_downloaded", bytesDownloaded) + // Success path: freeze execute latency NOW (before row iteration inflates time.Since). + // AfterExecute/CompleteStatement are called from closeCallback after all chunks + // are fetched, so the final metric carries complete chunk timing data. + if c.telemetry != nil && statementID != "" { + c.telemetry.FinalizeLatency(ctx) + } + + // chunkTimingAccumulator aggregates per-chunk fetch latencies across all + // fetchResultPage calls. These fields are safe without a mutex because they + // are only mutated from callbacks serialized by database/sql's closemu lock: + // telemetryUpdate and cloudFetchCallback run inside rows.Next() (which + // holds closemu.RLock), and closeCallback runs inside rows.Close() (which + // holds closemu.Lock). This ensures mutual exclusion even when Close() is + // called from database/sql's awaitDone goroutine on context cancellation. + var timing chunkTimingAccumulator + + // Detach from caller's context so that telemetry tag writes and flushes + // survive context cancellation (e.g. query timeout, database/sql awaitDone). + // All three callbacks (telemetryUpdate, cloudFetchCallback, closeCallback) + // use this detached context uniformly. + telemetryCtx := context2.WithoutCancel(ctx) + + // Telemetry callback invoked after each result page is fetched. + telemetryUpdate := func(chunkCount int, bytesDownloaded int64, chunkIndex int, chunkLatencyMs int64, _ int32) { + if c.telemetry == nil { + return + } + c.telemetry.AddTag(telemetryCtx, telemetry.TagChunkCount, chunkCount) + c.telemetry.AddTag(telemetryCtx, telemetry.TagBytesDownloaded, bytesDownloaded) + + // Aggregate per-chunk fetch latencies (skip direct results where latency is 0). + if timing.record(chunkLatencyMs) { + timing.applyTags(telemetryCtx, c.telemetry) + } + // chunk_total_present is set definitively in closeCallback once all pages are known. + } + + // cloudFetchCallback is invoked per S3 file download for CloudFetch result sets. + // It aggregates individual file download times into the same initial/slowest/sum vars + // used for inline chunk timing, matching JDBC's per-chunk HTTP GET timing model. + // For inline (non-CloudFetch) result sets this is never called. + var cloudFetchCallback func(downloadMs int64) + if c.telemetry != nil { + cloudFetchCallback = func(downloadMs int64) { + timing.cloudFetchFileCount++ // always count files for chunk_total_present, even sub-ms downloads + if timing.record(downloadMs) { + timing.applyTags(telemetryCtx, c.telemetry) + } + } + } + + // closeCallback is invoked from rows.Close() after all rows have been consumed. + // At that point chunk timing is fully accumulated in telemetryCtx tags, so we + // finalize EXECUTE_STATEMENT here rather than at QueryContext return time. + var closeCallback func(latencyMs int64, chunkCount int, iterErr error, closeErr error) + if c.telemetry != nil && statementID != "" { + interceptor := c.telemetry + connID := c.id + stmtID := statementID + closeCallback = func(latencyMs int64, chunkCount int, iterErr error, closeErr error) { + // Set chunk_total_present to the definitive total now that all iteration is done. + // For CloudFetch, use cloudFetchFileCount (actual S3 downloads) — this handles + // both paginated CF (1 link/page, so file count == page count) and bulk CF + // (all links in DirectResults, so file count == total S3 files). + // For inline ArrowBatch, cloudFetchFileCount is 0; fall back to chunkCount. + if timing.cloudFetchFileCount > 0 { + interceptor.AddTag(telemetryCtx, telemetry.TagChunkTotalPresent, timing.cloudFetchFileCount) + } else if chunkCount > 0 { + interceptor.AddTag(telemetryCtx, telemetry.TagChunkTotalPresent, chunkCount) + } + // EXECUTE_STATEMENT uses the iteration error (row consumption failure) + // to correctly report whether the statement succeeded or failed. + interceptor.AfterExecute(telemetryCtx, iterErr) + interceptor.CompleteStatement(telemetryCtx, stmtID, iterErr != nil) + // CLOSE_STATEMENT uses the actual CloseOperation RPC error. + interceptor.RecordOperation(telemetryCtx, connID, stmtID, telemetry.OperationTypeCloseStatement, latencyMs, closeErr) + } + } else if c.telemetry != nil { + interceptor := c.telemetry + connID := c.id + closeCallback = func(latencyMs int64, _ int, _ error, closeErr error) { + interceptor.RecordOperation(telemetryCtx, connID, "", telemetry.OperationTypeCloseStatement, latencyMs, closeErr) } } - rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, telemetryUpdate) + rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, &rows.TelemetryCallbacks{ + OnChunkFetched: telemetryUpdate, + OnClose: closeCallback, + OnCloudFetchFile: cloudFetchCallback, + }) return rows, err } @@ -396,14 +526,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver } } - executeStart := time.Now() resp, err := c.client.ExecuteStatement(ctx, &req) - // Record the Thrift call latency as a separate operation metric. - // This is distinct from the statement-level metric (BeforeExecuteWithTime), which - // measures end-to-end latency including polling and row fetching. - if c.telemetry != nil { - c.telemetry.RecordOperation(ctx, c.id, telemetry.OperationTypeExecuteStatement, time.Since(executeStart).Milliseconds(), err) - } var log *logger.DBSQLLogger log, ctx = client.LoggerAndContext(ctx, resp) @@ -668,14 +791,16 @@ func (c *conn) execStagingOperation( } if len(driverctx.StagingPathsFromContext(ctx)) != 0 { - // Telemetry callback for staging operation row fetching - telemetryUpdate := func(chunkCount int, bytesDownloaded int64) { + // Telemetry callback for staging operation row fetching (chunk timing not tracked for staging ops). + telemetryUpdate := func(chunkCount int, bytesDownloaded int64, chunkIndex int, chunkLatencyMs int64, totalChunksPresent int32) { if c.telemetry != nil { - c.telemetry.AddTag(ctx, "chunk_count", chunkCount) - c.telemetry.AddTag(ctx, "bytes_downloaded", bytesDownloaded) + c.telemetry.AddTag(ctx, telemetry.TagChunkCount, chunkCount) + c.telemetry.AddTag(ctx, telemetry.TagBytesDownloaded, bytesDownloaded) } } - row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, telemetryUpdate) + row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults, &rows.TelemetryCallbacks{ + OnChunkFetched: telemetryUpdate, + }) if err != nil { return dbsqlerrint.NewDriverError(ctx, "error reading row.", err) } diff --git a/connection_test.go b/connection_test.go index c4cb9f15..badf6a7b 100644 --- a/connection_test.go +++ b/connection_test.go @@ -1975,6 +1975,59 @@ func TestConn_execStagingOperation(t *testing.T) { }) } +func TestChunkTimingAccumulator_Record(t *testing.T) { + tests := []struct { + name string + latencies []int64 + wantInit int64 + wantSlow int64 + wantSum int64 + wantReturn []bool + }{ + {"zero latency skipped", []int64{0}, 0, 0, 0, []bool{false}}, + {"negative skipped", []int64{-5}, 0, 0, 0, []bool{false}}, + {"single positive", []int64{10}, 10, 10, 10, []bool{true}}, + {"initial preserved across calls", []int64{10, 20}, 10, 20, 30, []bool{true, true}}, + {"slowest tracks max not last", []int64{30, 10, 50}, 30, 50, 90, []bool{true, true, true}}, + {"zero interleaved skipped", []int64{10, 0, 20}, 10, 20, 30, []bool{true, false, true}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var a chunkTimingAccumulator + for i, lat := range tt.latencies { + got := a.record(lat) + if got != tt.wantReturn[i] { + t.Errorf("record(%d) = %v, want %v", lat, got, tt.wantReturn[i]) + } + } + if a.initialMs != tt.wantInit { + t.Errorf("initialMs = %d, want %d", a.initialMs, tt.wantInit) + } + if a.slowestMs != tt.wantSlow { + t.Errorf("slowestMs = %d, want %d", a.slowestMs, tt.wantSlow) + } + if a.sumMs != tt.wantSum { + t.Errorf("sumMs = %d, want %d", a.sumMs, tt.wantSum) + } + }) + } +} + +func TestChunkTimingAccumulator_CloudFetchFileCount(t *testing.T) { + var a chunkTimingAccumulator + a.cloudFetchFileCount++ + a.record(0) // sub-ms download — still counted but not timed + a.cloudFetchFileCount++ + a.record(5) + + if a.cloudFetchFileCount != 2 { + t.Errorf("cloudFetchFileCount = %d, want 2", a.cloudFetchFileCount) + } + if a.initialMs != 5 { + t.Errorf("initialMs = %d, want 5 (zero-latency file should not set initial)", a.initialMs) + } +} + func getTestSession() *cli_service.TOpenSessionResp { return &cli_service.TOpenSessionResp{SessionHandle: &cli_service.TSessionHandle{ SessionId: &cli_service.THandleIdentifier{ diff --git a/connector.go b/connector.go index 3e3ad330..9b1e0872 100644 --- a/connector.go +++ b/connector.go @@ -81,18 +81,19 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "") // Initialize telemetry: client config overlay decides; if unset, feature flags decide - conn.telemetry = telemetry.InitializeForConnection( - ctx, - c.cfg.Host, - c.cfg.DriverVersion, - c.client, - c.cfg.EnableTelemetry, - c.cfg.TelemetryBatchSize, - c.cfg.TelemetryFlushInterval, - ) + conn.telemetry = telemetry.InitializeForConnection(ctx, telemetry.TelemetryInitOptions{ + Host: c.cfg.Host, + DriverVersion: c.cfg.DriverVersion, + HTTPClient: c.client, + EnableTelemetry: c.cfg.EnableTelemetry, + BatchSize: c.cfg.TelemetryBatchSize, + FlushInterval: c.cfg.TelemetryFlushInterval, + RetryCount: c.cfg.TelemetryRetryCount, + RetryDelay: c.cfg.TelemetryRetryDelay, + }) if conn.telemetry != nil { log.Debug().Msg("telemetry initialized for connection") - conn.telemetry.RecordOperation(ctx, conn.id, telemetry.OperationTypeCreateSession, sessionLatencyMs, nil) + conn.telemetry.RecordOperation(ctx, conn.id, "", telemetry.OperationTypeCreateSession, sessionLatencyMs, nil) } log.Info().Msgf("connect: host=%s port=%d httpPath=%s serverProtocolVersion=0x%X", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath, session.ServerProtocolVersion) diff --git a/connector_test.go b/connector_test.go index 66f351c9..d613027f 100644 --- a/connector_test.go +++ b/connector_test.go @@ -51,23 +51,24 @@ func TestNewConnector(t *testing.T) { HTTPClient: &http.Client{Transport: roundTripper}, } expectedUserConfig := config.UserConfig{ - Host: host, - Port: port, - Protocol: "https", - AccessToken: accessToken, - Authenticator: &pat.PATAuth{AccessToken: accessToken}, - HTTPPath: "/" + httpPath, - MaxRows: maxRows, - QueryTimeout: timeout, - Catalog: catalog, - Schema: schema, - UserAgentEntry: userAgentEntry, - SessionParams: sessionParams, - RetryMax: 10, - RetryWaitMin: 3 * time.Second, - RetryWaitMax: 60 * time.Second, - Transport: roundTripper, - CloudFetchConfig: expectedCloudFetchConfig, + Host: host, + Port: port, + Protocol: "https", + AccessToken: accessToken, + Authenticator: &pat.PATAuth{AccessToken: accessToken}, + HTTPPath: "/" + httpPath, + MaxRows: maxRows, + QueryTimeout: timeout, + Catalog: catalog, + Schema: schema, + UserAgentEntry: userAgentEntry, + SessionParams: sessionParams, + RetryMax: 10, + RetryWaitMin: 3 * time.Second, + RetryWaitMax: 60 * time.Second, + Transport: roundTripper, + TelemetryRetryCount: -1, + CloudFetchConfig: expectedCloudFetchConfig, } expectedCfg := config.WithDefaults() expectedCfg.DriverVersion = DriverVersion @@ -98,18 +99,19 @@ func TestNewConnector(t *testing.T) { CloudFetchSpeedThresholdMbps: 0.1, } expectedUserConfig := config.UserConfig{ - Host: host, - Port: port, - Protocol: "https", - AccessToken: accessToken, - Authenticator: &pat.PATAuth{AccessToken: accessToken}, - HTTPPath: "/" + httpPath, - MaxRows: maxRows, - SessionParams: sessionParams, - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: expectedCloudFetchConfig, + Host: host, + Port: port, + Protocol: "https", + AccessToken: accessToken, + Authenticator: &pat.PATAuth{AccessToken: accessToken}, + HTTPPath: "/" + httpPath, + MaxRows: maxRows, + SessionParams: sessionParams, + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: expectedCloudFetchConfig, } expectedCfg := config.WithDefaults() expectedCfg.UserConfig = expectedUserConfig @@ -140,18 +142,19 @@ func TestNewConnector(t *testing.T) { CloudFetchSpeedThresholdMbps: 0.1, } expectedUserConfig := config.UserConfig{ - Host: host, - Port: port, - Protocol: "https", - AccessToken: accessToken, - Authenticator: &pat.PATAuth{AccessToken: accessToken}, - HTTPPath: "/" + httpPath, - MaxRows: maxRows, - SessionParams: sessionParams, - RetryMax: -1, - RetryWaitMin: 0, - RetryWaitMax: 0, - CloudFetchConfig: expectedCloudFetchConfig, + Host: host, + Port: port, + Protocol: "https", + AccessToken: accessToken, + Authenticator: &pat.PATAuth{AccessToken: accessToken}, + HTTPPath: "/" + httpPath, + MaxRows: maxRows, + SessionParams: sessionParams, + RetryMax: -1, + RetryWaitMin: 0, + RetryWaitMax: 0, + TelemetryRetryCount: -1, + CloudFetchConfig: expectedCloudFetchConfig, } expectedCfg := config.WithDefaults() expectedCfg.DriverVersion = DriverVersion diff --git a/internal/config/config.go b/internal/config/config.go index e5446ac8..b8be59cb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -104,6 +104,8 @@ type UserConfig struct { EnableTelemetry ConfigValue[bool] TelemetryBatchSize int // 0 = use default (100) TelemetryFlushInterval time.Duration // 0 = use default (5s) + TelemetryRetryCount int // -1 = use default (3); 0 = disable retries; set via telemetry_retry_count + TelemetryRetryDelay time.Duration // 0 = use default (100ms); set via telemetry_retry_delay Transport http.RoundTripper UseLz4Compression bool EnableMetricViewMetadata bool @@ -153,6 +155,8 @@ func (ucfg UserConfig) DeepCopy() UserConfig { EnableTelemetry: ucfg.EnableTelemetry, TelemetryBatchSize: ucfg.TelemetryBatchSize, TelemetryFlushInterval: ucfg.TelemetryFlushInterval, + TelemetryRetryCount: ucfg.TelemetryRetryCount, + TelemetryRetryDelay: ucfg.TelemetryRetryDelay, } } @@ -191,6 +195,10 @@ func (ucfg UserConfig) WithDefaults() UserConfig { // EnableTelemetry defaults to unset (ConfigValue zero value), // meaning telemetry is controlled by server feature flags. + // TelemetryRetryCount uses -1 as "not set" so that an explicit 0 from the + // DSN (meaning "disable retries") is distinguishable from the default. + ucfg.TelemetryRetryCount = -1 + return ucfg } @@ -314,6 +322,19 @@ func ParseDSN(dsn string) (UserConfig, error) { ucfg.TelemetryFlushInterval = d } } + if retryCount, ok, err := params.extractAsInt("telemetry_retry_count"); ok { + if err != nil { + return UserConfig{}, err + } + if retryCount >= 0 { + ucfg.TelemetryRetryCount = retryCount + } + } + if retryDelay, ok := params.extract("telemetry_retry_delay"); ok { + if d, err := time.ParseDuration(retryDelay); err == nil && d > 0 { + ucfg.TelemetryRetryDelay = d + } + } // for timezone we do a case insensitive key match. // We use getNoCase because we want to leave timezone in the params so that it will also diff --git a/internal/config/config_test.go b/internal/config/config_test.go index abea52b0..b5e52e4c 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -33,18 +33,19 @@ func TestParseConfig(t *testing.T) { name: "base case", args: args{dsn: "token:supersecret@example.cloud.databricks.com:443/sql/1.0/endpoints/12346a5b5b0e123a"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 443, - MaxRows: defaultMaxRows, - Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, - AccessToken: "supersecret", - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", - SessionParams: make(map[string]string), - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 443, + MaxRows: defaultMaxRows, + Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, + AccessToken: "supersecret", + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://example.cloud.databricks.com:443/sql/1.0/endpoints/12346a5b5b0e123a", wantErr: false, @@ -53,18 +54,19 @@ func TestParseConfig(t *testing.T) { name: "with https scheme", args: args{dsn: "https://token:supersecret@example.cloud.databricks.com:443/sql/1.0/endpoints/12346a5b5b0e123a"}, //nolint:gosec // G101: test DSN with example password, not a real credential wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 443, - MaxRows: defaultMaxRows, - AccessToken: "supersecret", - Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", - SessionParams: make(map[string]string), - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 443, + MaxRows: defaultMaxRows, + AccessToken: "supersecret", + Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://example.cloud.databricks.com:443/sql/1.0/endpoints/12346a5b5b0e123a", wantErr: false, @@ -73,17 +75,18 @@ func TestParseConfig(t *testing.T) { name: "with http scheme", args: args{dsn: "http://localhost:8080/sql/1.0/endpoints/12346a5b5b0e123a"}, wantCfg: UserConfig{ - Protocol: "http", - Host: "localhost", - Port: 8080, - MaxRows: defaultMaxRows, - Authenticator: &noop.NoopAuth{}, - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", - SessionParams: make(map[string]string), - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "http", + Host: "localhost", + Port: 8080, + MaxRows: defaultMaxRows, + Authenticator: &noop.NoopAuth{}, + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantErr: false, wantURL: "http://localhost:8080/sql/1.0/endpoints/12346a5b5b0e123a", @@ -92,16 +95,17 @@ func TestParseConfig(t *testing.T) { name: "with localhost", args: args{dsn: "http://localhost:8080"}, wantCfg: UserConfig{ - Protocol: "http", - Host: "localhost", - Port: 8080, - Authenticator: &noop.NoopAuth{}, - MaxRows: defaultMaxRows, - SessionParams: make(map[string]string), - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "http", + Host: "localhost", + Port: 8080, + Authenticator: &noop.NoopAuth{}, + MaxRows: defaultMaxRows, + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantErr: false, wantURL: "http://localhost:8080", @@ -110,19 +114,20 @@ func TestParseConfig(t *testing.T) { name: "with query params", args: args{dsn: "token:supersecret@example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a?timeout=100&maxRows=1000"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 8000, - AccessToken: "supersecret", - Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", - QueryTimeout: 100 * time.Second, - MaxRows: 1000, - SessionParams: make(map[string]string), - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 8000, + AccessToken: "supersecret", + Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", + QueryTimeout: 100 * time.Second, + MaxRows: 1000, + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a", wantErr: false, @@ -131,20 +136,21 @@ func TestParseConfig(t *testing.T) { name: "with query params and session params", args: args{dsn: "token:supersecret@example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a?timeout=100&maxRows=1000&timezone=America/Vancouver&QUERY_TAGS=team:testing,driver:go"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 8000, - AccessToken: "supersecret", - Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", - QueryTimeout: 100 * time.Second, - MaxRows: 1000, - Location: tz, - SessionParams: map[string]string{"timezone": "America/Vancouver", "QUERY_TAGS": "team:testing,driver:go"}, - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 8000, + AccessToken: "supersecret", + Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", + QueryTimeout: 100 * time.Second, + MaxRows: 1000, + Location: tz, + SessionParams: map[string]string{"timezone": "America/Vancouver", "QUERY_TAGS": "team:testing,driver:go"}, + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a", wantErr: false, @@ -153,17 +159,18 @@ func TestParseConfig(t *testing.T) { name: "bare", args: args{dsn: "example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Authenticator: &noop.NoopAuth{}, - Port: 8000, - MaxRows: defaultMaxRows, - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", - SessionParams: make(map[string]string), - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "https", + Host: "example.cloud.databricks.com", + Authenticator: &noop.NoopAuth{}, + Port: 8000, + MaxRows: defaultMaxRows, + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a", wantErr: false, @@ -172,19 +179,20 @@ func TestParseConfig(t *testing.T) { name: "with catalog", args: args{dsn: "token:supersecret@example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123b?catalog=default"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 8000, - MaxRows: defaultMaxRows, - AccessToken: "supersecret", - Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123b", - Catalog: "default", - SessionParams: make(map[string]string), - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 8000, + MaxRows: defaultMaxRows, + AccessToken: "supersecret", + Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123b", + Catalog: "default", + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123b", wantErr: false, @@ -193,19 +201,20 @@ func TestParseConfig(t *testing.T) { name: "with user agent entry", args: args{dsn: "token:supersecret@example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123b?userAgentEntry=partner-name"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 8000, - MaxRows: defaultMaxRows, - AccessToken: "supersecret", - Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123b", - UserAgentEntry: "partner-name", - SessionParams: make(map[string]string), - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 8000, + MaxRows: defaultMaxRows, + AccessToken: "supersecret", + Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123b", + UserAgentEntry: "partner-name", + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123b", wantErr: false, @@ -214,19 +223,20 @@ func TestParseConfig(t *testing.T) { name: "with schema", args: args{dsn: "token:supersecret2@example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a?schema=system"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 8000, - MaxRows: defaultMaxRows, - AccessToken: "supersecret2", - Authenticator: &pat.PATAuth{AccessToken: "supersecret2"}, - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", - Schema: "system", - SessionParams: make(map[string]string), - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 8000, + MaxRows: defaultMaxRows, + AccessToken: "supersecret2", + Authenticator: &pat.PATAuth{AccessToken: "supersecret2"}, + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", + Schema: "system", + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a", wantErr: false, @@ -235,17 +245,18 @@ func TestParseConfig(t *testing.T) { name: "with useCloudFetch", args: args{dsn: "token:supersecret@example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123b?useCloudFetch=true"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 8000, - MaxRows: defaultMaxRows, - AccessToken: "supersecret", - Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123b", - SessionParams: make(map[string]string), - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 8000, + MaxRows: defaultMaxRows, + AccessToken: "supersecret", + Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123b", + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, CloudFetchConfig: CloudFetchConfig{ UseCloudFetch: true, MaxDownloadThreads: 10, @@ -260,17 +271,18 @@ func TestParseConfig(t *testing.T) { name: "with useCloudFetch and maxDownloadThreads", args: args{dsn: "token:supersecret@example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123b?useCloudFetch=true&maxDownloadThreads=15"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 8000, - MaxRows: defaultMaxRows, - AccessToken: "supersecret", - Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123b", - SessionParams: make(map[string]string), - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 8000, + MaxRows: defaultMaxRows, + AccessToken: "supersecret", + Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123b", + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, CloudFetchConfig: CloudFetchConfig{ UseCloudFetch: true, MaxDownloadThreads: 15, @@ -285,21 +297,22 @@ func TestParseConfig(t *testing.T) { name: "with everything", args: args{dsn: "token:supersecret2@example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a?catalog=default&schema=system&userAgentEntry=partner-name&timeout=100&maxRows=1000&ANSI_MODE=true&useCloudFetch=true&maxDownloadThreads=15"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 8000, - AccessToken: "supersecret2", - Authenticator: &pat.PATAuth{AccessToken: "supersecret2"}, - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", - QueryTimeout: 100 * time.Second, - MaxRows: 1000, - UserAgentEntry: "partner-name", - Catalog: "default", - Schema: "system", - SessionParams: map[string]string{"ANSI_MODE": "true"}, - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 8000, + AccessToken: "supersecret2", + Authenticator: &pat.PATAuth{AccessToken: "supersecret2"}, + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", + QueryTimeout: 100 * time.Second, + MaxRows: 1000, + UserAgentEntry: "partner-name", + Catalog: "default", + Schema: "system", + SessionParams: map[string]string{"ANSI_MODE": "true"}, + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, CloudFetchConfig: CloudFetchConfig{ UseCloudFetch: true, MaxDownloadThreads: 15, @@ -314,17 +327,18 @@ func TestParseConfig(t *testing.T) { name: "missing http path", args: args{dsn: "token:supersecret@example.cloud.databricks.com:443"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 443, - MaxRows: defaultMaxRows, - AccessToken: "supersecret", - Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, - SessionParams: make(map[string]string), - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 443, + MaxRows: defaultMaxRows, + AccessToken: "supersecret", + Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://example.cloud.databricks.com:443", wantErr: false, @@ -334,20 +348,21 @@ func TestParseConfig(t *testing.T) { name: "missing http path 2", args: args{dsn: "token:supersecret2@example.cloud.databricks.com:443?catalog=default&schema=system&timeout=100&maxRows=1000"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 443, - AccessToken: "supersecret2", - Authenticator: &pat.PATAuth{AccessToken: "supersecret2"}, - QueryTimeout: 100 * time.Second, - MaxRows: 1000, - Catalog: "default", - Schema: "system", - SessionParams: make(map[string]string), - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 443, + AccessToken: "supersecret2", + Authenticator: &pat.PATAuth{AccessToken: "supersecret2"}, + QueryTimeout: 100 * time.Second, + MaxRows: 1000, + Catalog: "default", + Schema: "system", + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://example.cloud.databricks.com:443", wantErr: false, @@ -393,19 +408,20 @@ func TestParseConfig(t *testing.T) { name: "missing host", args: args{dsn: "token:supersecret2@:443?catalog=default&schema=system&timeout=100&maxRows=1000"}, wantCfg: UserConfig{ - Port: 443, - Protocol: "https", - AccessToken: "supersecret2", - Authenticator: &pat.PATAuth{AccessToken: "supersecret2"}, - MaxRows: 1000, - QueryTimeout: 100 * time.Second, - Catalog: "default", - Schema: "system", - SessionParams: make(map[string]string), - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Port: 443, + Protocol: "https", + AccessToken: "supersecret2", + Authenticator: &pat.PATAuth{AccessToken: "supersecret2"}, + MaxRows: 1000, + QueryTimeout: 100 * time.Second, + Catalog: "default", + Schema: "system", + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://:443", wantErr: false, @@ -415,22 +431,23 @@ func TestParseConfig(t *testing.T) { name: "with accessToken param", args: args{dsn: "example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a?catalog=default&schema=system&userAgentEntry=partner-name&timeout=100&maxRows=1000&ANSI_MODE=true&accessToken=supersecret2"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 8000, - AccessToken: "supersecret2", - Authenticator: &pat.PATAuth{AccessToken: "supersecret2"}, - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", - QueryTimeout: 100 * time.Second, - MaxRows: 1000, - UserAgentEntry: "partner-name", - Catalog: "default", - Schema: "system", - SessionParams: map[string]string{"ANSI_MODE": "true"}, - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 8000, + AccessToken: "supersecret2", + Authenticator: &pat.PATAuth{AccessToken: "supersecret2"}, + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", + QueryTimeout: 100 * time.Second, + MaxRows: 1000, + UserAgentEntry: "partner-name", + Catalog: "default", + Schema: "system", + SessionParams: map[string]string{"ANSI_MODE": "true"}, + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a", wantErr: false, @@ -439,22 +456,23 @@ func TestParseConfig(t *testing.T) { name: "with accessToken param and client id/secret params", args: args{dsn: "example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a?catalog=default&schema=system&userAgentEntry=partner-name&timeout=100&maxRows=1000&ANSI_MODE=true&accessToken=supersecret2&clientId=client_id&clientSecret=client_secret"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 8000, - AccessToken: "supersecret2", - Authenticator: &pat.PATAuth{AccessToken: "supersecret2"}, - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", - QueryTimeout: 100 * time.Second, - MaxRows: 1000, - UserAgentEntry: "partner-name", - Catalog: "default", - Schema: "system", - SessionParams: map[string]string{"ANSI_MODE": "true"}, - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 8000, + AccessToken: "supersecret2", + Authenticator: &pat.PATAuth{AccessToken: "supersecret2"}, + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", + QueryTimeout: 100 * time.Second, + MaxRows: 1000, + UserAgentEntry: "partner-name", + Catalog: "default", + Schema: "system", + SessionParams: map[string]string{"ANSI_MODE": "true"}, + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a", wantErr: false, @@ -463,22 +481,23 @@ func TestParseConfig(t *testing.T) { name: "authType unknown with accessTokenParam", args: args{dsn: "example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a?authType=unknown&catalog=default&schema=system&userAgentEntry=partner-name&timeout=100&maxRows=1000&ANSI_MODE=true&accessToken=supersecret2&clientId=client_id&clientSecret=client_secret"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 8000, - AccessToken: "supersecret2", - Authenticator: &pat.PATAuth{AccessToken: "supersecret2"}, - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", - QueryTimeout: 100 * time.Second, - MaxRows: 1000, - UserAgentEntry: "partner-name", - Catalog: "default", - Schema: "system", - SessionParams: map[string]string{"ANSI_MODE": "true"}, - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 8000, + AccessToken: "supersecret2", + Authenticator: &pat.PATAuth{AccessToken: "supersecret2"}, + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", + QueryTimeout: 100 * time.Second, + MaxRows: 1000, + UserAgentEntry: "partner-name", + Catalog: "default", + Schema: "system", + SessionParams: map[string]string{"ANSI_MODE": "true"}, + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a", wantErr: false, @@ -517,25 +536,47 @@ func TestParseConfig(t *testing.T) { name: "authType unknown with client id/secret", args: args{dsn: "example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a?authType=unknown&clientId=client_id&clientSecret=client_secret&catalog=default&schema=system&userAgentEntry=partner-name&timeout=100&maxRows=1000&ANSI_MODE=true"}, wantCfg: UserConfig{ - Protocol: "https", - Host: "example.cloud.databricks.com", - Port: 8000, - Authenticator: m2m.NewAuthenticator("client_id", "client_secret", "example.cloud.databricks.com"), - HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", - QueryTimeout: 100 * time.Second, - MaxRows: 1000, - UserAgentEntry: "partner-name", - Catalog: "default", - Schema: "system", - SessionParams: map[string]string{"ANSI_MODE": "true"}, - RetryMax: 4, - RetryWaitMin: 1 * time.Second, - RetryWaitMax: 30 * time.Second, - CloudFetchConfig: defCloudConfig, + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 8000, + Authenticator: m2m.NewAuthenticator("client_id", "client_secret", "example.cloud.databricks.com"), + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", + QueryTimeout: 100 * time.Second, + MaxRows: 1000, + UserAgentEntry: "partner-name", + Catalog: "default", + Schema: "system", + SessionParams: map[string]string{"ANSI_MODE": "true"}, + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: -1, + CloudFetchConfig: defCloudConfig, }, wantURL: "https://example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a", wantErr: false, }, + { + name: "with telemetry_retry_count=0 (disable retries)", + args: args{dsn: "token:supersecret@example.cloud.databricks.com:443/sql/1.0/endpoints/12346a5b5b0e123a?telemetry_retry_count=0"}, + wantCfg: UserConfig{ + Protocol: "https", + Host: "example.cloud.databricks.com", + Port: 443, + MaxRows: defaultMaxRows, + Authenticator: &pat.PATAuth{AccessToken: "supersecret"}, + AccessToken: "supersecret", + HTTPPath: "/sql/1.0/endpoints/12346a5b5b0e123a", + SessionParams: make(map[string]string), + RetryMax: 4, + RetryWaitMin: 1 * time.Second, + RetryWaitMax: 30 * time.Second, + TelemetryRetryCount: 0, + CloudFetchConfig: defCloudConfig, + }, + wantURL: "https://example.cloud.databricks.com:443/sql/1.0/endpoints/12346a5b5b0e123a", + wantErr: false, + }, { name: "authType m2m with accessToken", args: args{dsn: "example.cloud.databricks.com:8000/sql/1.0/endpoints/12346a5b5b0e123a?authType=oauthm2m&accessToken=supersecret2&catalog=default&schema=system&userAgentEntry=partner-name&timeout=100&maxRows=1000&ANSI_MODE=true"}, diff --git a/internal/rows/arrowbased/arrowIPCStreamIterator.go b/internal/rows/arrowbased/arrowIPCStreamIterator.go index e3ac82c5..2aeb0d28 100644 --- a/internal/rows/arrowbased/arrowIPCStreamIterator.go +++ b/internal/rows/arrowbased/arrowIPCStreamIterator.go @@ -138,7 +138,7 @@ func (ri *arrowIPCStreamIterator) fetchNextData() error { func (ri *arrowIPCStreamIterator) newIPCStreamIterator(fr *cli_service.TFetchResultsResp) (IPCStreamIterator, error) { rowSet := fr.Results if len(rowSet.ResultLinks) > 0 { - return NewCloudIPCStreamIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg) + return NewCloudIPCStreamIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg, nil) } else { return NewLocalIPCStreamIterator(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg) } diff --git a/internal/rows/arrowbased/arrowRecordIterator.go b/internal/rows/arrowbased/arrowRecordIterator.go index d0b620db..4fcd97f6 100644 --- a/internal/rows/arrowbased/arrowRecordIterator.go +++ b/internal/rows/arrowbased/arrowRecordIterator.go @@ -169,7 +169,7 @@ func (ri *arrowRecordIterator) getBatchIterator() error { func (ri *arrowRecordIterator) newBatchIterator(fr *cli_service.TFetchResultsResp) (BatchIterator, error) { rowSet := fr.Results if len(rowSet.ResultLinks) > 0 { - return NewCloudBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg) + return NewCloudBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg, nil) } else { return NewLocalBatchIterator(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg) } diff --git a/internal/rows/arrowbased/arrowRows.go b/internal/rows/arrowbased/arrowRows.go index c4a98e69..8a285f9e 100644 --- a/internal/rows/arrowbased/arrowRows.go +++ b/internal/rows/arrowbased/arrowRows.go @@ -80,8 +80,10 @@ type arrowRowScanner struct { // Make sure arrowRowScanner fulfills the RowScanner interface var _ rowscanner.RowScanner = (*arrowRowScanner)(nil) -// NewArrowRowScanner returns an instance of RowScanner which handles arrow format results -func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp, rowSet *cli_service.TRowSet, cfg *config.Config, logger *dbsqllog.DBSQLLogger, ctx context.Context) (rowscanner.RowScanner, dbsqlerr.DBError) { +// NewArrowRowScanner returns an instance of RowScanner which handles arrow format results. +// onCloudFetchDownload is an optional callback invoked for each CloudFetch S3 file download +// with the download duration in milliseconds. Pass nil for non-telemetry paths. +func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp, rowSet *cli_service.TRowSet, cfg *config.Config, logger *dbsqllog.DBSQLLogger, ctx context.Context, onCloudFetchDownload func(downloadMs int64)) (rowscanner.RowScanner, dbsqlerr.DBError) { // we take a passed in logger, rather than just using the global from dbsqllog, so that the containing rows // instance can pass in a logger with context such as correlation ID and operation ID @@ -119,7 +121,7 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp for _, resultLink := range rowSet.ResultLinks { logger.Debug().Msgf("- start row offset: %d, row count: %d", resultLink.StartRowOffset, resultLink.RowCount) } - bi, err2 = NewCloudBatchIterator(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg) + bi, err2 = NewCloudBatchIterator(context.Background(), rowSet.ResultLinks, rowSet.StartRowOffset, cfg, onCloudFetchDownload) } else { bi, err2 = NewLocalBatchIterator(context.Background(), rowSet.ArrowBatches, rowSet.StartRowOffset, schemaBytes, cfg) } diff --git a/internal/rows/arrowbased/arrowRows_test.go b/internal/rows/arrowbased/arrowRows_test.go index 84b27426..705bfec7 100644 --- a/internal/rows/arrowbased/arrowRows_test.go +++ b/internal/rows/arrowbased/arrowRows_test.go @@ -213,19 +213,19 @@ func TestArrowRowScanner(t *testing.T) { schema := &cli_service.TTableSchema{} metadataResp := getMetadataResp(schema) - ars, err := NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background()) + ars, err := NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background(), nil) assert.NotNil(t, ars) assert.Nil(t, err) assert.Equal(t, int64(0), ars.NRows()) rowSet.ArrowBatches = []*cli_service.TSparkArrowBatch{} - ars, err = NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background()) + ars, err = NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background(), nil) assert.NotNil(t, ars) assert.Nil(t, err) assert.Equal(t, int64(0), ars.NRows()) rowSet.ArrowBatches = []*cli_service.TSparkArrowBatch{{RowCount: 2}, {RowCount: 3}} - ars, _ = NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background()) + ars, _ = NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background(), nil) assert.NotNil(t, ars) assert.Equal(t, int64(5), ars.NRows()) }) @@ -237,7 +237,7 @@ func TestArrowRowScanner(t *testing.T) { schema := getAllTypesSchema() metadataResp := getMetadataResp(schema) - d, _ := NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background()) + d, _ := NewArrowRowScanner(metadataResp, rowSet, nil, nil, context.Background(), nil) ars := d.(*arrowRowScanner) @@ -313,7 +313,7 @@ func TestArrowRowScanner(t *testing.T) { cfg.UseArrowNativeTimestamp = true cfg.UseArrowNativeDecimal = true - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) ars := d.(*arrowRowScanner) @@ -344,14 +344,14 @@ func TestArrowRowScanner(t *testing.T) { cfg.UseArrowNativeTimestamp = true cfg.UseArrowNativeDecimal = true - _, err := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.Nil(t, err) // missing type qualifiers schema = getAllTypesSchema() schema.Columns[13].TypeDesc.Types[0].PrimitiveEntry.TypeQualifiers = nil metadataResp.Schema = schema - _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.NotNil(t, err) assert.True(t, strings.HasPrefix(err.Error(), "databricks: driver error: "+errArrowRowsConvertSchema+": "+errArrowRowsInvalidDecimalType)) @@ -359,7 +359,7 @@ func TestArrowRowScanner(t *testing.T) { schema = getAllTypesSchema() schema.Columns[13].TypeDesc.Types[0].PrimitiveEntry.TypeQualifiers.Qualifiers = nil metadataResp.Schema = schema - _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.NotNil(t, err) assert.True(t, strings.HasPrefix(err.Error(), "databricks: driver error: "+errArrowRowsConvertSchema+": "+errArrowRowsInvalidDecimalType)) @@ -367,7 +367,7 @@ func TestArrowRowScanner(t *testing.T) { schema = getAllTypesSchema() schema.Columns[13].TypeDesc.Types[0].PrimitiveEntry.TypeQualifiers.Qualifiers = map[string]*cli_service.TTypeQualifierValue{} metadataResp.Schema = schema - _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.NotNil(t, err) assert.True(t, strings.HasPrefix(err.Error(), "databricks: driver error: "+errArrowRowsConvertSchema+": "+errArrowRowsInvalidDecimalType)) @@ -375,7 +375,7 @@ func TestArrowRowScanner(t *testing.T) { schema = getAllTypesSchema() schema.Columns[13].TypeDesc.Types[0].PrimitiveEntry.TypeQualifiers.Qualifiers["precision"] = nil metadataResp.Schema = schema - _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.NotNil(t, err) assert.True(t, strings.HasPrefix(err.Error(), "databricks: driver error: "+errArrowRowsConvertSchema+": "+errArrowRowsInvalidDecimalType)) @@ -383,7 +383,7 @@ func TestArrowRowScanner(t *testing.T) { schema = getAllTypesSchema() schema.Columns[13].TypeDesc.Types[0].PrimitiveEntry.TypeQualifiers.Qualifiers["precision"].I32Value = nil metadataResp.Schema = schema - _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.NotNil(t, err) assert.True(t, strings.HasPrefix(err.Error(), "databricks: driver error: "+errArrowRowsConvertSchema+": "+errArrowRowsInvalidDecimalType)) @@ -391,7 +391,7 @@ func TestArrowRowScanner(t *testing.T) { schema = getAllTypesSchema() schema.Columns[13].TypeDesc.Types[0].PrimitiveEntry.TypeQualifiers.Qualifiers["scale"] = nil metadataResp.Schema = schema - _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.NotNil(t, err) assert.True(t, strings.HasPrefix(err.Error(), "databricks: driver error: "+errArrowRowsConvertSchema+": "+errArrowRowsInvalidDecimalType)) @@ -399,7 +399,7 @@ func TestArrowRowScanner(t *testing.T) { schema = getAllTypesSchema() schema.Columns[13].TypeDesc.Types[0].PrimitiveEntry.TypeQualifiers.Qualifiers["scale"].I32Value = nil metadataResp.Schema = schema - _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + _, err = NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.NotNil(t, err) msg := err.Error() pre := "databricks: driver error: " + errArrowRowsConvertSchema + ": " + errArrowRowsInvalidDecimalType @@ -414,7 +414,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseArrowBatches = true - d, err1 := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + d, err1 := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.Nil(t, err1) ars := d.(*arrowRowScanner) @@ -444,7 +444,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseArrowBatches = true - d, err := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + d, err := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) require.Nil(t, err) d.Close() @@ -483,7 +483,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseLz4Compression = false - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) ars := d.(*arrowRowScanner) @@ -553,7 +553,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseLz4Compression = false - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil, nil) ars := d.(*arrowRowScanner) @@ -591,7 +591,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseLz4Compression = false - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil, nil) ars := d.(*arrowRowScanner) @@ -630,7 +630,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseLz4Compression = false - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil, nil) ars := d.(*arrowRowScanner) @@ -671,7 +671,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseLz4Compression = false - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, nil, nil) ars := d.(*arrowRowScanner) @@ -707,7 +707,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseLz4Compression = false - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) ars := d.(*arrowRowScanner) @@ -856,7 +856,7 @@ func TestArrowRowScanner(t *testing.T) { cfg := config.Config{} cfg.UseLz4Compression = false - d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background()) + d, _ := NewArrowRowScanner(metadataResp, rowSet, &cfg, nil, context.Background(), nil) ars := d.(*arrowRowScanner) ars.UseArrowNativeComplexTypes = true @@ -933,7 +933,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = false config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -961,7 +961,7 @@ func TestArrowRowScanner(t *testing.T) { config := config.WithDefaults() config.UseArrowNativeComplexTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -984,7 +984,7 @@ func TestArrowRowScanner(t *testing.T) { config := config.WithDefaults() config.UseArrowNativeComplexTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1016,7 +1016,7 @@ func TestArrowRowScanner(t *testing.T) { config := config.WithDefaults() config.UseArrowNativeComplexTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1036,7 +1036,7 @@ func TestArrowRowScanner(t *testing.T) { config := config.WithDefaults() config.UseArrowNativeComplexTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1106,7 +1106,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = false config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1134,7 +1134,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = false config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1186,7 +1186,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err1 := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err1 := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err1) ars := d.(*arrowRowScanner) @@ -1227,7 +1227,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1286,7 +1286,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1319,7 +1319,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1358,7 +1358,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1397,7 +1397,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1433,7 +1433,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + d, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) ars := d.(*arrowRowScanner) @@ -1512,7 +1512,7 @@ func TestArrowRowScanner(t *testing.T) { config.UseArrowNativeComplexTypes = true config.UseArrowNativeDecimal = false config.UseArrowNativeIntervalTypes = false - _, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background()) + _, err := NewArrowRowScanner(executeStatementResp.DirectResults.ResultSetMetadata, executeStatementResp.DirectResults.ResultSet.Results, config, nil, context.Background(), nil) assert.Nil(t, err) }) @@ -1550,6 +1550,7 @@ func TestArrowRowScanner(t *testing.T) { cfg, logger, context.Background(), + nil, ) assert.Nil(t, err) diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 545aa6e7..58fcc059 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -33,6 +33,7 @@ func NewCloudIPCStreamIterator( files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config, + onFileDownloaded func(downloadMs int64), ) (IPCStreamIterator, dbsqlerr.DBError) { httpClient := http.DefaultClient if cfg.HTTPClient != nil { @@ -40,12 +41,13 @@ func NewCloudIPCStreamIterator( } bi := &cloudIPCStreamIterator{ - ctx: ctx, - cfg: cfg, - startRowOffset: startRowOffset, - pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](), - downloadTasks: NewQueue[cloudFetchDownloadTask](), - httpClient: httpClient, + ctx: ctx, + cfg: cfg, + startRowOffset: startRowOffset, + pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](), + downloadTasks: NewQueue[cloudFetchDownloadTask](), + httpClient: httpClient, + onFileDownloaded: onFileDownloaded, } for _, link := range files { @@ -61,8 +63,9 @@ func NewCloudBatchIterator( files []*cli_service.TSparkArrowResultLink, startRowOffset int64, cfg *config.Config, + onFileDownloaded func(downloadMs int64), ) (BatchIterator, dbsqlerr.DBError) { - ipcIterator, err := NewCloudIPCStreamIterator(ctx, files, startRowOffset, cfg) + ipcIterator, err := NewCloudIPCStreamIterator(ctx, files, startRowOffset, cfg, onFileDownloaded) if err != nil { return nil, err } @@ -141,12 +144,13 @@ func (bi *localIPCStreamIterator) Close() { } type cloudIPCStreamIterator struct { - ctx context.Context - cfg *config.Config - startRowOffset int64 - pendingLinks Queue[cli_service.TSparkArrowResultLink] - downloadTasks Queue[cloudFetchDownloadTask] - httpClient *http.Client + ctx context.Context + cfg *config.Config + startRowOffset int64 + pendingLinks Queue[cli_service.TSparkArrowResultLink] + downloadTasks Queue[cloudFetchDownloadTask] + httpClient *http.Client + onFileDownloaded func(downloadMs int64) // nil for non-telemetry paths } var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil) @@ -180,7 +184,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { return nil, io.EOF } - data, err := task.GetResult() + data, downloadMs, err := task.GetResult() // once we've got an errored out task - cancel the remaining ones if err != nil { @@ -190,6 +194,15 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { // explicitly call cancel function on successfully completed task to avoid context leak task.cancel() + + // Notify telemetry with per-file download time (matches JDBC's per-chunk HTTP GET timing). + // Always invoke for successfully completed downloads so the caller can count files; + // sub-millisecond downloads report downloadMs=0 and the caller decides whether to + // include them in timing aggregation. + if bi.onFileDownloaded != nil { + bi.onFileDownloaded(downloadMs) + } + return data, nil } @@ -206,8 +219,9 @@ func (bi *cloudIPCStreamIterator) Close() { } type cloudFetchDownloadTaskResult struct { - data io.Reader - err error + data io.Reader + err error + downloadMs int64 // wall-clock time for HTTP GET + decompression } type cloudFetchDownloadTask struct { @@ -221,7 +235,7 @@ type cloudFetchDownloadTask struct { httpClient *http.Client } -func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) { +func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, int64, error) { link := cft.link result, ok := <-cft.resultChan @@ -233,14 +247,14 @@ func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) { link.RowCount, result.err.Error(), ) - return nil, result.err + return nil, 0, result.err } logger.Debug().Msgf( "CloudFetch: received data for link at offset %d row count %d", link.StartRowOffset, link.RowCount, ) - return result.data, nil + return result.data, result.downloadMs, nil } // This branch should never be reached. If you see this message - something got really wrong @@ -249,7 +263,7 @@ func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) { link.StartRowOffset, link.RowCount, ) - return nil, nil + return nil, 0, nil } func (cft *cloudFetchDownloadTask) Run() { @@ -261,6 +275,7 @@ func (cft *cloudFetchDownloadTask) Run() { cft.link.StartRowOffset, cft.link.RowCount, ) + downloadStart := time.Now() data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient) if err != nil { cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err} @@ -270,6 +285,7 @@ func (cft *cloudFetchDownloadTask) Run() { // Read all data into memory before closing buf, err := io.ReadAll(getReader(data, cft.useLz4Compression)) data.Close() //nolint:errcheck,gosec // G104: close after reading data + downloadMs := time.Since(downloadStart).Milliseconds() if err != nil { cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err} return @@ -281,7 +297,7 @@ func (cft *cloudFetchDownloadTask) Run() { cft.link.RowCount, ) - cft.resultChan <- cloudFetchDownloadTaskResult{data: bytes.NewReader(buf), err: nil} + cft.resultChan <- cloudFetchDownloadTaskResult{data: bytes.NewReader(buf), err: nil, downloadMs: downloadMs} }() } diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index 5b230610..dc0d68db 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync" "testing" "time" @@ -77,6 +78,7 @@ func TestCloudFetchIterator(t *testing.T) { links, startRowOffset, cfg, + nil, ) if err != nil { panic(err) @@ -151,6 +153,7 @@ func TestCloudFetchIterator(t *testing.T) { links, startRowOffset, cfg, + nil, ) if err != nil { panic(err) @@ -209,6 +212,7 @@ func TestCloudFetchIterator(t *testing.T) { links, startRowOffset, cfg, + nil, ) if err != nil { panic(err) @@ -283,6 +287,7 @@ func TestCloudFetchIterator(t *testing.T) { }}, startRowOffset, cfg, + nil, ) assert.Nil(t, err) @@ -321,6 +326,7 @@ func TestCloudFetchIterator(t *testing.T) { }}, startRowOffset, cfg, + nil, ) assert.Nil(t, err) @@ -334,6 +340,115 @@ func TestCloudFetchIterator(t *testing.T) { }) } +// TestCloudFetchIterator_OnFileDownloaded_CallbackInvokedWithPositiveDuration verifies +// that the onFileDownloaded telemetry callback is called once per downloaded S3 file with +// a positive downloadMs value. +// +// This covers the CloudFetch timing fix where per-S3-file download durations are measured +// and reported as initial_chunk_latency_ms / slowest_chunk_latency_ms / sum_chunks_download_time_ms +// in the telemetry payload. +func TestCloudFetchIterator_OnFileDownloaded_CallbackInvokedWithPositiveDuration(t *testing.T) { + // Serve real arrow bytes so the iterator can parse them successfully. + arrowBytes := generateMockArrowBytes(generateArrowRecord()) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Add a tiny sleep so the measured download time is reliably > 0ms. + time.Sleep(2 * time.Millisecond) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(arrowBytes) + })) + defer server.Close() + + startRowOffset := int64(0) + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }, + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset + 1, + RowCount: 1, + }, + } + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + + var callbackMu sync.Mutex + var downloadDurations []int64 + + onFileDownloaded := func(downloadMs int64) { + callbackMu.Lock() + downloadDurations = append(downloadDurations, downloadMs) + callbackMu.Unlock() + } + + bi, err := NewCloudBatchIterator(context.Background(), links, startRowOffset, cfg, onFileDownloaded) + assert.Nil(t, err) + + // Consume all batches to trigger the downloads. + for bi.HasNext() { + _, nextErr := bi.Next() + assert.Nil(t, nextErr) + } + + callbackMu.Lock() + durations := make([]int64, len(downloadDurations)) + copy(durations, downloadDurations) + callbackMu.Unlock() + + // Callback must be invoked once per link. + assert.Equal(t, len(links), len(durations), + "onFileDownloaded must be called once per downloaded file") + + // Each reported duration must be positive (the server adds a 2ms delay). + for i, d := range durations { + assert.Greater(t, d, int64(0), + "onFileDownloaded must report positive downloadMs for file %d, got %d", i, d) + } +} + +// TestCloudFetchIterator_OnFileDownloaded_NilCallbackDoesNotPanic verifies that passing +// nil for onFileDownloaded (non-telemetry paths) does not cause a panic during iteration. +func TestCloudFetchIterator_OnFileDownloaded_NilCallbackDoesNotPanic(t *testing.T) { + arrowBytes := generateMockArrowBytes(generateArrowRecord()) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(arrowBytes) + })) + defer server.Close() + + startRowOffset := int64(0) + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }, + } + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + + // nil callback — must not panic + bi, err := NewCloudBatchIterator(context.Background(), links, startRowOffset, cfg, nil) + assert.Nil(t, err) + + assert.NotPanics(t, func() { + for bi.HasNext() { + _, _ = bi.Next() + } + }, "nil onFileDownloaded must not cause a panic") +} + func generateArrowRecord() arrow.Record { mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) diff --git a/internal/rows/rows.go b/internal/rows/rows.go index 7ee2db44..bd0c2605 100644 --- a/internal/rows/rows.go +++ b/internal/rows/rows.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "database/sql/driver" + "io" "math" "reflect" "time" @@ -59,9 +60,21 @@ type rows struct { ctx context.Context // Telemetry tracking - telemetryUpdate func(chunkCount int, bytesDownloaded int64) - chunkCount int - bytesDownloaded int64 + // telemetryUpdate is called after each chunk is fetched with: + // chunkCount: total chunks fetched so far (including direct results) + // bytesDownloaded: cumulative bytes + // chunkIndex: 0-based index of the chunk just fetched + // chunkLatencyMs: fetch latency for this chunk (0 for direct results or CloudFetch pages) + // totalChunksPresent: server-reported total, 0 if unknown + telemetryUpdate func(chunkCount int, bytesDownloaded int64, chunkIndex int, chunkLatencyMs int64, totalChunksPresent int32) + // cloudFetchCallback is invoked per S3 file download for CloudFetch result sets. + // It receives the individual file download duration so that telemetry can track + // initial/slowest/sum download times matching JDBC's per-chunk HTTP GET timing. + cloudFetchCallback func(downloadMs int64) + closeCallback func(latencyMs int64, chunkCount int, iterErr error, closeErr error) + chunkCount int + bytesDownloaded int64 + iterationErr error // first error from Next()/fetchResultPage, passed to closeCallback } var _ driver.Rows = (*rows)(nil) @@ -71,13 +84,26 @@ var _ driver.RowsColumnTypeNullable = (*rows)(nil) var _ driver.RowsColumnTypeLength = (*rows)(nil) var _ dbsqlrows.Rows = (*rows)(nil) +// TelemetryCallbacks bundles the optional telemetry hooks passed into NewRows. +// Pass nil when telemetry is not active; individual fields may also be nil. +type TelemetryCallbacks struct { + // OnChunkFetched is called after each result page fetch with chunk-level stats. + OnChunkFetched func(chunkCount int, bytesDownloaded int64, chunkIndex int, chunkLatencyMs int64, totalChunksPresent int32) + // OnClose is called from rows.Close() after all rows have been consumed. + // iterErr is the first error from Next()/fetchResultPage (nil if iteration succeeded). + // closeErr is the error from the CloseOperation RPC (nil if close succeeded). + OnClose func(latencyMs int64, chunkCount int, iterErr error, closeErr error) + // OnCloudFetchFile is called per S3 file download for CloudFetch result sets. + OnCloudFetchFile func(downloadMs int64) +} + func NewRows( ctx context.Context, opHandle *cli_service.TOperationHandle, client cli_service.TCLIService, config *config.Config, directResults *cli_service.TSparkDirectResults, - telemetryUpdate func(chunkCount int, bytesDownloaded int64), + callbacks *TelemetryCallbacks, ) (driver.Rows, dbsqlerr.DBError) { connId := driverctx.ConnIdFromContext(ctx) @@ -117,10 +143,14 @@ func NewRows( config: config, logger_: logger, ctx: ctx, - telemetryUpdate: telemetryUpdate, chunkCount: 0, bytesDownloaded: 0, } + if callbacks != nil { + r.telemetryUpdate = callbacks.OnChunkFetched + r.cloudFetchCallback = callbacks.OnCloudFetchFile + r.closeCallback = callbacks.OnClose + } // if we already have results for the query do some additional initialization if directResults != nil { @@ -145,7 +175,18 @@ func NewRows( } if r.telemetryUpdate != nil { - r.telemetryUpdate(r.chunkCount, r.bytesDownloaded) + // Determine totalChunksPresent for direct results. + // If the server already closed the operation, all data is here (totalPresent=1). + // For CloudFetch direct results, use the number of result links. + var totalPresent int32 + if directResults.CloseOperation != nil { + totalPresent = int32(r.chunkCount) + } else if directResults.ResultSet != nil && directResults.ResultSet.Results != nil && + directResults.ResultSet.Results.ResultLinks != nil { + totalPresent = int32(len(directResults.ResultSet.Results.ResultLinks)) //nolint:gosec + } + // chunkIndex=0, chunkLatencyMs=0: direct results have no separate fetch latency. + r.telemetryUpdate(r.chunkCount, r.bytesDownloaded, 0, 0, totalPresent) } } @@ -210,7 +251,11 @@ func (r *rows) Close() error { if r.ResultPageIterator != nil { r.logger().Debug().Msgf("databricks: closing Rows operation") + closeStart := time.Now() err := r.ResultPageIterator.Close() + if r.closeCallback != nil { + r.closeCallback(time.Since(closeStart).Milliseconds(), r.chunkCount, r.iterationErr, err) + } if err != nil { r.logger().Err(err).Msg(errRowsCloseFailed) return dbsqlerr_int.NewRequestError(r.ctx, errRowsCloseFailed, err) @@ -232,6 +277,7 @@ func (r *rows) Close() error { func (r *rows) Next(dest []driver.Value) error { err := isValidRows(r) if err != nil { + r.trackIterationErr(err) return err } @@ -242,17 +288,20 @@ func (r *rows) Next(dest []driver.Value) error { if b, e = r.isNextRowInPage(); !b && e == nil { err := r.fetchResultPage() if err != nil { + r.trackIterationErr(err) return err } } if e != nil { + r.trackIterationErr(e) return e } // Put values into the destination slice err = r.ScanRow(dest, r.nextRowNumber) if err != nil { + r.trackIterationErr(err) return err } @@ -473,7 +522,11 @@ func (r *rows) fetchResultPage() error { r.RowScanner = nil } + // Record 0-based chunk index before fetching (direct results occupied index 0 if present). + chunkIndex := r.chunkCount + fetchStart := time.Now() fetchResult, err1 := r.ResultPageIterator.Next() + chunkLatencyMs := time.Since(fetchStart).Milliseconds() if err1 != nil { return err1 } @@ -487,8 +540,23 @@ func (r *rows) fetchResultPage() error { } } + // For CloudFetch, the FetchResults RPC only returns presigned S3 URLs — the actual data + // transfer happens later via S3 HTTP GETs timed by cloudFetchCallback. Report 0 latency + // here so the Thrift round-trip is not misreported as chunk download time. + var totalPresent int32 + isCloudFetch := false + if fetchResult != nil && fetchResult.Results != nil && fetchResult.Results.ResultLinks != nil { + totalPresent = int32(len(fetchResult.Results.ResultLinks)) //nolint:gosec + isCloudFetch = true + } + + effectiveLatencyMs := chunkLatencyMs + if isCloudFetch { + effectiveLatencyMs = 0 + } + if r.telemetryUpdate != nil { - r.telemetryUpdate(r.chunkCount, r.bytesDownloaded) + r.telemetryUpdate(r.chunkCount, r.bytesDownloaded, chunkIndex, effectiveLatencyMs, totalPresent) } err1 = r.makeRowScanner(fetchResult) @@ -525,9 +593,9 @@ func (r *rows) makeRowScanner(fetchResults *cli_service.TFetchResultsResp) dbsql if fetchResults.Results.Columns != nil { rs, err = columnbased.NewColumnRowScanner(schema, fetchResults.Results, r.config, r.logger(), r.ctx) } else if fetchResults.Results.ArrowBatches != nil { - rs, err = arrowbased.NewArrowRowScanner(r.resultSetMetadata, fetchResults.Results, r.config, r.logger(), r.ctx) + rs, err = arrowbased.NewArrowRowScanner(r.resultSetMetadata, fetchResults.Results, r.config, r.logger(), r.ctx, nil) } else if fetchResults.Results.ResultLinks != nil { - rs, err = arrowbased.NewArrowRowScanner(r.resultSetMetadata, fetchResults.Results, r.config, r.logger(), r.ctx) + rs, err = arrowbased.NewArrowRowScanner(r.resultSetMetadata, fetchResults.Results, r.config, r.logger(), r.ctx, r.cloudFetchCallback) } else { r.logger().Error().Msg(errRowsUnknowRowType) err = dbsqlerr_int.NewDriverError(r.ctx, errRowsUnknowRowType, nil) @@ -547,6 +615,14 @@ func (r *rows) makeRowScanner(fetchResults *cli_service.TFetchResultsResp) dbsql return err } +// trackIterationErr records the first non-EOF error from Next()/fetchResultPage +// so that closeCallback can report it as the statement's error. +func (r *rows) trackIterationErr(err error) { + if r != nil && r.iterationErr == nil && err != nil && err != io.EOF { + r.iterationErr = err + } +} + func (r *rows) logger() *dbsqllog.DBSQLLogger { if r.logger_ == nil { if r.opHandle != nil { diff --git a/internal/rows/rows_test.go b/internal/rows/rows_test.go index bb8ac196..8947bf05 100644 --- a/internal/rows/rows_test.go +++ b/internal/rows/rows_test.go @@ -1563,3 +1563,121 @@ func TestFetchResultPage_PropagatesGetNextPageError(t *testing.T) { assert.ErrorContains(t, actualErr, errorMsg) } + +// TestRows_CloseCallback_ReceivesChunkCount verifies that when rows.Close() is called, +// the closeCallback receives the correct chunkCount reflecting the number of result pages +// that were fetched during iteration. +// +// This covers the fix where total_chunks_present in the telemetry payload was always null +// for paginated CloudFetch queries: the driver now derives it from r.chunkCount and passes +// it through closeCallback so connection.go can set the "chunk_total_present" tag. +func TestRows_CloseCallback_ReceivesChunkCount(t *testing.T) { + t.Parallel() + + noMoreRows := false + moreRows := true + + // Two pages: page 0 (5 rows, has more), page 1 (3 rows, no more). + colVals := []*cli_service.TColumn{ + {BoolVal: &cli_service.TBoolColumn{Values: []bool{true, false, true, false, true}}}, + } + colVals2 := []*cli_service.TColumn{ + {BoolVal: &cli_service.TBoolColumn{Values: []bool{true, false, true}}}, + } + + pages := []cli_service.TFetchResultsResp{ + { + Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}, + HasMoreRows: &moreRows, + Results: &cli_service.TRowSet{StartRowOffset: 0, Columns: colVals}, + }, + { + Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}, + HasMoreRows: &noMoreRows, + Results: &cli_service.TRowSet{StartRowOffset: 5, Columns: colVals2}, + }, + } + + pageIndex := -1 + fetchFn := func(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) { + pageIndex++ + p := pages[pageIndex] + return &p, nil + } + metaFn := func(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) { + return &cli_service.TGetResultSetMetadataResp{ + Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}, + Schema: &cli_service.TTableSchema{ + Columns: []*cli_service.TColumnDesc{ + {ColumnName: "flag", Position: 0, TypeDesc: &cli_service.TTypeDesc{ + Types: []*cli_service.TTypeEntry{{ + PrimitiveEntry: &cli_service.TPrimitiveTypeEntry{Type: cli_service.TTypeId_BOOLEAN_TYPE}, + }}, + }}, + }, + }, + }, nil + } + + testClient := &client.TestClient{ + FnFetchResults: fetchFn, + FnGetResultSetMetadata: metaFn, + } + + var callbackChunkCount int + closeCallback := func(latencyMs int64, chunkCount int, iterErr error, closeErr error) { + callbackChunkCount = chunkCount + } + + ctx := driverctx.NewContextWithConnId(context.Background(), "connId") + cfg := config.WithDefaults() + cfg.MaxRows = 5 // force paging + + dr, dbErr := NewRows(ctx, nil, testClient, cfg, nil, &TelemetryCallbacks{OnClose: closeCallback}) + assert.Nil(t, dbErr) + + // Drain all rows to force two FetchResults calls. + dest := make([]driver.Value, 1) + for dr.Next(dest) == nil { + } + + // Close should invoke the callback with the total chunk count (2 pages fetched). + assert.Nil(t, dr.Close()) + + // direct results count as chunk 0; two FetchResults calls give chunkCount=2. + // (No directResults here so chunkCount starts at 0, then +1 per FetchResults call.) + assert.Equal(t, 2, callbackChunkCount, + "closeCallback must receive the total number of result pages fetched") +} + +// TestRows_CloseCallback_NilDoesNotPanic verifies that passing nil for closeCallback +// does not cause a panic when rows.Close() is called. +func TestRows_CloseCallback_NilDoesNotPanic(t *testing.T) { + t.Parallel() + + noMoreRows := false + pages := []cli_service.TFetchResultsResp{ + { + Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}, + HasMoreRows: &noMoreRows, + Results: &cli_service.TRowSet{StartRowOffset: 0, Columns: []*cli_service.TColumn{}}, + }, + } + pageIndex := -1 + fetchFn := func(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) { + pageIndex++ + p := pages[pageIndex] + return &p, nil + } + testClient := &client.TestClient{FnFetchResults: fetchFn} + + ctx := driverctx.NewContextWithConnId(context.Background(), "connId") + cfg := config.WithDefaults() + + dr, dbErr := NewRows(ctx, nil, testClient, cfg, nil, nil) + assert.Nil(t, dbErr) + + assert.NotPanics(t, func() { + _ = dr.Close() + }, "nil closeCallback must not cause a panic on rows.Close()") +} diff --git a/telemetry/aggregator.go b/telemetry/aggregator.go index 38dd636c..97c37f64 100644 --- a/telemetry/aggregator.go +++ b/telemetry/aggregator.go @@ -36,6 +36,7 @@ type metricsAggregator struct { ctx context.Context // Cancellable context — cancelled on close to stop workers cancel context.CancelFunc exportQueue chan exportJob // Worker queue; drop batch only when full (matches JDBC LinkedBlockingQueue) + inFlight sync.WaitGroup // tracks jobs submitted to exportQueue but not yet exported } // statementMetrics holds aggregated metrics for a statement. @@ -84,7 +85,10 @@ func (agg *metricsAggregator) exportWorker() { if !ok { return } - agg.exporter.export(job.ctx, job.metrics) + func() { + defer agg.inFlight.Done() + agg.exporter.export(job.ctx, job.metrics) + }() case <-agg.ctx.Done(): return } @@ -115,7 +119,7 @@ func (agg *metricsAggregator) recordMetric(ctx context.Context, metric *telemetr // Terminal operations (session/statement close) flush immediately so metrics // are not lost if the connection closes before the next batch flush — matching // JDBC behavior where CLOSE_STATEMENT and DELETE_SESSION trigger immediate export. - opType, _ := metric.tags["operation_type"].(string) + opType, _ := metric.tags[TagOperationType].(string) if isTerminalOperationType(opType) || len(agg.batch) >= agg.batchSize { agg.flushUnlocked(ctx) } @@ -134,13 +138,13 @@ func (agg *metricsAggregator) recordMetric(ctx context.Context, metric *telemetr // Update aggregated values stmt.totalLatency += time.Duration(metric.latencyMs) * time.Millisecond - if chunkCount, ok := metric.tags["chunk_count"].(int); ok { + if chunkCount, ok := metric.tags[TagChunkCount].(int); ok { stmt.chunkCount += chunkCount } - if bytes, ok := metric.tags["bytes_downloaded"].(int64); ok { + if bytes, ok := metric.tags[TagBytesDownloaded].(int64); ok { stmt.bytesDownloaded += bytes } - if pollCount, ok := metric.tags["poll_count"].(int); ok { + if pollCount, ok := metric.tags[TagPollCount].(int); ok { stmt.pollCount += pollCount } @@ -197,9 +201,9 @@ func (agg *metricsAggregator) completeStatement(ctx context.Context, statementID } // Add aggregated counts - metric.tags["chunk_count"] = stmt.chunkCount - metric.tags["bytes_downloaded"] = stmt.bytesDownloaded - metric.tags["poll_count"] = stmt.pollCount + metric.tags[TagChunkCount] = stmt.chunkCount + metric.tags[TagBytesDownloaded] = stmt.bytesDownloaded + metric.tags[TagPollCount] = stmt.pollCount // Add error information if failed if failed && len(stmt.errors) > 0 { @@ -265,10 +269,14 @@ func (agg *metricsAggregator) flushUnlocked(ctx context.Context) { copy(metrics, agg.batch) agg.batch = agg.batch[:0] + // Increment before send so close()'s inFlight.Wait() sees the job + // even if a worker picks it up before the drain step runs. + agg.inFlight.Add(1) select { case agg.exportQueue <- exportJob{ctx: ctx, metrics: metrics}: default: // Queue full — drop batch silently (matches JDBC's RejectedExecutionException path) + agg.inFlight.Done() // undo Add — job was dropped, not queued logger.Debug().Msg("telemetry: export queue full, dropping metrics batch") } } @@ -277,19 +285,54 @@ func (agg *metricsAggregator) flushUnlocked(ctx context.Context) { // Safe to call multiple times — subsequent calls are no-ops (closeOnce). // // Shutdown order matters: -// 1. Stop periodic flush (close stopCh) so no new async exports are queued. -// 2. Synchronously flush the current batch directly (flushSync bypasses the -// worker queue, so it works even after workers are stopped). -// 3. Cancel the aggregator context to stop the 10 export worker goroutines. +// 1. Stop periodic flush (close stopCh) so no new async flushes are queued. +// 2. Flush agg.batch synchronously (direct export, bypasses worker queue). +// 3. Drain agg.exportQueue — export any jobs still sitting in the queue +// synchronously, bypassing workers (matches their inFlight.Done() call). +// 4. Wait for jobs already picked up by workers to finish their HTTP exports. +// Without this step, cancel() in step 5 could fire while a worker is +// mid-export, causing EXECUTE_STATEMENT/CLOSE_STATEMENT to be silently lost. +// 5. Cancel the aggregator context to stop the 10 export worker goroutines. func (agg *metricsAggregator) close(ctx context.Context) error { agg.closeOnce.Do(func() { - close(agg.stopCh) // Stop periodic flush loop - agg.flushSync(ctx) // Final flush — direct export, no workers needed - agg.cancel() // Stop export workers after final flush + close(agg.stopCh) // 1. Stop periodic flush loop + agg.flushSync(ctx) // 2. Flush agg.batch directly + + // 3. Drain any jobs still sitting in the exportQueue synchronously. + // Each job was counted by inFlight.Add(1) in flushUnlocked; call Done() + // here to match, since the worker won't process these jobs. + agg.drainExportQueue() + // 4. Wait for jobs already picked up by workers before the drain above. + // Workers call inFlight.Done() after their HTTP export completes. + // Use a goroutine + select so we respect the caller's context deadline; + // without this, a hung HTTP export would block conn.Close() indefinitely. + waitCh := make(chan struct{}) + go func() { agg.inFlight.Wait(); close(waitCh) }() + select { + case <-waitCh: + case <-ctx.Done(): + logger.Debug().Msg("telemetry: close timed out waiting for in-flight exports") + } + + agg.cancel() // 5. Stop export workers (all in-flight exports complete) }) return nil } +// drainExportQueue synchronously processes any jobs remaining in the export +// queue, matching inFlight.Done() for each one since workers won't handle them. +func (agg *metricsAggregator) drainExportQueue() { + for { + select { + case job := <-agg.exportQueue: + agg.exporter.export(job.ctx, job.metrics) + agg.inFlight.Done() + default: + return + } + } +} + // simpleError is a simple error implementation for testing. type simpleError struct { msg string diff --git a/telemetry/aggregator_test.go b/telemetry/aggregator_test.go new file mode 100644 index 00000000..f98c44c5 --- /dev/null +++ b/telemetry/aggregator_test.go @@ -0,0 +1,344 @@ +package telemetry + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" +) + +// TestAggregatorClose_WaitsForInFlightWorkerExports verifies that close() does not +// return until every metric picked up by an export worker has been delivered. +// +// Regression test for the race where agg.cancel() fired while a worker was +// mid-HTTP-export, causing EXECUTE_STATEMENT / CLOSE_STATEMENT to be silently lost. +// The fix: step 4 in close() calls agg.inFlight.Wait() before agg.cancel(). +func TestAggregatorClose_WaitsForInFlightWorkerExports(t *testing.T) { + const exportDelay = 100 * time.Millisecond + + var receivedCount int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req TelemetryRequest + if err := json.Unmarshal(body, &req); err == nil { + atomic.AddInt32(&receivedCount, int32(len(req.ProtoLogs))) + } + // Simulate slow server — forces the worker to be mid-HTTP-export when close() runs + time.Sleep(exportDelay) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + cfg := DefaultConfig() + cfg.FlushInterval = 10 * time.Second // disable periodic flush — we flush manually + cfg.BatchSize = 1 // one metric per batch → one worker export per metric + httpClient := &http.Client{Timeout: 5 * time.Second} + + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) + agg := newMetricsAggregator(exporter, cfg) + + ctx := context.Background() + + // Record 5 operation metrics — each triggers an immediate flushUnlocked (terminal op). + for i := 0; i < 5; i++ { + agg.recordMetric(ctx, &telemetryMetric{ + metricType: "operation", + timestamp: time.Now(), + tags: map[string]interface{}{ + TagOperationType: OperationTypeCloseStatement, // terminal → immediate flush + }, + }) + } + + // close() must block until all 5 in-flight worker exports complete. + closeStart := time.Now() + _ = agg.close(ctx) + closeDuration := time.Since(closeStart) + + // close() must have waited at least (exportDelay - some tolerance) per export. + // With 5 metrics and 10 workers running in parallel the minimum wait is exportDelay. + if closeDuration < exportDelay/2 { + t.Errorf("close() returned too quickly (%v); expected it to wait for in-flight exports (delay=%v)", closeDuration, exportDelay) + } + + // All 5 metrics must have been received by the server. + got := atomic.LoadInt32(&receivedCount) + if got != 5 { + t.Errorf("expected 5 metrics received by server, got %d", got) + } +} + +// TestAggregatorClose_DrainsPendingQueueJobsBeforeCancel verifies that metrics +// sitting in the exportQueue (submitted but not yet picked up by a worker) are +// exported synchronously during the drain phase of close(), not lost. +func TestAggregatorClose_DrainsPendingQueueJobsBeforeCancel(t *testing.T) { + var mu sync.Mutex + var receivedLogs []string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req TelemetryRequest + if err := json.Unmarshal(body, &req); err == nil { + mu.Lock() + receivedLogs = append(receivedLogs, req.ProtoLogs...) + mu.Unlock() + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + cfg := DefaultConfig() + cfg.FlushInterval = 10 * time.Second // no periodic flush + cfg.BatchSize = 100 // large batch — won't auto-flush on size + httpClient := &http.Client{Timeout: 5 * time.Second} + + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) + + // Use a single-worker aggregator with a tiny queue to make the "pending in queue" + // scenario deterministic: we manually call flushUnlocked to enqueue a job. + agg := newMetricsAggregator(exporter, cfg) + + ctx := context.Background() + + // Directly submit a job to the export queue (bypassing workers temporarily). + // We inject the metric as a "connection" type which flushes immediately. + agg.recordMetric(ctx, &telemetryMetric{ + metricType: "connection", + timestamp: time.Now(), + sessionID: "drain-test-session", + statementID: "drain-test-stmt", + }) + + // close() should drain the queue and export the metric before returning. + _ = agg.close(ctx) + + mu.Lock() + count := len(receivedLogs) + mu.Unlock() + + if count == 0 { + t.Error("expected metric to be exported during drain phase of close(), got none") + } +} + +// TestAggregatorFlushUnlocked_InFlightAddBeforeSend verifies that inFlight.Add(1) is +// called before the job is sent to exportQueue so that close()'s inFlight.Wait() +// cannot miss a job that a worker picks up before the drain step runs. +func TestAggregatorFlushUnlocked_InFlightAddBeforeSend(t *testing.T) { + var receivedCount int32 + + // Server with a brief delay so workers stay busy during the close() call. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req TelemetryRequest + if err := json.Unmarshal(body, &req); err == nil { + atomic.AddInt32(&receivedCount, int32(len(req.ProtoLogs))) + } + time.Sleep(20 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + cfg := DefaultConfig() + cfg.FlushInterval = 10 * time.Second + cfg.BatchSize = 1 + httpClient := &http.Client{Timeout: 5 * time.Second} + + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) + agg := newMetricsAggregator(exporter, cfg) + ctx := context.Background() + + const numMetrics = 20 + for i := 0; i < numMetrics; i++ { + agg.recordMetric(ctx, &telemetryMetric{ + metricType: "operation", + timestamp: time.Now(), + tags: map[string]interface{}{ + TagOperationType: OperationTypeCloseStatement, + }, + }) + } + + _ = agg.close(ctx) + + got := atomic.LoadInt32(&receivedCount) + if got != numMetrics { + t.Errorf("expected %d metrics, got %d — inFlight ordering may be broken", numMetrics, got) + } +} + +// TestAggregatorClose_SafeToCallMultipleTimes verifies that calling close() multiple +// times (via sync.Once) does not panic or deadlock. +func TestAggregatorClose_SafeToCallMultipleTimes(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + cfg := DefaultConfig() + httpClient := &http.Client{Timeout: 5 * time.Second} + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) + agg := newMetricsAggregator(exporter, cfg) + ctx := context.Background() + + // Call close() concurrently several times — must not panic or deadlock. + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = agg.close(ctx) + }() + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // pass + case <-time.After(5 * time.Second): + t.Fatal("close() deadlocked when called concurrently multiple times") + } +} + +// TestAggregatorFlushUnlocked_DropWhenQueueFull verifies that when the export queue +// is full, the batch is silently dropped and inFlight is not left incremented +// (which would cause inFlight.Wait() to block forever in close()). +// +// Strategy: cancel the aggregator context immediately so workers stop draining the queue, +// then fill the queue to capacity and call flushUnlocked once more. The drop path must +// call inFlight.Done() to undo the earlier Add so that inFlight.Wait() returns promptly. +func TestAggregatorFlushUnlocked_DropWhenQueueFull(t *testing.T) { + cfg := DefaultConfig() + cfg.FlushInterval = 10 * time.Second + cfg.BatchSize = 1 + httpClient := &http.Client{Timeout: 1 * time.Second} + + // Use a no-op exporter — we never actually export in this test. + exporter := newTelemetryExporter("http://127.0.0.1:0", "test-version", httpClient, cfg) + agg := newMetricsAggregator(exporter, cfg) + + // Cancel the aggregator context immediately so workers stop consuming from the queue. + agg.cancel() + // Give workers a moment to exit their select loop. + time.Sleep(20 * time.Millisecond) + + ctx := context.Background() + + // Fill the export queue to capacity with synthetic jobs, each paired with an inFlight.Add. + for i := 0; i < exportQueueSize; i++ { + agg.inFlight.Add(1) + agg.exportQueue <- exportJob{ctx: ctx, metrics: nil} + } + + // Now call flushUnlocked — the queue is full, so the batch must be dropped. + // The drop path must call inFlight.Done() to undo the Add it made before the send attempt. + agg.mu.Lock() + agg.batch = append(agg.batch, &telemetryMetric{ + metricType: "operation", + timestamp: time.Now(), + }) + agg.flushUnlocked(ctx) + agg.mu.Unlock() + + // Drain the synthetic queue entries and release their inFlight counts. + for i := 0; i < exportQueueSize; i++ { + <-agg.exportQueue + agg.inFlight.Done() + } + + // If flushUnlocked properly called Done() on drop, inFlight counter is now at 0 + // and inFlight.Wait() must return immediately (not block forever). + waitDone := make(chan struct{}) + go func() { + agg.inFlight.Wait() + close(waitDone) + }() + + select { + case <-waitDone: + // pass — inFlight counter was properly balanced on the drop path + case <-time.After(2 * time.Second): + t.Fatal("inFlight.Wait() blocked — inFlight counter is unbalanced after queue-full drop") + } +} + +// TestAggregatorClose_RespectsContextTimeout verifies that close() returns promptly +// when the caller's context deadline expires, rather than blocking indefinitely on a +// hung HTTP export. +// +// Regression test for the scenario where a telemetry server is unresponsive and +// conn.Close() would hang forever waiting for in-flight exports to finish. +func TestAggregatorClose_RespectsContextTimeout(t *testing.T) { + const serverDelay = 5 * time.Second + + // serverGotRequest is signaled when the HTTP handler receives a request, + // confirming a worker has picked up the job and started an HTTP export. + serverGotRequest := make(chan struct{}, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Signal that the worker has started the HTTP export. + select { + case serverGotRequest <- struct{}{}: + default: + } + // Simulate a hung/slow telemetry server that takes much longer than our timeout. + time.Sleep(serverDelay) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + cfg := DefaultConfig() + cfg.FlushInterval = 10 * time.Second + cfg.BatchSize = 1 + httpClient := &http.Client{Timeout: 10 * time.Second} + + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) + agg := newMetricsAggregator(exporter, cfg) + + // Record a metric that triggers an immediate flush (terminal op). + // A worker will pick it up and start a slow HTTP export. + agg.recordMetric(context.Background(), &telemetryMetric{ + metricType: "operation", + timestamp: time.Now(), + tags: map[string]interface{}{TagOperationType: OperationTypeCloseStatement}, + }) + + // Wait for the worker to actually start the HTTP request, rather than using + // a racy time.Sleep. This ensures close() enters the inFlight.Wait path + // (step 4) rather than draining the job synchronously (step 3). + select { + case <-serverGotRequest: + case <-time.After(5 * time.Second): + t.Fatal("worker did not pick up the job and start HTTP export in time") + } + + // Call close() with a short timeout — it must return when the context expires, + // NOT wait for the full serverDelay. + timeout := 200 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + closeStart := time.Now() + _ = agg.close(ctx) + closeDuration := time.Since(closeStart) + + // close() must have returned near the timeout, not after serverDelay. + if closeDuration >= serverDelay { + t.Errorf("close() blocked for %v (server delay %v); expected it to respect the %v context timeout", closeDuration, serverDelay, timeout) + } + // Allow some slack but it should be well under the server delay. + if closeDuration > 1*time.Second { + t.Errorf("close() took %v; expected it to return near the %v context timeout", closeDuration, timeout) + } +} diff --git a/telemetry/benchmark_test.go b/telemetry/benchmark_test.go index d4687a08..ea20e1fe 100644 --- a/telemetry/benchmark_test.go +++ b/telemetry/benchmark_test.go @@ -77,7 +77,7 @@ func BenchmarkAggregator_RecordMetric(b *testing.B) { sessionID: "bench-session", statementID: "bench-stmt", latencyMs: 10, - tags: map[string]interface{}{"operation_type": OperationTypeExecuteStatement}, + tags: map[string]interface{}{TagOperationType: OperationTypeExecuteStatement}, } b.ResetTimer() @@ -300,7 +300,7 @@ func TestGracefulShutdown_FinalFlush(t *testing.T) { timestamp: time.Now(), sessionID: "test-session", latencyMs: int64(i), - tags: map[string]interface{}{"operation_type": OperationTypeExecuteStatement}, + tags: map[string]interface{}{TagOperationType: OperationTypeExecuteStatement}, }) } diff --git a/telemetry/config.go b/telemetry/config.go index 1238333f..9054cb36 100644 --- a/telemetry/config.go +++ b/telemetry/config.go @@ -5,6 +5,18 @@ import ( "net/http" "strconv" "time" + + "github.com/databricks/databricks-sql-go/logger" +) + +const ( + // maxTelemetryRetryCount caps DSN-provided retry count to prevent + // excessive retries from misconfiguration. + maxTelemetryRetryCount = 10 + + // maxTelemetryRetryDelay caps DSN-provided retry delay to prevent + // excessively long backoff from misconfiguration. + maxTelemetryRetryDelay = 30 * time.Second ) // Config holds telemetry configuration. @@ -12,9 +24,11 @@ type Config struct { // Enabled controls whether telemetry is active Enabled bool - // EnableTelemetry indicates user wants telemetry enabled. - // Follows client > server > default priority. - EnableTelemetry bool + // EnableTelemetry is a tristate for the client DSN setting: + // nil — not set by the client; server feature flag controls enablement + // &true — client explicitly opted in (overrides server flag) + // &false— client explicitly opted out (overrides server flag) + EnableTelemetry *bool // BatchSize is the number of metrics to batch before flushing BatchSize int @@ -39,12 +53,18 @@ type Config struct { } // DefaultConfig returns default telemetry configuration. -// Note: Telemetry is disabled by default. The default will remain false until -// server-side feature flags are wired in to control the rollout. +// +// BEHAVIORAL NOTE (SDR-approved): When EnableTelemetry is nil (the default), +// telemetry enablement is controlled by the server-side feature flag +// (databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver). +// This means telemetry may be active without the user explicitly opting in. +// The user can always override by setting enableTelemetry=true or enableTelemetry=false +// in the DSN or via WithEnableTelemetry(). No PII is collected; only aggregate +// driver performance metrics are sent to the Databricks telemetry endpoint. func DefaultConfig() *Config { return &Config{ Enabled: false, - EnableTelemetry: false, + EnableTelemetry: nil, // unset — server feature flag decides BatchSize: 100, FlushInterval: 5 * time.Second, MaxRetries: 3, @@ -61,7 +81,7 @@ func ParseTelemetryConfig(params map[string]string) *Config { if v, ok := params["enableTelemetry"]; ok { if b, err := strconv.ParseBool(v); err == nil { - cfg.EnableTelemetry = b + cfg.EnableTelemetry = &b // non-nil: client explicitly set via DSN } } @@ -72,50 +92,48 @@ func ParseTelemetryConfig(params map[string]string) *Config { } if v, ok := params["telemetry_flush_interval"]; ok { - if duration, err := time.ParseDuration(v); err == nil { + if duration, err := time.ParseDuration(v); err == nil && duration > 0 { cfg.FlushInterval = duration } } + if v, ok := params["telemetry_retry_count"]; ok { + if n, err := strconv.Atoi(v); err == nil && n >= 0 { + if n > maxTelemetryRetryCount { + logger.Debug().Msgf("telemetry: retry_count %d exceeds max %d, clamping", n, maxTelemetryRetryCount) + n = maxTelemetryRetryCount + } + cfg.MaxRetries = n + } + } + + if v, ok := params["telemetry_retry_delay"]; ok { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + if d > maxTelemetryRetryDelay { + logger.Debug().Msgf("telemetry: retry_delay %v exceeds max %v, clamping", d, maxTelemetryRetryDelay) + d = maxTelemetryRetryDelay + } + cfg.RetryDelay = d + } + } + return cfg } -// isTelemetryEnabled checks if telemetry should be enabled for this connection. -// Implements the priority-based decision tree for telemetry enablement. -// -// Priority (highest to lowest): -// 1. enableTelemetry=true - Client opt-in (server feature flag still consulted) -// 2. enableTelemetry=false - Explicit opt-out (always disabled) -// 3. Server Feature Flag Only - Default behavior (Databricks-controlled) -// 4. Default - Disabled (false) +// isTelemetryEnabled returns true in exactly two cases: +// 1. The client explicitly set enableTelemetry=true in the DSN. +// 2. The client did not set enableTelemetry and the server feature flag is enabled +// (databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver). // -// Parameters: -// - ctx: Context for the request -// - cfg: Telemetry configuration -// - host: Databricks host to check feature flags against -// - httpClient: HTTP client for making feature flag requests -// -// Returns: -// - bool: true if telemetry should be enabled, false otherwise +// In all other cases — explicit opt-out or server flag absent/unreachable — returns false. func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, driverVersion string, httpClient *http.Client) bool { - // Priority 1 & 2: Respect client preference when explicitly set - // enableTelemetry=false → always disabled; enableTelemetry=true → check server flag - // When enableTelemetry is explicitly set to false, respect that - if !cfg.EnableTelemetry { - return false + if cfg.EnableTelemetry != nil { + return *cfg.EnableTelemetry } - // Priority 3 & 4: Check server-side feature flag - // This handles both: - // - User explicitly opted in (enableTelemetry=true) - respect server decision - // - Default behavior (no explicit setting) - server controls enablement - flagCache := getFeatureFlagCache() - serverEnabled, err := flagCache.isTelemetryEnabled(ctx, host, driverVersion, httpClient) + serverEnabled, err := getFeatureFlagCache().isTelemetryEnabled(ctx, host, driverVersion, httpClient) if err != nil { - // On error, respect default (disabled) - // This ensures telemetry failures don't impact driver operation return false } - return serverEnabled } diff --git a/telemetry/config_test.go b/telemetry/config_test.go index 55f24de3..6806b184 100644 --- a/telemetry/config_test.go +++ b/telemetry/config_test.go @@ -11,93 +11,61 @@ import ( func TestDefaultConfig(t *testing.T) { cfg := DefaultConfig() - // Verify telemetry is disabled by default if cfg.Enabled { t.Error("Expected telemetry to be disabled by default, got enabled") } - - // Verify other defaults + if cfg.EnableTelemetry != nil { + t.Error("Expected EnableTelemetry to be nil (unset) by default") + } if cfg.BatchSize != 100 { t.Errorf("Expected BatchSize 100, got %d", cfg.BatchSize) } - if cfg.FlushInterval != 5*time.Second { t.Errorf("Expected FlushInterval 5s, got %v", cfg.FlushInterval) } - if cfg.MaxRetries != 3 { t.Errorf("Expected MaxRetries 3, got %d", cfg.MaxRetries) } - if cfg.RetryDelay != 100*time.Millisecond { t.Errorf("Expected RetryDelay 100ms, got %v", cfg.RetryDelay) } - if !cfg.CircuitBreakerEnabled { t.Error("Expected CircuitBreakerEnabled true, got false") } - if cfg.CircuitBreakerThreshold != 5 { t.Errorf("Expected CircuitBreakerThreshold 5, got %d", cfg.CircuitBreakerThreshold) } - if cfg.CircuitBreakerTimeout != 1*time.Minute { t.Errorf("Expected CircuitBreakerTimeout 1m, got %v", cfg.CircuitBreakerTimeout) } } -func TestParseTelemetryConfig_EmptyParams(t *testing.T) { - params := map[string]string{} - cfg := ParseTelemetryConfig(params) - - // Should return defaults - if cfg.Enabled { - t.Error("Expected telemetry to be disabled by default") - } - - if cfg.BatchSize != 100 { - t.Errorf("Expected BatchSize 100, got %d", cfg.BatchSize) - } -} - func TestParseTelemetryConfig_EnabledTrue(t *testing.T) { - params := map[string]string{ - "enableTelemetry": "true", - } - cfg := ParseTelemetryConfig(params) + cfg := ParseTelemetryConfig(map[string]string{"enableTelemetry": "true"}) - if !cfg.EnableTelemetry { - t.Error("Expected EnableTelemetry to be true when set to 'true'") + if cfg.EnableTelemetry == nil || !*cfg.EnableTelemetry { + t.Error("Expected EnableTelemetry to be &true when set to 'true'") } } func TestParseTelemetryConfig_Enabled1(t *testing.T) { - params := map[string]string{ - "enableTelemetry": "1", - } - cfg := ParseTelemetryConfig(params) + cfg := ParseTelemetryConfig(map[string]string{"enableTelemetry": "1"}) - if !cfg.EnableTelemetry { - t.Error("Expected EnableTelemetry to be true when set to '1'") + if cfg.EnableTelemetry == nil || !*cfg.EnableTelemetry { + t.Error("Expected EnableTelemetry to be &true when set to '1'") } } func TestParseTelemetryConfig_EnabledFalse(t *testing.T) { - params := map[string]string{ - "enableTelemetry": "false", - } - cfg := ParseTelemetryConfig(params) + cfg := ParseTelemetryConfig(map[string]string{"enableTelemetry": "false"}) - if cfg.EnableTelemetry { - t.Error("Expected EnableTelemetry to be false when set to 'false'") + if cfg.EnableTelemetry == nil || *cfg.EnableTelemetry { + t.Error("Expected EnableTelemetry to be &false when set to 'false'") } } func TestParseTelemetryConfig_BatchSize(t *testing.T) { - params := map[string]string{ - "telemetry_batch_size": "50", - } - cfg := ParseTelemetryConfig(params) + cfg := ParseTelemetryConfig(map[string]string{"telemetry_batch_size": "50"}) if cfg.BatchSize != 50 { t.Errorf("Expected BatchSize 50, got %d", cfg.BatchSize) @@ -105,46 +73,31 @@ func TestParseTelemetryConfig_BatchSize(t *testing.T) { } func TestParseTelemetryConfig_BatchSizeInvalid(t *testing.T) { - params := map[string]string{ - "telemetry_batch_size": "invalid", - } - cfg := ParseTelemetryConfig(params) + cfg := ParseTelemetryConfig(map[string]string{"telemetry_batch_size": "invalid"}) - // Should fall back to default if cfg.BatchSize != 100 { t.Errorf("Expected BatchSize to fallback to 100, got %d", cfg.BatchSize) } } func TestParseTelemetryConfig_BatchSizeZero(t *testing.T) { - params := map[string]string{ - "telemetry_batch_size": "0", - } - cfg := ParseTelemetryConfig(params) + cfg := ParseTelemetryConfig(map[string]string{"telemetry_batch_size": "0"}) - // Should ignore zero and use default if cfg.BatchSize != 100 { t.Errorf("Expected BatchSize to fallback to 100 when zero, got %d", cfg.BatchSize) } } func TestParseTelemetryConfig_BatchSizeNegative(t *testing.T) { - params := map[string]string{ - "telemetry_batch_size": "-10", - } - cfg := ParseTelemetryConfig(params) + cfg := ParseTelemetryConfig(map[string]string{"telemetry_batch_size": "-10"}) - // Should ignore negative and use default if cfg.BatchSize != 100 { t.Errorf("Expected BatchSize to fallback to 100 when negative, got %d", cfg.BatchSize) } } func TestParseTelemetryConfig_FlushInterval(t *testing.T) { - params := map[string]string{ - "telemetry_flush_interval": "10s", - } - cfg := ParseTelemetryConfig(params) + cfg := ParseTelemetryConfig(map[string]string{"telemetry_flush_interval": "10s"}) if cfg.FlushInterval != 10*time.Second { t.Errorf("Expected FlushInterval 10s, got %v", cfg.FlushInterval) @@ -152,213 +105,199 @@ func TestParseTelemetryConfig_FlushInterval(t *testing.T) { } func TestParseTelemetryConfig_FlushIntervalInvalid(t *testing.T) { - params := map[string]string{ - "telemetry_flush_interval": "invalid", - } - cfg := ParseTelemetryConfig(params) + cfg := ParseTelemetryConfig(map[string]string{"telemetry_flush_interval": "invalid"}) - // Should fall back to default if cfg.FlushInterval != 5*time.Second { t.Errorf("Expected FlushInterval to fallback to 5s, got %v", cfg.FlushInterval) } } -func TestParseTelemetryConfig_MultipleParams(t *testing.T) { - params := map[string]string{ - "enableTelemetry": "true", - "telemetry_batch_size": "200", - "telemetry_flush_interval": "30s", +func TestParseTelemetryConfig_RetryCount(t *testing.T) { + cfg := ParseTelemetryConfig(map[string]string{"telemetry_retry_count": "5"}) + + if cfg.MaxRetries != 5 { + t.Errorf("Expected MaxRetries 5, got %d", cfg.MaxRetries) } - cfg := ParseTelemetryConfig(params) +} - if !cfg.EnableTelemetry { - t.Error("Expected EnableTelemetry to be true") +func TestParseTelemetryConfig_RetryCountZero(t *testing.T) { + // Zero is valid — it disables retries entirely (unlike batch_size where zero is rejected) + cfg := ParseTelemetryConfig(map[string]string{"telemetry_retry_count": "0"}) + + if cfg.MaxRetries != 0 { + t.Errorf("Expected MaxRetries 0 (disable retries), got %d", cfg.MaxRetries) } +} - if cfg.BatchSize != 200 { - t.Errorf("Expected BatchSize 200, got %d", cfg.BatchSize) +func TestParseTelemetryConfig_RetryCountInvalid(t *testing.T) { + cfg := ParseTelemetryConfig(map[string]string{"telemetry_retry_count": "invalid"}) + + if cfg.MaxRetries != 3 { + t.Errorf("Expected MaxRetries to fallback to 3, got %d", cfg.MaxRetries) } +} - if cfg.FlushInterval != 30*time.Second { - t.Errorf("Expected FlushInterval 30s, got %v", cfg.FlushInterval) +func TestParseTelemetryConfig_RetryDelay(t *testing.T) { + cfg := ParseTelemetryConfig(map[string]string{"telemetry_retry_delay": "500ms"}) + + if cfg.RetryDelay != 500*time.Millisecond { + t.Errorf("Expected RetryDelay 500ms, got %v", cfg.RetryDelay) } +} - // Other fields should still have defaults - if cfg.MaxRetries != 3 { - t.Errorf("Expected MaxRetries to remain default 3, got %d", cfg.MaxRetries) +func TestParseTelemetryConfig_RetryDelayInvalid(t *testing.T) { + cfg := ParseTelemetryConfig(map[string]string{"telemetry_retry_delay": "invalid"}) + + if cfg.RetryDelay != 100*time.Millisecond { + t.Errorf("Expected RetryDelay to fallback to 100ms, got %v", cfg.RetryDelay) } } -// TestIsTelemetryEnabled_ExplicitOptOut tests Priority 1 (client opt-out): enableTelemetry=false -func TestIsTelemetryEnabled_ExplicitOptOut(t *testing.T) { - // Setup: Create a server that returns enabled - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Even if server says enabled, explicit opt-out should disable - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}], "ttl_seconds": 300}`)) - })) - defer server.Close() +func TestParseTelemetryConfig_RetryCountExceedsCap(t *testing.T) { + cfg := ParseTelemetryConfig(map[string]string{"telemetry_retry_count": "15"}) + if cfg.MaxRetries != maxTelemetryRetryCount { + t.Errorf("Expected MaxRetries clamped to %d, got %d", maxTelemetryRetryCount, cfg.MaxRetries) + } +} - cfg := &Config{ - EnableTelemetry: false, // Priority 2: Explicit opt-out +func TestParseTelemetryConfig_RetryCountAtCap(t *testing.T) { + cfg := ParseTelemetryConfig(map[string]string{"telemetry_retry_count": "10"}) + if cfg.MaxRetries != 10 { + t.Errorf("Expected MaxRetries 10, got %d", cfg.MaxRetries) } +} - ctx := context.Background() - httpClient := &http.Client{Timeout: 5 * time.Second} +func TestParseTelemetryConfig_RetryDelayExceedsCap(t *testing.T) { + cfg := ParseTelemetryConfig(map[string]string{"telemetry_retry_delay": "60s"}) + if cfg.RetryDelay != maxTelemetryRetryDelay { + t.Errorf("Expected RetryDelay clamped to %v, got %v", maxTelemetryRetryDelay, cfg.RetryDelay) + } +} - result := isTelemetryEnabled(ctx, cfg, server.URL, "test-version", httpClient) +func TestParseTelemetryConfig_RetryDelayAtCap(t *testing.T) { + cfg := ParseTelemetryConfig(map[string]string{"telemetry_retry_delay": "30s"}) + if cfg.RetryDelay != 30*time.Second { + t.Errorf("Expected RetryDelay 30s, got %v", cfg.RetryDelay) + } +} - if result { - t.Error("Expected telemetry to be disabled with EnableTelemetry=false, got enabled") +func TestParseTelemetryConfig_AllParams(t *testing.T) { + cfg := ParseTelemetryConfig(map[string]string{ + "enableTelemetry": "true", + "telemetry_batch_size": "200", + "telemetry_flush_interval": "30s", + "telemetry_retry_count": "5", + "telemetry_retry_delay": "250ms", + }) + + if cfg.EnableTelemetry == nil || !*cfg.EnableTelemetry { + t.Error("Expected EnableTelemetry to be &true") + } + if cfg.BatchSize != 200 { + t.Errorf("Expected BatchSize 200, got %d", cfg.BatchSize) + } + if cfg.FlushInterval != 30*time.Second { + t.Errorf("Expected FlushInterval 30s, got %v", cfg.FlushInterval) + } + if cfg.MaxRetries != 5 { + t.Errorf("Expected MaxRetries 5, got %d", cfg.MaxRetries) + } + if cfg.RetryDelay != 250*time.Millisecond { + t.Errorf("Expected RetryDelay 250ms, got %v", cfg.RetryDelay) } } -// TestIsTelemetryEnabled_UserOptInServerEnabled tests Priority 1 (client opt-in): user opts in + server enabled -func TestIsTelemetryEnabled_UserOptInServerEnabled(t *testing.T) { - // Setup: Create a server that returns enabled +// TestIsTelemetryEnabled_ExplicitOptOut: client sets enableTelemetry=false → +// disabled even when server flag is true. Server is not consulted. +func TestIsTelemetryEnabled_ExplicitOptOut(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}], "ttl_seconds": 300}`)) })) defer server.Close() - cfg := &Config{ - EnableTelemetry: true, // User wants telemetry - } - - ctx := context.Background() - httpClient := &http.Client{Timeout: 5 * time.Second} + result := isTelemetryEnabled(context.Background(), &Config{EnableTelemetry: boolPtr(false)}, server.URL, "test-version", &http.Client{Timeout: 5 * time.Second}) - // Setup feature flag cache context - flagCache := getFeatureFlagCache() - flagCache.getOrCreateContext(server.URL) - defer flagCache.releaseContext(server.URL) + if result { + t.Error("Expected telemetry to be disabled when client sets enableTelemetry=false, got enabled") + } +} - result := isTelemetryEnabled(ctx, cfg, server.URL, "test-version", httpClient) +// TestIsTelemetryEnabled_ExplicitOptIn: client sets enableTelemetry=true → +// enabled without any server call (unreachable host proves no network call is made). +func TestIsTelemetryEnabled_ExplicitOptIn(t *testing.T) { + result := isTelemetryEnabled(context.Background(), &Config{EnableTelemetry: boolPtr(true)}, "http://unreachable-host", "test-version", &http.Client{Timeout: 5 * time.Second}) if !result { - t.Error("Expected telemetry to be enabled when user opts in and server allows, got disabled") + t.Error("Expected telemetry to be enabled when client sets enableTelemetry=true, got disabled") } } -// TestIsTelemetryEnabled_UserOptInServerDisabled tests: user opts in but server disabled -func TestIsTelemetryEnabled_UserOptInServerDisabled(t *testing.T) { - // Setup: Create a server that returns disabled +// TestIsTelemetryEnabled_ServerEnabled: no DSN override, server flag=true → enabled. +func TestIsTelemetryEnabled_ServerEnabled(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "false"}], "ttl_seconds": 300}`)) + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}], "ttl_seconds": 300}`)) })) defer server.Close() - cfg := &Config{ - EnableTelemetry: true, // User wants telemetry - } - - ctx := context.Background() - httpClient := &http.Client{Timeout: 5 * time.Second} - - // Setup feature flag cache context flagCache := getFeatureFlagCache() flagCache.getOrCreateContext(server.URL) defer flagCache.releaseContext(server.URL) - result := isTelemetryEnabled(ctx, cfg, server.URL, "test-version", httpClient) + result := isTelemetryEnabled(context.Background(), &Config{}, server.URL, "test-version", &http.Client{Timeout: 5 * time.Second}) - if result { - t.Error("Expected telemetry to be disabled when server disables it, got enabled") + if !result { + t.Error("Expected telemetry to be enabled when server flag is true and EnableTelemetry is nil, got disabled") } } -// TestIsTelemetryEnabled_ServerFlagOnly tests: default EnableTelemetry=false is always disabled -func TestIsTelemetryEnabled_ServerFlagOnly(t *testing.T) { - // Setup: Create a server that returns enabled +// TestIsTelemetryEnabled_ServerDisabled: no DSN override, server flag=false → disabled. +func TestIsTelemetryEnabled_ServerDisabled(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}], "ttl_seconds": 300}`)) + _, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "false"}], "ttl_seconds": 300}`)) })) defer server.Close() - cfg := &Config{ - EnableTelemetry: false, // Default: no explicit user preference - } - - ctx := context.Background() - httpClient := &http.Client{Timeout: 5 * time.Second} - - // Setup feature flag cache context flagCache := getFeatureFlagCache() flagCache.getOrCreateContext(server.URL) defer flagCache.releaseContext(server.URL) - result := isTelemetryEnabled(ctx, cfg, server.URL, "test-version", httpClient) - - // When enableTelemetry is false (default), should return false (Priority 2) - if result { - t.Error("Expected telemetry to be disabled with default EnableTelemetry=false, got enabled") - } -} - -// TestIsTelemetryEnabled_Default tests Priority 5: default disabled -func TestIsTelemetryEnabled_Default(t *testing.T) { - cfg := DefaultConfig() - - ctx := context.Background() - httpClient := &http.Client{Timeout: 5 * time.Second} - - result := isTelemetryEnabled(ctx, cfg, "test-host", "test-version", httpClient) + result := isTelemetryEnabled(context.Background(), &Config{}, server.URL, "test-version", &http.Client{Timeout: 5 * time.Second}) if result { - t.Error("Expected telemetry to be disabled by default, got enabled") + t.Error("Expected telemetry to be disabled when server flag is false and EnableTelemetry is nil, got enabled") } } -// TestIsTelemetryEnabled_ServerError tests error handling +// TestIsTelemetryEnabled_ServerError: no DSN override, server returns 500 → disabled. func TestIsTelemetryEnabled_ServerError(t *testing.T) { - // Setup: Create a server that returns error server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() - cfg := &Config{ - EnableTelemetry: true, // User wants telemetry - } - - ctx := context.Background() - httpClient := &http.Client{Timeout: 5 * time.Second} - - // Setup feature flag cache context flagCache := getFeatureFlagCache() flagCache.getOrCreateContext(server.URL) defer flagCache.releaseContext(server.URL) - result := isTelemetryEnabled(ctx, cfg, server.URL, "test-version", httpClient) + result := isTelemetryEnabled(context.Background(), &Config{}, server.URL, "test-version", &http.Client{Timeout: 5 * time.Second}) - // On error, should default to disabled if result { - t.Error("Expected telemetry to be disabled on server error, got enabled") + t.Error("Expected telemetry to be disabled when server errors and EnableTelemetry is nil, got enabled") } } -// TestIsTelemetryEnabled_ServerUnreachable tests unreachable server +// TestIsTelemetryEnabled_ServerUnreachable: no DSN override, server unreachable → disabled. func TestIsTelemetryEnabled_ServerUnreachable(t *testing.T) { - cfg := &Config{ - EnableTelemetry: true, // User wants telemetry - } - - ctx := context.Background() - httpClient := &http.Client{Timeout: 1 * time.Second} - - // Setup feature flag cache context with unreachable host flagCache := getFeatureFlagCache() - unreachableHost := "http://localhost:9999" - flagCache.getOrCreateContext(unreachableHost) - defer flagCache.releaseContext(unreachableHost) + flagCache.getOrCreateContext("http://localhost:9999") + defer flagCache.releaseContext("http://localhost:9999") - result := isTelemetryEnabled(ctx, cfg, unreachableHost, "test-version", httpClient) + result := isTelemetryEnabled(context.Background(), &Config{}, "http://localhost:9999", "test-version", &http.Client{Timeout: 1 * time.Second}) - // On error, should default to disabled if result { - t.Error("Expected telemetry to be disabled when server unreachable, got enabled") + t.Error("Expected telemetry to be disabled when server is unreachable and EnableTelemetry is nil, got enabled") } } diff --git a/telemetry/driver_integration.go b/telemetry/driver_integration.go index 743b4107..f8b32ffc 100644 --- a/telemetry/driver_integration.go +++ b/telemetry/driver_integration.go @@ -8,63 +8,88 @@ import ( "github.com/databricks/databricks-sql-go/internal/config" ) +// TelemetryInitOptions bundles the parameters for InitializeForConnection. +type TelemetryInitOptions struct { + // Host is the Databricks host. + Host string + + // DriverVersion is the driver version string. + DriverVersion string + + // HTTPClient is the HTTP client used for both feature-flag checks and + // telemetry export. The /telemetry-ext endpoint requires authentication, + // so this should be the authenticated driver client. + HTTPClient *http.Client + + // EnableTelemetry is a tristate from the client DSN: + // unset — server feature flag controls enablement + // true — client explicitly opted in + // false — client explicitly opted out + EnableTelemetry config.ConfigValue[bool] + + // BatchSize is the number of metrics per batch (0 = use default 100). + BatchSize int + + // FlushInterval is the flush interval (0 = use default 5s). + FlushInterval time.Duration + + // RetryCount is max retry attempts (-1 = use default 3; 0 = disable retries). + // IMPORTANT: Go's zero-value for int is 0, which disables retries. Callers + // constructing TelemetryInitOptions must set RetryCount = -1 explicitly to + // get the default retry behavior. + RetryCount int + + // RetryDelay is the base delay between retries (0 = use default 100ms). + RetryDelay time.Duration +} + // InitializeForConnection initializes telemetry for a database connection. // Returns an Interceptor if telemetry is enabled, nil otherwise. // This function handles all the logic for checking feature flags and creating the interceptor. -// -// Parameters: -// - ctx: Context for the initialization -// - host: Databricks host -// - driverVersion: Driver version string -// - httpClient: HTTP client for making requests -// - enableTelemetry: Client config overlay (unset = check server flag, true/false = override server) -// -// Returns: -// - *Interceptor: Telemetry interceptor if enabled, nil otherwise -func InitializeForConnection( - ctx context.Context, - host string, - driverVersion string, - httpClient *http.Client, - enableTelemetry config.ConfigValue[bool], - batchSize int, - flushInterval time.Duration, -) *Interceptor { +func InitializeForConnection(ctx context.Context, opts TelemetryInitOptions) *Interceptor { // Create telemetry config and apply client overlay. - // ConfigValue[bool] semantics: - // - unset → true (let server feature flag decide) - // - true → true (server feature flag still consulted) - // - false → false (explicitly disabled, skip server flag check) + // Priority: client DSN > server feature flag > default (disabled). cfg := DefaultConfig() - if val, isSet := enableTelemetry.Get(); isSet { - cfg.EnableTelemetry = val - } else { - cfg.EnableTelemetry = true // Unset: default to enabled, server flag decides + if val, isSet := opts.EnableTelemetry.Get(); isSet { + cfg.EnableTelemetry = &val // non-nil: client explicitly set via DSN + } + // When unset: cfg.EnableTelemetry remains nil, server feature flag controls enablement. + if opts.BatchSize > 0 { + cfg.BatchSize = opts.BatchSize + } + if opts.FlushInterval > 0 { + cfg.FlushInterval = opts.FlushInterval } - if batchSize > 0 { - cfg.BatchSize = batchSize + if opts.RetryCount >= 0 { + cfg.MaxRetries = opts.RetryCount + if cfg.MaxRetries > maxTelemetryRetryCount { + cfg.MaxRetries = maxTelemetryRetryCount + } } - if flushInterval > 0 { - cfg.FlushInterval = flushInterval + if opts.RetryDelay > 0 { + cfg.RetryDelay = opts.RetryDelay + if cfg.RetryDelay > maxTelemetryRetryDelay { + cfg.RetryDelay = maxTelemetryRetryDelay + } } // Get feature flag cache context FIRST (for reference counting) flagCache := getFeatureFlagCache() - flagCache.getOrCreateContext(host) + flagCache.getOrCreateContext(opts.Host) // Check if telemetry should be enabled - enabled := isTelemetryEnabled(ctx, cfg, host, driverVersion, httpClient) + enabled := isTelemetryEnabled(ctx, cfg, opts.Host, opts.DriverVersion, opts.HTTPClient) if !enabled { - flagCache.releaseContext(host) + flagCache.releaseContext(opts.Host) return nil } // Get or create telemetry client for this host clientMgr := getClientManager() - telemetryClient := clientMgr.getOrCreateClient(host, driverVersion, httpClient, cfg) + telemetryClient := clientMgr.getOrCreateClient(opts.Host, opts.DriverVersion, opts.HTTPClient, cfg) if telemetryClient == nil { // Client failed to start; release the flag cache ref we incremented above - flagCache.releaseContext(host) + flagCache.releaseContext(opts.Host) return nil } diff --git a/telemetry/exporter.go b/telemetry/exporter.go index 53fbc14a..3ecbf81f 100644 --- a/telemetry/exporter.go +++ b/telemetry/exporter.go @@ -41,18 +41,6 @@ type telemetryMetric struct { tags map[string]interface{} } -// exportedMetric is a single metric in the payload. -type exportedMetric struct { - MetricType string `json:"metric_type"` - Timestamp string `json:"timestamp"` // RFC3339 - WorkspaceID string `json:"workspace_id,omitempty"` - SessionID string `json:"session_id,omitempty"` - StatementID string `json:"statement_id,omitempty"` - LatencyMs int64 `json:"latency_ms,omitempty"` - ErrorType string `json:"error_type,omitempty"` - Tags map[string]interface{} `json:"tags,omitempty"` -} - // ensureHTTPScheme adds https:// prefix to host if no scheme is present. func ensureHTTPScheme(host string) string { if strings.HasPrefix(host, httpPrefix) || strings.HasPrefix(host, httpsPrefix) { @@ -167,28 +155,6 @@ func (e *telemetryExporter) doExport(ctx context.Context, metrics []*telemetryMe return nil } -// toExportedMetric converts internal metric to exported format with tag filtering. -func (m *telemetryMetric) toExportedMetric() *exportedMetric { - // Filter tags based on export scope - filteredTags := make(map[string]interface{}) - for k, v := range m.tags { - if shouldExportToDatabricks(m.metricType, k) { - filteredTags[k] = v - } - } - - return &exportedMetric{ - MetricType: m.metricType, - Timestamp: m.timestamp.Format(time.RFC3339), - WorkspaceID: m.workspaceID, - SessionID: m.sessionID, - StatementID: m.statementID, - LatencyMs: m.latencyMs, - ErrorType: m.errorType, - Tags: filteredTags, - } -} - // isRetryableStatus returns true if HTTP status is retryable. // Retryable statuses: 429 (Too Many Requests), 503 (Service Unavailable), 5xx (Server Errors) func isRetryableStatus(status int) bool { diff --git a/telemetry/exporter_test.go b/telemetry/exporter_test.go index 864f0c7e..10156f82 100644 --- a/telemetry/exporter_test.go +++ b/telemetry/exporter_test.go @@ -243,57 +243,6 @@ func TestExport_CircuitBreakerOpen(t *testing.T) { } } -func TestToExportedMetric_TagFiltering(t *testing.T) { - metric := &telemetryMetric{ - metricType: "connection", - timestamp: time.Date(2026, 1, 30, 10, 0, 0, 0, time.UTC), - workspaceID: "test-workspace", - sessionID: "test-session", - statementID: "test-statement", - latencyMs: 100, - errorType: "test-error", - tags: map[string]interface{}{ - "workspace.id": "ws-123", // Should be exported - "driver.version": "1.0.0", // Should be exported - "server.address": "localhost:8080", // Should NOT be exported (local only) - "unknown.tag": "value", // Should NOT be exported - }, - } - - exported := metric.toExportedMetric() - - // Verify basic fields - if exported.MetricType != "connection" { - t.Errorf("Expected MetricType 'connection', got %s", exported.MetricType) - } - - if exported.WorkspaceID != "test-workspace" { - t.Errorf("Expected WorkspaceID 'test-workspace', got %s", exported.WorkspaceID) - } - - // Verify timestamp format - if exported.Timestamp != "2026-01-30T10:00:00Z" { - t.Errorf("Expected timestamp '2026-01-30T10:00:00Z', got %s", exported.Timestamp) - } - - // Verify tag filtering - if _, ok := exported.Tags["workspace.id"]; !ok { - t.Error("Expected 'workspace.id' tag to be exported") - } - - if _, ok := exported.Tags["driver.version"]; !ok { - t.Error("Expected 'driver.version' tag to be exported") - } - - if _, ok := exported.Tags["server.address"]; ok { - t.Error("Expected 'server.address' tag to NOT be exported (local only)") - } - - if _, ok := exported.Tags["unknown.tag"]; ok { - t.Error("Expected 'unknown.tag' to NOT be exported") - } -} - func TestIsRetryableStatus(t *testing.T) { tests := []struct { status int diff --git a/telemetry/integration_test.go b/telemetry/integration_test.go index 13a6c8b3..20bd2fc0 100644 --- a/telemetry/integration_test.go +++ b/telemetry/integration_test.go @@ -149,39 +149,6 @@ func TestIntegration_CircuitBreakerOpening(t *testing.T) { } } -// TestIntegration_OptInPriority_ExplicitOptOut tests explicit opt-out. -func TestIntegration_OptInPriority_ExplicitOptOut(t *testing.T) { - cfg := &Config{ - EnableTelemetry: false, // Priority 1 (client): Explicit opt-out - BatchSize: 100, - FlushInterval: 5 * time.Second, - MaxRetries: 3, - RetryDelay: 100 * time.Millisecond, - } - - httpClient := &http.Client{Timeout: 5 * time.Second} - - // Server that returns enabled - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resp := map[string]interface{}{ - "flags": map[string]bool{ - "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true, - }, - } - _ = json.NewEncoder(w).Encode(resp) - })) - defer server.Close() - - ctx := context.Background() - - // Should be disabled due to explicit opt-out - result := isTelemetryEnabled(ctx, cfg, server.URL, "test-version", httpClient) - - if result { - t.Error("Expected telemetry to be disabled by explicit opt-out") - } -} - // TestIntegration_PrivacyCompliance verifies no sensitive data is collected. func TestIntegration_PrivacyCompliance_NoQueryText(t *testing.T) { cfg := DefaultConfig() @@ -240,72 +207,318 @@ func TestIntegration_PrivacyCompliance_NoQueryText(t *testing.T) { t.Log("Privacy compliance test passed: sensitive data not present in payload") } -// TestIntegration_FieldMapping verifies that only known metric fields are exported -// in the TelemetryRequest format (no generic tag pass-through). -func TestIntegration_FieldMapping(t *testing.T) { +// TestIntegration_TelemetryEventCorrectnessAllFields verifies that every field of the +// TelemetryRequest and nested TelemetryFrontendLog is correctly populated and present +// when a metric is exported. This is the canonical correctness check for the wire format. +func TestIntegration_TelemetryEventCorrectnessAllFields(t *testing.T) { + const ( + testDriverVersion = "9.9.9-test" + testSessionID = "sess-correctness-123" + testStatementID = "stmt-correctness-456" + testLatencyMs = int64(123) + testOperationType = "EXECUTE_STATEMENT" + testChunkCount = 7 + testPollCount = 4 + testErrorName = "NETWORK_ERROR" + ) + cfg := DefaultConfig() cfg.FlushInterval = 50 * time.Millisecond httpClient := &http.Client{Timeout: 5 * time.Second} - var capturedRequest TelemetryRequest + var mu sync.Mutex + var capturedReq TelemetryRequest server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) - _ = json.Unmarshal(body, &capturedRequest) + mu.Lock() + _ = json.Unmarshal(body, &capturedReq) + mu.Unlock() w.WriteHeader(http.StatusOK) })) defer server.Close() - exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) + exporter := newTelemetryExporter(server.URL, testDriverVersion, httpClient, cfg) metric := &telemetryMetric{ - metricType: "connection", + metricType: "operation", timestamp: time.Now(), - workspaceID: "ws-test", - sessionID: "sess-1", - latencyMs: 42, + sessionID: testSessionID, + statementID: testStatementID, + latencyMs: testLatencyMs, + errorType: testErrorName, tags: map[string]interface{}{ - "chunk_count": 3, - "bytes_downloaded": int64(1024), - "unknown.tag": "value", // should NOT appear in output + TagOperationType: testOperationType, + TagChunkCount: testChunkCount, + TagPollCount: testPollCount, }, } ctx := context.Background() exporter.export(ctx, []*telemetryMetric{metric}) + time.Sleep(200 * time.Millisecond) - time.Sleep(150 * time.Millisecond) + mu.Lock() + req := capturedReq + mu.Unlock() - if len(capturedRequest.ProtoLogs) == 0 { - t.Fatal("Expected at least one ProtoLog entry") + // --- TelemetryRequest top-level fields --- + if req.UploadTime == 0 { + t.Error("TelemetryRequest.UploadTime must be non-zero") + } + if req.Items == nil { + t.Error("TelemetryRequest.Items must not be nil (required by server schema)") + } + if len(req.ProtoLogs) == 0 { + t.Fatal("TelemetryRequest.ProtoLogs must contain at least one entry") } - // Each ProtoLog entry is a JSON-encoded TelemetryFrontendLog. - var log TelemetryFrontendLog - if err := json.Unmarshal([]byte(capturedRequest.ProtoLogs[0]), &log); err != nil { - t.Fatalf("Failed to unmarshal ProtoLog: %v", err) + // --- Parse TelemetryFrontendLog from ProtoLogs[0] --- + var frontendLog TelemetryFrontendLog + if err := json.Unmarshal([]byte(req.ProtoLogs[0]), &frontendLog); err != nil { + t.Fatalf("Failed to unmarshal ProtoLogs[0] as TelemetryFrontendLog: %v", err) } + // FrontendLogEventID must be generated (non-empty, unique timestamp-based ID) + if frontendLog.FrontendLogEventID == "" { + t.Error("TelemetryFrontendLog.FrontendLogEventID must be non-empty") + } + + // --- Context / ClientContext --- + if frontendLog.Context == nil { + t.Fatal("TelemetryFrontendLog.Context must not be nil") + } + if frontendLog.Context.ClientContext == nil { + t.Fatal("FrontendLogContext.ClientContext must not be nil") + } + cc := frontendLog.Context.ClientContext + if cc.ClientType != "golang" { + t.Errorf("ClientContext.ClientType must be %q, got %q", "golang", cc.ClientType) + } + if cc.ClientVersion != testDriverVersion { + t.Errorf("ClientContext.ClientVersion must be %q, got %q", testDriverVersion, cc.ClientVersion) + } + + // --- Entry / SQLDriverLog --- + if frontendLog.Entry == nil { + t.Fatal("TelemetryFrontendLog.Entry must not be nil") + } + if frontendLog.Entry.SQLDriverLog == nil { + t.Fatal("FrontendLogEntry.SQLDriverLog must not be nil") + } + ev := frontendLog.Entry.SQLDriverLog + + if ev.SessionID != testSessionID { + t.Errorf("TelemetryEvent.SessionID must be %q, got %q", testSessionID, ev.SessionID) + } + if ev.SQLStatementID != testStatementID { + t.Errorf("TelemetryEvent.SQLStatementID must be %q, got %q", testStatementID, ev.SQLStatementID) + } + if ev.OperationLatencyMs != testLatencyMs { + t.Errorf("TelemetryEvent.OperationLatencyMs must be %d, got %d", testLatencyMs, ev.OperationLatencyMs) + } + + // --- SystemConfiguration --- + if ev.SystemConfiguration == nil { + t.Fatal("TelemetryEvent.SystemConfiguration must not be nil") + } + sc := ev.SystemConfiguration + if sc.DriverName != "databricks-sql-go" { + t.Errorf("SystemConfiguration.DriverName must be %q, got %q", "databricks-sql-go", sc.DriverName) + } + if sc.DriverVersion != testDriverVersion { + t.Errorf("SystemConfiguration.DriverVersion must be %q, got %q", testDriverVersion, sc.DriverVersion) + } + if sc.RuntimeName != "go" { + t.Errorf("SystemConfiguration.RuntimeName must be %q, got %q", "go", sc.RuntimeName) + } + if sc.RuntimeVersion == "" { + t.Error("SystemConfiguration.RuntimeVersion must be non-empty") + } + if sc.OSName == "" { + t.Error("SystemConfiguration.OSName must be non-empty") + } + if sc.OSArch == "" { + t.Error("SystemConfiguration.OSArch must be non-empty") + } + if sc.CharSetEncoding != "UTF-8" { + t.Errorf("SystemConfiguration.CharSetEncoding must be %q, got %q", "UTF-8", sc.CharSetEncoding) + } + if sc.ProcessName == "" { + t.Error("SystemConfiguration.ProcessName must be non-empty") + } + + // --- SQLOperation / OperationDetail --- + if ev.SQLOperation == nil { + t.Fatal("TelemetryEvent.SQLOperation must not be nil for operation metrics with tags") + } + if ev.SQLOperation.OperationDetail == nil { + t.Fatal("SQLExecutionEvent.OperationDetail must not be nil when operation_type tag is set") + } + od := ev.SQLOperation.OperationDetail + if od.OperationType != testOperationType { + t.Errorf("OperationDetail.OperationType must be %q, got %q", testOperationType, od.OperationType) + } + if od.NOperationStatusCalls != int32(testPollCount) { + t.Errorf("OperationDetail.NOperationStatusCalls must be %d, got %d", testPollCount, od.NOperationStatusCalls) + } + + // ChunkDetails must be populated for chunk_count tag + if ev.SQLOperation.ChunkDetails == nil { + t.Fatal("SQLExecutionEvent.ChunkDetails must not be nil when chunk_count tag is set") + } + if ev.SQLOperation.ChunkDetails.TotalChunksIterated != int32(testChunkCount) { + t.Errorf("ChunkDetails.TotalChunksIterated must be %d, got %d", testChunkCount, ev.SQLOperation.ChunkDetails.TotalChunksIterated) + } + + // --- ErrorInfo --- + if ev.ErrorInfo == nil { + t.Fatal("TelemetryEvent.ErrorInfo must not be nil when errorType is set") + } + if ev.ErrorInfo.ErrorName != testErrorName { + t.Errorf("DriverErrorInfo.ErrorName must be %q, got %q", testErrorName, ev.ErrorInfo.ErrorName) + } + + t.Log("Telemetry event correctness check passed: all fields verified") +} + +// TestIntegration_OperationLatencyMs_ZeroNotOmitted verifies that OperationLatencyMs=0 +// is serialised as "operation_latency_ms":0 in the JSON payload, not omitted. +// +// Regression test: the field previously had `json:"operation_latency_ms,omitempty"` which +// caused 0ms latency (e.g. CloseOperation that completes in <1ms) to appear as null in +// the Databricks telemetry table. +func TestIntegration_OperationLatencyMs_ZeroNotOmitted(t *testing.T) { + cfg := DefaultConfig() + cfg.FlushInterval = 50 * time.Millisecond + httpClient := &http.Client{Timeout: 5 * time.Second} + + var mu sync.Mutex + var capturedReq TelemetryRequest + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + mu.Lock() + _ = json.Unmarshal(body, &capturedReq) + mu.Unlock() + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) + + // latencyMs=0 — simulates a CloseOperation that completed in <1ms. + metric := &telemetryMetric{ + metricType: "operation", + timestamp: time.Now(), + sessionID: "sess-zero-latency", + statementID: "stmt-zero-latency", + latencyMs: 0, // <1ms rounded to 0 + } + + ctx := context.Background() + exporter.export(ctx, []*telemetryMetric{metric}) + time.Sleep(200 * time.Millisecond) + + mu.Lock() + req := capturedReq + mu.Unlock() + + if len(req.ProtoLogs) == 0 { + t.Fatal("expected ProtoLogs to be non-empty") + } + + // Verify the raw JSON contains the key with value 0, not absent. + raw := req.ProtoLogs[0] + if !strings.Contains(raw, `"operation_latency_ms":0`) { + t.Errorf("expected raw JSON to contain \"operation_latency_ms\":0, got: %s", raw) + } + + // Also verify via struct parse. + var log TelemetryFrontendLog + if err := json.Unmarshal([]byte(raw), &log); err != nil { + t.Fatalf("failed to unmarshal ProtoLogs[0]: %v", err) + } if log.Entry == nil || log.Entry.SQLDriverLog == nil { - t.Fatal("Expected SQLDriverLog to be populated") + t.Fatal("Entry.SQLDriverLog must not be nil") } + if log.Entry.SQLDriverLog.OperationLatencyMs != 0 { + t.Errorf("expected OperationLatencyMs=0, got %d", log.Entry.SQLDriverLog.OperationLatencyMs) + } + + t.Log("OperationLatencyMs=0 correctly serialised (omitempty fix verified)") +} + +// TestIntegration_ChunkTotalPresent_DerivedFromChunkCount verifies that when the +// "chunk_total_present" tag is explicitly set (e.g. derived from r.chunkCount in +// rows.Close()), it is propagated to ChunkDetails.TotalChunksPresent in the payload. +// +// This covers the fix for paginated CloudFetch where the server never reports the +// grand total across all FetchResults calls; the driver derives it from chunkCount. +func TestIntegration_ChunkTotalPresent_DerivedFromChunkCount(t *testing.T) { + cfg := DefaultConfig() + cfg.FlushInterval = 50 * time.Millisecond + httpClient := &http.Client{Timeout: 5 * time.Second} + + var mu sync.Mutex + var capturedReq TelemetryRequest - entry := log.Entry.SQLDriverLog - if entry.SessionID != "sess-1" { - t.Errorf("Expected session_id=sess-1, got %q", entry.SessionID) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + mu.Lock() + _ = json.Unmarshal(body, &capturedReq) + mu.Unlock() + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + const ( + totalChunksPresent = 32 + totalChunksIterated = 32 + ) + + exporter := newTelemetryExporter(server.URL, "test-version", httpClient, cfg) + metric := &telemetryMetric{ + metricType: "operation", + timestamp: time.Now(), + sessionID: "sess-chunks", + statementID: "stmt-chunks", + latencyMs: 500, + tags: map[string]interface{}{ + TagChunkCount: totalChunksIterated, // total pages fetched + TagChunkTotalPresent: totalChunksPresent, // derived from r.chunkCount + }, } - if entry.OperationLatencyMs != 42 { - t.Errorf("Expected latency=42, got %d", entry.OperationLatencyMs) + + ctx := context.Background() + exporter.export(ctx, []*telemetryMetric{metric}) + time.Sleep(200 * time.Millisecond) + + mu.Lock() + req := capturedReq + mu.Unlock() + + if len(req.ProtoLogs) == 0 { + t.Fatal("expected ProtoLogs to be non-empty") } - if entry.SQLOperation != nil && entry.SQLOperation.ChunkDetails != nil { - if entry.SQLOperation.ChunkDetails.TotalChunksIterated != 3 { - t.Errorf("Expected total_chunks_iterated=3, got %d", entry.SQLOperation.ChunkDetails.TotalChunksIterated) - } + + var log TelemetryFrontendLog + if err := json.Unmarshal([]byte(req.ProtoLogs[0]), &log); err != nil { + t.Fatalf("failed to unmarshal ProtoLogs[0]: %v", err) } - // unknown.tag must not appear anywhere in the serialised output - if strings.Contains(capturedRequest.ProtoLogs[0], "unknown.tag") { - t.Error("unknown.tag must not be exported") + ev := log.Entry.SQLDriverLog + if ev == nil { + t.Fatal("SQLDriverLog must not be nil") + } + if ev.SQLOperation == nil || ev.SQLOperation.ChunkDetails == nil { + t.Fatal("ChunkDetails must not be nil when chunk_count tag is set") } - t.Log("Field mapping test passed") + cd := ev.SQLOperation.ChunkDetails + if cd.TotalChunksIterated != int32(totalChunksIterated) { + t.Errorf("TotalChunksIterated: expected %d, got %d", totalChunksIterated, cd.TotalChunksIterated) + } + if cd.TotalChunksPresent != int32(totalChunksPresent) { + t.Errorf("TotalChunksPresent: expected %d, got %d", totalChunksPresent, cd.TotalChunksPresent) + } } diff --git a/telemetry/interceptor.go b/telemetry/interceptor.go index 5ef01b1a..a2abfa92 100644 --- a/telemetry/interceptor.go +++ b/telemetry/interceptor.go @@ -15,11 +15,26 @@ type Interceptor struct { } // metricContext holds metric collection state in context. +// +// Thread safety: metricContext.tags is NOT protected by its own mutex. +// It relies on database/sql's closemu for serialization: +// - Writes (AddTag) happen during rows.Next() under closemu.RLock +// - Final reads (AfterExecute/closeCallback) happen during rows.Close() under closemu.Lock +// +// This ensures mutual exclusion between concurrent Next() and Close() calls, +// including Close() triggered by database/sql's awaitDone goroutine on context +// cancellation. Do NOT access tags outside of these serialized paths. type metricContext struct { sessionID string statementID string startTime time.Time tags map[string]interface{} + + // capturedLatencyMs is set by FinalizeLatency() to freeze the execute-phase + // latency before row iteration begins. AfterExecute uses this value instead + // of re-measuring from startTime (which would include row-scan time). + capturedLatencyMs int64 + latencyCaptured bool } type contextKey int @@ -83,6 +98,22 @@ func (i *Interceptor) BeforeExecuteWithTime(ctx context.Context, sessionID strin return withMetricContext(ctx, mc) } +// FinalizeLatency freezes the elapsed time as the statement's execution latency. +// Call this when the execute phase is complete (i.e. when QueryContext returns) so +// that AfterExecute, even if called later from rows.Close(), still reports +// execute-only latency rather than total latency that would include row iteration. +// Exported for use by the driver package. +func (i *Interceptor) FinalizeLatency(ctx context.Context) { + if !i.enabled { + return + } + mc := getMetricContext(ctx) + if mc != nil && !mc.latencyCaptured { + mc.capturedLatencyMs = time.Since(mc.startTime).Milliseconds() + mc.latencyCaptured = true + } +} + // AfterExecute is called after statement execution. // Records the metric with timing and error information. // Exported for use by the driver package. @@ -103,12 +134,20 @@ func (i *Interceptor) AfterExecute(ctx context.Context, err error) { } }() + // Use pre-captured latency if available (set by FinalizeLatency), otherwise + // fall back to measuring from startTime (covers the error-path where + // FinalizeLatency was never called). + latencyMs := time.Since(mc.startTime).Milliseconds() + if mc.latencyCaptured { + latencyMs = mc.capturedLatencyMs + } + metric := &telemetryMetric{ metricType: "statement", timestamp: mc.startTime, sessionID: mc.sessionID, statementID: mc.statementID, - latencyMs: time.Since(mc.startTime).Milliseconds(), + latencyMs: latencyMs, tags: mc.tags, } @@ -167,8 +206,10 @@ func (i *Interceptor) CompleteStatement(ctx context.Context, statementID string, } // RecordOperation records an operation with type, latency, and optional error. +// statementID is included when the operation is scoped to a specific statement (e.g. CLOSE_STATEMENT). +// Pass "" for session-level operations (CREATE_SESSION, DELETE_SESSION). // Exported for use by the driver package. -func (i *Interceptor) RecordOperation(ctx context.Context, sessionID string, operationType string, latencyMs int64, err error) { +func (i *Interceptor) RecordOperation(ctx context.Context, sessionID string, statementID string, operationType string, latencyMs int64, err error) { if !i.enabled { return } @@ -180,11 +221,12 @@ func (i *Interceptor) RecordOperation(ctx context.Context, sessionID string, ope }() metric := &telemetryMetric{ - metricType: "operation", - timestamp: time.Now(), - sessionID: sessionID, - latencyMs: latencyMs, - tags: map[string]interface{}{"operation_type": operationType}, + metricType: "operation", + timestamp: time.Now(), + sessionID: sessionID, + statementID: statementID, + latencyMs: latencyMs, + tags: map[string]interface{}{TagOperationType: operationType}, } if err != nil { diff --git a/telemetry/request.go b/telemetry/request.go index f516406d..6390a363 100644 --- a/telemetry/request.go +++ b/telemetry/request.go @@ -48,7 +48,7 @@ type TelemetryEvent struct { VolumeOperation *VolumeOperationEvent `json:"vol_operation,omitempty"` SQLOperation *SQLExecutionEvent `json:"sql_operation,omitempty"` ErrorInfo *DriverErrorInfo `json:"error_info,omitempty"` - OperationLatencyMs int64 `json:"operation_latency_ms,omitempty"` + OperationLatencyMs int64 `json:"operation_latency_ms"` } // DriverSystemConfiguration maps to DriverSystemConfiguration in the proto schema. @@ -177,20 +177,32 @@ func createTelemetryRequest(metrics []*telemetryMetric, driverVersion string) (* if tags := metric.tags; tags != nil { sqlOp := &SQLExecutionEvent{} - if v, ok := tags["result.format"].(string); ok { + if v, ok := tags[TagResultFormat].(string); ok { sqlOp.ExecutionResult = v } - if chunkCount, ok := tags["chunk_count"].(int); ok && chunkCount > 0 { + if chunkCount, ok := tags[TagChunkCount].(int); ok && chunkCount > 0 { sqlOp.ChunkDetails = &ChunkDetails{ TotalChunksIterated: int32(chunkCount), //nolint:gosec // chunk count is always small } + if v, ok := tags[TagChunkInitialLatencyMs].(int64); ok && v > 0 { + sqlOp.ChunkDetails.InitialChunkLatencyMs = v + } + if v, ok := tags[TagChunkSlowestLatencyMs].(int64); ok && v > 0 { + sqlOp.ChunkDetails.SlowestChunkLatencyMs = v + } + if v, ok := tags[TagChunkSumLatencyMs].(int64); ok && v > 0 { + sqlOp.ChunkDetails.SumChunksDownloadTimeMs = v + } + if v, ok := tags[TagChunkTotalPresent].(int); ok && v > 0 { + sqlOp.ChunkDetails.TotalChunksPresent = int32(v) //nolint:gosec // chunk count is always small + } } - if opType, ok := tags["operation_type"].(string); ok { + if opType, ok := tags[TagOperationType].(string); ok { detail := &OperationDetail{ OperationType: opType, } - if pollCount, ok := tags["poll_count"].(int); ok { + if pollCount, ok := tags[TagPollCount].(int); ok { detail.NOperationStatusCalls = int32(pollCount) //nolint:gosec // poll count is always small } sqlOp.OperationDetail = detail diff --git a/telemetry/tags.go b/telemetry/tags.go index f4b391f5..32a81eb2 100644 --- a/telemetry/tags.go +++ b/telemetry/tags.go @@ -10,15 +10,26 @@ const ( TagServerAddress = "server.address" // Not exported to Databricks ) -// Tag names for statement metrics +// Tag names for statement metrics. +// Values must match the keys used in metricContext.tags and read by +// createTelemetryRequest / aggregator — keep them in sync. const ( - TagStatementID = "statement.id" - TagResultFormat = "result.format" - TagResultChunkCount = "result.chunk_count" - TagResultBytesDownloaded = "result.bytes_downloaded" - TagCompressionEnabled = "result.compression_enabled" - TagPollCount = "poll.count" - TagPollLatency = "poll.latency_ms" + TagStatementID = "statement.id" + TagResultFormat = "result.format" + TagChunkCount = "chunk_count" + TagBytesDownloaded = "bytes_downloaded" + TagCompressionEnabled = "result.compression_enabled" + TagOperationType = "operation_type" + TagPollCount = "poll_count" + TagPollLatency = "poll.latency_ms" +) + +// Tag names for chunk timing metrics +const ( + TagChunkInitialLatencyMs = "chunk_initial_latency_ms" + TagChunkSlowestLatencyMs = "chunk_slowest_latency_ms" + TagChunkSumLatencyMs = "chunk_sum_latency_ms" + TagChunkTotalPresent = "chunk_total_present" ) // Tag names for error metrics @@ -33,69 +44,3 @@ const ( TagFeatureLZ4 = "feature.lz4" TagFeatureDirectResults = "feature.direct_results" ) - -// tagExportScope defines where a tag can be exported. -type tagExportScope int - -const ( - exportNone tagExportScope = 0 - exportLocal = 1 << iota - exportDatabricks - exportAll = exportLocal | exportDatabricks -) - -// tagDefinition defines a metric tag and its export scope. -type tagDefinition struct { - name string - exportScope tagExportScope - description string - required bool -} - -// connectionTags returns tags allowed for connection events. -func connectionTags() []tagDefinition { - return []tagDefinition{ - {TagWorkspaceID, exportDatabricks, "Databricks workspace ID", true}, - {TagSessionID, exportDatabricks, "Connection session ID", true}, - {TagDriverVersion, exportAll, "Driver version", false}, - {TagDriverOS, exportAll, "Operating system", false}, - {TagDriverRuntime, exportAll, "Go runtime version", false}, - {TagFeatureCloudFetch, exportDatabricks, "CloudFetch enabled", false}, - {TagFeatureLZ4, exportDatabricks, "LZ4 compression enabled", false}, - {TagServerAddress, exportLocal, "Server address (local only)", false}, - } -} - -// statementTags returns tags allowed for statement events. -func statementTags() []tagDefinition { - return []tagDefinition{ - {TagStatementID, exportDatabricks, "Statement ID", true}, - {TagSessionID, exportDatabricks, "Session ID", true}, - {TagResultFormat, exportDatabricks, "Result format", false}, - {TagResultChunkCount, exportDatabricks, "Chunk count", false}, - {TagResultBytesDownloaded, exportDatabricks, "Bytes downloaded", false}, - {TagCompressionEnabled, exportDatabricks, "Compression enabled", false}, - {TagPollCount, exportDatabricks, "Poll count", false}, - {TagPollLatency, exportDatabricks, "Poll latency", false}, - } -} - -// shouldExportToDatabricks returns true if tag should be exported to Databricks. -func shouldExportToDatabricks(metricType, tagName string) bool { - var tags []tagDefinition - switch metricType { - case "connection": - tags = connectionTags() - case "statement": - tags = statementTags() - default: - return false - } - - for _, tag := range tags { - if tag.name == tagName { - return tag.exportScope&exportDatabricks != 0 - } - } - return false -} diff --git a/telemetry/tags_test.go b/telemetry/tags_test.go index 268dc8b9..b3057b6c 100644 --- a/telemetry/tags_test.go +++ b/telemetry/tags_test.go @@ -1,178 +1,30 @@ package telemetry -import ( - "testing" -) +import "testing" -func TestConnectionTags(t *testing.T) { - tags := connectionTags() - - if len(tags) == 0 { - t.Fatal("Expected connection tags to be defined") - } - - // Verify required tags are present - requiredTags := []string{TagWorkspaceID, TagSessionID} - for _, requiredTag := range requiredTags { - found := false - for _, tag := range tags { - if tag.name == requiredTag && tag.required { - found = true - break - } - } - if !found { - t.Errorf("Required tag %s not found in connection tags", requiredTag) - } - } - - // Verify server address is local-only - for _, tag := range tags { - if tag.name == TagServerAddress { - if tag.exportScope&exportDatabricks != 0 { - t.Error("server.address should not be exported to Databricks") - } - if tag.exportScope&exportLocal == 0 { - t.Error("server.address should be exported locally") - } - } - } -} - -func TestStatementTags(t *testing.T) { - tags := statementTags() - - if len(tags) == 0 { - t.Fatal("Expected statement tags to be defined") - } - - // Verify required tags are present - requiredTags := []string{TagStatementID, TagSessionID} - for _, requiredTag := range requiredTags { - found := false - for _, tag := range tags { - if tag.name == requiredTag && tag.required { - found = true - break - } - } - if !found { - t.Errorf("Required tag %s not found in statement tags", requiredTag) - } - } -} - -func TestShouldExportToDatabricks_ConnectionTags(t *testing.T) { - tests := []struct { - name string - tagName string - expected bool +func TestTagConstants_NonEmpty(t *testing.T) { + // Verify all exported tag constants have non-empty values. + tags := []struct { + name string + value string }{ - {"workspace.id should export", TagWorkspaceID, true}, - {"session.id should export", TagSessionID, true}, - {"driver.version should export", TagDriverVersion, true}, - {"server.address should NOT export", TagServerAddress, false}, - {"unknown tag should NOT export", "unknown.tag", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := shouldExportToDatabricks("connection", tt.tagName) - if result != tt.expected { - t.Errorf("shouldExportToDatabricks(%q, %q) = %v, want %v", - "connection", tt.tagName, result, tt.expected) - } - }) - } -} - -func TestShouldExportToDatabricks_StatementTags(t *testing.T) { - tests := []struct { - name string - tagName string - expected bool - }{ - {"statement.id should export", TagStatementID, true}, - {"session.id should export", TagSessionID, true}, - {"result.format should export", TagResultFormat, true}, - {"result.chunk_count should export", TagResultChunkCount, true}, - {"unknown tag should NOT export", "unknown.tag", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := shouldExportToDatabricks("statement", tt.tagName) - if result != tt.expected { - t.Errorf("shouldExportToDatabricks(%q, %q) = %v, want %v", - "statement", tt.tagName, result, tt.expected) - } - }) - } -} - -func TestShouldExportToDatabricks_UnknownMetricType(t *testing.T) { - result := shouldExportToDatabricks("unknown_type", TagWorkspaceID) - if result { - t.Error("shouldExportToDatabricks with unknown metric type should return false") - } -} - -func TestExportScopeFlags(t *testing.T) { - // Test bit flag operations - if exportNone != 0 { - t.Error("exportNone should be 0") - } - - if exportAll&exportLocal == 0 { - t.Error("exportAll should include exportLocal") - } - - if exportAll&exportDatabricks == 0 { - t.Error("exportAll should include exportDatabricks") - } - - if exportLocal&exportDatabricks != 0 { - t.Error("exportLocal and exportDatabricks should be separate flags") - } -} - -func TestTagDefinitionStructure(t *testing.T) { - // Verify tag definition structure is correct - tags := connectionTags() - for _, tag := range tags { - if tag.name == "" { - t.Error("Tag name should not be empty") - } - if tag.description == "" { - t.Error("Tag description should not be empty") - } - // exportScope can be 0 (exportNone) which is valid - } -} - -func TestAllConnectionTagsHaveValidScopes(t *testing.T) { - tags := connectionTags() - for _, tag := range tags { - // Each tag should have at least one export scope set - // (except exportNone which is 0) - if tag.exportScope != exportNone { - hasValidScope := (tag.exportScope&exportLocal != 0) || (tag.exportScope&exportDatabricks != 0) - if !hasValidScope { - t.Errorf("Tag %s has invalid export scope: %d", tag.name, tag.exportScope) - } - } - } -} - -func TestAllStatementTagsHaveValidScopes(t *testing.T) { - tags := statementTags() - for _, tag := range tags { - // Each tag should have at least one export scope set - if tag.exportScope != exportNone { - hasValidScope := (tag.exportScope&exportLocal != 0) || (tag.exportScope&exportDatabricks != 0) - if !hasValidScope { - t.Errorf("Tag %s has invalid export scope: %d", tag.name, tag.exportScope) - } + {"TagWorkspaceID", TagWorkspaceID}, + {"TagSessionID", TagSessionID}, + {"TagDriverVersion", TagDriverVersion}, + {"TagStatementID", TagStatementID}, + {"TagResultFormat", TagResultFormat}, + {"TagChunkCount", TagChunkCount}, + {"TagBytesDownloaded", TagBytesDownloaded}, + {"TagOperationType", TagOperationType}, + {"TagPollCount", TagPollCount}, + {"TagChunkInitialLatencyMs", TagChunkInitialLatencyMs}, + {"TagChunkSlowestLatencyMs", TagChunkSlowestLatencyMs}, + {"TagChunkSumLatencyMs", TagChunkSumLatencyMs}, + {"TagChunkTotalPresent", TagChunkTotalPresent}, + } + for _, tt := range tags { + if tt.value == "" { + t.Errorf("Tag constant %s must not be empty", tt.name) } } }