diff --git a/cmd/kratos/internal/proto/server/server.go b/cmd/kratos/internal/proto/server/server.go index 665dea8ae07..016eb8116c7 100644 --- a/cmd/kratos/internal/proto/server/server.go +++ b/cmd/kratos/internal/proto/server/server.go @@ -20,10 +20,15 @@ var CmdServer = &cobra.Command{ Long: "Generate the proto server implementations. Example: kratos proto server api/xxx.proto --target-dir=internal/service", Run: run, } -var targetDir string + +var ( + targetDir string + templatePath string +) func init() { CmdServer.Flags().StringVarP(&targetDir, "target-dir", "t", "internal/service", "generate target directory") + CmdServer.Flags().StringVarP(&templatePath, "template-file", "m", "", "specify custom template file") } func run(_ *cobra.Command, args []string) { @@ -81,7 +86,7 @@ func run(_ *cobra.Command, args []string) { fmt.Fprintf(os.Stderr, "%s already exists: %s\n", s.Service, to) continue } - b, err := s.execute() + b, err := s.execute(templatePath) if err != nil { log.Fatal(err) } diff --git a/cmd/kratos/internal/proto/server/template.go b/cmd/kratos/internal/proto/server/template.go index 732947ae832..48271e9fa21 100644 --- a/cmd/kratos/internal/proto/server/template.go +++ b/cmd/kratos/internal/proto/server/template.go @@ -3,10 +3,11 @@ package server import ( "bytes" "html/template" + "os" ) //nolint:lll -var serviceTemplate = ` +var defaultServiceTemplate = ` {{- /* delete empty line */ -}} package service @@ -116,7 +117,21 @@ type Method struct { Type MethodType } -func (s *Service) execute() ([]byte, error) { +// LoadTemplate loads a template from a file or uses the default template +func LoadTemplate(templatePath string) (string, error) { + if templatePath == "" { + return defaultServiceTemplate, nil + } + + content, err := os.ReadFile(templatePath) + if err != nil { + return "", err + } + + return string(content), nil +} + +func (s *Service) execute(templatePath string) ([]byte, error) { const empty = "google.protobuf.Empty" buf := new(bytes.Buffer) for _, method := range s.Methods { @@ -131,7 +146,13 @@ func (s *Service) execute() ([]byte, error) { s.UseContext = true } } - tmpl, err := template.New("service").Parse(serviceTemplate) + + templateContent, err := LoadTemplate(templatePath) + if err != nil { + return nil, err + } + + tmpl, err := template.New("service").Parse(templateContent) if err != nil { return nil, err }