Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial interceptors implementation #3616

Open
wants to merge 6 commits into
base: v3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions codegen/service/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ func ClientFile(_ string, service *expr.ServiceExpr) *codegen.File {
Source: readTemplate("service_client_method"),
Data: m,
})
if len(m.ClientInterceptors) > 0 {
sections = append(sections, &codegen.SectionTemplate{
Name: "client-wrapper",
Source: readTemplate("client_wrappers"),
Data: map[string]interface{}{
"Method": m.Name,
"MethodVarName": codegen.Goify(m.Name, true),
"Service": svc.Name,
"ClientInterceptors": m.ClientInterceptors,
},
})
}
}
}

Expand Down
1 change: 1 addition & 0 deletions codegen/service/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func TestClient(t *testing.T) {
{"client-streaming-payload-no-result", testdata.StreamingPayloadNoResultMethodDSL, testdata.StreamingPayloadNoResultMethodClient},
{"client-bidirectional-streaming", testdata.BidirectionalStreamingMethodDSL, testdata.BidirectionalStreamingMethodClient},
{"client-bidirectional-streaming-no-payload", testdata.BidirectionalStreamingNoPayloadMethodDSL, testdata.BidirectionalStreamingNoPayloadMethodClient},
{"client-interceptor", testdata.EndpointWithClientInterceptorDSL, testdata.InterceptorClient},
}
for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
Expand Down
24 changes: 0 additions & 24 deletions codegen/service/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"goa.design/goa/v3/codegen"
"goa.design/goa/v3/codegen/service/testdata"
"goa.design/goa/v3/dsl"
"goa.design/goa/v3/eval"
"goa.design/goa/v3/expr"
)

Expand Down Expand Up @@ -257,30 +256,7 @@ func TestConvertFile(t *testing.T) {
}
}

// runDSL returns the DSL root resulting from running the given DSL.
func runDSL(t *testing.T, dsl func()) *expr.RootExpr {
// reset all roots and codegen data structures
Services = make(ServicesData)
eval.Reset()
expr.Root = new(expr.RootExpr)
expr.GeneratedResultTypes = new(expr.ResultTypesRoot)
require.NoError(t, eval.Register(expr.Root))
require.NoError(t, eval.Register(expr.GeneratedResultTypes))
expr.Root.API = expr.NewAPIExpr("test api", func() {})
expr.Root.API.Servers = []*expr.ServerExpr{expr.Root.API.DefaultServer()}

// run DSL (first pass)
require.True(t, eval.Execute(dsl, nil))

// run DSL (second pass)
require.NoError(t, eval.RunDSL())

// return generated root
return expr.Root
}

// Test fixtures

var obj = &expr.UserTypeExpr{
AttributeExpr: &expr.AttributeExpr{
Type: &expr.Object{
Expand Down
66 changes: 53 additions & 13 deletions codegen/service/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ type (
ServiceVarName string
// Methods lists the endpoint struct methods.
Methods []*EndpointMethodData
// HasServerInterceptors indicates if the service has server interceptors.
HasServerInterceptors bool
// HasClientInterceptors indicates if the service has client interceptors.
HasClientInterceptors bool
// ClientInitArgs lists the arguments needed to instantiate the client.
ClientInitArgs string
// Schemes contains the security schemes types used by the
Expand All @@ -44,6 +48,10 @@ type (
ServiceName string
// ServiceVarName is the name of the owner service Go interface.
ServiceVarName string
// ServerInterceptors contains the server-side interceptors for this method
ServerInterceptors []*InterceptorData
// ClientInterceptors contains the client-side interceptors for this method
ClientInterceptors []*InterceptorData
}
)

Expand Down Expand Up @@ -122,6 +130,18 @@ func EndpointFile(genpkg string, service *expr.ServiceExpr) *codegen.File {
Data: m,
FuncMap: map[string]any{"payloadVar": payloadVar},
})
if len(m.ServerInterceptors) > 0 {
sections = append(sections, &codegen.SectionTemplate{
Name: "endpoint-wrapper",
Source: readTemplate("endpoint_wrappers"),
Data: map[string]interface{}{
"MethodVarName": codegen.Goify(m.Name, true),
"Method": m.Name,
"Service": svc.Name,
"ServerInterceptors": m.ServerInterceptors,
},
})
}
}
}

Expand All @@ -133,25 +153,45 @@ func endpointData(service *expr.ServiceExpr) *EndpointsData {
methods := make([]*EndpointMethodData, len(svc.Methods))
names := make([]string, len(svc.Methods))
for i, m := range svc.Methods {
serverInts, clientInts := buildMethodInterceptors(service.Method(m.Name), svc.Scope)
methods[i] = &EndpointMethodData{
MethodData: m,
ArgName: codegen.Goify(m.VarName, false),
ServiceName: svc.Name,
ServiceVarName: serviceInterfaceName,
ClientVarName: clientStructName,
MethodData: m,
ArgName: codegen.Goify(m.VarName, false),
ServiceName: svc.Name,
ServiceVarName: serviceInterfaceName,
ClientVarName: clientStructName,
ServerInterceptors: serverInts,
ClientInterceptors: clientInts,
}
names[i] = codegen.Goify(m.VarName, false)
}
desc := fmt.Sprintf("%s wraps the %q service endpoints.", endpointsStructName, service.Name)
var hasServerInterceptors, hasClientInterceptors bool
for _, m := range methods {
if len(m.ServerInterceptors) > 0 {
hasServerInterceptors = true
if hasClientInterceptors {
break
}
}
if len(m.ClientInterceptors) > 0 {
hasClientInterceptors = true
if hasServerInterceptors {
break
}
}
}
return &EndpointsData{
Name: service.Name,
Description: desc,
VarName: endpointsStructName,
ClientVarName: clientStructName,
ServiceVarName: serviceInterfaceName,
ClientInitArgs: strings.Join(names, ", "),
Methods: methods,
Schemes: svc.Schemes,
Name: service.Name,
Description: desc,
VarName: endpointsStructName,
ClientVarName: clientStructName,
ServiceVarName: serviceInterfaceName,
ClientInitArgs: strings.Join(names, ", "),
Methods: methods,
HasServerInterceptors: hasServerInterceptors,
HasClientInterceptors: hasClientInterceptors,
Schemes: svc.Schemes,
}
}

Expand Down
2 changes: 2 additions & 0 deletions codegen/service/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ func TestEndpoint(t *testing.T) {
{"endpoint-streaming-payload-no-result", testdata.StreamingPayloadNoResultMethodDSL, testdata.StreamingPayloadNoResultMethodEndpoint},
{"endpoint-bidirectional-streaming", testdata.BidirectionalStreamingEndpointDSL, testdata.BidirectionalStreamingMethodEndpoint},
{"endpoint-bidirectional-streaming-no-payload", testdata.BidirectionalStreamingNoPayloadMethodDSL, testdata.BidirectionalStreamingNoPayloadMethodEndpoint},
{"endpoint-with-server-interceptor", testdata.EndpointWithServerInterceptorDSL, testdata.EndpointWithServerInterceptor},
{"endpoint-with-multiple-interceptors", testdata.EndpointWithMultipleInterceptorsDSL, testdata.EndpointWithMultipleInterceptors},
}
for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
Expand Down
217 changes: 217 additions & 0 deletions codegen/service/interceptors.go
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also generate example interceptor implementation like example service implementation?

Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
package service

import (
"path/filepath"

"goa.design/goa/v3/codegen"
"goa.design/goa/v3/expr"
)

type (
// ServiceInterceptorData contains all data needed for generating interceptor code
ServiceInterceptorData struct {
Service string
PkgName string
Methods []*MethodInterceptorData
ServerInterceptors []*InterceptorData
ClientInterceptors []*InterceptorData
AllInterceptors []*InterceptorData
HasPrivateImplementationTypes bool
}

// MethodInterceptorData contains interceptor data for a single method
MethodInterceptorData struct {
Service string
Method string
MethodVarName string
PayloadRef string
ResultRef string
ServerInterceptors []*InterceptorData
ClientInterceptors []*InterceptorData
}

// InterceptorData describes a single interceptor.
InterceptorData struct {
Name string
DesignName string
UnexportedName string
Description string
PayloadRef string
ResultRef string
ReadPayload []*AttributeData
WritePayload []*AttributeData
ReadResult []*AttributeData
WriteResult []*AttributeData
ServerStreamInputStruct string
ClientStreamInputStruct string
}

// AttributeData describes a single attribute.
AttributeData struct {
Name string
TypeRef string
FieldPointer bool
}
)

// InterceptorsFile returns the interceptors file for the given service.
func InterceptorsFile(genpkg string, service *expr.ServiceExpr) *codegen.File {
svc := Services.Get(service.Name)
data := interceptorsData(service)
if len(data.ServerInterceptors) == 0 && len(data.ClientInterceptors) == 0 {
return nil
}

path := filepath.Join(codegen.Gendir, svc.PathName, "interceptors.go")
sections := []*codegen.SectionTemplate{
codegen.Header(service.Name+" interceptors", svc.PkgName, []*codegen.ImportSpec{
{Path: "context"},
codegen.GoaImport(""),
}),
{
Name: "interceptors",
Source: readTemplate("interceptors"),
Data: data,
},
}

return &codegen.File{Path: path, SectionTemplates: sections}
}

func interceptorsData(service *expr.ServiceExpr) *ServiceInterceptorData {
svc := Services.Get(service.Name)
scope := svc.Scope

// Build method data first
methods := make([]*MethodInterceptorData, 0, len(service.Methods))
seenInts := make(map[string]*InterceptorData)
var serviceServerInts, serviceClientInts, allInts []*InterceptorData
var hasTypes bool

for _, m := range service.Methods {
methodServerInts, methodClientInts := buildMethodInterceptors(m, scope)
if len(methodServerInts) == 0 && len(methodClientInts) == 0 {
continue
}
hasTypes = hasTypes || hasPrivateImplementationTypes(methodServerInts) || hasPrivateImplementationTypes(methodClientInts)

// Add method data
methods = append(methods, &MethodInterceptorData{
Service: svc.Name,
Method: m.Name,
MethodVarName: codegen.Goify(m.Name, true),
PayloadRef: scope.GoFullTypeRef(m.Payload, ""),
ResultRef: scope.GoFullTypeRef(m.Result, ""),
ServerInterceptors: methodServerInts,
ClientInterceptors: methodClientInts,
})

// Collect unique interceptors
for _, i := range methodServerInts {
if _, ok := seenInts[i.Name]; !ok {
seenInts[i.Name] = i
serviceServerInts = append(serviceServerInts, i)
allInts = append(allInts, i)
}
}
for _, i := range methodClientInts {
if _, ok := seenInts[i.Name]; !ok {
seenInts[i.Name] = i
serviceClientInts = append(serviceClientInts, i)
allInts = append(allInts, i)
}
}
}

return &ServiceInterceptorData{
Service: service.Name,
PkgName: svc.PkgName,
Methods: methods,
ServerInterceptors: serviceServerInts,
ClientInterceptors: serviceClientInts,
AllInterceptors: allInts,
HasPrivateImplementationTypes: hasTypes,
}
}

func buildMethodInterceptors(m *expr.MethodExpr, scope *codegen.NameScope) ([]*InterceptorData, []*InterceptorData) {
svc := Services.Get(m.Service.Name)
methodData := svc.Method(m.Name)
var serverEndpointStruct, clientEndpointStruct string
if methodData.ServerStream != nil {
serverEndpointStruct = methodData.ServerStream.EndpointStruct
}
if methodData.ClientStream != nil {
clientEndpointStruct = methodData.ClientStream.EndpointStruct
}
var hasPrivateImplementationTypes bool
buildInterceptor := func(intr *expr.InterceptorExpr) *InterceptorData {
hasPrivateImplementationTypes = hasPrivateImplementationTypes ||
intr.ReadPayload != nil || intr.WritePayload != nil || intr.ReadResult != nil || intr.WriteResult != nil

return &InterceptorData{
Name: codegen.Goify(intr.Name, true),
DesignName: intr.Name,
UnexportedName: codegen.Goify(intr.Name, false),
Description: intr.Description,
PayloadRef: methodData.PayloadRef,
ResultRef: methodData.ResultRef,
ServerStreamInputStruct: serverEndpointStruct,
ClientStreamInputStruct: clientEndpointStruct,
ReadPayload: collectAttributes(intr.ReadPayload, m.Payload, scope),
WritePayload: collectAttributes(intr.WritePayload, m.Payload, scope),
ReadResult: collectAttributes(intr.ReadResult, m.Result, scope),
WriteResult: collectAttributes(intr.WriteResult, m.Result, scope),
}
}

serverInts := make([]*InterceptorData, len(m.ServerInterceptors))
for i, intr := range m.ServerInterceptors {
serverInts[i] = buildInterceptor(intr)
}

clientInts := make([]*InterceptorData, len(m.ClientInterceptors))
for i, intr := range m.ClientInterceptors {
clientInts[i] = buildInterceptor(intr)
}

return serverInts, clientInts
}

// hasPrivateImplementationTypes returns true if any of the interceptors have
// private implementation types.
func hasPrivateImplementationTypes(interceptors []*InterceptorData) bool {
for _, intr := range interceptors {
if intr.ReadPayload != nil || intr.WritePayload != nil || intr.ReadResult != nil || intr.WriteResult != nil {
return true
}
}
return false
}

// collectAttributes builds AttributeData from an AttributeExpr
func collectAttributes(attrNames, parent *expr.AttributeExpr, scope *codegen.NameScope) []*AttributeData {
if attrNames == nil {
return nil
}

obj := expr.AsObject(attrNames.Type)
if obj == nil {
return nil
}

data := make([]*AttributeData, len(*obj))
for i, nat := range *obj {
parentAttr := parent.Find(nat.Name)
if parentAttr == nil {
continue
}

data[i] = &AttributeData{
Name: codegen.Goify(nat.Name, true),
TypeRef: scope.GoTypeRef(parentAttr),
FieldPointer: parent.IsPrimitivePointer(nat.Name, true),
}
}
return data
}
Loading