Skip to content

Commit

Permalink
feat: support S3 Express One Zone (#1033)
Browse files Browse the repository at this point in the history
  • Loading branch information
lauzadis authored Feb 28, 2024
1 parent 4a20344 commit 7471609
Show file tree
Hide file tree
Showing 23 changed files with 462 additions and 121 deletions.
5 changes: 5 additions & 0 deletions .changes/d2f0f8cb-b94d-403e-9db8-5cd61bd8eb1b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"id": "d2f0f8cb-b94d-403e-9db8-5cd61bd8eb1b",
"type": "feature",
"description": "Add support for S3 Express One Zone"
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ package software.amazon.smithy.kotlin.codegen.rendering.auth

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.kotlin.codegen.KotlinSettings
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
import software.amazon.smithy.kotlin.codegen.core.clientName
import software.amazon.smithy.kotlin.codegen.core.withBlock
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
import software.amazon.smithy.kotlin.codegen.model.knowledge.AuthIndex
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,29 +186,36 @@ private object Sigv4EndpointCustomization : EndpointCustomization {
// SigV4a requires SigV4 so SigV4 integration renders SigV4a auth scheme.
// See comment in example model: https://smithy.io/2.0/aws/aws-auth.html?highlight=sigv4#aws-auth-sigv4a-trait
private fun renderAuthSchemes(writer: KotlinWriter, authSchemes: Expression, expressionRenderer: ExpressionRenderer) {
writer.writeInline("#T to ", RuntimeTypes.SmithyClient.Endpoints.SigningContextAttributeKey)
writer.withBlock("listOf(", ")") {
authSchemes.toNode().expectArrayNode().forEach {
val scheme = it.expectObjectNode()
val schemeName = scheme.expectStringMember("name").value

val authFactoryFn = when (schemeName) {
"sigv4" -> RuntimeTypes.Auth.HttpAuthAws.sigV4
"sigv4a" -> RuntimeTypes.Auth.HttpAuthAws.sigV4A
else -> return@forEach
}

withBlock("#T(", "),", authFactoryFn) {
// we delegate back to the expression visitor for each of these fields because it's possible to
// encounter template strings throughout

writeInline("serviceName = ")
renderOrElse(expressionRenderer, scheme.getStringMember("signingName"), "null")

writeInline("disableDoubleUriEncode = ")
renderOrElse(expressionRenderer, scheme.getBooleanMember("disableDoubleEncoding"), "false")

renderFieldsForScheme(writer, scheme, expressionRenderer)
val schemes = authSchemes.toNode().expectArrayNode().filter {
val name = it.expectObjectNode().expectStringMember("name").value
name == "sigv4" || name == "sigv4a"
}.takeIf { it.isNotEmpty() }

schemes?.let {
writer.writeInline("#T to ", RuntimeTypes.SmithyClient.Endpoints.SigningContextAttributeKey)
writer.withBlock("listOf(", ")") {
schemes.forEach {
val scheme = it.expectObjectNode()
val schemeName = scheme.expectStringMember("name").value

val authFactoryFn = when (schemeName) {
"sigv4" -> RuntimeTypes.Auth.HttpAuthAws.sigV4
"sigv4a" -> RuntimeTypes.Auth.HttpAuthAws.sigV4A
else -> return@forEach
}

withBlock("#T(", "),", authFactoryFn) {
// we delegate back to the expression visitor for each of these fields because it's possible to
// encounter template strings throughout

writeInline("serviceName = ")
renderOrElse(expressionRenderer, scheme.getStringMember("signingName"), "null")

writeInline("disableDoubleUriEncode = ")
renderOrElse(expressionRenderer, scheme.getBooleanMember("disableDoubleEncoding"), "false")

renderFieldsForScheme(writer, scheme, expressionRenderer)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ class DefaultEndpointProviderGenerator(

private val propertyRenderers = endpointCustomizations
.map { it.propertyRenderers }
.fold(mutableMapOf<String, EndpointPropertyRenderer>()) { acc, propRenderers ->
acc.putAll(propRenderers)
.fold(mutableMapOf<String, MutableList<EndpointPropertyRenderer>>()) { acc, propRenderers ->
propRenderers.forEach { (key, propRenderer) ->
acc[key] = acc.getOrDefault(key, mutableListOf()).also { it.add(propRenderer) }
}
acc
}

Expand Down Expand Up @@ -190,7 +192,9 @@ class DefaultEndpointProviderGenerator(

// caller has a chance to generate their own value for a recognized property
if (kStr in propertyRenderers) {
propertyRenderers[kStr]!!(writer, v, this@DefaultEndpointProviderGenerator)
propertyRenderers[kStr]!!.forEach { renderer ->
renderer(writer, v, this@DefaultEndpointProviderGenerator)
}
return@forEach
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ class DefaultEndpointProviderTestGenerator(
private val endpointCustomizations = ctx.integrations.mapNotNull { it.customizeEndpointResolution(ctx) }
private val propertyRenderers = endpointCustomizations
.map { it.propertyRenderers }
.fold(mutableMapOf<String, EndpointPropertyRenderer>()) { acc, propRenderers ->
acc.putAll(propRenderers)
.fold(mutableMapOf<String, MutableList<EndpointPropertyRenderer>>()) { acc, propRenderers ->
propRenderers.forEach { (key, propRenderer) ->
acc[key] = acc.getOrDefault(key, mutableListOf()).also { it.add(propRenderer) }
}
acc
}

Expand Down Expand Up @@ -131,7 +133,9 @@ class DefaultEndpointProviderTestGenerator(
withBlock("attributes = #T {", "},", RuntimeTypes.Core.Collections.attributesOf) {
endpoint.properties.entries.forEach { (k, v) ->
if (k in propertyRenderers) {
propertyRenderers[k]!!(writer, Expression.fromNode(v), this@DefaultEndpointProviderTestGenerator)
propertyRenderers[k]!!.forEach { renderer ->
renderer(writer, Expression.fromNode(v), this@DefaultEndpointProviderTestGenerator)
}
return@forEach
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,19 +348,12 @@ open class HttpProtocolClientGenerator(
return
}

val requestAlgorithmMember = ctx.model.getShape(input.get()).getOrNull()
?.members()
?.firstOrNull { it.memberName == httpChecksumTrait?.requestAlgorithmMember?.getOrNull() }

if (hasTrait<HttpChecksumRequiredTrait>() || httpChecksumTrait?.isRequestChecksumRequired == true) {
val interceptorSymbol = RuntimeTypes.HttpClient.Interceptors.Md5ChecksumInterceptor
val inputSymbol = ctx.symbolProvider.toSymbol(ctx.model.expectShape(inputShape))

requestAlgorithmMember?.let {
writer.withBlock("op.interceptors.add(#T<#T> { ", "})", interceptorSymbol, inputSymbol) {
writer.write("it.#L?.value == null", requestAlgorithmMember.defaultName())
}
} ?: writer.write("op.interceptors.add(#T<#T>())", interceptorSymbol, inputSymbol)
writer.withBlock("op.interceptors.add(#T<#T> {", "})", interceptorSymbol, inputSymbol) {
writer.write("op.context.getOrNull(#T.ChecksumAlgorithm) == null", RuntimeTypes.HttpClient.Operation.HttpOperationContext)
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ okio-version = "3.6.0"
otel-version = "1.32.0"
slf4j-version = "2.0.9"
slf4j-v1x-version = "1.7.36"
crt-kotlin-version = "0.8.2"
crt-kotlin-version = "0.8.5"

# codegen
smithy-version = "1.42.0"
Expand Down
1 change: 1 addition & 0 deletions runtime/auth/aws-signing-common/api/aws-signing-common.api
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public final class aws/smithy/kotlin/runtime/auth/awssigning/AwsSigningAttribute
public final fun getEnableAwsChunked ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
public final fun getHashSpecification ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
public final fun getNormalizeUriPath ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
public final fun getOmitSessionToken ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
public final fun getRequestSignature ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
public final fun getSignedBodyHeader ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
public final fun getSigner ()Laws/smithy/kotlin/runtime/collections/AttributeKey;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,9 @@ public object AwsSigningAttributes {
* @see <a href="https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html">SigV4 Streaming</a>
*/
public val EnableAwsChunked: AttributeKey<Boolean> = AttributeKey("aws.smithy.kotlin.signing#EnableAwsChunked")

/**
* Flag indicating whether the X-Amz-Security-Token header should be omitted from the canonical request during signing.
*/
public val OmitSessionToken: AttributeKey<Boolean> = AttributeKey("aws.smithy.kotlin.signing#OmitSessionToken")
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,30 @@ import aws.sdk.kotlin.crt.auth.signing.AwsSigningAlgorithm as CrtSigningAlgorith
import aws.sdk.kotlin.crt.auth.signing.AwsSigningConfig as CrtSigningConfig
import aws.sdk.kotlin.crt.http.Headers as CrtHeaders

private const val S3_EXPRESS_HEADER_NAME = "X-Amz-S3session-Token"

public object CrtAwsSigner : AwsSigner {
override suspend fun sign(request: HttpRequest, config: AwsSigningConfig): AwsSigningResult<HttpRequest> {
val isUnsigned = config.hashSpecification is HashSpecification.UnsignedPayload
val isAwsChunked = request.headers.contains("Content-Encoding", "aws-chunked")
val crtRequest = request.toSignableCrtRequest(isUnsigned, isAwsChunked)
val crtConfig = config.toCrtSigningConfig()
val isS3Express = request.headers.contains(S3_EXPRESS_HEADER_NAME)

val requestBuilder = request.toBuilder()

val crtResult = CrtSigner.sign(crtRequest, crtConfig)
val crtConfig = config.toCrtSigningConfig().toBuilder()
if (isS3Express) {
crtConfig.algorithm = CrtSigningAlgorithm.SIGV4_S3EXPRESS
crtConfig.omitSessionToken = false
requestBuilder.headers.remove(S3_EXPRESS_HEADER_NAME) // CRT signer fails if this header is already present
}

val crtRequest = requestBuilder.build().toSignableCrtRequest(isUnsigned, isAwsChunked)

val crtResult = CrtSigner.sign(crtRequest, crtConfig.build())
coroutineContext.debug<CrtAwsSigner> { "Calculated signature: ${crtResult.signature.decodeToString()}" }

val crtSignedResult = checkNotNull(crtResult.signedRequest) { "Signed request unexpectedly null" }

val requestBuilder = request.toBuilder()
requestBuilder.update(crtSignedResult)
return AwsSigningResult(requestBuilder.build(), crtResult.signature)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ public class AwsHttpSigner(private val config: Config) : HttpSigner {
val contextUseDoubleUriEncode = attributes.getOrNull(AwsSigningAttributes.UseDoubleUriEncode)
val contextNormalizeUriPath = attributes.getOrNull(AwsSigningAttributes.NormalizeUriPath)
val contextSigningServiceName = attributes.getOrNull(AwsSigningAttributes.SigningService)
val contextOmitSessionToken = attributes.getOrNull(AwsSigningAttributes.OmitSessionToken)

val enableAwsChunked = attributes.getOrNull(AwsSigningAttributes.EnableAwsChunked) ?: false

Expand All @@ -143,7 +144,7 @@ public class AwsHttpSigner(private val config: Config) : HttpSigner {
?: (Instant.now() + (attributes.getOrNull(HttpOperationContext.ClockSkew) ?: Duration.ZERO))

signatureType = config.signatureType
omitSessionToken = config.omitSessionToken
omitSessionToken = contextOmitSessionToken ?: config.omitSessionToken
normalizeUriPath = contextNormalizeUriPath ?: config.normalizeUriPath
useDoubleUriEncode = contextUseDoubleUriEncode ?: config.useDoubleUriEncode
expiresAfter = config.expiresAfter
Expand Down
Loading

0 comments on commit 7471609

Please sign in to comment.