Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,76 @@
// resp, meta, _ := proxy.Forward(ctx, req)
package llmproxy

import "encoding/json"

// Message represents a single message in a chat completion request.
type Message struct {
// Role is the role of the message author (e.g., "user", "assistant", "system").
Role string `json:"role"`
// Content is the text content of the message.
Content string `json:"content"`
// Content is the content of the message (can be string or array for multimodal).
Content any `json:"content"`
// Custom holds provider-specific message fields that don't map to standard fields.
Custom map[string]any `json:"-"`
}

// UnmarshalJSON implements custom JSON unmarshaling to capture unknown fields.
func (m *Message) UnmarshalJSON(data []byte) error {
type Alias Message
aux := &struct {
*Alias
}{
Alias: (*Alias)(m),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}

var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
return err
}

m.Custom = make(map[string]any)
for k, v := range raw {
if k != "role" && k != "content" {
m.Custom[k] = v
}
}

return nil
}

// MarshalJSON implements custom JSON marshaling to include Custom fields.
func (m Message) MarshalJSON() ([]byte, error) {
type Alias Message
aux := &struct {
Alias
}{
Alias: (Alias)(m),
}

data, err := json.Marshal(aux)
if err != nil {
return nil, err
}

if len(m.Custom) == 0 {
return data, nil
}

var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
return nil, err
}

for k, v := range m.Custom {
if k == "role" || k == "content" {
continue
}
result[k] = v
}

return json.Marshal(result)
}

// BodyMetadata contains extracted metadata from a parsed request body.
Expand Down
20 changes: 18 additions & 2 deletions providers/anthropic/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func TestParser(t *testing.T) {
t.Errorf("expected 1 message, got %d", len(meta.Messages))
}
if meta.Messages[0].Content != "hello" {
t.Errorf("expected content 'hello', got %s", meta.Messages[0].Content)
t.Errorf("expected content 'hello', got %v", meta.Messages[0].Content)
}
})

Expand Down Expand Up @@ -148,7 +148,7 @@ func TestExtractor(t *testing.T) {
t.Errorf("expected 1 choice, got %d", len(meta.Choices))
}
if meta.Choices[0].Message.Content != "Hello!" {
t.Errorf("expected content 'Hello!', got %s", meta.Choices[0].Message.Content)
t.Errorf("expected content 'Hello!', got %v", meta.Choices[0].Message.Content)
}
if string(raw) != respBody {
t.Error("raw body mismatch")
Expand Down Expand Up @@ -201,3 +201,19 @@ func TestExtractor(t *testing.T) {
}
})
}

func TestParser_ContentArrayWithMultipleTypes(t *testing.T) {
body := `{"model":"claude-3-opus-20240229","max_tokens":1024,"messages":[{"role":"user","content":[{"type":"text","text":"hello"},{"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc123"}}]}]}`
parser := &Parser{}

meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body))))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(meta.Messages) != 1 {
t.Errorf("expected 1 message, got %d", len(meta.Messages))
}
if meta.Messages[0].Content != "hello" {
t.Errorf("expected content 'hello', got %v", meta.Messages[0].Content)
}
}
17 changes: 15 additions & 2 deletions providers/bedrock/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestParser(t *testing.T) {
t.Errorf("expected role user, got %s", meta.Messages[0].Role)
}
if meta.Messages[0].Content != "hello" {
t.Errorf("expected content 'hello', got %s", meta.Messages[0].Content)
t.Errorf("expected content 'hello', got %v", meta.Messages[0].Content)
}
if string(raw) != body {
t.Error("raw body mismatch")
Expand Down Expand Up @@ -181,7 +181,7 @@ func TestExtractor(t *testing.T) {
t.Errorf("expected 1 choice, got %d", len(meta.Choices))
}
if meta.Choices[0].Message.Content != "Hello!" {
t.Errorf("expected content 'Hello!', got %s", meta.Choices[0].Message.Content)
t.Errorf("expected content 'Hello!', got %v", meta.Choices[0].Message.Content)
}
if meta.Choices[0].FinishReason != "end_turn" {
t.Errorf("expected finish_reason end_turn, got %s", meta.Choices[0].FinishReason)
Expand All @@ -191,3 +191,16 @@ func TestExtractor(t *testing.T) {
}
})
}

func TestParser_MessageWithMultipleContentBlocks(t *testing.T) {
body := `{"modelId":"anthropic.claude-3-sonnet-20240229-v1:0","messages":[{"role":"user","content":[{"text":"hello"},{"text":"world"}]}]}`
parser := &Parser{}

meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body))))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if meta.Messages[0].Content != "helloworld" {
t.Errorf("expected combined content 'helloworld', got %v", meta.Messages[0].Content)
}
}
6 changes: 5 additions & 1 deletion providers/bedrock/streaming_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,11 @@ func processBedrockStreamEvent(event *eventStreamEvent, meta *llmproxy.ResponseM
})
}
if meta.Choices[0].Message != nil {
meta.Choices[0].Message.Content += delta.Delta.Text
if str, ok := meta.Choices[0].Message.Content.(string); ok {
meta.Choices[0].Message.Content = str + delta.Delta.Text
} else {
meta.Choices[0].Message.Content = delta.Delta.Text
}
}
}

Expand Down
30 changes: 28 additions & 2 deletions providers/googleai/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestParser(t *testing.T) {
t.Errorf("expected role user, got %s", meta.Messages[0].Role)
}
if meta.Messages[0].Content != "hello" {
t.Errorf("expected content 'hello', got %s", meta.Messages[0].Content)
t.Errorf("expected content 'hello', got %v", meta.Messages[0].Content)
}
if string(raw) != body {
t.Error("raw body mismatch")
Expand Down Expand Up @@ -155,7 +155,7 @@ func TestExtractor(t *testing.T) {
t.Errorf("expected 1 choice, got %d", len(meta.Choices))
}
if meta.Choices[0].Message.Content != "Hello!" {
t.Errorf("expected content 'Hello!', got %s", meta.Choices[0].Message.Content)
t.Errorf("expected content 'Hello!', got %v", meta.Choices[0].Message.Content)
}
if meta.Choices[0].FinishReason != "stop" {
t.Errorf("expected finish_reason 'stop', got %s", meta.Choices[0].FinishReason)
Expand Down Expand Up @@ -185,3 +185,29 @@ func TestExtractor(t *testing.T) {
}
})
}

func TestParser_MessageWithMultipleParts(t *testing.T) {
body := `{"contents":[{"role":"user","parts":[{"text":"hello"},{"text":"world"}]}]}`
parser := &Parser{}

meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body))))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if meta.Messages[0].Content != "helloworld" {
t.Errorf("expected combined content 'helloworld', got %v", meta.Messages[0].Content)
}
}

func TestParser_MessageWithInlineData(t *testing.T) {
body := `{"contents":[{"role":"user","parts":[{"text":"describe this"},{"inlineData":{"mimeType":"image/png","data":"abc123"}}]}]}`
parser := &Parser{}

meta, _, err := parser.Parse(io.NopCloser(bytes.NewReader([]byte(body))))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if meta.Messages[0].Content != "describe this" {
t.Errorf("expected text content, got %v", meta.Messages[0].Content)
}
}
6 changes: 5 additions & 1 deletion providers/googleai/streaming_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ func (e *StreamingExtractor) extractStreamingWithController(resp *http.Response,
}
if candidate.Content != nil {
text := extractTextFromParts(candidate.Content.Parts)
meta.Choices[i].Message.Content += text
if str, ok := meta.Choices[i].Message.Content.(string); ok {
meta.Choices[i].Message.Content = str + text
} else {
meta.Choices[i].Message.Content = text
}
}
if candidate.FinishReason != "" {
meta.Choices[i].FinishReason = mapFinishReason(candidate.FinishReason)
Expand Down
Loading
Loading