Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/agent/src/agent.lua
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ function agent:step(prompt_builder: any, runtime_options: any): (table?, string?
arguments = tool_call.arguments,
registry_id = tool_call.registry_id,
context = tool_call.context,
provider_metadata = tool_call.provider_metadata,
agent_id = tool_info.agent_id
}

Expand Down
1 change: 1 addition & 0 deletions src/agent/src/tools/caller.lua
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ function tool_caller:validate(tool_calls: {ToolCall}?): (any, string?)
registry_id = registry_id,
meta = meta,
context = tool_call.context, -- Preserve tool context
provider_metadata = tool_call.provider_metadata,
valid = true
}

Expand Down
2 changes: 2 additions & 0 deletions src/llm/src/google/_index.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ entries:
modules:
- json
- http_client
imports:
output: wippy.llm:output

# wippy.llm.google:client_test
- name: client_test
Expand Down
224 changes: 224 additions & 0 deletions src/llm/src/google/client.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,27 @@
local json = require("json")
local http_client = require("http_client")
local output = require("output")

type StreamCallbacks = {
on_content: ((text: string) -> ())?,
on_tool_call: ((part: any) -> ())?,
on_thinking: ((text: string) -> ())?,
on_error: ((error_info: any) -> ())?,
on_done: ((result: StreamResult) -> ())?,
}

type StreamInput = {
stream: any,
metadata: table?,
}

type StreamResult = {
content: string,
tool_calls: {any},
finish_reason: string?,
usage: any?,
metadata: table,
}

local client = {
_http_client = http_client
Expand Down Expand Up @@ -41,9 +63,203 @@ local function parse_error_response(http_response)
return error_info
end

function client.process_stream(stream_response: StreamInput, callbacks: StreamCallbacks?): (string?, string?, StreamResult?)
if not stream_response or not stream_response.stream then
return nil, "Invalid stream response"
end

callbacks = callbacks or {}
local on_content = callbacks.on_content or function() end
local on_tool_call = callbacks.on_tool_call or function() end
local on_thinking = callbacks.on_thinking or function() end
local on_error = callbacks.on_error or function() end
local on_done = callbacks.on_done or function() end

local full_content = ""
local tool_calls = {}
local finish_reason = nil
local usage = nil
local metadata = stream_response.metadata or {}

while true do
local chunk, err = stream_response.stream:read()

if err then
on_error({ message = err })
return nil, err
end

if not chunk then
break
end

if chunk == "" then
goto continue
end

for data_line in chunk:gmatch('data:%s*(.-)%s*\n') do
if data_line == "" then
goto continue_line
end

local parsed, parse_err = json.decode(data_line)
if parse_err then
goto continue_line
end

if parsed.error then
local error_info = {
message = parsed.error.message,
code = parsed.error.code,
status = parsed.error.status
}
on_error(error_info)
return nil, error_info.message, { error = error_info }
end

if parsed.modelVersion then
metadata.model_version = parsed.modelVersion
end
if parsed.responseId then
metadata.response_id = parsed.responseId
end

if parsed.candidates and parsed.candidates[1] then
local candidate = parsed.candidates[1]

if candidate.content and candidate.content.parts then
for _, part in ipairs(candidate.content.parts) do
if part.functionCall then
table.insert(tool_calls, part)
on_tool_call(part)
elseif part.text then
if part.thought == true then
on_thinking(part.text)
else
full_content = full_content .. part.text
on_content(part.text)
end
end
end
end

if candidate.finishReason then
finish_reason = candidate.finishReason
end
end

if parsed.usageMetadata then
usage = parsed.usageMetadata
end

::continue_line::
end

::continue::
end

local result: StreamResult = {
content = full_content,
tool_calls = tool_calls,
finish_reason = finish_reason,
usage = usage,
metadata = metadata
}

on_done(result)
return full_content, nil, result
end

--- Process a streaming response and send chunks via output.streamer.
--- Returns an aggregated Google-like response compatible with map_success_response().
local function handle_stream_response(response, http_options)
local streamer = output.streamer(
http_options.stream_reply_to,
http_options.stream_topic,
http_options.stream_buffer_size or 10
)

local full_content = ""
local tool_call_parts = {}
local finish_reason = nil
local usage_metadata = nil
local response_metadata = {}

local _, stream_err = client.process_stream(
{ stream = response.stream, metadata = {} },
{
on_content = function(chunk: string)
full_content = full_content .. chunk
streamer:buffer_content(chunk)
end,

on_tool_call = function(tool_part: any)
table.insert(tool_call_parts, tool_part)
if tool_part.functionCall then
streamer:send_tool_call(
tool_part.functionCall.name,
tool_part.functionCall.args or {},
tool_part.functionCall.name
)
end
end,

on_thinking = function(text: string)
streamer:send_thinking(text)
end,

on_error = function(error_info: any)
streamer:send_error("server_error", error_info.message)
end,

on_done = function(result: StreamResult)
streamer:flush()
finish_reason = result.finish_reason
usage_metadata = result.usage
response_metadata = result.metadata
end
}
)

if stream_err then
return nil, {
status_code = 500,
message = "Stream processing failed: " .. tostring(stream_err)
}
end

-- Reconstruct Google-like response
local parts = {}
if full_content ~= "" then
table.insert(parts, { text = full_content })
end
for _, tc_part in ipairs(tool_call_parts) do
table.insert(parts, tc_part)
end

return {
candidates = {
{
content = { parts = parts, role = "model" },
finishReason = finish_reason
}
},
usageMetadata = usage_metadata,
modelVersion = response_metadata.model_version,
responseId = response_metadata.response_id,
metadata = response_metadata,
status_code = response.status_code or 200
}
end

function client.request(method, url, http_options)
http_options.headers["Accept"] = "application/json"

if http_options.stream then
url = url .. "?alt=sse"
http_options.headers["Accept"] = "text/event-stream"
end

local response = nil
local err = nil
if method == "GET" then
Expand All @@ -61,10 +277,18 @@ function client.request(method, url, http_options)
end

if response.status_code < 200 or response.status_code >= 300 then
if http_options.stream and response.stream and not response.body then
response.body = response.stream:read()
end
local parsed_error = parse_error_response(response)
return nil, parsed_error
end

-- Streaming: process stream, send chunks via streamer, return aggregated response
if http_options.stream and response.stream then
return handle_stream_response(response, http_options)
end

local parsed, parse_err = json.decode(response.body)
if parse_err then
local parse_error = {
Expand Down
15 changes: 13 additions & 2 deletions src/llm/src/google/generate.lua
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,22 @@ function generate.handler(contract_args)
})
end

local endpoint_path = "generateContent"
local request_options = { timeout = contract_args.timeout }

if contract_args.stream and contract_args.stream.reply_to then
endpoint_path = "streamGenerateContent"
request_options.stream = true
request_options.stream_reply_to = contract_args.stream.reply_to
request_options.stream_topic = contract_args.stream.topic
request_options.stream_buffer_size = contract_args.stream.buffer_size
end

local response = client_instance:request({
endpoint_path = "generateContent",
endpoint_path = endpoint_path,
model = contract_args.model,
payload = payload,
options = { timeout = contract_args.timeout }
options = request_options
})

if response.status_code < 200 or response.status_code >= 300 then
Expand Down
6 changes: 6 additions & 0 deletions src/llm/src/google/generative_ai/client.lua
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ function generative_ai_client.request(contract_args)
if contract_args.options.method == "POST" then
options.body = json.encode(contract_args.payload or {})
end
if contract_args.options.stream then
options.stream = true
options.stream_reply_to = contract_args.options.stream_reply_to
options.stream_topic = contract_args.options.stream_topic
options.stream_buffer_size = contract_args.options.stream_buffer_size
end

local base_url = contract_args.options.base_url or generative_ai_client._config.get_generative_ai_base_url()
if contract_args.model and contract_args.model ~= "" then
Expand Down
33 changes: 22 additions & 11 deletions src/llm/src/google/mapper.lua
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,16 @@ function mapper.map_messages(contract_messages, options)
and json.decode(msg.function_call.arguments)
or msg.function_call.arguments

table.insert(processed_messages, { role = "model", parts = {
local part = {
functionCall = {
name = msg.function_call.name,
args = next(arguments or {}) ~= nil and arguments or nil
}
} })
}
if msg.function_call.provider_metadata and msg.function_call.provider_metadata.thought_signature then
part.thoughtSignature = msg.function_call.provider_metadata.thought_signature
end
table.insert(processed_messages, { role = "model", parts = part })
i = i + 1
else
-- Skip unknown message types
Expand Down Expand Up @@ -255,18 +259,25 @@ function mapper.map_options(contract_options)
}
end

function mapper.map_tool_calls(function_calls)
if not function_calls then
function mapper.map_tool_calls(content_parts)
if not content_parts or #content_parts == 0 then
return {}
end

local contract_tool_calls = {}
for i, function_call in ipairs(function_calls) do
contract_tool_calls[i] = {
id = (function_call.name or "func") .. "_" .. time.now():unix(),
name = function_call.name,
arguments = function_call.args or {},
}
for i, content_part in ipairs(content_parts) do
if content_part.functionCall then
contract_tool_calls[i] = {
id = (content_part.functionCall.name or "func") .. "_" .. time.now():unix(),
name = content_part.functionCall.name,
arguments = content_part.functionCall.args or {},
}
if content_part.thoughtSignature then
contract_tool_calls[i].provider_metadata = {
thought_signature = content_part.thoughtSignature
}
end
end
end

return contract_tool_calls
Expand Down Expand Up @@ -327,7 +338,7 @@ function mapper.map_success_response(google_response)
if content_part.text then
content = content .. content_part.text
elseif content_part.functionCall then
table.insert(tool_calls, content_part.functionCall)
table.insert(tool_calls, content_part)
end
end
end
Expand Down
9 changes: 8 additions & 1 deletion src/llm/src/google/vertex/client.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ local vertex_client = {
}

local PROJECT_REQUIRED_ENDPOINTS = {
"generateContent"
"generateContent",
"streamGenerateContent"
}

local function build_url(base_url, contract_args)
Expand Down Expand Up @@ -60,6 +61,12 @@ function vertex_client.request(contract_args)
if contract_args.options.method == "POST" then
options.body = json.encode(contract_args.payload or {})
end
if contract_args.options.stream then
options.stream = true
options.stream_reply_to = contract_args.options.stream_reply_to
options.stream_topic = contract_args.options.stream_topic
options.stream_buffer_size = contract_args.options.stream_buffer_size
end

local base_url = contract_args.options.base_url
if not base_url then
Expand Down
Loading
Loading