Skip to content

Commit

Permalink
fix: correctly deserialize lists of documents (#748)
Browse files Browse the repository at this point in the history
  • Loading branch information
ianbotsf authored Nov 23, 2022
1 parent 1f954c8 commit f0c27da
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 20 deletions.
5 changes: 5 additions & 0 deletions .changes/2032f172-904d-474b-b554-4d7fa04e160c.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"id": "2032f172-904d-474b-b554-4d7fa04e160c",
"type": "bugfix",
"description": "Fix deserialization error for shapes with lists of document types"
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ object KotlinTypes {
val List: Symbol = builtInSymbol("List", "kotlin.collections")
val Set: Symbol = builtInSymbol("Set", "kotlin.collections")
val Map: Symbol = builtInSymbol("Map", "kotlin.collections")
val mutableListOf: Symbol = builtInSymbol("mutableListOf", "kotlin.collections")
val mutableMapOf: Symbol = builtInSymbol("mutableMapOf", "kotlin.collections")
}

object Jvm {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package software.amazon.smithy.kotlin.codegen.rendering.serde

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
import software.amazon.smithy.kotlin.codegen.model.*
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
import software.amazon.smithy.model.shapes.*
Expand Down Expand Up @@ -137,15 +138,20 @@ open class DeserializeStructGenerator(
val nestingLevel = 0
val memberName = ctx.symbolProvider.toMemberName(memberShape)
val descriptorName = memberShape.descriptorName()
val mutableCollectionType = targetShape.mutableCollectionType()
val valueCollector = deserializationResultName("builder.$memberName")
val mutableCollectionName = nestingLevel.variableNameFor(NestedIdentifierType.MAP)
val collectionReturnExpression = collectionReturnExpression(memberShape, mutableCollectionName)

writer.write("$descriptorName.index -> $valueCollector = ")
.indent()
.withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) {
write("val $mutableCollectionName = $mutableCollectionType()")
write(
"val #L = #T<String, #T#L>()",
mutableCollectionName,
KotlinTypes.Collections.mutableMapOf,
ctx.symbolProvider.toSymbol(targetShape.value),
nullabilitySuffix(targetShape.isSparse),
)
withBlock("while (hasNextEntry()) {", "}") {
delegateMapDeserialization(memberShape, targetShape, nestingLevel, mutableCollectionName)
}
Expand Down Expand Up @@ -252,7 +258,6 @@ open class DeserializeStructGenerator(
val valueName = nestingLevel.variableNameFor(NestedIdentifierType.VALUE)
val populateNullValuePostfix = if (isSparse) "" else "; continue"
val descriptorName = rootMemberShape.descriptorName(nestingLevel.nestedDescriptorName())
val mutableCollectionType = mapShape.mutableCollectionType()
val nextNestingLevel = nestingLevel + 1
val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType.MAP)
val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName)
Expand All @@ -261,7 +266,13 @@ open class DeserializeStructGenerator(
writer.withBlock("val $valueName =", "") {
withBlock("if (nextHasValue()) {", "} else { deserializeNull()$populateNullValuePostfix }") {
withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) {
write("val $memberName = $mutableCollectionType()")
write(
"val #L = #T<String, #T#L>()",
memberName,
KotlinTypes.Collections.mutableMapOf,
ctx.symbolProvider.toSymbol(mapShape.value),
nullabilitySuffix(mapShape.isSparse),
)
withBlock("while (hasNextEntry()) {", "}") {
delegateMapDeserialization(rootMemberShape, mapShape, nextNestingLevel, memberName)
}
Expand Down Expand Up @@ -298,7 +309,6 @@ open class DeserializeStructGenerator(
val valueName = nestingLevel.variableNameFor(NestedIdentifierType.VALUE)
val populateNullValuePostfix = if (isSparse) "" else "; continue"
val descriptorName = rootMemberShape.descriptorName(nestingLevel.nestedDescriptorName())
val mutableCollectionType = collectionShape.mutableCollectionType()
val nextNestingLevel = nestingLevel + 1
val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType.COLLECTION)
val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName)
Expand All @@ -307,7 +317,13 @@ open class DeserializeStructGenerator(
writer.withBlock("val $valueName =", "") {
withBlock("if (nextHasValue()) {", "} else { deserializeNull()$populateNullValuePostfix }") {
withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeList) {
write("val $memberName = $mutableCollectionType()")
write(
"val #L = #T<#T#L>()",
memberName,
KotlinTypes.Collections.mutableListOf,
ctx.symbolProvider.toSymbol(collectionShape.member),
nullabilitySuffix(collectionShape.isSparse),
)
withBlock("while (hasNextElement()) {", "}") {
delegateListDeserialization(rootMemberShape, collectionShape, nextNestingLevel, memberName)
}
Expand Down Expand Up @@ -353,15 +369,20 @@ open class DeserializeStructGenerator(
val nestingLevel = 0
val memberName = ctx.symbolProvider.toMemberName(memberShape)
val descriptorName = memberShape.descriptorName()
val mutableCollectionType = targetShape.mutableCollectionType()
val valueCollector = deserializationResultName("builder.$memberName")
val mutableCollectionName = nestingLevel.variableNameFor(NestedIdentifierType.COLLECTION)
val collectionReturnExpression = collectionReturnExpression(memberShape, mutableCollectionName)

writer.write("$descriptorName.index -> $valueCollector = ")
.indent()
.withBlock("deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeList) {
write("val $mutableCollectionName = $mutableCollectionType()")
write(
"val #L = #T<#T#L>()",
mutableCollectionName,
KotlinTypes.Collections.mutableListOf,
ctx.symbolProvider.toSymbol(targetShape.member),
nullabilitySuffix(targetShape.isSparse),
)
withBlock("while (hasNextElement()) {", "}") {
delegateListDeserialization(memberShape, targetShape, nestingLevel, mutableCollectionName)
}
Expand Down Expand Up @@ -454,11 +475,16 @@ open class DeserializeStructGenerator(
val elementName = nestingLevel.variableNameFor(NestedIdentifierType.ELEMENT)
val nextNestingLevel = nestingLevel + 1
val mapName = nextNestingLevel.variableNameFor(NestedIdentifierType.MAP)
val mutableCollectionType = mapShape.mutableCollectionType()
val collectionReturnExpression = collectionReturnExpression(rootMemberShape, mapName)

writer.withBlock("val $elementName = deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) {
write("val $mapName = $mutableCollectionType()")
write(
"val #L = #T<String, #T#L>()",
mapName,
KotlinTypes.Collections.mutableMapOf,
ctx.symbolProvider.toSymbol(mapShape.value),
nullabilitySuffix(mapShape.isSparse),
)
withBlock("while (hasNextEntry()) {", "}") {
delegateMapDeserialization(rootMemberShape, mapShape, nextNestingLevel, mapName)
}
Expand Down Expand Up @@ -487,11 +513,16 @@ open class DeserializeStructGenerator(
val elementName = nestingLevel.variableNameFor(NestedIdentifierType.ELEMENT)
val nextNestingLevel = nestingLevel + 1
val listName = nextNestingLevel.variableNameFor(NestedIdentifierType.COLLECTION)
val mutableCollectionType = elementShape.mutableCollectionType()
val collectionReturnExpression = collectionReturnExpression(rootMemberShape, listName)

writer.withBlock("val $elementName = deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeList) {
write("val $listName = $mutableCollectionType()")
write(
"val #L = #T<#T#L>()",
listName,
KotlinTypes.Collections.mutableListOf,
ctx.symbolProvider.toSymbol(elementShape.member),
nullabilitySuffix(elementShape.isSparse),
)
withBlock("while (hasNextElement()) {", "}") {
delegateListDeserialization(rootMemberShape, elementShape, nextNestingLevel, listName)
}
Expand Down Expand Up @@ -569,12 +600,6 @@ open class DeserializeStructGenerator(
else -> throw CodegenException("unknown deserializer for member: $shape; target: $target")
}
}

// Return the function to generate a mutable instance of collection type of input shape.
private fun MapShape.mutableCollectionType(): String =
ctx.symbolProvider.toSymbol(this).getProperty(SymbolProperty.MUTABLE_COLLECTION_FUNCTION).get() as String

// Return the function to generate a mutable instance of collection type of input shape.
private fun CollectionShape.mutableCollectionType(): String =
ctx.symbolProvider.toSymbol(this).getProperty(SymbolProperty.MUTABLE_COLLECTION_FUNCTION).get() as String
}

private fun nullabilitySuffix(isSparse: Boolean): String = if (isSparse) "?" else ""
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,46 @@ class DeserializeStructGeneratorTest {
actual.shouldContainOnlyOnceWithDiff(expected)
}

@Test
fun `it deserializes a structure with a list of document values`() {
val model = (
modelPrefix + """
structure FooResponse {
payload: DocumentList
}
list DocumentList {
member: Document
}
"""
).toSmithyModel()

val expected = """
deserializer.deserializeStruct(OBJ_DESCRIPTOR) {
loop@while (true) {
when (findNextFieldIndex()) {
PAYLOAD_DESCRIPTOR.index -> builder.payload =
deserializer.deserializeList(PAYLOAD_DESCRIPTOR) {
val col0 = mutableListOf<Document>()
while (hasNextElement()) {
val el0 = if (nextHasValue()) { deserializeDocument() } else { deserializeNull(); continue }
col0.add(el0)
}
col0
}
null -> break@loop
else -> skipValue()
}
}
}
""".trimIndent()

val actual = codegenDeserializerForShape(model, "com.test#Foo")

actual.shouldContainOnlyOnceWithDiff(expected)
actual.shouldContainOnlyOnceWithDiff("import aws.smithy.kotlin.runtime.smithy.Document")
}

@Test
fun `it deserializes a structure with a nested structure`() {
val model = (
Expand Down

0 comments on commit f0c27da

Please sign in to comment.