Skip to content

Commit

Permalink
Fix bug where the CLI ParseEndpoint method would try to wrap every me…
Browse files Browse the repository at this point in the history
…thod with interceptors even if they do not apply
  • Loading branch information
douglaswth committed Feb 1, 2025
1 parent 9c35929 commit cc0e16a
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 14 deletions.
19 changes: 15 additions & 4 deletions codegen/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ type (
Conversion string
// Example is a valid command invocation, starting with the command name.
Example string
// Interceptors contains the data for client interceptors if any apply to the endpoint method.
Interceptors *InterceptorData
}

// InterceptorData contains the data needed to generate interceptor code.
Expand Down Expand Up @@ -181,18 +183,19 @@ func BuildCommandData(data *service.Data) *CommandData {

// BuildSubcommandData builds the data needed by CLI code generators to render
// the CLI parsing of the service sub-command.
func BuildSubcommandData(svcName string, m *service.MethodData, buildFunction *BuildFunctionData, flags []*FlagData) *SubcommandData {
func BuildSubcommandData(data *service.Data, m *service.MethodData, buildFunction *BuildFunctionData, flags []*FlagData) *SubcommandData {
var (
name string
fullName string
description string

conversion string
conversion string
interceptors *InterceptorData
)
{
en := m.Name
name = codegen.KebabCase(en)
fullName = goifyTerms(svcName, en)
fullName = goifyTerms(data.Name, en)
description = m.Description
if description == "" {
description = fmt.Sprintf("Make request to the %q endpoint", m.Name)
Expand Down Expand Up @@ -227,6 +230,13 @@ func BuildSubcommandData(svcName string, m *service.MethodData, buildFunction *B
conversion += "\n}"
}
}

if len(m.ClientInterceptors) > 0 {
interceptors = &InterceptorData{
VarName: "inter",
PkgName: data.PkgName,
}
}
}
sub := &SubcommandData{
Name: name,
Expand All @@ -236,8 +246,9 @@ func BuildSubcommandData(svcName string, m *service.MethodData, buildFunction *B
MethodVarName: m.VarName,
BuildFunction: buildFunction,
Conversion: conversion,
Interceptors: interceptors,
}
generateExample(sub, svcName)
generateExample(sub, data.Name)

return sub
}
Expand Down
2 changes: 1 addition & 1 deletion grpc/codegen/client_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func ClientCLIFiles(genpkg string, root *expr.RootExpr) []*codegen.File {
command := cli.BuildCommandData(sd.Service)
for _, e := range sd.Endpoints {
flags, buildFunction := buildFlags(e)
subcmd := cli.BuildSubcommandData(sd.Service.Name, e.Method, buildFunction, flags)
subcmd := cli.BuildSubcommandData(sd.Service, e.Method, buildFunction, flags)
command.Subcommands = append(command.Subcommands, subcmd)
}
command.Example = command.Subcommands[0].Example
Expand Down
7 changes: 3 additions & 4 deletions grpc/codegen/templates/parse_endpoint.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ func ParseEndpoint(
c := {{ .PkgName }}.NewClient(cc, opts...)
switch epn {
{{- $pkgName := .PkgName }}
{{- $interceptors := .Interceptors }}
{{ range .Subcommands }}
case "{{ .Name }}":
endpoint = c.{{ .MethodVarName }}()
{{- if $interceptors }}
endpoint = {{ $interceptors.PkgName }}.Wrap{{ .MethodVarName }}ClientEndpoint(endpoint, {{ $interceptors.VarName }})
{{- if .Interceptors }}
endpoint = {{ .Interceptors.PkgName }}.Wrap{{ .MethodVarName }}ClientEndpoint(endpoint, {{ .Interceptors.VarName }})
{{- end }}
{{- if .BuildFunction }}
data, err = {{ $pkgName}}.{{ .BuildFunction.Name }}({{ range .BuildFunction.ActualParams }}*{{ . }}Flag, {{ end }})
Expand All @@ -44,4 +43,4 @@ func ParseEndpoint(
}

return endpoint, data, nil
}
}
4 changes: 2 additions & 2 deletions http/codegen/client_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type subcommandData struct {
// MultipartFuncName is the name of the function used to render a multipart
// request encoder.
MultipartFuncName string
// MultipartFuncName is the name of the variabl used to render a multipart
// MultipartFuncName is the name of the variable used to render a multipart
// request encoder.
MultipartVarName string
// StreamFlag is the flag used to identify the file to be streamed when
Expand Down Expand Up @@ -87,7 +87,7 @@ func buildSubcommandData(sd *ServiceData, e *EndpointData) *subcommandData {
flags, buildFunction := buildFlags(sd, e)

sub := &subcommandData{
SubcommandData: cli.BuildSubcommandData(sd.Service.Name, e.Method, buildFunction, flags),
SubcommandData: cli.BuildSubcommandData(sd.Service, e.Method, buildFunction, flags),
}
if e.MultipartRequestEncoder != nil {
sub.MultipartVarName = e.MultipartRequestEncoder.VarName
Expand Down
5 changes: 2 additions & 3 deletions http/codegen/templates/parse_endpoint.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@ func ParseEndpoint(
c := {{ .PkgName }}.NewClient(scheme, host, doer, enc, dec, restore{{ if .NeedStream }}, dialer, {{ .VarName }}Configurer{{ end }})
switch epn {
{{- $pkgName := .PkgName }}
{{- $interceptors := .Interceptors }}
{{- range .Subcommands }}
case "{{ .Name }}":
endpoint = c.{{ .MethodVarName }}({{ if .MultipartVarName }}{{ .MultipartVarName }}{{ end }})
{{- if $interceptors }}
endpoint = {{ $interceptors.PkgName }}.Wrap{{ .MethodVarName }}ClientEndpoint(endpoint, {{ $interceptors.VarName }})
{{- if .Interceptors }}
endpoint = {{ .Interceptors.PkgName }}.Wrap{{ .MethodVarName }}ClientEndpoint(endpoint, {{ .Interceptors.VarName }})
{{- end }}
{{- if .BuildFunction }}
data, err = {{ $pkgName }}.{{ .BuildFunction.Name }}({{ range .BuildFunction.ActualParams }}*{{ . }}Flag, {{ end }})
Expand Down

0 comments on commit cc0e16a

Please sign in to comment.