From fb865b243a8cc39fae89ae7f7357ac863326d8dd Mon Sep 17 00:00:00 2001 From: Raphael Simon Date: Sat, 9 Mar 2024 16:34:47 -0800 Subject: [PATCH] Properly handle decoding of maps with nil entries (#3490) --- codegen/go_transform.go | 40 ++++++++++--------- codegen/go_transform_test.go | 8 ++++ .../testdata/parse_endpoint_functions.go | 4 ++ .../testdata/payload_constructor_functions.go | 4 ++ 4 files changed, 38 insertions(+), 18 deletions(-) diff --git a/codegen/go_transform.go b/codegen/go_transform.go index fc29ce5b58..805ea43ed5 100644 --- a/codegen/go_transform.go +++ b/codegen/go_transform.go @@ -72,24 +72,6 @@ func GoTransform(source, target *expr.AttributeExpr, sourceVar, targetVar string return strings.TrimRight(code, "\n"), funcs, nil } -// transformPrimitive returns the code to transform source primtive type to -// target primitive type. It returns an error if source and target are not -// compatible for transformation. -func transformPrimitive(source, target *expr.AttributeExpr, sourceVar, targetVar string, newVar bool, ta *TransformAttrs) (string, error) { - if err := IsCompatible(source.Type, target.Type, sourceVar, targetVar); err != nil { - return "", err - } - assign := "=" - if newVar { - assign = ":=" - } - if source.Type.Name() != target.Type.Name() { - cast := ta.TargetCtx.Scope.Ref(target, ta.TargetCtx.Pkg(target)) - return fmt.Sprintf("%s %s %s(%s)\n", targetVar, assign, cast, sourceVar), nil - } - return fmt.Sprintf("%s %s %s\n", targetVar, assign, sourceVar), nil -} - // transformAttribute returns the code to transform source attribute to target // attribute. It returns an error if source and target are not compatible for // transformation. @@ -112,6 +94,24 @@ func transformAttribute(source, target *expr.AttributeExpr, sourceVar, targetVar return } +// transformPrimitive returns the code to transform source primtive type to +// target primitive type. It returns an error if source and target are not +// compatible for transformation. +func transformPrimitive(source, target *expr.AttributeExpr, sourceVar, targetVar string, newVar bool, ta *TransformAttrs) (string, error) { + if err := IsCompatible(source.Type, target.Type, sourceVar, targetVar); err != nil { + return "", err + } + assign := "=" + if newVar { + assign = ":=" + } + if source.Type.Name() != target.Type.Name() { + cast := ta.TargetCtx.Scope.Ref(target, ta.TargetCtx.Pkg(target)) + return fmt.Sprintf("%s %s %s(%s)\n", targetVar, assign, cast, sourceVar), nil + } + return fmt.Sprintf("%s %s %s\n", targetVar, assign, sourceVar), nil +} + // transformObject generates Go code to transform source object to target // object. func transformObject(source, target *expr.AttributeExpr, sourceVar, targetVar string, newVar bool, ta *TransformAttrs) (string, error) { @@ -682,6 +682,10 @@ for key, val := range {{ .SourceVar }} { {{ transformAttribute .SourceKey .TargetKey "key" "tk" true .TransformAttrs -}} {{ end -}} {{ if .IsElemStruct -}} + if val == nil { + {{ .TargetVar }}[tk] = nil + continue + } {{ .TargetVar }}[tk] = {{ transformHelperName .SourceElem .TargetElem .TransformAttrs -}}(val) {{ else -}} {{ transformAttribute .SourceElem .TargetElem "val" (printf "tv%s" .LoopVar) true .TransformAttrs -}} diff --git a/codegen/go_transform_test.go b/codegen/go_transform_test.go index 316bcb8919..9ad3c2681f 100644 --- a/codegen/go_transform_test.go +++ b/codegen/go_transform_test.go @@ -432,6 +432,10 @@ const ( target.TypeMap = make(map[string]*SimpleMap, len(source.TypeMap)) for key, val := range source.TypeMap { tk := key + if val == nil { + target.TypeMap[tk] = nil + continue + } target.TypeMap[tk] = transformSimpleMapToSimpleMap(val) } } @@ -613,6 +617,10 @@ const ( target.Recursive = make(map[string]*RecursiveMap, len(source.Recursive)) for key, val := range source.Recursive { tk := key + if val == nil { + target.Recursive[tk] = nil + continue + } target.Recursive[tk] = transformRecursiveMapToRecursiveMap(val) } } diff --git a/http/codegen/testdata/parse_endpoint_functions.go b/http/codegen/testdata/parse_endpoint_functions.go index 0a5c6c0c88..4b5ff9427d 100644 --- a/http/codegen/testdata/parse_endpoint_functions.go +++ b/http/codegen/testdata/parse_endpoint_functions.go @@ -992,6 +992,10 @@ func BuildMethodBodyInlineMapUserPayload(serviceBodyInlineMapUserMethodBodyInlin v := make(map[*servicebodyinlinemapuser.KeyType]*servicebodyinlinemapuser.ElemType, len(body)) for key, val := range body { tk := marshalKeyTypeRequestBodyToServicebodyinlinemapuserKeyType(val) + if val == nil { + v[tk] = nil + continue + } v[tk] = marshalElemTypeRequestBodyToServicebodyinlinemapuserElemType(val) } return v, nil diff --git a/http/codegen/testdata/payload_constructor_functions.go b/http/codegen/testdata/payload_constructor_functions.go index e26b5512d3..57b700e6cd 100644 --- a/http/codegen/testdata/payload_constructor_functions.go +++ b/http/codegen/testdata/payload_constructor_functions.go @@ -1009,6 +1009,10 @@ func NewMethodBodyInlineMapUserMapKeyTypeElemType(body map[*KeyTypeRequestBody]* v := make(map[*servicebodyinlinemapuser.KeyType]*servicebodyinlinemapuser.ElemType, len(body)) for key, val := range body { tk := unmarshalKeyTypeRequestBodyToServicebodyinlinemapuserKeyType(val) + if val == nil { + v[tk] = nil + continue + } v[tk] = unmarshalElemTypeRequestBodyToServicebodyinlinemapuserElemType(val) } return v