From 747160976f0940cb9dc2acf16fec43cf83cd7dc9 Mon Sep 17 00:00:00 2001 From: Matas Date: Wed, 28 Feb 2024 15:37:16 -0500 Subject: [PATCH] feat: support S3 Express One Zone (#1033) --- .../d2f0f8cb-b94d-403e-9db8-5cd61bd8eb1b.json | 5 ++ .../auth/IdentityProviderConfigGenerator.kt | 5 +- .../auth/SigV4AuthSchemeIntegration.kt | 53 ++++++------ .../DefaultEndpointProviderGenerator.kt | 10 ++- .../DefaultEndpointProviderTestGenerator.kt | 10 ++- .../protocol/HttpProtocolClientGenerator.kt | 13 +-- gradle/libs.versions.toml | 2 +- .../api/aws-signing-common.api | 1 + .../auth/awssigning/AwsSigningAttributes.kt | 5 ++ .../auth/awssigning/crt/CrtAwsSigner.kt | 19 ++++- .../kotlin/runtime/http/auth/AwsHttpSigner.kt | 3 +- .../protocol/http-client/api/http-client.api | 57 ++++++------- .../AbstractChecksumInterceptor.kt | 24 ++++++ .../FlexibleChecksumsRequestInterceptor.kt | 64 ++++++++++----- .../interceptors/Md5ChecksumInterceptor.kt | 38 ++++----- .../http/operation/HttpOperationContext.kt | 5 ++ .../http/operation/SdkOperationExecution.kt | 3 + .../AbstractChecksumInterceptorTest.kt | 77 +++++++++++++++++ ...FlexibleChecksumsRequestInterceptorTest.kt | 17 ++++ runtime/runtime-core/api/runtime-core.api | 10 +++ .../kotlin/runtime/collections/LruCache.kt | 78 ++++++++++++++++++ .../smithy/kotlin/runtime/hashing/Crc32.kt | 2 +- .../runtime/collections/LruCacheTest.kt | 82 +++++++++++++++++++ 23 files changed, 462 insertions(+), 121 deletions(-) create mode 100644 .changes/d2f0f8cb-b94d-403e-9db8-5cd61bd8eb1b.json create mode 100644 runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/AbstractChecksumInterceptor.kt create mode 100644 runtime/protocol/http-client/common/test/aws/smithy/kotlin/runtime/http/interceptors/AbstractChecksumInterceptorTest.kt create mode 100644 runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/LruCache.kt create mode 100644 runtime/runtime-core/common/test/aws/smithy/kotlin/runtime/collections/LruCacheTest.kt diff --git a/.changes/d2f0f8cb-b94d-403e-9db8-5cd61bd8eb1b.json b/.changes/d2f0f8cb-b94d-403e-9db8-5cd61bd8eb1b.json new file mode 100644 index 000000000..9cd86e2ef --- /dev/null +++ b/.changes/d2f0f8cb-b94d-403e-9db8-5cd61bd8eb1b.json @@ -0,0 +1,5 @@ +{ + "id": "d2f0f8cb-b94d-403e-9db8-5cd61bd8eb1b", + "type": "feature", + "description": "Add support for S3 Express One Zone" +} \ No newline at end of file diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/auth/IdentityProviderConfigGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/auth/IdentityProviderConfigGenerator.kt index ed4466690..c53f7c385 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/auth/IdentityProviderConfigGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/auth/IdentityProviderConfigGenerator.kt @@ -7,10 +7,7 @@ package software.amazon.smithy.kotlin.codegen.rendering.auth import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.kotlin.codegen.KotlinSettings -import software.amazon.smithy.kotlin.codegen.core.KotlinWriter -import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes -import software.amazon.smithy.kotlin.codegen.core.clientName -import software.amazon.smithy.kotlin.codegen.core.withBlock +import software.amazon.smithy.kotlin.codegen.core.* import software.amazon.smithy.kotlin.codegen.model.buildSymbol import software.amazon.smithy.kotlin.codegen.model.knowledge.AuthIndex import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/auth/SigV4AuthSchemeIntegration.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/auth/SigV4AuthSchemeIntegration.kt index 9d7953454..66be38767 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/auth/SigV4AuthSchemeIntegration.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/auth/SigV4AuthSchemeIntegration.kt @@ -186,29 +186,36 @@ private object Sigv4EndpointCustomization : EndpointCustomization { // SigV4a requires SigV4 so SigV4 integration renders SigV4a auth scheme. // See comment in example model: https://smithy.io/2.0/aws/aws-auth.html?highlight=sigv4#aws-auth-sigv4a-trait private fun renderAuthSchemes(writer: KotlinWriter, authSchemes: Expression, expressionRenderer: ExpressionRenderer) { - writer.writeInline("#T to ", RuntimeTypes.SmithyClient.Endpoints.SigningContextAttributeKey) - writer.withBlock("listOf(", ")") { - authSchemes.toNode().expectArrayNode().forEach { - val scheme = it.expectObjectNode() - val schemeName = scheme.expectStringMember("name").value - - val authFactoryFn = when (schemeName) { - "sigv4" -> RuntimeTypes.Auth.HttpAuthAws.sigV4 - "sigv4a" -> RuntimeTypes.Auth.HttpAuthAws.sigV4A - else -> return@forEach - } - - withBlock("#T(", "),", authFactoryFn) { - // we delegate back to the expression visitor for each of these fields because it's possible to - // encounter template strings throughout - - writeInline("serviceName = ") - renderOrElse(expressionRenderer, scheme.getStringMember("signingName"), "null") - - writeInline("disableDoubleUriEncode = ") - renderOrElse(expressionRenderer, scheme.getBooleanMember("disableDoubleEncoding"), "false") - - renderFieldsForScheme(writer, scheme, expressionRenderer) + val schemes = authSchemes.toNode().expectArrayNode().filter { + val name = it.expectObjectNode().expectStringMember("name").value + name == "sigv4" || name == "sigv4a" + }.takeIf { it.isNotEmpty() } + + schemes?.let { + writer.writeInline("#T to ", RuntimeTypes.SmithyClient.Endpoints.SigningContextAttributeKey) + writer.withBlock("listOf(", ")") { + schemes.forEach { + val scheme = it.expectObjectNode() + val schemeName = scheme.expectStringMember("name").value + + val authFactoryFn = when (schemeName) { + "sigv4" -> RuntimeTypes.Auth.HttpAuthAws.sigV4 + "sigv4a" -> RuntimeTypes.Auth.HttpAuthAws.sigV4A + else -> return@forEach + } + + withBlock("#T(", "),", authFactoryFn) { + // we delegate back to the expression visitor for each of these fields because it's possible to + // encounter template strings throughout + + writeInline("serviceName = ") + renderOrElse(expressionRenderer, scheme.getStringMember("signingName"), "null") + + writeInline("disableDoubleUriEncode = ") + renderOrElse(expressionRenderer, scheme.getBooleanMember("disableDoubleEncoding"), "false") + + renderFieldsForScheme(writer, scheme, expressionRenderer) + } } } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/DefaultEndpointProviderGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/DefaultEndpointProviderGenerator.kt index ab8743e65..f7bae2485 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/DefaultEndpointProviderGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/DefaultEndpointProviderGenerator.kt @@ -80,8 +80,10 @@ class DefaultEndpointProviderGenerator( private val propertyRenderers = endpointCustomizations .map { it.propertyRenderers } - .fold(mutableMapOf()) { acc, propRenderers -> - acc.putAll(propRenderers) + .fold(mutableMapOf>()) { acc, propRenderers -> + propRenderers.forEach { (key, propRenderer) -> + acc[key] = acc.getOrDefault(key, mutableListOf()).also { it.add(propRenderer) } + } acc } @@ -190,7 +192,9 @@ class DefaultEndpointProviderGenerator( // caller has a chance to generate their own value for a recognized property if (kStr in propertyRenderers) { - propertyRenderers[kStr]!!(writer, v, this@DefaultEndpointProviderGenerator) + propertyRenderers[kStr]!!.forEach { renderer -> + renderer(writer, v, this@DefaultEndpointProviderGenerator) + } return@forEach } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/DefaultEndpointProviderTestGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/DefaultEndpointProviderTestGenerator.kt index 0492bb580..d9573ebb5 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/DefaultEndpointProviderTestGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/DefaultEndpointProviderTestGenerator.kt @@ -42,8 +42,10 @@ class DefaultEndpointProviderTestGenerator( private val endpointCustomizations = ctx.integrations.mapNotNull { it.customizeEndpointResolution(ctx) } private val propertyRenderers = endpointCustomizations .map { it.propertyRenderers } - .fold(mutableMapOf()) { acc, propRenderers -> - acc.putAll(propRenderers) + .fold(mutableMapOf>()) { acc, propRenderers -> + propRenderers.forEach { (key, propRenderer) -> + acc[key] = acc.getOrDefault(key, mutableListOf()).also { it.add(propRenderer) } + } acc } @@ -131,7 +133,9 @@ class DefaultEndpointProviderTestGenerator( withBlock("attributes = #T {", "},", RuntimeTypes.Core.Collections.attributesOf) { endpoint.properties.entries.forEach { (k, v) -> if (k in propertyRenderers) { - propertyRenderers[k]!!(writer, Expression.fromNode(v), this@DefaultEndpointProviderTestGenerator) + propertyRenderers[k]!!.forEach { renderer -> + renderer(writer, Expression.fromNode(v), this@DefaultEndpointProviderTestGenerator) + } return@forEach } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpProtocolClientGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpProtocolClientGenerator.kt index 3db122389..a25332260 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpProtocolClientGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpProtocolClientGenerator.kt @@ -348,19 +348,12 @@ open class HttpProtocolClientGenerator( return } - val requestAlgorithmMember = ctx.model.getShape(input.get()).getOrNull() - ?.members() - ?.firstOrNull { it.memberName == httpChecksumTrait?.requestAlgorithmMember?.getOrNull() } - if (hasTrait() || httpChecksumTrait?.isRequestChecksumRequired == true) { val interceptorSymbol = RuntimeTypes.HttpClient.Interceptors.Md5ChecksumInterceptor val inputSymbol = ctx.symbolProvider.toSymbol(ctx.model.expectShape(inputShape)) - - requestAlgorithmMember?.let { - writer.withBlock("op.interceptors.add(#T<#T> { ", "})", interceptorSymbol, inputSymbol) { - writer.write("it.#L?.value == null", requestAlgorithmMember.defaultName()) - } - } ?: writer.write("op.interceptors.add(#T<#T>())", interceptorSymbol, inputSymbol) + writer.withBlock("op.interceptors.add(#T<#T> {", "})", interceptorSymbol, inputSymbol) { + writer.write("op.context.getOrNull(#T.ChecksumAlgorithm) == null", RuntimeTypes.HttpClient.Operation.HttpOperationContext) + } } } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index eea8dd4af..e97e8e243 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -12,7 +12,7 @@ okio-version = "3.6.0" otel-version = "1.32.0" slf4j-version = "2.0.9" slf4j-v1x-version = "1.7.36" -crt-kotlin-version = "0.8.2" +crt-kotlin-version = "0.8.5" # codegen smithy-version = "1.42.0" diff --git a/runtime/auth/aws-signing-common/api/aws-signing-common.api b/runtime/auth/aws-signing-common/api/aws-signing-common.api index 970cef170..632fe22ae 100644 --- a/runtime/auth/aws-signing-common/api/aws-signing-common.api +++ b/runtime/auth/aws-signing-common/api/aws-signing-common.api @@ -55,6 +55,7 @@ public final class aws/smithy/kotlin/runtime/auth/awssigning/AwsSigningAttribute public final fun getEnableAwsChunked ()Laws/smithy/kotlin/runtime/collections/AttributeKey; public final fun getHashSpecification ()Laws/smithy/kotlin/runtime/collections/AttributeKey; public final fun getNormalizeUriPath ()Laws/smithy/kotlin/runtime/collections/AttributeKey; + public final fun getOmitSessionToken ()Laws/smithy/kotlin/runtime/collections/AttributeKey; public final fun getRequestSignature ()Laws/smithy/kotlin/runtime/collections/AttributeKey; public final fun getSignedBodyHeader ()Laws/smithy/kotlin/runtime/collections/AttributeKey; public final fun getSigner ()Laws/smithy/kotlin/runtime/collections/AttributeKey; diff --git a/runtime/auth/aws-signing-common/common/src/aws/smithy/kotlin/runtime/auth/awssigning/AwsSigningAttributes.kt b/runtime/auth/aws-signing-common/common/src/aws/smithy/kotlin/runtime/auth/awssigning/AwsSigningAttributes.kt index 10e080a84..c4929ebcf 100644 --- a/runtime/auth/aws-signing-common/common/src/aws/smithy/kotlin/runtime/auth/awssigning/AwsSigningAttributes.kt +++ b/runtime/auth/aws-signing-common/common/src/aws/smithy/kotlin/runtime/auth/awssigning/AwsSigningAttributes.kt @@ -86,4 +86,9 @@ public object AwsSigningAttributes { * @see SigV4 Streaming */ public val EnableAwsChunked: AttributeKey = AttributeKey("aws.smithy.kotlin.signing#EnableAwsChunked") + + /** + * Flag indicating whether the X-Amz-Security-Token header should be omitted from the canonical request during signing. + */ + public val OmitSessionToken: AttributeKey = AttributeKey("aws.smithy.kotlin.signing#OmitSessionToken") } diff --git a/runtime/auth/aws-signing-crt/jvm/src/aws/smithy/kotlin/runtime/auth/awssigning/crt/CrtAwsSigner.kt b/runtime/auth/aws-signing-crt/jvm/src/aws/smithy/kotlin/runtime/auth/awssigning/crt/CrtAwsSigner.kt index cadcc0f12..c0069c1ac 100644 --- a/runtime/auth/aws-signing-crt/jvm/src/aws/smithy/kotlin/runtime/auth/awssigning/crt/CrtAwsSigner.kt +++ b/runtime/auth/aws-signing-crt/jvm/src/aws/smithy/kotlin/runtime/auth/awssigning/crt/CrtAwsSigner.kt @@ -22,19 +22,30 @@ import aws.sdk.kotlin.crt.auth.signing.AwsSigningAlgorithm as CrtSigningAlgorith import aws.sdk.kotlin.crt.auth.signing.AwsSigningConfig as CrtSigningConfig import aws.sdk.kotlin.crt.http.Headers as CrtHeaders +private const val S3_EXPRESS_HEADER_NAME = "X-Amz-S3session-Token" + public object CrtAwsSigner : AwsSigner { override suspend fun sign(request: HttpRequest, config: AwsSigningConfig): AwsSigningResult { val isUnsigned = config.hashSpecification is HashSpecification.UnsignedPayload val isAwsChunked = request.headers.contains("Content-Encoding", "aws-chunked") - val crtRequest = request.toSignableCrtRequest(isUnsigned, isAwsChunked) - val crtConfig = config.toCrtSigningConfig() + val isS3Express = request.headers.contains(S3_EXPRESS_HEADER_NAME) + + val requestBuilder = request.toBuilder() - val crtResult = CrtSigner.sign(crtRequest, crtConfig) + val crtConfig = config.toCrtSigningConfig().toBuilder() + if (isS3Express) { + crtConfig.algorithm = CrtSigningAlgorithm.SIGV4_S3EXPRESS + crtConfig.omitSessionToken = false + requestBuilder.headers.remove(S3_EXPRESS_HEADER_NAME) // CRT signer fails if this header is already present + } + + val crtRequest = requestBuilder.build().toSignableCrtRequest(isUnsigned, isAwsChunked) + + val crtResult = CrtSigner.sign(crtRequest, crtConfig.build()) coroutineContext.debug { "Calculated signature: ${crtResult.signature.decodeToString()}" } val crtSignedResult = checkNotNull(crtResult.signedRequest) { "Signed request unexpectedly null" } - val requestBuilder = request.toBuilder() requestBuilder.update(crtSignedResult) return AwsSigningResult(requestBuilder.build(), crtResult.signature) } diff --git a/runtime/auth/http-auth-aws/common/src/aws/smithy/kotlin/runtime/http/auth/AwsHttpSigner.kt b/runtime/auth/http-auth-aws/common/src/aws/smithy/kotlin/runtime/http/auth/AwsHttpSigner.kt index 81ad58c29..bb4532cbb 100644 --- a/runtime/auth/http-auth-aws/common/src/aws/smithy/kotlin/runtime/http/auth/AwsHttpSigner.kt +++ b/runtime/auth/http-auth-aws/common/src/aws/smithy/kotlin/runtime/http/auth/AwsHttpSigner.kt @@ -124,6 +124,7 @@ public class AwsHttpSigner(private val config: Config) : HttpSigner { val contextUseDoubleUriEncode = attributes.getOrNull(AwsSigningAttributes.UseDoubleUriEncode) val contextNormalizeUriPath = attributes.getOrNull(AwsSigningAttributes.NormalizeUriPath) val contextSigningServiceName = attributes.getOrNull(AwsSigningAttributes.SigningService) + val contextOmitSessionToken = attributes.getOrNull(AwsSigningAttributes.OmitSessionToken) val enableAwsChunked = attributes.getOrNull(AwsSigningAttributes.EnableAwsChunked) ?: false @@ -143,7 +144,7 @@ public class AwsHttpSigner(private val config: Config) : HttpSigner { ?: (Instant.now() + (attributes.getOrNull(HttpOperationContext.ClockSkew) ?: Duration.ZERO)) signatureType = config.signatureType - omitSessionToken = config.omitSessionToken + omitSessionToken = contextOmitSessionToken ?: config.omitSessionToken normalizeUriPath = contextNormalizeUriPath ?: config.normalizeUriPath useDoubleUriEncode = contextUseDoubleUriEncode ?: config.useDoubleUriEncode expiresAfter = config.expiresAfter diff --git a/runtime/protocol/http-client/api/http-client.api b/runtime/protocol/http-client/api/http-client.api index 8e429b641..7c28406ec 100644 --- a/runtime/protocol/http-client/api/http-client.api +++ b/runtime/protocol/http-client/api/http-client.api @@ -255,13 +255,10 @@ public final class aws/smithy/kotlin/runtime/http/engine/internal/ManagedHttpCli public static final fun manage (Laws/smithy/kotlin/runtime/http/engine/HttpClientEngine;)Laws/smithy/kotlin/runtime/http/engine/HttpClientEngine; } -public final class aws/smithy/kotlin/runtime/http/interceptors/ChecksumMismatchException : aws/smithy/kotlin/runtime/ClientException { - public fun (Ljava/lang/String;)V -} - -public final class aws/smithy/kotlin/runtime/http/interceptors/ContinueInterceptor : aws/smithy/kotlin/runtime/client/Interceptor { - public fun (J)V - public final fun getThresholdLengthBytes ()J +public abstract class aws/smithy/kotlin/runtime/http/interceptors/AbstractChecksumInterceptor : aws/smithy/kotlin/runtime/client/Interceptor { + public fun ()V + public abstract fun applyChecksum (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Ljava/lang/String;)Laws/smithy/kotlin/runtime/http/request/HttpRequest; + public abstract fun calculateChecksum (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun modifyBeforeAttemptCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun modifyBeforeCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun modifyBeforeDeserialization (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @@ -283,8 +280,13 @@ public final class aws/smithy/kotlin/runtime/http/interceptors/ContinueIntercept public fun readBeforeTransmit (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V } -public final class aws/smithy/kotlin/runtime/http/interceptors/DiscoveredEndpointErrorInterceptor : aws/smithy/kotlin/runtime/client/Interceptor { - public fun (Lkotlin/reflect/KClass;Lkotlin/jvm/functions/Function1;)V +public final class aws/smithy/kotlin/runtime/http/interceptors/ChecksumMismatchException : aws/smithy/kotlin/runtime/ClientException { + public fun (Ljava/lang/String;)V +} + +public final class aws/smithy/kotlin/runtime/http/interceptors/ContinueInterceptor : aws/smithy/kotlin/runtime/client/Interceptor { + public fun (J)V + public final fun getThresholdLengthBytes ()J public fun modifyBeforeAttemptCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun modifyBeforeCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun modifyBeforeDeserialization (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @@ -306,8 +308,8 @@ public final class aws/smithy/kotlin/runtime/http/interceptors/DiscoveredEndpoin public fun readBeforeTransmit (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V } -public final class aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsRequestInterceptor : aws/smithy/kotlin/runtime/client/Interceptor { - public fun (Lkotlin/jvm/functions/Function1;)V +public final class aws/smithy/kotlin/runtime/http/interceptors/DiscoveredEndpointErrorInterceptor : aws/smithy/kotlin/runtime/client/Interceptor { + public fun (Lkotlin/reflect/KClass;Lkotlin/jvm/functions/Function1;)V public fun modifyBeforeAttemptCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun modifyBeforeCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun modifyBeforeDeserialization (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @@ -329,6 +331,16 @@ public final class aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksums public fun readBeforeTransmit (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V } +public final class aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsRequestInterceptor : aws/smithy/kotlin/runtime/http/interceptors/AbstractChecksumInterceptor { + public fun ()V + public fun (Lkotlin/jvm/functions/Function1;)V + public synthetic fun (Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun applyChecksum (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Ljava/lang/String;)Laws/smithy/kotlin/runtime/http/request/HttpRequest; + public fun calculateChecksum (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun modifyBeforeSigning (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun readAfterSerialization (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V +} + public final class aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsResponseInterceptor : aws/smithy/kotlin/runtime/client/Interceptor { public static final field Companion Laws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsResponseInterceptor$Companion; public fun (Lkotlin/jvm/functions/Function1;)V @@ -357,29 +369,13 @@ public final class aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksums public final fun getChecksumHeaderValidated ()Laws/smithy/kotlin/runtime/collections/AttributeKey; } -public final class aws/smithy/kotlin/runtime/http/interceptors/Md5ChecksumInterceptor : aws/smithy/kotlin/runtime/client/Interceptor { +public final class aws/smithy/kotlin/runtime/http/interceptors/Md5ChecksumInterceptor : aws/smithy/kotlin/runtime/http/interceptors/AbstractChecksumInterceptor { public fun ()V public fun (Lkotlin/jvm/functions/Function1;)V public synthetic fun (Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V - public fun modifyBeforeAttemptCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public fun modifyBeforeCompletion-gIAlu-s (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public fun modifyBeforeDeserialization (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public fun modifyBeforeRetryLoop (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public fun modifyBeforeSerialization (Laws/smithy/kotlin/runtime/client/RequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun applyChecksum (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Ljava/lang/String;)Laws/smithy/kotlin/runtime/http/request/HttpRequest; + public fun calculateChecksum (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun modifyBeforeSigning (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public fun modifyBeforeTransmit (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public fun readAfterAttempt (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;)V - public fun readAfterDeserialization (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;)V - public fun readAfterExecution (Laws/smithy/kotlin/runtime/client/ResponseInterceptorContext;)V - public fun readAfterSerialization (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V - public fun readAfterSigning (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V - public fun readAfterTransmit (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;)V - public fun readBeforeAttempt (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V - public fun readBeforeDeserialization (Laws/smithy/kotlin/runtime/client/ProtocolResponseInterceptorContext;)V - public fun readBeforeExecution (Laws/smithy/kotlin/runtime/client/RequestInterceptorContext;)V - public fun readBeforeSerialization (Laws/smithy/kotlin/runtime/client/RequestInterceptorContext;)V - public fun readBeforeSigning (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V - public fun readBeforeTransmit (Laws/smithy/kotlin/runtime/client/ProtocolRequestInterceptorContext;)V } public final class aws/smithy/kotlin/runtime/http/interceptors/RequestCompressionInterceptor : aws/smithy/kotlin/runtime/client/Interceptor { @@ -476,6 +472,7 @@ public abstract interface class aws/smithy/kotlin/runtime/http/operation/HttpDes public final class aws/smithy/kotlin/runtime/http/operation/HttpOperationContext { public static final field INSTANCE Laws/smithy/kotlin/runtime/http/operation/HttpOperationContext; + public final fun getChecksumAlgorithm ()Laws/smithy/kotlin/runtime/collections/AttributeKey; public final fun getClockSkew ()Laws/smithy/kotlin/runtime/collections/AttributeKey; public final fun getClockSkewApproximateSigningTime ()Laws/smithy/kotlin/runtime/collections/AttributeKey; public final fun getHostPrefix ()Laws/smithy/kotlin/runtime/collections/AttributeKey; diff --git a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/AbstractChecksumInterceptor.kt b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/AbstractChecksumInterceptor.kt new file mode 100644 index 000000000..3fa8406bf --- /dev/null +++ b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/AbstractChecksumInterceptor.kt @@ -0,0 +1,24 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package aws.smithy.kotlin.runtime.http.interceptors + +import aws.smithy.kotlin.runtime.InternalApi +import aws.smithy.kotlin.runtime.client.ProtocolRequestInterceptorContext +import aws.smithy.kotlin.runtime.http.request.HttpRequest + +@InternalApi +public abstract class AbstractChecksumInterceptor : HttpInterceptor { + private var cachedChecksum: String? = null + + override suspend fun modifyBeforeSigning(context: ProtocolRequestInterceptorContext): HttpRequest { + cachedChecksum ?: calculateChecksum(context).also { cachedChecksum = it } + return cachedChecksum?.let { applyChecksum(context, it) } ?: context.protocolRequest + } + + public abstract suspend fun calculateChecksum(context: ProtocolRequestInterceptorContext): String? + + public abstract fun applyChecksum(context: ProtocolRequestInterceptorContext, checksum: String): HttpRequest +} diff --git a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsRequestInterceptor.kt b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsRequestInterceptor.kt index 7ee18cdf7..a7ebdc08c 100644 --- a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsRequestInterceptor.kt +++ b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsRequestInterceptor.kt @@ -10,6 +10,7 @@ import aws.smithy.kotlin.runtime.InternalApi import aws.smithy.kotlin.runtime.client.ProtocolRequestInterceptorContext import aws.smithy.kotlin.runtime.hashing.* import aws.smithy.kotlin.runtime.http.* +import aws.smithy.kotlin.runtime.http.operation.HttpOperationContext import aws.smithy.kotlin.runtime.http.request.HttpRequest import aws.smithy.kotlin.runtime.http.request.header import aws.smithy.kotlin.runtime.http.request.toBuilder @@ -30,23 +31,25 @@ import kotlin.coroutines.coroutineContext * In this case, a [LazyAsyncValue] will be added to the execution context which allows the trailing checksum to be sent * after the entire body has been streamed. * - * @param checksumAlgorithmNameInitializer a function which parses the input [I] to return the checksum algorithm name + * @param checksumAlgorithmNameInitializer an optional function which parses the input [I] to return the checksum algorithm name. + * if not set, then the [HttpOperationContext.ChecksumAlgorithm] execution context attribute will be used. */ @InternalApi public class FlexibleChecksumsRequestInterceptor( - private val checksumAlgorithmNameInitializer: (I) -> String?, -) : HttpInterceptor { + private val checksumAlgorithmNameInitializer: ((I) -> String?)? = null, +) : AbstractChecksumInterceptor() { private var checksumAlgorithmName: String? = null - override fun readAfterSerialization(context: ProtocolRequestInterceptorContext) { - @Suppress("UNCHECKED_CAST") - val input = context.request as I - checksumAlgorithmName = checksumAlgorithmNameInitializer(input) - } + @Deprecated("readAfterSerialization is no longer used") + override fun readAfterSerialization(context: ProtocolRequestInterceptorContext) { } - override suspend fun modifyBeforeRetryLoop(context: ProtocolRequestInterceptorContext): HttpRequest { + override suspend fun modifyBeforeSigning(context: ProtocolRequestInterceptorContext): HttpRequest { val logger = coroutineContext.logger>() + @Suppress("UNCHECKED_CAST") + val input = context.request as I + checksumAlgorithmName = checksumAlgorithmNameInitializer?.invoke(input) ?: context.executionContext.getOrNull(HttpOperationContext.ChecksumAlgorithm) + checksumAlgorithmName ?: run { logger.debug { "no checksum algorithm specified, skipping flexible checksums processing" } return context.protocolRequest @@ -65,7 +68,7 @@ public class FlexibleChecksumsRequestInterceptor( // this handles the case where a user inputs a precalculated checksum, but it doesn't match the input checksum algorithm req.headers.removeAllChecksumHeadersExcept(headerName) - val checksumAlgorithm = checksumAlgorithmName!!.toHashFunction() ?: throw ClientException("Could not parse checksum algorithm $checksumAlgorithmName") + val checksumAlgorithm = checksumAlgorithmName?.toHashFunction() ?: throw ClientException("Could not parse checksum algorithm $checksumAlgorithmName") if (!checksumAlgorithm.isSupported) { throw ClientException("Checksum algorithm $checksumAlgorithmName is not supported for flexible checksums") @@ -91,21 +94,38 @@ public class FlexibleChecksumsRequestInterceptor( } req.trailingHeaders.append(headerName, deferredChecksum) - } else if (req.headers[headerName] == null) { - logger.debug { "Calculating checksum" } + return req.build() + } else { + return super.modifyBeforeSigning(context) + } + } - val checksum: String = when { - req.body.contentLength == null && !req.body.isOneShot -> { - val channel = req.body.toSdkByteReadChannel()!! - channel.rollingHash(checksumAlgorithm).encodeBase64String() - } - else -> { - val bodyBytes = req.body.readAll()!! - req.body = bodyBytes.toHttpBody() // replace the consumed body - bodyBytes.hash(checksumAlgorithm).encodeBase64String() - } + override suspend fun calculateChecksum(context: ProtocolRequestInterceptorContext): String? { + val req = context.protocolRequest.toBuilder() + val checksumAlgorithm = checksumAlgorithmName?.toHashFunction() ?: return null + + return when { + req.body.contentLength == null && !req.body.isOneShot -> { + val channel = req.body.toSdkByteReadChannel()!! + channel.rollingHash(checksumAlgorithm).encodeBase64String() + } + else -> { + val bodyBytes = req.body.readAll()!! + req.body = bodyBytes.toHttpBody() + bodyBytes.hash(checksumAlgorithm).encodeBase64String() } + } + } + + override fun applyChecksum( + context: ProtocolRequestInterceptorContext, + checksum: String, + ): HttpRequest { + val headerName = "x-amz-checksum-$checksumAlgorithmName".lowercase() + + val req = context.protocolRequest.toBuilder() + if (!req.headers.contains(headerName)) { req.header(headerName, checksum) } diff --git a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/Md5ChecksumInterceptor.kt b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/Md5ChecksumInterceptor.kt index cdb122071..cbe6bcabe 100644 --- a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/Md5ChecksumInterceptor.kt +++ b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/interceptors/Md5ChecksumInterceptor.kt @@ -19,36 +19,36 @@ import aws.smithy.kotlin.runtime.text.encoding.encodeBase64String * See: * - https://awslabs.github.io/smithy/1.0/spec/core/behavior-traits.html#httpchecksumrequired-trait * - https://datatracker.ietf.org/doc/html/rfc1864.html + * @param block An optional function which parses the input [I] to determine if the `Content-MD5` header should be set. + * If not provided, the default behavior will set the header. */ @InternalApi public class Md5ChecksumInterceptor( private val block: ((input: I) -> Boolean)? = null, -) : HttpInterceptor { +) : AbstractChecksumInterceptor() { + override suspend fun modifyBeforeSigning(context: ProtocolRequestInterceptorContext): HttpRequest { + @Suppress("UNCHECKED_CAST") + val input = context.request as I - private var shouldInjectMD5Header: Boolean = false - - override fun readAfterSerialization(context: ProtocolRequestInterceptorContext) { - shouldInjectMD5Header = block?.let { - @Suppress("UNCHECKED_CAST") - val input = context.request as I - it(input) - } ?: true - } - - override suspend fun modifyBeforeRetryLoop(context: ProtocolRequestInterceptorContext): HttpRequest { - if (!shouldInjectMD5Header) { + val injectMd5Header = block?.invoke(input) ?: true + if (!injectMd5Header) { return context.protocolRequest } - val checksum = when (val body = context.protocolRequest.body) { + return super.modifyBeforeSigning(context) + } + + public override suspend fun calculateChecksum(context: ProtocolRequestInterceptorContext): String? = + when (val body = context.protocolRequest.body) { is HttpBody.Bytes -> body.bytes().md5().encodeBase64String() else -> null } - return checksum?.let { - val req = context.protocolRequest.toBuilder() - req.header("Content-MD5", it) - req.build() - } ?: context.protocolRequest + public override fun applyChecksum(context: ProtocolRequestInterceptorContext, checksum: String): HttpRequest { + val req = context.protocolRequest.toBuilder() + if (!req.headers.contains("Content-MD5")) { + req.header("Content-MD5", checksum) + } + return req.build() } } diff --git a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/HttpOperationContext.kt b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/HttpOperationContext.kt index 28ad9f770..524614a66 100644 --- a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/HttpOperationContext.kt +++ b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/HttpOperationContext.kt @@ -64,6 +64,11 @@ public object HttpOperationContext { * The approximate signing time of the request, used to compute client clock skew. */ public val ClockSkewApproximateSigningTime: AttributeKey = AttributeKey("aws.smithy.kotlin#ClockSkewApproximateSigningTime") + + /** + * The name of the algorithm to be used for computing a checksum of the request. + */ + public val ChecksumAlgorithm: AttributeKey = AttributeKey("aws.smithy.kotlin#ChecksumAlgorithm") } internal val ExecutionContext.operationMetrics: OperationMetrics diff --git a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/SdkOperationExecution.kt b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/SdkOperationExecution.kt index a17ebe437..29deb381c 100644 --- a/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/SdkOperationExecution.kt +++ b/runtime/protocol/http-client/common/src/aws/smithy/kotlin/runtime/http/operation/SdkOperationExecution.kt @@ -294,6 +294,9 @@ internal class AuthHandler( // update the request context with endpoint specific auth signing context val endpointAuthAttributes = endpoint.authOptions.firstOrNull { it.schemeId == authScheme.schemeId }?.attributes ?: emptyAttributes() request.context.merge(endpointAuthAttributes) + + // also update the request context with endpoint attributes + request.context.merge(endpoint.attributes) } val modified = interceptors.modifyBeforeSigning(request.subject.immutableView(true)) diff --git a/runtime/protocol/http-client/common/test/aws/smithy/kotlin/runtime/http/interceptors/AbstractChecksumInterceptorTest.kt b/runtime/protocol/http-client/common/test/aws/smithy/kotlin/runtime/http/interceptors/AbstractChecksumInterceptorTest.kt new file mode 100644 index 000000000..c3b28eff6 --- /dev/null +++ b/runtime/protocol/http-client/common/test/aws/smithy/kotlin/runtime/http/interceptors/AbstractChecksumInterceptorTest.kt @@ -0,0 +1,77 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +import aws.smithy.kotlin.runtime.client.ProtocolRequestInterceptorContext +import aws.smithy.kotlin.runtime.collections.get +import aws.smithy.kotlin.runtime.http.HttpBody +import aws.smithy.kotlin.runtime.http.SdkHttpClient +import aws.smithy.kotlin.runtime.http.interceptors.AbstractChecksumInterceptor +import aws.smithy.kotlin.runtime.http.operation.HttpOperationContext +import aws.smithy.kotlin.runtime.http.operation.newTestOperation +import aws.smithy.kotlin.runtime.http.operation.roundTrip +import aws.smithy.kotlin.runtime.http.request.HttpRequest +import aws.smithy.kotlin.runtime.http.request.HttpRequestBuilder +import aws.smithy.kotlin.runtime.http.request.header +import aws.smithy.kotlin.runtime.http.request.toBuilder +import aws.smithy.kotlin.runtime.httptest.TestEngine +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals + +class AbstractChecksumInterceptorTest { + private val client = SdkHttpClient(TestEngine()) + private val CHECKSUM_TEST_HEADER = "x-amz-kotlin-sdk-test-checksum-header" + + @Test + fun testChecksumIsCalculatedAndApplied() = runTest { + val req = HttpRequestBuilder().apply { + body = HttpBody.fromBytes("hello".encodeToByteArray()) + } + val expectedChecksumValue = "abcd" + + val op = newTestOperation(req, Unit) + + op.interceptors.add(TestAbstractChecksumInterceptor(expectedChecksumValue)) + + op.roundTrip(client, Unit) + val call = op.context.attributes[HttpOperationContext.HttpCallList].first() + assertEquals(expectedChecksumValue, call.request.headers[CHECKSUM_TEST_HEADER]) + } + + @Test + fun testCachedChecksumIsUsed() = runTest { + val req = HttpRequestBuilder().apply { + body = HttpBody.fromBytes("hello".encodeToByteArray()) + } + val expectedChecksumValue = "abcd" + + val op = newTestOperation(req, Unit) + + op.interceptors.add(TestAbstractChecksumInterceptor(expectedChecksumValue)) + + // the TestAbstractChecksumInterceptor will throw an exception if calculateChecksum is called more than once. + op.roundTrip(client, Unit) + op.roundTrip(client, Unit) + } + + inner class TestAbstractChecksumInterceptor( + private val expectedChecksum: String?, + ) : AbstractChecksumInterceptor() { + private var alreadyCalculatedChecksum = false + + override suspend fun calculateChecksum(context: ProtocolRequestInterceptorContext): String? { + check(!alreadyCalculatedChecksum) { "calculateChecksum was called more than once!" } + return expectedChecksum.also { alreadyCalculatedChecksum = true } + } + + override fun applyChecksum( + context: ProtocolRequestInterceptorContext, + checksum: String, + ): HttpRequest { + val req = context.protocolRequest.toBuilder() + req.header(CHECKSUM_TEST_HEADER, checksum) + return req.build() + } + } +} diff --git a/runtime/protocol/http-client/common/test/aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsRequestInterceptorTest.kt b/runtime/protocol/http-client/common/test/aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsRequestInterceptorTest.kt index 6413cb3bb..c4c85de66 100644 --- a/runtime/protocol/http-client/common/test/aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsRequestInterceptorTest.kt +++ b/runtime/protocol/http-client/common/test/aws/smithy/kotlin/runtime/http/interceptors/FlexibleChecksumsRequestInterceptorTest.kt @@ -126,6 +126,23 @@ class FlexibleChecksumsRequestInterceptorTest { assertEquals(0, call.request.headers.getNumChecksumHeaders()) } + @Test + fun itSetsChecksumHeaderViaExecutionContext() = runTest { + checksums.forEach { (checksumAlgorithmName, expectedChecksumValue) -> + val req = HttpRequestBuilder().apply { + body = HttpBody.fromBytes("bar".encodeToByteArray()) + } + + val op = newTestOperation(req, Unit) + op.context[HttpOperationContext.ChecksumAlgorithm] = checksumAlgorithmName + op.interceptors.add(FlexibleChecksumsRequestInterceptor()) + + op.roundTrip(client, Unit) + val call = op.context.attributes[HttpOperationContext.HttpCallList].first() + assertEquals(expectedChecksumValue, call.request.headers["x-amz-checksum-$checksumAlgorithmName"]) + } + } + @Test fun testCompletingSource() = runTest { val hashFunctionName = "crc32" diff --git a/runtime/runtime-core/api/runtime-core.api b/runtime/runtime-core/api/runtime-core.api index 3210317ca..bfe9a1f31 100644 --- a/runtime/runtime-core/api/runtime-core.api +++ b/runtime/runtime-core/api/runtime-core.api @@ -117,6 +117,16 @@ public final class aws/smithy/kotlin/runtime/collections/CollectionExtKt { public static final fun createOrAppend (Ljava/util/List;Ljava/lang/Object;)Ljava/util/List; } +public final class aws/smithy/kotlin/runtime/collections/LruCache { + public fun (I)V + public final fun get (Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getCapacity ()I + public final fun getEntries ()Ljava/util/Set; + public final fun getSize ()I + public final fun put (Ljava/lang/Object;Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun remove (Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public abstract interface class aws/smithy/kotlin/runtime/collections/MultiMap : java/util/Map, kotlin/jvm/internal/markers/KMappedMarker { public abstract fun contains (Ljava/lang/Object;Ljava/lang/Object;)Z public abstract fun getEntryValues ()Lkotlin/sequences/Sequence; diff --git a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/LruCache.kt b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/LruCache.kt new file mode 100644 index 000000000..a3b738d2f --- /dev/null +++ b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/collections/LruCache.kt @@ -0,0 +1,78 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.collections + +import aws.smithy.kotlin.runtime.InternalApi +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock + +/** + * A thread-safe generic LRU (least recently used) cache. + * Entries will be added up to a configured [capacity]. + * Once full, adding a new entry will evict the least recently used entry. + */ +@InternalApi +public class LruCache( + public val capacity: Int, +) { + private val mu = Mutex() // protects map + private val map = linkedMapOf() + + init { + require(capacity > 0) { "cache capacity must be greater than 0, was $capacity" } + } + + /** + * Returns the value for a key [k], or null if it does not exist. + * @param k the key to look up + * @return the value associated with the key, or null if it does not exist + */ + public suspend fun get(k: K): V? = mu.withLock { + map.moveKeyToBack(k) + return map[k] + } + + /** + * Add or update a cache entry with a key [k] and value [v]. + * @param k the key to associate the value with + * @param v the value to store in the cache + * @return [Unit] + */ + public suspend fun put(k: K, v: V): Unit = mu.withLock { + if (k !in map && map.size == capacity) { + map.remove(map.keys.first()) + } + map[k] = v + map.moveKeyToBack(k) + } + + /** + * Remove an entry associated with a key [k], if it exists. + * @param k the key to remove from the cache + * @return the value removed, or null if it did not exist + */ + public suspend fun remove(k: K): V? = mu.withLock { map.remove(k) } + + /** + * Get a snapshot of the entries in the cache. + * Note: This is not thread-safe! the underlying entries may change immediately after calling. + */ + public val entries: Set> + get() = map.toMap().entries + + /** + * Get the current size of the cache + * Note: This is not thread-safe! The size may change immediately after calling. + */ + public val size: Int + get() = map.size +} + +// Move a key [k] to the back of the map (indicating it is most recently used) +private fun LinkedHashMap.moveKeyToBack(k: K) { + if (containsKey(k)) { + put(k, remove(k)!!) + } +} diff --git a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/hashing/Crc32.kt b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/hashing/Crc32.kt index 48704fcf9..e86402b9e 100644 --- a/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/hashing/Crc32.kt +++ b/runtime/runtime-core/common/src/aws/smithy/kotlin/runtime/hashing/Crc32.kt @@ -33,7 +33,7 @@ public abstract class Crc32Base : HashFunction { public expect class Crc32() : Crc32Base /** - * Compute the MD5 hash of the current [ByteArray] + * Compute the CRC32 checksum of the given [ByteArray] */ @InternalApi public fun ByteArray.crc32(): UInt = Crc32().apply { update(this@crc32) }.digestValue() diff --git a/runtime/runtime-core/common/test/aws/smithy/kotlin/runtime/collections/LruCacheTest.kt b/runtime/runtime-core/common/test/aws/smithy/kotlin/runtime/collections/LruCacheTest.kt new file mode 100644 index 000000000..89fddcb4e --- /dev/null +++ b/runtime/runtime-core/common/test/aws/smithy/kotlin/runtime/collections/LruCacheTest.kt @@ -0,0 +1,82 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.smithy.kotlin.runtime.collections + +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +public class LruCacheTest { + @Test + fun testGetAndPut() = runTest { + val cache = LruCache(1) + cache.put("a", "1") + assertEquals(cache.get("a"), "1") + } + + @Test + fun testGetShouldUpdateLruStatus() = runTest { + val cache = LruCache(3) + cache.put("a", "1") + cache.put("b", "2") + cache.put("c", "3") + + assertEquals(3, cache.size) + assertEquals("a, b, c", cache.entries.joinToString { it.key }) + + cache.get("a") + assertEquals("b, c, a", cache.entries.joinToString { it.key }) + } + + @Test + fun testPutShouldUpdateLruStatus() = runTest { + val cache = LruCache(3) + cache.put("a", "1") + cache.put("b", "2") + cache.put("c", "3") + + assertEquals(3, cache.size) + assertEquals("a, b, c", cache.entries.joinToString { it.key }) + + cache.put("a", "4") + assertEquals("b, c, a", cache.entries.joinToString { it.key }) + } + + @Test + fun testEviction() = runTest { + val cache = LruCache(2) + cache.put("a", "1") + cache.put("b", "2") + assertEquals(2, cache.size) + + cache.put("c", "3") + assertEquals(2, cache.size) + assertEquals("b, c", cache.entries.joinToString { it.key }) + } + + @Test + fun testUpdatingKeyWhenCacheIsFullDoesNotEvict() = runTest { + val cache = LruCache(2) + + cache.put("a", 1) + + cache.put("b", 2) + cache.put("b", 3) + + assertEquals(2, cache.size) + } + + @Test + fun testCapacity() = runTest { + assertFailsWith { + LruCache(-1) + } + assertFailsWith { + LruCache(0) + } + LruCache(1) + } +}