Skip to content

Commit

Permalink
Properly handle decoding of maps with nil entries
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael committed Mar 10, 2024
1 parent 6570900 commit e12e25f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 18 deletions.
40 changes: 22 additions & 18 deletions codegen/go_transform.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,24 +72,6 @@ func GoTransform(source, target *expr.AttributeExpr, sourceVar, targetVar string
return strings.TrimRight(code, "\n"), funcs, nil
}

// transformPrimitive returns the code to transform source primtive type to
// target primitive type. It returns an error if source and target are not
// compatible for transformation.
func transformPrimitive(source, target *expr.AttributeExpr, sourceVar, targetVar string, newVar bool, ta *TransformAttrs) (string, error) {
if err := IsCompatible(source.Type, target.Type, sourceVar, targetVar); err != nil {
return "", err
}
assign := "="
if newVar {
assign = ":="
}
if source.Type.Name() != target.Type.Name() {
cast := ta.TargetCtx.Scope.Ref(target, ta.TargetCtx.Pkg(target))
return fmt.Sprintf("%s %s %s(%s)\n", targetVar, assign, cast, sourceVar), nil
}
return fmt.Sprintf("%s %s %s\n", targetVar, assign, sourceVar), nil
}

// transformAttribute returns the code to transform source attribute to target
// attribute. It returns an error if source and target are not compatible for
// transformation.
Expand All @@ -112,6 +94,24 @@ func transformAttribute(source, target *expr.AttributeExpr, sourceVar, targetVar
return
}

// transformPrimitive returns the code to transform source primtive type to
// target primitive type. It returns an error if source and target are not
// compatible for transformation.
func transformPrimitive(source, target *expr.AttributeExpr, sourceVar, targetVar string, newVar bool, ta *TransformAttrs) (string, error) {
if err := IsCompatible(source.Type, target.Type, sourceVar, targetVar); err != nil {
return "", err
}
assign := "="
if newVar {
assign = ":="
}
if source.Type.Name() != target.Type.Name() {
cast := ta.TargetCtx.Scope.Ref(target, ta.TargetCtx.Pkg(target))
return fmt.Sprintf("%s %s %s(%s)\n", targetVar, assign, cast, sourceVar), nil
}
return fmt.Sprintf("%s %s %s\n", targetVar, assign, sourceVar), nil
}

// transformObject generates Go code to transform source object to target
// object.
func transformObject(source, target *expr.AttributeExpr, sourceVar, targetVar string, newVar bool, ta *TransformAttrs) (string, error) {
Expand Down Expand Up @@ -682,6 +682,10 @@ for key, val := range {{ .SourceVar }} {
{{ transformAttribute .SourceKey .TargetKey "key" "tk" true .TransformAttrs -}}
{{ end -}}
{{ if .IsElemStruct -}}
if val == nil {
{{ .TargetVar }}[tk] = nil
continue
}
{{ .TargetVar }}[tk] = {{ transformHelperName .SourceElem .TargetElem .TransformAttrs -}}(val)
{{ else -}}
{{ transformAttribute .SourceElem .TargetElem "val" (printf "tv%s" .LoopVar) true .TransformAttrs -}}
Expand Down
8 changes: 8 additions & 0 deletions codegen/go_transform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,10 @@ const (
target.TypeMap = make(map[string]*SimpleMap, len(source.TypeMap))
for key, val := range source.TypeMap {
tk := key
if val == nil {
target.TypeMap[tk] = nil
continue
}
target.TypeMap[tk] = transformSimpleMapToSimpleMap(val)
}
}
Expand Down Expand Up @@ -613,6 +617,10 @@ const (
target.Recursive = make(map[string]*RecursiveMap, len(source.Recursive))
for key, val := range source.Recursive {
tk := key
if val == nil {
target.Recursive[tk] = nil
continue
}
target.Recursive[tk] = transformRecursiveMapToRecursiveMap(val)
}
}
Expand Down
4 changes: 4 additions & 0 deletions http/codegen/testdata/parse_endpoint_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,10 @@ func BuildMethodBodyInlineMapUserPayload(serviceBodyInlineMapUserMethodBodyInlin
v := make(map[*servicebodyinlinemapuser.KeyType]*servicebodyinlinemapuser.ElemType, len(body))
for key, val := range body {
tk := marshalKeyTypeRequestBodyToServicebodyinlinemapuserKeyType(val)
if val == nil {
v[tk] = nil
continue
}
v[tk] = marshalElemTypeRequestBodyToServicebodyinlinemapuserElemType(val)
}
return v, nil
Expand Down
4 changes: 4 additions & 0 deletions http/codegen/testdata/payload_constructor_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,10 @@ func NewMethodBodyInlineMapUserMapKeyTypeElemType(body map[*KeyTypeRequestBody]*
v := make(map[*servicebodyinlinemapuser.KeyType]*servicebodyinlinemapuser.ElemType, len(body))
for key, val := range body {
tk := unmarshalKeyTypeRequestBodyToServicebodyinlinemapuserKeyType(val)
if val == nil {
v[tk] = nil
continue
}
v[tk] = unmarshalElemTypeRequestBodyToServicebodyinlinemapuserElemType(val)
}
return v
Expand Down

0 comments on commit e12e25f

Please sign in to comment.