Skip to content

Commit

Permalink
Initial interceptors implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael committed Nov 26, 2024
1 parent af86bfb commit 0955550
Show file tree
Hide file tree
Showing 48 changed files with 2,608 additions and 83 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
<a href="https://gophers.slack.com/messages/goa"><img alt="Slack: Goa" src="https://img.shields.io/badge/Goa-gray.svg?longCache=true&logo=slack&colorB=red&style=for-the-badge"></a>
<a href="https://invite.slack.golangbridge.org/"><img alt="Slack: Sign-up" src="https://img.shields.io/badge/Signup-gray.svg?longCache=true&logo=slack&colorB=red&style=for-the-badge"></a>
<a href="https://bsky.app/profile/goadesign.bsky.social"><img alt="BSky: Goa" src="https://img.shields.io/badge/Bluesky-0285FF?longCache=true&logo=bluesky&logoColor=fff&style=for-the-badge"></a>
<a href="https://twitter.com/goadesign"><img alt="Twitter: @goadesign" src="https://img.shields.io/badge/@goadesign-gray.svg?logo=twitter&colorB=blue&style=for-the-badge"></a>
</p>
</p>

Expand Down
12 changes: 12 additions & 0 deletions codegen/service/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ func ClientFile(_ string, service *expr.ServiceExpr) *codegen.File {
Source: readTemplate("service_client_method"),
Data: m,
})
if len(m.ClientInterceptors) > 0 {
sections = append(sections, &codegen.SectionTemplate{
Name: "client-wrapper",
Source: readTemplate("client_wrappers"),
Data: map[string]interface{}{
"Method": m.Name,
"MethodVarName": codegen.Goify(m.Name, true),
"Service": svc.Name,
"ClientInterceptors": m.ClientInterceptors,
},
})
}
}
}

Expand Down
1 change: 1 addition & 0 deletions codegen/service/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func TestClient(t *testing.T) {
{"client-streaming-payload-no-result", testdata.StreamingPayloadNoResultMethodDSL, testdata.StreamingPayloadNoResultMethodClient},
{"client-bidirectional-streaming", testdata.BidirectionalStreamingMethodDSL, testdata.BidirectionalStreamingMethodClient},
{"client-bidirectional-streaming-no-payload", testdata.BidirectionalStreamingNoPayloadMethodDSL, testdata.BidirectionalStreamingNoPayloadMethodClient},
{"client-interceptor", testdata.EndpointWithClientInterceptorDSL, testdata.InterceptorClient},
}
for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
Expand Down
24 changes: 0 additions & 24 deletions codegen/service/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"goa.design/goa/v3/codegen"
"goa.design/goa/v3/codegen/service/testdata"
"goa.design/goa/v3/dsl"
"goa.design/goa/v3/eval"
"goa.design/goa/v3/expr"
)

Expand Down Expand Up @@ -257,30 +256,7 @@ func TestConvertFile(t *testing.T) {
}
}

// runDSL returns the DSL root resulting from running the given DSL.
func runDSL(t *testing.T, dsl func()) *expr.RootExpr {
// reset all roots and codegen data structures
Services = make(ServicesData)
eval.Reset()
expr.Root = new(expr.RootExpr)
expr.GeneratedResultTypes = new(expr.ResultTypesRoot)
require.NoError(t, eval.Register(expr.Root))
require.NoError(t, eval.Register(expr.GeneratedResultTypes))
expr.Root.API = expr.NewAPIExpr("test api", func() {})
expr.Root.API.Servers = []*expr.ServerExpr{expr.Root.API.DefaultServer()}

// run DSL (first pass)
require.True(t, eval.Execute(dsl, nil))

// run DSL (second pass)
require.NoError(t, eval.RunDSL())

// return generated root
return expr.Root
}

// Test fixtures

var obj = &expr.UserTypeExpr{
AttributeExpr: &expr.AttributeExpr{
Type: &expr.Object{
Expand Down
66 changes: 53 additions & 13 deletions codegen/service/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ type (
ServiceVarName string
// Methods lists the endpoint struct methods.
Methods []*EndpointMethodData
// HasServerInterceptors indicates if the service has server interceptors.
HasServerInterceptors bool
// HasClientInterceptors indicates if the service has client interceptors.
HasClientInterceptors bool
// ClientInitArgs lists the arguments needed to instantiate the client.
ClientInitArgs string
// Schemes contains the security schemes types used by the
Expand All @@ -44,6 +48,10 @@ type (
ServiceName string
// ServiceVarName is the name of the owner service Go interface.
ServiceVarName string
// ServerInterceptors contains the server-side interceptors for this method
ServerInterceptors []*InterceptorData
// ClientInterceptors contains the client-side interceptors for this method
ClientInterceptors []*InterceptorData
}
)

Expand Down Expand Up @@ -122,6 +130,18 @@ func EndpointFile(genpkg string, service *expr.ServiceExpr) *codegen.File {
Data: m,
FuncMap: map[string]any{"payloadVar": payloadVar},
})
if len(m.ServerInterceptors) > 0 {
sections = append(sections, &codegen.SectionTemplate{
Name: "endpoint-wrapper",
Source: readTemplate("endpoint_wrappers"),
Data: map[string]interface{}{
"MethodVarName": codegen.Goify(m.Name, true),
"Method": m.Name,
"Service": svc.Name,
"ServerInterceptors": m.ServerInterceptors,
},
})
}
}
}

Expand All @@ -133,25 +153,45 @@ func endpointData(service *expr.ServiceExpr) *EndpointsData {
methods := make([]*EndpointMethodData, len(svc.Methods))
names := make([]string, len(svc.Methods))
for i, m := range svc.Methods {
serverInts, clientInts := buildMethodInterceptors(service.Method(m.Name), svc.Scope)
methods[i] = &EndpointMethodData{
MethodData: m,
ArgName: codegen.Goify(m.VarName, false),
ServiceName: svc.Name,
ServiceVarName: serviceInterfaceName,
ClientVarName: clientStructName,
MethodData: m,
ArgName: codegen.Goify(m.VarName, false),
ServiceName: svc.Name,
ServiceVarName: serviceInterfaceName,
ClientVarName: clientStructName,
ServerInterceptors: serverInts,
ClientInterceptors: clientInts,
}
names[i] = codegen.Goify(m.VarName, false)
}
desc := fmt.Sprintf("%s wraps the %q service endpoints.", endpointsStructName, service.Name)
var hasServerInterceptors, hasClientInterceptors bool
for _, m := range methods {
if len(m.ServerInterceptors) > 0 {
hasServerInterceptors = true
if hasClientInterceptors {
break
}
}
if len(m.ClientInterceptors) > 0 {
hasClientInterceptors = true
if hasServerInterceptors {
break
}
}
}
return &EndpointsData{
Name: service.Name,
Description: desc,
VarName: endpointsStructName,
ClientVarName: clientStructName,
ServiceVarName: serviceInterfaceName,
ClientInitArgs: strings.Join(names, ", "),
Methods: methods,
Schemes: svc.Schemes,
Name: service.Name,
Description: desc,
VarName: endpointsStructName,
ClientVarName: clientStructName,
ServiceVarName: serviceInterfaceName,
ClientInitArgs: strings.Join(names, ", "),
Methods: methods,
HasServerInterceptors: hasServerInterceptors,
HasClientInterceptors: hasClientInterceptors,
Schemes: svc.Schemes,
}
}

Expand Down
2 changes: 2 additions & 0 deletions codegen/service/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ func TestEndpoint(t *testing.T) {
{"endpoint-streaming-payload-no-result", testdata.StreamingPayloadNoResultMethodDSL, testdata.StreamingPayloadNoResultMethodEndpoint},
{"endpoint-bidirectional-streaming", testdata.BidirectionalStreamingEndpointDSL, testdata.BidirectionalStreamingMethodEndpoint},
{"endpoint-bidirectional-streaming-no-payload", testdata.BidirectionalStreamingNoPayloadMethodDSL, testdata.BidirectionalStreamingNoPayloadMethodEndpoint},
{"endpoint-with-server-interceptor", testdata.EndpointWithServerInterceptorDSL, testdata.EndpointWithServerInterceptor},
{"endpoint-with-multiple-interceptors", testdata.EndpointWithMultipleInterceptorsDSL, testdata.EndpointWithMultipleInterceptors},
}
for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
Expand Down
215 changes: 215 additions & 0 deletions codegen/service/interceptors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
package service

import (
"path/filepath"

"goa.design/goa/v3/codegen"
"goa.design/goa/v3/expr"
)

type (
// ServiceInterceptorData contains all data needed for generating interceptor code
ServiceInterceptorData struct {
Service string
PkgName string
Methods []*MethodInterceptorData
ServerInterceptors []*InterceptorData
ClientInterceptors []*InterceptorData
AllInterceptors []*InterceptorData
HasPrivateImplementationTypes bool
}

// MethodInterceptorData contains interceptor data for a single method
MethodInterceptorData struct {
Service string
Method string
MethodVarName string
PayloadRef string
ResultRef string
ServerInterceptors []*InterceptorData
ClientInterceptors []*InterceptorData
}

// InterceptorData describes a single interceptor.
InterceptorData struct {
Name string
UnexportedName string
Description string
PayloadRef string
ResultRef string
ReadPayload []*AttributeData
WritePayload []*AttributeData
ReadResult []*AttributeData
WriteResult []*AttributeData
ServerStreamInputStruct string
ClientStreamInputStruct string
}

// AttributeData describes a single attribute.
AttributeData struct {
Name string
TypeRef string
FieldPointer bool
}
)

// InterceptorsFile returns the interceptors file for the given service.
func InterceptorsFile(genpkg string, service *expr.ServiceExpr) *codegen.File {
svc := Services.Get(service.Name)
data := interceptorsData(service)
if len(data.ServerInterceptors) == 0 && len(data.ClientInterceptors) == 0 {
return nil
}

path := filepath.Join(codegen.Gendir, svc.PathName, "interceptors.go")
sections := []*codegen.SectionTemplate{
codegen.Header(service.Name+" interceptors", svc.PkgName, []*codegen.ImportSpec{
{Path: "context"},
codegen.GoaImport(""),
}),
{
Name: "interceptors",
Source: readTemplate("interceptors"),
Data: data,
},
}

return &codegen.File{Path: path, SectionTemplates: sections}
}

func interceptorsData(service *expr.ServiceExpr) *ServiceInterceptorData {
svc := Services.Get(service.Name)
scope := svc.Scope

// Build method data first
methods := make([]*MethodInterceptorData, 0, len(service.Methods))
seenInts := make(map[string]*InterceptorData)
var serviceServerInts, serviceClientInts, allInts []*InterceptorData
var hasTypes bool

for _, m := range service.Methods {
methodServerInts, methodClientInts := buildMethodInterceptors(m, scope)
if len(methodServerInts) == 0 && len(methodClientInts) == 0 {
continue
}
hasTypes = hasTypes || hasPrivateImplementationTypes(methodServerInts) || hasPrivateImplementationTypes(methodClientInts)

// Add method data
methods = append(methods, &MethodInterceptorData{
Service: svc.Name,
Method: m.Name,
MethodVarName: codegen.Goify(m.Name, true),
PayloadRef: scope.GoFullTypeRef(m.Payload, ""),
ResultRef: scope.GoFullTypeRef(m.Result, ""),
ServerInterceptors: methodServerInts,
ClientInterceptors: methodClientInts,
})

// Collect unique interceptors
for _, i := range methodServerInts {
if _, ok := seenInts[i.Name]; !ok {
seenInts[i.Name] = i
serviceServerInts = append(serviceServerInts, i)
allInts = append(allInts, i)
}
}
for _, i := range methodClientInts {
if _, ok := seenInts[i.Name]; !ok {
seenInts[i.Name] = i
serviceClientInts = append(serviceClientInts, i)
allInts = append(allInts, i)
}
}
}

return &ServiceInterceptorData{
Service: service.Name,
PkgName: svc.PkgName,
Methods: methods,
ServerInterceptors: serviceServerInts,
ClientInterceptors: serviceClientInts,
AllInterceptors: allInts,
HasPrivateImplementationTypes: hasTypes,
}
}

func buildMethodInterceptors(m *expr.MethodExpr, scope *codegen.NameScope) ([]*InterceptorData, []*InterceptorData) {
svc := Services.Get(m.Service.Name)
methodData := svc.Method(m.Name)
var serverEndpointStruct, clientEndpointStruct string
if methodData.ServerStream != nil {
serverEndpointStruct = methodData.ServerStream.EndpointStruct
}
if methodData.ClientStream != nil {
clientEndpointStruct = methodData.ClientStream.EndpointStruct
}
var hasPrivateImplementationTypes bool
buildInterceptor := func(intr *expr.InterceptorExpr) *InterceptorData {
hasPrivateImplementationTypes = hasPrivateImplementationTypes ||
intr.ReadPayload != nil || intr.WritePayload != nil || intr.ReadResult != nil || intr.WriteResult != nil

return &InterceptorData{
Name: codegen.Goify(intr.Name, true),
UnexportedName: codegen.Goify(intr.Name, false),
Description: intr.Description,
PayloadRef: methodData.PayloadRef,
ResultRef: methodData.ResultRef,
ServerStreamInputStruct: serverEndpointStruct,
ClientStreamInputStruct: clientEndpointStruct,
ReadPayload: collectAttributes(intr.ReadPayload, m.Payload, scope),
WritePayload: collectAttributes(intr.WritePayload, m.Payload, scope),
ReadResult: collectAttributes(intr.ReadResult, m.Result, scope),
WriteResult: collectAttributes(intr.WriteResult, m.Result, scope),
}
}

serverInts := make([]*InterceptorData, len(m.ServerInterceptors))
for i, intr := range m.ServerInterceptors {
serverInts[i] = buildInterceptor(intr)
}

clientInts := make([]*InterceptorData, len(m.ClientInterceptors))
for i, intr := range m.ClientInterceptors {
clientInts[i] = buildInterceptor(intr)
}

return serverInts, clientInts
}

// hasPrivateImplementationTypes returns true if any of the interceptors have
// private implementation types.
func hasPrivateImplementationTypes(interceptors []*InterceptorData) bool {
for _, intr := range interceptors {
if intr.ReadPayload != nil || intr.WritePayload != nil || intr.ReadResult != nil || intr.WriteResult != nil {
return true
}
}
return false
}

// collectAttributes builds AttributeData from an AttributeExpr
func collectAttributes(attrNames, parent *expr.AttributeExpr, scope *codegen.NameScope) []*AttributeData {
if attrNames == nil {
return nil
}

obj := expr.AsObject(attrNames.Type)
if obj == nil {
return nil
}

data := make([]*AttributeData, len(*obj))
for i, nat := range *obj {
parentAttr := parent.Find(nat.Name)
if parentAttr == nil {
continue
}

data[i] = &AttributeData{
Name: codegen.Goify(nat.Name, true),
TypeRef: scope.GoTypeRef(parentAttr),
FieldPointer: parent.IsPrimitivePointer(nat.Name, true),
}
}
return data
}
Loading

0 comments on commit 0955550

Please sign in to comment.