// Copyright (c) 2023 Gitpod GmbH. All rights reserved. // Licensed under the GNU Affero General Public License (AGPL). // See License.AGPL.txt in the project root for license information. package main import ( "fmt" "path" "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/types/pluginpb" ) const ( contextPackage = protogen.GoImportPath("context") connectPackage = protogen.GoImportPath("github.com/bufbuild/connect-go") ) func main() { protogen.Options{}.Run(func(gen *protogen.Plugin) error { gen.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) for _, f := range gen.Files { if !f.Generate { continue } generateFile(gen, f) } return nil }) } func generateFile(gen *protogen.Plugin, file *protogen.File) { // We only generate our proxy implementation for services, not for raw structs if len(file.Services) == 0 { return } var ( targetPackageName = fmt.Sprintf("%sconnect", file.GoPackageName) filename = path.Join( path.Dir(file.GeneratedFilenamePrefix), targetPackageName, fmt.Sprintf("%s.proxy.connect.go", path.Base(file.GeneratedFilenamePrefix))) importPath = protogen.GoImportPath(path.Join(string(file.GoImportPath), string(file.GoPackageName))) ) // Setup a new generated file g := gen.NewGeneratedFile(filename, importPath) // generate preamble g.P("// Code generated by protoc-proxy-gen. DO NOT EDIT.") g.P() g.P("package ", targetPackageName) g.P() g.Import(file.GoImportPath) g.P() // generate individual services for _, service := range file.Services { // generate struct definition handlerStructName := fmt.Sprintf("Proxy%sHandler", service.GoName) // Generate a type assertion to ensure the handler implements the connect handler interface g.P(fmt.Sprintf("var _ %sHandler = (*%s)(nil)", service.GoName, handlerStructName)) g.Annotate(handlerStructName, service.Location) g.P(fmt.Sprintf("type %s struct {", handlerStructName)) g.P(fmt.Sprintf(" Client %s", g.QualifiedGoIdent(file.GoImportPath.Ident(service.GoName+"Client")))) g.P(fmt.Sprintf(" Unimplemented%sHandler", service.GoName)) g.P("}") g.P() for _, method := range service.Methods { // We do not generate any non-unary methods, for now. // Should we need these, we can choose to do so and handle them explicitly. // The handler still continues to work fine, as it inherits from the default Unimplemented handling, and will // always return Unimplemented. if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { continue } // method signature g.P(fmt.Sprintf("func (s *%s) %s(ctx %s, req *%s) (*%s, error) {", handlerStructName, method.GoName, g.QualifiedGoIdent(contextPackage.Ident("Context")), g.QualifiedGoIdent(connectPackage.Ident("Request"))+"["+g.QualifiedGoIdent(method.Input.GoIdent)+"]", g.QualifiedGoIdent(connectPackage.Ident("Response"))+"["+g.QualifiedGoIdent(method.Output.GoIdent)+"]", )) // method implementation g.P(fmt.Sprintf(" resp, err := s.Client.%s(ctx, req.Msg)", method.GoName)) g.P(" if err != nil {") g.P(" // TODO(milan): Convert to correct status code") g.P(" return nil, err") g.P(" }") g.P() g.P(fmt.Sprintf(" return %s(resp), nil", g.QualifiedGoIdent(connectPackage.Ident("NewResponse")))) // method end g.P("}") g.P() } } }