diff --git a/internal/envconfig/envconfig.go b/internal/envconfig/envconfig.go index 936e60c726ea..ba05b65d5b83 100644 --- a/internal/envconfig/envconfig.go +++ b/internal/envconfig/envconfig.go @@ -59,6 +59,15 @@ var ( // unconditionally. XDSEndpointHashKeyBackwardCompat = boolFromEnv("GRPC_XDS_ENDPOINT_HASH_KEY_BACKWARD_COMPAT", false) + // LabelServerGoroutines controls setting [runtime/pprof.Labels] on the + // goroutines spawned by [grpc.Server] type. + // For now, this is limited to the goroutines spawned to handle incoming + // requests on the server. + // Set "GRPC_GO_SERVER_GOROUTINE_LABELS" to "grpc.method=true" to + // enable this grpc.method label, or "all" to enable all valid labels. + // This variable is a bit-field. + LabelServerGoroutines = goroutineLabelsFromEnv("GRPC_GO_SERVER_GOROUTINE_LABELS", 0) + // RingHashSetRequestHashKey is set if the ring hash balancer can get the // request hash header by setting the "requestHashHeader" field, according // to gRFC A76. It can be disabled by setting the environment variable @@ -156,3 +165,52 @@ func uint64FromEnv(envVar string, def, min, max uint64) uint64 { } return v } + +// GoroutineLabels is a bitfield indicating which goroutine labels are enabled. +type GoroutineLabels uint16 + +func goroutineLabelsFromEnv(envVar string, def GoroutineLabels) GoroutineLabels { + val := def + v := os.Getenv(envVar) + if strings.EqualFold(v, "all") { + return AllGoroutineLabels + } else if strings.EqualFold(v, "none") { + return 0 + } + for s := range strings.SplitSeq(v, ",") { + s = strings.TrimSpace(s) + if len(s) == 0 { + continue + } + pre, post, ok := strings.Cut(s, "=") + if !ok { + // no equals sign + continue + } + post = strings.TrimSpace(post) + pre = strings.TrimSpace(pre) + bitDesignator := GoroutineLabels(0) + switch { + case strings.EqualFold(pre, "grpc.method"): + bitDesignator = GoroutineLabelServerMethod + default: + continue + } + if strings.EqualFold(post, "true") { + val |= bitDesignator + } else if strings.EqualFold(post, "false") { + val &^= bitDesignator + } + } + return val +} + +const ( + // GoroutineLabelServerMethod sets the grpc.method label on new + // server-side gRPC streams. + GoroutineLabelServerMethod GoroutineLabels = 1 << iota +) + +// AllGoroutineLabels is an or'd together bitfield of all valid GoroutineLabels +// constant values (above). +const AllGoroutineLabels = GoroutineLabelServerMethod diff --git a/internal/envconfig/envconfig_test.go b/internal/envconfig/envconfig_test.go index 68fdf6c73a7f..4382860a06b3 100644 --- a/internal/envconfig/envconfig_test.go +++ b/internal/envconfig/envconfig_test.go @@ -101,3 +101,118 @@ func (s) TestBoolFromEnv(t *testing.T) { }) } } + +func (s) TestGoroutineLabelsFromEnv(t *testing.T) { + var testCases = []struct { + name string + val string + def GoroutineLabels + want GoroutineLabels + }{ + { + name: "unset_env_non-zero_default", + val: "", + def: GoroutineLabelServerMethod, + want: GoroutineLabelServerMethod, + }, { + name: "unset_env_zero_default", + val: "", + def: 0, + want: 0, + }, { + name: "force-enable_zero_default", + val: "grpc.method=true", + def: 0, + want: GoroutineLabelServerMethod, + }, { + name: "force-enable_zero_default_all_caps", + val: "grpc.method=TRUE", + def: 0, + want: GoroutineLabelServerMethod, + }, { + name: "force-enable_zero_default_with_whitespace", + val: " grpc.method\t= true", + def: 0, + want: GoroutineLabelServerMethod, + }, { + name: "force-enable_zero_default_with_other_garbage", + val: "grpc.method=true,foobar", + def: 0, + want: GoroutineLabelServerMethod, + }, { + name: "force-enable_mixed_case_zero_default_with_other_garbage", + val: "grpc.method=tRuE,foobar", + def: 0, + want: GoroutineLabelServerMethod, + }, { + name: "force-disable_zero_default", + val: "grpc.method=false", + def: 0, + want: 0, + }, { + name: "force-disable_non-zero_default", + val: "grpc.method=false", + def: GoroutineLabelServerMethod, + want: 0, + }, { + name: "force-disable_non-zero_default_all_caps", + val: "grpc.method=FALSE", + def: GoroutineLabelServerMethod, + want: 0, + }, { + name: "force-disable_non-zero_default_mixed_case", + val: "grpc.method=fAlSe", + def: GoroutineLabelServerMethod, + want: 0, + }, { + name: "unknown_val_no_equal", + val: "grpc.unknown.garbage", + def: GoroutineLabelServerMethod, + want: GoroutineLabelServerMethod, + }, { + name: "unknown_val", + val: "grpc.unknown.garbage=fooble", + def: GoroutineLabelServerMethod, + want: GoroutineLabelServerMethod, + }, { + name: "all_with_empty_default", + val: "all", + def: 0, + want: AllGoroutineLabels, + }, { + name: "all_with_server_method_default", + val: "all", + def: GoroutineLabelServerMethod, + want: AllGoroutineLabels, + }, { + name: "none_with_empty_default", + val: "none", + def: 0, + want: 0, + }, { + name: "none_with_server_method_default", + val: "none", + def: GoroutineLabelServerMethod, + want: 0, + }, { + name: "unparseable_rhs", + val: "grpc.method=quux", + def: GoroutineLabelServerMethod, + want: GoroutineLabelServerMethod, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + const testVar = "testvar" + if tc.val == "" { + os.Unsetenv(testVar) + } else { + os.Setenv(testVar, tc.val) + } + if got := goroutineLabelsFromEnv(testVar, tc.def); got != tc.want { + t.Errorf("goroutineLabelsFromEnv(%q(=%q), %v) = %v; want %v", testVar, tc.val, tc.def, got, tc.want) + } + }) + } + +} diff --git a/server.go b/server.go index 6fb7e0944fad..cf0a2067190c 100644 --- a/server.go +++ b/server.go @@ -28,6 +28,7 @@ import ( "net/http" "reflect" "runtime" + "runtime/pprof" "strings" "sync" "sync/atomic" @@ -42,6 +43,7 @@ import ( "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcutil" istats "google.golang.org/grpc/internal/stats" @@ -1798,6 +1800,12 @@ func (s *Server) handleMalformedMethodName(stream *transport.ServerStream, ti *t func (s *Server) handleStream(t transport.ServerTransport, stream *transport.ServerStream) { ctx := stream.Context() ctx = contextWithServer(ctx, s) + if envconfig.LabelServerGoroutines&envconfig.GoroutineLabelServerMethod != 0 { + // This method always runs in its own goroutine, so we can set a + // goroutine label without needing to restore a previous context. + ctx = pprof.WithLabels(ctx, pprof.Labels("grpc.method", stream.Method())) + pprof.SetGoroutineLabels(ctx) + } var ti *traceInfo if EnableTracing { tr := newTrace("grpc.Recv."+methodFamily(stream.Method()), stream.Method()) diff --git a/test/server_test.go b/test/server_test.go index 0441c08fd00e..72d07506a414 100644 --- a/test/server_test.go +++ b/test/server_test.go @@ -21,12 +21,15 @@ package test import ( "context" "io" + "runtime/pprof" "sync/atomic" "testing" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/status" testgrpc "google.golang.org/grpc/interop/grpc_testing" @@ -70,6 +73,101 @@ func (s) TestServerReturningContextError(t *testing.T) { } +func pprofCtxCollectLabels(ctx context.Context) map[string]string { + seenLabels := map[string]string{} + pprof.ForLabels(ctx, func(k, val string) bool { + seenLabels[k] = val + return true + }) + return seenLabels +} + +// TestServerSetGoroutineLabelsInContext verifies that when enabled, the +// grpc.method runtime/pprof goroutine label gets set in the context that's +// passed to the handlers. +func (s) TestServerSetGoroutineLabelsInContext(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.LabelServerGoroutines, envconfig.GoroutineLabelServerMethod) + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + ctxLabels := pprofCtxCollectLabels(ctx) + if val, ok := ctxLabels["grpc.method"]; !ok { + t.Errorf("missing \"grpc.method\" label; found labels: %v", ctxLabels) + } else if wantVal := "/grpc.testing.TestService/EmptyCall"; val != wantVal { + t.Errorf("unexpected value for \"grpc.method\" label %q; want %q", ctxLabels["grpc.method"], wantVal) + } + return &testpb.Empty{}, nil + }, + FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { + ctxLabels := pprofCtxCollectLabels(stream.Context()) + if val, ok := ctxLabels["grpc.method"]; !ok { + t.Errorf("missing \"grpc.method\" label; found labels: %v", ctxLabels) + } else if wantVal := "/grpc.testing.TestService/FullDuplexCall"; val != wantVal { + t.Errorf("unexpected value for \"grpc.method\" label %q; want %q", ctxLabels["grpc.method"], wantVal) + } + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("ss.Client.EmptyCall() got error %v; want OK", err) + } + + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("unexpected error starting the stream: %v", err) + } + if _, err = stream.Recv(); err != io.EOF { + t.Fatalf("ss.Client.FullDuplexCall().Recv() got error %v; want io.EOF", err) + } +} + +// TestServerSetGoroutineLabelsInContextEnvVarDisabled verifies that when +// disable, the grpc.method runtime/pprof goroutine label does _not_ get set in +// the context that's passed to the handlers. +func (s) TestServerSetGoroutineLabelsInContextEnvVarDisabled(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.LabelServerGoroutines, 0) + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + ctxLabels := pprofCtxCollectLabels(ctx) + if val, ok := ctxLabels["grpc.method"]; ok { + t.Errorf("\"grpc.method\" label set with value %q; found labels: %v", val, ctxLabels) + } + return &testpb.Empty{}, nil + }, + FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { + ctxLabels := pprofCtxCollectLabels(stream.Context()) + if val, ok := ctxLabels["grpc.method"]; ok { + t.Errorf("\"grpc.method\" label set with value %q; found labels: %v", val, ctxLabels) + } + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("ss.Client.EmptyCall() got error %v; want OK", err) + } + + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("unexpected error starting the stream: %v", err) + } + if _, err = stream.Recv(); err != io.EOF { + t.Fatalf("ss.Client.FullDuplexCall().Recv() got error %v; want io.EOF", err) + } +} + func (s) TestChainUnaryServerInterceptor(t *testing.T) { var ( firstIntKey = ctxKey("firstIntKey")