diff --git a/src/agent/src/agent.lua b/src/agent/src/agent.lua index b8990cf..c7621c5 100644 --- a/src/agent/src/agent.lua +++ b/src/agent/src/agent.lua @@ -587,6 +587,7 @@ function agent:step(prompt_builder: any, runtime_options: any): (table?, string? arguments = tool_call.arguments, registry_id = tool_call.registry_id, context = tool_call.context, + provider_metadata = tool_call.provider_metadata, agent_id = tool_info.agent_id } diff --git a/src/agent/src/tools/caller.lua b/src/agent/src/tools/caller.lua index b3e7cc3..0e21360 100644 --- a/src/agent/src/tools/caller.lua +++ b/src/agent/src/tools/caller.lua @@ -126,6 +126,7 @@ function tool_caller:validate(tool_calls: {ToolCall}?): (any, string?) registry_id = registry_id, meta = meta, context = tool_call.context, -- Preserve tool context + provider_metadata = tool_call.provider_metadata, valid = true } diff --git a/src/llm/src/google/_index.yaml b/src/llm/src/google/_index.yaml index 5d58922..180b1e0 100644 --- a/src/llm/src/google/_index.yaml +++ b/src/llm/src/google/_index.yaml @@ -113,6 +113,8 @@ entries: modules: - json - http_client + imports: + output: wippy.llm:output # wippy.llm.google:client_test - name: client_test diff --git a/src/llm/src/google/client.lua b/src/llm/src/google/client.lua index 1307bca..c3694bc 100644 --- a/src/llm/src/google/client.lua +++ b/src/llm/src/google/client.lua @@ -1,5 +1,27 @@ local json = require("json") local http_client = require("http_client") +local output = require("output") + +type StreamCallbacks = { + on_content: ((text: string) -> ())?, + on_tool_call: ((part: any) -> ())?, + on_thinking: ((text: string) -> ())?, + on_error: ((error_info: any) -> ())?, + on_done: ((result: StreamResult) -> ())?, +} + +type StreamInput = { + stream: any, + metadata: table?, +} + +type StreamResult = { + content: string, + tool_calls: {any}, + finish_reason: string?, + usage: any?, + metadata: table, +} local client = { _http_client = http_client @@ -41,9 +63,203 @@ local function parse_error_response(http_response) return error_info end +function client.process_stream(stream_response: StreamInput, callbacks: StreamCallbacks?): (string?, string?, StreamResult?) + if not stream_response or not stream_response.stream then + return nil, "Invalid stream response" + end + + callbacks = callbacks or {} + local on_content = callbacks.on_content or function() end + local on_tool_call = callbacks.on_tool_call or function() end + local on_thinking = callbacks.on_thinking or function() end + local on_error = callbacks.on_error or function() end + local on_done = callbacks.on_done or function() end + + local full_content = "" + local tool_calls = {} + local finish_reason = nil + local usage = nil + local metadata = stream_response.metadata or {} + + while true do + local chunk, err = stream_response.stream:read() + + if err then + on_error({ message = err }) + return nil, err + end + + if not chunk then + break + end + + if chunk == "" then + goto continue + end + + for data_line in chunk:gmatch('data:%s*(.-)%s*\n') do + if data_line == "" then + goto continue_line + end + + local parsed, parse_err = json.decode(data_line) + if parse_err then + goto continue_line + end + + if parsed.error then + local error_info = { + message = parsed.error.message, + code = parsed.error.code, + status = parsed.error.status + } + on_error(error_info) + return nil, error_info.message, { error = error_info } + end + + if parsed.modelVersion then + metadata.model_version = parsed.modelVersion + end + if parsed.responseId then + metadata.response_id = parsed.responseId + end + + if parsed.candidates and parsed.candidates[1] then + local candidate = parsed.candidates[1] + + if candidate.content and candidate.content.parts then + for _, part in ipairs(candidate.content.parts) do + if part.functionCall then + table.insert(tool_calls, part) + on_tool_call(part) + elseif part.text then + if part.thought == true then + on_thinking(part.text) + else + full_content = full_content .. part.text + on_content(part.text) + end + end + end + end + + if candidate.finishReason then + finish_reason = candidate.finishReason + end + end + + if parsed.usageMetadata then + usage = parsed.usageMetadata + end + + ::continue_line:: + end + + ::continue:: + end + + local result: StreamResult = { + content = full_content, + tool_calls = tool_calls, + finish_reason = finish_reason, + usage = usage, + metadata = metadata + } + + on_done(result) + return full_content, nil, result +end + +--- Process a streaming response and send chunks via output.streamer. +--- Returns an aggregated Google-like response compatible with map_success_response(). +local function handle_stream_response(response, http_options) + local streamer = output.streamer( + http_options.stream_reply_to, + http_options.stream_topic, + http_options.stream_buffer_size or 10 + ) + + local full_content = "" + local tool_call_parts = {} + local finish_reason = nil + local usage_metadata = nil + local response_metadata = {} + + local _, stream_err = client.process_stream( + { stream = response.stream, metadata = {} }, + { + on_content = function(chunk: string) + full_content = full_content .. chunk + streamer:buffer_content(chunk) + end, + + on_tool_call = function(tool_part: any) + table.insert(tool_call_parts, tool_part) + if tool_part.functionCall then + streamer:send_tool_call( + tool_part.functionCall.name, + tool_part.functionCall.args or {}, + tool_part.functionCall.name + ) + end + end, + + on_thinking = function(text: string) + streamer:send_thinking(text) + end, + + on_error = function(error_info: any) + streamer:send_error("server_error", error_info.message) + end, + + on_done = function(result: StreamResult) + streamer:flush() + finish_reason = result.finish_reason + usage_metadata = result.usage + response_metadata = result.metadata + end + } + ) + + if stream_err then + return nil, { + status_code = 500, + message = "Stream processing failed: " .. tostring(stream_err) + } + end + + -- Reconstruct Google-like response + local parts = {} + if full_content ~= "" then + table.insert(parts, { text = full_content }) + end + for _, tc_part in ipairs(tool_call_parts) do + table.insert(parts, tc_part) + end + + return { + candidates = { + { + content = { parts = parts, role = "model" }, + finishReason = finish_reason + } + }, + usageMetadata = usage_metadata, + modelVersion = response_metadata.model_version, + responseId = response_metadata.response_id, + metadata = response_metadata, + status_code = response.status_code or 200 + } +end + function client.request(method, url, http_options) http_options.headers["Accept"] = "application/json" + if http_options.stream then + url = url .. "?alt=sse" + http_options.headers["Accept"] = "text/event-stream" + end + local response = nil local err = nil if method == "GET" then @@ -61,10 +277,18 @@ function client.request(method, url, http_options) end if response.status_code < 200 or response.status_code >= 300 then + if http_options.stream and response.stream and not response.body then + response.body = response.stream:read() + end local parsed_error = parse_error_response(response) return nil, parsed_error end + -- Streaming: process stream, send chunks via streamer, return aggregated response + if http_options.stream and response.stream then + return handle_stream_response(response, http_options) + end + local parsed, parse_err = json.decode(response.body) if parse_err then local parse_error = { diff --git a/src/llm/src/google/generate.lua b/src/llm/src/google/generate.lua index fc0c903..4233e9e 100644 --- a/src/llm/src/google/generate.lua +++ b/src/llm/src/google/generate.lua @@ -84,11 +84,22 @@ function generate.handler(contract_args) }) end + local endpoint_path = "generateContent" + local request_options = { timeout = contract_args.timeout } + + if contract_args.stream and contract_args.stream.reply_to then + endpoint_path = "streamGenerateContent" + request_options.stream = true + request_options.stream_reply_to = contract_args.stream.reply_to + request_options.stream_topic = contract_args.stream.topic + request_options.stream_buffer_size = contract_args.stream.buffer_size + end + local response = client_instance:request({ - endpoint_path = "generateContent", + endpoint_path = endpoint_path, model = contract_args.model, payload = payload, - options = { timeout = contract_args.timeout } + options = request_options }) if response.status_code < 200 or response.status_code >= 300 then diff --git a/src/llm/src/google/generative_ai/client.lua b/src/llm/src/google/generative_ai/client.lua index 31c722c..410d773 100644 --- a/src/llm/src/google/generative_ai/client.lua +++ b/src/llm/src/google/generative_ai/client.lua @@ -26,6 +26,12 @@ function generative_ai_client.request(contract_args) if contract_args.options.method == "POST" then options.body = json.encode(contract_args.payload or {}) end + if contract_args.options.stream then + options.stream = true + options.stream_reply_to = contract_args.options.stream_reply_to + options.stream_topic = contract_args.options.stream_topic + options.stream_buffer_size = contract_args.options.stream_buffer_size + end local base_url = contract_args.options.base_url or generative_ai_client._config.get_generative_ai_base_url() if contract_args.model and contract_args.model ~= "" then diff --git a/src/llm/src/google/mapper.lua b/src/llm/src/google/mapper.lua index 4bf2d5e..349cc1b 100644 --- a/src/llm/src/google/mapper.lua +++ b/src/llm/src/google/mapper.lua @@ -185,12 +185,16 @@ function mapper.map_messages(contract_messages, options) and json.decode(msg.function_call.arguments) or msg.function_call.arguments - table.insert(processed_messages, { role = "model", parts = { + local part = { functionCall = { name = msg.function_call.name, args = next(arguments or {}) ~= nil and arguments or nil } - } }) + } + if msg.function_call.provider_metadata and msg.function_call.provider_metadata.thought_signature then + part.thoughtSignature = msg.function_call.provider_metadata.thought_signature + end + table.insert(processed_messages, { role = "model", parts = part }) i = i + 1 else -- Skip unknown message types @@ -255,18 +259,25 @@ function mapper.map_options(contract_options) } end -function mapper.map_tool_calls(function_calls) - if not function_calls then +function mapper.map_tool_calls(content_parts) + if not content_parts or #content_parts == 0 then return {} end local contract_tool_calls = {} - for i, function_call in ipairs(function_calls) do - contract_tool_calls[i] = { - id = (function_call.name or "func") .. "_" .. time.now():unix(), - name = function_call.name, - arguments = function_call.args or {}, - } + for i, content_part in ipairs(content_parts) do + if content_part.functionCall then + contract_tool_calls[i] = { + id = (content_part.functionCall.name or "func") .. "_" .. time.now():unix(), + name = content_part.functionCall.name, + arguments = content_part.functionCall.args or {}, + } + if content_part.thoughtSignature then + contract_tool_calls[i].provider_metadata = { + thought_signature = content_part.thoughtSignature + } + end + end end return contract_tool_calls @@ -327,7 +338,7 @@ function mapper.map_success_response(google_response) if content_part.text then content = content .. content_part.text elseif content_part.functionCall then - table.insert(tool_calls, content_part.functionCall) + table.insert(tool_calls, content_part) end end end diff --git a/src/llm/src/google/vertex/client.lua b/src/llm/src/google/vertex/client.lua index ded5b10..3845e6a 100644 --- a/src/llm/src/google/vertex/client.lua +++ b/src/llm/src/google/vertex/client.lua @@ -8,7 +8,8 @@ local vertex_client = { } local PROJECT_REQUIRED_ENDPOINTS = { - "generateContent" + "generateContent", + "streamGenerateContent" } local function build_url(base_url, contract_args) @@ -60,6 +61,12 @@ function vertex_client.request(contract_args) if contract_args.options.method == "POST" then options.body = json.encode(contract_args.payload or {}) end + if contract_args.options.stream then + options.stream = true + options.stream_reply_to = contract_args.options.stream_reply_to + options.stream_topic = contract_args.options.stream_topic + options.stream_buffer_size = contract_args.options.stream_buffer_size + end local base_url = contract_args.options.base_url if not base_url then diff --git a/src/llm/src/prompt.lua b/src/llm/src/prompt.lua index 2cada79..cda55e8 100644 --- a/src/llm/src/prompt.lua +++ b/src/llm/src/prompt.lua @@ -16,10 +16,15 @@ type ContentPart = { source: ImageSource?, } +type FunctionCallOptions = { + provider_metadata: table?, +} + type FunctionCall = { name: string, arguments: string, id: string?, + provider_metadata: table?, } type Message = { @@ -39,7 +44,7 @@ type PromptBuilder = { add_assistant: (self: any, content: string, meta: table?) -> PromptBuilder, add_developer: (self: any, content: string, meta: table?) -> PromptBuilder, add_message: (self: any, role: string, content_parts: {ContentPart}, name: string?, metadata: table?) -> PromptBuilder, - add_function_call: (self: any, function_name: string, arguments: string, function_call_id: string?) -> PromptBuilder, + add_function_call: (self: any, function_name: string, arguments: string, function_call_id: string?, options: FunctionCallOptions?) -> PromptBuilder, add_function_result: (self: any, name: string, content: any, function_call_id: string?) -> PromptBuilder, add_cache_marker: (self: any, marker_id: string?) -> PromptBuilder, get_messages: (self: any) -> {Message}, @@ -284,7 +289,7 @@ function prompt.new(messages: {Message}?) end -- Add a function call by assistant - builder.add_function_call = function(self: any, function_name: string, arguments: string, function_call_id: string?) + builder.add_function_call = function(self: any, function_name: string, arguments: string, function_call_id: string?, options: FunctionCallOptions?) if function_name and arguments then local message = { role = prompt.ROLE.FUNCTION_CALL, @@ -299,6 +304,10 @@ function prompt.new(messages: {Message}?) message.function_call.id = function_call_id end + if options and options.provider_metadata then + message.function_call.provider_metadata = options.provider_metadata + end + table.insert(self.messages, message) end return self