Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions gravity/discovery/gce.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ type gceTokenResponse struct {
TokenType string `json:"token_type"`
}

// gceInstanceList is a subset of the GCE instances.list response.
type gceInstanceList struct {
Items []gceInstance `json:"items"`
}

// gceAggregatedList is a subset of the GCE instances.aggregatedList response.
type gceAggregatedList struct {
Items map[string]gceInstancesScopedList `json:"items"`
Expand Down
153 changes: 148 additions & 5 deletions gravity/grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ type GravityClient struct {
connected bool
reconnecting bool // Tracks if reconnection is in progress
endpointReconnecting []atomic.Bool // Per-endpoint reconnect guard (multi-endpoint only)
endpointFailCount []atomic.Int32
connectionIDs []string // Stores connection IDs from server responses
connectionIDChan chan string // Channel to signal when connection ID is received
helloAckedStreams sync.Map // streamIndex (int) → true: tracks which streams received SessionHelloResponse
Expand Down Expand Up @@ -392,6 +393,8 @@ type StreamInfo struct {
sendMu sync.Mutex // Serializes Send calls on this stream
}

const endpointDiscoveryRefreshFailureThreshold int32 = 10

// StreamManager manages multiple gRPC streams for multiplexing with advanced load balancing
type StreamManager struct {
// Control streams (one per connection) - using session service
Expand Down Expand Up @@ -805,7 +808,7 @@ func (g *GravityClient) startMultiEndpoint() error {

// If hostname is already an IP address, skip DNS lookup
if ip := net.ParseIP(hostname); ip != nil {
ep := &GravityEndpoint{URL: endpointURL}
ep := &GravityEndpoint{URL: endpointURL, TLSServerName: g.preferredTLSServerName(endpointURL)}
ep.healthy.Store(false)
g.endpoints = append(g.endpoints, ep)
continue
Expand Down Expand Up @@ -840,12 +843,12 @@ func (g *GravityClient) startMultiEndpoint() error {
}
} else {
// DNS returned no IPs, fall back to original URL
ep := &GravityEndpoint{URL: endpointURL}
ep := &GravityEndpoint{URL: endpointURL, TLSServerName: g.preferredTLSServerName(endpointURL)}
ep.healthy.Store(false)
g.endpoints = append(g.endpoints, ep)
}
} else {
ep := &GravityEndpoint{URL: endpointURL}
ep := &GravityEndpoint{URL: endpointURL, TLSServerName: g.preferredTLSServerName(endpointURL)}
ep.healthy.Store(false)
g.endpoints = append(g.endpoints, ep)
}
Expand Down Expand Up @@ -910,6 +913,7 @@ func (g *GravityClient) startMultiEndpoint() error {
g.circuitBreakers = make([]*CircuitBreaker, connectionCount)
g.connectionURLs = make([]string, connectionCount)
g.endpointReconnecting = make([]atomic.Bool, connectionCount)
g.endpointFailCount = make([]atomic.Int32, connectionCount)
g.endpointStreamIndices = make(map[string][]int)

// Initialize connection health tracking
Expand Down Expand Up @@ -1649,7 +1653,7 @@ func (g *GravityClient) cycleEndpoint(oldURL, newURL string) {
return
}

newEp := &GravityEndpoint{URL: newURL}
newEp := &GravityEndpoint{URL: newURL, TLSServerName: g.preferredTLSServerName(newURL)}
newEp.healthy.Store(false)
g.endpoints[oldIdx] = newEp
g.endpointsMu.Unlock()
Expand Down Expand Up @@ -1692,7 +1696,7 @@ func (g *GravityClient) addEndpoint(newURL string) {
}
}

newEp := &GravityEndpoint{URL: newURL}
newEp := &GravityEndpoint{URL: newURL, TLSServerName: g.preferredTLSServerName(newURL)}
newEp.healthy.Store(false) // will become healthy after reconnection
g.endpoints[slotIdx] = newEp
g.endpointsMu.Unlock()
Expand All @@ -1717,6 +1721,9 @@ func (g *GravityClient) addEndpoint(newURL string) {
for len(g.endpointReconnecting) <= slotIdx {
g.endpointReconnecting = append(g.endpointReconnecting, atomic.Bool{})
}
for len(g.endpointFailCount) <= slotIdx {
g.endpointFailCount = append(g.endpointFailCount, atomic.Int32{})
}
g.mu.Unlock()

// Grow stream manager arrays under their respective locks.
Expand Down Expand Up @@ -3714,6 +3721,140 @@ func (g *GravityClient) hasHealthyEndpoint() bool {
return false
}

func (g *GravityClient) preferredTLSServerName(endpointURL string) string {
if strings.TrimSpace(endpointURL) == "" {
return ""
}
// Extract hostname directly from the URL without DNS resolution.
// parseGRPCURL resolves /etc/hosts which can replace the hostname
// with an IP, breaking TLS SNI. We need the original hostname.
raw := endpointURL
if strings.HasPrefix(raw, "grpc://") {
raw = "https://" + raw[7:] // make it parseable by url.Parse
} else if !strings.Contains(raw, "://") {
raw = "https://" + raw
}
u, err := url.Parse(raw)
if err != nil {
return g.defaultServerName
}
host := u.Hostname()
if host == "" || net.ParseIP(host) != nil {
return g.defaultServerName
}
return host
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

func (g *GravityClient) resetEndpointFailureCount(endpointIndex int) {
g.mu.RLock()
if endpointIndex >= 0 && endpointIndex < len(g.endpointFailCount) {
g.endpointFailCount[endpointIndex].Store(0)
}
g.mu.RUnlock()
}

func (g *GravityClient) incrementEndpointFailureCount(endpointIndex int) int32 {
g.mu.RLock()
if endpointIndex < 0 || endpointIndex >= len(g.endpointFailCount) {
g.mu.RUnlock()
return 0
}
count := g.endpointFailCount[endpointIndex].Add(1)
g.mu.RUnlock()
return count
}

func (g *GravityClient) handleEndpointReconnectFailure(endpointIndex int, currentURL string) string {
failures := g.incrementEndpointFailureCount(endpointIndex)
if failures == 0 {
return currentURL
}

if failures%endpointDiscoveryRefreshFailureThreshold != 0 {
return currentURL
}

g.logger.Info("endpoint %d reached %d consecutive reconnection failures; refreshing service discovery",
endpointIndex, failures)

refreshedURL := g.refreshFailingEndpointFromDiscovery(endpointIndex, currentURL)
if refreshedURL == "" || refreshedURL == currentURL {
g.logger.Info("endpoint %d discovery refresh found no replacement after %d failures", endpointIndex, failures)
return currentURL
}

g.logger.Info("endpoint %d discovery refresh replaced %s -> %s after %d failures",
endpointIndex, currentURL, refreshedURL, failures)
g.resetEndpointFailureCount(endpointIndex)
return refreshedURL
}

func (g *GravityClient) refreshFailingEndpointFromDiscovery(endpointIndex int, currentURL string) string {
if g.discoveryResolveFunc == nil {
return ""
}

candidates := g.discoveryResolveFunc()
if len(candidates) == 0 {
return ""
}

unique := make([]string, 0, len(candidates))
seen := make(map[string]bool, len(candidates))
for _, c := range candidates {
c = strings.TrimSpace(c)
if c == "" || seen[c] {
continue
}
seen[c] = true
unique = append(unique, c)
}
if len(unique) == 0 {
return ""
}

inUse := make(map[string]bool)
g.endpointsMu.RLock()
for i, ep := range g.endpoints {
if i == endpointIndex || ep == nil || strings.TrimSpace(ep.URL) == "" {
continue
}
inUse[ep.URL] = true
}
g.endpointsMu.RUnlock()

var replacement string
for _, c := range unique {
if c == currentURL || inUse[c] {
continue
}
replacement = c
break
}
if replacement == "" {
return ""
}

tlsServerName := g.preferredTLSServerName(replacement)

g.endpointsMu.Lock()
if endpointIndex >= 0 && endpointIndex < len(g.endpoints) && g.endpoints[endpointIndex] != nil {
g.endpoints[endpointIndex].URL = replacement
if tlsServerName != "" {
g.endpoints[endpointIndex].TLSServerName = tlsServerName
}
}
g.endpointsMu.Unlock()

g.mu.Lock()
if endpointIndex >= 0 && endpointIndex < len(g.connectionURLs) {
g.connectionURLs[endpointIndex] = replacement
}
g.mu.Unlock()

return replacement
}

func (g *GravityClient) reconnectEndpoint(endpointIndex int, reason string) {
defer g.endpointReconnecting[endpointIndex].Store(false)

Expand Down Expand Up @@ -3779,7 +3920,9 @@ func (g *GravityClient) reconnectEndpoint(endpointIndex int, reason string) {
g.logger.Info("endpoint %d reconnection attempt %d", endpointIndex, attempt)
if err := g.reconnectSingleEndpoint(endpointIndex, endpointURL); err != nil {
g.logger.Warn("endpoint %d reconnection attempt %d failed: %v", endpointIndex, attempt, err)
endpointURL = g.handleEndpointReconnectFailure(endpointIndex, endpointURL)
} else {
g.resetEndpointFailureCount(endpointIndex)
g.logger.Info("endpoint %d (%s) reconnected successfully after %d attempt(s)", endpointIndex, endpointURL, attempt)
return
}
Expand Down
82 changes: 82 additions & 0 deletions gravity/reconnect_discovery_refresh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package gravity

import (
"context"
"sync/atomic"
"testing"

"github.com/agentuity/go-common/logger"
)

func newReconnectRefreshTestClient(endpoints []*GravityEndpoint, urls []string, discover func() []string) *GravityClient {
ctx, cancel := context.WithCancel(context.Background())
g := &GravityClient{
ctx: ctx,
cancel: cancel,
logger: logger.NewTestLogger(),
defaultServerName: "gravity-usw.agentuity.cloud",
endpoints: endpoints,
connectionURLs: urls,
endpointFailCount: make([]atomic.Int32, len(endpoints)),
discoveryResolveFunc: discover,
}
return g
}

func TestGravityClient_ReResolvesAfterFailures(t *testing.T) {
endpoints := []*GravityEndpoint{{URL: "grpc://10.0.0.1:443"}}
urls := []string{"grpc://10.0.0.1:443"}
g := newReconnectRefreshTestClient(endpoints, urls, func() []string {
return []string{"grpc://10.0.0.1:443", "grpc://10.0.0.99:443"}
})
t.Cleanup(g.cancel)

current := "grpc://10.0.0.1:443"
for i := int32(1); i < endpointDiscoveryRefreshFailureThreshold; i++ {
updated := g.handleEndpointReconnectFailure(0, current)
if updated != current {
t.Fatalf("unexpected URL change before threshold at attempt %d: %s", i, updated)
}
}

updated := g.handleEndpointReconnectFailure(0, current)
if updated != "grpc://10.0.0.99:443" {
t.Fatalf("expected URL to refresh at threshold, got %s", updated)
}
if g.endpoints[0].URL != "grpc://10.0.0.99:443" {
t.Fatalf("expected endpoint URL updated, got %s", g.endpoints[0].URL)
}
if g.connectionURLs[0] != "grpc://10.0.0.99:443" {
t.Fatalf("expected connection URL updated, got %s", g.connectionURLs[0])
}
}

func TestGravityClient_MixedEndpoints_OnlyReResolveFailing(t *testing.T) {
endpoints := []*GravityEndpoint{
{URL: "grpc://10.0.0.1:443"},
{URL: "grpc://10.0.0.2:443"},
}
urls := []string{"grpc://10.0.0.1:443", "grpc://10.0.0.2:443"}
g := newReconnectRefreshTestClient(endpoints, urls, func() []string {
return []string{"grpc://10.0.0.1:443", "grpc://10.0.0.2:443", "grpc://10.0.0.77:443"}
})
t.Cleanup(g.cancel)

current := "grpc://10.0.0.1:443"
for i := int32(0); i < endpointDiscoveryRefreshFailureThreshold; i++ {
current = g.handleEndpointReconnectFailure(0, current)
}

if g.endpoints[0].URL != "grpc://10.0.0.77:443" {
t.Fatalf("expected failing endpoint to refresh, got %s", g.endpoints[0].URL)
}
if g.endpoints[1].URL != "grpc://10.0.0.2:443" {
t.Fatalf("expected healthy endpoint unchanged, got %s", g.endpoints[1].URL)
}
if g.connectionURLs[1] != "grpc://10.0.0.2:443" {
t.Fatalf("expected healthy endpoint connection URL unchanged, got %s", g.connectionURLs[1])
}
if g.endpointFailCount[1].Load() != 0 {
t.Fatalf("expected non-failing endpoint failure count untouched, got %d", g.endpointFailCount[1].Load())
}
}
65 changes: 65 additions & 0 deletions gravity/tls_servername_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,68 @@ func TestReconnectEndpoint_UsesTLSServerName(t *testing.T) {
assert.Empty(t, hostname,
"reconnect should fall back for hostname-based endpoints")
}

// TestPreferredTLSServerName_HostsFileDoesNotAffectSNI verifies that
// preferredTLSServerName returns the original hostname from the URL even when
// /etc/hosts would resolve it to an IP. The TLS SNI must use the hostname
// (not the resolved IP) for certificate validation to work.
func TestPreferredTLSServerName_HostsFileDoesNotAffectSNI(t *testing.T) {
g := &GravityClient{
logger: logger.NewTestLogger(),
defaultServerName: "fallback.example.com",
}

tests := []struct {
name string
url string
expected string
}{
{
name: "grpc scheme with hostname returns hostname",
url: "grpc://gravity.agentuity.io:443",
expected: "gravity.agentuity.io",
},
{
name: "grpc scheme without port returns hostname",
url: "grpc://gravity.agentuity.io",
expected: "gravity.agentuity.io",
},
{
name: "hostname with port no scheme returns hostname",
url: "gravity.agentuity.io:443",
expected: "gravity.agentuity.io",
},
{
name: "localhost returns hostname not IP",
url: "grpc://localhost:443",
expected: "localhost",
},
{
name: "IP address returns default server name",
url: "grpc://192.168.1.1:443",
expected: "fallback.example.com",
},
{
name: "IPv6 address returns default server name",
url: "grpc://[fd15:d710::1]:443",
expected: "fallback.example.com",
},
{
name: "empty URL returns empty",
url: "",
expected: "",
},
{
name: "whitespace URL returns empty",
url: " ",
expected: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := g.preferredTLSServerName(tt.url)
assert.Equal(t, tt.expected, got)
})
}
}
Loading
Loading