From 3a1cb226bc31f5528a56a2d9fd62fcfe89ca77e1 Mon Sep 17 00:00:00 2001 From: Jaycee Li Date: Thu, 16 Apr 2026 11:41:36 -0700 Subject: [PATCH] fix: ResponseStream fails to parse error message and multi-line SSE data payloads PiperOrigin-RevId: 900832738 --- .../com/google/genai/ReplayApiClient.java | 13 +- .../com/google/genai/ReplayApiResponse.java | 23 +-- .../java/com/google/genai/ResponseStream.java | 61 +++++-- .../java/com/google/genai/AsyncChatTest.java | 12 +- src/test/java/com/google/genai/ChatTest.java | 12 +- .../com/google/genai/ResponseStreamTest.java | 165 ++++++++++++++++++ 6 files changed, 258 insertions(+), 28 deletions(-) create mode 100644 src/test/java/com/google/genai/ResponseStreamTest.java diff --git a/src/main/java/com/google/genai/ReplayApiClient.java b/src/main/java/com/google/genai/ReplayApiClient.java index c2f379fc0e6..d1a3622ed2e 100644 --- a/src/main/java/com/google/genai/ReplayApiClient.java +++ b/src/main/java/com/google/genai/ReplayApiClient.java @@ -135,7 +135,13 @@ public ApiResponse request( matchRequest( currentInteraction.request().orElse(null), buildRequest(httpMethod, path, requestJson, httpOptions)); - return buildResponseFromReplay(currentInteraction.response().orElse(null)); + boolean isStream = + currentInteraction + .request() + .flatMap(r -> r.url()) + .map(u -> u.contains("streamGenerateContent")) + .orElse(false); + return buildResponseFromReplay(currentInteraction.response().orElse(null), isStream); } else { // Note that if the client mode is "api", then the ReplayApiClient will not be used. throw new IllegalArgumentException("Invalid client mode: " + this.clientMode); @@ -227,7 +233,8 @@ private void matchRequest(ReplayRequest replayRequest, Request actualRequest) { } /** Builds the response from a {@link ReplayResponse}. */ - private ReplayApiResponse buildResponseFromReplay(ReplayResponse replayResponse) { + private ReplayApiResponse buildResponseFromReplay( + ReplayResponse replayResponse, boolean isStream) { if (replayResponse == null) { throw new IllegalArgumentException("Replay response is null."); } @@ -235,7 +242,7 @@ private ReplayApiResponse buildResponseFromReplay(ReplayResponse replayResponse) JsonSerializable.toJsonNode(replayResponse.bodySegments().orElse(new ArrayList<>())); Headers headers = Headers.of(replayResponse.headers().orElse(ImmutableMap.of())); return new ReplayApiResponse( - (ArrayNode) bodyNode, replayResponse.statusCode().orElse(0), headers); + (ArrayNode) bodyNode, replayResponse.statusCode().orElse(0), headers, isStream); } private static String formatUrl(String url) { diff --git a/src/main/java/com/google/genai/ReplayApiResponse.java b/src/main/java/com/google/genai/ReplayApiResponse.java index ed0b418f506..4e58bb70b49 100644 --- a/src/main/java/com/google/genai/ReplayApiResponse.java +++ b/src/main/java/com/google/genai/ReplayApiResponse.java @@ -38,32 +38,35 @@ public final class ReplayApiResponse extends ApiResponse { private final Headers headers; private final ArrayNode bodySegments; - public ReplayApiResponse(ArrayNode bodySegments, int statusCode, Headers headers) { + public ReplayApiResponse( + ArrayNode bodySegments, int statusCode, Headers headers, boolean isStream) { this.bodySegments = bodySegments; this.statusCode = statusCode; this.headers = headers; if (bodySegments.size() == 0) { this.body = ResponseBody.create(MediaType.parse("application/json"), ""); - } else if (bodySegments.size() == 1) { - // For unary response - this.body = - ResponseBody.create( - JsonSerializable.toJsonString(bodySegments.get(0)), - MediaType.parse("application/json")); - } else { + } else if (isStream || bodySegments.size() > 1) { // For streaming response try { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - byte[] newline = "\n".getBytes(StandardCharsets.UTF_8); + byte[] dataPrefix = "data: ".getBytes(StandardCharsets.UTF_8); + byte[] doubleNewline = "\n\n".getBytes(StandardCharsets.UTF_8); for (JsonNode segment : bodySegments) { + outputStream.write(dataPrefix); outputStream.write(JsonSerializable.objectMapper.writeValueAsBytes(segment)); - outputStream.write(newline); + outputStream.write(doubleNewline); } this.body = ResponseBody.create(outputStream.toByteArray(), MediaType.parse("application/json")); } catch (IOException e) { throw new GenAiIOException("Failed to convert body segments to a JSON string.", e); } + } else { + // For unary response + this.body = + ResponseBody.create( + JsonSerializable.toJsonString(bodySegments.get(0)), + MediaType.parse("application/json")); } } diff --git a/src/main/java/com/google/genai/ResponseStream.java b/src/main/java/com/google/genai/ResponseStream.java index 3aa8e2b32d2..9065308abbe 100644 --- a/src/main/java/com/google/genai/ResponseStream.java +++ b/src/main/java/com/google/genai/ResponseStream.java @@ -1,5 +1,5 @@ /* - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,9 @@ package com.google.genai; import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.genai.errors.ApiException; import com.google.genai.errors.GenAiIOException; import java.io.BufferedReader; import java.io.IOException; @@ -32,6 +34,7 @@ import java.util.NoSuchElementException; import java.util.logging.Logger; import okhttp3.Headers; +import org.jspecify.annotations.Nullable; /** An iterable of datatype objects. */ public class ResponseStream implements Iterable, AutoCloseable { @@ -114,6 +117,18 @@ public T next() { nextJson = readNextJson(); try { JsonNode currentJsonNode = JsonSerializable.stringToJsonNode(currentJson); + + if (currentJsonNode.isObject() && currentJsonNode.has("error")) { + int extractedCode = 500; + JsonNode errorNode = currentJsonNode.get("error"); + if (errorNode.has("code") && errorNode.get("code").isInt()) { + extractedCode = errorNode.get("code").asInt(); + } + ArrayNode arrayNode = JsonSerializable.objectMapper.createArrayNode(); + arrayNode.add(currentJsonNode); + ApiException.throwFromErrorNode(arrayNode, extractedCode); + } + if (responseHeaders != null && currentJsonNode.isObject()) { ObjectNode rootNode = (ObjectNode) currentJsonNode; ObjectNode headersNode = JsonSerializable.objectMapper.createObjectNode(); @@ -142,23 +157,47 @@ public T next() { } } - private String readNextJson() { + private @Nullable String readNextJson() { // Streaming API returns in the following format: // data: {contents: ...} // \n // data: {contents: ...} // \n // ... + List dataBuffer = new ArrayList<>(); try { - String line = reader.readLine(); - if (line == null) { - return null; - } else if (line.length() == 0) { - return readNextJson(); - } else if (line.startsWith("data: ")) { - return line.substring("data: ".length()); - } else { - return line; + while (true) { + String line = reader.readLine(); + if (line == null) { + if (!dataBuffer.isEmpty()) { + return String.join("\n", dataBuffer); + } + return null; + } + if (line.isEmpty()) { + if (!dataBuffer.isEmpty()) { + // Handle multi-line SSE data + return String.join("\n", dataBuffer); + } + continue; + } + if (line.startsWith(":")) { + continue; + } + int colonIndex = line.indexOf(':'); + String fieldname = line; + String value = ""; + if (colonIndex != -1) { + fieldname = line.substring(0, colonIndex); + value = line.substring(colonIndex + 1); + if (value.startsWith(" ")) { + value = value.substring(1); + } + } + + if (fieldname.equals("data")) { + dataBuffer.add(value); + } } } catch (IOException e) { throw new GenAiIOException("Failed to read next JSON object from the stream", e); diff --git a/src/test/java/com/google/genai/AsyncChatTest.java b/src/test/java/com/google/genai/AsyncChatTest.java index 6dfa9a923d1..37652cfe03c 100644 --- a/src/test/java/com/google/genai/AsyncChatTest.java +++ b/src/test/java/com/google/genai/AsyncChatTest.java @@ -96,8 +96,16 @@ public class AsyncChatTest { String jsonChunk3 = responseChunk3.toJson(); String streamData = - "data: " + jsonChunk1 + "\n" + "data: " + jsonChunk2 + "\n" + "data: " + jsonChunk3 + "\n"; - String streamData2 = "data: " + jsonChunk1 + "\n" + "data: " + jsonChunk2 + "\n"; + "data: " + + jsonChunk1 + + "\n\n" + + "data: " + + jsonChunk2 + + "\n\n" + + "data: " + + jsonChunk3 + + "\n\n"; + String streamData2 = "data: " + jsonChunk1 + "\n\n" + "data: " + jsonChunk2 + "\n\n"; GenerateContentResponse nonStreamingResponse = GenerateContentResponse.builder() diff --git a/src/test/java/com/google/genai/ChatTest.java b/src/test/java/com/google/genai/ChatTest.java index ec3a3267253..a4ffd837abd 100644 --- a/src/test/java/com/google/genai/ChatTest.java +++ b/src/test/java/com/google/genai/ChatTest.java @@ -106,8 +106,16 @@ public static String findTheaters(String movie, String location, String time) { String jsonChunk3 = responseChunk3.toJson(); String streamData = - "data: " + jsonChunk1 + "\n" + "data: " + jsonChunk2 + "\n" + "data: " + jsonChunk3 + "\n"; - String streamData2 = "data: " + jsonChunk1 + "\n" + "data: " + jsonChunk2 + "\n"; + "data: " + + jsonChunk1 + + "\n\n" + + "data: " + + jsonChunk2 + + "\n\n" + + "data: " + + jsonChunk3 + + "\n\n"; + String streamData2 = "data: " + jsonChunk1 + "\n\n" + "data: " + jsonChunk2 + "\n\n"; GenerateContentResponse nonStreamingResponse = GenerateContentResponse.builder() diff --git a/src/test/java/com/google/genai/ResponseStreamTest.java b/src/test/java/com/google/genai/ResponseStreamTest.java new file mode 100644 index 00000000000..80fb84ab852 --- /dev/null +++ b/src/test/java/com/google/genai/ResponseStreamTest.java @@ -0,0 +1,165 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.genai; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.genai.types.Candidate; +import com.google.genai.types.GenerateContentResponse; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; +import okhttp3.Headers; +import okhttp3.MediaType; +import okhttp3.ResponseBody; +import org.junit.jupiter.api.Test; + +public final class ResponseStreamTest { + + public static class DummyConverter { + public JsonNode convert(JsonNode fromObject, ObjectNode parentObject) { + return fromObject; + } + } + + @Test + public void testMultiLineSseParsing() throws Exception { + String sseData = + "data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"line1\\n\"}]}},\n" + + "data: {\"content\": {\"parts\": [{\"text\": \"line2\"}]}}]}\n" + + "\n"; // End of event + + ResponseBody body = + ResponseBody.create( + sseData.getBytes(StandardCharsets.UTF_8), MediaType.parse("text/event-stream")); + FakeApiResponse response = new FakeApiResponse(Headers.of(), body); + + DummyConverter converter = new DummyConverter(); + ResponseStream responseStream = + new ResponseStream<>(GenerateContentResponse.class, response, converter, "convert"); + + Iterator iterator = responseStream.iterator(); + + assertTrue(iterator.hasNext()); + GenerateContentResponse response1 = iterator.next(); + + assertTrue(response1.candidates().isPresent()); + assertEquals(2, response1.candidates().get().size()); + + Candidate c1 = response1.candidates().get().get(0); + assertEquals("line1\n", c1.content().get().text()); + + Candidate c2 = response1.candidates().get().get(1); + assertEquals("line2", c2.content().get().text()); + + assertTrue(!iterator.hasNext()); + } + + @Test + public void testIgnoreNonSseLines() throws Exception { + String sseData = + "data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"valid data\"}]}}]}\n" + + ": some comment line\n" + + "ignored field: some value\n" + + "\n"; // End of event + + ResponseBody body = + ResponseBody.create( + sseData.getBytes(StandardCharsets.UTF_8), MediaType.parse("text/event-stream")); + FakeApiResponse response = new FakeApiResponse(Headers.of(), body); + + DummyConverter converter = new DummyConverter(); + ResponseStream responseStream = + new ResponseStream<>(GenerateContentResponse.class, response, converter, "convert"); + + Iterator iterator = responseStream.iterator(); + + assertTrue(iterator.hasNext()); + GenerateContentResponse response1 = iterator.next(); + + assertEquals("valid data", response1.text()); + + assertTrue(!iterator.hasNext()); + } + + @Test + public void testStreamErrorHandling() throws Exception { + String sseData = + "data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"valid data\"}]}}]}\n" + + "\n" + + "data: {\"error\": {\"code\": 429, \"message\": \"Quota exceeded\", \"status\": \"RESOURCE_EXHAUSTED\"}}\n" + + "\n"; + + ResponseBody body = + ResponseBody.create( + sseData.getBytes(StandardCharsets.UTF_8), MediaType.parse("text/event-stream")); + FakeApiResponse response = new FakeApiResponse(Headers.of(), body); + + DummyConverter converter = new DummyConverter(); + ResponseStream responseStream = + new ResponseStream<>(GenerateContentResponse.class, response, converter, "convert"); + + Iterator iterator = responseStream.iterator(); + + assertTrue(iterator.hasNext()); + GenerateContentResponse response1 = iterator.next(); + assertEquals("valid data", response1.text()); + + assertTrue(iterator.hasNext()); + + try { + iterator.next(); + org.junit.jupiter.api.Assertions.fail("Expected ApiException was not thrown"); + } catch (com.google.genai.errors.ApiException e) { + assertEquals(429, e.code()); + assertEquals("RESOURCE_EXHAUSTED", e.status()); + assertTrue(e.getMessage().contains("Quota exceeded")); + } + } + + @Test + public void testMultipleValidEvents() throws Exception { + String sseData = + "data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"chunk1\"}]}}]}\n" + + "\n" + + "data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"chunk2\"}]}}]}\n" + + "\n"; + + ResponseBody body = + ResponseBody.create( + sseData.getBytes(StandardCharsets.UTF_8), MediaType.parse("text/event-stream")); + FakeApiResponse response = new FakeApiResponse(Headers.of(), body); + + DummyConverter converter = new DummyConverter(); + ResponseStream responseStream = + new ResponseStream<>(GenerateContentResponse.class, response, converter, "convert"); + + Iterator iterator = responseStream.iterator(); + + assertTrue(iterator.hasNext()); + GenerateContentResponse response1 = iterator.next(); + assertEquals("chunk1", response1.text()); + + assertTrue(iterator.hasNext()); + GenerateContentResponse response2 = iterator.next(); + assertEquals("chunk2", response2.text()); + + assertTrue(!iterator.hasNext()); + } +}