From f0c27da166db11dc74288f237d4b5742ee7902ea Mon Sep 17 00:00:00 2001 From: Ian Botsford <83236726+ianbotsf@users.noreply.github.com> Date: Wed, 23 Nov 2022 06:57:07 -0800 Subject: [PATCH] fix: correctly deserialize lists of documents (#748) --- .../2032f172-904d-474b-b554-4d7fa04e160c.json | 5 ++ .../smithy/kotlin/codegen/lang/KotlinTypes.kt | 2 + .../serde/DeserializeStructGenerator.kt | 65 +++++++++++++------ .../serde/DeserializeStructGeneratorTest.kt | 40 ++++++++++++ 4 files changed, 92 insertions(+), 20 deletions(-) create mode 100644 .changes/2032f172-904d-474b-b554-4d7fa04e160c.json diff --git a/.changes/2032f172-904d-474b-b554-4d7fa04e160c.json b/.changes/2032f172-904d-474b-b554-4d7fa04e160c.json new file mode 100644 index 000000000..458674b2d --- /dev/null +++ b/.changes/2032f172-904d-474b-b554-4d7fa04e160c.json @@ -0,0 +1,5 @@ +{ + "id": "2032f172-904d-474b-b554-4d7fa04e160c", + "type": "bugfix", + "description": "Fix deserialization error for shapes with lists of document types" +} \ No newline at end of file diff --git a/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt b/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt index 967fa9251..5ae8fb3dd 100644 --- a/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt +++ b/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt @@ -38,6 +38,8 @@ object KotlinTypes { val List: Symbol = builtInSymbol("List", "kotlin.collections") val Set: Symbol = builtInSymbol("Set", "kotlin.collections") val Map: Symbol = builtInSymbol("Map", "kotlin.collections") + val mutableListOf: Symbol = builtInSymbol("mutableListOf", "kotlin.collections") + val mutableMapOf: Symbol = builtInSymbol("mutableMapOf", "kotlin.collections") } object Jvm { diff --git a/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt b/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt index 942a05d75..5bb455d42 100644 --- a/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt +++ b/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGenerator.kt @@ -6,6 +6,7 @@ package software.amazon.smithy.kotlin.codegen.rendering.serde import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.kotlin.codegen.core.* +import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes import software.amazon.smithy.kotlin.codegen.model.* import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator import software.amazon.smithy.model.shapes.* @@ -137,7 +138,6 @@ open class DeserializeStructGenerator( val nestingLevel = 0 val memberName = ctx.symbolProvider.toMemberName(memberShape) val descriptorName = memberShape.descriptorName() - val mutableCollectionType = targetShape.mutableCollectionType() val valueCollector = deserializationResultName("builder.$memberName") val mutableCollectionName = nestingLevel.variableNameFor(NestedIdentifierType.MAP) val collectionReturnExpression = collectionReturnExpression(memberShape, mutableCollectionName) @@ -145,7 +145,13 @@ open class DeserializeStructGenerator( writer.write("$descriptorName.index -> $valueCollector = ") .indent() .withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) { - write("val $mutableCollectionName = $mutableCollectionType()") + write( + "val #L = #T()", + mutableCollectionName, + KotlinTypes.Collections.mutableMapOf, + ctx.symbolProvider.toSymbol(targetShape.value), + nullabilitySuffix(targetShape.isSparse), + ) withBlock("while (hasNextEntry()) {", "}") { delegateMapDeserialization(memberShape, targetShape, nestingLevel, mutableCollectionName) } @@ -252,7 +258,6 @@ open class DeserializeStructGenerator( val valueName = nestingLevel.variableNameFor(NestedIdentifierType.VALUE) val populateNullValuePostfix = if (isSparse) "" else "; continue" val descriptorName = rootMemberShape.descriptorName(nestingLevel.nestedDescriptorName()) - val mutableCollectionType = mapShape.mutableCollectionType() val nextNestingLevel = nestingLevel + 1 val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType.MAP) val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName) @@ -261,7 +266,13 @@ open class DeserializeStructGenerator( writer.withBlock("val $valueName =", "") { withBlock("if (nextHasValue()) {", "} else { deserializeNull()$populateNullValuePostfix }") { withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) { - write("val $memberName = $mutableCollectionType()") + write( + "val #L = #T()", + memberName, + KotlinTypes.Collections.mutableMapOf, + ctx.symbolProvider.toSymbol(mapShape.value), + nullabilitySuffix(mapShape.isSparse), + ) withBlock("while (hasNextEntry()) {", "}") { delegateMapDeserialization(rootMemberShape, mapShape, nextNestingLevel, memberName) } @@ -298,7 +309,6 @@ open class DeserializeStructGenerator( val valueName = nestingLevel.variableNameFor(NestedIdentifierType.VALUE) val populateNullValuePostfix = if (isSparse) "" else "; continue" val descriptorName = rootMemberShape.descriptorName(nestingLevel.nestedDescriptorName()) - val mutableCollectionType = collectionShape.mutableCollectionType() val nextNestingLevel = nestingLevel + 1 val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType.COLLECTION) val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName) @@ -307,7 +317,13 @@ open class DeserializeStructGenerator( writer.withBlock("val $valueName =", "") { withBlock("if (nextHasValue()) {", "} else { deserializeNull()$populateNullValuePostfix }") { withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeList) { - write("val $memberName = $mutableCollectionType()") + write( + "val #L = #T<#T#L>()", + memberName, + KotlinTypes.Collections.mutableListOf, + ctx.symbolProvider.toSymbol(collectionShape.member), + nullabilitySuffix(collectionShape.isSparse), + ) withBlock("while (hasNextElement()) {", "}") { delegateListDeserialization(rootMemberShape, collectionShape, nextNestingLevel, memberName) } @@ -353,7 +369,6 @@ open class DeserializeStructGenerator( val nestingLevel = 0 val memberName = ctx.symbolProvider.toMemberName(memberShape) val descriptorName = memberShape.descriptorName() - val mutableCollectionType = targetShape.mutableCollectionType() val valueCollector = deserializationResultName("builder.$memberName") val mutableCollectionName = nestingLevel.variableNameFor(NestedIdentifierType.COLLECTION) val collectionReturnExpression = collectionReturnExpression(memberShape, mutableCollectionName) @@ -361,7 +376,13 @@ open class DeserializeStructGenerator( writer.write("$descriptorName.index -> $valueCollector = ") .indent() .withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeList) { - write("val $mutableCollectionName = $mutableCollectionType()") + write( + "val #L = #T<#T#L>()", + mutableCollectionName, + KotlinTypes.Collections.mutableListOf, + ctx.symbolProvider.toSymbol(targetShape.member), + nullabilitySuffix(targetShape.isSparse), + ) withBlock("while (hasNextElement()) {", "}") { delegateListDeserialization(memberShape, targetShape, nestingLevel, mutableCollectionName) } @@ -454,11 +475,16 @@ open class DeserializeStructGenerator( val elementName = nestingLevel.variableNameFor(NestedIdentifierType.ELEMENT) val nextNestingLevel = nestingLevel + 1 val mapName = nextNestingLevel.variableNameFor(NestedIdentifierType.MAP) - val mutableCollectionType = mapShape.mutableCollectionType() val collectionReturnExpression = collectionReturnExpression(rootMemberShape, mapName) writer.withBlock("val $elementName = deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) { - write("val $mapName = $mutableCollectionType()") + write( + "val #L = #T()", + mapName, + KotlinTypes.Collections.mutableMapOf, + ctx.symbolProvider.toSymbol(mapShape.value), + nullabilitySuffix(mapShape.isSparse), + ) withBlock("while (hasNextEntry()) {", "}") { delegateMapDeserialization(rootMemberShape, mapShape, nextNestingLevel, mapName) } @@ -487,11 +513,16 @@ open class DeserializeStructGenerator( val elementName = nestingLevel.variableNameFor(NestedIdentifierType.ELEMENT) val nextNestingLevel = nestingLevel + 1 val listName = nextNestingLevel.variableNameFor(NestedIdentifierType.COLLECTION) - val mutableCollectionType = elementShape.mutableCollectionType() val collectionReturnExpression = collectionReturnExpression(rootMemberShape, listName) writer.withBlock("val $elementName = deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeList) { - write("val $listName = $mutableCollectionType()") + write( + "val #L = #T<#T#L>()", + listName, + KotlinTypes.Collections.mutableListOf, + ctx.symbolProvider.toSymbol(elementShape.member), + nullabilitySuffix(elementShape.isSparse), + ) withBlock("while (hasNextElement()) {", "}") { delegateListDeserialization(rootMemberShape, elementShape, nextNestingLevel, listName) } @@ -569,12 +600,6 @@ open class DeserializeStructGenerator( else -> throw CodegenException("unknown deserializer for member: $shape; target: $target") } } - - // Return the function to generate a mutable instance of collection type of input shape. - private fun MapShape.mutableCollectionType(): String = - ctx.symbolProvider.toSymbol(this).getProperty(SymbolProperty.MUTABLE_COLLECTION_FUNCTION).get() as String - - // Return the function to generate a mutable instance of collection type of input shape. - private fun CollectionShape.mutableCollectionType(): String = - ctx.symbolProvider.toSymbol(this).getProperty(SymbolProperty.MUTABLE_COLLECTION_FUNCTION).get() as String } + +private fun nullabilitySuffix(isSparse: Boolean): String = if (isSparse) "?" else "" diff --git a/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGeneratorTest.kt b/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGeneratorTest.kt index 27fa5f941..3730a144b 100644 --- a/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGeneratorTest.kt +++ b/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/DeserializeStructGeneratorTest.kt @@ -118,6 +118,46 @@ class DeserializeStructGeneratorTest { actual.shouldContainOnlyOnceWithDiff(expected) } + @Test + fun `it deserializes a structure with a list of document values`() { + val model = ( + modelPrefix + """ + structure FooResponse { + payload: DocumentList + } + + list DocumentList { + member: Document + } + """ + ).toSmithyModel() + + val expected = """ + deserializer.deserializeStruct(OBJ_DESCRIPTOR) { + loop@while (true) { + when (findNextFieldIndex()) { + PAYLOAD_DESCRIPTOR.index -> builder.payload = + deserializer.deserializeList(PAYLOAD_DESCRIPTOR) { + val col0 = mutableListOf() + while (hasNextElement()) { + val el0 = if (nextHasValue()) { deserializeDocument() } else { deserializeNull(); continue } + col0.add(el0) + } + col0 + } + null -> break@loop + else -> skipValue() + } + } + } + """.trimIndent() + + val actual = codegenDeserializerForShape(model, "com.test#Foo") + + actual.shouldContainOnlyOnceWithDiff(expected) + actual.shouldContainOnlyOnceWithDiff("import aws.smithy.kotlin.runtime.smithy.Document") + } + @Test fun `it deserializes a structure with a nested structure`() { val model = (