Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions auth/oauth/u2m/u2m.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,19 @@ func GetConfig(ctx context.Context, hostName, clientID, clientSecret, callbackUR
}

config := oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Endpoint: endpoint,
RedirectURL: callbackURL,
Scopes: scopes,
ClientID: clientID,
Endpoint: endpoint,
RedirectURL: callbackURL,
Scopes: scopes,
}
if clientSecret != "" {
config.ClientSecret = clientSecret
} else {
// For U2M (public apps using PKCE), force AuthStyleInParams to avoid
// sending Basic auth with empty password. AuthStyleInHeader sends
// "Authorization: Basic base64(clientID:)" which the server rejects
// with "Public app should not use a client secret".
config.Endpoint.AuthStyle = oauth2.AuthStyleInParams
}

return config, nil
Expand Down
25 changes: 25 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"database/sql/driver"
"fmt"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -76,12 +77,16 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
}
log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "")

// Extract SPOG routing headers from ?o= in HTTPPath
spogHeaders := extractSpogHeaders(c.cfg.HTTPPath)

// Initialize telemetry: client config overlay decides; if unset, feature flags decide
conn.telemetry = telemetry.InitializeForConnection(
ctx,
c.cfg.Host,
c.client,
c.cfg.EnableTelemetry,
spogHeaders,
)
if conn.telemetry != nil {
log.Debug().Msg("telemetry initialized for connection")
Expand All @@ -107,6 +112,7 @@ func NewConnector(options ...ConnOption) (driver.Connector, error) {
// config with default options
cfg := config.WithDefaults()
cfg.DriverVersion = DriverVersion
telemetry.SetDriverVersion(DriverVersion)

for _, opt := range options {
opt(cfg)
Expand All @@ -117,6 +123,25 @@ func NewConnector(options ...ConnOption) (driver.Connector, error) {
return &connector{cfg: cfg, client: client}, nil
}

// extractSpogHeaders extracts ?o=<workspaceId> from httpPath and returns
// an x-databricks-org-id header for SPOG routing.
func extractSpogHeaders(httpPath string) map[string]string {
if !strings.Contains(httpPath, "?") {
return nil
}
// Parse query string from httpPath
parts := strings.SplitN(httpPath, "?", 2)
params, err := url.ParseQuery(parts[1])
if err != nil {
return nil
}
orgID := params.Get("o")
if orgID == "" {
return nil
}
return map[string]string{"x-databricks-org-id": orgID}
}

func withUserConfig(ucfg config.UserConfig) ConnOption {
return func(c *config.Config) {
c.UserConfig = ucfg
Expand Down
4 changes: 2 additions & 2 deletions telemetry/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func ParseTelemetryConfig(params map[string]string) *Config {
//
// Returns:
// - bool: true if telemetry should be enabled, false otherwise
func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, httpClient *http.Client) bool {
func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, httpClient *http.Client, extraHeaders map[string]string) bool {
// Priority 1: Client explicitly set (overrides server)
if cfg.EnableTelemetry.IsSet() {
val, _ := cfg.EnableTelemetry.Get()
Expand All @@ -111,7 +111,7 @@ func isTelemetryEnabled(ctx context.Context, cfg *Config, host string, httpClien

// Priority 2: Check server-side feature flag
flagCache := getFeatureFlagCache()
serverEnabled, err := flagCache.isTelemetryEnabled(ctx, host, httpClient)
serverEnabled, err := flagCache.isTelemetryEnabled(ctx, host, httpClient, extraHeaders)
if err != nil {
// Priority 3: Fail-safe default (disabled)
return false
Expand Down
50 changes: 17 additions & 33 deletions telemetry/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package telemetry

import (
"context"
"encoding/json"

"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -206,12 +206,8 @@ func TestIsTelemetryEnabled_ClientOverrideEnabled(t *testing.T) {
// Setup: Create a server that returns disabled
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Server says disabled, but client override should win
resp := map[string]interface{}{
"flags": map[string]bool{
"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": false,
},
}
_ = json.NewEncoder(w).Encode(resp)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "false"}]}`))
}))
defer server.Close()

Expand All @@ -228,7 +224,7 @@ func TestIsTelemetryEnabled_ClientOverrideEnabled(t *testing.T) {
defer flagCache.releaseContext(server.URL)

// Client override should bypass server check
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient)
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil)

if !result {
t.Error("Expected telemetry to be enabled when client explicitly sets enableTelemetry=true, got disabled")
Expand All @@ -240,12 +236,8 @@ func TestIsTelemetryEnabled_ClientOverrideDisabled(t *testing.T) {
// Setup: Create a server that returns enabled
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Server says enabled, but client override should win
resp := map[string]interface{}{
"flags": map[string]bool{
"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true,
},
}
_ = json.NewEncoder(w).Encode(resp)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}]}`))
}))
defer server.Close()

Expand All @@ -261,7 +253,7 @@ func TestIsTelemetryEnabled_ClientOverrideDisabled(t *testing.T) {
flagCache.getOrCreateContext(server.URL)
defer flagCache.releaseContext(server.URL)

result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient)
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil)

if result {
t.Error("Expected telemetry to be disabled when client explicitly sets enableTelemetry=false, got enabled")
Expand All @@ -272,12 +264,8 @@ func TestIsTelemetryEnabled_ClientOverrideDisabled(t *testing.T) {
func TestIsTelemetryEnabled_ServerEnabled(t *testing.T) {
// Setup: Create a server that returns enabled
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]interface{}{
"flags": map[string]bool{
"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true,
},
}
_ = json.NewEncoder(w).Encode(resp)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "true"}]}`))
}))
defer server.Close()

Expand All @@ -293,7 +281,7 @@ func TestIsTelemetryEnabled_ServerEnabled(t *testing.T) {
flagCache.getOrCreateContext(server.URL)
defer flagCache.releaseContext(server.URL)

result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient)
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil)

if !result {
t.Error("Expected telemetry to be enabled when server flag is true, got disabled")
Expand All @@ -304,12 +292,8 @@ func TestIsTelemetryEnabled_ServerEnabled(t *testing.T) {
func TestIsTelemetryEnabled_ServerDisabled(t *testing.T) {
// Setup: Create a server that returns disabled
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]interface{}{
"flags": map[string]bool{
"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": false,
},
}
_ = json.NewEncoder(w).Encode(resp)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"flags": [{"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver", "value": "false"}]}`))
}))
defer server.Close()

Expand All @@ -325,7 +309,7 @@ func TestIsTelemetryEnabled_ServerDisabled(t *testing.T) {
flagCache.getOrCreateContext(server.URL)
defer flagCache.releaseContext(server.URL)

result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient)
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil)

if result {
t.Error("Expected telemetry to be disabled when server flag is false, got enabled")
Expand All @@ -340,7 +324,7 @@ func TestIsTelemetryEnabled_FailSafeDefault(t *testing.T) {
httpClient := &http.Client{Timeout: 5 * time.Second}

// No server available, should default to disabled (fail-safe)
result := isTelemetryEnabled(ctx, cfg, "nonexistent-host", httpClient)
result := isTelemetryEnabled(ctx, cfg, "nonexistent-host", httpClient, nil)

if result {
t.Error("Expected telemetry to be disabled when server unavailable (fail-safe), got enabled")
Expand All @@ -367,7 +351,7 @@ func TestIsTelemetryEnabled_ServerError(t *testing.T) {
flagCache.getOrCreateContext(server.URL)
defer flagCache.releaseContext(server.URL)

result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient)
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil)

// On error, should default to disabled (fail-safe)
if result {
Expand All @@ -390,7 +374,7 @@ func TestIsTelemetryEnabled_ServerUnreachable(t *testing.T) {
flagCache.getOrCreateContext(unreachableHost)
defer flagCache.releaseContext(unreachableHost)

result := isTelemetryEnabled(ctx, cfg, unreachableHost, httpClient)
result := isTelemetryEnabled(ctx, cfg, unreachableHost, httpClient, nil)

// On error, should default to disabled (fail-safe)
if result {
Expand Down Expand Up @@ -418,7 +402,7 @@ func TestIsTelemetryEnabled_ClientOverridesServerError(t *testing.T) {
flagCache.getOrCreateContext(server.URL)
defer flagCache.releaseContext(server.URL)

result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient)
result := isTelemetryEnabled(ctx, cfg, server.URL, httpClient, nil)

// Client override should work even when server errors
if !result {
Expand Down
4 changes: 3 additions & 1 deletion telemetry/driver_integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
// - host: Databricks host
// - httpClient: HTTP client for making requests
// - enableTelemetry: Client config overlay (unset = check server flag, true/false = override server)
// - extraHeaders: Additional HTTP headers for SPOG routing (e.g. x-databricks-org-id)
//
// Returns:
// - *Interceptor: Telemetry interceptor if enabled, nil otherwise
Expand All @@ -24,13 +25,14 @@ func InitializeForConnection(
host string,
httpClient *http.Client,
enableTelemetry config.ConfigValue[bool],
extraHeaders map[string]string,
) *Interceptor {
// Create telemetry config and apply client overlay
cfg := DefaultConfig()
cfg.EnableTelemetry = enableTelemetry

// Check if telemetry should be enabled
if !isTelemetryEnabled(ctx, cfg, host, httpClient) {
if !isTelemetryEnabled(ctx, cfg, host, httpClient, extraHeaders) {
return nil
}

Expand Down
Loading
Loading