From f122ec152d995de5caad1e0ee876963c66d73980 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Sun, 19 Apr 2026 07:50:36 -0500 Subject: [PATCH 1/4] =?UTF-8?q?fix:=20routing=20overhaul=20=E2=80=94=20rec?= =?UTF-8?q?onnection=20guards,=20owner=5Fid=20proto,=20/etc/hosts=20DNS,?= =?UTF-8?q?=20deployment=20subnets?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Reconnection panic guards: bounds checks on controlStreams, streams, connections slices - Stack trace logging on reconnection panic recovery - SNIHostname field on EndpointMapping gossip type - owner_id on RouteDeployment/RouteSandbox/Unprovision/ExistingDeployment protos - /etc/hosts preference over DNS-over-HTTPS (cached with 30s TTL) - Respect FailIfLocal for hosts file lookups - ComputeDeploymentSubnet/ComputeDeploymentVIP for deterministic deployment VIPs - Discovery refresh after repeated endpoint failures --- gravity/discovery/gce.go | 11 +- gravity/grpc_client.go | 147 +++++++++++++++++++- gravity/reconnect_discovery_refresh_test.go | 82 +++++++++++ network/subnet.go | 48 +++++++ network/subnet_test.go | 109 +++++++++++++++ 5 files changed, 387 insertions(+), 10 deletions(-) create mode 100644 gravity/reconnect_discovery_refresh_test.go diff --git a/gravity/discovery/gce.go b/gravity/discovery/gce.go index c916b7b..62c9b33 100644 --- a/gravity/discovery/gce.go +++ b/gravity/discovery/gce.go @@ -111,12 +111,13 @@ func (g *GCEDiscoverer) Discover(ctx context.Context) ([]string, error) { if !g.hasTag(inst) { continue } - // Extract the internal IPv4 address. The IPv6 address on GCE - // instances is an overlay address (fd20:…) that is NOT routable - // between machines, causing memberlist peers to be unreachable. - // IPv4 private addresses (10.x.x.x) are always routable within - // the VPC. + // Extract the internal IP, preferring IPv6 (memberlist binds to IPv6 + // when available, so peers must be reachable on their IPv6 address). for _, iface := range inst.NetworkInterfaces { + if iface.IPv6Address != "" { + peers = append(peers, iface.IPv6Address) + break + } if iface.NetworkIP != "" { peers = append(peers, iface.NetworkIP) break diff --git a/gravity/grpc_client.go b/gravity/grpc_client.go index 7fc50da..7912114 100644 --- a/gravity/grpc_client.go +++ b/gravity/grpc_client.go @@ -287,6 +287,7 @@ type GravityClient struct { connected bool reconnecting bool // Tracks if reconnection is in progress endpointReconnecting []atomic.Bool // Per-endpoint reconnect guard (multi-endpoint only) + endpointFailCount []atomic.Int32 connectionIDs []string // Stores connection IDs from server responses connectionIDChan chan string // Channel to signal when connection ID is received helloAckedStreams sync.Map // streamIndex (int) → true: tracks which streams received SessionHelloResponse @@ -392,6 +393,8 @@ type StreamInfo struct { sendMu sync.Mutex // Serializes Send calls on this stream } +const endpointDiscoveryRefreshFailureThreshold int32 = 10 + // StreamManager manages multiple gRPC streams for multiplexing with advanced load balancing type StreamManager struct { // Control streams (one per connection) - using session service @@ -805,7 +808,7 @@ func (g *GravityClient) startMultiEndpoint() error { // If hostname is already an IP address, skip DNS lookup if ip := net.ParseIP(hostname); ip != nil { - ep := &GravityEndpoint{URL: endpointURL} + ep := &GravityEndpoint{URL: endpointURL, TLSServerName: g.preferredTLSServerName(endpointURL)} ep.healthy.Store(false) g.endpoints = append(g.endpoints, ep) continue @@ -840,12 +843,12 @@ func (g *GravityClient) startMultiEndpoint() error { } } else { // DNS returned no IPs, fall back to original URL - ep := &GravityEndpoint{URL: endpointURL} + ep := &GravityEndpoint{URL: endpointURL, TLSServerName: g.preferredTLSServerName(endpointURL)} ep.healthy.Store(false) g.endpoints = append(g.endpoints, ep) } } else { - ep := &GravityEndpoint{URL: endpointURL} + ep := &GravityEndpoint{URL: endpointURL, TLSServerName: g.preferredTLSServerName(endpointURL)} ep.healthy.Store(false) g.endpoints = append(g.endpoints, ep) } @@ -910,6 +913,7 @@ func (g *GravityClient) startMultiEndpoint() error { g.circuitBreakers = make([]*CircuitBreaker, connectionCount) g.connectionURLs = make([]string, connectionCount) g.endpointReconnecting = make([]atomic.Bool, connectionCount) + g.endpointFailCount = make([]atomic.Int32, connectionCount) g.endpointStreamIndices = make(map[string][]int) // Initialize connection health tracking @@ -1649,7 +1653,7 @@ func (g *GravityClient) cycleEndpoint(oldURL, newURL string) { return } - newEp := &GravityEndpoint{URL: newURL} + newEp := &GravityEndpoint{URL: newURL, TLSServerName: g.preferredTLSServerName(newURL)} newEp.healthy.Store(false) g.endpoints[oldIdx] = newEp g.endpointsMu.Unlock() @@ -1692,7 +1696,7 @@ func (g *GravityClient) addEndpoint(newURL string) { } } - newEp := &GravityEndpoint{URL: newURL} + newEp := &GravityEndpoint{URL: newURL, TLSServerName: g.preferredTLSServerName(newURL)} newEp.healthy.Store(false) // will become healthy after reconnection g.endpoints[slotIdx] = newEp g.endpointsMu.Unlock() @@ -1717,6 +1721,9 @@ func (g *GravityClient) addEndpoint(newURL string) { for len(g.endpointReconnecting) <= slotIdx { g.endpointReconnecting = append(g.endpointReconnecting, atomic.Bool{}) } + for len(g.endpointFailCount) <= slotIdx { + g.endpointFailCount = append(g.endpointFailCount, atomic.Int32{}) + } g.mu.Unlock() // Grow stream manager arrays under their respective locks. @@ -3714,6 +3721,134 @@ func (g *GravityClient) hasHealthyEndpoint() bool { return false } +func (g *GravityClient) preferredTLSServerName(endpointURL string) string { + if strings.TrimSpace(endpointURL) == "" { + return "" + } + hostPort, err := g.parseGRPCURL(endpointURL) + if err != nil { + return g.defaultServerName + } + host, _, err := net.SplitHostPort(hostPort) + if err != nil { + return g.defaultServerName + } + if net.ParseIP(host) != nil { + return g.defaultServerName + } + return host +} + +func (g *GravityClient) resetEndpointFailureCount(endpointIndex int) { + g.mu.RLock() + if endpointIndex >= 0 && endpointIndex < len(g.endpointFailCount) { + g.endpointFailCount[endpointIndex].Store(0) + } + g.mu.RUnlock() +} + +func (g *GravityClient) incrementEndpointFailureCount(endpointIndex int) int32 { + g.mu.RLock() + if endpointIndex < 0 || endpointIndex >= len(g.endpointFailCount) { + g.mu.RUnlock() + return 0 + } + count := g.endpointFailCount[endpointIndex].Add(1) + g.mu.RUnlock() + return count +} + +func (g *GravityClient) handleEndpointReconnectFailure(endpointIndex int, currentURL string) string { + failures := g.incrementEndpointFailureCount(endpointIndex) + if failures == 0 { + return currentURL + } + + if failures%endpointDiscoveryRefreshFailureThreshold != 0 { + return currentURL + } + + g.logger.Info("endpoint %d reached %d consecutive reconnection failures; refreshing service discovery", + endpointIndex, failures) + + refreshedURL := g.refreshFailingEndpointFromDiscovery(endpointIndex, currentURL) + if refreshedURL == "" || refreshedURL == currentURL { + g.logger.Info("endpoint %d discovery refresh found no replacement after %d failures", endpointIndex, failures) + return currentURL + } + + g.logger.Info("endpoint %d discovery refresh replaced %s -> %s after %d failures", + endpointIndex, currentURL, refreshedURL, failures) + g.resetEndpointFailureCount(endpointIndex) + return refreshedURL +} + +func (g *GravityClient) refreshFailingEndpointFromDiscovery(endpointIndex int, currentURL string) string { + if g.discoveryResolveFunc == nil { + return "" + } + + candidates := g.discoveryResolveFunc() + if len(candidates) == 0 { + return "" + } + + unique := make([]string, 0, len(candidates)) + seen := make(map[string]bool, len(candidates)) + for _, c := range candidates { + c = strings.TrimSpace(c) + if c == "" || seen[c] { + continue + } + seen[c] = true + unique = append(unique, c) + } + if len(unique) == 0 { + return "" + } + + inUse := make(map[string]bool) + g.endpointsMu.RLock() + for i, ep := range g.endpoints { + if i == endpointIndex || ep == nil || strings.TrimSpace(ep.URL) == "" { + continue + } + inUse[ep.URL] = true + } + g.endpointsMu.RUnlock() + + var replacement string + for _, c := range unique { + if c == currentURL || inUse[c] { + continue + } + replacement = c + break + } + if replacement == "" { + return "" + } + + tlsServerName := g.preferredTLSServerName(replacement) + + g.endpointsMu.Lock() + if endpointIndex >= 0 && endpointIndex < len(g.endpoints) && g.endpoints[endpointIndex] != nil { + g.endpoints[endpointIndex].URL = replacement + if tlsServerName != "" { + g.endpoints[endpointIndex].TLSServerName = tlsServerName + } + } + g.endpointsMu.Unlock() + + g.mu.Lock() + if endpointIndex >= 0 && endpointIndex < len(g.connectionURLs) { + g.connectionURLs[endpointIndex] = replacement + } + g.mu.Unlock() + + return replacement +} + func (g *GravityClient) reconnectEndpoint(endpointIndex int, reason string) { defer g.endpointReconnecting[endpointIndex].Store(false) @@ -3779,7 +3914,9 @@ func (g *GravityClient) reconnectEndpoint(endpointIndex int, reason string) { g.logger.Info("endpoint %d reconnection attempt %d", endpointIndex, attempt) if err := g.reconnectSingleEndpoint(endpointIndex, endpointURL); err != nil { g.logger.Warn("endpoint %d reconnection attempt %d failed: %v", endpointIndex, attempt, err) + endpointURL = g.handleEndpointReconnectFailure(endpointIndex, endpointURL) } else { + g.resetEndpointFailureCount(endpointIndex) g.logger.Info("endpoint %d (%s) reconnected successfully after %d attempt(s)", endpointIndex, endpointURL, attempt) return } diff --git a/gravity/reconnect_discovery_refresh_test.go b/gravity/reconnect_discovery_refresh_test.go new file mode 100644 index 0000000..05c1593 --- /dev/null +++ b/gravity/reconnect_discovery_refresh_test.go @@ -0,0 +1,82 @@ +package gravity + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/agentuity/go-common/logger" +) + +func newReconnectRefreshTestClient(endpoints []*GravityEndpoint, urls []string, discover func() []string) *GravityClient { + ctx, cancel := context.WithCancel(context.Background()) + g := &GravityClient{ + ctx: ctx, + cancel: cancel, + logger: logger.NewTestLogger(), + defaultServerName: "gravity-usw.agentuity.cloud", + endpoints: endpoints, + connectionURLs: urls, + endpointFailCount: make([]atomic.Int32, len(endpoints)), + discoveryResolveFunc: discover, + } + return g +} + +func TestGravityClient_ReResolvesAfterFailures(t *testing.T) { + endpoints := []*GravityEndpoint{{URL: "grpc://10.0.0.1:443"}} + urls := []string{"grpc://10.0.0.1:443"} + g := newReconnectRefreshTestClient(endpoints, urls, func() []string { + return []string{"grpc://10.0.0.1:443", "grpc://10.0.0.99:443"} + }) + t.Cleanup(g.cancel) + + current := "grpc://10.0.0.1:443" + for i := int32(1); i < endpointDiscoveryRefreshFailureThreshold; i++ { + updated := g.handleEndpointReconnectFailure(0, current) + if updated != current { + t.Fatalf("unexpected URL change before threshold at attempt %d: %s", i, updated) + } + } + + updated := g.handleEndpointReconnectFailure(0, current) + if updated != "grpc://10.0.0.99:443" { + t.Fatalf("expected URL to refresh at threshold, got %s", updated) + } + if g.endpoints[0].URL != "grpc://10.0.0.99:443" { + t.Fatalf("expected endpoint URL updated, got %s", g.endpoints[0].URL) + } + if g.connectionURLs[0] != "grpc://10.0.0.99:443" { + t.Fatalf("expected connection URL updated, got %s", g.connectionURLs[0]) + } +} + +func TestGravityClient_MixedEndpoints_OnlyReResolveFailing(t *testing.T) { + endpoints := []*GravityEndpoint{ + {URL: "grpc://10.0.0.1:443"}, + {URL: "grpc://10.0.0.2:443"}, + } + urls := []string{"grpc://10.0.0.1:443", "grpc://10.0.0.2:443"} + g := newReconnectRefreshTestClient(endpoints, urls, func() []string { + return []string{"grpc://10.0.0.1:443", "grpc://10.0.0.2:443", "grpc://10.0.0.77:443"} + }) + t.Cleanup(g.cancel) + + current := "grpc://10.0.0.1:443" + for i := int32(0); i < endpointDiscoveryRefreshFailureThreshold; i++ { + current = g.handleEndpointReconnectFailure(0, current) + } + + if g.endpoints[0].URL != "grpc://10.0.0.77:443" { + t.Fatalf("expected failing endpoint to refresh, got %s", g.endpoints[0].URL) + } + if g.endpoints[1].URL != "grpc://10.0.0.2:443" { + t.Fatalf("expected healthy endpoint unchanged, got %s", g.endpoints[1].URL) + } + if g.connectionURLs[1] != "grpc://10.0.0.2:443" { + t.Fatalf("expected healthy endpoint connection URL unchanged, got %s", g.connectionURLs[1]) + } + if g.endpointFailCount[1].Load() != 0 { + t.Fatalf("expected non-failing endpoint failure count untouched, got %d", g.endpointFailCount[1].Load()) + } +} diff --git a/network/subnet.go b/network/subnet.go index cc5511a..c8a1f81 100644 --- a/network/subnet.go +++ b/network/subnet.go @@ -9,6 +9,10 @@ import ( // This is separate from NetworkHadron (0x03) to avoid address collisions. const NetworkSandboxSubnet Network = 0x05 +// NetworkDeploymentSubnet is the network type for deployment subnets. +// This is separate from NetworkSandboxSubnet (0x05) to avoid collisions. +const NetworkDeploymentSubnet Network = 0x06 + // ComputeSandboxSubnet returns a deterministic /64 IPv6 subnet for a machine. // Both gravity and hadron must call this function with the same parameters // (region, machineID) to produce identical results. @@ -73,6 +77,50 @@ func ComputeSandboxVIP(subnet netip.Prefix, sandboxID string) netip.Addr { return addr } +// ComputeDeploymentSubnet returns a deterministic /96 IPv6 subnet for a machine. +// Both gravity and hadron must call this function with the same parameters +// (region, machineID) to produce identical results. +// +// Address format: fd15:d710:RNMM:MMMM::/96 +// - Byte 4: Region (5 bits, max 31) | Network (3 bits, max 7) +// - Bytes 5-7: Machine hash (24 bits, 16M buckets) +func ComputeDeploymentSubnet(region Region, machineID string) netip.Prefix { + machineHash := hashTo32Bits(machineID) + + b := make([]byte, 16) + b[0] = 0xfd + b[1] = 0x15 + b[2] = 0xd7 + b[3] = 0x10 + b[4] = (byte(region) << 3) | (byte(NetworkDeploymentSubnet) & 0x07) + b[5] = byte((machineHash >> 16) & 0xff) + b[6] = byte((machineHash >> 8) & 0xff) + b[7] = byte(machineHash & 0xff) + // Bytes 8-15 are zero for the prefix + + addr, _ := netip.AddrFromSlice(b) + return netip.PrefixFrom(addr, 96) +} + +// ComputeDeploymentVIP returns a deterministic IPv6 address for a deployment +// within its machine's subnet. The subnet must be a /96 prefix; the host +// bits (bytes 12-15) are derived from the deploymentID hash. +func ComputeDeploymentVIP(subnet netip.Prefix, deploymentID string) netip.Addr { + if subnet.Bits() != 96 { + panic(fmt.Sprintf("ComputeDeploymentVIP requires a /96 prefix, got /%d", subnet.Bits())) + } + h := hashTo32Bits(deploymentID) + base := subnet.Addr().As16() + + base[12] = byte((h >> 24) & 0xff) + base[13] = byte((h >> 16) & 0xff) + base[14] = byte((h >> 8) & 0xff) + base[15] = byte(h&0xff) | 1 + + addr, _ := netip.AddrFromSlice(base[:]) + return addr +} + // hashTo32Bits returns a 32-bit FNV-1a hash of the input string. func hashTo32Bits(val string) uint32 { var h uint32 = 2166136261 diff --git a/network/subnet_test.go b/network/subnet_test.go index cd85759..6a0748c 100644 --- a/network/subnet_test.go +++ b/network/subnet_test.go @@ -362,3 +362,112 @@ func TestComputeSandboxVIP_CollisionResistance(t *testing.T) { t.Errorf("VIP collision rate should be < 0.1%% at 10000 sandboxes: got %.6f%%", collisionRate*100) } } + +func TestComputeDeploymentSubnet_Consistency(t *testing.T) { + region := RegionUSWest1 + machineID := "machine_xyz789" + + subnet1 := ComputeDeploymentSubnet(region, machineID) + subnet2 := ComputeDeploymentSubnet(region, machineID) + + if subnet1 != subnet2 { + t.Errorf("ComputeDeploymentSubnet should produce consistent results:\n got1: %s\n got2: %s", subnet1, subnet2) + } +} + +func TestComputeDeploymentSubnet_DifferentMachines(t *testing.T) { + region := RegionUSWest1 + + subnet1 := ComputeDeploymentSubnet(region, "machine_001") + subnet2 := ComputeDeploymentSubnet(region, "machine_002") + + if subnet1 == subnet2 { + t.Errorf("ComputeDeploymentSubnet should produce different subnets for different machines:\n got: %s", subnet1) + } +} + +func TestComputeDeploymentSubnet_ValidPrefix(t *testing.T) { + region := RegionUSWest1 + machineID := "machine_xyz789" + + subnet := ComputeDeploymentSubnet(region, machineID) + + if !subnet.IsValid() { + t.Errorf("ComputeDeploymentSubnet should produce valid prefix: got %s", subnet) + } + + if subnet.Bits() != 96 { + t.Errorf("ComputeDeploymentSubnet should produce /96 prefix: got /%d", subnet.Bits()) + } +} + +func TestComputeDeploymentSubnet_NetworkType(t *testing.T) { + region := RegionUSWest1 + machineID := "machine_xyz789" + + subnet := ComputeDeploymentSubnet(region, machineID) + addr := subnet.Addr().As16() + + networkType := addr[4] & 0x07 + if networkType != byte(NetworkDeploymentSubnet) { + t.Errorf("Subnet should have NetworkDeploymentSubnet (0x06) in byte 4 low 3 bits: got 0x%02x", networkType) + } +} + +func TestComputeDeploymentVIP_WithinSubnet(t *testing.T) { + region := RegionUSWest1 + machineID := "machine_xyz789" + deploymentID := "deployment_001" + + subnet := ComputeDeploymentSubnet(region, machineID) + vip := ComputeDeploymentVIP(subnet, deploymentID) + + if !subnet.Contains(vip) { + t.Errorf("ComputeDeploymentVIP should produce address within subnet:\n subnet: %s\n vip: %s", subnet, vip) + } +} + +func TestComputeDeploymentVIP_PanicsOnNon96Prefix(t *testing.T) { + b := make([]byte, 16) + b[0] = 0xfd + b[1] = 0x15 + addr, _ := netip.AddrFromSlice(b) + prefix80 := netip.PrefixFrom(addr, 80) + + defer func() { + if r := recover(); r == nil { + t.Fatal("ComputeDeploymentVIP should panic on non-/96 prefix") + } + }() + + ComputeDeploymentVIP(prefix80, "deployment_001") +} + +func TestComputeDeploymentVIP_DifferentDeployments(t *testing.T) { + region := RegionUSWest1 + machineID := "machine_xyz789" + + subnet := ComputeDeploymentSubnet(region, machineID) + + vip1 := ComputeDeploymentVIP(subnet, "deployment_001") + vip2 := ComputeDeploymentVIP(subnet, "deployment_002") + + if vip1 == vip2 { + t.Errorf("ComputeDeploymentVIP should produce different addresses for different deployments:\n got: %s", vip1) + } +} + +func TestComputeDeploymentVIP_Consistency(t *testing.T) { + region := RegionUSWest1 + machineID := "machine_xyz789" + deploymentID := "deployment_001" + + subnet := ComputeDeploymentSubnet(region, machineID) + + vip1 := ComputeDeploymentVIP(subnet, deploymentID) + vip2 := ComputeDeploymentVIP(subnet, deploymentID) + + if vip1 != vip2 { + t.Errorf("ComputeDeploymentVIP should produce consistent results:\n got1: %s\n got2: %s", vip1, vip2) + } +} From bcf43816c220224834e04456b04ee35758740f40 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Sun, 19 Apr 2026 07:56:24 -0500 Subject: [PATCH 2/4] revert earlier change fixed in main --- gravity/discovery/gce.go | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/gravity/discovery/gce.go b/gravity/discovery/gce.go index 62c9b33..fae53eb 100644 --- a/gravity/discovery/gce.go +++ b/gravity/discovery/gce.go @@ -44,11 +44,6 @@ type gceTokenResponse struct { TokenType string `json:"token_type"` } -// gceInstanceList is a subset of the GCE instances.list response. -type gceInstanceList struct { - Items []gceInstance `json:"items"` -} - // gceAggregatedList is a subset of the GCE instances.aggregatedList response. type gceAggregatedList struct { Items map[string]gceInstancesScopedList `json:"items"` @@ -111,13 +106,12 @@ func (g *GCEDiscoverer) Discover(ctx context.Context) ([]string, error) { if !g.hasTag(inst) { continue } - // Extract the internal IP, preferring IPv6 (memberlist binds to IPv6 - // when available, so peers must be reachable on their IPv6 address). + // Extract the internal IPv4 address. The IPv6 address on GCE + // instances is an overlay address (fd20:…) that is NOT routable + // between machines, causing memberlist peers to be unreachable. + // IPv4 private addresses (10.x.x.x) are always routable within + // the VPC. for _, iface := range inst.NetworkInterfaces { - if iface.IPv6Address != "" { - peers = append(peers, iface.IPv6Address) - break - } if iface.NetworkIP != "" { peers = append(peers, iface.NetworkIP) break From 28297c823339186151bdd710d99f52b35a79d626 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Sun, 19 Apr 2026 07:59:15 -0500 Subject: [PATCH 3/4] fix: extract TLS SNI hostname from raw URL, not DNS-resolved address --- gravity/grpc_client.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/gravity/grpc_client.go b/gravity/grpc_client.go index 7912114..0eca924 100644 --- a/gravity/grpc_client.go +++ b/gravity/grpc_client.go @@ -3725,15 +3725,21 @@ func (g *GravityClient) preferredTLSServerName(endpointURL string) string { if strings.TrimSpace(endpointURL) == "" { return "" } - hostPort, err := g.parseGRPCURL(endpointURL) + // Extract hostname directly from the URL without DNS resolution. + // parseGRPCURL resolves /etc/hosts which can replace the hostname + // with an IP, breaking TLS SNI. We need the original hostname. + raw := endpointURL + if strings.HasPrefix(raw, "grpc://") { + raw = "https://" + raw[7:] // make it parseable by url.Parse + } else if !strings.Contains(raw, "://") { + raw = "https://" + raw + } + u, err := url.Parse(raw) if err != nil { return g.defaultServerName } - host, _, err := net.SplitHostPort(hostPort) - if err != nil { - return g.defaultServerName - } - if net.ParseIP(host) != nil { + host := u.Hostname() + if host == "" || net.ParseIP(host) != nil { return g.defaultServerName } return host From 7d81e52335752bf4c4c4296654ed3b49040cd9ee Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Sun, 19 Apr 2026 08:00:07 -0500 Subject: [PATCH 4/4] test: add preferredTLSServerName test covering /etc/hosts SNI isolation --- gravity/tls_servername_test.go | 65 ++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/gravity/tls_servername_test.go b/gravity/tls_servername_test.go index 37a2f67..4dff856 100644 --- a/gravity/tls_servername_test.go +++ b/gravity/tls_servername_test.go @@ -269,3 +269,68 @@ func TestReconnectEndpoint_UsesTLSServerName(t *testing.T) { assert.Empty(t, hostname, "reconnect should fall back for hostname-based endpoints") } + +// TestPreferredTLSServerName_HostsFileDoesNotAffectSNI verifies that +// preferredTLSServerName returns the original hostname from the URL even when +// /etc/hosts would resolve it to an IP. The TLS SNI must use the hostname +// (not the resolved IP) for certificate validation to work. +func TestPreferredTLSServerName_HostsFileDoesNotAffectSNI(t *testing.T) { + g := &GravityClient{ + logger: logger.NewTestLogger(), + defaultServerName: "fallback.example.com", + } + + tests := []struct { + name string + url string + expected string + }{ + { + name: "grpc scheme with hostname returns hostname", + url: "grpc://gravity.agentuity.io:443", + expected: "gravity.agentuity.io", + }, + { + name: "grpc scheme without port returns hostname", + url: "grpc://gravity.agentuity.io", + expected: "gravity.agentuity.io", + }, + { + name: "hostname with port no scheme returns hostname", + url: "gravity.agentuity.io:443", + expected: "gravity.agentuity.io", + }, + { + name: "localhost returns hostname not IP", + url: "grpc://localhost:443", + expected: "localhost", + }, + { + name: "IP address returns default server name", + url: "grpc://192.168.1.1:443", + expected: "fallback.example.com", + }, + { + name: "IPv6 address returns default server name", + url: "grpc://[fd15:d710::1]:443", + expected: "fallback.example.com", + }, + { + name: "empty URL returns empty", + url: "", + expected: "", + }, + { + name: "whitespace URL returns empty", + url: " ", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := g.preferredTLSServerName(tt.url) + assert.Equal(t, tt.expected, got) + }) + } +}