Milan Pavlik 6774bf3338
[public-api] Auto-generate proxy handler (#16626)
* fix

* fix

* Fix

* fix

* fix

* fix

* fix

* fix

* Fix

* Fix

* Fix
2023-03-06 09:27:03 +01:00

109 lines
3.3 KiB
Go

// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.
package main
import (
"fmt"
"path"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/types/pluginpb"
)
const (
contextPackage = protogen.GoImportPath("context")
connectPackage = protogen.GoImportPath("github.com/bufbuild/connect-go")
)
func main() {
protogen.Options{}.Run(func(gen *protogen.Plugin) error {
gen.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
for _, f := range gen.Files {
if !f.Generate {
continue
}
generateFile(gen, f)
}
return nil
})
}
func generateFile(gen *protogen.Plugin, file *protogen.File) {
// We only generate our proxy implementation for services, not for raw structs
if len(file.Services) == 0 {
return
}
var (
targetPackageName = fmt.Sprintf("%sconnect", file.GoPackageName)
filename = path.Join(
path.Dir(file.GeneratedFilenamePrefix),
targetPackageName,
fmt.Sprintf("%s.proxy.connect.go", path.Base(file.GeneratedFilenamePrefix)))
importPath = protogen.GoImportPath(path.Join(string(file.GoImportPath), string(file.GoPackageName)))
)
// Setup a new generated file
g := gen.NewGeneratedFile(filename, importPath)
// generate preamble
g.P("// Code generated by protoc-proxy-gen. DO NOT EDIT.")
g.P()
g.P("package ", targetPackageName)
g.P()
g.Import(file.GoImportPath)
g.P()
// generate individual services
for _, service := range file.Services {
// generate struct definition
handlerStructName := fmt.Sprintf("Proxy%sHandler", service.GoName)
// Generate a type assertion to ensure the handler implements the connect handler interface
g.P(fmt.Sprintf("var _ %sHandler = (*%s)(nil)", service.GoName, handlerStructName))
g.Annotate(handlerStructName, service.Location)
g.P(fmt.Sprintf("type %s struct {", handlerStructName))
g.P(fmt.Sprintf(" Client %s", g.QualifiedGoIdent(file.GoImportPath.Ident(service.GoName+"Client"))))
g.P(fmt.Sprintf(" Unimplemented%sHandler", service.GoName))
g.P("}")
g.P()
for _, method := range service.Methods {
// We do not generate any non-unary methods, for now.
// Should we need these, we can choose to do so and handle them explicitly.
// The handler still continues to work fine, as it inherits from the default Unimplemented handling, and will
// always return Unimplemented.
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
continue
}
// method signature
g.P(fmt.Sprintf("func (s *%s) %s(ctx %s, req *%s) (*%s, error) {",
handlerStructName,
method.GoName,
g.QualifiedGoIdent(contextPackage.Ident("Context")),
g.QualifiedGoIdent(connectPackage.Ident("Request"))+"["+g.QualifiedGoIdent(method.Input.GoIdent)+"]",
g.QualifiedGoIdent(connectPackage.Ident("Response"))+"["+g.QualifiedGoIdent(method.Output.GoIdent)+"]",
))
// method implementation
g.P(fmt.Sprintf(" resp, err := s.Client.%s(ctx, req.Msg)", method.GoName))
g.P(" if err != nil {")
g.P(" // TODO(milan): Convert to correct status code")
g.P(" return nil, err")
g.P(" }")
g.P()
g.P(fmt.Sprintf(" return %s(resp), nil", g.QualifiedGoIdent(connectPackage.Ident("NewResponse"))))
// method end
g.P("}")
g.P()
}
}
}