Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions cmd/kratos/internal/proto/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@ var CmdServer = &cobra.Command{
Long: "Generate the proto server implementations. Example: kratos proto server api/xxx.proto --target-dir=internal/service",
Run: run,
}
var targetDir string

var (
targetDir string
templatePath string
)

func init() {
CmdServer.Flags().StringVarP(&targetDir, "target-dir", "t", "internal/service", "generate target directory")
CmdServer.Flags().StringVarP(&templatePath, "template-file", "m", "", "specify custom template file")
}

func run(_ *cobra.Command, args []string) {
Expand Down Expand Up @@ -81,7 +86,7 @@ func run(_ *cobra.Command, args []string) {
fmt.Fprintf(os.Stderr, "%s already exists: %s\n", s.Service, to)
continue
}
b, err := s.execute()
b, err := s.execute(templatePath)
if err != nil {
log.Fatal(err)
}
Expand Down
27 changes: 24 additions & 3 deletions cmd/kratos/internal/proto/server/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package server
import (
"bytes"
"html/template"
"os"
)

//nolint:lll
var serviceTemplate = `
var defaultServiceTemplate = `
{{- /* delete empty line */ -}}
package service

Expand Down Expand Up @@ -116,7 +117,21 @@ type Method struct {
Type MethodType
}

func (s *Service) execute() ([]byte, error) {
// LoadTemplate loads a template from a file or uses the default template
func LoadTemplate(templatePath string) (string, error) {
if templatePath == "" {
return defaultServiceTemplate, nil
}

content, err := os.ReadFile(templatePath)
if err != nil {
return "", err
}

return string(content), nil
}

func (s *Service) execute(templatePath string) ([]byte, error) {
const empty = "google.protobuf.Empty"
buf := new(bytes.Buffer)
for _, method := range s.Methods {
Expand All @@ -131,7 +146,13 @@ func (s *Service) execute() ([]byte, error) {
s.UseContext = true
}
}
tmpl, err := template.New("service").Parse(serviceTemplate)

templateContent, err := LoadTemplate(templatePath)
if err != nil {
return nil, err
}

tmpl, err := template.New("service").Parse(templateContent)
if err != nil {
return nil, err
}
Expand Down