Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/main/java/com/google/genai/ReplayApiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,13 @@ public ApiResponse request(
matchRequest(
currentInteraction.request().orElse(null),
buildRequest(httpMethod, path, requestJson, httpOptions));
return buildResponseFromReplay(currentInteraction.response().orElse(null));
boolean isStream =
currentInteraction
.request()
.flatMap(r -> r.url())
.map(u -> u.contains("streamGenerateContent"))
.orElse(false);
return buildResponseFromReplay(currentInteraction.response().orElse(null), isStream);
} else {
// Note that if the client mode is "api", then the ReplayApiClient will not be used.
throw new IllegalArgumentException("Invalid client mode: " + this.clientMode);
Expand Down Expand Up @@ -227,15 +233,16 @@ private void matchRequest(ReplayRequest replayRequest, Request actualRequest) {
}

/** Builds the response from a {@link ReplayResponse}. */
private ReplayApiResponse buildResponseFromReplay(ReplayResponse replayResponse) {
private ReplayApiResponse buildResponseFromReplay(
ReplayResponse replayResponse, boolean isStream) {
if (replayResponse == null) {
throw new IllegalArgumentException("Replay response is null.");
}
JsonNode bodyNode =
JsonSerializable.toJsonNode(replayResponse.bodySegments().orElse(new ArrayList<>()));
Headers headers = Headers.of(replayResponse.headers().orElse(ImmutableMap.of()));
return new ReplayApiResponse(
(ArrayNode) bodyNode, replayResponse.statusCode().orElse(0), headers);
(ArrayNode) bodyNode, replayResponse.statusCode().orElse(0), headers, isStream);
}

private static String formatUrl(String url) {
Expand Down
23 changes: 13 additions & 10 deletions src/main/java/com/google/genai/ReplayApiResponse.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,32 +38,35 @@ public final class ReplayApiResponse extends ApiResponse {
private final Headers headers;
private final ArrayNode bodySegments;

public ReplayApiResponse(ArrayNode bodySegments, int statusCode, Headers headers) {
public ReplayApiResponse(
ArrayNode bodySegments, int statusCode, Headers headers, boolean isStream) {
this.bodySegments = bodySegments;
this.statusCode = statusCode;
this.headers = headers;
if (bodySegments.size() == 0) {
this.body = ResponseBody.create(MediaType.parse("application/json"), "");
} else if (bodySegments.size() == 1) {
// For unary response
this.body =
ResponseBody.create(
JsonSerializable.toJsonString(bodySegments.get(0)),
MediaType.parse("application/json"));
} else {
} else if (isStream || bodySegments.size() > 1) {
// For streaming response
try {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
byte[] newline = "\n".getBytes(StandardCharsets.UTF_8);
byte[] dataPrefix = "data: ".getBytes(StandardCharsets.UTF_8);
byte[] doubleNewline = "\n\n".getBytes(StandardCharsets.UTF_8);
for (JsonNode segment : bodySegments) {
outputStream.write(dataPrefix);
outputStream.write(JsonSerializable.objectMapper.writeValueAsBytes(segment));
outputStream.write(newline);
outputStream.write(doubleNewline);
}
this.body =
ResponseBody.create(outputStream.toByteArray(), MediaType.parse("application/json"));
} catch (IOException e) {
throw new GenAiIOException("Failed to convert body segments to a JSON string.", e);
}
} else {
// For unary response
this.body =
ResponseBody.create(
JsonSerializable.toJsonString(bodySegments.get(0)),
MediaType.parse("application/json"));
}
}

Expand Down
61 changes: 50 additions & 11 deletions src/main/java/com/google/genai/ResponseStream.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2025 Google LLC
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,9 @@
package com.google.genai;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.genai.errors.ApiException;
import com.google.genai.errors.GenAiIOException;
import java.io.BufferedReader;
import java.io.IOException;
Expand All @@ -32,6 +34,7 @@
import java.util.NoSuchElementException;
import java.util.logging.Logger;
import okhttp3.Headers;
import org.jspecify.annotations.Nullable;

/** An iterable of datatype objects. */
public class ResponseStream<T extends JsonSerializable> implements Iterable<T>, AutoCloseable {
Expand Down Expand Up @@ -114,6 +117,18 @@ public T next() {
nextJson = readNextJson();
try {
JsonNode currentJsonNode = JsonSerializable.stringToJsonNode(currentJson);

if (currentJsonNode.isObject() && currentJsonNode.has("error")) {
int extractedCode = 500;
JsonNode errorNode = currentJsonNode.get("error");
if (errorNode.has("code") && errorNode.get("code").isInt()) {
extractedCode = errorNode.get("code").asInt();
}
ArrayNode arrayNode = JsonSerializable.objectMapper.createArrayNode();
arrayNode.add(currentJsonNode);
ApiException.throwFromErrorNode(arrayNode, extractedCode);
}

if (responseHeaders != null && currentJsonNode.isObject()) {
ObjectNode rootNode = (ObjectNode) currentJsonNode;
ObjectNode headersNode = JsonSerializable.objectMapper.createObjectNode();
Expand Down Expand Up @@ -142,23 +157,47 @@ public T next() {
}
}

private String readNextJson() {
private @Nullable String readNextJson() {
// Streaming API returns in the following format:
// data: {contents: ...}
// \n
// data: {contents: ...}
// \n
// ...
List<String> dataBuffer = new ArrayList<>();
try {
String line = reader.readLine();
if (line == null) {
return null;
} else if (line.length() == 0) {
return readNextJson();
} else if (line.startsWith("data: ")) {
return line.substring("data: ".length());
} else {
return line;
while (true) {
String line = reader.readLine();
if (line == null) {
if (!dataBuffer.isEmpty()) {
return String.join("\n", dataBuffer);
}
return null;
}
if (line.isEmpty()) {
if (!dataBuffer.isEmpty()) {
// Handle multi-line SSE data
return String.join("\n", dataBuffer);
}
continue;
}
if (line.startsWith(":")) {
continue;
}
int colonIndex = line.indexOf(':');
String fieldname = line;
String value = "";
if (colonIndex != -1) {
fieldname = line.substring(0, colonIndex);
value = line.substring(colonIndex + 1);
if (value.startsWith(" ")) {
value = value.substring(1);
}
}

if (fieldname.equals("data")) {
dataBuffer.add(value);
}
}
} catch (IOException e) {
throw new GenAiIOException("Failed to read next JSON object from the stream", e);
Expand Down
12 changes: 10 additions & 2 deletions src/test/java/com/google/genai/AsyncChatTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,16 @@ public class AsyncChatTest {
String jsonChunk3 = responseChunk3.toJson();

String streamData =
"data: " + jsonChunk1 + "\n" + "data: " + jsonChunk2 + "\n" + "data: " + jsonChunk3 + "\n";
String streamData2 = "data: " + jsonChunk1 + "\n" + "data: " + jsonChunk2 + "\n";
"data: "
+ jsonChunk1
+ "\n\n"
+ "data: "
+ jsonChunk2
+ "\n\n"
+ "data: "
+ jsonChunk3
+ "\n\n";
String streamData2 = "data: " + jsonChunk1 + "\n\n" + "data: " + jsonChunk2 + "\n\n";

GenerateContentResponse nonStreamingResponse =
GenerateContentResponse.builder()
Expand Down
12 changes: 10 additions & 2 deletions src/test/java/com/google/genai/ChatTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,16 @@ public static String findTheaters(String movie, String location, String time) {
String jsonChunk3 = responseChunk3.toJson();

String streamData =
"data: " + jsonChunk1 + "\n" + "data: " + jsonChunk2 + "\n" + "data: " + jsonChunk3 + "\n";
String streamData2 = "data: " + jsonChunk1 + "\n" + "data: " + jsonChunk2 + "\n";
"data: "
+ jsonChunk1
+ "\n\n"
+ "data: "
+ jsonChunk2
+ "\n\n"
+ "data: "
+ jsonChunk3
+ "\n\n";
String streamData2 = "data: " + jsonChunk1 + "\n\n" + "data: " + jsonChunk2 + "\n\n";

GenerateContentResponse nonStreamingResponse =
GenerateContentResponse.builder()
Expand Down
165 changes: 165 additions & 0 deletions src/test/java/com/google/genai/ResponseStreamTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.genai;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.genai.types.Candidate;
import com.google.genai.types.GenerateContentResponse;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import okhttp3.Headers;
import okhttp3.MediaType;
import okhttp3.ResponseBody;
import org.junit.jupiter.api.Test;

public final class ResponseStreamTest {

public static class DummyConverter {
public JsonNode convert(JsonNode fromObject, ObjectNode parentObject) {
return fromObject;
}
}

@Test
public void testMultiLineSseParsing() throws Exception {
String sseData =
"data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"line1\\n\"}]}},\n"
+ "data: {\"content\": {\"parts\": [{\"text\": \"line2\"}]}}]}\n"
+ "\n"; // End of event

ResponseBody body =
ResponseBody.create(
sseData.getBytes(StandardCharsets.UTF_8), MediaType.parse("text/event-stream"));
FakeApiResponse response = new FakeApiResponse(Headers.of(), body);

DummyConverter converter = new DummyConverter();
ResponseStream<GenerateContentResponse> responseStream =
new ResponseStream<>(GenerateContentResponse.class, response, converter, "convert");

Iterator<GenerateContentResponse> iterator = responseStream.iterator();

assertTrue(iterator.hasNext());
GenerateContentResponse response1 = iterator.next();

assertTrue(response1.candidates().isPresent());
assertEquals(2, response1.candidates().get().size());

Candidate c1 = response1.candidates().get().get(0);
assertEquals("line1\n", c1.content().get().text());

Candidate c2 = response1.candidates().get().get(1);
assertEquals("line2", c2.content().get().text());

assertTrue(!iterator.hasNext());
}

@Test
public void testIgnoreNonSseLines() throws Exception {
String sseData =
"data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"valid data\"}]}}]}\n"
+ ": some comment line\n"
+ "ignored field: some value\n"
+ "\n"; // End of event

ResponseBody body =
ResponseBody.create(
sseData.getBytes(StandardCharsets.UTF_8), MediaType.parse("text/event-stream"));
FakeApiResponse response = new FakeApiResponse(Headers.of(), body);

DummyConverter converter = new DummyConverter();
ResponseStream<GenerateContentResponse> responseStream =
new ResponseStream<>(GenerateContentResponse.class, response, converter, "convert");

Iterator<GenerateContentResponse> iterator = responseStream.iterator();

assertTrue(iterator.hasNext());
GenerateContentResponse response1 = iterator.next();

assertEquals("valid data", response1.text());

assertTrue(!iterator.hasNext());
}

@Test
public void testStreamErrorHandling() throws Exception {
String sseData =
"data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"valid data\"}]}}]}\n"
+ "\n"
+ "data: {\"error\": {\"code\": 429, \"message\": \"Quota exceeded\", \"status\": \"RESOURCE_EXHAUSTED\"}}\n"
+ "\n";

ResponseBody body =
ResponseBody.create(
sseData.getBytes(StandardCharsets.UTF_8), MediaType.parse("text/event-stream"));
FakeApiResponse response = new FakeApiResponse(Headers.of(), body);

DummyConverter converter = new DummyConverter();
ResponseStream<GenerateContentResponse> responseStream =
new ResponseStream<>(GenerateContentResponse.class, response, converter, "convert");

Iterator<GenerateContentResponse> iterator = responseStream.iterator();

assertTrue(iterator.hasNext());
GenerateContentResponse response1 = iterator.next();
assertEquals("valid data", response1.text());

assertTrue(iterator.hasNext());

try {
iterator.next();
org.junit.jupiter.api.Assertions.fail("Expected ApiException was not thrown");
} catch (com.google.genai.errors.ApiException e) {
assertEquals(429, e.code());
assertEquals("RESOURCE_EXHAUSTED", e.status());
assertTrue(e.getMessage().contains("Quota exceeded"));
}
}

@Test
public void testMultipleValidEvents() throws Exception {
String sseData =
"data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"chunk1\"}]}}]}\n"
+ "\n"
+ "data: {\"candidates\": [{\"content\": {\"parts\": [{\"text\": \"chunk2\"}]}}]}\n"
+ "\n";

ResponseBody body =
ResponseBody.create(
sseData.getBytes(StandardCharsets.UTF_8), MediaType.parse("text/event-stream"));
FakeApiResponse response = new FakeApiResponse(Headers.of(), body);

DummyConverter converter = new DummyConverter();
ResponseStream<GenerateContentResponse> responseStream =
new ResponseStream<>(GenerateContentResponse.class, response, converter, "convert");

Iterator<GenerateContentResponse> iterator = responseStream.iterator();

assertTrue(iterator.hasNext());
GenerateContentResponse response1 = iterator.next();
assertEquals("chunk1", response1.text());

assertTrue(iterator.hasNext());
GenerateContentResponse response2 = iterator.next();
assertEquals("chunk2", response2.text());

assertTrue(!iterator.hasNext());
}
}
Loading