diff --git a/codegen/cli/cli.go b/codegen/cli/cli.go index 7e3c461acf..f91758f456 100644 --- a/codegen/cli/cli.go +++ b/codegen/cli/cli.go @@ -192,43 +192,41 @@ func BuildSubcommandData(data *service.Data, m *service.MethodData, buildFunctio conversion string interceptors *InterceptorData ) - { - en := m.Name - name = codegen.KebabCase(en) - fullName = goifyTerms(data.Name, en) - description = m.Description - if description == "" { - description = fmt.Sprintf("Make request to the %q endpoint", m.Name) - } + en := m.Name + name = codegen.KebabCase(en) + fullName = goifyTerms(data.Name, en) + description = m.Description + if description == "" { + description = fmt.Sprintf("Make request to the %q endpoint", m.Name) + } - if m.Payload != "" && buildFunction == nil && len(flags) > 0 { - // No build function, just convert the arg to the body type - var convPre, convSuff string - target := "data" + if m.Payload != "" && buildFunction == nil && len(flags) > 0 { + // No build function, just convert the arg to the body type + var convPre, convSuff string + target := "data" + if flagType(m.Payload) == "JSON" { + target = "val" + convPre = fmt.Sprintf("var val %s\n", m.Payload) + convSuff = "\ndata = val" + } + conv, _, check := conversionCode( + "*"+flags[0].FullName+"Flag", + target, + m.Payload, + false, + ) + conversion = convPre + conv + convSuff + if check { + conversion = "var err error\n" + conversion + conversion += "\nif err != nil {\n" if flagType(m.Payload) == "JSON" { - target = "val" - convPre = fmt.Sprintf("var val %s\n", m.Payload) - convSuff = "\ndata = val" - } - conv, _, check := conversionCode( - "*"+flags[0].FullName+"Flag", - target, - m.Payload, - false, - ) - conversion = convPre + conv + convSuff - if check { - conversion = "var err error\n" + conversion - conversion += "\nif err != nil {\n" - if flagType(m.Payload) == "JSON" { - conversion += fmt.Sprintf(`return nil, nil, fmt.Errorf("invalid JSON for %s, \nerror: %%s, \nexample of valid JSON:\n%%s", err, %q)`, - flags[0].FullName+"Flag", flags[0].Example) - } else { - conversion += fmt.Sprintf(`return nil, nil, fmt.Errorf("invalid value for %s, must be %s")`, - flags[0].FullName+"Flag", flags[0].Type) - } - conversion += "\n}" + conversion += fmt.Sprintf(`return nil, nil, fmt.Errorf("invalid JSON for %s, \nerror: %%s, \nexample of valid JSON:\n%%s", err, %q)`, + flags[0].FullName+"Flag", flags[0].Example) + } else { + conversion += fmt.Sprintf(`return nil, nil, fmt.Errorf("invalid value for %s, must be %s")`, + flags[0].FullName+"Flag", flags[0].Type) } + conversion += "\n}" } if len(m.ClientInterceptors) > 0 { @@ -364,66 +362,64 @@ func FieldLoadCode(f *FlagData, argName, argTypeName, validate string, defaultVa startIf string endIf string ) - { - if !f.Required { - startIf = fmt.Sprintf("if %s != \"\" {\n", f.FullName) - endIf = "\n}" + if !f.Required { + startIf = fmt.Sprintf("if %s != \"\" {\n", f.FullName) + endIf = "\n}" + } + if argTypeName == codegen.GoNativeTypeName(expr.String) { + ref := "&" + if f.Required || defaultValue != nil { + ref = "" } - if argTypeName == codegen.GoNativeTypeName(expr.String) { - ref := "&" - if f.Required || defaultValue != nil { - ref = "" - } - code = argName + " = " + ref + f.FullName - declErr = validate != "" - } else { - var checkErr bool - code, declErr, checkErr = conversionCode(f.FullName, argName, argTypeName, !f.Required && defaultValue == nil) - if checkErr { - code += "\nif err != nil {\n" - nilVal := "nil" - if expr.IsPrimitive(payload) { - code += fmt.Sprintf("var zero %s\n", payloadRef) - nilVal = "zero" - } - if flagType(argTypeName) == "JSON" { - code += fmt.Sprintf(`return %s, fmt.Errorf("invalid JSON for %s, \nerror: %%s, \nexample of valid JSON:\n%%s", err, %q)`, - nilVal, argName, f.Example) - } else { - code += fmt.Sprintf(`return %s, fmt.Errorf("invalid value for %s, must be %s")`, - nilVal, argName, f.Type) - } - code += "\n}" - } - } - if validate != "" { - nilCheck := "if " + argName + " != nil {" - if strings.HasPrefix(validate, nilCheck) { - // hackety hack... the validation code is generated for the client and needs to - // account for the fact that the field could be nil in this case. We are reusing - // that code to validate a CLI flag which can never be nil. Lint tools complain - // about that so remove the if statements. Ideally we'd have a better way to do - // this but that requires a lot of changes and the added complexity might not be - // worth it. - var lines []string - ls := strings.Split(validate, "\n") - for i := 1; i < len(ls)-1; i++ { - if ls[i+1] == nilCheck { - i++ // skip both closing brace on previous line and check - continue - } - lines = append(lines, ls[i]) - } - validate = strings.Join(lines, "\n") - } - code += "\n" + validate + "\n" + code = argName + " = " + ref + f.FullName + declErr = validate != "" + } else { + var checkErr bool + code, declErr, checkErr = conversionCode(f.FullName, argName, argTypeName, !f.Required && defaultValue == nil) + if checkErr { + code += "\nif err != nil {\n" nilVal := "nil" if expr.IsPrimitive(payload) { code += fmt.Sprintf("var zero %s\n", payloadRef) nilVal = "zero" } - code += fmt.Sprintf("if err != nil {\n\treturn %s, err\n}", nilVal) + if flagType(argTypeName) == "JSON" { + code += fmt.Sprintf(`return %s, fmt.Errorf("invalid JSON for %s, \nerror: %%s, \nexample of valid JSON:\n%%s", err, %q)`, + nilVal, argName, f.Example) + } else { + code += fmt.Sprintf(`return %s, fmt.Errorf("invalid value for %s, must be %s")`, + nilVal, argName, f.Type) + } + code += "\n}" + } + } + if validate != "" { + nilCheck := "if " + argName + " != nil {" + if strings.HasPrefix(validate, nilCheck) { + // hackety hack... the validation code is generated for the client and needs to + // account for the fact that the field could be nil in this case. We are reusing + // that code to validate a CLI flag which can never be nil. Lint tools complain + // about that so remove the if statements. Ideally we'd have a better way to do + // this but that requires a lot of changes and the added complexity might not be + // worth it. + var lines []string + ls := strings.Split(validate, "\n") + for i := 1; i < len(ls)-1; i++ { + if ls[i+1] == nilCheck { + i++ // skip both closing brace on previous line and check + continue + } + lines = append(lines, ls[i]) + } + validate = strings.Join(lines, "\n") + } + code += "\n" + validate + "\n" + nilVal := "nil" + if expr.IsPrimitive(payload) { + code += fmt.Sprintf("var zero %s\n", payloadRef) + nilVal = "zero" } + code += fmt.Sprintf("if err != nil {\n\treturn %s, err\n}", nilVal) } return fmt.Sprintf("%s%s%s", startIf, code, endIf), declErr } diff --git a/codegen/service/interceptors.go b/codegen/service/interceptors.go index f24f270f46..aa4bdfc0e9 100644 --- a/codegen/service/interceptors.go +++ b/codegen/service/interceptors.go @@ -8,7 +8,7 @@ import ( ) // InterceptorsFiles returns the interceptors files for the given service. -func InterceptorsFiles(genpkg string, service *expr.ServiceExpr) []*codegen.File { +func InterceptorsFiles(_ string, service *expr.ServiceExpr) []*codegen.File { var files []*codegen.File svc := Services.Get(service.Name) diff --git a/codegen/service/service.go b/codegen/service/service.go index 4fb12d67fe..bc900ce79f 100644 --- a/codegen/service/service.go +++ b/codegen/service/service.go @@ -122,16 +122,9 @@ func Files(genpkg string, service *expr.ServiceExpr, userTypePkgs map[string][]s // transform result type functions for _, t := range svc.viewedResultTypes { - svcSections = append(svcSections, &codegen.SectionTemplate{ - Name: "viewed-result-type-to-service-result-type", - Source: readTemplate("type_init"), - Data: t.ResultInit, - }) - svcSections = append(svcSections, &codegen.SectionTemplate{ - Name: "service-result-type-to-viewed-result-type", - Source: readTemplate("type_init"), - Data: t.Init, - }) + svcSections = append(svcSections, + &codegen.SectionTemplate{Name: "viewed-result-type-to-service-result-type", Source: readTemplate("type_init"), Data: t.ResultInit}, + &codegen.SectionTemplate{Name: "service-result-type-to-viewed-result-type", Source: readTemplate("type_init"), Data: t.Init}) } var projh []*codegen.TransformFunctionData for _, t := range svc.projectedTypes { diff --git a/codegen/service/service_data.go b/codegen/service/service_data.go index 195722f7bd..42679c31f7 100644 --- a/codegen/service/service_data.go +++ b/codegen/service/service_data.go @@ -696,79 +696,76 @@ func (d ServicesData) analyze(service *expr.ServiceExpr) *Data { seenProj map[string]*ProjectedTypeData seenViewed map[string]*ViewedResultTypeData ) - { - scope = codegen.NewNameScope() - scope.Unique("Use") // Reserve "Use" for Endpoints struct Use method. - viewScope = codegen.NewNameScope() - pkgName = scope.HashedUnique(service, strings.ToLower(codegen.Goify(service.Name, false)), "svc") - viewspkg = pkgName + "views" - seen = make(map[string]struct{}) - seenErrors = make(map[string]struct{}) - seenProj = make(map[string]*ProjectedTypeData) - seenViewed = make(map[string]*ViewedResultTypeData) - - // A function to collect user types from an error expression - recordError := func(er *expr.ErrorExpr) { - errTypes = append(errTypes, collectTypes(er.AttributeExpr, scope, seen)...) - if er.Type == expr.ErrorResult { - if _, ok := seenErrors[er.Name]; ok { - return - } - seenErrors[er.Name] = struct{}{} - errorInits = append(errorInits, buildErrorInitData(er, scope)) + scope = codegen.NewNameScope() + scope.Unique("Use") // Reserve "Use" for Endpoints struct Use method. + viewScope = codegen.NewNameScope() + pkgName = scope.HashedUnique(service, strings.ToLower(codegen.Goify(service.Name, false)), "svc") + viewspkg = pkgName + "views" + seen = make(map[string]struct{}) + seenErrors = make(map[string]struct{}) + seenProj = make(map[string]*ProjectedTypeData) + seenViewed = make(map[string]*ViewedResultTypeData) + + // A function to collect user types from an error expression + recordError := func(er *expr.ErrorExpr) { + errTypes = append(errTypes, collectTypes(er.AttributeExpr, scope, seen)...) + if er.Type == expr.ErrorResult { + if _, ok := seenErrors[er.Name]; ok { + return } + seenErrors[er.Name] = struct{}{} + errorInits = append(errorInits, buildErrorInitData(er, scope)) } - for _, er := range service.Errors { - recordError(er) - } + } + for _, er := range service.Errors { + recordError(er) + } - // A function to collect inner user types from an attribute expression - collectUserTypes := func(att *expr.AttributeExpr) { - if ut, ok := att.Type.(expr.UserType); ok { - att = ut.Attribute() - } - types = append(types, collectTypes(att, scope, seen)...) - } - for _, m := range service.Methods { - // collect inner user types - collectUserTypes(m.Payload) - collectUserTypes(m.StreamingPayload) - collectUserTypes(m.Result) - // Collect projected types - if hasResultType(m.Result) { - types, umeths := collectProjectedTypes(expr.DupAtt(m.Result), m.Result, viewspkg, scope, viewScope, seenProj) - projTypes = append(projTypes, types...) - viewedUnionMeths = append(viewedUnionMeths, umeths...) - } - for _, er := range m.Errors { - recordError(er) - } + // A function to collect inner user types from an attribute expression + collectUserTypes := func(att *expr.AttributeExpr) { + if ut, ok := att.Type.(expr.UserType); ok { + att = ut.Attribute() + } + types = append(types, collectTypes(att, scope, seen)...) + } + for _, m := range service.Methods { + // collect inner user types + collectUserTypes(m.Payload) + collectUserTypes(m.StreamingPayload) + collectUserTypes(m.Result) + // Collect projected types + if hasResultType(m.Result) { + types, umeths := collectProjectedTypes(expr.DupAtt(m.Result), m.Result, viewspkg, scope, viewScope, seenProj) + projTypes = append(projTypes, types...) + viewedUnionMeths = append(viewedUnionMeths, umeths...) + } + for _, er := range m.Errors { + recordError(er) } + } - // A function to convert raw object type to user type. - wrapObject := func(att *expr.AttributeExpr, name, id string) { - if _, ok := att.Type.(*expr.Object); ok { - att.Type = &expr.UserTypeExpr{ - AttributeExpr: expr.DupAtt(att), - TypeName: scope.Name(name), - UID: id, - } - } - if ut, ok := att.Type.(expr.UserType); ok { - seen[ut.ID()] = struct{}{} + // A function to convert raw object type to user type. + wrapObject := func(att *expr.AttributeExpr, name, id string) { + if _, ok := att.Type.(*expr.Object); ok { + att.Type = &expr.UserTypeExpr{ + AttributeExpr: expr.DupAtt(att), + TypeName: scope.Name(name), + UID: id, } } - - for _, m := range service.Methods { - name := codegen.Goify(m.Name, true) - // Create user type for raw object payloads - wrapObject(m.Payload, name+"Payload", service.Name+"#"+name+"Payload") - // Create user type for raw object streaming payloads - wrapObject(m.StreamingPayload, name+"StreamingPayload", service.Name+"#"+name+"StreamingPayload") - // Create user type for raw object results - wrapObject(m.Result, name+"Result", service.Name+"#"+name+"Result") + if ut, ok := att.Type.(expr.UserType); ok { + seen[ut.ID()] = struct{}{} } + } + for _, m := range service.Methods { + name := codegen.Goify(m.Name, true) + // Create user type for raw object payloads + wrapObject(m.Payload, name+"Payload", service.Name+"#"+name+"Payload") + // Create user type for raw object streaming payloads + wrapObject(m.StreamingPayload, name+"StreamingPayload", service.Name+"#"+name+"StreamingPayload") + // Create user type for raw object results + wrapObject(m.Result, name+"Result", service.Name+"#"+name+"Result") } // Add forced types @@ -796,86 +793,77 @@ func (d ServicesData) analyze(service *expr.ServiceExpr) *Data { methods []*MethodData schemes SchemesData ) - { - methods = make([]*MethodData, len(service.Methods)) - for i, e := range service.Methods { - m := buildMethodData(e, scope) - methods[i] = m - for _, s := range m.Schemes { - schemes = schemes.Append(s) - } - rt, ok := e.Result.Type.(*expr.ResultTypeExpr) - if !ok { - continue - } - var view string - if v, ok := e.Result.Meta.Last(expr.ViewMetaKey); ok { - view = v - } - if vrt, ok := seenViewed[m.Result+"::"+view]; ok { - m.ViewedResult = vrt - continue - } - projected := seenProj[rt.ID()] - projAtt := &expr.AttributeExpr{Type: projected.Type} - vrt := buildViewedResultType(e.Result, projAtt, viewspkg, scope, viewScope) - found := false - for _, rt := range viewedRTs { - if rt.Type.ID() == vrt.Type.ID() { - found = true - break - } - } - if !found { - viewedRTs = append(viewedRTs, vrt) - } + methods = make([]*MethodData, len(service.Methods)) + for i, e := range service.Methods { + m := buildMethodData(e, scope) + methods[i] = m + for _, s := range m.Schemes { + schemes = schemes.Append(s) + } + rt, ok := e.Result.Type.(*expr.ResultTypeExpr) + if !ok { + continue + } + var view string + if v, ok := e.Result.Meta.Last(expr.ViewMetaKey); ok { + view = v + } + if vrt, ok := seenViewed[m.Result+"::"+view]; ok { m.ViewedResult = vrt - seenViewed[vrt.Name+"::"+view] = vrt + continue + } + projected := seenProj[rt.ID()] + projAtt := &expr.AttributeExpr{Type: projected.Type} + vrt := buildViewedResultType(e.Result, projAtt, viewspkg, scope, viewScope) + found := false + for _, rt := range viewedRTs { + if rt.Type.ID() == vrt.Type.ID() { + found = true + break + } + } + if !found { + viewedRTs = append(viewedRTs, vrt) } + m.ViewedResult = vrt + seenViewed[vrt.Name+"::"+view] = vrt } var ( unionMethods []*UnionValueMethodData + ms []*UnionValueMethodData ) - { - var ms []*UnionValueMethodData - seen := make(map[string]struct{}) - for _, t := range types { - ms = append(ms, collectUnionMethods(&expr.AttributeExpr{Type: t.Type}, scope, t.Loc, seen)...) - } - for _, t := range errTypes { - ms = append(ms, collectUnionMethods(&expr.AttributeExpr{Type: t.Type}, scope, t.Loc, seen)...) - } - for _, m := range service.Methods { - ms = append(ms, collectUnionMethods(m.Payload, scope, codegen.UserTypeLocation(m.Payload.Type), seen)...) - ms = append(ms, collectUnionMethods(m.StreamingPayload, scope, codegen.UserTypeLocation(m.StreamingPayload.Type), seen)...) - ms = append(ms, collectUnionMethods(m.Result, scope, codegen.UserTypeLocation(m.Result.Type), seen)...) - for _, e := range m.Errors { - ms = append(ms, collectUnionMethods(e.AttributeExpr, scope, codegen.UserTypeLocation(e.Type), seen)...) - } + seen = make(map[string]struct{}) + for _, t := range types { + ms = append(ms, collectUnionMethods(&expr.AttributeExpr{Type: t.Type}, scope, t.Loc, seen)...) + } + for _, t := range errTypes { + ms = append(ms, collectUnionMethods(&expr.AttributeExpr{Type: t.Type}, scope, t.Loc, seen)...) + } + for _, m := range service.Methods { + ms = append(ms, collectUnionMethods(m.Payload, scope, codegen.UserTypeLocation(m.Payload.Type), seen)...) + ms = append(ms, collectUnionMethods(m.StreamingPayload, scope, codegen.UserTypeLocation(m.StreamingPayload.Type), seen)...) + ms = append(ms, collectUnionMethods(m.Result, scope, codegen.UserTypeLocation(m.Result.Type), seen)...) + for _, e := range m.Errors { + ms = append(ms, collectUnionMethods(e.AttributeExpr, scope, codegen.UserTypeLocation(e.Type), seen)...) } - sort.Slice(ms, func(i, j int) bool { - return ms[i].Name < ms[j].Name - }) - pkgs := make(map[string]struct{}) - for _, m := range ms { - key := m.TypeRef + "::" + m.Name + "::" + m.Loc.PackageName() - if _, ok := pkgs[key]; ok { - continue - } - pkgs[key] = struct{}{} - unionMethods = append(unionMethods, m) + } + sort.Slice(ms, func(i, j int) bool { + return ms[i].Name < ms[j].Name + }) + pkgs := make(map[string]struct{}) + for _, m := range ms { + key := m.TypeRef + "::" + m.Name + "::" + m.Loc.PackageName() + if _, ok := pkgs[key]; ok { + continue } + pkgs[key] = struct{}{} + unionMethods = append(unionMethods, m) } - var ( - desc string - ) - { - desc = service.Description - if desc == "" { - desc = fmt.Sprintf("Service is the %s service interface.", service.Name) - } + desc := service.Description + if desc == "" { + desc = fmt.Sprintf("Service is the %s service interface.", service.Name) } varName := codegen.Goify(service.Name, false) @@ -1626,14 +1614,12 @@ func buildProjectedType(projected, att *expr.AttributeExpr, viewspkg string, sco varname = viewScope.GoTypeName(projected) pt = projected.Type.(expr.UserType) ) - { - if _, isrt := pt.(*expr.ResultTypeExpr); isrt { - typeInits = buildTypeInits(projected, att, viewspkg, scope, viewScope) - projections = buildProjections(projected, att, viewspkg, scope, viewScope) - views = buildViews(att.Type.(*expr.ResultTypeExpr), viewScope) - } - validations = buildValidations(projected, viewScope) + if _, isrt := pt.(*expr.ResultTypeExpr); isrt { + typeInits = buildTypeInits(projected, att, viewspkg, scope, viewScope) + projections = buildProjections(projected, att, viewspkg, scope, viewScope) + views = buildViews(att.Type.(*expr.ResultTypeExpr), viewScope) } + validations = buildValidations(projected, viewScope) removeMeta(projected) return &ProjectedTypeData{ UserTypeData: &UserTypeData{ @@ -1678,112 +1664,98 @@ func buildViewedResultType(att, projected *expr.AttributeExpr, viewspkg string, var ( viewName string views []*ViewData - - rt = att.Type.(*expr.ResultTypeExpr) - isarr = expr.IsArray(att.Type) ) - { - if !rt.HasMultipleViews() { - viewName = expr.DefaultView - } - if v, ok := att.Meta.Last(expr.ViewMetaKey); ok { - viewName = v - } - views = buildViews(rt, viewScope) + + rt := att.Type.(*expr.ResultTypeExpr) + isarr := expr.IsArray(att.Type) + if !rt.HasMultipleViews() { + viewName = expr.DefaultView } + if v, ok := att.Meta.Last(expr.ViewMetaKey); ok { + viewName = v + } + views = buildViews(rt, viewScope) // build validation data - var ( - validate *ValidateData - - resvar = scope.GoTypeName(att) - resref = scope.GoTypeRef(att) - ) - { - data := map[string]any{ - "Projected": scope.GoTypeName(projected), - "ArgVar": "result", - "Source": "result", - "Views": views, - "IsViewed": true, - } - buf := &bytes.Buffer{} - if err := validateTypeCodeTmpl.Execute(buf, data); err != nil { - panic(err) // bug - } - name := "Validate" + resvar - validate = &ValidateData{ - Name: name, - Description: fmt.Sprintf("%s runs the validations defined on the viewed result type %s.", name, resvar), - Ref: resref, - Validate: buf.String(), - } + var validate *ValidateData + resvar := scope.GoTypeName(att) + resref := scope.GoTypeRef(att) + data := map[string]any{ + "Projected": scope.GoTypeName(projected), + "ArgVar": "result", + "Source": "result", + "Views": views, + "IsViewed": true, + } + buf := &bytes.Buffer{} + if err := validateTypeCodeTmpl.Execute(buf, data); err != nil { + panic(err) // bug + } + name := "Validate" + resvar + validate = &ValidateData{ + Name: name, + Description: fmt.Sprintf("%s runs the validations defined on the viewed result type %s.", name, resvar), + Ref: resref, + Validate: buf.String(), } // build constructor to initialize viewed result type from result type - var ( - init *InitData - - vresref = viewScope.GoFullTypeRef(att, viewspkg) - ) - { - data := map[string]any{ - "ToViewed": true, - "ArgVar": "res", - "ReturnVar": "vres", - "Views": views, - "ReturnTypeRef": vresref, - "IsCollection": isarr, - "TargetType": scope.GoFullTypeName(att, viewspkg), - "InitName": "new" + viewScope.GoTypeName(projected), - } - buf := &bytes.Buffer{} - if err := initTypeCodeTmpl.Execute(buf, data); err != nil { - panic(err) // bug - } - pkg := "" - if loc := codegen.UserTypeLocation(att.Type); loc != nil { - pkg = loc.PackageName() - } - name := "NewViewed" + resvar - init = &InitData{ - Name: name, - Description: fmt.Sprintf("%s initializes viewed result type %s from result type %s using the given view.", name, resvar, resvar), - Args: []*InitArgData{ - {Name: "res", Ref: scope.GoFullTypeRef(att, pkg)}, - {Name: "view", Ref: "string"}, - }, - ReturnTypeRef: vresref, - Code: buf.String(), - } + var init *InitData + vresref := viewScope.GoFullTypeRef(att, viewspkg) + data = map[string]any{ + "ToViewed": true, + "ArgVar": "res", + "ReturnVar": "vres", + "Views": views, + "ReturnTypeRef": vresref, + "IsCollection": isarr, + "TargetType": scope.GoFullTypeName(att, viewspkg), + "InitName": "new" + viewScope.GoTypeName(projected), + } + buf = &bytes.Buffer{} + if err := initTypeCodeTmpl.Execute(buf, data); err != nil { + panic(err) // bug + } + pkg := "" + if loc := codegen.UserTypeLocation(att.Type); loc != nil { + pkg = loc.PackageName() + } + name = "NewViewed" + resvar + init = &InitData{ + Name: name, + Description: fmt.Sprintf("%s initializes viewed result type %s from result type %s using the given view.", name, resvar, resvar), + Args: []*InitArgData{ + {Name: "res", Ref: scope.GoFullTypeRef(att, pkg)}, + {Name: "view", Ref: "string"}, + }, + ReturnTypeRef: vresref, + Code: buf.String(), } // build constructor to initialize result type from viewed result type var resinit *InitData - { - if loc := codegen.UserTypeLocation(att.Type); loc != nil { - resref = scope.GoFullTypeRef(att, loc.PackageName()) - } - data := map[string]any{ - "ToResult": true, - "ArgVar": "vres", - "ReturnVar": "res", - "Views": views, - "ReturnTypeRef": resref, - "InitName": "new" + scope.GoTypeName(att), - } - buf := &bytes.Buffer{} - if err := initTypeCodeTmpl.Execute(buf, data); err != nil { - panic(err) // bug - } - name := "New" + resvar - resinit = &InitData{ - Name: name, - Description: fmt.Sprintf("%s initializes result type %s from viewed result type %s.", name, resvar, resvar), - Args: []*InitArgData{{Name: "vres", Ref: scope.GoFullTypeRef(att, viewspkg)}}, - ReturnTypeRef: resref, - Code: buf.String(), - } + if loc := codegen.UserTypeLocation(att.Type); loc != nil { + resref = scope.GoFullTypeRef(att, loc.PackageName()) + } + data = map[string]any{ + "ToResult": true, + "ArgVar": "vres", + "ReturnVar": "res", + "Views": views, + "ReturnTypeRef": resref, + "InitName": "new" + scope.GoTypeName(att), + } + buf = &bytes.Buffer{} + if err := initTypeCodeTmpl.Execute(buf, data); err != nil { + panic(err) // bug + } + name = "New" + resvar + resinit = &InitData{ + Name: name, + Description: fmt.Sprintf("%s initializes result type %s from viewed result type %s.", name, resvar, resvar), + Args: []*InitArgData{{Name: "vres", Ref: scope.GoFullTypeRef(att, viewspkg)}}, + ReturnTypeRef: resref, + Code: buf.String(), } projT := wrapProjected(projected.Type.(expr.UserType)) @@ -1846,72 +1818,62 @@ func buildTypeInits(projected, att *expr.AttributeExpr, viewspkg string, scope, // For every view defined in the result type, build a constructor function // to create the result type from a projected type based on the view. - var init []*InitData - { - init = make([]*InitData, 0, len(prt.Views)) - for _, view := range prt.Views { - var ( - typ expr.DataType - - obj = &expr.Object{} - ) - { - walkViewAttrs(pobj, view, func(name string, att, _ *expr.AttributeExpr) { - obj.Set(name, att) - }) - typ = obj - if parr != nil { - typ = &expr.Array{ElemType: &expr.AttributeExpr{ - Type: &expr.ResultTypeExpr{ - UserTypeExpr: &expr.UserTypeExpr{ - AttributeExpr: &expr.AttributeExpr{Type: obj}, - TypeName: scope.GoTypeName(parr.ElemType), - }, - }, - }} - } - } - src := &expr.AttributeExpr{ + init := make([]*InitData, 0, len(prt.Views)) + for _, view := range prt.Views { + var typ expr.DataType + obj := &expr.Object{} + walkViewAttrs(pobj, view, func(name string, att, _ *expr.AttributeExpr) { + obj.Set(name, att) + }) + typ = obj + if parr != nil { + typ = &expr.Array{ElemType: &expr.AttributeExpr{ Type: &expr.ResultTypeExpr{ UserTypeExpr: &expr.UserTypeExpr{ - AttributeExpr: &expr.AttributeExpr{Type: typ}, - TypeName: scope.GoTypeName(projected), + AttributeExpr: &expr.AttributeExpr{Type: obj}, + TypeName: scope.GoTypeName(parr.ElemType), }, - Views: prt.Views, - Identifier: prt.Identifier, }, - } + }} + } + src := &expr.AttributeExpr{ + Type: &expr.ResultTypeExpr{ + UserTypeExpr: &expr.UserTypeExpr{ + AttributeExpr: &expr.AttributeExpr{Type: typ}, + TypeName: scope.GoTypeName(projected), + }, + Views: prt.Views, + Identifier: prt.Identifier, + }, + } - var ( - name string - code string - helpers []*codegen.TransformFunctionData + var ( + name string + code string + helpers []*codegen.TransformFunctionData + ) - srcCtx = projectedTypeContext(viewspkg, true, viewScope) - tgtCtx = typeContext("", scope) - resvar = scope.GoTypeName(att) - ) - { - name = "new" + resvar - if view.Name != expr.DefaultView { - name += codegen.Goify(view.Name, true) - } - code, helpers = buildConstructorCode(src, att, "vres", "res", srcCtx, tgtCtx, view.Name) - } + srcCtx := projectedTypeContext(viewspkg, true, viewScope) + tgtCtx := typeContext("", scope) + resvar := scope.GoTypeName(att) + name = "new" + resvar + if view.Name != expr.DefaultView { + name += codegen.Goify(view.Name, true) + } + code, helpers = buildConstructorCode(src, att, "vres", "res", srcCtx, tgtCtx, view.Name) - pkg := "" - if loc := codegen.UserTypeLocation(att.Type); loc != nil { - pkg = loc.PackageName() - } - init = append(init, &InitData{ - Name: name, - Description: fmt.Sprintf("%s converts projected type %s to service type %s.", name, resvar, resvar), - Args: []*InitArgData{{Name: "vres", Ref: viewScope.GoFullTypeRef(projected, viewspkg)}}, - ReturnTypeRef: scope.GoFullTypeRef(att, pkg), - Code: code, - Helpers: helpers, - }) + pkg := "" + if loc := codegen.UserTypeLocation(att.Type); loc != nil { + pkg = loc.PackageName() } + init = append(init, &InitData{ + Name: name, + Description: fmt.Sprintf("%s converts projected type %s to service type %s.", name, resvar, resvar), + Args: []*InitArgData{{Name: "vres", Ref: viewScope.GoFullTypeRef(projected, viewspkg)}}, + ReturnTypeRef: scope.GoFullTypeRef(att, pkg), + Code: code, + Helpers: helpers, + }) } return init } @@ -1919,40 +1881,31 @@ func buildTypeInits(projected, att *expr.AttributeExpr, viewspkg string, scope, // buildProjections builds the data to generate the constructor code to // project a result type to a projected type based on a view. func buildProjections(projected, att *expr.AttributeExpr, viewspkg string, scope, viewScope *codegen.NameScope) []*InitData { - var ( - projections []*InitData - - rt = att.Type.(*expr.ResultTypeExpr) - ) - + var projections []*InitData + rt := att.Type.(*expr.ResultTypeExpr) projections = make([]*InitData, 0, len(rt.Views)) for _, view := range rt.Views { - var ( - typ expr.DataType - - obj = &expr.Object{} - ) - { - pobj := expr.AsObject(projected.Type) - parr := expr.AsArray(projected.Type) - if parr != nil { - // result type collection - pobj = expr.AsObject(parr.ElemType.Type) - } - walkViewAttrs(pobj, view, func(name string, att, _ *expr.AttributeExpr) { - obj.Set(name, att) - }) - typ = obj - if parr != nil { - typ = &expr.Array{ElemType: &expr.AttributeExpr{ - Type: &expr.ResultTypeExpr{ - UserTypeExpr: &expr.UserTypeExpr{ - AttributeExpr: &expr.AttributeExpr{Type: obj}, - TypeName: parr.ElemType.Type.Name(), - }, + var typ expr.DataType + obj := &expr.Object{} + pobj := expr.AsObject(projected.Type) + parr := expr.AsArray(projected.Type) + if parr != nil { + // result type collection + pobj = expr.AsObject(parr.ElemType.Type) + } + walkViewAttrs(pobj, view, func(name string, att, _ *expr.AttributeExpr) { + obj.Set(name, att) + }) + typ = obj + if parr != nil { + typ = &expr.Array{ElemType: &expr.AttributeExpr{ + Type: &expr.ResultTypeExpr{ + UserTypeExpr: &expr.UserTypeExpr{ + AttributeExpr: &expr.AttributeExpr{Type: obj}, + TypeName: parr.ElemType.Type.Name(), }, - }} - } + }, + }} } tgt := &expr.AttributeExpr{ Type: &expr.ResultTypeExpr{ @@ -1969,18 +1922,15 @@ func buildProjections(projected, att *expr.AttributeExpr, viewspkg string, scope name string code string helpers []*codegen.TransformFunctionData - - srcCtx = typeContext("", scope) - tgtCtx = projectedTypeContext(viewspkg, true, viewScope) - tname = scope.GoTypeName(projected) ) - { - name = "new" + tname - if view.Name != expr.DefaultView { - name += codegen.Goify(view.Name, true) - } - code, helpers = buildConstructorCode(att, tgt, "res", "vres", srcCtx, tgtCtx, view.Name) + srcCtx := typeContext("", scope) + tgtCtx := projectedTypeContext(viewspkg, true, viewScope) + tname := scope.GoTypeName(projected) + name = "new" + tname + if view.Name != expr.DefaultView { + name += codegen.Goify(view.Name, true) } + code, helpers = buildConstructorCode(att, tgt, "res", "vres", srcCtx, tgtCtx, view.Name) pkg := "" if loc := codegen.UserTypeLocation(att.Type); loc != nil { @@ -2001,12 +1951,9 @@ func buildProjections(projected, att *expr.AttributeExpr, viewspkg string, scope // buildValidations builds the data required to generate validations for the // projected types. func buildValidations(projected *expr.AttributeExpr, scope *codegen.NameScope) []*ValidateData { - var ( - validations []*ValidateData - - ut = projected.Type.(expr.UserType) - tname = scope.GoTypeName(projected) - ) + var validations []*ValidateData + ut := projected.Type.(expr.UserType) + tname := scope.GoTypeName(projected) if rt, isrt := ut.(*expr.ResultTypeExpr); isrt { // for result types we create a validation function containing view // specific validation logic for each view @@ -2022,12 +1969,10 @@ func buildValidations(projected *expr.AttributeExpr, scope *codegen.NameScope) [ name string vn string ) - { - name = "Validate" + tname - if view.Name != expr.DefaultView { - vn = codegen.Goify(view.Name, true) - name += vn - } + name = "Validate" + tname + if view.Name != expr.DefaultView { + vn = codegen.Goify(view.Name, true) + name += vn } if arr != nil { @@ -2038,29 +1983,26 @@ func buildValidations(projected *expr.AttributeExpr, scope *codegen.NameScope) [ var ( ctx *codegen.AttributeContext fields []map[string]any - - o = &expr.Object{} ) - { - walkViewAttrs(expr.AsObject(projected.Type), view, func(name string, attr, vatt *expr.AttributeExpr) { - if rt, ok := attr.Type.(*expr.ResultTypeExpr); ok { - // use explicitly specified view (if any) for the attribute, - // otherwise use default - vw := "" - if v, ok := vatt.Meta.Last(expr.ViewMetaKey); ok && v != expr.DefaultView { - vw = v - } - fields = append(fields, map[string]any{ - "Name": name, - "ValidateVar": "Validate" + scope.GoTypeName(attr) + codegen.Goify(vw, true), - "IsRequired": rt.Attribute().IsRequired(name), - }) - } else { - o.Set(name, attr) + o := &expr.Object{} + walkViewAttrs(expr.AsObject(projected.Type), view, func(name string, attr, vatt *expr.AttributeExpr) { + if rt, ok := attr.Type.(*expr.ResultTypeExpr); ok { + // use explicitly specified view (if any) for the attribute, + // otherwise use default + vw := "" + if v, ok := vatt.Meta.Last(expr.ViewMetaKey); ok && v != expr.DefaultView { + vw = v } - }) - ctx = projectedTypeContext("", !expr.IsPrimitive(projected.Type), scope) - } + fields = append(fields, map[string]any{ + "Name": name, + "ValidateVar": "Validate" + scope.GoTypeName(attr) + codegen.Goify(vw, true), + "IsRequired": rt.Attribute().IsRequired(name), + }) + } else { + o.Set(name, attr) + } + }) + ctx = projectedTypeContext("", !expr.IsPrimitive(projected.Type), scope) data["Validate"] = codegen.ValidationCode(&expr.AttributeExpr{Type: o, Validation: rt.Validation}, rt, ctx, true, false, true, "result") data["Fields"] = fields } @@ -2142,15 +2084,10 @@ func buildConstructorCode(src, tgt *expr.AttributeExpr, sourceVar, targetVar str data["Source"] = sourceVar data["Target"] = targetVar - var ( - code string - err error - ) - { - // build code for target with no result types - if code, helpers, err = codegen.GoTransform(src, tatt, sourceVar, targetVar, sourceCtx, targetCtx, "transform", true); err != nil { - panic(err) // bug - } + // build code for target with no result types + code, helpers, err := codegen.GoTransform(src, tatt, sourceVar, targetVar, sourceCtx, targetCtx, "transform", true) + if err != nil { + panic(err) // bug } data["Code"] = code diff --git a/dsl/interceptor_test.go b/dsl/interceptor_test.go index 719b600709..7b94d86011 100644 --- a/dsl/interceptor_test.go +++ b/dsl/interceptor_test.go @@ -15,7 +15,7 @@ import ( func TestInterceptor(t *testing.T) { cases := map[string]struct { DSL func() - Assert func(t *testing.T, intr *expr.InterceptorExpr) + Assert func(*testing.T, *expr.InterceptorExpr) }{ "valid-minimal": { func() { @@ -132,7 +132,7 @@ func TestInterceptor(t *testing.T) { func() { Interceptor("", func() {}) }, - func(t *testing.T, intr *expr.InterceptorExpr) { + func(t *testing.T, _ *expr.InterceptorExpr) { assert.NotNil(t, eval.Context.Errors, "expected a validation error") }, }, @@ -141,7 +141,7 @@ func TestInterceptor(t *testing.T) { Interceptor("duplicate", func() {}) Interceptor("duplicate", func() {}) }, - func(t *testing.T, intr *expr.InterceptorExpr) { + func(t *testing.T, _ *expr.InterceptorExpr) { if eval.Context.Errors == nil { t.Error("expected a validation error, got none") } @@ -164,7 +164,7 @@ func TestInterceptor(t *testing.T) { func TestServerInterceptor(t *testing.T) { cases := map[string]struct { DSL func() - Assert func(t *testing.T, svc *expr.ServiceExpr, err error) + Assert func(*testing.T, *expr.ServiceExpr, error) }{ "valid-reference": { func() { @@ -200,7 +200,7 @@ func TestServerInterceptor(t *testing.T) { ServerInterceptor(42) // Invalid type }) }, - func(t *testing.T, svc *expr.ServiceExpr, err error) { + func(t *testing.T, _ *expr.ServiceExpr, err error) { require.Error(t, err) }, }, @@ -210,7 +210,7 @@ func TestServerInterceptor(t *testing.T) { ServerInterceptor("invalid") }) }, - func(t *testing.T, svc *expr.ServiceExpr, err error) { + func(t *testing.T, _ *expr.ServiceExpr, err error) { require.Error(t, err) }, }, @@ -230,7 +230,7 @@ func TestServerInterceptor(t *testing.T) { func TestClientInterceptor(t *testing.T) { cases := map[string]struct { DSL func() - Assert func(t *testing.T, svc *expr.ServiceExpr, err error) + Assert func(*testing.T, *expr.ServiceExpr, error) }{ "valid-reference": { func() { @@ -264,7 +264,7 @@ func TestClientInterceptor(t *testing.T) { ClientInterceptor(42) // Invalid type }) }, - func(t *testing.T, svc *expr.ServiceExpr, err error) { + func(t *testing.T, _ *expr.ServiceExpr, err error) { require.Error(t, err) }, }, @@ -274,7 +274,7 @@ func TestClientInterceptor(t *testing.T) { ClientInterceptor("invalid") }) }, - func(t *testing.T, svc *expr.ServiceExpr, err error) { + func(t *testing.T, _ *expr.ServiceExpr, err error) { require.Error(t, err) }, }, diff --git a/expr/http_endpoint.go b/expr/http_endpoint.go index bd17478469..bc6c92a679 100644 --- a/expr/http_endpoint.go +++ b/expr/http_endpoint.go @@ -878,7 +878,7 @@ func (e *HTTPEndpointExpr) validateHeadersAndCookies() *eval.ValidationErrors { // EvalName returns the generic definition name used in error messages. func (r *RouteExpr) EvalName() string { - return fmt.Sprintf(`route %s "%s" of %s`, r.Method, r.Path, r.Endpoint.EvalName()) + return fmt.Sprintf(`route %s %q of %s`, r.Method, r.Path, r.Endpoint.EvalName()) } // Validate validates a route expression by ensuring that the route parameters diff --git a/grpc/codegen/client_cli.go b/grpc/codegen/client_cli.go index 9f7f9d8f92..fe54ff40ca 100644 --- a/grpc/codegen/client_cli.go +++ b/grpc/codegen/client_cli.go @@ -19,31 +19,37 @@ func ClientCLIFiles(genpkg string, root *expr.RootExpr) []*codegen.File { data []*cli.CommandData svcs []*expr.GRPCServiceExpr ) - { - for _, svc := range root.API.GRPC.Services { - if len(svc.GRPCEndpoints) == 0 { - continue - } - sd := GRPCServices.Get(svc.Name()) - command := cli.BuildCommandData(sd.Service) - for _, e := range sd.Endpoints { - flags, buildFunction := buildFlags(e) - subcmd := cli.BuildSubcommandData(sd.Service, e.Method, buildFunction, flags) - command.Subcommands = append(command.Subcommands, subcmd) - } - command.Example = command.Subcommands[0].Example - data = append(data, command) - svcs = append(svcs, svc) + for _, svc := range root.API.GRPC.Services { + if len(svc.GRPCEndpoints) == 0 { + continue } - } - var files []*codegen.File - { - for _, svr := range root.API.Servers { - files = append(files, endpointParser(genpkg, root, svr, data)) + sd := GRPCServices.Get(svc.Name()) + command := cli.BuildCommandData(sd.Service) + for _, e := range sd.Endpoints { + flags, buildFunction := buildFlags(e) + subcmd := cli.BuildSubcommandData(sd.Service, e.Method, buildFunction, flags) + command.Subcommands = append(command.Subcommands, subcmd) } - for i, svc := range svcs { - files = append(files, payloadBuilders(genpkg, svc, data[i])) + command.Example = command.Subcommands[0].Example + data = append(data, command) + svcs = append(svcs, svc) + sd := GRPCServices.Get(svc.Name()) + command := cli.BuildCommandData(sd.Service) + for _, e := range sd.Endpoints { + flags, buildFunction := buildFlags(e) + subcmd := cli.BuildSubcommandData(sd.Service.Name, e.Method, buildFunction, flags) + command.Subcommands = append(command.Subcommands, subcmd) } + command.Example = command.Subcommands[0].Example + data = append(data, command) + svcs = append(svcs, svc) + } + var files []*codegen.File + for _, svr := range root.API.Servers { + files = append(files, endpointParser(genpkg, root, svr, data)) + } + for i, svc := range svcs { + files = append(files, payloadBuilders(genpkg, svc, data[i])) } return files } @@ -71,14 +77,9 @@ func endpointParser(genpkg string, root *expr.RootExpr, svr *expr.ServerExpr, da continue } svcName := sd.Service.PathName - specs = append(specs, &codegen.ImportSpec{ - Path: path.Join(genpkg, "grpc", svcName, "client"), - Name: sd.Service.PkgName + "c", - }) - specs = append(specs, &codegen.ImportSpec{ - Path: path.Join(genpkg, "grpc", svcName, pbPkgName), - Name: svcName + pbPkgName, - }) + specs = append(specs, + &codegen.ImportSpec{Path: path.Join(genpkg, "grpc", svcName, "client"), Name: sd.Service.PkgName + "c"}, + &codegen.ImportSpec{Path: path.Join(genpkg, "grpc", svcName, pbPkgName), Name: svcName + pbPkgName}) specs = append(specs, sd.Service.UserTypeImports...) // Add interceptors import if service has client interceptors if len(sd.Service.ClientInterceptors) > 0 { diff --git a/grpc/codegen/service_data.go b/grpc/codegen/service_data.go index e375aa9c3f..0c42bbdc80 100644 --- a/grpc/codegen/service_data.go +++ b/grpc/codegen/service_data.go @@ -439,29 +439,26 @@ func (ServicesData) analyze(gs *expr.GRPCServiceExpr) *ServiceData { sd *ServiceData seen, imported map[string]struct{} svcVarN string - - svc = service.Services.Get(gs.Name()) - scope = codegen.NewNameScope() - pkg = codegen.SnakeCase(codegen.Goify(svc.Name, false)) + pbPkgName ) - { - svcVarN = scope.HashedUnique(gs.ServiceExpr, codegen.Goify(svc.Name, true)) - sd = &ServiceData{ - Service: svc, - Name: svcVarN, - Description: svc.Description, - PkgName: pkg, - ServerStruct: "Server", - ClientStruct: "Client", - ServerInit: "New", - ClientInit: "NewClient", - ServerInterface: svcVarN + "Server", - ClientInterface: svcVarN + "Client", - ClientInterfaceInit: fmt.Sprintf("%s.New%sClient", pkg, svcVarN), - Scope: scope, - } - seen, imported = make(map[string]struct{}), make(map[string]struct{}) + svc := service.Services.Get(gs.Name()) + scope := codegen.NewNameScope() + pkg := codegen.SnakeCase(codegen.Goify(svc.Name, false)) + pbPkgName + svcVarN = scope.HashedUnique(gs.ServiceExpr, codegen.Goify(svc.Name, true)) + sd = &ServiceData{ + Service: svc, + Name: svcVarN, + Description: svc.Description, + PkgName: pkg, + ServerStruct: "Server", + ClientStruct: "Client", + ServerInit: "New", + ClientInit: "NewClient", + ServerInterface: svcVarN + "Server", + ClientInterface: svcVarN + "Client", + ClientInterfaceInit: fmt.Sprintf("%s.New%sClient", pkg, svcVarN), + Scope: scope, } + seen, imported = make(map[string]struct{}), make(map[string]struct{}) for _, e := range gs.GRPCEndpoints { // convert request and response types to protocol buffer message types e.Request = makeProtoBufMessage(e.Request, protoBufify(e.Name()+"_request", true, true), sd) @@ -513,101 +510,89 @@ func (ServicesData) analyze(gs *expr.GRPCServiceExpr) *ServiceData { resultRef string viewedResultRef string errors []*ErrorData - - md = svc.Method(e.Name()) ) - { - if e.MethodExpr.Payload.Type != expr.Empty { - payloadRef = svc.Scope.GoFullTypeRef(e.MethodExpr.Payload, - pkgWithDefault(md.PayloadLoc, svc.PkgName)) - } - if e.MethodExpr.Result.Type != expr.Empty { - resultRef = svc.Scope.GoFullTypeRef(e.MethodExpr.Result, - pkgWithDefault(md.ResultLoc, svc.PkgName)) - } - if md.ViewedResult != nil { - viewedResultRef = md.ViewedResult.FullRef - } - errors = buildErrorsData(e, sd) - for _, er := range e.GRPCErrors { - if er.ErrorExpr.Type == expr.ErrorResult || !expr.IsObject(er.ErrorExpr.Type) { - continue - } - collect(er.Response.Message) + md := svc.Method(e.Name()) + if e.MethodExpr.Payload.Type != expr.Empty { + payloadRef = svc.Scope.GoFullTypeRef(e.MethodExpr.Payload, + pkgWithDefault(md.PayloadLoc, svc.PkgName)) + } + if e.MethodExpr.Result.Type != expr.Empty { + resultRef = svc.Scope.GoFullTypeRef(e.MethodExpr.Result, + pkgWithDefault(md.ResultLoc, svc.PkgName)) + } + if md.ViewedResult != nil { + viewedResultRef = md.ViewedResult.FullRef + } + errors = buildErrorsData(e, sd) + for _, er := range e.GRPCErrors { + if er.ErrorExpr.Type == expr.ErrorResult || !expr.IsObject(er.ErrorExpr.Type) { + continue } + collect(er.Response.Message) } // build request data - var ( - request *RequestData - reqMD []*MetadataData - ) - { - reqMD = extractMetadata(e.Metadata, e.MethodExpr.Payload, svc.Scope) - request = &RequestData{ - Description: e.Request.Description, - Metadata: reqMD, - ServerConvert: buildRequestConvertData(e.Request, e.MethodExpr.Payload, reqMD, e, sd, true), - ClientConvert: buildRequestConvertData(e.Request, e.MethodExpr.Payload, reqMD, e, sd, false), - } - if obj := expr.AsObject(e.Request.Type); (obj != nil && len(*obj) > 0) || expr.IsUnion(e.Request.Type) { - // add the request message as the first argument to the CLI - request.CLIArgs = append(request.CLIArgs, &InitArgData{ - Name: "message", - Ref: "message", - TypeName: protoBufGoFullTypeName(e.Request, sd.PkgName, sd.Scope), - TypeRef: protoBufGoFullTypeRef(e.Request, sd.PkgName, sd.Scope), - Example: e.Request.Example(expr.Root.API.ExampleGenerator), - }) - } - // pass the metadata as arguments to client CLI args - for _, m := range reqMD { - request.CLIArgs = append(request.CLIArgs, &InitArgData{ - Name: m.VarName, - Ref: m.VarName, - FieldName: m.FieldName, - FieldType: m.FieldType, - TypeName: m.TypeName, - TypeRef: m.TypeRef, - Type: m.Type, - Pointer: m.Pointer, - Required: m.Required, - Validate: m.Validate, - Example: m.Example, - DefaultValue: m.DefaultValue, - }) - } - if e.StreamingRequest.Type != expr.Empty { - request.Message = collect(e.StreamingRequest) - } else { - request.Message = collect(e.Request) - } + var request *RequestData + reqMD := extractMetadata(e.Metadata, e.MethodExpr.Payload, svc.Scope) + request = &RequestData{ + Description: e.Request.Description, + Metadata: reqMD, + ServerConvert: buildRequestConvertData(e.Request, e.MethodExpr.Payload, reqMD, e, sd, true), + ClientConvert: buildRequestConvertData(e.Request, e.MethodExpr.Payload, reqMD, e, sd, false), + } + if obj := expr.AsObject(e.Request.Type); (obj != nil && len(*obj) > 0) || expr.IsUnion(e.Request.Type) { + // add the request message as the first argument to the CLI + request.CLIArgs = append(request.CLIArgs, &InitArgData{ + Name: "message", + Ref: "message", + TypeName: protoBufGoFullTypeName(e.Request, sd.PkgName, sd.Scope), + TypeRef: protoBufGoFullTypeRef(e.Request, sd.PkgName, sd.Scope), + Example: e.Request.Example(expr.Root.API.ExampleGenerator), + }) + } + // pass the metadata as arguments to client CLI args + for _, m := range reqMD { + request.CLIArgs = append(request.CLIArgs, &InitArgData{ + Name: m.VarName, + Ref: m.VarName, + FieldName: m.FieldName, + FieldType: m.FieldType, + TypeName: m.TypeName, + TypeRef: m.TypeRef, + Type: m.Type, + Pointer: m.Pointer, + Required: m.Required, + Validate: m.Validate, + Example: m.Example, + DefaultValue: m.DefaultValue, + }) + } + if e.StreamingRequest.Type != expr.Empty { + request.Message = collect(e.StreamingRequest) + } else { + request.Message = collect(e.Request) } // build response data var ( response *ResponseData - hdrs []*MetadataData trlrs []*MetadataData - - result, svcCtx = resultContext(e, sd) ) - { - hdrs = extractMetadata(e.Response.Headers, result, svc.Scope) - trlrs = extractMetadata(e.Response.Trailers, result, svc.Scope) - response = &ResponseData{ - StatusCode: statusCodeToGRPCConst(e.Response.StatusCode), - Description: e.Response.Description, - Headers: hdrs, - Trailers: trlrs, - ServerConvert: buildResponseConvertData(e.Response.Message, result, svcCtx, hdrs, trlrs, e, sd, true), - ClientConvert: buildResponseConvertData(e.Response.Message, result, svcCtx, hdrs, trlrs, e, sd, false), - } - // If the endpoint is a streaming endpoint, no message is returned - // by gRPC. Hence, no need to set response message. - if e.Response.Message.Type != expr.Empty || !e.MethodExpr.IsStreaming() { - response.Message = collect(e.Response.Message) - } + result, svcCtx := resultContext(e, sd) + hdrs := extractMetadata(e.Response.Headers, result, svc.Scope) + trlrs = extractMetadata(e.Response.Trailers, result, svc.Scope) + response = &ResponseData{ + StatusCode: statusCodeToGRPCConst(e.Response.StatusCode), + Description: e.Response.Description, + Headers: hdrs, + Trailers: trlrs, + ServerConvert: buildResponseConvertData(e.Response.Message, result, svcCtx, hdrs, trlrs, e, sd, true), + ClientConvert: buildResponseConvertData(e.Response.Message, result, svcCtx, hdrs, trlrs, e, sd, false), + } + // If the endpoint is a streaming endpoint, no message is returned + // by gRPC. Hence, no need to set response message. + if e.Response.Message.Type != expr.Empty || !e.MethodExpr.IsStreaming() { + response.Message = collect(e.Response.Message) } // gather security requirements @@ -615,17 +600,15 @@ func (ServicesData) analyze(gs *expr.GRPCServiceExpr) *ServiceData { msgSch service.SchemesData metSch service.SchemesData ) - { - for _, req := range e.Requirements { - for _, sch := range req.Schemes { - s := md.Requirements.Scheme(sch.SchemeName).Dup() - s.In = sch.In - switch s.In { - case "message": - msgSch = msgSch.Append(s) - default: - metSch = metSch.Append(s) - } + for _, req := range e.Requirements { + for _, sch := range req.Schemes { + s := md.Requirements.Scheme(sch.SchemeName).Dup() + s.In = sch.In + switch s.In { + case "message": + msgSch = msgSch.Append(s) + default: + metSch = metSch.Append(s) } } } @@ -874,27 +857,24 @@ func buildRequestConvertData(request, payload *expr.AttributeExpr, md []*Metadat if svr { // server side - var data *InitData - { - data = buildInitData(request, payload, "message", "v", svcCtx, false, svr, false, sd) - data.Name = fmt.Sprintf("New%sPayload", codegen.Goify(e.Name(), true)) - data.Description = fmt.Sprintf("%s builds the payload of the %q endpoint of the %q service from the gRPC request type.", data.Name, e.Name(), svc.Name) - for _, m := range md { - // pass the metadata as arguments to payload constructor in server - data.Args = append(data.Args, &InitArgData{ - Name: m.VarName, - Ref: m.VarName, - FieldName: m.FieldName, - FieldType: m.FieldType, - TypeName: m.TypeName, - TypeRef: m.TypeRef, - Type: m.Type, - Pointer: m.Pointer, - Required: m.Required, - Validate: m.Validate, - Example: m.Example, - }) - } + data := buildInitData(request, payload, "message", "v", svcCtx, false, svr, false, sd) + data.Name = fmt.Sprintf("New%sPayload", codegen.Goify(e.Name(), true)) + data.Description = fmt.Sprintf("%s builds the payload of the %q endpoint of the %q service from the gRPC request type.", data.Name, e.Name(), svc.Name) + for _, m := range md { + // pass the metadata as arguments to payload constructor in server + data.Args = append(data.Args, &InitArgData{ + Name: m.VarName, + Ref: m.VarName, + FieldName: m.FieldName, + FieldType: m.FieldType, + TypeName: m.TypeName, + TypeRef: m.TypeRef, + Type: m.Type, + Pointer: m.Pointer, + Required: m.Required, + Validate: m.Validate, + Example: m.Example, + }) } return &ConvertData{ SrcName: protoBufGoFullTypeName(request, sd.PkgName, sd.Scope), @@ -907,14 +887,8 @@ func buildRequestConvertData(request, payload *expr.AttributeExpr, md []*Metadat } // client side - - var ( - data *InitData - ) - { - data = buildInitData(payload, request, "payload", "message", svcCtx, true, svr, false, sd) - data.Description = fmt.Sprintf("%s builds the gRPC request type from the payload of the %q endpoint of the %q service.", data.Name, e.Name(), svc.Name) - } + data := buildInitData(payload, request, "payload", "message", svcCtx, true, svr, false, sd) + data.Description = fmt.Sprintf("%s builds the gRPC request type from the payload of the %q endpoint of the %q service.", data.Name, e.Name(), svc.Name) return &ConvertData{ SrcName: svc.Scope.GoFullTypeName(payload, pkg), SrcRef: svc.Scope.GoFullTypeRef(payload, pkg), @@ -936,19 +910,11 @@ func buildResponseConvertData(response, result *expr.AttributeExpr, svcCtx *code if !svr && (e.MethodExpr.IsStreaming() || isEmpty(e.MethodExpr.Result.Type)) { return nil } - - var ( - svc = sd.Service - ) - + svc := sd.Service if svr { // server side - - var data *InitData - { - data = buildInitData(result, response, "result", "message", svcCtx, true, svr, false, sd) - data.Description = fmt.Sprintf("%s builds the gRPC response type from the result of the %q endpoint of the %q service.", data.Name, e.Name(), svc.Name) - } + data := buildInitData(result, response, "result", "message", svcCtx, true, svr, false, sd) + data.Description = fmt.Sprintf("%s builds the gRPC response type from the result of the %q endpoint of the %q service.", data.Name, e.Name(), svc.Name) return &ConvertData{ SrcName: svcCtx.Scope.Name(result, svcCtx.Pkg(result), svcCtx.Pointer, svcCtx.UseDefault), SrcRef: svcCtx.Scope.Ref(result, svcCtx.Pkg(result)), @@ -959,44 +925,40 @@ func buildResponseConvertData(response, result *expr.AttributeExpr, svcCtx *code } // client side - - var data *InitData - { - data = buildInitData(response, result, "message", "result", svcCtx, false, svr, false, sd) - data.Name = fmt.Sprintf("New%sResult", codegen.Goify(e.Name(), true)) - data.Description = fmt.Sprintf("%s builds the result type of the %q endpoint of the %q service from the gRPC response type.", data.Name, e.Name(), svc.Name) - for _, m := range hdrs { - // pass the headers as arguments to result constructor in client - data.Args = append(data.Args, &InitArgData{ - Name: m.VarName, - Ref: m.VarName, - FieldName: m.FieldName, - FieldType: m.FieldType, - TypeName: m.TypeName, - TypeRef: m.TypeRef, - Type: m.Type, - Pointer: m.Pointer, - Required: m.Required, - Validate: m.Validate, - Example: m.Example, - }) - } - for _, m := range trlrs { - // pass the trailers as arguments to result constructor in client - data.Args = append(data.Args, &InitArgData{ - Name: m.VarName, - Ref: m.VarName, - FieldName: m.FieldName, - FieldType: m.FieldType, - TypeName: m.TypeName, - TypeRef: m.TypeRef, - Type: m.Type, - Pointer: m.Pointer, - Required: m.Required, - Validate: m.Validate, - Example: m.Example, - }) - } + data := buildInitData(response, result, "message", "result", svcCtx, false, svr, false, sd) + data.Name = fmt.Sprintf("New%sResult", codegen.Goify(e.Name(), true)) + data.Description = fmt.Sprintf("%s builds the result type of the %q endpoint of the %q service from the gRPC response type.", data.Name, e.Name(), svc.Name) + for _, m := range hdrs { + // pass the headers as arguments to result constructor in client + data.Args = append(data.Args, &InitArgData{ + Name: m.VarName, + Ref: m.VarName, + FieldName: m.FieldName, + FieldType: m.FieldType, + TypeName: m.TypeName, + TypeRef: m.TypeRef, + Type: m.Type, + Pointer: m.Pointer, + Required: m.Required, + Validate: m.Validate, + Example: m.Example, + }) + } + for _, m := range trlrs { + // pass the trailers as arguments to result constructor in client + data.Args = append(data.Args, &InitArgData{ + Name: m.VarName, + Ref: m.VarName, + FieldName: m.FieldName, + FieldType: m.FieldType, + TypeName: m.TypeName, + TypeRef: m.TypeRef, + Type: m.Type, + Pointer: m.Pointer, + Required: m.Required, + Validate: m.Validate, + Example: m.Example, + }) } return &ConvertData{ SrcName: protoBufGoFullTypeName(response, sd.PkgName, sd.Scope), @@ -1031,40 +993,38 @@ func buildInitData(source, target *expr.AttributeExpr, sourceVar, targetVar stri // pbCtx = protoBufTypeContext(sd.PkgName, sd.Scope, proto && svr || !proto && !svr) pbCtx = protoBufTypeContext(sd.PkgName, sd.Scope, false) ) - { - name = "New" - srcCtx = pbCtx - tgtCtx = svcCtx - if proto { - srcCtx = svcCtx - tgtCtx = pbCtx - name += "Proto" - } - isStruct = expr.IsObject(target.Type) || expr.IsUnion(target.Type) - if _, ok := source.Type.(expr.UserType); ok && usesrc { - name += protoBufGoTypeName(source, sd.Scope) - } - n := protoBufGoTypeName(target, sd.Scope) - if !isStruct { - // If target is array, map, or primitive the name will be suffixed with - // the definition (e.g int, []string, map[int]string) which is incorrect. - n = protoBufGoTypeName(source, sd.Scope) - } - name += n - code, helpers, err = protoBufTransform(source, target, sourceVar, targetVar, srcCtx, tgtCtx, proto, true) - if err != nil { - panic(err) // bug - } - sd.transformHelpers = codegen.AppendHelpers(sd.transformHelpers, helpers) - if (!proto && !isEmpty(source.Type)) || (proto && !isEmpty(target.Type)) { - args = []*InitArgData{{ - Name: sourceVar, - Ref: sourceVar, - TypeName: srcCtx.Scope.Name(source, srcCtx.Pkg(source), srcCtx.Pointer, srcCtx.UseDefault), - TypeRef: srcCtx.Scope.Ref(source, srcCtx.Pkg(source)), - Example: source.Example(expr.Root.API.ExampleGenerator), - }} - } + name = "New" + srcCtx = pbCtx + tgtCtx = svcCtx + if proto { + srcCtx = svcCtx + tgtCtx = pbCtx + name += "Proto" + } + isStruct = expr.IsObject(target.Type) || expr.IsUnion(target.Type) + if _, ok := source.Type.(expr.UserType); ok && usesrc { + name += protoBufGoTypeName(source, sd.Scope) + } + n := protoBufGoTypeName(target, sd.Scope) + if !isStruct { + // If target is array, map, or primitive the name will be suffixed with + // the definition (e.g int, []string, map[int]string) which is incorrect. + n = protoBufGoTypeName(source, sd.Scope) + } + name += n + code, helpers, err = protoBufTransform(source, target, sourceVar, targetVar, srcCtx, tgtCtx, proto, true) + if err != nil { + panic(err) // bug + } + sd.transformHelpers = codegen.AppendHelpers(sd.transformHelpers, helpers) + if (!proto && !isEmpty(source.Type)) || (proto && !isEmpty(target.Type)) { + args = []*InitArgData{{ + Name: sourceVar, + Ref: sourceVar, + TypeName: srcCtx.Scope.Name(source, srcCtx.Pkg(source), srcCtx.Pointer, srcCtx.UseDefault), + TypeRef: srcCtx.Scope.Ref(source, srcCtx.Pkg(source)), + Example: source.Example(expr.Root.API.ExampleGenerator), + }} } return &InitData{ Name: name, @@ -1081,21 +1041,15 @@ func buildInitData(source, target *expr.AttributeExpr, sourceVar, targetVar stri // endpoint expression. The response message for each error response are // inferred from the method's error expression if not specified explicitly. func buildErrorsData(e *expr.GRPCEndpointExpr, sd *ServiceData) []*ErrorData { - var ( - errors []*ErrorData - - svc = sd.Service - ) + var errors []*ErrorData + svc := sd.Service errors = make([]*ErrorData, 0, len(e.GRPCErrors)) for _, v := range e.GRPCErrors { - var responseData *ResponseData - { - responseData = &ResponseData{ - StatusCode: statusCodeToGRPCConst(v.Response.StatusCode), - Description: v.Response.Description, - ServerConvert: buildErrorConvertData(v, e, sd, true), - ClientConvert: buildErrorConvertData(v, e, sd, false), - } + responseData := &ResponseData{ + StatusCode: statusCodeToGRPCConst(v.Response.StatusCode), + Description: v.Response.Description, + ServerConvert: buildErrorConvertData(v, e, sd, true), + ClientConvert: buildErrorConvertData(v, e, sd, false), } errorLoc := svc.Method(e.MethodExpr.Name).ErrorLocs[v.Name] errors = append(errors, &ErrorData{ @@ -1113,20 +1067,13 @@ func buildErrorConvertData(ge *expr.GRPCErrorExpr, e *expr.GRPCEndpointExpr, sd if ge.ErrorExpr.Type == expr.ErrorResult || !expr.IsObject(ge.ErrorExpr.Type) { return nil } - var ( - svc = sd.Service - svcCtx = serviceTypeContext(svc.PkgName, svc.Scope) - ) - + svc := sd.Service + svcCtx := serviceTypeContext(svc.PkgName, svc.Scope) if svr { // server side - - var data *InitData - { - data = buildInitData(ge.ErrorExpr.AttributeExpr, ge.Response.Message, "er", "message", svcCtx, true, svr, false, sd) - data.Name = fmt.Sprintf("New%s%sError", codegen.Goify(e.Name(), true), codegen.Goify(ge.Name, true)) - data.Description = fmt.Sprintf("%s builds the gRPC error response type from the error of the %q endpoint of the %q service.", data.Name, e.Name(), svc.Name) - } + data := buildInitData(ge.ErrorExpr.AttributeExpr, ge.Response.Message, "er", "message", svcCtx, true, svr, false, sd) + data.Name = fmt.Sprintf("New%s%sError", codegen.Goify(e.Name(), true), codegen.Goify(ge.Name, true)) + data.Description = fmt.Sprintf("%s builds the gRPC error response type from the error of the %q endpoint of the %q service.", data.Name, e.Name(), svc.Name) return &ConvertData{ SrcName: svcCtx.Scope.Name(ge.ErrorExpr.AttributeExpr, svcCtx.Pkg(ge.ErrorExpr.AttributeExpr), svcCtx.Pointer, svcCtx.UseDefault), SrcRef: svcCtx.Scope.Ref(ge.ErrorExpr.AttributeExpr, svcCtx.Pkg(ge.ErrorExpr.AttributeExpr)), @@ -1137,13 +1084,9 @@ func buildErrorConvertData(ge *expr.GRPCErrorExpr, e *expr.GRPCEndpointExpr, sd } // client side - - var data *InitData - { - data = buildInitData(ge.Response.Message, ge.ErrorExpr.AttributeExpr, "message", "er", svcCtx, false, svr, false, sd) - data.Name = fmt.Sprintf("New%s%sError", codegen.Goify(e.Name(), true), codegen.Goify(ge.Name, true)) - data.Description = fmt.Sprintf("%s builds the error type of the %q endpoint of the %q service from the gRPC error response type.", data.Name, e.Name(), svc.Name) - } + data := buildInitData(ge.Response.Message, ge.ErrorExpr.AttributeExpr, "message", "er", svcCtx, false, svr, false, sd) + data.Name = fmt.Sprintf("New%s%sError", codegen.Goify(e.Name(), true), codegen.Goify(ge.Name, true)) + data.Description = fmt.Sprintf("%s builds the error type of the %q endpoint of the %q service from the gRPC error response type.", data.Name, e.Name(), svc.Name) return &ConvertData{ SrcName: protoBufGoFullTypeName(ge.Response.Message, sd.PkgName, sd.Scope), SrcRef: protoBufGoFullTypeRef(ge.Response.Message, sd.PkgName, sd.Scope), @@ -1183,73 +1126,70 @@ func buildStreamData(e *expr.GRPCEndpointExpr, sd *ServiceData, svr bool) *Strea svcCtx = serviceTypeContext(svc.PkgName, svc.Scope) result, resCtx = resultContext(e, sd) ) - { - resVar := "result" - if md.ViewedResult != nil { - resVar = "vresult" - } - if svr { - typ = "server" - varn = md.ServerStream.VarName - intName = fmt.Sprintf("%s.%s_%sServer", sd.PkgName, svc.StructName, md.VarName) - svcInt = fmt.Sprintf("%s.%s", svc.PkgName, md.ServerStream.Interface) - if e.MethodExpr.Result.Type != expr.Empty { - sendName = md.ServerStream.SendName - sendRef = ed.ResultRef - sendWithContextName = md.ServerStream.SendWithContextName - sendConvert = &ConvertData{ - SrcName: resCtx.Scope.Name(result, resCtx.Pkg(result), resCtx.Pointer, resCtx.UseDefault), - SrcRef: resCtx.Scope.Ref(result, resCtx.Pkg(result)), - TgtName: protoBufGoFullTypeName(e.Response.Message, sd.PkgName, sd.Scope), - TgtRef: protoBufGoFullTypeRef(e.Response.Message, sd.PkgName, sd.Scope), - Init: buildInitData(result, e.Response.Message, resVar, "v", resCtx, true, svr, true, sd), - } + resVar := "result" + if md.ViewedResult != nil { + resVar = "vresult" + } + if svr { + typ = "server" + varn = md.ServerStream.VarName + intName = fmt.Sprintf("%s.%s_%sServer", sd.PkgName, svc.StructName, md.VarName) + svcInt = fmt.Sprintf("%s.%s", svc.PkgName, md.ServerStream.Interface) + if e.MethodExpr.Result.Type != expr.Empty { + sendName = md.ServerStream.SendName + sendRef = ed.ResultRef + sendWithContextName = md.ServerStream.SendWithContextName + sendConvert = &ConvertData{ + SrcName: resCtx.Scope.Name(result, resCtx.Pkg(result), resCtx.Pointer, resCtx.UseDefault), + SrcRef: resCtx.Scope.Ref(result, resCtx.Pkg(result)), + TgtName: protoBufGoFullTypeName(e.Response.Message, sd.PkgName, sd.Scope), + TgtRef: protoBufGoFullTypeRef(e.Response.Message, sd.PkgName, sd.Scope), + Init: buildInitData(result, e.Response.Message, resVar, "v", resCtx, true, svr, true, sd), } - if e.MethodExpr.StreamingPayload.Type != expr.Empty { - recvName = md.ServerStream.RecvName - recvWithContextName = md.ServerStream.RecvWithContextName - recvRef = svcCtx.Scope.Ref(e.MethodExpr.StreamingPayload, svcCtx.Pkg(e.MethodExpr.StreamingPayload)) - recvConvert = &ConvertData{ - SrcName: protoBufGoFullTypeName(e.StreamingRequest, sd.PkgName, sd.Scope), - SrcRef: protoBufGoFullTypeRef(e.StreamingRequest, sd.PkgName, sd.Scope), - TgtName: svcCtx.Scope.Name(e.MethodExpr.StreamingPayload, svcCtx.Pkg(e.MethodExpr.StreamingPayload), svcCtx.Pointer, svcCtx.UseDefault), - TgtRef: recvRef, - Init: buildInitData(e.StreamingRequest, e.MethodExpr.StreamingPayload, "v", "spayload", svcCtx, false, svr, true, sd), - Validation: addValidation(e.StreamingRequest, "stream", sd, true), - } + } + if e.MethodExpr.StreamingPayload.Type != expr.Empty { + recvName = md.ServerStream.RecvName + recvWithContextName = md.ServerStream.RecvWithContextName + recvRef = svcCtx.Scope.Ref(e.MethodExpr.StreamingPayload, svcCtx.Pkg(e.MethodExpr.StreamingPayload)) + recvConvert = &ConvertData{ + SrcName: protoBufGoFullTypeName(e.StreamingRequest, sd.PkgName, sd.Scope), + SrcRef: protoBufGoFullTypeRef(e.StreamingRequest, sd.PkgName, sd.Scope), + TgtName: svcCtx.Scope.Name(e.MethodExpr.StreamingPayload, svcCtx.Pkg(e.MethodExpr.StreamingPayload), svcCtx.Pointer, svcCtx.UseDefault), + TgtRef: recvRef, + Init: buildInitData(e.StreamingRequest, e.MethodExpr.StreamingPayload, "v", "spayload", svcCtx, false, svr, true, sd), + Validation: addValidation(e.StreamingRequest, "stream", sd, true), } - mustClose = md.ServerStream.MustClose - } else { - typ = "client" - varn = md.ClientStream.VarName - intName = fmt.Sprintf("%s.%s_%sClient", sd.PkgName, svc.StructName, md.VarName) - svcInt = fmt.Sprintf("%s.%s", svc.PkgName, md.ClientStream.Interface) - if e.MethodExpr.StreamingPayload.Type != expr.Empty { - sendName = md.ClientStream.SendName - sendWithContextName = md.ClientStream.SendWithContextName - sendRef = svcCtx.Scope.Ref(e.MethodExpr.StreamingPayload, svcCtx.Pkg(e.MethodExpr.StreamingPayload)) - sendConvert = &ConvertData{ - SrcName: svcCtx.Scope.Name(e.MethodExpr.StreamingPayload, svcCtx.Pkg(e.MethodExpr.StreamingPayload), svcCtx.Pointer, svcCtx.UseDefault), - SrcRef: sendRef, - TgtName: protoBufGoFullTypeName(e.StreamingRequest, sd.PkgName, sd.Scope), - TgtRef: protoBufGoFullTypeRef(e.StreamingRequest, sd.PkgName, sd.Scope), - Init: buildInitData(e.MethodExpr.StreamingPayload, e.StreamingRequest, "spayload", "v", svcCtx, true, svr, true, sd), - } + } + mustClose = md.ServerStream.MustClose + } else { + typ = "client" + varn = md.ClientStream.VarName + intName = fmt.Sprintf("%s.%s_%sClient", sd.PkgName, svc.StructName, md.VarName) + svcInt = fmt.Sprintf("%s.%s", svc.PkgName, md.ClientStream.Interface) + if e.MethodExpr.StreamingPayload.Type != expr.Empty { + sendName = md.ClientStream.SendName + sendWithContextName = md.ClientStream.SendWithContextName + sendRef = svcCtx.Scope.Ref(e.MethodExpr.StreamingPayload, svcCtx.Pkg(e.MethodExpr.StreamingPayload)) + sendConvert = &ConvertData{ + SrcName: svcCtx.Scope.Name(e.MethodExpr.StreamingPayload, svcCtx.Pkg(e.MethodExpr.StreamingPayload), svcCtx.Pointer, svcCtx.UseDefault), + SrcRef: sendRef, + TgtName: protoBufGoFullTypeName(e.StreamingRequest, sd.PkgName, sd.Scope), + TgtRef: protoBufGoFullTypeRef(e.StreamingRequest, sd.PkgName, sd.Scope), + Init: buildInitData(e.MethodExpr.StreamingPayload, e.StreamingRequest, "spayload", "v", svcCtx, true, svr, true, sd), } - if e.MethodExpr.Result.Type != expr.Empty { - recvName = md.ClientStream.RecvName - recvWithContextName = md.ClientStream.RecvWithContextName - recvRef = ed.ResultRef - recvConvert = &ConvertData{ - SrcName: protoBufGoFullTypeName(e.Response.Message, sd.PkgName, sd.Scope), - SrcRef: protoBufGoFullTypeRef(e.Response.Message, sd.PkgName, sd.Scope), - TgtName: resCtx.Scope.Name(result, resCtx.Pkg(result), resCtx.Pointer, resCtx.UseDefault), - TgtRef: resCtx.Scope.Ref(result, resCtx.Pkg(result)), - Init: buildInitData(e.Response.Message, result, "v", resVar, resCtx, false, svr, true, sd), - Validation: addValidation(e.Response.Message, "stream", sd, false), - } + } + if e.MethodExpr.Result.Type != expr.Empty { + recvName = md.ClientStream.RecvName + recvWithContextName = md.ClientStream.RecvWithContextName + recvRef = ed.ResultRef + recvConvert = &ConvertData{ + SrcName: protoBufGoFullTypeName(e.Response.Message, sd.PkgName, sd.Scope), + SrcRef: protoBufGoFullTypeRef(e.Response.Message, sd.PkgName, sd.Scope), + TgtName: resCtx.Scope.Name(result, resCtx.Pkg(result), resCtx.Pointer, resCtx.UseDefault), + TgtRef: resCtx.Scope.Ref(result, resCtx.Pkg(result)), + Init: buildInitData(e.Response.Message, result, "v", resVar, resCtx, false, svr, true, sd), + Validation: addValidation(e.Response.Message, "stream", sd, false), } - mustClose = md.ClientStream.MustClose } if sendConvert != nil { sendDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint gRPC stream.", sendName, sendConvert.TgtName, md.Name) @@ -1259,6 +1199,13 @@ func buildStreamData(e *expr.GRPCEndpointExpr, sd *ServiceData, svr bool) *Strea recvDesc = fmt.Sprintf("%s reads instances of %q from the %q endpoint gRPC stream.", recvName, recvConvert.SrcName, md.Name) recvWithContextDesc = fmt.Sprintf("%s reads instances of %q from the %q endpoint gRPC stream with context.", recvWithContextName, recvConvert.SrcName, md.Name) } + mustClose = md.ClientStream.MustClose + } + if sendConvert != nil { + sendDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint gRPC stream.", sendName, sendConvert.TgtName, md.Name) + } + if recvConvert != nil { + recvDesc = fmt.Sprintf("%s reads instances of %q from the %q endpoint gRPC stream.", recvName, recvConvert.SrcName, md.Name) } return &StreamData{ VarName: varn, @@ -1298,18 +1245,16 @@ func extractMetadata(a *expr.MappedAttributeExpr, service *expr.AttributeExpr, s typeRef = scope.GoTypeRef(unalias(c)) ft = service.Type ) - { - varn = scope.Name(codegen.Goify(name, false)) - fieldName = codegen.Goify(name, true) - if !expr.IsObject(service.Type) { - fieldName = "" - } else { - pointer = service.IsPrimitivePointer(name, true) - ft = service.Find(name).Type - } - if pointer { - typeRef = "*" + typeRef - } + varn = scope.Name(codegen.Goify(name, false)) + fieldName = codegen.Goify(name, true) + if !expr.IsObject(service.Type) { + fieldName = "" + } else { + pointer = service.IsPrimitivePointer(name, true) + ft = service.Find(name).Type + } + if pointer { + typeRef = "*" + typeRef } metadata = append(metadata, &MetadataData{ Name: elem, diff --git a/http/codegen/websocket.go b/http/codegen/websocket.go index 0439b81622..aec33ff1d4 100644 --- a/http/codegen/websocket.go +++ b/http/codegen/websocket.go @@ -94,110 +94,132 @@ func initWebSocketData(ed *EndpointData, e *expr.HTTPEndpointExpr, sd *ServiceDa svc = sd.Service svcctx = serviceContext(sd.Service.PkgName, sd.Service.Scope) ) - { - svrSendTypeName = ed.Result.Name - svrSendTypeRef = ed.Result.Ref - svrSendDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint websocket connection.", md.ServerStream.SendName, svrSendTypeName, md.Name) - svrSendWithContextDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint websocket connection with context.", md.ServerStream.SendWithContextName, svrSendTypeName, md.Name) - cliRecvDesc = fmt.Sprintf("%s reads instances of %q from the %q endpoint websocket connection.", md.ClientStream.RecvName, svrSendTypeName, md.Name) - cliRecvWithContextDesc = fmt.Sprintf("%s reads instances of %q from the %q endpoint websocket connection with context.", md.ClientStream.RecvWithContextName, svrSendTypeName, md.Name) - if e.MethodExpr.Stream == expr.ClientStreamKind || e.MethodExpr.Stream == expr.BidirectionalStreamKind { - svrRecvTypeName = sd.Scope.GoFullTypeName(e.MethodExpr.StreamingPayload, svc.PkgName) - svrRecvTypeRef = sd.Scope.GoFullTypeRef(e.MethodExpr.StreamingPayload, svc.PkgName) - svrPayload = buildRequestBodyType(e.StreamingBody, e.MethodExpr.StreamingPayload, e, true, sd) - if needInit(e.MethodExpr.StreamingPayload.Type) { - makeHTTPType(e.StreamingBody) - body := e.StreamingBody.Type - // generate constructor function to transform request body, - // into the method streaming payload type - var ( - name string - desc string - serverArgs []*InitArgData - serverCode string - err error - ) - { - n := codegen.Goify(e.MethodExpr.Name, true) - p := codegen.Goify(svrPayload.Name, true) - // Raw payload object has type name prefixed with endpoint name. No need to - // prefix the type name again. - if strings.HasPrefix(p, n) { - name = fmt.Sprintf("New%s", p) - } else { - name = fmt.Sprintf("New%s%s", n, p) - } - desc = fmt.Sprintf("%s builds a %s service %s endpoint payload.", name, svc.Name, e.MethodExpr.Name) - if body != expr.Empty { - var ( - ref string - svcode string - ) - { - ref = "body" - if expr.IsObject(body) { - ref = "&body" - } - if ut, ok := body.(expr.UserType); ok { - if val := ut.Attribute().Validation; val != nil { - httpctx := httpContext("", sd.Scope, true, true) - svcode = codegen.ValidationCode(ut.Attribute(), ut, httpctx, true, expr.IsAlias(ut), false, "body") - } + svrSendTypeName = ed.Result.Name + svrSendTypeRef = ed.Result.Ref + svrSendDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint websocket connection.", md.ServerStream.SendName, svrSendTypeName, md.Name) + svrSendWithContextDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint websocket connection with context.", md.ServerStream.SendWithContextName, svrSendTypeName, md.Name) + cliRecvDesc = fmt.Sprintf("%s reads instances of %q from the %q endpoint websocket connection.", md.ClientStream.RecvName, svrSendTypeName, md.Name) + cliRecvWithContextDesc = fmt.Sprintf("%s reads instances of %q from the %q endpoint websocket connection with context.", md.ClientStream.RecvWithContextName, svrSendTypeName, md.Name) + if e.MethodExpr.Stream == expr.ClientStreamKind || e.MethodExpr.Stream == expr.BidirectionalStreamKind { + svrRecvTypeName = sd.Scope.GoFullTypeName(e.MethodExpr.StreamingPayload, svc.PkgName) + svrRecvTypeRef = sd.Scope.GoFullTypeRef(e.MethodExpr.StreamingPayload, svc.PkgName) + svrPayload = buildRequestBodyType(e.StreamingBody, e.MethodExpr.StreamingPayload, e, true, sd) + if needInit(e.MethodExpr.StreamingPayload.Type) { + makeHTTPType(e.StreamingBody) + body := e.StreamingBody.Type + // generate constructor function to transform request body, + // into the method streaming payload type + var ( + name string + desc string + serverArgs []*InitArgData + serverCode string + err error + ) + { + n := codegen.Goify(e.MethodExpr.Name, true) + p := codegen.Goify(svrPayload.Name, true) + // Raw payload object has type name prefixed with endpoint name. No need to + // prefix the type name again. + if strings.HasPrefix(p, n) { + name = fmt.Sprintf("New%s", p) + } else { + name = fmt.Sprintf("New%s%s", n, p) + } + desc = fmt.Sprintf("%s builds a %s service %s endpoint payload.", name, svc.Name, e.MethodExpr.Name) + if body != expr.Empty { + var ( + ref string + svcode string + ) + { + ref = "body" + if expr.IsObject(body) { + ref = "&body" + } + if ut, ok := body.(expr.UserType); ok { + if val := ut.Attribute().Validation; val != nil { + httpctx := httpContext("", sd.Scope, true, true) + svcode = codegen.ValidationCode(ut.Attribute(), ut, httpctx, true, expr.IsAlias(ut), false, "body") } } - serverArgs = []*InitArgData{{ - Ref: ref, - AttributeData: &AttributeData{ - Name: "payload", - VarName: "body", - TypeName: sd.Scope.GoTypeName(e.StreamingBody), - TypeRef: sd.Scope.GoTypeRef(e.StreamingBody), - Type: e.StreamingBody.Type, - Required: true, - Example: e.Body.Example(expr.Root.API.ExampleGenerator), - Validate: svcode, - }, - }} } - if body != expr.Empty { - var helpers []*codegen.TransformFunctionData + serverArgs = []*InitArgData{{ + Ref: ref, + AttributeData: &AttributeData{ + Name: "payload", + VarName: "body", + TypeName: sd.Scope.GoTypeName(e.StreamingBody), + TypeRef: sd.Scope.GoTypeRef(e.StreamingBody), + Type: e.StreamingBody.Type, + Required: true, + Example: e.Body.Example(expr.Root.API.ExampleGenerator), + Validate: svcode, + }, + }} + if ut, ok := body.(expr.UserType); ok { + if val := ut.Attribute().Validation; val != nil { httpctx := httpContext("", sd.Scope, true, true) - serverCode, helpers, err = marshal(e.StreamingBody, e.MethodExpr.StreamingPayload, "body", "v", httpctx, svcctx) - if err == nil { - sd.ServerTransformHelpers = codegen.AppendHelpers(sd.ServerTransformHelpers, helpers) - } - } - if err != nil { - fmt.Println(err.Error()) // TBD validate DSL so errors are not possible + svcode = codegen.ValidationCode(ut.Attribute(), ut, httpctx, true, expr.IsAlias(ut), false, "body") } } - svrPayload.Init = &InitData{ - Name: name, - Description: desc, - ServerArgs: serverArgs, - ReturnTypeName: svc.Scope.GoFullTypeName(e.MethodExpr.StreamingPayload, svc.PkgName), - ReturnTypeRef: svc.Scope.GoFullTypeRef(e.MethodExpr.StreamingPayload, svc.PkgName), - ReturnIsStruct: expr.IsObject(e.MethodExpr.StreamingPayload.Type), - ReturnTypePkg: svc.PkgName, - ServerCode: serverCode, - } + serverArgs = []*InitArgData{{ + Ref: ref, + AttributeData: &AttributeData{ + Name: "payload", + VarName: "body", + TypeName: sd.Scope.GoTypeName(e.StreamingBody), + TypeRef: sd.Scope.GoTypeRef(e.StreamingBody), + Type: e.StreamingBody.Type, + Required: true, + Example: e.Body.Example(expr.Root.API.ExampleGenerator), + Validate: svcode, + }, + }} } - cliPayload = buildRequestBodyType(e.StreamingBody, e.MethodExpr.StreamingPayload, e, false, sd) - if cliPayload != nil { - sd.ClientTypeNames[cliPayload.Name] = false - sd.ServerTypeNames[cliPayload.Name] = false + if body != expr.Empty { + var helpers []*codegen.TransformFunctionData + httpctx := httpContext("", sd.Scope, true, true) + serverCode, helpers, err = marshal(e.StreamingBody, e.MethodExpr.StreamingPayload, "body", "v", httpctx, svcctx) + if err == nil { + sd.ServerTransformHelpers = codegen.AppendHelpers(sd.ServerTransformHelpers, helpers) + } } - if e.MethodExpr.Stream == expr.ClientStreamKind { - svrSendDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint websocket connection and closes the connection.", md.ServerStream.SendName, svrSendTypeName, md.Name) - svrSendWithContextDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint websocket connection with context and closes the connection.", md.ServerStream.SendWithContextName, svrSendTypeName, md.Name) - cliRecvDesc = fmt.Sprintf("%s stops sending messages to the %q endpoint websocket connection and reads instances of %q from the connection.", md.ClientStream.RecvName, md.Name, svrSendTypeName) - cliRecvWithContextDesc = fmt.Sprintf("%s stops sending messages to the %q endpoint websocket connection and reads instances of %q from the connection with context.", md.ClientStream.RecvWithContextName, md.Name, svrSendTypeName) + if err != nil { + fmt.Println(err.Error()) // TBD validate DSL so errors are not possible } - svrRecvDesc = fmt.Sprintf("%s reads instances of %q from the %q endpoint websocket connection.", md.ServerStream.RecvName, svrRecvTypeName, md.Name) - svrRecvWithContextDesc = fmt.Sprintf("%s reads instances of %q from the %q endpoint websocket connection with context.", md.ServerStream.RecvWithContextName, svrRecvTypeName, md.Name) - cliSendDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint websocket connection.", md.ClientStream.SendName, svrRecvTypeName, md.Name) - cliSendWithContextDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint websocket connection with context.", md.ClientStream.SendWithContextName, svrRecvTypeName, md.Name) } + svrPayload.Init = &InitData{ + Name: name, + Description: desc, + ServerArgs: serverArgs, + ReturnTypeName: svc.Scope.GoFullTypeName(e.MethodExpr.StreamingPayload, svc.PkgName), + ReturnTypeRef: svc.Scope.GoFullTypeRef(e.MethodExpr.StreamingPayload, svc.PkgName), + ReturnIsStruct: expr.IsObject(e.MethodExpr.StreamingPayload.Type), + ReturnTypePkg: svc.PkgName, + ServerCode: serverCode, + } + if e.MethodExpr.Stream == expr.ClientStreamKind { + svrSendDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint websocket connection and closes the connection.", md.ServerStream.SendName, svrSendTypeName, md.Name) + svrSendWithContextDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint websocket connection with context and closes the connection.", md.ServerStream.SendWithContextName, svrSendTypeName, md.Name) + cliRecvDesc = fmt.Sprintf("%s stops sending messages to the %q endpoint websocket connection and reads instances of %q from the connection.", md.ClientStream.RecvName, md.Name, svrSendTypeName) + cliRecvWithContextDesc = fmt.Sprintf("%s stops sending messages to the %q endpoint websocket connection and reads instances of %q from the connection with context.", md.ClientStream.RecvWithContextName, md.Name, svrSendTypeName) + } + svrRecvDesc = fmt.Sprintf("%s reads instances of %q from the %q endpoint websocket connection.", md.ServerStream.RecvName, svrRecvTypeName, md.Name) + svrRecvWithContextDesc = fmt.Sprintf("%s reads instances of %q from the %q endpoint websocket connection with context.", md.ServerStream.RecvWithContextName, svrRecvTypeName, md.Name) + cliSendDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint websocket connection.", md.ClientStream.SendName, svrRecvTypeName, md.Name) + cliSendWithContextDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint websocket connection with context.", md.ClientStream.SendWithContextName, svrRecvTypeName, md.Name) + cliPayload = buildRequestBodyType(e.StreamingBody, e.MethodExpr.StreamingPayload, e, false, sd) + if cliPayload != nil { + sd.ClientTypeNames[cliPayload.Name] = false + sd.ServerTypeNames[cliPayload.Name] = false + } + if e.MethodExpr.Stream == expr.ClientStreamKind { + svrSendDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint websocket connection and closes the connection.", md.ServerStream.SendName, svrSendTypeName, md.Name) + cliRecvDesc = fmt.Sprintf("%s stops sending messages to the %q endpoint websocket connection and reads instances of %q from the connection.", md.ClientStream.RecvName, md.Name, svrSendTypeName) + } + svrRecvDesc = fmt.Sprintf("%s reads instances of %q from the %q endpoint websocket connection.", md.ServerStream.RecvName, svrRecvTypeName, md.Name) + cliSendDesc = fmt.Sprintf("%s streams instances of %q to the %q endpoint websocket connection.", md.ClientStream.SendName, svrRecvTypeName, md.Name) } ed.ServerWebSocket = &WebSocketData{ VarName: md.ServerStream.VarName,