-
Notifications
You must be signed in to change notification settings - Fork 4.7k
grpc: support per-message compression and enforce embedding in ServerTransportStream implementations #8972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
grpc: support per-message compression and enforce embedding in ServerTransportStream implementations #8972
Changes from 9 commits
e311529
846f339
b1e38ee
1feab8c
a7ff095
99f8640
731535b
489f488
9e906cb
a8bf7be
04e6bdc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,9 +27,11 @@ import ( | |
|
|
||
| "google.golang.org/grpc" | ||
| "google.golang.org/grpc/codes" | ||
| "google.golang.org/grpc/credentials/insecure" | ||
| "google.golang.org/grpc/encoding" | ||
| "google.golang.org/grpc/encoding/internal" | ||
| "google.golang.org/grpc/internal/stubserver" | ||
| "google.golang.org/grpc/stats" | ||
| "google.golang.org/grpc/status" | ||
|
|
||
| testgrpc "google.golang.org/grpc/interop/grpc_testing" | ||
|
|
@@ -237,3 +239,271 @@ func (s) TestDecompressionExceedsMaxMessageSize(t *testing.T) { | |
| t.Errorf("Client.UnaryCall(%+v) returned status %v, want %v", req, got, want) | ||
| } | ||
| } | ||
|
|
||
| // statsHandler is a stats.Handler that counts the number of compressed | ||
| // outbound and inbound messages by comparing CompressedLength to Length. | ||
| type statsHandler struct { | ||
| stats.Handler | ||
| compress atomic.Int32 | ||
| decompress atomic.Int32 | ||
| } | ||
|
|
||
| func (h *statsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { return ctx } | ||
| func (h *statsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { return ctx } | ||
| func (h *statsHandler) HandleConn(context.Context, stats.ConnStats) {} | ||
| func (h *statsHandler) HandleRPC(_ context.Context, s stats.RPCStats) { | ||
| switch st := s.(type) { | ||
| case *stats.OutPayload: | ||
| if st.CompressedLength < st.Length { | ||
| h.compress.Add(1) | ||
| } | ||
| case *stats.InPayload: | ||
| if st.CompressedLength < st.Length { | ||
| h.decompress.Add(1) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // TestMessageCompression_StreamToggle tests that SetServerStreamMessageCompression | ||
| // and SetClientStreamMessageCompression correctly enable and disable per-message | ||
| // compression mid-stream on the server and client side respectively. | ||
| func (s) TestMessageCompression_StreamToggle(t *testing.T) { | ||
| sh := &statsHandler{} | ||
| ss := &stubserver.StubServer{ | ||
| FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { | ||
| if _, err := stream.Recv(); err != nil { | ||
| return err | ||
| } | ||
| if err := stream.Send(&testpb.StreamingOutputCallResponse{ | ||
| Payload: &testpb.Payload{Body: make([]byte, 1000)}, | ||
| }); err != nil { | ||
| return err | ||
| } | ||
| if _, err := stream.Recv(); err != nil { | ||
| return err | ||
| } | ||
| if err := grpc.SetServerStreamMessageCompression(stream.Context(), false); err != nil { | ||
| return err | ||
| } | ||
| if err := stream.Send(&testpb.StreamingOutputCallResponse{ | ||
| Payload: &testpb.Payload{Body: make([]byte, 1000)}, | ||
| }); err != nil { | ||
| return err | ||
| } | ||
| if _, err := stream.Recv(); err != nil { | ||
| return err | ||
| } | ||
| if err := grpc.SetServerStreamMessageCompression(stream.Context(), true); err != nil { | ||
| return err | ||
| } | ||
| return stream.Send(&testpb.StreamingOutputCallResponse{ | ||
| Payload: &testpb.Payload{Body: make([]byte, 1000)}, | ||
| }) | ||
| }, | ||
| } | ||
|
|
||
| if err := ss.Start(nil, grpc.WithStatsHandler(sh)); err != nil { | ||
| t.Fatalf("Error starting server: %v", err) | ||
| } | ||
| defer ss.Stop() | ||
|
|
||
| ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) | ||
| defer cancel() | ||
|
|
||
| stream, err := ss.Client.FullDuplexCall(ctx, grpc.UseCompressor("gzip")) | ||
| if err != nil { | ||
| t.Fatalf("FullDuplexCall failed: %v", err) | ||
| } | ||
|
|
||
| // 1. Send first compressed message | ||
| stream.Send(&testpb.StreamingOutputCallRequest{Payload: &testpb.Payload{Body: make([]byte, 1000)}}) | ||
| stream.Recv() | ||
| if sh.compress.Load() != 1 || sh.decompress.Load() != 1 { | ||
| t.Fatalf("After call 1 (compression enabled): got compress=%d decompress=%d, want compress=1 decompress=1", | ||
| sh.compress.Load(), sh.decompress.Load()) | ||
| } | ||
|
|
||
| // 2. Disable message compression and send second message | ||
| grpc.SetClientStreamMessageCompression(stream.Context(), false) | ||
| stream.Send(&testpb.StreamingOutputCallRequest{Payload: &testpb.Payload{Body: make([]byte, 1000)}}) | ||
| stream.Recv() | ||
| if sh.compress.Load() != 1 || sh.decompress.Load() != 1 { | ||
| t.Fatalf("After call 2 (compression disabled): got compress=%d decompress=%d, want compress=1 decompress=1", | ||
| sh.compress.Load(), sh.decompress.Load()) | ||
| } | ||
|
|
||
| // 3. Enable message compression and send third message | ||
| grpc.SetClientStreamMessageCompression(stream.Context(), true) | ||
| stream.Send(&testpb.StreamingOutputCallRequest{Payload: &testpb.Payload{Body: make([]byte, 1000)}}) | ||
| stream.Recv() | ||
| if sh.compress.Load() != 2 || sh.decompress.Load() != 2 { | ||
| t.Fatalf("After call 3 (compression re-enabled): got compress=%d decompress=%d, want compress=2 decompress=2", | ||
| sh.compress.Load(), sh.decompress.Load()) | ||
| } | ||
| } | ||
|
|
||
| // TestMessageCompression_AmbiguousContext verifies that | ||
| // SetServerStreamMessageCompression and SetClientStreamMessageCompression work | ||
| // independently when called on a context that contains keys for both a server | ||
| // stream and a client stream. This situation arises when a server handler | ||
| // propagates its context to an outbound gRPC call for deadline propagation: | ||
| // the outbound ClientStream.Context() inherits the server-stream key from the | ||
| // parent and also adds its own client-stream compression key. | ||
| func (s) TestMessageCompression_AmbiguousContext(t *testing.T) { | ||
| backendSH := &statsHandler{} | ||
| backend := &stubserver.StubServer{ | ||
| FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { | ||
| for { | ||
| if _, err := stream.Recv(); err == io.EOF { | ||
| return nil | ||
| } else if err != nil { | ||
| return err | ||
| } | ||
| if err := stream.Send(&testpb.StreamingOutputCallResponse{ | ||
| Payload: &testpb.Payload{Body: make([]byte, 1000)}, | ||
| }); err != nil { | ||
| return err | ||
| } | ||
| } | ||
| }, | ||
| } | ||
| if err := backend.StartServer(grpc.StatsHandler(backendSH)); err != nil { | ||
| t.Fatalf("Error starting backend: %v", err) | ||
| } | ||
| defer backend.Stop() | ||
|
|
||
| // errCh carries any errors from the two SetXxx calls inside the proxy handler. | ||
| errCh := make(chan error, 2) | ||
| proxy := &stubserver.StubServer{ | ||
| FullDuplexCallF: func(serverStream testgrpc.TestService_FullDuplexCallServer) error { | ||
| // Use the server handler's context as the parent for the outbound | ||
| // call. This is the standard deadline-propagation pattern and is | ||
| // what makes the resulting context "ambiguous": it inherits | ||
| // streamKey{} from the server infrastructure, and the gRPC client | ||
| // stack appends compressKey{} when creating the outbound stream. | ||
| backendConn, err := grpc.NewClient( | ||
| backend.Address, | ||
| grpc.WithTransportCredentials(insecure.NewCredentials()), | ||
| ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Single line here too. Please see: https://google.github.io/styleguide/go/guide#line-length |
||
| if err != nil { | ||
| return err | ||
| } | ||
| defer backendConn.Close() | ||
|
|
||
| clientStream, err := testgrpc.NewTestServiceClient(backendConn).FullDuplexCall( | ||
| serverStream.Context(), grpc.UseCompressor("gzip"), | ||
| ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here too |
||
| if err != nil { | ||
| return err | ||
| } | ||
|
|
||
| // clientStream.Context() now has BOTH: | ||
| // - streamKey{} → the proxy's own server-side transport stream | ||
| // - compressKey{} → the outbound client stream's compression flag | ||
| ambiguousCtx := clientStream.Context() | ||
|
|
||
| // Must only affect the server stream (proxy → original caller). | ||
| errCh <- grpc.SetServerStreamMessageCompression(ambiguousCtx, false) | ||
| // Must only affect the client stream (proxy → backend). | ||
| errCh <- grpc.SetClientStreamMessageCompression(ambiguousCtx, false) | ||
|
|
||
| // Forward one request/response pair through the proxy. | ||
| req, err := serverStream.Recv() | ||
| if err != nil { | ||
| return err | ||
| } | ||
| if err := clientStream.Send(req); err != nil { | ||
| return err | ||
| } | ||
| if err := clientStream.CloseSend(); err != nil { | ||
| return err | ||
| } | ||
| resp, err := clientStream.Recv() | ||
| if err != nil { | ||
| return err | ||
| } | ||
| return serverStream.Send(resp) | ||
| }, | ||
| } | ||
| proxySH := &statsHandler{} | ||
| if err := proxy.Start([]grpc.ServerOption{grpc.StatsHandler(proxySH)}); err != nil { | ||
| t.Fatalf("Error starting proxy: %v", err) | ||
| } | ||
| defer proxy.Stop() | ||
|
|
||
| ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) | ||
| defer cancel() | ||
|
|
||
| stream, err := proxy.Client.FullDuplexCall(ctx, grpc.UseCompressor("gzip")) | ||
| if err != nil { | ||
| t.Fatalf("FullDuplexCall failed: %v", err) | ||
| } | ||
| if err := stream.Send(&testpb.StreamingOutputCallRequest{ | ||
| Payload: &testpb.Payload{Body: make([]byte, 1000)}, | ||
| }); err != nil { | ||
| t.Fatalf("Send failed: %v", err) | ||
| } | ||
| if _, err := stream.Recv(); err != nil { | ||
| t.Fatalf("Recv failed: %v", err) | ||
| } | ||
|
|
||
| // Collect results from both SetXxx calls. | ||
| for _, name := range []string{"SetServerStreamMessageCompression", "SetClientStreamMessageCompression"} { | ||
| select { | ||
| case err := <-errCh: | ||
| if err != nil { | ||
| t.Fatalf("%s on ambiguous context returned unexpected error: %v", name, err) | ||
| } | ||
| case <-ctx.Done(): | ||
| t.Fatalf("timed out waiting for %s result", name) | ||
| } | ||
| } | ||
|
|
||
| // SetServerStreamMessageCompression disabled compression on the | ||
| // proxy → caller direction: the proxy's server stream must not have | ||
| // sent any compressed messages. | ||
| if got := proxySH.compress.Load(); got != 0 { | ||
| t.Errorf("proxy server outbound compress count = %d, want 0 (SetServerStreamMessageCompression disabled it)", got) | ||
| } | ||
| // SetClientStreamMessageCompression disabled compression on the | ||
| // proxy → backend direction: the backend must not have received any | ||
| // compressed messages. | ||
| if got := backendSH.decompress.Load(); got != 0 { | ||
| t.Errorf("backend inbound decompress count = %d, want 0 (SetClientStreamMessageCompression disabled it)", got) | ||
| } | ||
| } | ||
|
|
||
| func (s) TestMessageCompression_Unary(t *testing.T) { | ||
| sh := &statsHandler{} | ||
| ss := &stubserver.StubServer{ | ||
| UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { | ||
| grpc.SetSendCompressor(ctx, "gzip") | ||
| if in.ResponseSize == 0 { | ||
| grpc.SetServerStreamMessageCompression(ctx, false) | ||
| } | ||
| return &testpb.SimpleResponse{Payload: &testpb.Payload{Body: make([]byte, 10000)}}, nil | ||
| }, | ||
| } | ||
|
|
||
| if err := ss.Start(nil, grpc.WithStatsHandler(sh)); err != nil { | ||
| t.Fatalf("Error starting server: %v", err) | ||
| } | ||
| defer ss.Stop() | ||
|
|
||
| ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) | ||
| defer cancel() | ||
|
|
||
| // Call 1: Compression ON | ||
| ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{ResponseSize: 1, Payload: &testpb.Payload{Body: make([]byte, 1000)}}, grpc.UseCompressor("gzip")) | ||
| if sh.compress.Load() != 1 || sh.decompress.Load() != 1 { | ||
| t.Fatalf("After call 1 (compression enabled): got compress=%d decompress=%d, want compress=1 decompress=1", | ||
| sh.compress.Load(), sh.decompress.Load()) | ||
| } | ||
|
|
||
| // Call 2: Compression OFF (for response, but request is still compressed by UseCompressor) | ||
| ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{ResponseSize: 0, Payload: &testpb.Payload{Body: make([]byte, 1000)}}, grpc.UseCompressor("gzip")) | ||
| if sh.compress.Load() != 2 || sh.decompress.Load() != 1 { | ||
| t.Fatalf("After call 2 (server response compression disabled): got compress=%d decompress=%d, want compress=2 decompress=1", | ||
| sh.compress.Load(), sh.decompress.Load()) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,13 +25,15 @@ import ( | |
| "sync" | ||
| "sync/atomic" | ||
|
|
||
| "google.golang.org/grpc/internal" | ||
| "google.golang.org/grpc/mem" | ||
| "google.golang.org/grpc/metadata" | ||
| "google.golang.org/grpc/status" | ||
| ) | ||
|
|
||
| // ServerStream implements streaming functionality for a gRPC server. | ||
| type ServerStream struct { | ||
| internal.EnforceServerTransportStreamEmbedding | ||
|
easwars marked this conversation as resolved.
|
||
| Stream // Embed for common stream functionality. | ||
|
|
||
| st internalServerTransport | ||
|
|
@@ -50,8 +52,21 @@ type ServerStream struct { | |
| headerSent atomic.Bool // atomically set when the headers are sent out. | ||
|
|
||
| headerWireLength int | ||
|
|
||
| // enableCompression controls whether per-message compression is enabled for | ||
| // this stream. It is accessed serially alongside SendMsg calls, so no mutex | ||
| // is needed. | ||
| enableCompression bool | ||
| } | ||
|
|
||
| // SetEnableCompression sets whether per-message compression is enabled for | ||
| // subsequent messages sent on this stream. | ||
| func (s *ServerStream) SetEnableCompression(v bool) { s.enableCompression = v } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Let's name this |
||
|
|
||
| // IsCompressionEnabled reports whether per-message compression is enabled for | ||
| // this stream. | ||
| func (s *ServerStream) IsCompressionEnabled() bool { return s.enableCompression } | ||
|
|
||
| // Read reads an n byte message from the input stream. | ||
| func (s *ServerStream) Read(n int) (mem.BufferSlice, error) { | ||
| b, err := s.Stream.read(n) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Here and elsewhere in this test, please have these calls to
t.Fatalbe on a single line. Thanks.