diff --git a/aws/client/aws-client-awsjson/src/it/java/software/amazon/smithy/java/client/aws/jsonprotocols/AwsJson1ProtocolTests.java b/aws/client/aws-client-awsjson/src/it/java/software/amazon/smithy/java/client/aws/jsonprotocols/AwsJson1ProtocolTests.java index 36d4cf867..e4f302a00 100644 --- a/aws/client/aws-client-awsjson/src/it/java/software/amazon/smithy/java/client/aws/jsonprotocols/AwsJson1ProtocolTests.java +++ b/aws/client/aws-client-awsjson/src/it/java/software/amazon/smithy/java/client/aws/jsonprotocols/AwsJson1ProtocolTests.java @@ -20,10 +20,6 @@ public class AwsJson1ProtocolTests { @HttpClientRequestTests @ProtocolTestFilter( skipTests = { - // TODO: implement content-encoding - "SDKAppliedContentEncoding_awsJson1_0", - "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_0", - // Skipping top-level input defaults isn't necessary in Smithy-Java given it uses builders and // the defaults don't impact nullability. This applies to the following tests. "AwsJson10ClientSkipsTopLevelDefaultValuesInInput", @@ -42,8 +38,9 @@ public void requestTest(DataStream expected, DataStream actual) { Node.parse(new String(ByteBufferUtils.getBytes(expected.asByteBuffer()), StandardCharsets.UTF_8))); } - assertEquals(expectedJson, new StringBuildingSubscriber(actual).getResult()); - + if (expected.contentType() != null) { // Skip request compression tests since they do not have expected body + assertEquals(expectedJson, new StringBuildingSubscriber(actual).getResult()); + } } @HttpClientResponseTests diff --git a/aws/client/aws-client-restjson/src/it/java/software/amazon/smithy/java/client/aws/restjson/RestJson1ProtocolTests.java b/aws/client/aws-client-restjson/src/it/java/software/amazon/smithy/java/client/aws/restjson/RestJson1ProtocolTests.java index e5dc54b93..9aa3023b9 100644 --- a/aws/client/aws-client-restjson/src/it/java/software/amazon/smithy/java/client/aws/restjson/RestJson1ProtocolTests.java +++ b/aws/client/aws-client-restjson/src/it/java/software/amazon/smithy/java/client/aws/restjson/RestJson1ProtocolTests.java @@ -26,8 +26,6 @@ skipOperations = { // We dont ignore defaults on input shapes "aws.protocoltests.restjson#OperationWithDefaults", - // TODO: support content-encoding - "aws.protocoltests.restjson#PutWithContentEncoding" }) public class RestJson1ProtocolTests { private static final String EMPTY_BODY = ""; @@ -50,7 +48,7 @@ public void requestTest(DataStream expected, DataStream actual) { } else { assertEquals(expectedStr, actualStr); } - } else { + } else if (expected.contentType() != null) { // Skip request compression tests since they do not have expected body assertEquals(EMPTY_BODY, actualStr); } } diff --git a/aws/client/aws-client-restxml/src/it/java/software/amazon/smithy/java/aws/client/restxml/RestXmlProtocolTests.java b/aws/client/aws-client-restxml/src/it/java/software/amazon/smithy/java/aws/client/restxml/RestXmlProtocolTests.java index 021e0ba12..9e45c3dea 100644 --- a/aws/client/aws-client-restxml/src/it/java/software/amazon/smithy/java/aws/client/restxml/RestXmlProtocolTests.java +++ b/aws/client/aws-client-restxml/src/it/java/software/amazon/smithy/java/aws/client/restxml/RestXmlProtocolTests.java @@ -25,7 +25,6 @@ import software.amazon.smithy.java.protocoltests.harness.HttpClientRequestTests; import software.amazon.smithy.java.protocoltests.harness.HttpClientResponseTests; import software.amazon.smithy.java.protocoltests.harness.ProtocolTest; -import software.amazon.smithy.java.protocoltests.harness.ProtocolTestFilter; import software.amazon.smithy.java.protocoltests.harness.StringBuildingSubscriber; import software.amazon.smithy.java.protocoltests.harness.TestType; @@ -34,11 +33,6 @@ testType = TestType.CLIENT) public class RestXmlProtocolTests { @HttpClientRequestTests - @ProtocolTestFilter( - skipTests = { - "SDKAppliedContentEncoding_restXml", - "SDKAppendedGzipAfterProvidedEncoding_restXml", - }) public void requestTest(DataStream expected, DataStream actual) { if (expected.contentLength() != 0) { var a = new String(ByteBufferUtils.getBytes(actual.asByteBuffer()), StandardCharsets.UTF_8); @@ -51,7 +45,7 @@ public void requestTest(DataStream expected, DataStream actual) { } else { assertEquals(a, b); } - } else { + } else if (expected.contentType() != null) { // Skip request compression tests since they do not have expected body assertEquals("", new StringBuildingSubscriber(actual).getResult()); } } diff --git a/client/client-core/src/test/java/software/amazon/smithy/java/client/core/ClientTest.java b/client/client-core/src/test/java/software/amazon/smithy/java/client/core/ClientTest.java index e971b6d61..266ae3d15 100644 --- a/client/client-core/src/test/java/software/amazon/smithy/java/client/core/ClientTest.java +++ b/client/client-core/src/test/java/software/amazon/smithy/java/client/core/ClientTest.java @@ -35,6 +35,7 @@ import software.amazon.smithy.java.client.http.mock.MockQueue; import software.amazon.smithy.java.client.http.plugins.ApplyHttpRetryInfoPlugin; import software.amazon.smithy.java.client.http.plugins.HttpChecksumPlugin; +import software.amazon.smithy.java.client.http.plugins.RequestCompressionPlugin; import software.amazon.smithy.java.client.http.plugins.UserAgentPlugin; import software.amazon.smithy.java.core.serde.document.Document; import software.amazon.smithy.java.dynamicclient.DynamicClient; @@ -83,6 +84,7 @@ public class ClientTest { SimpleAuthDetectionPlugin.class, UserAgentPlugin.class, ApplyHttpRetryInfoPlugin.class, + RequestCompressionPlugin.class, HttpChecksumPlugin.class, FooPlugin.class); diff --git a/client/client-http/src/main/java/software/amazon/smithy/java/client/http/HttpContext.java b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/HttpContext.java index 845e78a11..63689e16d 100644 --- a/client/client-http/src/main/java/software/amazon/smithy/java/client/http/HttpContext.java +++ b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/HttpContext.java @@ -27,5 +27,17 @@ public final class HttpContext { public static final Context.Key ENDPOINT_RESOLVER_HTTP_HEADERS = Context.key( "HTTP headers to use with the request returned from an endpoint resolver"); + /** + * The minimum length of bytes threshold for a request body to be compressed. Defaults to 10240 bytes if not set. + */ + public static final Context.Key REQUEST_MIN_COMPRESSION_SIZE_BYTES = + Context.key("Minimum bytes size for request compression"); + + /** + * If request compression is disabled. + */ + public static final Context.Key DISABLE_REQUEST_COMPRESSION = + Context.key("If request compression is disabled"); + private HttpContext() {} } diff --git a/client/client-http/src/main/java/software/amazon/smithy/java/client/http/compression/CompressionAlgorithm.java b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/compression/CompressionAlgorithm.java new file mode 100644 index 000000000..02d3a72e9 --- /dev/null +++ b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/compression/CompressionAlgorithm.java @@ -0,0 +1,31 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.client.http.compression; + +import java.util.List; +import software.amazon.smithy.java.io.datastream.DataStream; +import software.amazon.smithy.utils.ListUtils; + +/** + * Represents a compression algorithm that can be used to compress request + * bodies. + */ +public interface CompressionAlgorithm { + /** + * The ID of the compression algorithm. This is matched against the algorithm + * names used in the trait e.g. "gzip" + */ + String algorithmId(); + + /** + * Compresses content of fixed length + */ + DataStream compress(DataStream data); + + static List supportedAlgorithms() { + return ListUtils.of(new Gzip()); + } +} diff --git a/client/client-http/src/main/java/software/amazon/smithy/java/client/http/compression/Gzip.java b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/compression/Gzip.java new file mode 100644 index 000000000..c650e8357 --- /dev/null +++ b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/compression/Gzip.java @@ -0,0 +1,40 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.client.http.compression; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.zip.GZIPOutputStream; +import software.amazon.smithy.java.io.ByteBufferOutputStream; +import software.amazon.smithy.java.io.datastream.DataStream; + +public final class Gzip implements CompressionAlgorithm { + + @Override + public String algorithmId() { + return "gzip"; + } + + @Override + public DataStream compress(DataStream data) { + if (!data.hasKnownLength()) { // Using streaming + return DataStream.ofInputStream( + new GzipCompressingInputStream(data.asInputStream()), + data.contentType(), + -1); + } + + try (var bos = new ByteBufferOutputStream(); + var in = data.asInputStream()) { + var gzip = new GZIPOutputStream(bos); + in.transferTo(gzip); + gzip.close(); + return DataStream.ofBytes(bos.toByteBuffer().array()); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } +} diff --git a/client/client-http/src/main/java/software/amazon/smithy/java/client/http/compression/GzipCompressingInputStream.java b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/compression/GzipCompressingInputStream.java new file mode 100644 index 000000000..1d90282b5 --- /dev/null +++ b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/compression/GzipCompressingInputStream.java @@ -0,0 +1,148 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.client.http.compression; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UncheckedIOException; +import java.util.zip.GZIPOutputStream; + +/** + * An InputStream that compresses data from a source InputStream using GZIP compression. + * This implementation lazily compress from the source data on-demand as it's read. + */ +final class GzipCompressingInputStream extends InputStream { + private static final int CHUNK_SIZE = 8192; + private final InputStream source; + private final ByteArrayOutputStream bufferStream; + private final GZIPOutputStream gzipStream; + private final byte[] chunk = new byte[CHUNK_SIZE]; + private byte[] buffer; + private int bufferPos; + private int bufferLimit; + private boolean sourceExhausted; + private boolean closed; + + public GzipCompressingInputStream(InputStream source) { + this.source = source; + this.bufferStream = new ByteArrayOutputStream(); + this.gzipStream = createGzipOutputStream(bufferStream); + this.buffer = new byte[0]; + this.bufferPos = 0; + this.bufferLimit = 0; + this.sourceExhausted = false; + this.closed = false; + } + + @Override + public int read() throws IOException { + byte[] b = new byte[1]; + int result = read(b, 0, 1); + return result == -1 ? -1 : (b[0] & 0xFF); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (closed) { + throw new IOException("Stream closed"); + } + + if (b == null) { + throw new NullPointerException("b"); + } else if (off < 0 || len < 0 || len > b.length - off) { + throw new IndexOutOfBoundsException(); + } else if (len == 0) { + return 0; + } + + // Try to fill the output buffer if it's empty + while (bufferPos >= bufferLimit) { + if (!fillBuffer()) { + return -1; // End of stream + } + } + + // Copy available data from buffer + int available = bufferLimit - bufferPos; + int toRead = Math.min(available, len); + System.arraycopy(buffer, bufferPos, b, off, toRead); + bufferPos += toRead; + + return toRead; + } + + /** + * Reads a chunk from the source, compresses it, and fills the internal buffer. + * + * @return true if data was added to buffer, false if end of stream reached + */ + private boolean fillBuffer() throws IOException { + if (sourceExhausted) { + return false; + } + + // Read a chunk from source + int bytesRead = source.read(chunk); + + if (bytesRead == -1) { + // Source is exhausted, finish compression + gzipStream.finish(); + sourceExhausted = true; + } else { + // Compress the chunk + gzipStream.write(chunk, 0, bytesRead); + gzipStream.flush(); + } + + // Get compressed data from buffer stream + byte[] compressed = bufferStream.toByteArray(); + if (compressed.length > 0) { + buffer = compressed; + bufferPos = 0; + bufferLimit = compressed.length; + bufferStream.reset(); + return true; + } + + if (sourceExhausted) { + return bufferPos >= bufferLimit; + } + return true; + } + + @Override + public void close() throws IOException { + if (!closed) { + closed = true; + try { + gzipStream.close(); + } finally { + source.close(); + } + } + } + + @Override + public int available() throws IOException { + if (closed) { + throw new IOException("Stream closed"); + } + return bufferLimit - bufferPos; + } + + /** + * Utility method to avoid having to throw the checked IOException exception. + */ + private GZIPOutputStream createGzipOutputStream(OutputStream bufferStream) { + try { + return new GZIPOutputStream(bufferStream); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } +} diff --git a/client/client-http/src/main/java/software/amazon/smithy/java/client/http/plugins/RequestCompressionPlugin.java b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/plugins/RequestCompressionPlugin.java new file mode 100644 index 000000000..51ad5f869 --- /dev/null +++ b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/plugins/RequestCompressionPlugin.java @@ -0,0 +1,103 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.client.http.plugins; + +import java.util.List; +import software.amazon.smithy.java.client.core.AutoClientPlugin; +import software.amazon.smithy.java.client.core.ClientConfig; +import software.amazon.smithy.java.client.core.interceptors.ClientInterceptor; +import software.amazon.smithy.java.client.core.interceptors.RequestHook; +import software.amazon.smithy.java.client.http.HttpContext; +import software.amazon.smithy.java.client.http.HttpMessageExchange; +import software.amazon.smithy.java.client.http.compression.CompressionAlgorithm; +import software.amazon.smithy.java.context.Context; +import software.amazon.smithy.java.core.schema.TraitKey; +import software.amazon.smithy.java.http.api.HttpRequest; +import software.amazon.smithy.java.io.datastream.DataStream; +import software.amazon.smithy.model.traits.RequestCompressionTrait; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * Compress the request body using provided compression algorithm if @requestCompression trait is applied. + */ +@SmithyInternalApi +public final class RequestCompressionPlugin implements AutoClientPlugin { + + @Override + public void configureClient(ClientConfig.Builder config) { + if (config.isUsingMessageExchange(HttpMessageExchange.INSTANCE)) { + config.addInterceptor(RequestCompressionInterceptor.INSTANCE); + } + } + + static final class RequestCompressionInterceptor implements ClientInterceptor { + + private static final int DEFAULT_MIN_COMPRESSION_SIZE_BYTES = 10240; + // This cap matches ApiGateway's spec: https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-openapi-minimum-compression-size.html + private static final int MIN_COMPRESSION_SIZE_CAP = 10485760; + private static final String CONTENT_ENCODING_HEADER = "Content-Encoding"; + private static final ClientInterceptor INSTANCE = new RequestCompressionInterceptor(); + private static final TraitKey REQUEST_COMPRESSION_TRAIT_KEY = + TraitKey.get(RequestCompressionTrait.class); + // Currently only Gzip is supported in Smithy model: https://smithy.io/2.0/spec/behavior-traits.html#requestcompression-trait + private static final List supportedAlgorithms = + CompressionAlgorithm.supportedAlgorithms(); + + @Override + public RequestT modifyBeforeTransmit(RequestHook hook) { + return hook.mapRequest(HttpRequest.class, RequestCompressionInterceptor::processRequest); + } + + private static HttpRequest processRequest(RequestHook hook) { + if (shouldCompress(hook)) { + var compressionTrait = + hook.operation().schema().getTrait(REQUEST_COMPRESSION_TRAIT_KEY); + var request = hook.request(); + // Will pick the first supported algorithm to compress the body. + for (var algorithmId : compressionTrait.getEncodings()) { + for (var algorithm : supportedAlgorithms) { + if (algorithmId.equals(algorithm.algorithmId())) { + var compressed = algorithm.compress(request.body()); + return request.toBuilder() + .body(compressed) + .withAddedHeader(CONTENT_ENCODING_HEADER, algorithmId) + .build(); + } + } + } + } + return hook.request(); + } + + private static boolean shouldCompress(RequestHook hook) { + var context = hook.context(); + var operation = hook.operation(); + if (!operation.schema().hasTrait(REQUEST_COMPRESSION_TRAIT_KEY) + || context.getOrDefault(HttpContext.DISABLE_REQUEST_COMPRESSION, false)) { + return false; + } + var requestBody = hook.request().body(); + // Streaming should not have known length + if (operation.inputStreamMember() != null && !requestBody.hasKnownLength()) { + return true; + } + return isBodySizeValid(requestBody, context); + } + + private static boolean isBodySizeValid(DataStream requestBody, Context context) { + var minCompressionSize = context.getOrDefault(HttpContext.REQUEST_MIN_COMPRESSION_SIZE_BYTES, + DEFAULT_MIN_COMPRESSION_SIZE_BYTES); + validateCompressionSize(minCompressionSize); + return requestBody.contentLength() >= minCompressionSize; + } + + private static void validateCompressionSize(int minCompressionSize) { + if (minCompressionSize < 0 || minCompressionSize > MIN_COMPRESSION_SIZE_CAP) { + throw new IllegalArgumentException("Min compression size must be between 0 and 10485760"); + } + } + } +} diff --git a/client/client-http/src/main/resources/META-INF/services/software.amazon.smithy.java.client.core.AutoClientPlugin b/client/client-http/src/main/resources/META-INF/services/software.amazon.smithy.java.client.core.AutoClientPlugin index 4755e76fe..4c632c1f3 100644 --- a/client/client-http/src/main/resources/META-INF/services/software.amazon.smithy.java.client.core.AutoClientPlugin +++ b/client/client-http/src/main/resources/META-INF/services/software.amazon.smithy.java.client.core.AutoClientPlugin @@ -1,3 +1,4 @@ software.amazon.smithy.java.client.http.plugins.UserAgentPlugin software.amazon.smithy.java.client.http.plugins.ApplyHttpRetryInfoPlugin +software.amazon.smithy.java.client.http.plugins.RequestCompressionPlugin software.amazon.smithy.java.client.http.plugins.HttpChecksumPlugin diff --git a/client/client-http/src/test/java/software/amazon/smithy/java/client/http/compression/GzipCompressingInputStreamTest.java b/client/client-http/src/test/java/software/amazon/smithy/java/client/http/compression/GzipCompressingInputStreamTest.java new file mode 100644 index 000000000..b44162319 --- /dev/null +++ b/client/client-http/src/test/java/software/amazon/smithy/java/client/http/compression/GzipCompressingInputStreamTest.java @@ -0,0 +1,100 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.client.http.compression; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.lessThan; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.util.zip.GZIPInputStream; +import org.junit.jupiter.api.Test; + +public class GzipCompressingInputStreamTest { + + @Test + public void compressesDataCorrectly() throws Exception { + var original = "Hello World!"; + var source = new ByteArrayInputStream(original.getBytes(StandardCharsets.UTF_8)); + + try (var gzipStream = new GzipCompressingInputStream(source)) { + var compressed = gzipStream.readAllBytes(); + var decompressed = decompress(compressed); + assertThat(decompressed, equalTo(original)); + } + } + + @Test + public void compressesLargeData() throws Exception { + var original = "Hello World! ".repeat(10000); + var source = new ByteArrayInputStream(original.getBytes(StandardCharsets.UTF_8)); + + try (var gzipStream = new GzipCompressingInputStream(source)) { + var compressed = gzipStream.readAllBytes(); + assertThat(compressed.length, lessThan(original.length())); + var decompressed = decompress(compressed); + assertThat(decompressed, equalTo(original)); + } + } + + @Test + public void compressesEmptyData() throws Exception { + var source = new ByteArrayInputStream(new byte[0]); + + try (var gzipStream = new GzipCompressingInputStream(source)) { + var compressed = gzipStream.readAllBytes(); + var decompressed = decompress(compressed); + assertThat(decompressed, equalTo("")); + } + } + + @Test + public void readSingleByteWorks() throws Exception { + var original = "AB"; + var source = new ByteArrayInputStream(original.getBytes(StandardCharsets.UTF_8)); + + try (var gzipStream = new GzipCompressingInputStream(source); + var out = new ByteArrayOutputStream()) { + int b; + while ((b = gzipStream.read()) != -1) { + out.write(b); + } + var decompressed = decompress(out.toByteArray()); + assertThat(decompressed, equalTo(original)); + } + } + + @Test + public void readWithBufferWorks() throws Exception { + var original = "Test buffer read"; + var source = new ByteArrayInputStream(original.getBytes(StandardCharsets.UTF_8)); + + try (var gzipStream = new GzipCompressingInputStream(source); + var out = new ByteArrayOutputStream()) { + var buffer = new byte[4]; + int len; + while ((len = gzipStream.read(buffer, 0, buffer.length)) != -1) { + out.write(buffer, 0, len); + } + var decompressed = decompress(out.toByteArray()); + assertThat(decompressed, equalTo(original)); + } + } + + private String decompress(byte[] compressed) throws Exception { + try (var gzipIn = new GZIPInputStream(new ByteArrayInputStream(compressed)); + var out = new ByteArrayOutputStream()) { + var buffer = new byte[1024]; + int len; + while ((len = gzipIn.read(buffer)) > 0) { + out.write(buffer, 0, len); + } + return out.toString(StandardCharsets.UTF_8); + } + } +} diff --git a/client/client-http/src/test/java/software/amazon/smithy/java/client/http/plugins/RequestCompressionPluginTest.java b/client/client-http/src/test/java/software/amazon/smithy/java/client/http/plugins/RequestCompressionPluginTest.java new file mode 100644 index 000000000..1d5d6b1d1 --- /dev/null +++ b/client/client-http/src/test/java/software/amazon/smithy/java/client/http/plugins/RequestCompressionPluginTest.java @@ -0,0 +1,300 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.client.http.plugins; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.zip.GZIPInputStream; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.java.client.core.interceptors.RequestHook; +import software.amazon.smithy.java.client.http.HttpContext; +import software.amazon.smithy.java.context.Context; +import software.amazon.smithy.java.core.schema.ApiOperation; +import software.amazon.smithy.java.core.schema.ApiService; +import software.amazon.smithy.java.core.schema.Schema; +import software.amazon.smithy.java.core.schema.SerializableStruct; +import software.amazon.smithy.java.core.schema.ShapeBuilder; +import software.amazon.smithy.java.core.serde.ShapeSerializer; +import software.amazon.smithy.java.core.serde.TypeRegistry; +import software.amazon.smithy.java.http.api.HttpRequest; +import software.amazon.smithy.java.io.datastream.DataStream; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.traits.RequestCompressionTrait; +import software.amazon.smithy.model.traits.StreamingTrait; + +public class RequestCompressionPluginTest { + + private static final String REQUEST_BODY = "THIS IS MY COMPRESSION TEST BODY!"; + + @Test + public void doesNotCompressWhenDisabled() throws Exception { + var interceptor = new RequestCompressionPlugin.RequestCompressionInterceptor(); + var context = Context.create(); + context.put(HttpContext.DISABLE_REQUEST_COMPRESSION, true); + context.put(HttpContext.REQUEST_MIN_COMPRESSION_SIZE_BYTES, 1); + var req = HttpRequest.builder() + .uri(new URI("/")) + .method("POST") + .body(DataStream.ofString(REQUEST_BODY)) + .build(); + + var result = interceptor.modifyBeforeTransmit( + new RequestHook<>(createOperationWithCompressionTrait(), context, new TestInput(), req)); + + assertThat(result.headers().allValues("Content-Encoding"), empty()); + } + + @Test + public void compressesWhenBodyMeetsMinSize() throws Exception { + var interceptor = new RequestCompressionPlugin.RequestCompressionInterceptor(); + var context = Context.create(); + String largeBody = REQUEST_BODY.repeat(10); + context.put(HttpContext.REQUEST_MIN_COMPRESSION_SIZE_BYTES, 10); + var req = HttpRequest.builder() + .uri(new URI("/")) + .method("POST") + .body(DataStream.ofString(largeBody)) + .build(); + + var result = interceptor.modifyBeforeTransmit( + new RequestHook<>(createOperationWithCompressionTrait(), context, new TestInput(), req)); + + assertThat(result.headers().allValues("Content-Encoding"), contains("gzip")); + String decompressed = decompress(result.body().asByteBuffer().array()); + assertThat(decompressed, equalTo(largeBody)); + } + + @Test + public void doesNotCompressWhenBodyBelowMinSize() throws Exception { + var interceptor = new RequestCompressionPlugin.RequestCompressionInterceptor(); + var context = Context.create(); + context.put(HttpContext.REQUEST_MIN_COMPRESSION_SIZE_BYTES, 10000); + var req = HttpRequest.builder() + .uri(new URI("/")) + .method("POST") + .body(DataStream.ofString(REQUEST_BODY)) + .build(); + + var result = interceptor.modifyBeforeTransmit( + new RequestHook<>(createOperationWithCompressionTrait(), context, new TestInput(), req)); + + assertThat(result.headers().allValues("Content-Encoding"), empty()); + } + + @Test + public void usesDefaultMinCompressionSize() throws Exception { + var interceptor = new RequestCompressionPlugin.RequestCompressionInterceptor(); + var context = Context.create(); + // Body is smaller than default 10240 + var req = HttpRequest.builder() + .uri(new URI("/")) + .method("POST") + .body(DataStream.ofString(REQUEST_BODY)) + .build(); + + var result = interceptor.modifyBeforeTransmit( + new RequestHook<>(createOperationWithCompressionTrait(), context, new TestInput(), req)); + + assertThat(result.headers().allValues("Content-Encoding"), empty()); + } + + @Test + public void alwaysCompressesStreamingWithoutKnownLength() throws Exception { + var interceptor = new RequestCompressionPlugin.RequestCompressionInterceptor(); + var context = Context.create(); + context.put(HttpContext.REQUEST_MIN_COMPRESSION_SIZE_BYTES, 999999); + String original = "small"; + var streamBody = DataStream.ofInputStream(new ByteArrayInputStream(original.getBytes(StandardCharsets.UTF_8))); + var req = HttpRequest.builder() + .uri(new URI("/")) + .method("POST") + .body(streamBody) + .build(); + + var result = interceptor.modifyBeforeTransmit( + new RequestHook<>(createOperationWithStreamingInput(), context, new TestInput(), req)); + + assertThat(result.headers().allValues("Content-Encoding"), contains("gzip")); + String decompressed = decompress(result.body().asInputStream().readAllBytes()); + assertThat(decompressed, equalTo(original)); + } + + @Test + public void throwsForNegativeMinCompressionSize() throws Exception { + var interceptor = new RequestCompressionPlugin.RequestCompressionInterceptor(); + var context = Context.create(); + context.put(HttpContext.REQUEST_MIN_COMPRESSION_SIZE_BYTES, -1); + String largeBody = REQUEST_BODY.repeat(100); + var req = HttpRequest.builder() + .uri(new URI("/")) + .method("POST") + .body(DataStream.ofString(largeBody)) + .build(); + + var hook = new RequestHook<>(createOperationWithCompressionTrait(), context, new TestInput(), req); + Assertions.assertThrows(IllegalArgumentException.class, () -> interceptor.modifyBeforeTransmit(hook)); + } + + @Test + public void throwsForMinCompressionSizeExceedingCap() throws Exception { + var interceptor = new RequestCompressionPlugin.RequestCompressionInterceptor(); + var context = Context.create(); + context.put(HttpContext.REQUEST_MIN_COMPRESSION_SIZE_BYTES, 10485761); + String largeBody = REQUEST_BODY.repeat(100); + var req = HttpRequest.builder() + .uri(new URI("/")) + .method("POST") + .body(DataStream.ofString(largeBody)) + .build(); + + var hook = new RequestHook<>(createOperationWithCompressionTrait(), context, new TestInput(), req); + Assertions.assertThrows(IllegalArgumentException.class, () -> interceptor.modifyBeforeTransmit(hook)); + } + + @Test + public void doesNotCompressWhenTraitAbsent() throws Exception { + var interceptor = new RequestCompressionPlugin.RequestCompressionInterceptor(); + var context = Context.create(); + context.put(HttpContext.REQUEST_MIN_COMPRESSION_SIZE_BYTES, 1); + var req = HttpRequest.builder() + .uri(new URI("/")) + .method("POST") + .body(DataStream.ofString(REQUEST_BODY)) + .build(); + + var result = interceptor.modifyBeforeTransmit( + new RequestHook<>(createOperationWithoutCompressionTrait(), context, new TestInput(), req)); + + assertThat(result.headers().allValues("Content-Encoding"), empty()); + } + + @Test + public void appendsGzipToExistingContentEncoding() throws Exception { + var interceptor = new RequestCompressionPlugin.RequestCompressionInterceptor(); + var context = Context.create(); + context.put(HttpContext.REQUEST_MIN_COMPRESSION_SIZE_BYTES, 10); + var req = HttpRequest.builder() + .uri(new URI("/")) + .method("POST") + .body(DataStream.ofString(REQUEST_BODY)) + .withAddedHeader("Content-Encoding", "custom") + .build(); + + var result = interceptor.modifyBeforeTransmit( + new RequestHook<>(createOperationWithCompressionTrait(), context, new TestInput(), req)); + + var encodings = result.headers().allValues("Content-Encoding"); + assertThat(encodings, contains("custom", "gzip")); + } + + // Helper: Create operation with @requestCompression trait + private ApiOperation createOperationWithCompressionTrait() { + var trait = RequestCompressionTrait.builder().addEncoding("gzip").build(); + var schema = Schema.createOperation(ShapeId.from("com.test#TestOp"), trait); + return createOperation(schema, null); + } + + // Helper: Create operation without @requestCompression trait + private ApiOperation createOperationWithoutCompressionTrait() { + var schema = Schema.createOperation(ShapeId.from("com.test#TestOp")); + return createOperation(schema, null); + } + + // Helper: Create operation with streaming input + private ApiOperation createOperationWithStreamingInput() { + var trait = RequestCompressionTrait.builder().addEncoding("gzip").build(); + var schema = Schema.createOperation(ShapeId.from("com.test#TestOp"), trait); + var blobSchema = Schema.createBlob(ShapeId.from("com.test#StreamBody"), new StreamingTrait()); + return createOperation(schema, blobSchema); + } + + private ApiOperation createOperation(Schema schema, Schema streamMember) { + return new ApiOperation<>() { + @Override + public ShapeBuilder inputBuilder() { + return null; + } + + @Override + public ShapeBuilder outputBuilder() { + return null; + } + + @Override + public Schema schema() { + return schema; + } + + @Override + public Schema inputSchema() { + return null; + } + + @Override + public Schema outputSchema() { + return null; + } + + @Override + public ApiService service() { + return null; + } + + @Override + public TypeRegistry errorRegistry() { + return null; + } + + @Override + public List effectiveAuthSchemes() { + return List.of(); + } + + @Override + public Schema inputStreamMember() { + return streamMember; + } + }; + } + + private static final class TestInput implements SerializableStruct { + @Override + public Schema schema() { + throw new UnsupportedOperationException(); + } + + @Override + public void serializeMembers(ShapeSerializer serializer) { + throw new UnsupportedOperationException(); + } + + @Override + public T getMemberValue(Schema member) { + return null; + } + } + + private String decompress(byte[] compressed) throws Exception { + try (var gzipIn = new GZIPInputStream(new ByteArrayInputStream(compressed)); + var out = new ByteArrayOutputStream()) { + byte[] buffer = new byte[1024]; + int len; + while ((len = gzipIn.read(buffer)) > 0) { + out.write(buffer, 0, len); + } + return out.toString(StandardCharsets.UTF_8); + } + } +} diff --git a/codegen/codegen-core/src/main/java/software/amazon/smithy/java/codegen/integrations/core/RequestCompressionTraitInitializer.java b/codegen/codegen-core/src/main/java/software/amazon/smithy/java/codegen/integrations/core/RequestCompressionTraitInitializer.java index 2ff180bd4..db4b2baf9 100644 --- a/codegen/codegen-core/src/main/java/software/amazon/smithy/java/codegen/integrations/core/RequestCompressionTraitInitializer.java +++ b/codegen/codegen-core/src/main/java/software/amazon/smithy/java/codegen/integrations/core/RequestCompressionTraitInitializer.java @@ -22,6 +22,6 @@ public void accept(JavaWriter writer, RequestCompressionTrait requestCompression writer.putContext("requestComp", RequestCompressionTrait.class); writer.putContext("list", List.class); writer.writeInline( - "${requestComp:T}.builder().encodings(${list:T}.of(${#enc}${enc:S}${^key.last}, ${/key.last}${/enc})).build()"); + "${requestComp:T}.builder().encodings(${list:T}.of(${#enc}${value:S}${^key.last}, ${/key.last}${/enc})).build()"); } }