diff --git a/cmd/protoc-gen-go-http/http.go b/cmd/protoc-gen-go-http/http.go index 8fae077f03f..696e1a6d17b 100644 --- a/cmd/protoc-gen-go-http/http.go +++ b/cmd/protoc-gen-go-http/http.go @@ -161,6 +161,7 @@ func buildHTTPRule(g *protogen.GeneratedFile, service *protogen.Service, m *prot } else if body != "" { md.HasBody = true md.Body = "." + camelCaseVars(body) + md.BodyProtoName = body } else { md.HasBody = false } @@ -168,6 +169,7 @@ func buildHTTPRule(g *protogen.GeneratedFile, service *protogen.Service, m *prot md.ResponseBody = "" } else if responseBody != "" { md.ResponseBody = "." + camelCaseVars(responseBody) + md.ResponseBodyProtoName = responseBody } return md } @@ -259,6 +261,23 @@ func replacePath(name string, value string, path string) string { return path } +// toGetterAccessor converts a CamelCase field accessor like ".Payload" or ".User.Profile" +// into a chained getter call like ".GetPayload()" or ".GetUser().GetProfile()". +// This is compatible with both open and opaque protobuf Go API styles. +func toGetterAccessor(camelCaseBody string) string { + if camelCaseBody == "" { + return "" + } + parts := strings.Split(camelCaseBody[1:], ".") // strip leading dot, split segments + var result strings.Builder + for _, part := range parts { + result.WriteString(".Get") + result.WriteString(part) + result.WriteString("()") + } + return result.String() +} + func camelCaseVars(s string) string { subs := strings.Split(s, ".") vars := make([]string, 0, len(subs)) diff --git a/cmd/protoc-gen-go-http/httpTemplate.tpl b/cmd/protoc-gen-go-http/httpTemplate.tpl index 3dc9e0b2add..7c2435a4751 100644 --- a/cmd/protoc-gen-go-http/httpTemplate.tpl +++ b/cmd/protoc-gen-go-http/httpTemplate.tpl @@ -26,9 +26,15 @@ func _{{$svrType}}_{{.Name}}{{.Num}}_HTTP_Handler(srv {{$svrType}}HTTPServer) fu return func(ctx http.Context) error { var in {{.Request}} {{- if .HasBody}} - if err := ctx.Bind(&in{{.Body}}); err != nil { + {{- if .BodyProtoName}} + if err := ctx.Bind(in.ProtoReflect().Mutable(in.ProtoReflect().Descriptor().Fields().ByName("{{.BodyProtoName}}")).Message().Interface()); err != nil { return err } + {{- else}} + if err := ctx.Bind(&in); err != nil { + return err + } + {{- end}} {{- end}} if err := ctx.BindQuery(&in); err != nil { return err @@ -47,7 +53,11 @@ func _{{$svrType}}_{{.Name}}{{.Num}}_HTTP_Handler(srv {{$svrType}}HTTPServer) fu return err } reply := out.(*{{.Reply}}) + {{- if .ResponseBodyProtoName}} + return ctx.Result(200, reply{{getterAccessor .ResponseBody}}) + {{- else}} return ctx.Result(200, reply{{.ResponseBody}}) + {{- end}} } } {{end}} @@ -79,11 +89,15 @@ func (c *{{$svrType}}HTTPClientImpl) {{.Name}}(ctx context.Context, in *{{.Reque path := binding.EncodeURL(pattern, in, {{not .HasBody}}) opts = append(opts, http.Operation(Operation{{$svrType}}{{.OriginalName}})) opts = append(opts, http.PathTemplate(pattern)) - {{if .HasBody -}} - err := c.cc.Invoke(ctx, "{{.Method}}", path, in{{.Body}}, &out{{.ResponseBody}}, opts...) - {{else -}} - err := c.cc.Invoke(ctx, "{{.Method}}", path, nil, &out{{.ResponseBody}}, opts...) - {{end -}} + {{- if .HasBody}} + {{- if .BodyProtoName}} + err := c.cc.Invoke(ctx, "{{.Method}}", path, in{{getterAccessor .Body}}, {{if .ResponseBodyProtoName}}out.ProtoReflect().Mutable(out.ProtoReflect().Descriptor().Fields().ByName("{{.ResponseBodyProtoName}}")).Message().Interface(){{else}}&out{{.ResponseBody}}{{end}}, opts...) + {{- else}} + err := c.cc.Invoke(ctx, "{{.Method}}", path, in{{.Body}}, {{if .ResponseBodyProtoName}}out.ProtoReflect().Mutable(out.ProtoReflect().Descriptor().Fields().ByName("{{.ResponseBodyProtoName}}")).Message().Interface(){{else}}&out{{.ResponseBody}}{{end}}, opts...) + {{- end}} + {{- else}} + err := c.cc.Invoke(ctx, "{{.Method}}", path, nil, {{if .ResponseBodyProtoName}}out.ProtoReflect().Mutable(out.ProtoReflect().Descriptor().Fields().ByName("{{.ResponseBodyProtoName}}")).Message().Interface(){{else}}&out{{.ResponseBody}}{{end}}, opts...) + {{- end}} if err != nil { return nil, err } diff --git a/cmd/protoc-gen-go-http/http_test.go b/cmd/protoc-gen-go-http/http_test.go index 40680b3f15e..59031371b76 100644 --- a/cmd/protoc-gen-go-http/http_test.go +++ b/cmd/protoc-gen-go-http/http_test.go @@ -2,6 +2,7 @@ package main import ( "reflect" + "strings" "testing" ) @@ -98,3 +99,90 @@ func TestReplaceBoundary(t *testing.T) { t.Fatal(`"/test/{message.namespace=*}/name/{message.name=*}" should be "/test/{message.namespace:.*}/name/{message.name:.*}"`) } } + +func TestGetterAccessor(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"", ""}, + {".Payload", ".GetPayload()"}, + {".UserProfile", ".GetUserProfile()"}, + {".User.Profile", ".GetUser().GetProfile()"}, + } + for _, tt := range tests { + got := toGetterAccessor(tt.input) + if got != tt.expected { + t.Fatalf("toGetterAccessor(%q) = %q, want %q", tt.input, got, tt.expected) + } + } +} + +func TestTemplateBodyNamedField(t *testing.T) { + sd := &serviceDesc{ + ServiceType: "TestSvc", + ServiceName: "test.v1.TestSvc", + Metadata: "test/v1/test.proto", + Methods: []*methodDesc{ + { + Name: "BotHandler", + OriginalName: "BotHandler", + Num: 0, + Request: "BotHandlerRequest", + Reply: "emptypb.Empty", + Path: "/api/v1/bot", + Method: "POST", + HasBody: true, + Body: ".Payload", + BodyProtoName: "payload", + }, + }, + } + output := sd.execute() + + // Server handler should use proto reflection for named body field + if !strings.Contains(output, `in.ProtoReflect().Mutable(in.ProtoReflect().Descriptor().Fields().ByName("payload")).Message().Interface()`) { + t.Fatal("server handler should use ProtoReflect().Mutable() for named body field") + } + if strings.Contains(output, "&in.Payload") { + t.Fatal("server handler should NOT use direct field access &in.Payload") + } + + // Client should use getter for named body field + if !strings.Contains(output, "in.GetPayload()") { + t.Fatal("client should use getter in.GetPayload() for named body field") + } + if strings.Contains(output, "in.Payload") && !strings.Contains(output, "in.GetPayload()") { + t.Fatal("client should NOT use direct field access in.Payload") + } +} + +func TestTemplateBodyStar(t *testing.T) { + sd := &serviceDesc{ + ServiceType: "TestSvc", + ServiceName: "test.v1.TestSvc", + Metadata: "test/v1/test.proto", + Methods: []*methodDesc{ + { + Name: "Create", + OriginalName: "Create", + Num: 0, + Request: "CreateRequest", + Reply: "CreateReply", + Path: "/api/v1/create", + Method: "POST", + HasBody: true, + Body: "", + }, + }, + } + output := sd.execute() + + // body="*" should use &in for server bind and in for client + if !strings.Contains(output, "ctx.Bind(&in)") { + t.Fatal("body=* should use ctx.Bind(&in)") + } + if strings.Contains(output, "ProtoReflect") { + t.Fatal("body=* should NOT use ProtoReflect") + } +} diff --git a/cmd/protoc-gen-go-http/template.go b/cmd/protoc-gen-go-http/template.go index 9de8e7e1790..fef6fc3a836 100644 --- a/cmd/protoc-gen-go-http/template.go +++ b/cmd/protoc-gen-go-http/template.go @@ -27,12 +27,14 @@ type methodDesc struct { Reply string Comment string // http_rule - Path string - Method string - HasVars bool - HasBody bool - Body string - ResponseBody string + Path string + Method string + HasVars bool + HasBody bool + Body string + BodyProtoName string // proto field name for named body (empty when body is "*") + ResponseBody string + ResponseBodyProtoName string // proto field name for named response_body (empty when "*") } func (s *serviceDesc) execute() string { @@ -41,7 +43,9 @@ func (s *serviceDesc) execute() string { s.MethodSets[m.Name] = m } buf := new(bytes.Buffer) - tmpl, err := template.New("http").Parse(strings.TrimSpace(httpTemplate)) + tmpl, err := template.New("http").Funcs(template.FuncMap{ + "getterAccessor": toGetterAccessor, + }).Parse(strings.TrimSpace(httpTemplate)) if err != nil { panic(err) } diff --git a/cmd/protoc-gen-go-http/version.go b/cmd/protoc-gen-go-http/version.go index b93a354a3ef..39679615c91 100644 --- a/cmd/protoc-gen-go-http/version.go +++ b/cmd/protoc-gen-go-http/version.go @@ -1,4 +1,4 @@ package main // release is the current protoc-gen-go-http version. -const release = "v2.9.2" +const release = "v2.9.3"