diff --git a/aws/aws-event-streams/src/main/java/software/amazon/smithy/java/aws/events/AwsEventShapeDecoder.java b/aws/aws-event-streams/src/main/java/software/amazon/smithy/java/aws/events/AwsEventShapeDecoder.java index 3ec3829bd..9903d37f3 100644 --- a/aws/aws-event-streams/src/main/java/software/amazon/smithy/java/aws/events/AwsEventShapeDecoder.java +++ b/aws/aws-event-streams/src/main/java/software/amazon/smithy/java/aws/events/AwsEventShapeDecoder.java @@ -170,7 +170,8 @@ static class EventStreamDeserializer extends SpecificShapeDeserializer { public void readStruct(Schema schema, T builder, ShapeDeserializer.StructMemberConsumer consumer) { var payloadWritten = false; for (Schema member : schema.members()) { - if (member.hasTrait(TraitKey.EVENT_HEADER_TRAIT)) { + if (member.hasTrait(TraitKey.EVENT_HEADER_TRAIT) + && headersDeserializer.headers.containsKey(member.memberName())) { consumer.accept(builder, member, headersDeserializer); } else if (member.hasTrait(TraitKey.EVENT_PAYLOAD_TRAIT)) { consumer.accept(builder, member, codecDeserializer); diff --git a/aws/client/aws-client-restjson/src/it/java/software/amazon/smithy/java/client/aws/restjson/RestJson1ProtocolTests.java b/aws/client/aws-client-restjson/src/it/java/software/amazon/smithy/java/client/aws/restjson/RestJson1ProtocolTests.java index c689c08a4..640cac6cb 100644 --- a/aws/client/aws-client-restjson/src/it/java/software/amazon/smithy/java/client/aws/restjson/RestJson1ProtocolTests.java +++ b/aws/client/aws-client-restjson/src/it/java/software/amazon/smithy/java/client/aws/restjson/RestJson1ProtocolTests.java @@ -11,6 +11,7 @@ import java.nio.charset.StandardCharsets; import software.amazon.smithy.java.io.ByteBufferUtils; import software.amazon.smithy.java.io.datastream.DataStream; +import software.amazon.smithy.java.protocoltests.harness.EventStreamClientTests; import software.amazon.smithy.java.protocoltests.harness.HttpClientRequestTests; import software.amazon.smithy.java.protocoltests.harness.HttpClientResponseTests; import software.amazon.smithy.java.protocoltests.harness.ProtocolTest; @@ -26,6 +27,40 @@ skipOperations = { // We dont ignore defaults on input shapes "aws.protocoltests.restjson#OperationWithDefaults", + }, + skipTests = { + // Need to add exception type to header + "ClientErrorInput", + "DuplexClientErrorInput", + // Currently we are using JSON codec for plain text payload, need to correct it. + "StringPayloadOutput", + "DuplexStringPayloadOutput", + // eventstream:1.0.1 made ByteValue.encodeValue() and + // ShortValue.encodeValue() no-ops, producing malformed frames. + "ByteHeaderInput", + "DuplexByteHeaderInput", + "ShortHeaderInput", + "DuplexShortHeaderInput", + // Blob test params use inconsistent encoding conventions — + // headers use base64, payloads use raw strings. + "BlobPayloadInput", + "DuplexBlobPayloadInput", + "BlobPayloadOutput", + "DuplexBlobPayloadOutput", + "BlobHeaderInput", + "DuplexBlobHeaderInput", + "BlobHeaderOutput", + "DuplexBlobHeaderOutput", + "MultipleHeaderInput", + "DuplexMultipleHeaderInput", + "MultipleHeaderOutput", + "DuplexMultipleHeaderOutput", + // Decoder returns modeled error events instead of throwing + "ClientErrorOutput", + "DuplexClientErrorOutput", + // Client doesn't validate missing @required initial response members + "MissingRequiredInitialResponseOutput", + "DuplexMissingRequiredInitialResponseOutput" }) public class RestJson1ProtocolTests { private static final String EMPTY_BODY = ""; @@ -57,4 +92,9 @@ public void requestTest(DataStream expected, DataStream actual) { public void responseTest(Runnable test) { test.run(); } + + @EventStreamClientTests + public void eventStreamClientTest(Runnable test) { + test.run(); + } } diff --git a/protocol-test-harness/build.gradle.kts b/protocol-test-harness/build.gradle.kts index 2f15acb9e..c12be1ebe 100644 --- a/protocol-test-harness/build.gradle.kts +++ b/protocol-test-harness/build.gradle.kts @@ -19,6 +19,7 @@ dependencies { implementation(project(":client:client-http")) implementation(project(":codecs:json-codec", configuration = "shadow")) implementation(libs.assertj.core) + implementation(project(":aws:aws-event-streams")) api(platform(libs.junit.bom)) api(libs.junit.jupiter.api) diff --git a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/Assertions.java b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/Assertions.java index bbd4e05e2..a733bcbf4 100644 --- a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/Assertions.java +++ b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/Assertions.java @@ -5,18 +5,35 @@ package software.amazon.smithy.java.protocoltests.harness; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import software.amazon.eventstream.HeaderValue; +import software.amazon.eventstream.MessageDecoder; +import software.amazon.smithy.java.core.error.ModeledException; import software.amazon.smithy.java.http.api.HttpMessage; import software.amazon.smithy.java.http.api.HttpRequest; import software.amazon.smithy.java.io.uri.SmithyUri; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.model.node.ObjectNode; +import software.amazon.smithy.model.node.StringNode; import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase; +import software.amazon.smithy.protocoltests.traits.TestFailureExpectation; +import software.amazon.smithy.protocoltests.traits.eventstream.Event; +import software.amazon.smithy.protocoltests.traits.eventstream.EventHeaderValue; +import software.amazon.smithy.protocoltests.traits.eventstream.EventStreamTestCase; /** * Provides a number of testing utilities for validating protocol test results. @@ -52,7 +69,7 @@ static void assertUriEquals(HttpRequestTestCase testCase, SmithyUri uri) { } private static void assertQueryParamsEquals(List expectedParams, String actualQuery) { - var expectedSet = paserQueryParamsList(expectedParams); + var expectedSet = parseQueryParamsList(expectedParams); var actualSet = parseQueryParamsString(actualQuery); assertEquals(expectedSet, actualSet, "Query parameters mismatch"); } @@ -67,7 +84,7 @@ private static Set parseQueryParamsString(String query) { return result; } - private static Set paserQueryParamsList(List params) { + private static Set parseQueryParamsList(List params) { Set result = new HashSet<>(); for (String paramPair : params) { var pair = paramPair.split("=", 2); @@ -106,4 +123,123 @@ private static String convertHeaderToString(String key, List values) { return value; }).collect(Collectors.joining(", ")); } + + static void assertEventHeaderEquals(String key, EventHeaderValue expected, HeaderValue actual) { + switch (expected.getType()) { + case BOOLEAN -> assertEquals(expected.asBoolean(), actual.getBoolean(), key); + case BYTE -> assertEquals(expected.asByte(), actual.getByte(), key); + case SHORT -> assertEquals(expected.asShort(), actual.getShort(), key); + case INTEGER -> assertEquals(expected.asInteger(), actual.getInteger(), key); + case LONG -> assertEquals(expected.asLong(), actual.getLong(), key); + case STRING -> assertEquals(expected.asString(), actual.getString(), key); + case BLOB -> assertArrayEquals(expected.asBlob(), actual.getByteArray(), key); + case TIMESTAMP -> assertEquals(expected.asTimestamp(), actual.getTimestamp(), key); + } + } + + static void assertEventStreamRequestEquals(HttpRequest request, Event event) { + var bodyBytes = request.body().asByteBuffer(); + var decoder = new MessageDecoder(); + decoder.feed(bodyBytes.duplicate()); + var messages = decoder.getDecodedMessages(); + var message = messages.getFirst(); + var actualHeaders = message.getHeaders(); + for (var entry : event.getHeaders().entrySet()) { + var key = entry.getKey(); + var expected = entry.getValue(); + var actual = actualHeaders.get(key); + assertThat(actual).as("Missing header: " + key).isNotNull(); + Assertions.assertEventHeaderEquals(key, expected, actual); + } + for (var header : event.getForbidHeaders()) { + assertFalse(actualHeaders.containsKey(header)); + } + for (var header : event.getRequireHeaders()) { + assertTrue(actualHeaders.containsKey(header)); + } + event.getBody().ifPresent(expectedBody -> { + assertEventStreamBodyEquals(expectedBody, + new String(message.getPayload(), StandardCharsets.UTF_8), + event.getBodyMediaType().orElse(null)); + }); + } + + static void assertInitialRequestEquals(EventStreamTestCase testCase, HttpRequest request) { + if (testCase.getInitialRequest().isPresent()) { + var initialRequest = testCase.getInitialRequest().get(); + assertEquals(initialRequest.expectStringMember("uri").getValue(), request.uri().getPath()); + assertEquals(initialRequest.expectStringMember("method").getValue(), request.method()); + initialRequest.getStringMember("resolvedHost").ifPresent(host -> { + assertEquals(host.getValue(), request.uri().getHost()); + }); + var actualQueryParams = request.uri().getQuery(); + if (actualQueryParams != null) { + assertInitialRequestQueryMatches(initialRequest, actualQueryParams); + } + assertInitialRequestHeaderMatches(initialRequest, request); + initialRequest.getStringMember("body").ifPresent(bodyNode -> { + assertEventStreamBodyEquals(bodyNode.getValue(), + new StringBuildingSubscriber(request.body()).getResult(), + initialRequest.getStringMember("bodyMediaType").map(StringNode::getValue).orElse(null)); + }); + } + } + + private static void assertInitialRequestQueryMatches(ObjectNode initialRequest, String actualQuery) { + initialRequest.getArrayMember("queryParams").ifPresent(params -> { + assertQueryParamsEquals(params.getElementsAs(StringNode::getValue), actualQuery); + }); + var queryParamSet = parseQueryParamsString(actualQuery); + initialRequest.getArrayMember("forbidQueryParams").ifPresent(params -> { + for (var param : params.getElementsAs(StringNode::getValue)) { + assertFalse(queryParamSet.contains(param)); + } + }); + + initialRequest.getArrayMember("requireQueryParams").ifPresent(params -> { + for (var param : params.getElementsAs(StringNode::getValue)) { + assertTrue(queryParamSet.contains(param)); + } + }); + } + + private static void assertInitialRequestHeaderMatches(ObjectNode initialRequest, HttpRequest actualRequest) { + var actualHeaders = actualRequest.headers().map(); + initialRequest.getObjectMember("headers").ifPresent(headersNode -> { + Map headers = new HashMap<>(); + headersNode.getStringMap().forEach((k, v) -> headers.put(k, v.expectStringNode().getValue())); + assertHeadersEqual(actualRequest, headers); + }); + + initialRequest.getArrayMember("forbidHeaders").ifPresent(headers -> { + for (var header : headers.getElementsAs(StringNode::getValue)) { + assertFalse(actualHeaders.containsKey(header)); + } + }); + + initialRequest.getArrayMember("requireHeaders").ifPresent(headers -> { + for (var header : headers.getElementsAs(StringNode::getValue)) { + assertTrue(actualHeaders.containsKey(header)); + } + }); + } + + private static void assertEventStreamBodyEquals(String expectedBody, String actualBody, String bodyType) { + if ("application/json".equals(bodyType)) { + Node.assertEquals(Node.parse(expectedBody), Node.parse(actualBody)); + } else { + assertEquals(expectedBody, actualBody); + } + } + + static void assertExpectationEquals(EventStreamTestCase testCase, Throwable e) { + testCase.getExpectation() + .getFailure() + .flatMap(TestFailureExpectation::getErrorId) + .ifPresent(errorId -> { + assertInstanceOf(ModeledException.class, e); + assertEquals(errorId.getName(), + ((ModeledException) e).schema().id().getName()); + }); + } } diff --git a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/EventStreamClientTests.java b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/EventStreamClientTests.java new file mode 100644 index 000000000..47f608174 --- /dev/null +++ b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/EventStreamClientTests.java @@ -0,0 +1,23 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.protocoltests.harness; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.platform.commons.annotation.Testable; + +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@TestTemplate +@Testable +@Timeout(5) +@ExtendWith(EventStreamClientTestsProtocolTestProvider.class) +public @interface EventStreamClientTests {} diff --git a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/EventStreamClientTestsProtocolTestProvider.java b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/EventStreamClientTestsProtocolTestProvider.java new file mode 100644 index 000000000..0cd994c68 --- /dev/null +++ b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/EventStreamClientTestsProtocolTestProvider.java @@ -0,0 +1,394 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.protocoltests.harness; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.junit.jupiter.api.Assertions.fail; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; +import java.util.stream.Stream; +import org.junit.jupiter.api.extension.Extension; +import org.junit.jupiter.api.extension.TestTemplateInvocationContext; +import software.amazon.eventstream.HeaderValue; +import software.amazon.eventstream.Message; +import software.amazon.smithy.java.client.core.ClientTransport; +import software.amazon.smithy.java.client.core.MessageExchange; +import software.amazon.smithy.java.client.core.RequestOverrideConfig; +import software.amazon.smithy.java.client.core.auth.scheme.AuthSchemeResolver; +import software.amazon.smithy.java.client.http.HttpMessageExchange; +import software.amazon.smithy.java.context.Context; +import software.amazon.smithy.java.core.schema.ApiOperation; +import software.amazon.smithy.java.core.schema.SerializableStruct; +import software.amazon.smithy.java.core.serde.event.EventStream; +import software.amazon.smithy.java.core.serde.event.EventStreamReader; +import software.amazon.smithy.java.http.api.HttpHeaders; +import software.amazon.smithy.java.http.api.HttpRequest; +import software.amazon.smithy.java.http.api.HttpResponse; +import software.amazon.smithy.java.http.api.HttpVersion; +import software.amazon.smithy.java.io.datastream.DataStream; +import software.amazon.smithy.model.node.ObjectNode; +import software.amazon.smithy.model.node.StringNode; +import software.amazon.smithy.protocoltests.traits.eventstream.Event; +import software.amazon.smithy.protocoltests.traits.eventstream.EventHeaderValue; +import software.amazon.smithy.protocoltests.traits.eventstream.EventStreamTestCase; +import software.amazon.smithy.protocoltests.traits.eventstream.EventType; + +/** + * Provides client test cases for {@link EventStreamTestCase}'s. See also the {@link EventStreamClientTests} annotation. + */ +final class EventStreamClientTestsProtocolTestProvider extends + ProtocolTestProvider { + + @Override + protected Class getAnnotationType() { + return EventStreamClientTests.class; + } + + @Override + protected Class getSharedTestDataType() { + return ProtocolTestExtension.SharedClientTestData.class; + } + + @Override + @SuppressWarnings("unchecked") + protected Stream generateProtocolTests( + ProtocolTestExtension.SharedClientTestData store, + EventStreamClientTests annotation, + TestFilter filter + ) { + return store.operations() + .stream() + .flatMap(operation -> operation.eventStreamTestCases() + .stream() + .map(testCase -> { + if (filter.skipOperation(operation.id()) || filter.skipTestCase(testCase)) { + return new IgnoredTestCase(testCase.getId()); + } + var testProtocol = store.getProtocol(testCase.getProtocol()); + var placeholderTransport = + (MockClient.PlaceHolderTransport) store + .mockClient() + .config() + .transport(); + var overrideConfig = RequestOverrideConfig.builder() + .protocol(testProtocol) + .authSchemeResolver(AuthSchemeResolver.NO_AUTH) + .build(); + var writer = + operation.operationModel().inputStreamMember() != null ? EventStream.newWriter() + : null; + var input = buildInput(writer, + operation.operationModel(), + testCase.getInitialRequestParams()); + + if (testCase.getInitialRequest().isPresent()) { + var testTransport = new RequestTestTransport(); + placeholderTransport.setTransport(testTransport); + return new RequestTestInvocationContext( + testCase, + null, + store.mockClient(), + operation.operationModel(), + input, + null, + writer, + overrideConfig, + testTransport::getCapturedRequest); + } + + if (testCase.getInitialResponse().isPresent()) { + var testTransport = + new InitialResponseTestTransport(testCase.getInitialResponse().get()); + placeholderTransport.setTransport(testTransport); + var outputBuilder = operation.operationModel().outputBuilder(); + testCase.getInitialResponseParams() + .ifPresent(params -> new ProtocolTestDocument(params, null) + .deserializeInto(outputBuilder)); + return new ResponseTestInvocationContext( + testCase, + null, + store.mockClient(), + operation.operationModel(), + input, + outputBuilder.errorCorrection().build(), + writer, + overrideConfig); + } + + var event = testCase.getEvents().getFirst(); // Currently each test case only has one event. + if (event.getType().equals(EventType.REQUEST)) { + var testTransport = new RequestTestTransport(); + placeholderTransport.setTransport(testTransport); + var eventBuilder = operation.operationModel().inputEventBuilderSupplier().get(); + event.getParams() + .ifPresent(params -> new ProtocolTestDocument(params, null) + .deserializeInto(eventBuilder)); + return new RequestTestInvocationContext( + testCase, + event, + store.mockClient(), + operation.operationModel(), + input, + eventBuilder.build(), + writer, + overrideConfig, + testTransport::getCapturedRequest); + } else { + SerializableStruct expectedEvent = null; + if (event.getParams().isPresent()) { + var eventBuilder = operation.operationModel().outputEventBuilderSupplier().get(); + new ProtocolTestDocument(event.getParams().get(), null) + .deserializeInto(eventBuilder); + expectedEvent = eventBuilder.build(); + } + var testTransport = new ResponseTestTransport(event); + placeholderTransport.setTransport(testTransport); + return new ResponseTestInvocationContext( + testCase, + event, + store.mockClient(), + operation.operationModel(), + input, + expectedEvent, + writer, + overrideConfig); + } + })); + } + + private record RequestTestInvocationContext( + EventStreamTestCase testCase, + Event event, + MockClient mockClient, + ApiOperation apiOperation, + SerializableStruct input, + SerializableStruct expected, + EventStream writer, + RequestOverrideConfig overrideConfig, + Supplier requestSupplier) implements TestTemplateInvocationContext { + + @Override + public String getDisplayName(int invocationIndex) { + return testCase.getId(); + } + + @Override + public List getAdditionalExtensions() { + return List.of((ProtocolTestParameterResolver) () -> { + if (event != null) { // normal request event. + Thread.ofVirtual().start(() -> { + try (var w = writer.asWriter()) { + w.write(expected); + } + }); + } + try { + mockClient.clientRequest(input, apiOperation, overrideConfig); + var request = requestSupplier.get(); + if (event != null) { + Assertions.assertEventStreamRequestEquals(request, event); + } else { + Assertions.assertInitialRequestEquals(testCase, request); + } + } finally { + writer.close(); + } + }); + } + } + + private record ResponseTestInvocationContext( + EventStreamTestCase testCase, + Event event, + MockClient mockClient, + ApiOperation apiOperation, + SerializableStruct input, + SerializableStruct expected, + EventStream writer, + RequestOverrideConfig overrideConfig) implements TestTemplateInvocationContext { + + @Override + public String getDisplayName(int invocationIndex) { + return testCase.getId(); + } + + @Override + public List getAdditionalExtensions() { + return List.of((ProtocolTestParameterResolver) () -> { + try { + var output = mockClient.clientRequest(input, apiOperation, overrideConfig); + var actual = output; + if (event != null) { // Normal response event + EventStreamReader reader = + output.getMemberValue(apiOperation.outputStreamMember()); + actual = reader.read(); + } + if (testCase.getExpectation().isFailure()) { + fail("Expected failure but got: " + actual); + } + // Ignore stream field for initial response comparison. + assertThat(actual) + .usingRecursiveComparison(ComparisonUtils.getComparisonConfig()) + .ignoringFields("stream") + .isEqualTo(expected); + } catch (Exception e) { + if (testCase.getExpectation().isFailure()) { + Assertions.assertExpectationEquals(testCase, e); + return; + } + throw e; + } finally { + if (writer != null) { + writer.close(); + } + } + }); + } + } + + private static final class RequestTestTransport implements ClientTransport { + private static final HttpResponse DUMMY_RESPONSE = HttpResponse.create() + .setStatusCode(555) + .toUnmodifiable(); + + private HttpRequest capturedRequest; + + public HttpRequest getCapturedRequest() { + return capturedRequest; + } + + @Override + public HttpResponse send(Context context, HttpRequest request) { + this.capturedRequest = request; + return DUMMY_RESPONSE; + } + + @Override + public MessageExchange messageExchange() { + return HttpMessageExchange.INSTANCE; + } + } + + private record ResponseTestTransport(Event event) implements ClientTransport { + + private static byte[] buildFrameBytes(Event event) { + var headers = new HashMap(); + for (var entry : event.getHeaders().entrySet()) { + headers.put(entry.getKey(), toHeaderValue(entry.getValue())); + } + byte[] payload = event.getBody() + .map(b -> decodeBody(b, event.getBodyMediaType())) + .orElse(new byte[0]); + var message = new Message(headers, payload); + var buf = message.toByteBuffer(); + var bytes = new byte[buf.remaining()]; + buf.get(bytes); + return bytes; + } + + private static HeaderValue toHeaderValue(EventHeaderValue value) { + return switch (value.getType()) { + case BOOLEAN -> + HeaderValue.fromBoolean(value.asBoolean()); + case BYTE -> HeaderValue.fromByte(value.asByte()); + case SHORT -> HeaderValue.fromShort(value.asShort()); + case INTEGER -> + HeaderValue.fromInteger(value.asInteger()); + case LONG -> HeaderValue.fromLong(value.asLong()); + case BLOB -> HeaderValue.fromByteArray(value.asBlob()); + case STRING -> HeaderValue.fromString(value.asString()); + case TIMESTAMP -> + HeaderValue.fromTimestamp(value.asTimestamp()); + }; + } + + @Override + public HttpResponse send(Context context, HttpRequest request) { + byte[] frameBytes; + if (event.getBytes().isPresent()) { + frameBytes = event.getBytes().get(); + } else { + frameBytes = buildFrameBytes(event); + } + + return HttpResponse.create() + .setHttpVersion(HttpVersion.HTTP_1_1) + .setStatusCode(200) + .setHeaders(HttpHeaders.of(Map.of( + "content-type", + List.of("application/vnd.amazon.eventstream")))) + .setBody(DataStream.ofBytes(frameBytes, "application/vnd.amazon.eventstream")) + .toUnmodifiable(); + } + + @Override + public MessageExchange messageExchange() { + return HttpMessageExchange.INSTANCE; + } + } + + private record InitialResponseTestTransport(ObjectNode response) + implements ClientTransport { + + @Override + public HttpResponse send(Context context, HttpRequest request) { + var builder = HttpResponse.create() + .setHttpVersion(HttpVersion.HTTP_1_1) + .setStatusCode(response.expectNumberMember("code").getValue().intValue()); + + response.getObjectMember("headers").ifPresent(headers -> { + Map> headerMap = new HashMap<>(); + for (var headerEntry : headers.getMembers().entrySet()) { + headerMap.put(headerEntry.getKey().getValue(), + List.of(headerEntry.getValue().expectStringNode().getValue())); + } + response.getStringMember("bodyMediaType") + .ifPresent(mediaType -> headerMap.put("content-type", List.of(mediaType.getValue()))); + builder.setHeaders(HttpHeaders.of(headerMap)); + }); + + response.getStringMember("body").ifPresent(body -> { + var mediaType = response.getStringMember("bodyMediaType") + .map(StringNode::getValue); + builder.setBody(DataStream.ofBytes( + decodeBody(body.getValue(), mediaType), + mediaType.orElse(null))); + }); + return builder.toUnmodifiable(); + } + + @Override + public MessageExchange messageExchange() { + return HttpMessageExchange.INSTANCE; + } + } + + private static byte[] decodeBody(String body, Optional mediaType) { + return mediaType + .filter(ProtocolTestProvider::isBinaryMediaType) + .map(type -> Base64.getDecoder().decode(body)) + .orElseGet(() -> body.getBytes(StandardCharsets.UTF_8)); + } + + private static SerializableStruct buildInput( + EventStream writer, + ApiOperation operation, + Optional params + ) { + var inputBuilder = operation.inputBuilder(); + if (writer != null) { + inputBuilder.setMemberValue(operation.inputStreamMember(), writer); + } + params.ifPresent(p -> new ProtocolTestDocument(p, null).deserializeInto(inputBuilder)); + return inputBuilder.errorCorrection().build(); + } +} diff --git a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpClientRequestProtocolTestProvider.java b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpClientRequestProtocolTestProvider.java index 3cf4bd6b8..cfaf8541f 100644 --- a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpClientRequestProtocolTestProvider.java +++ b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpClientRequestProtocolTestProvider.java @@ -62,7 +62,7 @@ protected Stream generateProtocolTests( .stream() .map(testCase -> { if (filter.skipOperation(operation.id()) || filter.skipTestCase(testCase)) { - return new IgnoredTestCase(testCase); + return new IgnoredTestCase(testCase.getId()); } var testProtocol = store.getProtocol(testCase.getProtocol()); var testResolver = testCase.getAuthScheme().isEmpty() diff --git a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpClientResponseProtocolTestProvider.java b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpClientResponseProtocolTestProvider.java index e5e685b47..8065e8f5f 100644 --- a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpClientResponseProtocolTestProvider.java +++ b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpClientResponseProtocolTestProvider.java @@ -64,7 +64,7 @@ protected Stream generateProtocolTests( .map(protocolTestCase -> { var testCase = protocolTestCase.responseTestCase(); if (filter.skipOperation(operation.id()) || filter.skipTestCase(testCase)) { - return new IgnoredTestCase(testCase); + return new IgnoredTestCase(testCase.getId()); } boolean isErrorTestCase = protocolTestCase.isErrorTest(); // Get specific values to use for this test case's context diff --git a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpServerRequestProtocolTestProvider.java b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpServerRequestProtocolTestProvider.java index 9e83b4ce9..17d833294 100644 --- a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpServerRequestProtocolTestProvider.java +++ b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpServerRequestProtocolTestProvider.java @@ -55,7 +55,7 @@ protected Stream generateProtocolTests( boolean shouldSkip = testFilter.skipOperation(testOperation.id()); for (var testCase : testOperation.requestTestCases()) { if (shouldSkip || testFilter.skipTestCase(testCase)) { - invocationContexts.add(new IgnoredTestCase(testCase)); + invocationContexts.add(new IgnoredTestCase(testCase.getId())); continue; } var createUri = createUri(testData.endpoint(), testCase.getUri(), testCase.getQueryParams()); diff --git a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpServerResponseProtocolTestProvider.java b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpServerResponseProtocolTestProvider.java index 3c1b2b777..fa1d4077f 100644 --- a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpServerResponseProtocolTestProvider.java +++ b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpServerResponseProtocolTestProvider.java @@ -55,7 +55,7 @@ protected Stream generateProtocolTests( for (var protocolTestCase : testOperation.responseTestCases()) { var testCase = protocolTestCase.responseTestCase(); if (shouldSkip || filter.skipTestCase(testCase)) { - invocationContexts.add(new IgnoredTestCase(testCase)); + invocationContexts.add(new IgnoredTestCase(testCase.getId())); continue; } Map> headers = new HashMap<>(); diff --git a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpTestOperation.java b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpTestOperation.java index f6e6bbdf0..8c3faca62 100644 --- a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpTestOperation.java +++ b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/HttpTestOperation.java @@ -10,6 +10,7 @@ import software.amazon.smithy.model.shapes.ShapeId; import software.amazon.smithy.protocoltests.traits.HttpMalformedRequestTestCase; import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase; +import software.amazon.smithy.protocoltests.traits.eventstream.EventStreamTestCase; /** * Data class holding information needed to execute a protocol test for a given operation. @@ -27,4 +28,5 @@ record HttpTestOperation( ApiOperation operationModel, List requestTestCases, List responseTestCases, - List malformedRequestTestCases) {} + List malformedRequestTestCases, + List eventStreamTestCases) {} diff --git a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/IgnoredTestCase.java b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/IgnoredTestCase.java index 3f0b9caa7..afa96db6f 100644 --- a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/IgnoredTestCase.java +++ b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/IgnoredTestCase.java @@ -10,13 +10,12 @@ import org.junit.jupiter.api.extension.ExecutionCondition; import org.junit.jupiter.api.extension.Extension; import org.junit.jupiter.api.extension.TestTemplateInvocationContext; -import software.amazon.smithy.protocoltests.traits.HttpMessageTestCase; -record IgnoredTestCase(HttpMessageTestCase testCase) implements TestTemplateInvocationContext { +record IgnoredTestCase(String testId) implements TestTemplateInvocationContext { @Override public String getDisplayName(int invocationIndex) { - return testCase.getId(); + return testId; } @Override diff --git a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/ProtocolTestDocument.java b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/ProtocolTestDocument.java index 14f4572d2..45054cc82 100644 --- a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/ProtocolTestDocument.java +++ b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/ProtocolTestDocument.java @@ -190,6 +190,8 @@ public Instant asTimestamp() { return node.asNumberNode() .map(NumberNode::getValue) .map(ProtocolTestDocument::readNumberTimestamp) + .or(() -> node.asStringNode() + .map(s -> Instant.parse(s.getValue()))) .orElseGet(Document.super::asTimestamp); } diff --git a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/ProtocolTestExtension.java b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/ProtocolTestExtension.java index ee20d23ba..2dfb2a127 100644 --- a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/ProtocolTestExtension.java +++ b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/ProtocolTestExtension.java @@ -46,6 +46,8 @@ import software.amazon.smithy.protocoltests.traits.HttpRequestTestCase; import software.amazon.smithy.protocoltests.traits.HttpRequestTestsTrait; import software.amazon.smithy.protocoltests.traits.HttpResponseTestsTrait; +import software.amazon.smithy.protocoltests.traits.eventstream.EventStreamTestCase; +import software.amazon.smithy.protocoltests.traits.eventstream.EventStreamTestsTrait; import software.amazon.smithy.utils.SmithyInternalApi; /** @@ -301,6 +303,10 @@ private static List getTestOperations( .map(ap -> ap.equals(testType.appliesTo)) .orElse(true); + Predicate eventStreamTestTypeFiler = tc -> tc.getAppliesTo() + .map(ap -> ap.equals(testType.appliesTo)) + .orElse(true); + var symbolProvider = new JavaSymbolProvider( serviceModel, service, @@ -315,6 +321,7 @@ private static List getTestOperations( List requestTestsCases = new ArrayList<>(); List responseTestsCases = new ArrayList<>(); List malformedRequestTestCases = new ArrayList<>(); + List eventStreamTestCases = new ArrayList<>(); operation.getTrait(HttpRequestTestsTrait.class) .map(HttpRequestTestsTrait::getTestCases) .map(l -> l.stream().filter(testTypeFiler).toList()) @@ -347,6 +354,11 @@ private static List getTestOperations( operation.getTrait(HttpMalformedRequestTestsTrait.class) .map(HttpMalformedRequestTestsTrait::getTestCases) .ifPresent(malformedRequestTestCases::addAll); + + operation.getTrait(EventStreamTestsTrait.class) + .map(EventStreamTestsTrait::getTestCases) + .map(l -> l.stream().filter(eventStreamTestTypeFiler).toList()) + .ifPresent(eventStreamTestCases::addAll); result.add( new HttpTestOperation( operationId.toShapeId(), @@ -354,7 +366,8 @@ private static List getTestOperations( apiOperation, requestTestsCases, responseTestsCases, - malformedRequestTestCases)); + malformedRequestTestCases, + eventStreamTestCases)); } } diff --git a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/TestFilter.java b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/TestFilter.java index 0bec08557..db4f2f604 100644 --- a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/TestFilter.java +++ b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/TestFilter.java @@ -11,6 +11,7 @@ import java.util.stream.Collectors; import software.amazon.smithy.model.shapes.ShapeId; import software.amazon.smithy.protocoltests.traits.HttpMessageTestCase; +import software.amazon.smithy.protocoltests.traits.eventstream.EventStreamTestCase; /** * Test filter class that implements filtering of protocol tests. This filter is configured by the {@link ProtocolTestFilter} annotation. @@ -28,6 +29,11 @@ sealed interface TestFilter { */ boolean skipTestCase(HttpMessageTestCase testCase); + /** + * Filters event stream test cases. + */ + boolean skipTestCase(EventStreamTestCase testCase); + default TestFilter combine(TestFilter other) { return new CombinedTestFilter(this, other); } @@ -75,6 +81,12 @@ public boolean skipTestCase(HttpMessageTestCase testCase) { return skippedTests.contains(testCase.getId()) || (!tests.isEmpty() && !tests.contains(testCase.getId())); } + + @Override + public boolean skipTestCase(EventStreamTestCase testCase) { + return skippedTests.contains(testCase.getId()) + || (!tests.isEmpty() && !tests.contains(testCase.getId())); + } } final class EmptyFilter implements TestFilter { @@ -88,6 +100,11 @@ public boolean skipOperation(ShapeId operationId) { public boolean skipTestCase(HttpMessageTestCase testCase) { return false; } + + @Override + public boolean skipTestCase(EventStreamTestCase testCase) { + return false; + } } final class CombinedTestFilter implements TestFilter { @@ -109,5 +126,10 @@ public boolean skipOperation(ShapeId operationId) { public boolean skipTestCase(HttpMessageTestCase testCase) { return first.skipTestCase(testCase) || second.skipTestCase(testCase); } + + @Override + public boolean skipTestCase(EventStreamTestCase testCase) { + return first.skipTestCase(testCase) || second.skipTestCase(testCase); + } } }