From 7aba7c0236edcd888df9ccc49214f31cbe2b76d0 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Mon, 13 Apr 2026 23:17:40 -0500 Subject: [PATCH 1/7] feat: add SSE streaming support with billing extraction - Add streaming SSE parser for OpenAI and Anthropic responses - Implement StreamingResponseExtractor interface for all providers - Use http.ResponseController for efficient flushing - Add simdjson-go for fast JSON parsing of usage data - Auto-inject stream_options.include_usage for billing - Extract token usage from streaming responses for cost calculation - Add comprehensive unit tests for streaming functionality --- autorouter.go | 262 +++++++++++++++- autorouter_test.go | 276 +++++++++++++++++ billing_calculator.go | 62 ++++ examples/basic/main.go | 25 +- extractor.go | 71 +++-- go.mod | 4 + go.sum | 8 + internal/fastjson/extractor.go | 268 ++++++++++++++++ internal/fastjson/extractor_test.go | 197 ++++++++++++ providers/anthropic/provider.go | 16 +- providers/anthropic/streaming_extractor.go | 217 +++++++++++++ providers/azure/provider.go | 4 +- providers/bedrock/provider.go | 22 +- providers/bedrock/streaming_extractor.go | 73 +++++ providers/googleai/provider.go | 2 +- providers/googleai/streaming_extractor.go | 73 +++++ providers/openai_compatible/multiapi.go | 72 ++++- providers/openai_compatible/provider.go | 4 +- .../openai_compatible/streaming_extractor.go | 192 ++++++++++++ .../streaming_extractor_test.go | 147 +++++++++ providers/perplexity/provider.go | 2 +- streaming.go | 291 ++++++++++++++++++ streaming_test.go | 286 +++++++++++++++++ 23 files changed, 2494 insertions(+), 80 deletions(-) create mode 100644 billing_calculator.go create mode 100644 internal/fastjson/extractor.go create mode 100644 internal/fastjson/extractor_test.go create mode 100644 providers/anthropic/streaming_extractor.go create mode 100644 providers/bedrock/streaming_extractor.go create mode 100644 providers/googleai/streaming_extractor.go create mode 100644 providers/openai_compatible/streaming_extractor.go create mode 100644 providers/openai_compatible/streaming_extractor_test.go create mode 100644 streaming.go create mode 100644 streaming_test.go diff --git a/autorouter.go b/autorouter.go index 5c23378..bcc8f9c 100644 --- a/autorouter.go +++ b/autorouter.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -17,6 +18,7 @@ type AutoRouter struct { interceptors InterceptorChain client *http.Client fallbackProvider Provider + billingCalculator *BillingCalculator } type AutoRouterOption func(*AutoRouter) @@ -45,6 +47,10 @@ func WithAutoRouterModelProviderLookup(lookup ModelProviderLookup) AutoRouterOpt return func(a *AutoRouter) { a.modelProviderLookup = lookup } } +func WithAutoRouterBillingCalculator(calculator *BillingCalculator) AutoRouterOption { + return func(a *AutoRouter) { a.billingCalculator = calculator } +} + func NewAutoRouter(opts ...AutoRouterOption) *AutoRouter { a := &AutoRouter{ registry: NewRegistry(), @@ -57,6 +63,10 @@ func NewAutoRouter(opts ...AutoRouterOption) *AutoRouter { return a } +func (a *AutoRouter) BillingCalculator() *BillingCalculator { + return a.billingCalculator +} + func (a *AutoRouter) RegisterProvider(p Provider) { a.registry.Register(p) } @@ -87,7 +97,6 @@ func (a *AutoRouter) Forward(ctx context.Context, req *http.Request) (*http.Resp } providerName := a.detector.Detect(hint) - // If no provider detected and we have a model provider lookup, try that if providerName == "" && a.modelProviderLookup != nil && model != "" { providerName = a.modelProviderLookup(model) } @@ -96,18 +105,15 @@ func (a *AutoRouter) Forward(ctx context.Context, req *http.Request) (*http.Resp if providerName != "" { provider, _ = a.registry.Get(providerName) if provider == nil { - // Explicit provider name was provided but not found in registry return nil, ResponseMetadata{}, ErrNoProvider } } else { - // No provider detected, use fallback provider = a.fallbackProvider if provider == nil { return nil, ResponseMetadata{}, ErrNoProvider } } - // Strip provider prefix from model name (e.g., "openai/gpt-4" -> "gpt-4") if raw != nil { if strippedModel, hasPrefix := stripProviderPrefix(model); hasPrefix { raw["model"] = strippedModel @@ -120,7 +126,6 @@ func (a *AutoRouter) Forward(ctx context.Context, req *http.Request) (*http.Resp } } - // Detect API type: path takes precedence, then body+provider detection apiType := DetectAPITypeFromPath(req.URL.Path) if apiType == "" { apiType = DetectAPITypeFromBodyAndProvider(body, providerName) @@ -190,19 +195,222 @@ func (a *AutoRouter) roundTrip(provider Provider, req *http.Request) (*http.Resp return resp, respMeta, rawBody, nil } +func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w http.ResponseWriter) (ResponseMetadata, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + return ResponseMetadata{}, err + } + req.Body.Close() + + var raw map[string]any + var model string + if err := json.Unmarshal(body, &raw); err == nil { + if m, ok := raw["model"].(string); ok { + model = m + } + } + + hint := ProviderHint{ + Model: model, + Headers: req.Header, + } + providerName := a.detector.Detect(hint) + + if providerName == "" && a.modelProviderLookup != nil && model != "" { + providerName = a.modelProviderLookup(model) + } + + var provider Provider + if providerName != "" { + provider, _ = a.registry.Get(providerName) + if provider == nil { + return ResponseMetadata{}, ErrNoProvider + } + } else { + provider = a.fallbackProvider + if provider == nil { + return ResponseMetadata{}, ErrNoProvider + } + } + + if raw != nil { + if strippedModel, hasPrefix := stripProviderPrefix(model); hasPrefix { + raw["model"] = strippedModel + model = strippedModel + } + if a.billingCalculator != nil { + if stream, ok := raw["stream"].(bool); ok && stream { + raw["stream_options"] = map[string]any{"include_usage": true} + } + } + var err error + body, err = json.Marshal(raw) + if err != nil { + return ResponseMetadata{}, fmt.Errorf("failed to marshal request body: %w", err) + } + } + + apiType := DetectAPITypeFromPath(req.URL.Path) + if apiType == "" { + apiType = DetectAPITypeFromBodyAndProvider(body, providerName) + } + + meta, _, err := provider.BodyParser().Parse(io.NopCloser(bytes.NewReader(body))) + if err != nil { + return ResponseMetadata{}, err + } + + if meta.Custom == nil { + meta.Custom = make(map[string]any) + } + meta.Custom["api_type"] = apiType + meta.Custom["provider"] = providerName + + upstreamURL, err := provider.URLResolver().Resolve(meta) + if err != nil { + return ResponseMetadata{}, err + } + + upstreamReq, err := http.NewRequestWithContext(ctx, req.Method, upstreamURL.String(), bytes.NewReader(body)) + if err != nil { + return ResponseMetadata{}, err + } + + for k, v := range req.Header { + upstreamReq.Header[k] = v + } + + if err := provider.RequestEnricher().Enrich(upstreamReq, meta, body); err != nil { + return ResponseMetadata{}, err + } + + ctxValue := MetaContextValue{Meta: meta, RawBody: body} + upstreamReq = upstreamReq.WithContext(context.WithValue(upstreamReq.Context(), MetaContextKey{}, ctxValue)) + + upstreamResp, err := a.client.Do(upstreamReq) + if err != nil { + return ResponseMetadata{}, err + } + defer upstreamResp.Body.Close() + + for k, v := range upstreamResp.Header { + if k != "Content-Length" { + w.Header()[k] = v + } + } + + w.WriteHeader(upstreamResp.StatusCode) + + rc := http.NewResponseController(w) + + extractor := provider.ResponseExtractor() + streamExtractor, isStreaming := extractor.(StreamingResponseExtractor) + + var respMeta ResponseMetadata + + if isStreaming && streamExtractor.IsStreamingResponse(upstreamResp) { + respMeta, err = streamExtractor.ExtractStreamingWithController(upstreamResp, w, rc) + if err != nil { + return respMeta, err + } + } else { + respMeta, err = a.streamResponseWithFlush(upstreamResp.Body, w, rc, extractor) + if err != nil { + return respMeta, err + } + } + + if a.billingCalculator != nil { + a.billingCalculator.Calculate(meta, &respMeta) + } + + return respMeta, nil +} + +func (a *AutoRouter) streamResponseWithFlush(r io.Reader, w http.ResponseWriter, rc *http.ResponseController, extractor ResponseExtractor) (ResponseMetadata, error) { + var buf bytes.Buffer + tee := io.TeeReader(r, &buf) + + respMeta, _, err := extractor.Extract(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(tee), + }) + if err != nil { + return respMeta, err + } + + readBuf := make([]byte, 1024*512) + for { + n, err := buf.Read(readBuf) + if err != nil { + if err == io.EOF { + if n > 0 { + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return respMeta, fmt.Errorf("write chunk: %w", writeErr) + } + } + break + } + if errors.Is(err, context.Canceled) { + break + } + return respMeta, fmt.Errorf("copy chunk: %w", err) + } + if n == 0 { + break + } + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return respMeta, fmt.Errorf("write chunk: %w", writeErr) + } + if flushErr := rc.Flush(); flushErr != nil { + return respMeta, fmt.Errorf("flush: %w", flushErr) + } + } + + return respMeta, nil +} + func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { - resp, meta, err := a.Forward(r.Context(), r) + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - defer resp.Body.Close() + r.Body.Close() - body, err := io.ReadAll(resp.Body) + var raw map[string]any + var isStreamingRequest bool + if err := json.Unmarshal(body, &raw); err == nil { + if stream, ok := raw["stream"].(bool); ok && stream { + isStreamingRequest = true + } + } + + r.Body = io.NopCloser(bytes.NewReader(body)) + + if isStreamingRequest { + meta, err := a.ForwardStreaming(r.Context(), r, w) + if err != nil { + if !headerSent(w) { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + + if billing, ok := meta.Custom["billing_result"].(BillingResult); ok { + w.Header().Set("X-Gateway-Cost", fmt.Sprintf("%.6f", billing.TotalCost)) + w.Header().Set("X-Gateway-Prompt-Tokens", fmt.Sprintf("%d", billing.PromptTokens)) + w.Header().Set("X-Gateway-Completion-Tokens", fmt.Sprintf("%d", billing.CompletionTokens)) + } + return + } + + resp, meta, err := a.Forward(r.Context(), r) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } + defer resp.Body.Close() for k, v := range resp.Header { w.Header()[k] = v @@ -215,9 +423,43 @@ func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(resp.StatusCode) - if _, err := w.Write(body); err != nil { - // Headers already sent, can't report error to client + + rc := http.NewResponseController(w) + readBuf := make([]byte, 1024*512) + for { + n, err := resp.Body.Read(readBuf) + if err != nil { + if err == io.EOF { + if n > 0 { + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return + } + } + break + } + if errors.Is(err, context.Canceled) { + break + } + return + } + if n == 0 { + break + } + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return + } + _ = rc.Flush() + } +} + +func headerSent(w http.ResponseWriter) bool { + type headerChecker interface { + WroteHeader() bool + } + if hc, ok := w.(headerChecker); ok { + return hc.WroteHeader() } + return false } var ErrNoProvider = &ProviderError{Message: "no provider available for request"} diff --git a/autorouter_test.go b/autorouter_test.go index f4cc319..e48d391 100644 --- a/autorouter_test.go +++ b/autorouter_test.go @@ -66,6 +66,32 @@ func (m *mockExtractor) Extract(resp *http.Response) (ResponseMetadata, []byte, return m.extractFn(resp) } +type mockStreamingProvider struct { + *mockProvider + streamingExtractor *mockStreamingExtractor +} + +func (m *mockStreamingProvider) ResponseExtractor() ResponseExtractor { + return m.streamingExtractor +} + +type mockStreamingExtractor struct { + *mockExtractor + isStreaming bool + extractStreamingFn func(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) +} + +func (m *mockStreamingExtractor) IsStreamingResponse(resp *http.Response) bool { + return m.isStreaming +} + +func (m *mockStreamingExtractor) ExtractStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) { + if m.extractStreamingFn != nil { + return m.extractStreamingFn(resp, w, rc) + } + return ResponseMetadata{}, nil +} + type mockDetector struct{ detectFn func(ProviderHint) string } func (m *mockDetector) Detect(hint ProviderHint) string { return m.detectFn(hint) } @@ -390,3 +416,253 @@ func TestAutoRouter_PreservesModelWithoutPrefix(t *testing.T) { t.Errorf("StatusCode = %d, want 200", resp.StatusCode) } } + +func TestAutoRouter_StreamingInjectsStreamOptions(t *testing.T) { + var receivedBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + w.Write([]byte("data: {\"id\":\"test\"}\n\ndata: [DONE]\n\n")) + })) + defer upstream.Close() + + provider := &mockStreamingProvider{ + mockProvider: &mockProvider{ + name: "test", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4", Stream: true}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + }, + streamingExtractor: &mockStreamingExtractor{ + isStreaming: true, + extractStreamingFn: func(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) { + io.Copy(w, resp.Body) + rc.Flush() + return ResponseMetadata{ID: "test"}, nil + }, + }, + } + provider.mockProvider.extractFn = func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "test"}, body, nil + } + + billing := NewBillingCalculator( + func(provider, model string) (CostInfo, bool) { + return CostInfo{Input: 1, Output: 2}, true + }, + nil, + ) + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "test" })), + WithAutoRouterBillingCalculator(billing), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`{"model":"gpt-4","stream":true,"messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", w.Code) + } + + streamOpts, ok := receivedBody["stream_options"].(map[string]any) + if !ok { + t.Fatal("stream_options not injected") + } + if includeUsage, ok := streamOpts["include_usage"].(bool); !ok || !includeUsage { + t.Errorf("stream_options.include_usage = %v, want true", streamOpts["include_usage"]) + } +} + +func TestAutoRouter_StreamingOverridesStreamOptions(t *testing.T) { + var receivedBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + w.Write([]byte("data: {\"id\":\"test\"}\n\ndata: [DONE]\n\n")) + })) + defer upstream.Close() + + provider := &mockStreamingProvider{ + mockProvider: &mockProvider{ + name: "test", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4", Stream: true}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + }, + streamingExtractor: &mockStreamingExtractor{ + isStreaming: true, + extractStreamingFn: func(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) { + io.Copy(w, resp.Body) + rc.Flush() + return ResponseMetadata{ID: "test"}, nil + }, + }, + } + provider.mockProvider.extractFn = func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "test"}, body, nil + } + + billing := NewBillingCalculator( + func(provider, model string) (CostInfo, bool) { + return CostInfo{Input: 1, Output: 2}, true + }, + nil, + ) + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "test" })), + WithAutoRouterBillingCalculator(billing), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`{"model":"gpt-4","stream":true,"stream_options":{"include_usage":false},"messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", w.Code) + } + + streamOpts, ok := receivedBody["stream_options"].(map[string]any) + if !ok { + t.Fatal("stream_options not present") + } + if includeUsage, ok := streamOpts["include_usage"].(bool); !ok || !includeUsage { + t.Errorf("stream_options.include_usage = %v, want true (should override false)", streamOpts["include_usage"]) + } +} + +func TestAutoRouter_StreamingNoBillingNoStreamOptions(t *testing.T) { + var receivedBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + w.Write([]byte("data: {\"id\":\"test\"}\n\ndata: [DONE]\n\n")) + })) + defer upstream.Close() + + provider := &mockStreamingProvider{ + mockProvider: &mockProvider{ + name: "test", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4", Stream: true}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + }, + streamingExtractor: &mockStreamingExtractor{ + isStreaming: true, + extractStreamingFn: func(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) { + io.Copy(w, resp.Body) + rc.Flush() + return ResponseMetadata{ID: "test"}, nil + }, + }, + } + provider.mockProvider.extractFn = func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "test"}, body, nil + } + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "test" })), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`{"model":"gpt-4","stream":true,"messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", w.Code) + } + + if _, ok := receivedBody["stream_options"]; ok { + t.Error("stream_options should not be injected when no billing calculator") + } +} + +func TestAutoRouter_NonStreamingNoStreamOptions(t *testing.T) { + var receivedBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"id":"test"}`)) + })) + defer upstream.Close() + + provider := &mockProvider{ + name: "test", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "gpt-4"}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "test"}, body, nil + }, + } + + billing := NewBillingCalculator( + func(provider, model string) (CostInfo, bool) { + return CostInfo{Input: 1, Output: 2}, true + }, + nil, + ) + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "test" })), + WithAutoRouterBillingCalculator(billing), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", w.Code) + } + + if _, ok := receivedBody["stream_options"]; ok { + t.Error("stream_options should not be injected for non-streaming requests") + } +} diff --git a/billing_calculator.go b/billing_calculator.go new file mode 100644 index 0000000..ca610c6 --- /dev/null +++ b/billing_calculator.go @@ -0,0 +1,62 @@ +package llmproxy + +type BillingCalculator struct { + lookup CostLookup + onResult func(BillingResult) +} + +func NewBillingCalculator(lookup CostLookup, onResult func(BillingResult)) *BillingCalculator { + return &BillingCalculator{ + lookup: lookup, + onResult: onResult, + } +} + +func (c *BillingCalculator) Calculate(meta BodyMetadata, respMeta *ResponseMetadata) *BillingResult { + var provider string + if meta.Custom != nil { + if p, ok := meta.Custom["provider"].(string); ok && p != "" { + provider = p + } + } + if provider == "" { + provider = DetectProviderFromModel(meta.Model) + } + + costInfo, found := c.lookup(provider, meta.Model) + if !found { + costInfo, found = c.lookup("", meta.Model) + } + + if !found { + return nil + } + + var cacheUsage *CacheUsage + if cu, ok := respMeta.Custom["cache_usage"]; ok { + if usage, ok := cu.(CacheUsage); ok { + cacheUsage = &usage + } + } + + result := CalculateCost(provider, meta.Model, costInfo, respMeta.Usage.PromptTokens, respMeta.Usage.CompletionTokens, cacheUsage) + + if respMeta.Custom == nil { + respMeta.Custom = make(map[string]any) + } + respMeta.Custom["billing_result"] = result + + if c.onResult != nil { + c.onResult(result) + } + + return &result +} + +func (c *BillingCalculator) Lookup() CostLookup { + return c.lookup +} + +func (c *BillingCalculator) OnResult() func(BillingResult) { + return c.onResult +} diff --git a/examples/basic/main.go b/examples/basic/main.go index 2d3d051..f2e7962 100644 --- a/examples/basic/main.go +++ b/examples/basic/main.go @@ -208,7 +208,7 @@ func main() { llmproxy.WithAutoRouterInterceptor(tracingInterceptor), llmproxy.WithAutoRouterInterceptor(loggingInterceptor), llmproxy.WithAutoRouterInterceptor(interceptors.NewMetrics(metrics)), - llmproxy.WithAutoRouterInterceptor(interceptors.NewResponseHeaderBan("Openai-Organization", "Openai-Project", "Set-Cookie")), + llmproxy.WithAutoRouterInterceptor(interceptors.NewResponseHeaderBan("Openai-Organization", "Openai-Project", "Anthropic-Organization-Id", "Set-Cookie")), llmproxy.WithAutoRouterInterceptor(interceptors.NewAddRequestHeader(interceptors.NewHeader("User-Agent", "Agentuity AI Gateway/1.0"))), llmproxy.WithAutoRouterInterceptor(interceptors.NewAddResponseHeader(interceptors.NewHeader("Server", "Agentuity AI Gateway/1.0"))), llmproxy.WithAutoRouterFallbackProvider(providers[0]), @@ -219,9 +219,11 @@ func main() { } if costLookup != nil { - opts = append(opts, llmproxy.WithAutoRouterInterceptor(interceptors.NewBilling(costLookup, func(r llmproxy.BillingResult) { + billingCallback := func(r llmproxy.BillingResult) { logr.Info("Billing: provider=%s model=%s tokens=%d/%d cost=$%.6f", r.Provider, r.Model, r.PromptTokens, r.CompletionTokens, r.TotalCost) - }))) + } + opts = append(opts, llmproxy.WithAutoRouterInterceptor(interceptors.NewBilling(costLookup, billingCallback))) + opts = append(opts, llmproxy.WithAutoRouterBillingCalculator(llmproxy.NewBillingCalculator(costLookup, billingCallback))) } router := llmproxy.NewAutoRouter(opts...) @@ -268,6 +270,23 @@ func main() { logr.Info(" curl -X POST http://localhost:8080/ \\") logr.Info(" -H 'Content-Type: application/json' \\") logr.Info(" -d '{\"model\":\"gpt-4o\",\"input\":\"Hello\"}'") + logr.Info("") + logr.Info("Streaming examples:") + logr.Info(" # Streaming Chat Completions with usage (OpenAI)") + logr.Info(" # Note: stream_options.include_usage is required to get token counts") + logr.Info(" curl -X POST http://localhost:8080/ \\") + logr.Info(" -H 'Content-Type: application/json' \\") + logr.Info(" -d '{\"model\":\"gpt-4\",\"stream\":true,\"stream_options\":{\"include_usage\":true},\"messages\":[{\"role\":\"user\",\"content\":\"Tell me a short story\"}]}'") + logr.Info("") + logr.Info(" # Streaming Messages (Anthropic)") + logr.Info(" curl -X POST http://localhost:8080/ \\") + logr.Info(" -H 'Content-Type: application/json' \\") + logr.Info(" -d '{\"model\":\"claude-3-opus\",\"stream\":true,\"max_tokens\":1024,\"messages\":[{\"role\":\"user\",\"content\":\"Tell me a joke\"}]}'") + logr.Info("") + logr.Info(" # Streaming with provider prefix") + logr.Info(" curl -X POST http://localhost:8080/ \\") + logr.Info(" -H 'Content-Type: application/json' \\") + logr.Info(" -d '{\"model\":\"openai/gpt-4\",\"stream\":true,\"stream_options\":{\"include_usage\":true},\"messages\":[{\"role\":\"user\",\"content\":\"Hello\"}]}'") if err := http.ListenAndServe(":8080", nil); err != nil { log.Fatalf("server error: %v", err) diff --git a/extractor.go b/extractor.go index dead380..bd0c724 100644 --- a/extractor.go +++ b/extractor.go @@ -1,29 +1,52 @@ package llmproxy -import "net/http" - -// ResponseExtractor parses an upstream provider response and extracts metadata. -// -// Implementations handle provider-specific response formats and map them -// to the common ResponseMetadata structure. This allows the proxy to track -// token usage, costs, and other metrics in a provider-agnostic way. -// -// The extractor must return the raw response body bytes so the proxy can -// re-attach them to the response for the caller. This preserves any -// custom/unsupported fields in the original JSON. +import ( + "io" + "net/http" +) + type ResponseExtractor interface { - // Extract parses the HTTP response and returns unified metadata. - // - // The method reads and consumes the response body, parses it for metadata, - // and returns both the metadata and the raw body bytes. The proxy will - // re-attach the raw bytes to the response so the caller can read them. - // - // Parameters: - // - resp: The HTTP response from the upstream provider - // - // Returns: - // - metadata: Parsed response metadata (tokens, model, etc.) - // - rawBody: The original response body bytes (must be returned for forwarding) - // - error: Any parsing error Extract(resp *http.Response) (metadata ResponseMetadata, rawBody []byte, err error) } + +type StreamingResponseExtractor interface { + ResponseExtractor + ExtractStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) + IsStreamingResponse(resp *http.Response) bool +} + +type StreamingHandler interface { + HandleStream(resp *http.Response, w http.ResponseWriter, meta BodyMetadata) (ResponseMetadata, error) +} + +type DefaultStreamingHandler struct { + extractor StreamingResponseExtractor +} + +func NewDefaultStreamingHandler(extractor StreamingResponseExtractor) *DefaultStreamingHandler { + return &DefaultStreamingHandler{extractor: extractor} +} + +func (h *DefaultStreamingHandler) HandleStream(resp *http.Response, w http.ResponseWriter, meta BodyMetadata) (ResponseMetadata, error) { + rc := http.NewResponseController(w) + return h.extractor.ExtractStreamingWithController(resp, w, rc) +} + +type TeeReader struct { + r io.Reader + w io.Writer +} + +func NewTeeReader(r io.Reader, w io.Writer) *TeeReader { + return &TeeReader{r: r, w: w} +} + +func (t *TeeReader) Read(p []byte) (n int, err error) { + n, err = t.r.Read(p) + if n > 0 { + if _, writeErr := t.w.Write(p[:n]); writeErr != nil { + return n, writeErr + } + } + return +} diff --git a/go.mod b/go.mod index 73bb9a3..ceab32e 100644 --- a/go.mod +++ b/go.mod @@ -6,5 +6,9 @@ require go.opentelemetry.io/otel/trace v1.43.0 require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/klauspost/compress v1.15.15 // indirect + github.com/klauspost/cpuid/v2 v2.2.3 // indirect + github.com/minio/simdjson-go v0.4.5 // indirect go.opentelemetry.io/otel v1.43.0 // indirect + golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e // indirect ) diff --git a/go.sum b/go.sum index 86176c1..6d4b853 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw= +github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= +github.com/klauspost/cpuid/v2 v2.2.3 h1:sxCkb+qR91z4vsqw4vGGZlDgPz3G7gjaLyK3V8y70BU= +github.com/klauspost/cpuid/v2 v2.2.3/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/minio/simdjson-go v0.4.5 h1:r4IQwjRGmWCQ2VeMc7fGiilu1z5du0gJ/I/FsKwgo5A= +github.com/minio/simdjson-go v0.4.5/go.mod h1:eoNz0DcLQRyEDeaPr4Ru6JpjlZPzbA0IodxVJk8lO8E= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -12,5 +18,7 @@ go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e h1:CsOuNlbOuf0mzxJIefr6Q4uAUetRUwZE4qt7VfzP+xo= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/fastjson/extractor.go b/internal/fastjson/extractor.go new file mode 100644 index 0000000..0b9d7cc --- /dev/null +++ b/internal/fastjson/extractor.go @@ -0,0 +1,268 @@ +package fastjson + +import ( + "encoding/json" + "io" + + "github.com/agentuity/llmproxy" + "github.com/minio/simdjson-go" +) + +type UsageExtractor struct { + simdSupported bool +} + +func NewUsageExtractor() *UsageExtractor { + return &UsageExtractor{ + simdSupported: simdjson.SupportedCPU(), + } +} + +func (e *UsageExtractor) ExtractOpenAI(data []byte) (*llmproxy.Usage, *llmproxy.CacheUsage, error) { + if e.simdSupported && len(data) > 1024 { + return e.extractOpenAISimd(data) + } + return e.extractOpenAIStd(data) +} + +func (e *UsageExtractor) extractOpenAISimd(data []byte) (*llmproxy.Usage, *llmproxy.CacheUsage, error) { + pj, err := simdjson.Parse(data, nil, simdjson.WithCopyStrings(false)) + if err != nil { + return nil, nil, err + } + + iter := pj.Iter() + + var elem *simdjson.Element + + usageElem, err := iter.FindElement(elem, "usage") + if err != nil || usageElem == nil { + return &llmproxy.Usage{}, nil, nil + } + + usage := &llmproxy.Usage{} + var cacheUsage *llmproxy.CacheUsage + + usageIter := usageElem.Iter + obj, err := usageIter.Object(nil) + if err != nil { + return usage, nil, nil + } + + var tmpIter simdjson.Iter + for { + name, t, err := obj.NextElement(&tmpIter) + if err == io.EOF { + break + } + if err != nil { + break + } + + switch name { + case "prompt_tokens": + if t == simdjson.TypeInt { + v, _ := tmpIter.Int() + usage.PromptTokens = int(v) + } + case "completion_tokens": + if t == simdjson.TypeInt { + v, _ := tmpIter.Int() + usage.CompletionTokens = int(v) + } + case "total_tokens": + if t == simdjson.TypeInt { + v, _ := tmpIter.Int() + usage.TotalTokens = int(v) + } + case "prompt_tokens_details": + if t == simdjson.TypeObject { + cacheUsage = e.extractOpenAICacheUsage(&tmpIter) + } + } + } + + if usage.TotalTokens == 0 { + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + + return usage, cacheUsage, nil +} + +func (e *UsageExtractor) extractOpenAICacheUsage(iter *simdjson.Iter) *llmproxy.CacheUsage { + obj, err := iter.Object(nil) + if err != nil { + return nil + } + + cacheUsage := &llmproxy.CacheUsage{} + var tmpIter simdjson.Iter + found := false + + for { + name, t, err := obj.NextElement(&tmpIter) + if err == io.EOF { + break + } + if err != nil { + break + } + + if name == "cached_tokens" && t == simdjson.TypeInt { + v, _ := tmpIter.Int() + cacheUsage.CachedTokens = int(v) + found = true + } + } + + if !found { + return nil + } + return cacheUsage +} + +func (e *UsageExtractor) extractOpenAIStd(data []byte) (*llmproxy.Usage, *llmproxy.CacheUsage, error) { + var resp struct { + Usage *struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *struct { + CachedTokens int `json:"cached_tokens"` + } `json:"prompt_tokens_details"` + } `json:"usage"` + } + + if err := json.Unmarshal(data, &resp); err != nil { + return nil, nil, err + } + + if resp.Usage == nil { + return &llmproxy.Usage{}, nil, nil + } + + usage := &llmproxy.Usage{ + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + } + + var cacheUsage *llmproxy.CacheUsage + if resp.Usage.PromptTokensDetails != nil && resp.Usage.PromptTokensDetails.CachedTokens > 0 { + cacheUsage = &llmproxy.CacheUsage{ + CachedTokens: resp.Usage.PromptTokensDetails.CachedTokens, + } + } + + return usage, cacheUsage, nil +} + +func (e *UsageExtractor) ExtractAnthropic(data []byte) (*llmproxy.Usage, *llmproxy.CacheUsage, error) { + if e.simdSupported && len(data) > 1024 { + return e.extractAnthropicSimd(data) + } + return e.extractAnthropicStd(data) +} + +func (e *UsageExtractor) extractAnthropicSimd(data []byte) (*llmproxy.Usage, *llmproxy.CacheUsage, error) { + pj, err := simdjson.Parse(data, nil, simdjson.WithCopyStrings(false)) + if err != nil { + return nil, nil, err + } + + iter := pj.Iter() + + var elem *simdjson.Element + + usageElem, err := iter.FindElement(elem, "usage") + if err != nil || usageElem == nil { + return &llmproxy.Usage{}, nil, nil + } + + usage := &llmproxy.Usage{} + cacheUsage := &llmproxy.CacheUsage{} + + usageIter := usageElem.Iter + obj, err := usageIter.Object(nil) + if err != nil { + return usage, nil, nil + } + + var tmpIter simdjson.Iter + for { + name, t, err := obj.NextElement(&tmpIter) + if err == io.EOF { + break + } + if err != nil { + break + } + + switch name { + case "input_tokens": + if t == simdjson.TypeInt { + v, _ := tmpIter.Int() + usage.PromptTokens = int(v) + } + case "output_tokens": + if t == simdjson.TypeInt { + v, _ := tmpIter.Int() + usage.CompletionTokens = int(v) + } + case "cache_creation_input_tokens": + if t == simdjson.TypeInt { + v, _ := tmpIter.Int() + cacheUsage.CacheCreationInputTokens = int(v) + } + case "cache_read_input_tokens": + if t == simdjson.TypeInt { + v, _ := tmpIter.Int() + cacheUsage.CacheReadInputTokens = int(v) + } + } + } + + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + + hasCache := cacheUsage.CacheCreationInputTokens > 0 || cacheUsage.CacheReadInputTokens > 0 + if !hasCache { + cacheUsage = nil + } + + return usage, cacheUsage, nil +} + +func (e *UsageExtractor) extractAnthropicStd(data []byte) (*llmproxy.Usage, *llmproxy.CacheUsage, error) { + var resp struct { + Usage *struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` + } `json:"usage"` + } + + if err := json.Unmarshal(data, &resp); err != nil { + return nil, nil, err + } + + if resp.Usage == nil { + return &llmproxy.Usage{}, nil, nil + } + + usage := &llmproxy.Usage{ + PromptTokens: resp.Usage.InputTokens, + CompletionTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, + } + + var cacheUsage *llmproxy.CacheUsage + if resp.Usage.CacheCreationInputTokens > 0 || resp.Usage.CacheReadInputTokens > 0 { + cacheUsage = &llmproxy.CacheUsage{ + CacheCreationInputTokens: resp.Usage.CacheCreationInputTokens, + CacheReadInputTokens: resp.Usage.CacheReadInputTokens, + } + } + + return usage, cacheUsage, nil +} diff --git a/internal/fastjson/extractor_test.go b/internal/fastjson/extractor_test.go new file mode 100644 index 0000000..5549d52 --- /dev/null +++ b/internal/fastjson/extractor_test.go @@ -0,0 +1,197 @@ +package fastjson + +import ( + "testing" + + "github.com/agentuity/llmproxy" +) + +func TestUsageExtractor_ExtractOpenAI(t *testing.T) { + tests := []struct { + name string + input string + expectedPrompt int + expectedCompletion int + expectedCached int + }{ + { + name: "basic usage", + input: `{"id":"test","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150}}`, + expectedPrompt: 100, + expectedCompletion: 50, + expectedCached: 0, + }, + { + name: "with cache", + input: `{"id":"test","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150,"prompt_tokens_details":{"cached_tokens":80}}}`, + expectedPrompt: 100, + expectedCompletion: 50, + expectedCached: 80, + }, + { + name: "no usage", + input: `{"id":"test"}`, + expectedPrompt: 0, + expectedCompletion: 0, + expectedCached: 0, + }, + } + + extractor := NewUsageExtractor() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + usage, cacheUsage, err := extractor.ExtractOpenAI([]byte(tt.input)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if usage.PromptTokens != tt.expectedPrompt { + t.Errorf("expected prompt tokens %d, got %d", tt.expectedPrompt, usage.PromptTokens) + } + if usage.CompletionTokens != tt.expectedCompletion { + t.Errorf("expected completion tokens %d, got %d", tt.expectedCompletion, usage.CompletionTokens) + } + + if tt.expectedCached > 0 { + if cacheUsage == nil { + t.Error("expected cache usage, got nil") + } else if cacheUsage.CachedTokens != tt.expectedCached { + t.Errorf("expected cached tokens %d, got %d", tt.expectedCached, cacheUsage.CachedTokens) + } + } + }) + } +} + +func TestUsageExtractor_ExtractAnthropic(t *testing.T) { + tests := []struct { + name string + input string + expectedPrompt int + expectedCompletion int + expectedCacheRead int + }{ + { + name: "basic usage", + input: `{"id":"test","usage":{"input_tokens":100,"output_tokens":50}}`, + expectedPrompt: 100, + expectedCompletion: 50, + expectedCacheRead: 0, + }, + { + name: "with cache", + input: `{"id":"test","usage":{"input_tokens":50,"output_tokens":100,"cache_read_input_tokens":2000,"cache_creation_input_tokens":500}}`, + expectedPrompt: 50, + expectedCompletion: 100, + expectedCacheRead: 2000, + }, + { + name: "no usage", + input: `{"id":"test"}`, + expectedPrompt: 0, + expectedCompletion: 0, + expectedCacheRead: 0, + }, + } + + extractor := NewUsageExtractor() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + usage, cacheUsage, err := extractor.ExtractAnthropic([]byte(tt.input)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if usage.PromptTokens != tt.expectedPrompt { + t.Errorf("expected prompt tokens %d, got %d", tt.expectedPrompt, usage.PromptTokens) + } + if usage.CompletionTokens != tt.expectedCompletion { + t.Errorf("expected completion tokens %d, got %d", tt.expectedCompletion, usage.CompletionTokens) + } + + if tt.expectedCacheRead > 0 { + if cacheUsage == nil { + t.Error("expected cache usage, got nil") + } else if cacheUsage.CacheReadInputTokens != tt.expectedCacheRead { + t.Errorf("expected cache read tokens %d, got %d", tt.expectedCacheRead, cacheUsage.CacheReadInputTokens) + } + } + }) + } +} + +func BenchmarkUsageExtractor_ExtractOpenAI_Std(b *testing.B) { + extractor := NewUsageExtractor() + data := []byte(`{"id":"test","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150,"prompt_tokens_details":{"cached_tokens":80}}}`) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = extractor.ExtractOpenAI(data) + } +} + +func BenchmarkUsageExtractor_ExtractAnthropic_Std(b *testing.B) { + extractor := NewUsageExtractor() + data := []byte(`{"id":"test","usage":{"input_tokens":100,"output_tokens":50,"cache_read_input_tokens":2000}}`) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = extractor.ExtractAnthropic(data) + } +} + +func TestBillingCalculator(t *testing.T) { + lookup := func(provider, model string) (llmproxy.CostInfo, bool) { + if model == "gpt-4" { + return llmproxy.CostInfo{Input: 30, Output: 60}, true + } + return llmproxy.CostInfo{}, false + } + + var results []llmproxy.BillingResult + onResult := func(r llmproxy.BillingResult) { + results = append(results, r) + } + + calculator := llmproxy.NewBillingCalculator(lookup, onResult) + + meta := llmproxy.BodyMetadata{ + Model: "gpt-4", + Custom: map[string]any{"provider": "openai"}, + } + + respMeta := &llmproxy.ResponseMetadata{ + Usage: llmproxy.Usage{ + PromptTokens: 100, + CompletionTokens: 50, + }, + } + + result := calculator.Calculate(meta, respMeta) + + if result == nil { + t.Fatal("expected result, got nil") + } + + if result.PromptTokens != 100 { + t.Errorf("expected prompt tokens 100, got %d", result.PromptTokens) + } + if result.CompletionTokens != 50 { + t.Errorf("expected completion tokens 50, got %d", result.CompletionTokens) + } + + expectedCost := (30.0 * 100 / 1_000_000) + (60.0 * 50 / 1_000_000) + if result.TotalCost != expectedCost { + t.Errorf("expected cost %.6f, got %.6f", expectedCost, result.TotalCost) + } + + if len(results) != 1 { + t.Errorf("expected 1 result callback, got %d", len(results)) + } + + if _, ok := respMeta.Custom["billing_result"]; !ok { + t.Error("expected billing_result in custom map") + } +} diff --git a/providers/anthropic/provider.go b/providers/anthropic/provider.go index 84b0b0e..f6e85b5 100644 --- a/providers/anthropic/provider.go +++ b/providers/anthropic/provider.go @@ -4,17 +4,10 @@ import ( "github.com/agentuity/llmproxy" ) -// Provider is an Anthropic provider implementation. type Provider struct { *llmproxy.BaseProvider } -// New creates a new Anthropic provider with the given API key. -// The provider is configured to use Anthropic's API endpoint (https://api.anthropic.com). -// -// Example: -// -// provider, _ := anthropic.New("sk-ant-your-api-key") func New(apiKey string) (*Provider, error) { resolver, err := NewResolver("https://api.anthropic.com") if err != nil { @@ -25,17 +18,12 @@ func New(apiKey string) (*Provider, error) { BaseProvider: llmproxy.NewBaseProvider("anthropic", llmproxy.WithBodyParser(&Parser{}), llmproxy.WithRequestEnricher(NewEnricher(apiKey)), - llmproxy.WithResponseExtractor(NewExtractor()), + llmproxy.WithResponseExtractor(NewStreamingExtractor()), llmproxy.WithURLResolver(resolver), ), }, nil } -// NewWithVersion creates a new Anthropic provider with a specific API version. -// -// Example: -// -// provider, _ := anthropic.NewWithVersion("sk-ant-your-api-key", "2024-01-01") func NewWithVersion(apiKey, version string) (*Provider, error) { resolver, err := NewResolver("https://api.anthropic.com") if err != nil { @@ -46,7 +34,7 @@ func NewWithVersion(apiKey, version string) (*Provider, error) { BaseProvider: llmproxy.NewBaseProvider("anthropic", llmproxy.WithBodyParser(&Parser{}), llmproxy.WithRequestEnricher(NewEnricherWithVersion(apiKey, version)), - llmproxy.WithResponseExtractor(NewExtractor()), + llmproxy.WithResponseExtractor(NewStreamingExtractor()), llmproxy.WithURLResolver(resolver), ), }, nil diff --git a/providers/anthropic/streaming_extractor.go b/providers/anthropic/streaming_extractor.go new file mode 100644 index 0000000..724abc0 --- /dev/null +++ b/providers/anthropic/streaming_extractor.go @@ -0,0 +1,217 @@ +package anthropic + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "net/http" + + "github.com/agentuity/llmproxy" +) + +type StreamingExtractor struct { + *Extractor +} + +func NewStreamingExtractor() *StreamingExtractor { + return &StreamingExtractor{ + Extractor: NewExtractor(), + } +} + +func (e *StreamingExtractor) IsStreamingResponse(resp *http.Response) bool { + return llmproxy.IsSSEStream(resp.Header.Get("Content-Type")) +} + +func (e *StreamingExtractor) ExtractStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + if !e.IsStreamingResponse(resp) { + return e.extractNonStreamingWithController(resp, w, rc) + } + return e.extractStreamingWithController(resp, w, rc) +} + +func (e *StreamingExtractor) extractNonStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + var buf bytes.Buffer + tee := io.TeeReader(resp.Body, &buf) + + meta, _, err := e.Extractor.Extract(&http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header, + Body: io.NopCloser(tee), + }) + if err != nil { + return meta, err + } + + readBuf := make([]byte, 1024*512) + for { + n, err := buf.Read(readBuf) + if err != nil { + if err == io.EOF { + if n > 0 { + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return meta, writeErr + } + } + break + } + if errors.Is(err, context.Canceled) { + break + } + return meta, err + } + if n == 0 { + break + } + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return meta, writeErr + } + if flushErr := rc.Flush(); flushErr != nil { + return meta, flushErr + } + } + + return meta, nil +} + +func (e *StreamingExtractor) extractStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + meta := llmproxy.ResponseMetadata{ + Choices: make([]llmproxy.Choice, 0), + Custom: make(map[string]any), + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + + var accumulatedUsage *llmproxy.StreamingUsage + var messageStart *llmproxy.AnthropicStreamMessage + + for scanner.Scan() { + line := scanner.Bytes() + + if len(line) == 0 { + continue + } + + if bytes.HasPrefix(line, []byte("data: ")) { + data := bytes.TrimPrefix(line, []byte("data: ")) + data = bytes.TrimSpace(data) + + event, err := llmproxy.ParseAnthropicSSEEvent(data) + if err != nil { + continue + } + + if event == nil { + continue + } + + switch event.Type { + case "message_start": + if event.Message != nil { + messageStart = event.Message + meta.ID = event.Message.ID + meta.Model = event.Message.Model + if event.Message.Usage != nil { + usage := &llmproxy.StreamingUsage{ + PromptTokens: event.Message.Usage.InputTokens, + } + if event.Message.Usage.CacheCreationInputTokens > 0 || event.Message.Usage.CacheReadInputTokens > 0 { + usage.CacheUsage = &llmproxy.CacheUsage{ + CacheCreationInputTokens: event.Message.Usage.CacheCreationInputTokens, + CacheReadInputTokens: event.Message.Usage.CacheReadInputTokens, + } + } + accumulatedUsage = usage + } + } + case "content_block_start": + if event.ContentBlock != nil && event.Index == 0 { + meta.Choices = append(meta.Choices, llmproxy.Choice{ + Index: 0, + Message: &llmproxy.Message{ + Role: "assistant", + }, + }) + } + case "content_block_delta": + if event.Delta != nil && event.Delta.Type == "text_delta" { + if len(meta.Choices) > 0 { + if meta.Choices[0].Message == nil { + meta.Choices[0].Message = &llmproxy.Message{Role: "assistant"} + } + } + } + case "message_delta": + if event.Usage != nil { + if accumulatedUsage == nil { + accumulatedUsage = &llmproxy.StreamingUsage{} + } + accumulatedUsage.CompletionTokens = event.Usage.OutputTokens + if event.Usage.CacheReadInputTokens > 0 || event.Usage.CacheCreationInputTokens > 0 { + if accumulatedUsage.CacheUsage == nil { + accumulatedUsage.CacheUsage = &llmproxy.CacheUsage{} + } + accumulatedUsage.CacheUsage.CacheReadInputTokens = event.Usage.CacheReadInputTokens + accumulatedUsage.CacheUsage.CacheCreationInputTokens = event.Usage.CacheCreationInputTokens + } + } + if event.Delta != nil && event.Delta.StopReason != "" { + if len(meta.Choices) > 0 { + meta.Choices[0].FinishReason = event.Delta.StopReason + } + } + case "message_stop": + } + + if _, err := w.Write(line); err != nil { + return meta, err + } + if _, err := w.Write([]byte("\n\n")); err != nil { + return meta, err + } + _ = rc.Flush() + } else if bytes.HasPrefix(line, []byte("event: ")) { + if _, err := w.Write(line); err != nil { + return meta, err + } + if _, err := w.Write([]byte("\n")); err != nil { + return meta, err + } + _ = rc.Flush() + } + } + + if err := scanner.Err(); err != nil { + return meta, err + } + + if accumulatedUsage != nil { + meta.Usage = llmproxy.Usage{ + PromptTokens: accumulatedUsage.PromptTokens, + CompletionTokens: accumulatedUsage.CompletionTokens, + TotalTokens: accumulatedUsage.PromptTokens + accumulatedUsage.CompletionTokens, + } + if accumulatedUsage.CacheUsage != nil { + meta.Custom["cache_usage"] = *accumulatedUsage.CacheUsage + } + } + + if messageStart != nil { + if meta.ID == "" { + meta.ID = messageStart.ID + } + if meta.Model == "" { + meta.Model = messageStart.Model + } + } + + return meta, nil +} diff --git a/providers/azure/provider.go b/providers/azure/provider.go index 3257c26..6e331f9 100644 --- a/providers/azure/provider.go +++ b/providers/azure/provider.go @@ -66,7 +66,7 @@ func New(resourceName, deploymentID, apiVersion string, opts ...Option) (*Provid BaseProvider: llmproxy.NewBaseProvider("azure", llmproxy.WithBodyParser(&openai_compatible.Parser{}), llmproxy.WithRequestEnricher(enricher), - llmproxy.WithResponseExtractor(&openai_compatible.Extractor{}), + llmproxy.WithResponseExtractor(openai_compatible.NewStreamingExtractor()), llmproxy.WithURLResolver(resolver), ), resourceName: resourceName, @@ -92,7 +92,7 @@ func NewWithDynamicDeployment(resourceName, apiVersion string, opts ...Option) ( BaseProvider: llmproxy.NewBaseProvider("azure", llmproxy.WithBodyParser(&openai_compatible.Parser{}), llmproxy.WithRequestEnricher(enricher), - llmproxy.WithResponseExtractor(&openai_compatible.Extractor{}), + llmproxy.WithResponseExtractor(openai_compatible.NewStreamingExtractor()), llmproxy.WithURLResolver(resolver), ), resourceName: resourceName, diff --git a/providers/bedrock/provider.go b/providers/bedrock/provider.go index 561dedf..86b95b3 100644 --- a/providers/bedrock/provider.go +++ b/providers/bedrock/provider.go @@ -4,39 +4,21 @@ import ( "github.com/agentuity/llmproxy" ) -// Provider is an AWS Bedrock provider implementation. type Provider struct { *llmproxy.BaseProvider } -// New creates a new Bedrock provider with AWS credentials. -// Uses the Converse API which provides a unified format across models. -// -// Parameters: -// - region: AWS region (e.g., "us-east-1", "us-west-2", "eu-west-1") -// - accessKeyID: AWS Access Key ID -// - secretAccessKey: AWS Secret Access Key -// - sessionToken: AWS Session Token (optional, pass "" for long-term credentials) -// -// Example: -// -// // Long-term credentials -// provider, _ := bedrock.New("us-east-1", "AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", "") -// -// // Temporary credentials (from AssumeRole, etc.) -// provider, _ := bedrock.New("us-east-1", "AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", "AQoDYXdzEJr...") func New(region, accessKeyID, secretAccessKey, sessionToken string) (*Provider, error) { return &Provider{ BaseProvider: llmproxy.NewBaseProvider("bedrock", llmproxy.WithBodyParser(&Parser{}), llmproxy.WithRequestEnricher(NewEnricher(region, accessKeyID, secretAccessKey, sessionToken)), - llmproxy.WithResponseExtractor(NewExtractor()), + llmproxy.WithResponseExtractor(NewStreamingExtractor()), llmproxy.WithURLResolver(NewResolver(region)), ), }, nil } -// NewWithConfig creates a Bedrock provider with full configuration. func NewWithConfig(region, accessKeyID, secretAccessKey, sessionToken string, useConverseAPI bool) (*Provider, error) { var resolver llmproxy.URLResolver if useConverseAPI { @@ -49,7 +31,7 @@ func NewWithConfig(region, accessKeyID, secretAccessKey, sessionToken string, us BaseProvider: llmproxy.NewBaseProvider("bedrock", llmproxy.WithBodyParser(&Parser{}), llmproxy.WithRequestEnricher(NewEnricher(region, accessKeyID, secretAccessKey, sessionToken)), - llmproxy.WithResponseExtractor(NewExtractor()), + llmproxy.WithResponseExtractor(NewStreamingExtractor()), llmproxy.WithURLResolver(resolver), ), }, nil diff --git a/providers/bedrock/streaming_extractor.go b/providers/bedrock/streaming_extractor.go new file mode 100644 index 0000000..b4e74f8 --- /dev/null +++ b/providers/bedrock/streaming_extractor.go @@ -0,0 +1,73 @@ +package bedrock + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + + "github.com/agentuity/llmproxy" +) + +type StreamingExtractor struct { + *Extractor +} + +func NewStreamingExtractor() *StreamingExtractor { + return &StreamingExtractor{ + Extractor: NewExtractor(), + } +} + +func (e *StreamingExtractor) IsStreamingResponse(resp *http.Response) bool { + return llmproxy.IsSSEStream(resp.Header.Get("Content-Type")) +} + +func (e *StreamingExtractor) ExtractStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + return e.extractNonStreamingWithController(resp, w, rc) +} + +func (e *StreamingExtractor) extractNonStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + var buf bytes.Buffer + tee := io.TeeReader(resp.Body, &buf) + + meta, _, err := e.Extractor.Extract(&http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header, + Body: io.NopCloser(tee), + }) + if err != nil { + return meta, err + } + + readBuf := make([]byte, 1024*512) + for { + n, err := buf.Read(readBuf) + if err != nil { + if err == io.EOF { + if n > 0 { + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return meta, writeErr + } + } + break + } + if errors.Is(err, context.Canceled) { + break + } + return meta, err + } + if n == 0 { + break + } + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return meta, writeErr + } + if flushErr := rc.Flush(); flushErr != nil { + return meta, flushErr + } + } + + return meta, nil +} diff --git a/providers/googleai/provider.go b/providers/googleai/provider.go index 0127605..88ac9df 100644 --- a/providers/googleai/provider.go +++ b/providers/googleai/provider.go @@ -25,7 +25,7 @@ func New(apiKey string) (*Provider, error) { BaseProvider: llmproxy.NewBaseProvider("googleai", llmproxy.WithBodyParser(&Parser{}), llmproxy.WithRequestEnricher(NewEnricher(apiKey)), - llmproxy.WithResponseExtractor(NewExtractor()), + llmproxy.WithResponseExtractor(NewStreamingExtractor()), llmproxy.WithURLResolver(resolver), ), }, nil diff --git a/providers/googleai/streaming_extractor.go b/providers/googleai/streaming_extractor.go new file mode 100644 index 0000000..de6201e --- /dev/null +++ b/providers/googleai/streaming_extractor.go @@ -0,0 +1,73 @@ +package googleai + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + + "github.com/agentuity/llmproxy" +) + +type StreamingExtractor struct { + *Extractor +} + +func NewStreamingExtractor() *StreamingExtractor { + return &StreamingExtractor{ + Extractor: NewExtractor(), + } +} + +func (e *StreamingExtractor) IsStreamingResponse(resp *http.Response) bool { + return llmproxy.IsSSEStream(resp.Header.Get("Content-Type")) +} + +func (e *StreamingExtractor) ExtractStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + return e.extractNonStreamingWithController(resp, w, rc) +} + +func (e *StreamingExtractor) extractNonStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + var buf bytes.Buffer + tee := io.TeeReader(resp.Body, &buf) + + meta, _, err := e.Extractor.Extract(&http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header, + Body: io.NopCloser(tee), + }) + if err != nil { + return meta, err + } + + readBuf := make([]byte, 1024*512) + for { + n, err := buf.Read(readBuf) + if err != nil { + if err == io.EOF { + if n > 0 { + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return meta, writeErr + } + } + break + } + if errors.Is(err, context.Canceled) { + break + } + return meta, err + } + if n == 0 { + break + } + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return meta, writeErr + } + if flushErr := rc.Flush(); flushErr != nil { + return meta, flushErr + } + } + + return meta, nil +} diff --git a/providers/openai_compatible/multiapi.go b/providers/openai_compatible/multiapi.go index 8b65cb0..4b4e013 100644 --- a/providers/openai_compatible/multiapi.go +++ b/providers/openai_compatible/multiapi.go @@ -2,7 +2,9 @@ package openai_compatible import ( "bytes" + "context" "encoding/json" + "errors" "io" "net/http" @@ -56,8 +58,6 @@ func (e *MultiAPIExtractor) Extract(resp *http.Response) (llmproxy.ResponseMetad } resp.Body.Close() - // Detect response type by inspecting response-specific fields - // Responses API has "output" and "status", Chat Completions has "choices" var raw map[string]any isResponsesAPI := false if err := json.Unmarshal(body, &raw); err == nil { @@ -68,7 +68,6 @@ func (e *MultiAPIExtractor) Extract(resp *http.Response) (llmproxy.ResponseMetad } } - // Restore body for downstream extractors resp.Body = io.NopCloser(bytes.NewReader(body)) if isResponsesAPI { @@ -76,3 +75,70 @@ func (e *MultiAPIExtractor) Extract(resp *http.Response) (llmproxy.ResponseMetad } return e.chatCompletionsExtractor.Extract(resp) } + +type StreamingMultiAPIExtractor struct { + *MultiAPIExtractor + chatCompletionsStreaming *StreamingExtractor +} + +func NewStreamingMultiAPIExtractor() *StreamingMultiAPIExtractor { + return &StreamingMultiAPIExtractor{ + MultiAPIExtractor: NewMultiAPIExtractor(), + chatCompletionsStreaming: NewStreamingExtractor(), + } +} + +func (e *StreamingMultiAPIExtractor) IsStreamingResponse(resp *http.Response) bool { + return llmproxy.IsSSEStream(resp.Header.Get("Content-Type")) +} + +func (e *StreamingMultiAPIExtractor) ExtractStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + if !e.IsStreamingResponse(resp) { + return e.extractNonStreamingWithController(resp, w, rc) + } + return e.chatCompletionsStreaming.ExtractStreamingWithController(resp, w, rc) +} + +func (e *StreamingMultiAPIExtractor) extractNonStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + var buf bytes.Buffer + tee := io.TeeReader(resp.Body, &buf) + + meta, _, err := e.MultiAPIExtractor.Extract(&http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header, + Body: io.NopCloser(tee), + }) + if err != nil { + return meta, err + } + + readBuf := make([]byte, 1024*512) + for { + n, err := buf.Read(readBuf) + if err != nil { + if err == io.EOF { + if n > 0 { + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return meta, writeErr + } + } + break + } + if errors.Is(err, context.Canceled) { + break + } + return meta, err + } + if n == 0 { + break + } + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return meta, writeErr + } + if flushErr := rc.Flush(); flushErr != nil { + return meta, flushErr + } + } + + return meta, nil +} diff --git a/providers/openai_compatible/provider.go b/providers/openai_compatible/provider.go index 6830f04..95810e5 100644 --- a/providers/openai_compatible/provider.go +++ b/providers/openai_compatible/provider.go @@ -30,7 +30,7 @@ func New(name, apiKey, baseURL string) (*Provider, error) { BaseProvider: llmproxy.NewBaseProvider(name, llmproxy.WithBodyParser(&Parser{}), llmproxy.WithRequestEnricher(NewEnricher(apiKey)), - llmproxy.WithResponseExtractor(NewExtractor()), + llmproxy.WithResponseExtractor(NewStreamingExtractor()), llmproxy.WithURLResolver(resolver), ), }, nil @@ -46,7 +46,7 @@ func NewMultiAPI(name, apiKey, baseURL string) (*Provider, error) { BaseProvider: llmproxy.NewBaseProvider(name, llmproxy.WithBodyParser(NewMultiAPIParser()), llmproxy.WithRequestEnricher(NewEnricher(apiKey)), - llmproxy.WithResponseExtractor(NewMultiAPIExtractor()), + llmproxy.WithResponseExtractor(NewStreamingMultiAPIExtractor()), llmproxy.WithURLResolver(resolver), ), }, nil diff --git a/providers/openai_compatible/streaming_extractor.go b/providers/openai_compatible/streaming_extractor.go new file mode 100644 index 0000000..6eb8b5f --- /dev/null +++ b/providers/openai_compatible/streaming_extractor.go @@ -0,0 +1,192 @@ +package openai_compatible + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "net/http" + + "github.com/agentuity/llmproxy" +) + +type StreamingExtractor struct { + *Extractor +} + +func NewStreamingExtractor() *StreamingExtractor { + return &StreamingExtractor{ + Extractor: NewExtractor(), + } +} + +func (e *StreamingExtractor) IsStreamingResponse(resp *http.Response) bool { + return llmproxy.IsSSEStream(resp.Header.Get("Content-Type")) +} + +func (e *StreamingExtractor) ExtractStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + if !e.IsStreamingResponse(resp) { + return e.extractNonStreamingWithController(resp, w, rc) + } + + return e.extractStreamingWithController(resp, w, rc) +} + +func (e *StreamingExtractor) extractNonStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + var buf bytes.Buffer + tee := io.TeeReader(resp.Body, &buf) + + meta, _, err := e.Extractor.Extract(&http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header, + Body: io.NopCloser(tee), + }) + if err != nil { + return meta, err + } + + readBuf := make([]byte, 1024*512) + for { + n, err := buf.Read(readBuf) + if err != nil { + if err == io.EOF { + if n > 0 { + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return meta, writeErr + } + } + break + } + if errors.Is(err, context.Canceled) { + break + } + return meta, err + } + if n == 0 { + break + } + if _, writeErr := w.Write(readBuf[:n]); writeErr != nil { + return meta, writeErr + } + if flushErr := rc.Flush(); flushErr != nil { + return meta, flushErr + } + } + + return meta, nil +} + +func (e *StreamingExtractor) extractStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + meta := llmproxy.ResponseMetadata{ + Choices: make([]llmproxy.Choice, 0), + Custom: make(map[string]any), + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + + var lastChunk *llmproxy.OpenAIStreamChunk + var accumulatedUsage *llmproxy.StreamingUsage + + for scanner.Scan() { + line := scanner.Bytes() + + if len(line) == 0 { + continue + } + + if bytes.HasPrefix(line, []byte("data: ")) { + data := bytes.TrimPrefix(line, []byte("data: ")) + data = bytes.TrimSpace(data) + + if bytes.Equal(data, []byte("[DONE]")) { + if _, err := w.Write([]byte("data: [DONE]\n\n")); err != nil { + return meta, err + } + _ = rc.Flush() + break + } + + chunk, err := llmproxy.ParseOpenAISSEEvent(data) + if err != nil { + continue + } + + if chunk == nil { + continue + } + + lastChunk = chunk + + if chunk.ID != "" { + meta.ID = chunk.ID + } + if chunk.Model != "" { + meta.Model = chunk.Model + } + if chunk.Object != "" { + meta.Object = chunk.Object + } + + if chunk.Usage != nil { + usage := llmproxy.ExtractUsageFromOpenAIChunk(chunk) + if usage != nil { + accumulatedUsage = usage + } + } + + if _, err := w.Write(line); err != nil { + return meta, err + } + if _, err := w.Write([]byte("\n\n")); err != nil { + return meta, err + } + _ = rc.Flush() + } else { + if _, err := w.Write(line); err != nil { + return meta, err + } + if _, err := w.Write([]byte("\n")); err != nil { + return meta, err + } + _ = rc.Flush() + } + } + + if err := scanner.Err(); err != nil { + return meta, err + } + + if accumulatedUsage != nil { + meta.Usage = llmproxy.Usage{ + PromptTokens: accumulatedUsage.PromptTokens, + CompletionTokens: accumulatedUsage.CompletionTokens, + TotalTokens: accumulatedUsage.TotalTokens, + } + if accumulatedUsage.CacheUsage != nil { + meta.Custom["cache_usage"] = *accumulatedUsage.CacheUsage + } + } else if lastChunk != nil { + for _, choice := range lastChunk.Choices { + c := llmproxy.Choice{ + Index: choice.Index, + FinishReason: choice.FinishReason, + } + if choice.Delta != nil { + c.Delta = &llmproxy.Message{ + Role: choice.Delta.Role, + Content: choice.Delta.Content, + } + } + meta.Choices = append(meta.Choices, c) + } + } + + return meta, nil +} diff --git a/providers/openai_compatible/streaming_extractor_test.go b/providers/openai_compatible/streaming_extractor_test.go new file mode 100644 index 0000000..c25dd4c --- /dev/null +++ b/providers/openai_compatible/streaming_extractor_test.go @@ -0,0 +1,147 @@ +package openai_compatible + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/agentuity/llmproxy" +) + +func TestStreamingExtractor_ExtractStreaming(t *testing.T) { + streamData := `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]} + +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]} + +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}} + +data: [DONE] + +` + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(streamData)), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingExtractor() + + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.ID != "chatcmpl-123" { + t.Errorf("expected ID 'chatcmpl-123', got %q", meta.ID) + } + if meta.Model != "gpt-4" { + t.Errorf("expected model 'gpt-4', got %q", meta.Model) + } + if meta.Usage.PromptTokens != 10 { + t.Errorf("expected prompt tokens 10, got %d", meta.Usage.PromptTokens) + } + if meta.Usage.CompletionTokens != 5 { + t.Errorf("expected completion tokens 5, got %d", meta.Usage.CompletionTokens) + } + + output := recorder.Body.String() + if !strings.Contains(output, "data: ") { + t.Error("expected SSE data format in output") + } + if !strings.Contains(output, "[DONE]") { + t.Error("expected [DONE] in output") + } +} + +func TestStreamingExtractor_IsStreamingResponse(t *testing.T) { + extractor := NewStreamingExtractor() + + tests := []struct { + contentType string + expected bool + }{ + {"text/event-stream", true}, + {"text/event-stream; charset=utf-8", true}, + {"application/json", false}, + {"text/plain", false}, + } + + for _, tt := range tests { + t.Run(tt.contentType, func(t *testing.T) { + resp := &http.Response{ + Header: http.Header{"Content-Type": []string{tt.contentType}}, + } + result := extractor.IsStreamingResponse(resp) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestStreamingExtractor_NonStreamingFallback(t *testing.T) { + extractor := NewStreamingExtractor() + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewReader([]byte(`{"id":"test","model":"gpt-4","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150}}`))), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.Usage.PromptTokens != 100 { + t.Errorf("expected prompt tokens 100, got %d", meta.Usage.PromptTokens) + } +} + +func TestStreamingExtractor_ExtractStreamingWithCache(t *testing.T) { + streamData := `data: {"id":"chatcmpl-123","model":"gpt-4","choices":[{"index":0,"delta":{"content":"test"}}]} + +data: {"id":"chatcmpl-123","model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150,"prompt_tokens_details":{"cached_tokens":80}}} + +data: [DONE] + +` + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(streamData)), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingExtractor() + + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.Usage.PromptTokens != 100 { + t.Errorf("expected prompt tokens 100, got %d", meta.Usage.PromptTokens) + } + + cacheUsage, ok := meta.Custom["cache_usage"].(llmproxy.CacheUsage) + if !ok { + t.Fatal("expected cache_usage in custom map") + } + if cacheUsage.CachedTokens != 80 { + t.Errorf("expected cached tokens 80, got %d", cacheUsage.CachedTokens) + } +} diff --git a/providers/perplexity/provider.go b/providers/perplexity/provider.go index 3e1e3b1..27d6cac 100644 --- a/providers/perplexity/provider.go +++ b/providers/perplexity/provider.go @@ -19,7 +19,7 @@ func New(apiKey string) (*Provider, error) { BaseProvider: llmproxy.NewBaseProvider("perplexity", llmproxy.WithBodyParser(&openai_compatible.Parser{}), llmproxy.WithRequestEnricher(openai_compatible.NewEnricher(apiKey)), - llmproxy.WithResponseExtractor(openai_compatible.NewExtractor()), + llmproxy.WithResponseExtractor(openai_compatible.NewStreamingExtractor()), llmproxy.WithURLResolver(resolver), ), }, nil diff --git a/streaming.go b/streaming.go new file mode 100644 index 0000000..1681308 --- /dev/null +++ b/streaming.go @@ -0,0 +1,291 @@ +package llmproxy + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "io" + "strings" +) + +var ( + ErrStreamComplete = errors.New("stream complete") +) + +type SSEEvent struct { + ID []byte + Event []byte + Data []byte + Retry []byte +} + +type SSEParser struct { + scanner *bufio.Scanner +} + +func NewSSEParser(r io.Reader) *SSEParser { + return &SSEParser{ + scanner: bufio.NewScanner(r), + } +} + +func (p *SSEParser) Next() (*SSEEvent, error) { + var event SSEEvent + + for p.scanner.Scan() { + line := p.scanner.Bytes() + + if len(line) == 0 { + if len(event.Data) > 0 { + return &event, nil + } + continue + } + + if line[0] == ':' { + continue + } + + colon := bytes.IndexByte(line, ':') + if colon < 0 { + event.Data = append(event.Data, line...) + continue + } + + field := line[:colon] + value := line[colon+1:] + + if len(value) > 0 && value[0] == ' ' { + value = value[1:] + } + + switch string(field) { + case "id": + event.ID = append(event.ID[:0], value...) + case "event": + event.Event = append(event.Event[:0], value...) + case "data": + if len(event.Data) > 0 { + event.Data = append(event.Data, '\n') + } + event.Data = append(event.Data, value...) + case "retry": + event.Retry = append(event.Retry[:0], value...) + } + } + + if err := p.scanner.Err(); err != nil { + return nil, err + } + + if len(event.Data) > 0 { + return &event, nil + } + + return nil, io.EOF +} + +type StreamingUsage struct { + PromptTokens int + CompletionTokens int + TotalTokens int + CacheUsage *CacheUsage +} + +type OpenAIStreamChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []OpenAIStreamChoice `json:"choices"` + Usage *OpenAIStreamUsage `json:"usage,omitempty"` +} + +type OpenAIStreamChoice struct { + Index int `json:"index"` + Delta *OpenAIStreamDelta `json:"delta,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Logprobs *OpenAIStreamLogprobs `json:"logprobs,omitempty"` +} + +type OpenAIStreamDelta struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` +} + +type OpenAIStreamLogprobs struct { + Content []OpenAIStreamLogprobContent `json:"content,omitempty"` +} + +type OpenAIStreamLogprobContent struct { + Token string `json:"token"` + Logprob float64 `json:"logprob"` +} + +type OpenAIStreamUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *OpenAIStreamPromptDetails `json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails *OpenAIStreamCompletionDetails `json:"completion_tokens_details,omitempty"` +} + +type OpenAIStreamPromptDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` + AudioTokens int `json:"audio_tokens,omitempty"` +} + +type OpenAIStreamCompletionDetails struct { + ReasoningTokens int `json:"reasoning_tokens,omitempty"` + AudioTokens int `json:"audio_tokens,omitempty"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"` +} + +func ParseOpenAISSEEvent(data []byte) (*OpenAIStreamChunk, error) { + data = bytes.TrimSpace(data) + if len(data) == 0 { + return nil, nil + } + + if bytes.Equal(data, []byte("[DONE]")) { + return nil, ErrStreamComplete + } + + var chunk OpenAIStreamChunk + if err := json.Unmarshal(data, &chunk); err != nil { + return nil, err + } + + return &chunk, nil +} + +type AnthropicStreamEvent struct { + Type string `json:"type"` + Index int `json:"index,omitempty"` + Delta *AnthropicStreamDelta `json:"delta,omitempty"` + ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"` + Usage *AnthropicStreamUsage `json:"usage,omitempty"` + Message *AnthropicStreamMessage `json:"message,omitempty"` +} + +type AnthropicStreamDelta struct { + Type string `json:"type,omitempty"` + Text string `json:"text,omitempty"` + StopReason string `json:"stop_reason,omitempty"` +} + +type AnthropicContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +type AnthropicStreamUsage struct { + InputTokens int `json:"input_tokens,omitempty"` + OutputTokens int `json:"output_tokens,omitempty"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"` + CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` +} + +type AnthropicStreamMessage struct { + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Role string `json:"role,omitempty"` + Content []AnthropicContentBlock `json:"content,omitempty"` + Model string `json:"model,omitempty"` + StopReason string `json:"stop_reason,omitempty"` + Usage *AnthropicStreamUsage `json:"usage,omitempty"` +} + +func ParseAnthropicSSEEvent(data []byte) (*AnthropicStreamEvent, error) { + data = bytes.TrimSpace(data) + if len(data) == 0 { + return nil, nil + } + + var event AnthropicStreamEvent + if err := json.Unmarshal(data, &event); err != nil { + return nil, err + } + + return &event, nil +} + +func IsSSEStream(contentType string) bool { + return strings.Contains(strings.ToLower(contentType), "text/event-stream") +} + +func ExtractUsageFromOpenAIChunk(chunk *OpenAIStreamChunk) *StreamingUsage { + if chunk == nil || chunk.Usage == nil { + return nil + } + + usage := &StreamingUsage{ + PromptTokens: chunk.Usage.PromptTokens, + CompletionTokens: chunk.Usage.CompletionTokens, + TotalTokens: chunk.Usage.TotalTokens, + } + + if chunk.Usage.PromptTokensDetails != nil && chunk.Usage.PromptTokensDetails.CachedTokens > 0 { + usage.CacheUsage = &CacheUsage{ + CachedTokens: chunk.Usage.PromptTokensDetails.CachedTokens, + } + } + + return usage +} + +func ExtractUsageFromAnthropicEvent(event *AnthropicStreamEvent) *StreamingUsage { + if event == nil { + return nil + } + + var usage *AnthropicStreamUsage + + switch event.Type { + case "message_start": + if event.Message != nil && event.Message.Usage != nil { + usage = event.Message.Usage + } + case "message_delta": + if event.Usage != nil { + usage = event.Usage + } + case "message_stop": + return nil + default: + return nil + } + + if usage == nil { + return nil + } + + result := &StreamingUsage{ + PromptTokens: usage.InputTokens, + CompletionTokens: usage.OutputTokens, + } + + if usage.CacheCreationInputTokens > 0 || usage.CacheReadInputTokens > 0 { + result.CacheUsage = &CacheUsage{ + CacheCreationInputTokens: usage.CacheCreationInputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + } + } + + return result +} + +func FormatSSEEvent(event string, data []byte) []byte { + var buf bytes.Buffer + if len(event) > 0 { + buf.WriteString("event: ") + buf.WriteString(event) + buf.WriteByte('\n') + } + buf.WriteString("data: ") + buf.Write(data) + buf.WriteString("\n\n") + return buf.Bytes() +} diff --git a/streaming_test.go b/streaming_test.go new file mode 100644 index 0000000..4b7b154 --- /dev/null +++ b/streaming_test.go @@ -0,0 +1,286 @@ +package llmproxy + +import ( + "bytes" + "io" + "testing" +) + +func TestSSEParser(t *testing.T) { + tests := []struct { + name string + input string + expected []*SSEEvent + }{ + { + name: "simple event", + input: "data: {\"test\":\"value\"}\n\n", + expected: []*SSEEvent{ + {Data: []byte(`{"test":"value"}`)}, + }, + }, + { + name: "event with type", + input: "event: message\ndata: {\"test\":\"value\"}\n\n", + expected: []*SSEEvent{ + {Event: []byte("message"), Data: []byte(`{"test":"value"}`)}, + }, + }, + { + name: "multiple events", + input: "data: first\n\ndata: second\n\n", + expected: []*SSEEvent{ + {Data: []byte("first")}, + {Data: []byte("second")}, + }, + }, + { + name: "multiline data", + input: "data: line1\ndata: line2\n\n", + expected: []*SSEEvent{ + {Data: []byte("line1\nline2")}, + }, + }, + { + name: "OpenAI streaming format", + input: "data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\ndata: [DONE]\n\n", + expected: []*SSEEvent{ + {Data: []byte(`{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`)}, + {Data: []byte("[DONE]")}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := NewSSEParser(bytes.NewReader([]byte(tt.input))) + + var events []*SSEEvent + for { + event, err := parser.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + events = append(events, event) + } + + if len(events) != len(tt.expected) { + t.Fatalf("expected %d events, got %d", len(tt.expected), len(events)) + } + + for i, event := range events { + if !bytes.Equal(event.Data, tt.expected[i].Data) { + t.Errorf("event %d: expected data %q, got %q", i, tt.expected[i].Data, event.Data) + } + if !bytes.Equal(event.Event, tt.expected[i].Event) { + t.Errorf("event %d: expected event %q, got %q", i, tt.expected[i].Event, event.Event) + } + } + }) + } +} + +func TestParseOpenAISSEEvent(t *testing.T) { + tests := []struct { + name string + input []byte + expectError bool + expectDone bool + }{ + { + name: "valid chunk", + input: []byte(`{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4","choices":[{"index":0,"delta":{"content":"Hello"}}]}`), + expectError: false, + }, + { + name: "done marker", + input: []byte("[DONE]"), + expectDone: true, + expectError: false, + }, + { + name: "empty input", + input: []byte{}, + expectError: false, + }, + { + name: "invalid JSON", + input: []byte(`{invalid}`), + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chunk, err := ParseOpenAISSEEvent(tt.input) + + if tt.expectDone { + if err != ErrStreamComplete { + t.Errorf("expected ErrStreamComplete, got %v", err) + } + return + } + + if tt.expectError { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil && !tt.expectError { + t.Errorf("unexpected error: %v", err) + } + + if chunk != nil && tt.input != nil && len(tt.input) > 0 { + _ = chunk.ID + } + }) + } +} + +func TestExtractUsageFromOpenAIChunk(t *testing.T) { + tests := []struct { + name string + chunk *OpenAIStreamChunk + expected *StreamingUsage + }{ + { + name: "nil chunk", + chunk: nil, + expected: nil, + }, + { + name: "chunk without usage", + chunk: &OpenAIStreamChunk{ID: "test"}, + expected: nil, + }, + { + name: "chunk with basic usage", + chunk: &OpenAIStreamChunk{ + Usage: &OpenAIStreamUsage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + }, + }, + expected: &StreamingUsage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + }, + }, + { + name: "chunk with cache usage", + chunk: &OpenAIStreamChunk{ + Usage: &OpenAIStreamUsage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + PromptTokensDetails: &OpenAIStreamPromptDetails{ + CachedTokens: 80, + }, + }, + }, + expected: &StreamingUsage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + CacheUsage: &CacheUsage{ + CachedTokens: 80, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractUsageFromOpenAIChunk(tt.chunk) + + if tt.expected == nil { + if result != nil { + t.Errorf("expected nil, got %+v", result) + } + return + } + + if result == nil { + t.Fatal("expected result, got nil") + } + + if result.PromptTokens != tt.expected.PromptTokens { + t.Errorf("expected PromptTokens %d, got %d", tt.expected.PromptTokens, result.PromptTokens) + } + if result.CompletionTokens != tt.expected.CompletionTokens { + t.Errorf("expected CompletionTokens %d, got %d", tt.expected.CompletionTokens, result.CompletionTokens) + } + if result.TotalTokens != tt.expected.TotalTokens { + t.Errorf("expected TotalTokens %d, got %d", tt.expected.TotalTokens, result.TotalTokens) + } + + if tt.expected.CacheUsage != nil { + if result.CacheUsage == nil { + t.Error("expected CacheUsage, got nil") + } else if result.CacheUsage.CachedTokens != tt.expected.CacheUsage.CachedTokens { + t.Errorf("expected CachedTokens %d, got %d", tt.expected.CacheUsage.CachedTokens, result.CacheUsage.CachedTokens) + } + } + }) + } +} + +func TestIsSSEStream(t *testing.T) { + tests := []struct { + contentType string + expected bool + }{ + {"text/event-stream", true}, + {"text/event-stream; charset=utf-8", true}, + {"application/json", false}, + {"text/plain", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.contentType, func(t *testing.T) { + result := IsSSEStream(tt.contentType) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestFormatSSEEvent(t *testing.T) { + tests := []struct { + name string + event string + data []byte + expected []byte + }{ + { + name: "data only", + event: "", + data: []byte(`{"test":"value"}`), + expected: []byte("data: {\"test\":\"value\"}\n\n"), + }, + { + name: "event and data", + event: "message", + data: []byte(`{"test":"value"}`), + expected: []byte("event: message\ndata: {\"test\":\"value\"}\n\n"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := FormatSSEEvent(tt.event, tt.data) + if !bytes.Equal(result, tt.expected) { + t.Errorf("expected %q, got %q", tt.expected, result) + } + }) + } +} From bbefc9c83e46bf6d4bec6bece158600194abb688 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Mon, 13 Apr 2026 23:20:28 -0500 Subject: [PATCH 2/7] docs: add streaming support documentation --- DESIGN.md | 145 +++++++++++++++++++++++++++++++++++++++++++++++++++++- README.md | 32 ++++++++++++ 2 files changed, 175 insertions(+), 2 deletions(-) diff --git a/DESIGN.md b/DESIGN.md index 3753d06..ebaf0e3 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -394,6 +394,138 @@ curl -X POST http://localhost:8080/v1/chat/completions \ --- +## Streaming + +The proxy fully supports SSE (Server-Sent Events) streaming with efficient token usage extraction for billing. + +### StreamingResponseExtractor + +Extends `ResponseExtractor` for streaming responses: + +``` +type StreamingResponseExtractor interface { + ResponseExtractor + ExtractStreamingWithController(resp, w, rc) -> (ResponseMetadata, error) + IsStreamingResponse(resp) -> bool +} +``` + +All built-in providers implement this interface for streaming support. + +### Streaming Flow + +``` ++------------------+ +| Incoming Request | +| stream: true | ++--------+---------+ + | + Parse body, detect + stream: true flag + | ++--------v---------+ +| AutoRouter | +| ForwardStreaming | ++--------+---------+ + | + Auto-inject + stream_options: + {include_usage:true} + | ++--------v---------+ +| HTTP Request | +| to upstream | ++--------+---------+ + | ++--------v---------+ +| Upstream Response| +| text/event-stream| ++--------+---------+ + | ++--------v---------+ +| StreamingExtractor| +| Parse SSE events | +| Extract usage | +| Flush each chunk | ++--------+---------+ + | ++--------v---------+ +| BillingCalculator| +| Calculate cost | ++--------+---------+ + | ++--------v---------+ +| HTTP Response | +| to client | ++------------------+ +``` + +### Usage Extraction + +**OpenAI**: Usage is sent in the final chunk before `[DONE]`: + +```json +data: {"id":"...","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150}} +data: [DONE] +``` + +**Anthropic**: Usage is sent in `message_start` and `message_delta` events: + +```json +data: {"type":"message_start","message":{"usage":{"input_tokens":100}}} +... +data: {"type":"message_delta","usage":{"output_tokens":50}} +data: {"type":"message_stop"} +``` + +### Auto stream_options Injection + +When `BillingCalculator` is configured and the request has `stream: true`, the proxy automatically injects: + +```json +{ + "stream": true, + "stream_options": { "include_usage": true } +} +``` + +This ensures OpenAI returns token usage in the streaming response for billing calculation. + +### Efficient Flushing + +Uses `http.ResponseController` for optimal streaming: + +```go +rc := http.NewResponseController(w) + +for each SSE event { + w.Write(event) + rc.Flush() // Immediate flush after each chunk +} +``` + +Non-streaming responses also use chunked read/write/flush with a 512KB buffer for better performance. + +### Billing with Streaming + +```go +adapter, _ := modelsdev.LoadFromURL() + +billingCallback := func(r llmproxy.BillingResult) { + log.Printf("Cost: $%.6f", r.TotalCost) +} + +router := llmproxy.NewAutoRouter( + llmproxy.WithAutoRouterBillingCalculator( + llmproxy.NewBillingCalculator(adapter.GetCostLookup(), billingCallback), + ), +) +``` + +After the stream completes, the billing callback is invoked with the extracted token usage. + +--- + ## Providers Nine providers are included. Six share the OpenAI-compatible base; three have fully custom implementations. @@ -884,12 +1016,16 @@ Matches the signature of `github.com/agentuity/go-common/logger` without requiri ``` llmproxy/ ├── apitype.go # API type detection and constants -├── autorouter.go # AutoRouter, provider/API auto-detection +├── autorouter.go # AutoRouter, provider/API auto-detection, streaming ├── billing.go # CostInfo, CostLookup, BillingResult, CalculateCost +├── billing_calculator.go # BillingCalculator for streaming/non-streaming ├── detection.go # Provider detection from model/header ├── enricher.go # RequestEnricher interface -├── extractor.go # ResponseExtractor interface +├── extractor.go # ResponseExtractor, StreamingResponseExtractor interface ├── interceptor.go # Interceptor, InterceptorChain, RoundTripFunc +├── internal/ +│ └── fastjson/ +│ └── extractor.go # Fast JSON parsing with simdjson-go ├── logger.go # Logger interface, LoggerFunc adapter ├── metadata.go # BodyMetadata, ResponseMetadata, Message, Usage, Choice ├── parser.go # BodyParser interface @@ -897,6 +1033,7 @@ llmproxy/ ├── proxy.go # Proxy struct, Forward method ├── registry.go # Registry interface, MapRegistry ├── resolver.go # URLResolver interface +├── streaming.go # SSE parser, streaming types, usage extraction ├── interceptors/ │ ├── addheader.go # AddHeaderInterceptor │ ├── billing.go # BillingInterceptor @@ -911,14 +1048,18 @@ llmproxy/ │ └── adapter.go # models.dev pricing adapter ├── providers/ │ ├── anthropic/ # Anthropic Messages API +│ │ └── streaming_extractor.go # Anthropic SSE streaming │ ├── azure/ # Azure OpenAI │ ├── bedrock/ # AWS Bedrock Converse API +│ │ └── streaming_extractor.go # Bedrock streaming │ ├── fireworks/ # Fireworks (OpenAI-compatible) │ ├── googleai/ # Google AI Gemini +│ │ └── streaming_extractor.go # Google AI streaming │ ├── groq/ # Groq (OpenAI-compatible) │ ├── openai/ # OpenAI (Chat Completions + Responses) │ ├── openai_compatible/ # Base for OpenAI-compatible providers │ │ ├── multiapi.go # Multi-API parser/extractor +│ │ ├── streaming_extractor.go # SSE streaming with usage extraction │ │ ├── responses_parser.go # Responses API parser │ │ └── responses_extractor.go # Responses API extractor │ └── xai/ # x.AI (OpenAI-compatible) diff --git a/README.md b/README.md index 1219276..3ee6a7b 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,7 @@ curl -X POST http://localhost:8080/ \ - **9 Provider Implementations**: OpenAI, Anthropic, Groq, Fireworks, x.AI, Google AI, AWS Bedrock, Azure OpenAI, OpenAI-compatible base - **AutoRouter**: Single endpoint with automatic provider/API detection - **Responses API**: Full support for OpenAI's new Responses API +- **SSE Streaming**: Full streaming support with efficient token usage extraction - **8 Built-in Interceptors**: Logging, Metrics, Retry, Billing, Tracing (OTel), HeaderBan, AddHeader, PromptCaching - **Pricing Integration**: models.dev adapter with markup support - **Prompt Caching**: prompt caching support for Anthropic, OpenAI, xAI, Fireworks, and Bedrock @@ -153,6 +154,37 @@ curl -X POST http://localhost:8080/v1/chat/completions \ -d '{"model":"gpt-4","messages":[{"role":"user","content":"Hello"}]}' ``` +## Streaming + +SSE streaming is fully supported with automatic token usage extraction for billing: + +```bash +# Streaming with automatic usage extraction +curl -X POST http://localhost:8080/ \ + -H 'Content-Type: application/json' \ + -d '{"model":"gpt-4","stream":true,"messages":[{"role":"user","content":"Hello"}]}' +``` + +**Key Features:** + +- **Efficient flushing**: Uses `http.ResponseController` for immediate SSE delivery +- **Token extraction**: Extracts usage from streaming responses for billing +- **Auto stream_options**: Automatically injects `stream_options.include_usage` when billing is configured +- **Works with billing**: Billing is calculated after stream completes + +**Example with billing:** + +```go +adapter, _ := modelsdev.LoadFromURL() +billingCallback := func(r llmproxy.BillingResult) { + log.Printf("Cost: $%.6f (tokens: %d/%d)", r.TotalCost, r.PromptTokens, r.CompletionTokens) +} + +router := llmproxy.NewAutoRouter( + llmproxy.WithAutoRouterBillingCalculator(llmproxy.NewBillingCalculator(adapter.GetCostLookup(), billingCallback)), +) +``` + ## Providers | Provider | Auth | API Format | Notes | From 35b9213d5fa20523692d114f19441f0764a5cadb Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Mon, 13 Apr 2026 23:30:05 -0500 Subject: [PATCH 3/7] fix: use exclusion list for stream_options injection, add Anthropic streaming tests Replace whitelist-based provider switch with exclusion list approach for stream_options injection. Providers with native streaming usage reporting (Anthropic, Bedrock, Google AI) are excluded; all others automatically get stream_options injected. This is more extensible for new providers. Also adds comprehensive Anthropic streaming tests: - TestParseAnthropicSSEEvent - TestExtractUsageFromAnthropicEvent - TestAnthropicSSEParser (full event stream) - TestAutoRouter_AnthropicStreamingNoStreamOptions --- autorouter.go | 12 ++- autorouter_test.go | 65 ++++++++++++ streaming_test.go | 249 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 325 insertions(+), 1 deletion(-) diff --git a/autorouter.go b/autorouter.go index bcc8f9c..ee16b0f 100644 --- a/autorouter.go +++ b/autorouter.go @@ -240,7 +240,9 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w } if a.billingCalculator != nil { if stream, ok := raw["stream"].(bool); ok && stream { - raw["stream_options"] = map[string]any{"include_usage": true} + if !nativeStreamUsageProviders[providerName] { + raw["stream_options"] = map[string]any{"include_usage": true} + } } } var err error @@ -472,6 +474,14 @@ func (e *ProviderError) Error() string { return e.Message } +// nativeStreamUsageProviders are providers that include usage data +// natively in their streaming events without needing stream_options. +var nativeStreamUsageProviders = map[string]bool{ + "anthropic": true, + "bedrock": true, + "googleai": true, +} + var knownProviderPrefixes = map[string]bool{ "openai": true, "anthropic": true, diff --git a/autorouter_test.go b/autorouter_test.go index e48d391..6417c31 100644 --- a/autorouter_test.go +++ b/autorouter_test.go @@ -666,3 +666,68 @@ func TestAutoRouter_NonStreamingNoStreamOptions(t *testing.T) { t.Error("stream_options should not be injected for non-streaming requests") } } + +func TestAutoRouter_AnthropicStreamingNoStreamOptions(t *testing.T) { + var receivedBody map[string]any + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &receivedBody) + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + w.Write([]byte("event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"test\",\"usage\":{\"input_tokens\":100}}}\n\nevent: message_delta\ndata: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":50}}\n\n")) + })) + defer upstream.Close() + + provider := &mockStreamingProvider{ + mockProvider: &mockProvider{ + name: "anthropic", + parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) { + data, _ := io.ReadAll(body) + return BodyMetadata{Model: "claude-3-opus", Stream: true}, data, nil + }, + enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil }, + resolveFn: func(meta BodyMetadata) (*url.URL, error) { + return url.Parse(upstream.URL) + }, + }, + streamingExtractor: &mockStreamingExtractor{ + isStreaming: true, + extractStreamingFn: func(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) { + io.Copy(w, resp.Body) + rc.Flush() + return ResponseMetadata{ID: "test", Usage: Usage{PromptTokens: 100, CompletionTokens: 50}}, nil + }, + }, + } + provider.mockProvider.extractFn = func(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{ID: "test"}, body, nil + } + + billing := NewBillingCalculator( + func(provider, model string) (CostInfo, bool) { + return CostInfo{Input: 3, Output: 15}, true + }, + nil, + ) + + router := NewAutoRouter( + WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "anthropic" })), + WithAutoRouterBillingCalculator(billing), + ) + router.RegisterProvider(provider) + + req := httptest.NewRequest("POST", "/", bytes.NewReader([]byte(`{"model":"claude-3-opus","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"Hello"}]}`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("StatusCode = %d, want 200", w.Code) + } + + if _, ok := receivedBody["stream_options"]; ok { + t.Error("stream_options should NOT be injected for Anthropic (always sends usage in events)") + } +} diff --git a/streaming_test.go b/streaming_test.go index 4b7b154..5204ef3 100644 --- a/streaming_test.go +++ b/streaming_test.go @@ -284,3 +284,252 @@ func TestFormatSSEEvent(t *testing.T) { }) } } + +func TestParseAnthropicSSEEvent(t *testing.T) { + tests := []struct { + name string + input []byte + expectError bool + eventType string + expectedInputTok int + expectedOutputTok int + }{ + { + name: "empty input", + input: []byte{}, + expectError: false, + }, + { + name: "invalid JSON", + input: []byte(`{invalid}`), + expectError: true, + }, + { + name: "message_start event", + input: []byte(`{"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","model":"claude-3-opus-20240229","usage":{"input_tokens":150,"cache_read_input_tokens":1000}}}`), + eventType: "message_start", + expectedInputTok: 150, + expectedOutputTok: 0, + }, + { + name: "message_delta event", + input: []byte(`{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":75}}`), + eventType: "message_delta", + expectedInputTok: 0, + expectedOutputTok: 75, + }, + { + name: "content_block_start event", + input: []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`), + eventType: "content_block_start", + }, + { + name: "content_block_delta event", + input: []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`), + eventType: "content_block_delta", + }, + { + name: "message_stop event", + input: []byte(`{"type":"message_stop"}`), + eventType: "message_stop", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event, err := ParseAnthropicSSEEvent(tt.input) + + if tt.expectError { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if len(tt.input) == 0 { + if event != nil { + t.Error("expected nil event for empty input") + } + return + } + + if event.Type != tt.eventType { + t.Errorf("expected type %q, got %q", tt.eventType, event.Type) + } + + if tt.eventType == "message_start" && event.Message != nil { + if event.Message.Usage == nil { + t.Error("expected usage in message_start") + } else { + if event.Message.Usage.InputTokens != tt.expectedInputTok { + t.Errorf("expected input tokens %d, got %d", tt.expectedInputTok, event.Message.Usage.InputTokens) + } + } + } + + if tt.eventType == "message_delta" && event.Usage != nil { + if event.Usage.OutputTokens != tt.expectedOutputTok { + t.Errorf("expected output tokens %d, got %d", tt.expectedOutputTok, event.Usage.OutputTokens) + } + } + }) + } +} + +func TestExtractUsageFromAnthropicEvent(t *testing.T) { + tests := []struct { + name string + event *AnthropicStreamEvent + expectedPrompt int + expectedCompletion int + expectedCacheRead int + expectedCacheCreate int + }{ + { + name: "nil event", + event: nil, + }, + { + name: "message_stop returns nil", + event: &AnthropicStreamEvent{Type: "message_stop"}, + }, + { + name: "message_start with usage", + event: &AnthropicStreamEvent{ + Type: "message_start", + Message: &AnthropicStreamMessage{ + Usage: &AnthropicStreamUsage{ + InputTokens: 100, + CacheReadInputTokens: 500, + }, + }, + }, + expectedPrompt: 100, + expectedCacheRead: 500, + }, + { + name: "message_delta with usage", + event: &AnthropicStreamEvent{ + Type: "message_delta", + Usage: &AnthropicStreamUsage{ + OutputTokens: 50, + CacheCreationInputTokens: 200, + }, + }, + expectedCompletion: 50, + expectedCacheCreate: 200, + }, + { + name: "content_block_delta returns nil", + event: &AnthropicStreamEvent{ + Type: "content_block_delta", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractUsageFromAnthropicEvent(tt.event) + + if tt.expectedPrompt == 0 && tt.expectedCompletion == 0 { + if result != nil && (result.PromptTokens != 0 || result.CompletionTokens != 0) { + t.Errorf("expected nil or zero usage, got %+v", result) + } + return + } + + if result == nil { + t.Fatal("expected result, got nil") + } + + if result.PromptTokens != tt.expectedPrompt { + t.Errorf("expected prompt tokens %d, got %d", tt.expectedPrompt, result.PromptTokens) + } + if result.CompletionTokens != tt.expectedCompletion { + t.Errorf("expected completion tokens %d, got %d", tt.expectedCompletion, result.CompletionTokens) + } + + if tt.expectedCacheRead > 0 || tt.expectedCacheCreate > 0 { + if result.CacheUsage == nil { + t.Fatal("expected cache usage") + } + if result.CacheUsage.CacheReadInputTokens != tt.expectedCacheRead { + t.Errorf("expected cache read %d, got %d", tt.expectedCacheRead, result.CacheUsage.CacheReadInputTokens) + } + if result.CacheUsage.CacheCreationInputTokens != tt.expectedCacheCreate { + t.Errorf("expected cache create %d, got %d", tt.expectedCacheCreate, result.CacheUsage.CacheCreationInputTokens) + } + } + }) + } +} + +func TestAnthropicSSEParser(t *testing.T) { + // Realistic Anthropic streaming format + input := `event: message_start +data: {"type":"message_start","message":{"id":"msg_1a2b3c","type":"message","role":"assistant","model":"claude-3-opus-20240229","content":[],"stop_reason":null,"usage":{"input_tokens":150,"cache_read_input_tokens":1000}}} + +event: content_block_start +data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" world"}} + +event: content_block_stop +data: {"type":"content_block_stop","index":0} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":25}} + +event: message_stop +data: {"type":"message_stop"} + +` + + parser := NewSSEParser(bytes.NewReader([]byte(input))) + + var events []*SSEEvent + for { + event, err := parser.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + events = append(events, event) + } + + if len(events) != 7 { + t.Fatalf("expected 7 events, got %d", len(events)) + } + + // Verify message_start + if string(events[0].Event) != "message_start" { + t.Errorf("expected event 'message_start', got %q", events[0].Event) + } + startEvent, _ := ParseAnthropicSSEEvent(events[0].Data) + if startEvent.Message.Usage.InputTokens != 150 { + t.Errorf("expected 150 input tokens, got %d", startEvent.Message.Usage.InputTokens) + } + if startEvent.Message.Usage.CacheReadInputTokens != 1000 { + t.Errorf("expected 1000 cache read tokens, got %d", startEvent.Message.Usage.CacheReadInputTokens) + } + + // Verify message_delta + if string(events[5].Event) != "message_delta" { + t.Errorf("expected event 'message_delta', got %q", events[5].Event) + } + deltaEvent, _ := ParseAnthropicSSEEvent(events[5].Data) + if deltaEvent.Usage.OutputTokens != 25 { + t.Errorf("expected 25 output tokens, got %d", deltaEvent.Usage.OutputTokens) + } +} From e2ce42f0347071b25b77bcdc666a74ba624e8467 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Mon, 13 Apr 2026 23:36:17 -0500 Subject: [PATCH 4/7] fix: improve test assertions in streaming_test.go - Add error checking for ParseAnthropicSSEEvent calls to prevent panics - Require strict nil result for expected-nil test cases - Replace no-op _ = chunk.ID with real assertions for ID and content --- streaming_test.go | 52 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/streaming_test.go b/streaming_test.go index 5204ef3..2cb5db9 100644 --- a/streaming_test.go +++ b/streaming_test.go @@ -85,15 +85,19 @@ func TestSSEParser(t *testing.T) { func TestParseOpenAISSEEvent(t *testing.T) { tests := []struct { - name string - input []byte - expectError bool - expectDone bool + name string + input []byte + expectError bool + expectDone bool + expectedID string + expectedContent string }{ { - name: "valid chunk", - input: []byte(`{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4","choices":[{"index":0,"delta":{"content":"Hello"}}]}`), - expectError: false, + name: "valid chunk", + input: []byte(`{"id":"chatcmpl-123","object":"chat.completion.chunk","model":"gpt-4","choices":[{"index":0,"delta":{"content":"Hello"}}]}`), + expectError: false, + expectedID: "chatcmpl-123", + expectedContent: "Hello", }, { name: "done marker", @@ -133,10 +137,26 @@ func TestParseOpenAISSEEvent(t *testing.T) { if err != nil && !tt.expectError { t.Errorf("unexpected error: %v", err) + return + } + + if len(tt.input) == 0 { + if chunk != nil { + t.Error("expected nil chunk for empty input") + } + return } - if chunk != nil && tt.input != nil && len(tt.input) > 0 { - _ = chunk.ID + if chunk == nil { + t.Fatal("expected non-nil chunk for non-empty input") + } + + if chunk.ID != tt.expectedID { + t.Errorf("expected ID %q, got %q", tt.expectedID, chunk.ID) + } + + if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != tt.expectedContent { + t.Errorf("expected content %q, got %q", tt.expectedContent, chunk.Choices[0].Delta.Content) } }) } @@ -437,8 +457,8 @@ func TestExtractUsageFromAnthropicEvent(t *testing.T) { result := ExtractUsageFromAnthropicEvent(tt.event) if tt.expectedPrompt == 0 && tt.expectedCompletion == 0 { - if result != nil && (result.PromptTokens != 0 || result.CompletionTokens != 0) { - t.Errorf("expected nil or zero usage, got %+v", result) + if result != nil { + t.Errorf("expected nil result, got %+v", result) } return } @@ -516,7 +536,10 @@ data: {"type":"message_stop"} if string(events[0].Event) != "message_start" { t.Errorf("expected event 'message_start', got %q", events[0].Event) } - startEvent, _ := ParseAnthropicSSEEvent(events[0].Data) + startEvent, err := ParseAnthropicSSEEvent(events[0].Data) + if err != nil { + t.Fatalf("ParseAnthropicSSEEvent failed for events[0]: %v", err) + } if startEvent.Message.Usage.InputTokens != 150 { t.Errorf("expected 150 input tokens, got %d", startEvent.Message.Usage.InputTokens) } @@ -528,7 +551,10 @@ data: {"type":"message_stop"} if string(events[5].Event) != "message_delta" { t.Errorf("expected event 'message_delta', got %q", events[5].Event) } - deltaEvent, _ := ParseAnthropicSSEEvent(events[5].Data) + deltaEvent, err := ParseAnthropicSSEEvent(events[5].Data) + if err != nil { + t.Fatalf("ParseAnthropicSSEEvent failed for events[5]: %v", err) + } if deltaEvent.Usage.OutputTokens != 25 { t.Errorf("expected 25 output tokens, got %d", deltaEvent.Usage.OutputTokens) } From 04537cade356de6ab47d878c7ee5049aca5198d8 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Mon, 13 Apr 2026 23:46:36 -0500 Subject: [PATCH 5/7] fix: address PR review findings Critical fixes: - ForwardStreaming now wraps upstream call with interceptor chain - Use HTTP trailers for billing headers in streaming (fixes timing) - SSE events now forwarded even when parsing fails - stream_options merged instead of overwritten Robustness improvements: - Add nil guards in BillingCalculator.Calculate - Increase SSEParser buffer to 1MB for large payloads - FormatSSEEvent splits multi-line data correctly - Add TotalTokens fallback in extractOpenAIStd - Support Anthropic ephemeral cache token fields Test improvements: - Remove unused mockExtractor embedding - Add cacheUsage nil assertion in tests - Add error checking for ParseAnthropicSSEEvent calls Dependency fix: - Run go mod tidy to mark simdjson-go as direct dependency --- autorouter.go | 49 +++++++++++++++---- autorouter_test.go | 6 ++- billing_calculator.go | 3 ++ go.mod | 6 ++- internal/fastjson/extractor.go | 22 ++++++++- internal/fastjson/extractor_test.go | 4 ++ providers/anthropic/streaming_extractor.go | 17 ++++--- .../openai_compatible/streaming_extractor.go | 17 ++++--- streaming.go | 14 ++++-- 9 files changed, 104 insertions(+), 34 deletions(-) diff --git a/autorouter.go b/autorouter.go index ee16b0f..86f309a 100644 --- a/autorouter.go +++ b/autorouter.go @@ -241,7 +241,13 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w if a.billingCalculator != nil { if stream, ok := raw["stream"].(bool); ok && stream { if !nativeStreamUsageProviders[providerName] { - raw["stream_options"] = map[string]any{"include_usage": true} + // Merge include_usage into existing stream_options if present + streamOpts, ok := raw["stream_options"].(map[string]any) + if !ok { + streamOpts = make(map[string]any) + raw["stream_options"] = streamOpts + } + streamOpts["include_usage"] = true } } } @@ -289,12 +295,36 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w ctxValue := MetaContextValue{Meta: meta, RawBody: body} upstreamReq = upstreamReq.WithContext(context.WithValue(upstreamReq.Context(), MetaContextKey{}, ctxValue)) - upstreamResp, err := a.client.Do(upstreamReq) + // Wrap with interceptor chain (mirrors Forward method pattern) + chain := a.interceptors + doRequest := func(req *http.Request) (*http.Response, ResponseMetadata, []byte, error) { + resp, err := a.client.Do(req) + if err != nil { + return nil, ResponseMetadata{}, nil, err + } + // For streaming: return response with body still open. + // ResponseMetadata will be extracted during streaming. + return resp, ResponseMetadata{}, nil, nil + } + + if len(chain) > 0 { + doRequest = chain.Wrap(doRequest) + } + + upstreamResp, _, _, err := doRequest(upstreamReq) if err != nil { return ResponseMetadata{}, err } + if upstreamResp == nil { + return ResponseMetadata{}, errors.New("no response from upstream") + } defer upstreamResp.Body.Close() + // Declare HTTP trailers for billing headers (must be before WriteHeader) + if a.billingCalculator != nil { + w.Header().Set("Trailer", "X-Gateway-Cost,X-Gateway-Prompt-Tokens,X-Gateway-Completion-Tokens") + } + for k, v := range upstreamResp.Header { if k != "Content-Length" { w.Header()[k] = v @@ -324,6 +354,12 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w if a.billingCalculator != nil { a.billingCalculator.Calculate(meta, &respMeta) + // Set billing headers as HTTP trailers (sent after body completes) + if billing, ok := respMeta.Custom["billing_result"].(BillingResult); ok { + w.Header().Set("X-Gateway-Cost", fmt.Sprintf("%.6f", billing.TotalCost)) + w.Header().Set("X-Gateway-Prompt-Tokens", fmt.Sprintf("%d", billing.PromptTokens)) + w.Header().Set("X-Gateway-Completion-Tokens", fmt.Sprintf("%d", billing.CompletionTokens)) + } } return respMeta, nil @@ -391,19 +427,14 @@ func (a *AutoRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { r.Body = io.NopCloser(bytes.NewReader(body)) if isStreamingRequest { - meta, err := a.ForwardStreaming(r.Context(), r, w) + _, err := a.ForwardStreaming(r.Context(), r, w) if err != nil { if !headerSent(w) { http.Error(w, err.Error(), http.StatusInternalServerError) } return } - - if billing, ok := meta.Custom["billing_result"].(BillingResult); ok { - w.Header().Set("X-Gateway-Cost", fmt.Sprintf("%.6f", billing.TotalCost)) - w.Header().Set("X-Gateway-Prompt-Tokens", fmt.Sprintf("%d", billing.PromptTokens)) - w.Header().Set("X-Gateway-Completion-Tokens", fmt.Sprintf("%d", billing.CompletionTokens)) - } + // Billing headers are sent as HTTP trailers in ForwardStreaming return } diff --git a/autorouter_test.go b/autorouter_test.go index 6417c31..23933a2 100644 --- a/autorouter_test.go +++ b/autorouter_test.go @@ -76,11 +76,15 @@ func (m *mockStreamingProvider) ResponseExtractor() ResponseExtractor { } type mockStreamingExtractor struct { - *mockExtractor isStreaming bool extractStreamingFn func(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (ResponseMetadata, error) } +func (m *mockStreamingExtractor) Extract(resp *http.Response) (ResponseMetadata, []byte, error) { + body, _ := io.ReadAll(resp.Body) + return ResponseMetadata{}, body, nil +} + func (m *mockStreamingExtractor) IsStreamingResponse(resp *http.Response) bool { return m.isStreaming } diff --git a/billing_calculator.go b/billing_calculator.go index ca610c6..e8e7e22 100644 --- a/billing_calculator.go +++ b/billing_calculator.go @@ -13,6 +13,9 @@ func NewBillingCalculator(lookup CostLookup, onResult func(BillingResult)) *Bill } func (c *BillingCalculator) Calculate(meta BodyMetadata, respMeta *ResponseMetadata) *BillingResult { + if c.lookup == nil || respMeta == nil { + return nil + } var provider string if meta.Custom != nil { if p, ok := meta.Custom["provider"].(string); ok && p != "" { diff --git a/go.mod b/go.mod index ceab32e..f867248 100644 --- a/go.mod +++ b/go.mod @@ -2,13 +2,15 @@ module github.com/agentuity/llmproxy go 1.26.2 -require go.opentelemetry.io/otel/trace v1.43.0 +require ( + github.com/minio/simdjson-go v0.4.5 + go.opentelemetry.io/otel/trace v1.43.0 +) require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/klauspost/compress v1.15.15 // indirect github.com/klauspost/cpuid/v2 v2.2.3 // indirect - github.com/minio/simdjson-go v0.4.5 // indirect go.opentelemetry.io/otel v1.43.0 // indirect golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e // indirect ) diff --git a/internal/fastjson/extractor.go b/internal/fastjson/extractor.go index 0b9d7cc..5432454 100644 --- a/internal/fastjson/extractor.go +++ b/internal/fastjson/extractor.go @@ -147,6 +147,10 @@ func (e *UsageExtractor) extractOpenAIStd(data []byte) (*llmproxy.Usage, *llmpro TotalTokens: resp.Usage.TotalTokens, } + if usage.TotalTokens == 0 { + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + var cacheUsage *llmproxy.CacheUsage if resp.Usage.PromptTokensDetails != nil && resp.Usage.PromptTokensDetails.CachedTokens > 0 { cacheUsage = &llmproxy.CacheUsage{ @@ -219,12 +223,22 @@ func (e *UsageExtractor) extractAnthropicSimd(data []byte) (*llmproxy.Usage, *ll v, _ := tmpIter.Int() cacheUsage.CacheReadInputTokens = int(v) } + case "ephemeral_5m_input_tokens": + if t == simdjson.TypeInt { + v, _ := tmpIter.Int() + cacheUsage.Ephemeral5mInputTokens = int(v) + } + case "ephemeral_1h_input_tokens": + if t == simdjson.TypeInt { + v, _ := tmpIter.Int() + cacheUsage.Ephemeral1hInputTokens = int(v) + } } } usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - hasCache := cacheUsage.CacheCreationInputTokens > 0 || cacheUsage.CacheReadInputTokens > 0 + hasCache := cacheUsage.CacheCreationInputTokens > 0 || cacheUsage.CacheReadInputTokens > 0 || cacheUsage.Ephemeral5mInputTokens > 0 || cacheUsage.Ephemeral1hInputTokens > 0 if !hasCache { cacheUsage = nil } @@ -239,6 +253,8 @@ func (e *UsageExtractor) extractAnthropicStd(data []byte) (*llmproxy.Usage, *llm OutputTokens int `json:"output_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"` + Ephemeral5mInputTokens int `json:"ephemeral_5m_input_tokens"` + Ephemeral1hInputTokens int `json:"ephemeral_1h_input_tokens"` } `json:"usage"` } @@ -257,10 +273,12 @@ func (e *UsageExtractor) extractAnthropicStd(data []byte) (*llmproxy.Usage, *llm } var cacheUsage *llmproxy.CacheUsage - if resp.Usage.CacheCreationInputTokens > 0 || resp.Usage.CacheReadInputTokens > 0 { + if resp.Usage.CacheCreationInputTokens > 0 || resp.Usage.CacheReadInputTokens > 0 || resp.Usage.Ephemeral5mInputTokens > 0 || resp.Usage.Ephemeral1hInputTokens > 0 { cacheUsage = &llmproxy.CacheUsage{ CacheCreationInputTokens: resp.Usage.CacheCreationInputTokens, CacheReadInputTokens: resp.Usage.CacheReadInputTokens, + Ephemeral5mInputTokens: resp.Usage.Ephemeral5mInputTokens, + Ephemeral1hInputTokens: resp.Usage.Ephemeral1hInputTokens, } } diff --git a/internal/fastjson/extractor_test.go b/internal/fastjson/extractor_test.go index 5549d52..ab6a81e 100644 --- a/internal/fastjson/extractor_test.go +++ b/internal/fastjson/extractor_test.go @@ -59,6 +59,10 @@ func TestUsageExtractor_ExtractOpenAI(t *testing.T) { } else if cacheUsage.CachedTokens != tt.expectedCached { t.Errorf("expected cached tokens %d, got %d", tt.expectedCached, cacheUsage.CachedTokens) } + } else { + if cacheUsage != nil { + t.Errorf("expected nil cache usage, got %+v", cacheUsage) + } } }) } diff --git a/providers/anthropic/streaming_extractor.go b/providers/anthropic/streaming_extractor.go index 724abc0..5967759 100644 --- a/providers/anthropic/streaming_extractor.go +++ b/providers/anthropic/streaming_extractor.go @@ -104,6 +104,15 @@ func (e *StreamingExtractor) extractStreamingWithController(resp *http.Response, data := bytes.TrimPrefix(line, []byte("data: ")) data = bytes.TrimSpace(data) + // Forward the raw data regardless of parsing success + if _, err := w.Write(line); err != nil { + return meta, err + } + if _, err := w.Write([]byte("\n\n")); err != nil { + return meta, err + } + _ = rc.Flush() + event, err := llmproxy.ParseAnthropicSSEEvent(data) if err != nil { continue @@ -170,14 +179,6 @@ func (e *StreamingExtractor) extractStreamingWithController(resp *http.Response, } case "message_stop": } - - if _, err := w.Write(line); err != nil { - return meta, err - } - if _, err := w.Write([]byte("\n\n")); err != nil { - return meta, err - } - _ = rc.Flush() } else if bytes.HasPrefix(line, []byte("event: ")) { if _, err := w.Write(line); err != nil { return meta, err diff --git a/providers/openai_compatible/streaming_extractor.go b/providers/openai_compatible/streaming_extractor.go index 6eb8b5f..1a8bb2f 100644 --- a/providers/openai_compatible/streaming_extractor.go +++ b/providers/openai_compatible/streaming_extractor.go @@ -113,6 +113,15 @@ func (e *StreamingExtractor) extractStreamingWithController(resp *http.Response, break } + // Forward the raw data regardless of parsing success + if _, err := w.Write(line); err != nil { + return meta, err + } + if _, err := w.Write([]byte("\n\n")); err != nil { + return meta, err + } + _ = rc.Flush() + chunk, err := llmproxy.ParseOpenAISSEEvent(data) if err != nil { continue @@ -140,14 +149,6 @@ func (e *StreamingExtractor) extractStreamingWithController(resp *http.Response, accumulatedUsage = usage } } - - if _, err := w.Write(line); err != nil { - return meta, err - } - if _, err := w.Write([]byte("\n\n")); err != nil { - return meta, err - } - _ = rc.Flush() } else { if _, err := w.Write(line); err != nil { return meta, err diff --git a/streaming.go b/streaming.go index 1681308..6a10195 100644 --- a/streaming.go +++ b/streaming.go @@ -25,8 +25,10 @@ type SSEParser struct { } func NewSSEParser(r io.Reader) *SSEParser { + scanner := bufio.NewScanner(r) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) return &SSEParser{ - scanner: bufio.NewScanner(r), + scanner: scanner, } } @@ -284,8 +286,12 @@ func FormatSSEEvent(event string, data []byte) []byte { buf.WriteString(event) buf.WriteByte('\n') } - buf.WriteString("data: ") - buf.Write(data) - buf.WriteString("\n\n") + // Split data on newlines and write each as a separate "data:" line + for _, line := range bytes.Split(data, []byte{'\n'}) { + buf.WriteString("data: ") + buf.Write(line) + buf.WriteByte('\n') + } + buf.WriteByte('\n') return buf.Bytes() } From 9703cb33614831f373ef2a7e2839b19db7167152 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Mon, 13 Apr 2026 23:59:51 -0500 Subject: [PATCH 6/7] feat: implement true streaming for Bedrock and Google AI Previously both providers buffered entire responses before forwarding to client, defeating SSE streaming. Now: Resolvers: - Google AI: Use :streamGenerateContent?alt=sse for streaming - Bedrock: Use /converse-stream for streaming Streaming extractors: - Google AI: Incremental SSE parsing, forward each event immediately - Bedrock: Binary AWS event stream parsing, forward each message immediately Tests: - Verify incremental streaming (data arrives before stream completes) - Test metadata extraction from streams - Test resolver endpoint selection based on stream flag All 100+ tests pass. --- autorouter.go | 1 + providers/bedrock/resolver.go | 12 +- providers/bedrock/streaming_extractor.go | 261 ++++++++++- providers/bedrock/streaming_extractor_test.go | 404 ++++++++++++++++++ providers/googleai/resolver.go | 15 +- providers/googleai/streaming_extractor.go | 107 ++++- .../googleai/streaming_extractor_test.go | 287 +++++++++++++ 7 files changed, 1079 insertions(+), 8 deletions(-) create mode 100644 providers/bedrock/streaming_extractor_test.go create mode 100644 providers/googleai/streaming_extractor_test.go diff --git a/autorouter.go b/autorouter.go index 86f309a..b5cbc73 100644 --- a/autorouter.go +++ b/autorouter.go @@ -273,6 +273,7 @@ func (a *AutoRouter) ForwardStreaming(ctx context.Context, req *http.Request, w } meta.Custom["api_type"] = apiType meta.Custom["provider"] = providerName + meta.Stream = true upstreamURL, err := provider.URLResolver().Resolve(meta) if err != nil { diff --git a/providers/bedrock/resolver.go b/providers/bedrock/resolver.go index 7890eed..4cc0d9f 100644 --- a/providers/bedrock/resolver.go +++ b/providers/bedrock/resolver.go @@ -16,8 +16,10 @@ type Resolver struct { } // Resolve returns the Bedrock endpoint URL for the given model. -// The URL format depends on whether we use the Converse or Invoke API: -// - Converse: https://bedrock-runtime.{region}.amazonaws.com/model/{modelId}/converse +// The URL format depends on whether we use the Converse or Invoke API, +// and whether streaming is requested: +// - Converse (streaming): https://bedrock-runtime.{region}.amazonaws.com/model/{modelId}/converse-stream +// - Converse (non-streaming): https://bedrock-runtime.{region}.amazonaws.com/model/{modelId}/converse // - Invoke: https://bedrock-runtime.{region}.amazonaws.com/model/{modelId}/invoke func (r *Resolver) Resolve(meta llmproxy.BodyMetadata) (*url.URL, error) { modelID := meta.Model @@ -30,7 +32,11 @@ func (r *Resolver) Resolve(meta llmproxy.BodyMetadata) (*url.URL, error) { var endpoint string if r.UseConverse { - endpoint = fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", r.Region, encodedModelID) + if meta.Stream { + endpoint = fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse-stream", r.Region, encodedModelID) + } else { + endpoint = fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", r.Region, encodedModelID) + } } else { endpoint = fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke", r.Region, encodedModelID) } diff --git a/providers/bedrock/streaming_extractor.go b/providers/bedrock/streaming_extractor.go index b4e74f8..30d9fb8 100644 --- a/providers/bedrock/streaming_extractor.go +++ b/providers/bedrock/streaming_extractor.go @@ -3,9 +3,13 @@ package bedrock import ( "bytes" "context" + "encoding/binary" + "encoding/json" "errors" + "fmt" "io" "net/http" + "strings" "github.com/agentuity/llmproxy" ) @@ -21,13 +25,268 @@ func NewStreamingExtractor() *StreamingExtractor { } func (e *StreamingExtractor) IsStreamingResponse(resp *http.Response) bool { - return llmproxy.IsSSEStream(resp.Header.Get("Content-Type")) + ct := resp.Header.Get("Content-Type") + return llmproxy.IsSSEStream(ct) || isEventStream(ct) +} + +func isEventStream(contentType string) bool { + return strings.Contains(contentType, "vnd.amazon.eventstream") } func (e *StreamingExtractor) ExtractStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + ct := resp.Header.Get("Content-Type") + if isEventStream(ct) { + return e.extractEventStreamWithController(resp, w, rc) + } return e.extractNonStreamingWithController(resp, w, rc) } +// eventStreamEvent represents a parsed AWS event stream event. +type eventStreamEvent struct { + EventType string + Payload []byte +} + +func (e *StreamingExtractor) extractEventStreamWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + meta := llmproxy.ResponseMetadata{ + Choices: make([]llmproxy.Choice, 0, 1), + Custom: make(map[string]any), + } + + for { + event, rawBytes, err := readEventStreamMessage(resp.Body) + if err != nil { + if err == io.EOF || err == io.ErrUnexpectedEOF { + // Write any partial bytes we managed to read + if len(rawBytes) > 0 { + _, _ = w.Write(rawBytes) + _ = rc.Flush() + } + break + } + if errors.Is(err, context.Canceled) { + break + } + return meta, err + } + + // Forward raw bytes immediately + if _, writeErr := w.Write(rawBytes); writeErr != nil { + return meta, writeErr + } + _ = rc.Flush() + + // Process event for metadata extraction + if event != nil { + processBedrockStreamEvent(event, &meta) + } + } + + return meta, nil +} + +// readEventStreamMessage reads a single AWS event stream message from the reader. +// AWS event stream format: +// +// [4 bytes: total length][4 bytes: headers length][4 bytes: prelude CRC] +// [headers...][payload...][4 bytes: message CRC] +func readEventStreamMessage(r io.Reader) (*eventStreamEvent, []byte, error) { + // Read the 12-byte prelude + prelude := make([]byte, 12) + if _, err := io.ReadFull(r, prelude); err != nil { + return nil, nil, err + } + + totalLen := binary.BigEndian.Uint32(prelude[0:4]) + headersLen := binary.BigEndian.Uint32(prelude[4:8]) + + // Sanity check: minimum message is 16 bytes (12 prelude + 4 message CRC), + // maximum is 16MB to prevent memory issues + if totalLen < 16 || totalLen > 16*1024*1024 { + return nil, prelude, fmt.Errorf("invalid event stream message length: %d", totalLen) + } + + // Read remaining bytes (total - 12 bytes of prelude already read) + remaining := make([]byte, totalLen-12) + if _, err := io.ReadFull(r, remaining); err != nil { + // Return prelude bytes so they can still be forwarded + return nil, prelude, err + } + + // Reconstruct full raw message for forwarding + rawBytes := make([]byte, totalLen) + copy(rawBytes, prelude) + copy(rawBytes[12:], remaining) + + // Parse headers + headers := parseEventStreamHeaders(remaining[:headersLen]) + + // Extract payload (between headers and message CRC) + payloadLen := totalLen - 12 - headersLen - 4 + var payload []byte + if payloadLen > 0 { + payload = remaining[headersLen : headersLen+payloadLen] + } + + return &eventStreamEvent{ + EventType: headers[":event-type"], + Payload: payload, + }, rawBytes, nil +} + +// parseEventStreamHeaders parses AWS event stream binary headers. +// Header format: [1 byte: name length][name][1 byte: type][value...] +// Type 7 (string): [2 bytes: value length][value] +func parseEventStreamHeaders(data []byte) map[string]string { + headers := make(map[string]string) + offset := uint32(0) + dataLen := uint32(len(data)) + + for offset < dataLen { + // Read header name length + nameLen := uint32(data[offset]) + offset++ + if offset+nameLen > dataLen { + break + } + + // Read header name + name := string(data[offset : offset+nameLen]) + offset += nameLen + if offset >= dataLen { + break + } + + // Read header type + headerType := data[offset] + offset++ + + // Skip value based on type + switch headerType { + case 0, 1: // bool_true, bool_false - no value bytes + case 2: // byte + offset++ + case 3: // short + offset += 2 + case 4: // int + offset += 4 + case 5, 8: // long, timestamp + offset += 8 + case 6, 7: // bytes, string - 2-byte length prefix + value + if offset+2 > dataLen { + return headers + } + valueLen := uint32(binary.BigEndian.Uint16(data[offset : offset+2])) + offset += 2 + if offset+valueLen > dataLen { + return headers + } + if headerType == 7 { + headers[name] = string(data[offset : offset+valueLen]) + } + offset += valueLen + case 9: // uuid + offset += 16 + default: + return headers // unknown type, bail + } + } + + return headers +} + +// Bedrock stream event payload types + +type bedrockStreamStart struct { + Role string `json:"role"` +} + +type bedrockStreamDelta struct { + ContentBlockIndex int `json:"contentBlockIndex"` + Delta struct { + Text string `json:"text,omitempty"` + } `json:"delta"` +} + +type bedrockStreamStop struct { + StopReason string `json:"stopReason"` +} + +type bedrockStreamMetadata struct { + Usage ResponseUsage `json:"usage"` + Metrics *ResponseMetrics `json:"metrics,omitempty"` +} + +func processBedrockStreamEvent(event *eventStreamEvent, meta *llmproxy.ResponseMetadata) { + if len(event.Payload) == 0 { + return + } + + switch event.EventType { + case "messageStart": + var start bedrockStreamStart + if json.Unmarshal(event.Payload, &start) == nil { + role := start.Role + if role == "" { + role = "assistant" + } + if len(meta.Choices) == 0 { + meta.Choices = append(meta.Choices, llmproxy.Choice{ + Index: 0, + Message: &llmproxy.Message{ + Role: role, + }, + }) + } + } + + case "contentBlockDelta": + var delta bedrockStreamDelta + if json.Unmarshal(event.Payload, &delta) == nil { + if len(meta.Choices) == 0 { + meta.Choices = append(meta.Choices, llmproxy.Choice{ + Index: 0, + Message: &llmproxy.Message{ + Role: "assistant", + }, + }) + } + if meta.Choices[0].Message != nil { + meta.Choices[0].Message.Content += delta.Delta.Text + } + } + + case "messageStop": + var stop bedrockStreamStop + if json.Unmarshal(event.Payload, &stop) == nil { + if len(meta.Choices) > 0 { + meta.Choices[0].FinishReason = stop.StopReason + } + } + + case "metadata": + var metadata bedrockStreamMetadata + if json.Unmarshal(event.Payload, &metadata) == nil { + meta.Usage = llmproxy.Usage{ + PromptTokens: metadata.Usage.InputTokens, + CompletionTokens: metadata.Usage.OutputTokens, + TotalTokens: metadata.Usage.TotalTokens, + } + if metadata.Metrics != nil { + meta.Custom["latency_ms"] = metadata.Metrics.LatencyMs + } + if metadata.Usage.CacheReadInputTokens > 0 || metadata.Usage.CacheWriteInputTokens > 0 { + cacheDetails := extractCacheDetails(metadata.Usage.CacheDetails) + meta.Custom["cache_usage"] = llmproxy.CacheUsage{ + CachedTokens: metadata.Usage.CacheReadInputTokens, + CacheWriteTokens: metadata.Usage.CacheWriteInputTokens, + CacheDetails: cacheDetails, + } + } + } + } +} + func (e *StreamingExtractor) extractNonStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { var buf bytes.Buffer tee := io.TeeReader(resp.Body, &buf) diff --git a/providers/bedrock/streaming_extractor_test.go b/providers/bedrock/streaming_extractor_test.go new file mode 100644 index 0000000..74e7867 --- /dev/null +++ b/providers/bedrock/streaming_extractor_test.go @@ -0,0 +1,404 @@ +package bedrock + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/agentuity/llmproxy" +) + +// buildEventStreamMessage constructs a binary AWS event stream message +// with the given event type and JSON payload. +func buildEventStreamMessage(eventType string, payload []byte) []byte { + var headers bytes.Buffer + writeEventStreamHeader(&headers, ":event-type", eventType) + writeEventStreamHeader(&headers, ":content-type", "application/json") + writeEventStreamHeader(&headers, ":message-type", "event") + + headersBytes := headers.Bytes() + headersLen := uint32(len(headersBytes)) + totalLen := uint32(12 + headersLen + uint32(len(payload)) + 4) + + var buf bytes.Buffer + binary.Write(&buf, binary.BigEndian, totalLen) + binary.Write(&buf, binary.BigEndian, headersLen) + binary.Write(&buf, binary.BigEndian, uint32(0)) // prelude CRC (not validated in parser) + buf.Write(headersBytes) + buf.Write(payload) + binary.Write(&buf, binary.BigEndian, uint32(0)) // message CRC (not validated in parser) + + return buf.Bytes() +} + +func writeEventStreamHeader(buf *bytes.Buffer, name, value string) { + buf.WriteByte(byte(len(name))) + buf.WriteString(name) + buf.WriteByte(7) // string type + binary.Write(buf, binary.BigEndian, uint16(len(value))) + buf.WriteString(value) +} + +func TestStreamingExtractor_EventStream(t *testing.T) { + // Build a complete Bedrock streaming response with binary events + var stream bytes.Buffer + + // messageStart event + startPayload, _ := json.Marshal(map[string]string{"role": "assistant"}) + stream.Write(buildEventStreamMessage("messageStart", startPayload)) + + // contentBlockDelta events + delta1, _ := json.Marshal(map[string]any{ + "contentBlockIndex": 0, + "delta": map[string]string{"text": "Hello"}, + }) + stream.Write(buildEventStreamMessage("contentBlockDelta", delta1)) + + delta2, _ := json.Marshal(map[string]any{ + "contentBlockIndex": 0, + "delta": map[string]string{"text": " World"}, + }) + stream.Write(buildEventStreamMessage("contentBlockDelta", delta2)) + + // messageStop event + stopPayload, _ := json.Marshal(map[string]string{"stopReason": "end_turn"}) + stream.Write(buildEventStreamMessage("messageStop", stopPayload)) + + // metadata event + metadataPayload, _ := json.Marshal(map[string]any{ + "usage": map[string]int{ + "inputTokens": 10, + "outputTokens": 5, + "totalTokens": 15, + }, + "metrics": map[string]int64{ + "latencyMs": 100, + }, + }) + stream.Write(buildEventStreamMessage("metadata", metadataPayload)) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/vnd.amazon.eventstream"}}, + Body: io.NopCloser(&stream), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingExtractor() + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify usage metadata + if meta.Usage.PromptTokens != 10 { + t.Errorf("expected prompt tokens 10, got %d", meta.Usage.PromptTokens) + } + if meta.Usage.CompletionTokens != 5 { + t.Errorf("expected completion tokens 5, got %d", meta.Usage.CompletionTokens) + } + if meta.Usage.TotalTokens != 15 { + t.Errorf("expected total tokens 15, got %d", meta.Usage.TotalTokens) + } + + // Verify choices + if len(meta.Choices) != 1 { + t.Fatalf("expected 1 choice, got %d", len(meta.Choices)) + } + if meta.Choices[0].Message.Role != "assistant" { + t.Errorf("expected role 'assistant', got %q", meta.Choices[0].Message.Role) + } + if meta.Choices[0].Message.Content != "Hello World" { + t.Errorf("expected content 'Hello World', got %q", meta.Choices[0].Message.Content) + } + if meta.Choices[0].FinishReason != "end_turn" { + t.Errorf("expected finish_reason 'end_turn', got %q", meta.Choices[0].FinishReason) + } + + // Verify latency metric + if latency, ok := meta.Custom["latency_ms"]; !ok || latency != int64(100) { + t.Errorf("expected latency_ms 100, got %v", meta.Custom["latency_ms"]) + } + + // Verify data was forwarded to client + if recorder.Body.Len() == 0 { + t.Error("no data written to client") + } +} + +func TestStreamingExtractor_EventStreamIncremental(t *testing.T) { + // Use a pipe to simulate slow upstream + pr, pw := io.Pipe() + + var mu sync.Mutex + var firstByteTime time.Time + var streamDoneTime time.Time + + // Send events with delay to verify incrementality + go func() { + defer pw.Close() + + startPayload, _ := json.Marshal(map[string]string{"role": "assistant"}) + pw.Write(buildEventStreamMessage("messageStart", startPayload)) + + delta1, _ := json.Marshal(map[string]any{ + "contentBlockIndex": 0, + "delta": map[string]string{"text": "Hello"}, + }) + pw.Write(buildEventStreamMessage("contentBlockDelta", delta1)) + + time.Sleep(100 * time.Millisecond) + + delta2, _ := json.Marshal(map[string]any{ + "contentBlockIndex": 0, + "delta": map[string]string{"text": " World"}, + }) + pw.Write(buildEventStreamMessage("contentBlockDelta", delta2)) + + metadataPayload, _ := json.Marshal(map[string]any{ + "usage": map[string]int{ + "inputTokens": 10, + "outputTokens": 5, + "totalTokens": 15, + }, + }) + pw.Write(buildEventStreamMessage("metadata", metadataPayload)) + }() + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/vnd.amazon.eventstream"}}, + Body: io.NopCloser(pr), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingExtractor() + + // Monitor when data arrives + go func() { + for { + mu.Lock() + if recorder.Body.Len() > 0 && firstByteTime.IsZero() { + firstByteTime = time.Now() + } + mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + }() + + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + mu.Lock() + streamDoneTime = time.Now() + mu.Unlock() + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify first bytes arrived before stream completed + mu.Lock() + defer mu.Unlock() + if firstByteTime.IsZero() { + t.Fatal("no data was received") + } + timeDiff := streamDoneTime.Sub(firstByteTime) + if timeDiff < 50*time.Millisecond { + t.Errorf("data did not arrive incrementally: first chunk and completion were only %v apart", timeDiff) + } + + // Verify metadata was still extracted + if meta.Usage.TotalTokens != 15 { + t.Errorf("expected total tokens 15, got %d", meta.Usage.TotalTokens) + } +} + +func TestStreamingExtractor_IsStreamingResponse(t *testing.T) { + extractor := NewStreamingExtractor() + + tests := []struct { + contentType string + expected bool + }{ + {"text/event-stream", true}, + {"text/event-stream; charset=utf-8", true}, + {"application/vnd.amazon.eventstream", true}, + {"application/json", false}, + {"text/plain", false}, + } + + for _, tt := range tests { + t.Run(tt.contentType, func(t *testing.T) { + resp := &http.Response{ + Header: http.Header{"Content-Type": []string{tt.contentType}}, + } + result := extractor.IsStreamingResponse(resp) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestStreamingExtractor_NonStreamingFallback(t *testing.T) { + extractor := NewStreamingExtractor() + + respBody := `{"requestId":"req-123","modelId":"anthropic.claude-3-sonnet-20240229-v1:0","output":{"message":{"role":"assistant","content":[{"text":"Hello!"}]}},"usage":{"inputTokens":10,"outputTokens":5,"totalTokens":15},"stopReason":"end_turn"}` + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(respBody)), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.Usage.PromptTokens != 10 { + t.Errorf("expected prompt tokens 10, got %d", meta.Usage.PromptTokens) + } + if meta.Usage.CompletionTokens != 5 { + t.Errorf("expected completion tokens 5, got %d", meta.Usage.CompletionTokens) + } + if len(meta.Choices) != 1 { + t.Fatalf("expected 1 choice, got %d", len(meta.Choices)) + } + if meta.Choices[0].Message.Content != "Hello!" { + t.Errorf("expected content 'Hello!', got %q", meta.Choices[0].Message.Content) + } +} + +func TestStreamingExtractor_EventStreamWithCache(t *testing.T) { + var stream bytes.Buffer + + startPayload, _ := json.Marshal(map[string]string{"role": "assistant"}) + stream.Write(buildEventStreamMessage("messageStart", startPayload)) + + deltaPayload, _ := json.Marshal(map[string]any{ + "contentBlockIndex": 0, + "delta": map[string]string{"text": "cached response"}, + }) + stream.Write(buildEventStreamMessage("contentBlockDelta", deltaPayload)) + + stopPayload, _ := json.Marshal(map[string]string{"stopReason": "end_turn"}) + stream.Write(buildEventStreamMessage("messageStop", stopPayload)) + + metadataPayload, _ := json.Marshal(map[string]any{ + "usage": map[string]any{ + "inputTokens": 100, + "outputTokens": 50, + "totalTokens": 150, + "cacheReadInputTokens": 80, + "cacheWriteInputTokens": 20, + }, + }) + stream.Write(buildEventStreamMessage("metadata", metadataPayload)) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/vnd.amazon.eventstream"}}, + Body: io.NopCloser(&stream), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingExtractor() + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.Usage.PromptTokens != 100 { + t.Errorf("expected prompt tokens 100, got %d", meta.Usage.PromptTokens) + } + + cacheUsage, ok := meta.Custom["cache_usage"].(llmproxy.CacheUsage) + if !ok { + t.Fatal("expected cache_usage in custom map") + } + if cacheUsage.CachedTokens != 80 { + t.Errorf("expected cached tokens 80, got %d", cacheUsage.CachedTokens) + } + if cacheUsage.CacheWriteTokens != 20 { + t.Errorf("expected cache write tokens 20, got %d", cacheUsage.CacheWriteTokens) + } +} + +func TestResolver_StreamingEndpoint(t *testing.T) { + t.Run("resolves to converse-stream when streaming", func(t *testing.T) { + resolver := NewResolver("us-east-1") + meta := llmproxy.BodyMetadata{Model: "anthropic.claude-3-sonnet-20240229-v1:0", Stream: true} + + u, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(u.String(), "/converse-stream") { + t.Errorf("expected converse-stream in URL, got %s", u.String()) + } + }) + + t.Run("resolves to converse when not streaming", func(t *testing.T) { + resolver := NewResolver("us-east-1") + meta := llmproxy.BodyMetadata{Model: "anthropic.claude-3-sonnet-20240229-v1:0", Stream: false} + + u, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if strings.Contains(u.String(), "converse-stream") { + t.Errorf("expected converse (not stream) in URL, got %s", u.String()) + } + if !strings.Contains(u.String(), "/converse") { + t.Errorf("expected converse in URL, got %s", u.String()) + } + }) + + t.Run("invoke endpoint ignores streaming flag", func(t *testing.T) { + resolver := NewInvokeResolver("us-east-1") + meta := llmproxy.BodyMetadata{Model: "amazon.titan-text-express-v1", Stream: true} + + u, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(u.String(), "/invoke") { + t.Errorf("expected invoke in URL, got %s", u.String()) + } + }) +} + +func TestParseEventStreamHeaders(t *testing.T) { + var buf bytes.Buffer + writeEventStreamHeader(&buf, ":event-type", "contentBlockDelta") + writeEventStreamHeader(&buf, ":content-type", "application/json") + writeEventStreamHeader(&buf, ":message-type", "event") + + headers := parseEventStreamHeaders(buf.Bytes()) + + if headers[":event-type"] != "contentBlockDelta" { + t.Errorf("expected event-type 'contentBlockDelta', got %q", headers[":event-type"]) + } + if headers[":content-type"] != "application/json" { + t.Errorf("expected content-type 'application/json', got %q", headers[":content-type"]) + } + if headers[":message-type"] != "event" { + t.Errorf("expected message-type 'event', got %q", headers[":message-type"]) + } +} diff --git a/providers/googleai/resolver.go b/providers/googleai/resolver.go index 3851a9a..4bbc8b5 100644 --- a/providers/googleai/resolver.go +++ b/providers/googleai/resolver.go @@ -14,8 +14,9 @@ type Resolver struct { BaseURL *url.URL } -// Resolve returns the full URL for the generateContent endpoint. -// The URL format is: {base}/v1beta/models/{model}:generateContent +// Resolve returns the full URL for the generateContent or streamGenerateContent endpoint. +// When meta.Stream is true, the URL uses streamGenerateContent with alt=sse for SSE format. +// Otherwise, the URL uses generateContent. // // If meta.Model is empty, defaults to "gemini-pro". func (r *Resolver) Resolve(meta llmproxy.BodyMetadata) (*url.URL, error) { @@ -24,7 +25,15 @@ func (r *Resolver) Resolve(meta llmproxy.BodyMetadata) (*url.URL, error) { model = "gemini-pro" } - endpoint := r.BaseURL.JoinPath("v1beta", "models", fmt.Sprintf("%s:generateContent", model)) + var endpoint *url.URL + if meta.Stream { + endpoint = r.BaseURL.JoinPath("v1beta", "models", fmt.Sprintf("%s:streamGenerateContent", model)) + q := endpoint.Query() + q.Set("alt", "sse") + endpoint.RawQuery = q.Encode() + } else { + endpoint = r.BaseURL.JoinPath("v1beta", "models", fmt.Sprintf("%s:generateContent", model)) + } return endpoint, nil } diff --git a/providers/googleai/streaming_extractor.go b/providers/googleai/streaming_extractor.go index de6201e..82c20b9 100644 --- a/providers/googleai/streaming_extractor.go +++ b/providers/googleai/streaming_extractor.go @@ -1,8 +1,10 @@ package googleai import ( + "bufio" "bytes" "context" + "encoding/json" "errors" "io" "net/http" @@ -25,7 +27,110 @@ func (e *StreamingExtractor) IsStreamingResponse(resp *http.Response) bool { } func (e *StreamingExtractor) ExtractStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { - return e.extractNonStreamingWithController(resp, w, rc) + if !e.IsStreamingResponse(resp) { + return e.extractNonStreamingWithController(resp, w, rc) + } + return e.extractStreamingWithController(resp, w, rc) +} + +func (e *StreamingExtractor) extractStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { + meta := llmproxy.ResponseMetadata{ + Choices: make([]llmproxy.Choice, 0), + Custom: make(map[string]any), + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + + for scanner.Scan() { + line := scanner.Bytes() + + if len(line) == 0 { + continue + } + + if bytes.HasPrefix(line, []byte("data: ")) { + data := bytes.TrimPrefix(line, []byte("data: ")) + data = bytes.TrimSpace(data) + + // Forward the raw line immediately + if _, err := w.Write(line); err != nil { + return meta, err + } + if _, err := w.Write([]byte("\n\n")); err != nil { + return meta, err + } + _ = rc.Flush() + + if len(data) == 0 { + continue + } + + // Parse for metadata extraction + var chunk Response + if err := json.Unmarshal(data, &chunk); err != nil { + continue + } + + // Extract model + if chunk.ModelName != "" { + meta.Model = chunk.ModelName + } + + // Extract usage (typically in each chunk, final values in last chunk) + if chunk.UsageMetadata.TotalTokenCount > 0 { + meta.Usage = llmproxy.Usage{ + PromptTokens: chunk.UsageMetadata.PromptTokenCount, + CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount, + TotalTokens: chunk.UsageMetadata.TotalTokenCount, + } + } + + // Extract text from candidates + for i, candidate := range chunk.Candidates { + if len(meta.Choices) <= i { + meta.Choices = append(meta.Choices, llmproxy.Choice{ + Index: i, + Message: &llmproxy.Message{ + Role: "assistant", + }, + }) + } + if candidate.Content != nil { + text := extractTextFromParts(candidate.Content.Parts) + meta.Choices[i].Message.Content += text + } + if candidate.FinishReason != "" { + meta.Choices[i].FinishReason = mapFinishReason(candidate.FinishReason) + } + } + + // Extract prompt feedback + if chunk.PromptFeedback != nil { + meta.Custom["prompt_feedback"] = chunk.PromptFeedback + } + } else { + // Forward non-data lines (e.g., comments, event types) + if _, err := w.Write(line); err != nil { + return meta, err + } + if _, err := w.Write([]byte("\n")); err != nil { + return meta, err + } + _ = rc.Flush() + } + } + + if err := scanner.Err(); err != nil { + return meta, err + } + + return meta, nil } func (e *StreamingExtractor) extractNonStreamingWithController(resp *http.Response, w http.ResponseWriter, rc *http.ResponseController) (llmproxy.ResponseMetadata, error) { diff --git a/providers/googleai/streaming_extractor_test.go b/providers/googleai/streaming_extractor_test.go new file mode 100644 index 0000000..cc1cad1 --- /dev/null +++ b/providers/googleai/streaming_extractor_test.go @@ -0,0 +1,287 @@ +package googleai + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/agentuity/llmproxy" +) + +func TestStreamingExtractor_ExtractStreaming(t *testing.T) { + streamData := `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2,"totalTokenCount":12}} + +data: {"candidates":[{"content":{"parts":[{"text":" World"}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}} + +` + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(streamData)), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingExtractor() + + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify usage extracted from last chunk + if meta.Usage.PromptTokens != 10 { + t.Errorf("expected prompt tokens 10, got %d", meta.Usage.PromptTokens) + } + if meta.Usage.CompletionTokens != 5 { + t.Errorf("expected completion tokens 5, got %d", meta.Usage.CompletionTokens) + } + if meta.Usage.TotalTokens != 15 { + t.Errorf("expected total tokens 15, got %d", meta.Usage.TotalTokens) + } + + // Verify choices extracted + if len(meta.Choices) != 1 { + t.Fatalf("expected 1 choice, got %d", len(meta.Choices)) + } + if meta.Choices[0].Message.Content != "Hello World" { + t.Errorf("expected content 'Hello World', got %q", meta.Choices[0].Message.Content) + } + if meta.Choices[0].FinishReason != "stop" { + t.Errorf("expected finish_reason 'stop', got %q", meta.Choices[0].FinishReason) + } + + // Verify data was forwarded to client + output := recorder.Body.String() + if !strings.Contains(output, "data: ") { + t.Error("expected SSE data format in output") + } + if !strings.Contains(output, `"Hello"`) { + t.Error("expected Hello text in output") + } + if !strings.Contains(output, `" World"`) { + t.Error("expected World text in output") + } +} + +func TestStreamingExtractor_StreamsIncrementally(t *testing.T) { + // Use a pipe to simulate slow upstream that sends data over time + pr, pw := io.Pipe() + + var mu sync.Mutex + var firstChunkTime time.Time + var streamDoneTime time.Time + + // Simulate upstream sending events with delay + go func() { + defer pw.Close() + pw.Write([]byte("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Hello\"}],\"role\":\"model\"}}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":2,\"totalTokenCount\":12}}\n\n")) + time.Sleep(100 * time.Millisecond) + pw.Write([]byte("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\" World\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":5,\"totalTokenCount\":15}}\n\n")) + }() + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(pr), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingExtractor() + + // Monitor when data arrives at the recorder + go func() { + for { + mu.Lock() + if recorder.Body.Len() > 0 && firstChunkTime.IsZero() { + firstChunkTime = time.Now() + } + mu.Unlock() + time.Sleep(10 * time.Millisecond) + } + }() + + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + mu.Lock() + streamDoneTime = time.Now() + mu.Unlock() + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify first data arrived before stream completed + mu.Lock() + defer mu.Unlock() + if firstChunkTime.IsZero() { + t.Fatal("no data was received") + } + // The stream takes ~100ms. First data should arrive well before completion. + timeDiff := streamDoneTime.Sub(firstChunkTime) + if timeDiff < 50*time.Millisecond { + t.Errorf("data did not arrive incrementally: first chunk and completion were only %v apart", timeDiff) + } + + // Verify metadata was still extracted + if meta.Usage.TotalTokens != 15 { + t.Errorf("expected total tokens 15, got %d", meta.Usage.TotalTokens) + } +} + +func TestStreamingExtractor_IsStreamingResponse(t *testing.T) { + extractor := NewStreamingExtractor() + + tests := []struct { + contentType string + expected bool + }{ + {"text/event-stream", true}, + {"text/event-stream; charset=utf-8", true}, + {"application/json", false}, + {"text/plain", false}, + } + + for _, tt := range tests { + t.Run(tt.contentType, func(t *testing.T) { + resp := &http.Response{ + Header: http.Header{"Content-Type": []string{tt.contentType}}, + } + result := extractor.IsStreamingResponse(resp) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestStreamingExtractor_NonStreamingFallback(t *testing.T) { + extractor := NewStreamingExtractor() + + respBody := `{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello!"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}` + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(respBody)), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.Usage.PromptTokens != 10 { + t.Errorf("expected prompt tokens 10, got %d", meta.Usage.PromptTokens) + } + if meta.Usage.CompletionTokens != 5 { + t.Errorf("expected completion tokens 5, got %d", meta.Usage.CompletionTokens) + } + if len(meta.Choices) != 1 { + t.Fatalf("expected 1 choice, got %d", len(meta.Choices)) + } + if meta.Choices[0].Message.Content != "Hello!" { + t.Errorf("expected content 'Hello!', got %q", meta.Choices[0].Message.Content) + } + + // Verify body was forwarded + output := recorder.Body.String() + if output != respBody { + t.Errorf("expected body to be forwarded, got %q", output) + } +} + +func TestStreamingExtractor_ModelExtraction(t *testing.T) { + streamData := `data: {"candidates":[{"content":{"parts":[{"text":"Hi"}],"role":"model"}}],"model":"gemini-1.5-flash","usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":1,"totalTokenCount":6}} + +` + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(streamData)), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingExtractor() + meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if meta.Model != "gemini-1.5-flash" { + t.Errorf("expected model 'gemini-1.5-flash', got %q", meta.Model) + } +} + +func TestStreamingExtractor_EmptyStream(t *testing.T) { + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader("")), + } + + recorder := httptest.NewRecorder() + rc := http.NewResponseController(recorder) + + extractor := NewStreamingExtractor() + _, err := extractor.ExtractStreamingWithController(resp, recorder, rc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestResolver_StreamingEndpoint(t *testing.T) { + resolver, err := NewResolver("https://generativelanguage.googleapis.com") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + t.Run("resolves to streamGenerateContent when streaming", func(t *testing.T) { + meta := llmproxy.BodyMetadata{Model: "gemini-pro", Stream: true} + u, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expected := "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?alt=sse" + if u.String() != expected { + t.Errorf("expected %s, got %s", expected, u.String()) + } + }) + + t.Run("resolves to generateContent when not streaming", func(t *testing.T) { + meta := llmproxy.BodyMetadata{Model: "gemini-pro", Stream: false} + u, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expected := "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" + if u.String() != expected { + t.Errorf("expected %s, got %s", expected, u.String()) + } + }) + + t.Run("defaults to gemini-pro when streaming with empty model", func(t *testing.T) { + meta := llmproxy.BodyMetadata{Stream: true} + u, err := resolver.Resolve(meta) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expected := "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?alt=sse" + if u.String() != expected { + t.Errorf("expected %s, got %s", expected, u.String()) + } + }) +} From 70f7809911d0ca0b5c713b9c1707b36161907035 Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Tue, 14 Apr 2026 00:05:20 -0500 Subject: [PATCH 7/7] fix: resolve data race in streaming tests Use thread-safe response writer with channel-based signaling instead of polling shared buffer. This eliminates the race condition where one goroutine writes to the buffer while another reads its length. - Added threadSafeResponseWriter with mutex-protected buffer - Use channel to signal first write time (eliminates race on time.Time) - Added Flush() method to satisfy http.Flusher interface - Both bedrock and googleai streaming tests now pass with -race --- providers/bedrock/streaming_extractor_test.go | 108 +++++++++++----- .../googleai/streaming_extractor_test.go | 121 +++++++++++++----- 2 files changed, 168 insertions(+), 61 deletions(-) diff --git a/providers/bedrock/streaming_extractor_test.go b/providers/bedrock/streaming_extractor_test.go index 74e7867..eac080a 100644 --- a/providers/bedrock/streaming_extractor_test.go +++ b/providers/bedrock/streaming_extractor_test.go @@ -6,15 +6,76 @@ import ( "encoding/json" "io" "net/http" - "net/http/httptest" "strings" "sync" + "sync/atomic" "testing" "time" "github.com/agentuity/llmproxy" ) +// threadSafeResponseWriter is an http.ResponseWriter that is safe for concurrent access. +// It signals via a channel when the first write occurs. +type threadSafeResponseWriter struct { + mu sync.Mutex + buf bytes.Buffer + header http.Header + wroteHead bool + firstWrite chan struct{} + closed atomic.Bool +} + +func newThreadSafeResponseWriter() *threadSafeResponseWriter { + return &threadSafeResponseWriter{ + header: make(http.Header), + firstWrite: make(chan struct{}), + } +} + +func (w *threadSafeResponseWriter) Header() http.Header { + w.mu.Lock() + defer w.mu.Unlock() + return w.header +} + +func (w *threadSafeResponseWriter) Write(data []byte) (int, error) { + w.mu.Lock() + wrote := w.wroteHead + if !wrote { + w.wroteHead = true + } + n, err := w.buf.Write(data) + w.mu.Unlock() + + if !wrote && !w.closed.Swap(true) { + close(w.firstWrite) + } + return n, err +} + +func (w *threadSafeResponseWriter) WriteHeader(code int) { + w.mu.Lock() + w.wroteHead = true + w.mu.Unlock() +} + +func (w *threadSafeResponseWriter) Flush() { + // No-op for test - the actual flush would happen in real ResponseWriter +} + +func (w *threadSafeResponseWriter) Bytes() []byte { + w.mu.Lock() + defer w.mu.Unlock() + return w.buf.Bytes() +} + +func (w *threadSafeResponseWriter) Len() int { + w.mu.Lock() + defer w.mu.Unlock() + return w.buf.Len() +} + // buildEventStreamMessage constructs a binary AWS event stream message // with the given event type and JSON payload. func buildEventStreamMessage(eventType string, payload []byte) []byte { @@ -90,7 +151,7 @@ func TestStreamingExtractor_EventStream(t *testing.T) { Body: io.NopCloser(&stream), } - recorder := httptest.NewRecorder() + recorder := newThreadSafeResponseWriter() rc := http.NewResponseController(recorder) extractor := NewStreamingExtractor() @@ -130,7 +191,7 @@ func TestStreamingExtractor_EventStream(t *testing.T) { } // Verify data was forwarded to client - if recorder.Body.Len() == 0 { + if recorder.Len() == 0 { t.Error("no data written to client") } } @@ -139,10 +200,6 @@ func TestStreamingExtractor_EventStreamIncremental(t *testing.T) { // Use a pipe to simulate slow upstream pr, pw := io.Pipe() - var mu sync.Mutex - var firstByteTime time.Time - var streamDoneTime time.Time - // Send events with delay to verify incrementality go func() { defer pw.Close() @@ -180,42 +237,35 @@ func TestStreamingExtractor_EventStreamIncremental(t *testing.T) { Body: io.NopCloser(pr), } - recorder := httptest.NewRecorder() + recorder := newThreadSafeResponseWriter() rc := http.NewResponseController(recorder) extractor := NewStreamingExtractor() - // Monitor when data arrives + // Use a channel to safely receive the first write time + firstByteTimeCh := make(chan time.Time, 1) go func() { - for { - mu.Lock() - if recorder.Body.Len() > 0 && firstByteTime.IsZero() { - firstByteTime = time.Now() - } - mu.Unlock() - time.Sleep(10 * time.Millisecond) - } + <-recorder.firstWrite + firstByteTimeCh <- time.Now() }() meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) - mu.Lock() - streamDoneTime = time.Now() - mu.Unlock() + streamDoneTime := time.Now() if err != nil { t.Fatalf("unexpected error: %v", err) } // Verify first bytes arrived before stream completed - mu.Lock() - defer mu.Unlock() - if firstByteTime.IsZero() { + select { + case firstByteTime := <-firstByteTimeCh: + timeDiff := streamDoneTime.Sub(firstByteTime) + if timeDiff < 50*time.Millisecond { + t.Errorf("data did not arrive incrementally: first chunk and completion were only %v apart", timeDiff) + } + default: t.Fatal("no data was received") } - timeDiff := streamDoneTime.Sub(firstByteTime) - if timeDiff < 50*time.Millisecond { - t.Errorf("data did not arrive incrementally: first chunk and completion were only %v apart", timeDiff) - } // Verify metadata was still extracted if meta.Usage.TotalTokens != 15 { @@ -261,7 +311,7 @@ func TestStreamingExtractor_NonStreamingFallback(t *testing.T) { Body: io.NopCloser(strings.NewReader(respBody)), } - recorder := httptest.NewRecorder() + recorder := newThreadSafeResponseWriter() rc := http.NewResponseController(recorder) meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) @@ -315,7 +365,7 @@ func TestStreamingExtractor_EventStreamWithCache(t *testing.T) { Body: io.NopCloser(&stream), } - recorder := httptest.NewRecorder() + recorder := newThreadSafeResponseWriter() rc := http.NewResponseController(recorder) extractor := NewStreamingExtractor() diff --git a/providers/googleai/streaming_extractor_test.go b/providers/googleai/streaming_extractor_test.go index cc1cad1..7feb5dd 100644 --- a/providers/googleai/streaming_extractor_test.go +++ b/providers/googleai/streaming_extractor_test.go @@ -1,17 +1,85 @@ package googleai import ( + "bytes" "io" "net/http" - "net/http/httptest" "strings" "sync" + "sync/atomic" "testing" "time" "github.com/agentuity/llmproxy" ) +// threadSafeResponseWriter is an http.ResponseWriter that is safe for concurrent access. +// It signals via a channel when the first write occurs. +type threadSafeResponseWriter struct { + mu sync.Mutex + buf bytes.Buffer + header http.Header + wroteHead bool + firstWrite chan struct{} + closed atomic.Bool +} + +func newThreadSafeResponseWriter() *threadSafeResponseWriter { + return &threadSafeResponseWriter{ + header: make(http.Header), + firstWrite: make(chan struct{}), + } +} + +func (w *threadSafeResponseWriter) Header() http.Header { + w.mu.Lock() + defer w.mu.Unlock() + return w.header +} + +func (w *threadSafeResponseWriter) Write(data []byte) (int, error) { + w.mu.Lock() + wrote := w.wroteHead + if !wrote { + w.wroteHead = true + } + n, err := w.buf.Write(data) + w.mu.Unlock() + + if !wrote && !w.closed.Swap(true) { + close(w.firstWrite) + } + return n, err +} + +func (w *threadSafeResponseWriter) WriteHeader(code int) { + w.mu.Lock() + w.wroteHead = true + w.mu.Unlock() +} + +func (w *threadSafeResponseWriter) Flush() { + // No-op for test - the actual flush would happen in real ResponseWriter +} + +func (w *threadSafeResponseWriter) Bytes() []byte { + w.mu.Lock() + defer w.mu.Unlock() + return w.buf.Bytes() +} + +func (w *threadSafeResponseWriter) Len() int { + w.mu.Lock() + defer w.mu.Unlock() + return w.buf.Len() +} + +func (w *threadSafeResponseWriter) String() string { + w.mu.Lock() + defer w.mu.Unlock() + return w.buf.String() +} + func TestStreamingExtractor_ExtractStreaming(t *testing.T) { streamData := `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2,"totalTokenCount":12}} @@ -25,7 +93,7 @@ data: {"candidates":[{"content":{"parts":[{"text":" World"}],"role":"model"},"fi Body: io.NopCloser(strings.NewReader(streamData)), } - recorder := httptest.NewRecorder() + recorder := newThreadSafeResponseWriter() rc := http.NewResponseController(recorder) extractor := NewStreamingExtractor() @@ -58,7 +126,7 @@ data: {"candidates":[{"content":{"parts":[{"text":" World"}],"role":"model"},"fi } // Verify data was forwarded to client - output := recorder.Body.String() + output := recorder.String() if !strings.Contains(output, "data: ") { t.Error("expected SSE data format in output") } @@ -74,10 +142,6 @@ func TestStreamingExtractor_StreamsIncrementally(t *testing.T) { // Use a pipe to simulate slow upstream that sends data over time pr, pw := io.Pipe() - var mu sync.Mutex - var firstChunkTime time.Time - var streamDoneTime time.Time - // Simulate upstream sending events with delay go func() { defer pw.Close() @@ -92,43 +156,36 @@ func TestStreamingExtractor_StreamsIncrementally(t *testing.T) { Body: io.NopCloser(pr), } - recorder := httptest.NewRecorder() + recorder := newThreadSafeResponseWriter() rc := http.NewResponseController(recorder) extractor := NewStreamingExtractor() - // Monitor when data arrives at the recorder + // Use a channel to safely receive the first write time + firstChunkTimeCh := make(chan time.Time, 1) go func() { - for { - mu.Lock() - if recorder.Body.Len() > 0 && firstChunkTime.IsZero() { - firstChunkTime = time.Now() - } - mu.Unlock() - time.Sleep(10 * time.Millisecond) - } + <-recorder.firstWrite + firstChunkTimeCh <- time.Now() }() meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) - mu.Lock() - streamDoneTime = time.Now() - mu.Unlock() + streamDoneTime := time.Now() if err != nil { t.Fatalf("unexpected error: %v", err) } // Verify first data arrived before stream completed - mu.Lock() - defer mu.Unlock() - if firstChunkTime.IsZero() { + select { + case firstChunkTime := <-firstChunkTimeCh: + // The stream takes ~100ms. First data should arrive well before completion. + timeDiff := streamDoneTime.Sub(firstChunkTime) + if timeDiff < 50*time.Millisecond { + t.Errorf("data did not arrive incrementally: first chunk and completion were only %v apart", timeDiff) + } + default: t.Fatal("no data was received") } - // The stream takes ~100ms. First data should arrive well before completion. - timeDiff := streamDoneTime.Sub(firstChunkTime) - if timeDiff < 50*time.Millisecond { - t.Errorf("data did not arrive incrementally: first chunk and completion were only %v apart", timeDiff) - } // Verify metadata was still extracted if meta.Usage.TotalTokens != 15 { @@ -173,7 +230,7 @@ func TestStreamingExtractor_NonStreamingFallback(t *testing.T) { Body: io.NopCloser(strings.NewReader(respBody)), } - recorder := httptest.NewRecorder() + recorder := newThreadSafeResponseWriter() rc := http.NewResponseController(recorder) meta, err := extractor.ExtractStreamingWithController(resp, recorder, rc) @@ -195,7 +252,7 @@ func TestStreamingExtractor_NonStreamingFallback(t *testing.T) { } // Verify body was forwarded - output := recorder.Body.String() + output := recorder.String() if output != respBody { t.Errorf("expected body to be forwarded, got %q", output) } @@ -212,7 +269,7 @@ func TestStreamingExtractor_ModelExtraction(t *testing.T) { Body: io.NopCloser(strings.NewReader(streamData)), } - recorder := httptest.NewRecorder() + recorder := newThreadSafeResponseWriter() rc := http.NewResponseController(recorder) extractor := NewStreamingExtractor() @@ -233,7 +290,7 @@ func TestStreamingExtractor_EmptyStream(t *testing.T) { Body: io.NopCloser(strings.NewReader("")), } - recorder := httptest.NewRecorder() + recorder := newThreadSafeResponseWriter() rc := http.NewResponseController(recorder) extractor := NewStreamingExtractor()