Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions cmd/protoc-gen-go-http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,15 @@ func buildHTTPRule(g *protogen.GeneratedFile, service *protogen.Service, m *prot
} else if body != "" {
md.HasBody = true
md.Body = "." + camelCaseVars(body)
md.BodyProtoName = body
} else {
md.HasBody = false
}
if responseBody == "*" {
md.ResponseBody = ""
} else if responseBody != "" {
md.ResponseBody = "." + camelCaseVars(responseBody)
md.ResponseBodyProtoName = responseBody
}
return md
}
Expand Down Expand Up @@ -259,6 +261,23 @@ func replacePath(name string, value string, path string) string {
return path
}

// toGetterAccessor converts a CamelCase field accessor like ".Payload" or ".User.Profile"
// into a chained getter call like ".GetPayload()" or ".GetUser().GetProfile()".
// This is compatible with both open and opaque protobuf Go API styles.
func toGetterAccessor(camelCaseBody string) string {
if camelCaseBody == "" {
return ""
}
parts := strings.Split(camelCaseBody[1:], ".") // strip leading dot, split segments
var result strings.Builder
for _, part := range parts {
result.WriteString(".Get")
result.WriteString(part)
result.WriteString("()")
}
return result.String()
}

func camelCaseVars(s string) string {
subs := strings.Split(s, ".")
vars := make([]string, 0, len(subs))
Expand Down
26 changes: 20 additions & 6 deletions cmd/protoc-gen-go-http/httpTemplate.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@ func _{{$svrType}}_{{.Name}}{{.Num}}_HTTP_Handler(srv {{$svrType}}HTTPServer) fu
return func(ctx http.Context) error {
var in {{.Request}}
{{- if .HasBody}}
if err := ctx.Bind(&in{{.Body}}); err != nil {
{{- if .BodyProtoName}}
if err := ctx.Bind(in.ProtoReflect().Mutable(in.ProtoReflect().Descriptor().Fields().ByName("{{.BodyProtoName}}")).Message().Interface()); err != nil {
return err
}
{{- else}}
if err := ctx.Bind(&in); err != nil {
return err
}
{{- end}}
{{- end}}
if err := ctx.BindQuery(&in); err != nil {
return err
Expand All @@ -47,7 +53,11 @@ func _{{$svrType}}_{{.Name}}{{.Num}}_HTTP_Handler(srv {{$svrType}}HTTPServer) fu
return err
}
reply := out.(*{{.Reply}})
{{- if .ResponseBodyProtoName}}
return ctx.Result(200, reply{{getterAccessor .ResponseBody}})
{{- else}}
return ctx.Result(200, reply{{.ResponseBody}})
{{- end}}
}
}
{{end}}
Expand Down Expand Up @@ -79,11 +89,15 @@ func (c *{{$svrType}}HTTPClientImpl) {{.Name}}(ctx context.Context, in *{{.Reque
path := binding.EncodeURL(pattern, in, {{not .HasBody}})
opts = append(opts, http.Operation(Operation{{$svrType}}{{.OriginalName}}))
opts = append(opts, http.PathTemplate(pattern))
{{if .HasBody -}}
err := c.cc.Invoke(ctx, "{{.Method}}", path, in{{.Body}}, &out{{.ResponseBody}}, opts...)
{{else -}}
err := c.cc.Invoke(ctx, "{{.Method}}", path, nil, &out{{.ResponseBody}}, opts...)
{{end -}}
{{- if .HasBody}}
{{- if .BodyProtoName}}
err := c.cc.Invoke(ctx, "{{.Method}}", path, in{{getterAccessor .Body}}, {{if .ResponseBodyProtoName}}out.ProtoReflect().Mutable(out.ProtoReflect().Descriptor().Fields().ByName("{{.ResponseBodyProtoName}}")).Message().Interface(){{else}}&out{{.ResponseBody}}{{end}}, opts...)
{{- else}}
err := c.cc.Invoke(ctx, "{{.Method}}", path, in{{.Body}}, {{if .ResponseBodyProtoName}}out.ProtoReflect().Mutable(out.ProtoReflect().Descriptor().Fields().ByName("{{.ResponseBodyProtoName}}")).Message().Interface(){{else}}&out{{.ResponseBody}}{{end}}, opts...)
{{- end}}
{{- else}}
err := c.cc.Invoke(ctx, "{{.Method}}", path, nil, {{if .ResponseBodyProtoName}}out.ProtoReflect().Mutable(out.ProtoReflect().Descriptor().Fields().ByName("{{.ResponseBodyProtoName}}")).Message().Interface(){{else}}&out{{.ResponseBody}}{{end}}, opts...)
{{- end}}
if err != nil {
return nil, err
}
Expand Down
88 changes: 88 additions & 0 deletions cmd/protoc-gen-go-http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"reflect"
"strings"
"testing"
)

Expand Down Expand Up @@ -98,3 +99,90 @@ func TestReplaceBoundary(t *testing.T) {
t.Fatal(`"/test/{message.namespace=*}/name/{message.name=*}" should be "/test/{message.namespace:.*}/name/{message.name:.*}"`)
}
}

func TestGetterAccessor(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"", ""},
{".Payload", ".GetPayload()"},
{".UserProfile", ".GetUserProfile()"},
{".User.Profile", ".GetUser().GetProfile()"},
}
for _, tt := range tests {
got := toGetterAccessor(tt.input)
if got != tt.expected {
t.Fatalf("toGetterAccessor(%q) = %q, want %q", tt.input, got, tt.expected)
}
}
}

func TestTemplateBodyNamedField(t *testing.T) {
sd := &serviceDesc{
ServiceType: "TestSvc",
ServiceName: "test.v1.TestSvc",
Metadata: "test/v1/test.proto",
Methods: []*methodDesc{
{
Name: "BotHandler",
OriginalName: "BotHandler",
Num: 0,
Request: "BotHandlerRequest",
Reply: "emptypb.Empty",
Path: "/api/v1/bot",
Method: "POST",
HasBody: true,
Body: ".Payload",
BodyProtoName: "payload",
},
},
}
output := sd.execute()

// Server handler should use proto reflection for named body field
if !strings.Contains(output, `in.ProtoReflect().Mutable(in.ProtoReflect().Descriptor().Fields().ByName("payload")).Message().Interface()`) {
t.Fatal("server handler should use ProtoReflect().Mutable() for named body field")
}
if strings.Contains(output, "&in.Payload") {
t.Fatal("server handler should NOT use direct field access &in.Payload")
}

// Client should use getter for named body field
if !strings.Contains(output, "in.GetPayload()") {
t.Fatal("client should use getter in.GetPayload() for named body field")
}
if strings.Contains(output, "in.Payload") && !strings.Contains(output, "in.GetPayload()") {
t.Fatal("client should NOT use direct field access in.Payload")
}
}

func TestTemplateBodyStar(t *testing.T) {
sd := &serviceDesc{
ServiceType: "TestSvc",
ServiceName: "test.v1.TestSvc",
Metadata: "test/v1/test.proto",
Methods: []*methodDesc{
{
Name: "Create",
OriginalName: "Create",
Num: 0,
Request: "CreateRequest",
Reply: "CreateReply",
Path: "/api/v1/create",
Method: "POST",
HasBody: true,
Body: "",
},
},
}
output := sd.execute()

// body="*" should use &in for server bind and in for client
if !strings.Contains(output, "ctx.Bind(&in)") {
t.Fatal("body=* should use ctx.Bind(&in)")
}
if strings.Contains(output, "ProtoReflect") {
t.Fatal("body=* should NOT use ProtoReflect")
}
}
18 changes: 11 additions & 7 deletions cmd/protoc-gen-go-http/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ type methodDesc struct {
Reply string
Comment string
// http_rule
Path string
Method string
HasVars bool
HasBody bool
Body string
ResponseBody string
Path string
Method string
HasVars bool
HasBody bool
Body string
BodyProtoName string // proto field name for named body (empty when body is "*")
ResponseBody string
ResponseBodyProtoName string // proto field name for named response_body (empty when "*")
}

func (s *serviceDesc) execute() string {
Expand All @@ -41,7 +43,9 @@ func (s *serviceDesc) execute() string {
s.MethodSets[m.Name] = m
}
buf := new(bytes.Buffer)
tmpl, err := template.New("http").Parse(strings.TrimSpace(httpTemplate))
tmpl, err := template.New("http").Funcs(template.FuncMap{
"getterAccessor": toGetterAccessor,
}).Parse(strings.TrimSpace(httpTemplate))
if err != nil {
panic(err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/protoc-gen-go-http/version.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main

// release is the current protoc-gen-go-http version.
const release = "v2.9.2"
const release = "v2.9.3"