diff --git a/encoding/compressor_test.go b/encoding/compressor_test.go index 18260ae37078..b7ee3de82ef3 100644 --- a/encoding/compressor_test.go +++ b/encoding/compressor_test.go @@ -27,9 +27,11 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding/internal" "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/stats" "google.golang.org/grpc/status" testgrpc "google.golang.org/grpc/interop/grpc_testing" @@ -237,3 +239,252 @@ func (s) TestDecompressionExceedsMaxMessageSize(t *testing.T) { t.Errorf("Client.UnaryCall(%+v) returned status %v, want %v", req, got, want) } } + +// statsHandler is a stats.Handler that counts the number of compressed +// outbound and inbound messages by comparing CompressedLength to Length. +type statsHandler struct { + stats.Handler + compress atomic.Int32 + decompress atomic.Int32 +} + +func (h *statsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { return ctx } +func (h *statsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { return ctx } +func (h *statsHandler) HandleConn(context.Context, stats.ConnStats) {} +func (h *statsHandler) HandleRPC(_ context.Context, s stats.RPCStats) { + switch st := s.(type) { + case *stats.OutPayload: + if st.CompressedLength < st.Length { + h.compress.Add(1) + } + case *stats.InPayload: + if st.CompressedLength < st.Length { + h.decompress.Add(1) + } + } +} + +// TestMessageCompression_StreamToggle tests that SetServerStreamMessageCompression +// and SetClientStreamMessageCompression correctly enable and disable per-message +// compression mid-stream on the server and client side respectively. +func (s) TestMessageCompression_StreamToggle(t *testing.T) { + sh := &statsHandler{} + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { + if _, err := stream.Recv(); err != nil { + return err + } + if err := stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: &testpb.Payload{Body: make([]byte, 1000)}, + }); err != nil { + return err + } + if _, err := stream.Recv(); err != nil { + return err + } + if err := grpc.SetServerStreamMessageCompression(stream.Context(), false); err != nil { + return err + } + if err := stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: &testpb.Payload{Body: make([]byte, 1000)}, + }); err != nil { + return err + } + if _, err := stream.Recv(); err != nil { + return err + } + if err := grpc.SetServerStreamMessageCompression(stream.Context(), true); err != nil { + return err + } + return stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: &testpb.Payload{Body: make([]byte, 1000)}, + }) + }, + } + + if err := ss.Start(nil, grpc.WithStatsHandler(sh)); err != nil { + t.Fatalf("Error starting server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + stream, err := ss.Client.FullDuplexCall(ctx, grpc.UseCompressor("gzip")) + if err != nil { + t.Fatalf("FullDuplexCall failed: %v", err) + } + + // 1. Send first compressed message + stream.Send(&testpb.StreamingOutputCallRequest{Payload: &testpb.Payload{Body: make([]byte, 1000)}}) + stream.Recv() + if sh.compress.Load() != 1 || sh.decompress.Load() != 1 { + t.Fatalf("After call 1 (compression enabled): got compress=%d decompress=%d, want compress=1 decompress=1", sh.compress.Load(), sh.decompress.Load()) + } + + // 2. Disable message compression and send second message + grpc.SetClientStreamMessageCompression(stream, false) + stream.Send(&testpb.StreamingOutputCallRequest{Payload: &testpb.Payload{Body: make([]byte, 1000)}}) + stream.Recv() + if sh.compress.Load() != 1 || sh.decompress.Load() != 1 { + t.Fatalf("After call 2 (compression disabled): got compress=%d decompress=%d, want compress=1 decompress=1", sh.compress.Load(), sh.decompress.Load()) + } + + // 3. Enable message compression and send third message + grpc.SetClientStreamMessageCompression(stream, true) + stream.Send(&testpb.StreamingOutputCallRequest{Payload: &testpb.Payload{Body: make([]byte, 1000)}}) + stream.Recv() + if sh.compress.Load() != 2 || sh.decompress.Load() != 2 { + t.Fatalf("After call 3 (compression re-enabled): got compress=%d decompress=%d, want compress=2 decompress=2", sh.compress.Load(), sh.decompress.Load()) + } +} + +// TestMessageCompression_AmbiguousContext verifies that +// SetServerStreamMessageCompression and SetClientStreamMessageCompression work +// independently when a server handler propagates its context to an outbound +// gRPC call for deadline propagation. +func (s) TestMessageCompression_AmbiguousContext(t *testing.T) { + backendSH := &statsHandler{} + backend := &stubserver.StubServer{ + FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { + for { + if _, err := stream.Recv(); err == io.EOF { + return nil + } else if err != nil { + return err + } + if err := stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: &testpb.Payload{Body: make([]byte, 1000)}, + }); err != nil { + return err + } + } + }, + } + if err := backend.StartServer(grpc.StatsHandler(backendSH)); err != nil { + t.Fatalf("Error starting backend: %v", err) + } + defer backend.Stop() + + // errCh carries any errors from the two SetXxx calls inside the proxy handler. + errCh := make(chan error, 2) + proxy := &stubserver.StubServer{ + FullDuplexCallF: func(serverStream testgrpc.TestService_FullDuplexCallServer) error { + // Use the server handler's context as the parent for the outbound + // call. This is the standard deadline-propagation pattern. + backendConn, err := grpc.NewClient(backend.Address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return err + } + defer backendConn.Close() + + clientStream, err := testgrpc.NewTestServiceClient(backendConn).FullDuplexCall(serverStream.Context(), grpc.UseCompressor("gzip")) + if err != nil { + return err + } + + // Must only affect the server stream (proxy → original caller). + errCh <- grpc.SetServerStreamMessageCompression(clientStream.Context(), false) + // Must only affect the client stream (proxy → backend). + errCh <- grpc.SetClientStreamMessageCompression(clientStream, false) + + // Forward one request/response pair through the proxy. + req, err := serverStream.Recv() + if err != nil { + return err + } + if err := clientStream.Send(req); err != nil { + return err + } + if err := clientStream.CloseSend(); err != nil { + return err + } + resp, err := clientStream.Recv() + if err != nil { + return err + } + return serverStream.Send(resp) + }, + } + proxySH := &statsHandler{} + if err := proxy.Start([]grpc.ServerOption{grpc.StatsHandler(proxySH)}); err != nil { + t.Fatalf("Error starting proxy: %v", err) + } + defer proxy.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + stream, err := proxy.Client.FullDuplexCall(ctx, grpc.UseCompressor("gzip")) + if err != nil { + t.Fatalf("FullDuplexCall failed: %v", err) + } + if err := stream.Send(&testpb.StreamingOutputCallRequest{ + Payload: &testpb.Payload{Body: make([]byte, 1000)}, + }); err != nil { + t.Fatalf("Send failed: %v", err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("Recv failed: %v", err) + } + + // Collect results from both SetXxx calls. + for _, name := range []string{"SetServerStreamMessageCompression", "SetClientStreamMessageCompression"} { + select { + case err := <-errCh: + if err != nil { + t.Fatalf("%s on ambiguous context returned unexpected error: %v", name, err) + } + case <-ctx.Done(): + t.Fatalf("timed out waiting for %s result", name) + } + } + + // SetServerStreamMessageCompression disabled compression on the + // proxy → caller direction: the proxy's server stream must not have + // sent any compressed messages. + if got := proxySH.compress.Load(); got != 0 { + t.Errorf("proxy server outbound compress count = %d, want 0 (SetServerStreamMessageCompression disabled it)", got) + } + // SetClientStreamMessageCompression disabled compression on the + // proxy → backend direction: the backend must not have received any + // compressed messages. + if got := backendSH.decompress.Load(); got != 0 { + t.Errorf("backend inbound decompress count = %d, want 0 (SetClientStreamMessageCompression disabled it)", got) + } +} + +func (s) TestMessageCompression_Unary(t *testing.T) { + sh := &statsHandler{} + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + grpc.SetSendCompressor(ctx, "gzip") + if in.ResponseSize == 0 { + grpc.SetServerStreamMessageCompression(ctx, false) + } + return &testpb.SimpleResponse{Payload: &testpb.Payload{Body: make([]byte, 10000)}}, nil + }, + } + + if err := ss.Start(nil, grpc.WithStatsHandler(sh)); err != nil { + t.Fatalf("Error starting server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Call 1: Compression ON + ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{ResponseSize: 1, Payload: &testpb.Payload{Body: make([]byte, 1000)}}, grpc.UseCompressor("gzip")) + if sh.compress.Load() != 1 || sh.decompress.Load() != 1 { + t.Fatalf("After call 1 (compression enabled): got compress=%d decompress=%d, want compress=1 decompress=1", + sh.compress.Load(), sh.decompress.Load()) + } + + // Call 2: Compression OFF (for response, but request is still compressed by UseCompressor) + ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{ResponseSize: 0, Payload: &testpb.Payload{Body: make([]byte, 1000)}}, grpc.UseCompressor("gzip")) + if sh.compress.Load() != 2 || sh.decompress.Load() != 1 { + t.Fatalf("After call 2 (server response compression disabled): got compress=%d decompress=%d, want compress=2 decompress=1", + sh.compress.Load(), sh.decompress.Load()) + } +} diff --git a/internal/internal.go b/internal/internal.go index 4b3d563f8d76..628f4a691d82 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -298,3 +298,9 @@ type Timer interface { type EnforceMetricsRecorderEmbedding interface { enforceMetricsRecorderEmbedding() } + +// EnforceServerTransportStreamEmbedding is used to enforce proper +// ServerTransportStream implementation embedding. +type EnforceServerTransportStreamEmbedding interface { + enforceServerTransportStreamEmbedding() +} diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 7ab3422b8a27..ad910db36b52 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -424,6 +424,7 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream st: ht, headerWireLength: 0, // won't have access to header wire length until golang/go#18997. } + s.enableCompression.Store(true) // compression is enabled by default s.Stream.buf.init() s.readRequester = s s.trReader = transportReader{ diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 3a8c36e4f94f..0735ed110013 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -399,6 +399,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade st: t, headerWireLength: int(frame.Header().Length), } + s.enableCompression.Store(true) // compression is enabled by default s.Stream.buf.init() var ( // if false, content-type was missing or invalid diff --git a/internal/transport/server_stream.go b/internal/transport/server_stream.go index ed6a13b7501a..9a6e33f5a117 100644 --- a/internal/transport/server_stream.go +++ b/internal/transport/server_stream.go @@ -25,6 +25,7 @@ import ( "sync" "sync/atomic" + "google.golang.org/grpc/internal" "google.golang.org/grpc/mem" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -32,6 +33,7 @@ import ( // ServerStream implements streaming functionality for a gRPC server. type ServerStream struct { + internal.EnforceServerTransportStreamEmbedding Stream // Embed for common stream functionality. st internalServerTransport @@ -50,8 +52,20 @@ type ServerStream struct { headerSent atomic.Bool // atomically set when the headers are sent out. headerWireLength int + + // enableCompression controls whether per-message compression is enabled for + // this stream. + enableCompression atomic.Bool } +// EnableCompression sets whether per-message compression is enabled for +// subsequent messages sent on this stream. +func (s *ServerStream) EnableCompression(v bool) { s.enableCompression.Store(v) } + +// IsCompressionEnabled reports whether per-message compression is enabled for +// this stream. +func (s *ServerStream) IsCompressionEnabled() bool { return s.enableCompression.Load() } + // Read reads an n byte message from the input stream. func (s *ServerStream) Read(n int) (mem.BufferSlice, error) { b, err := s.Stream.read(n) diff --git a/internal/xds/rbac/rbac_engine_test.go b/internal/xds/rbac/rbac_engine_test.go index fe84bf249d35..7500106db914 100644 --- a/internal/xds/rbac/rbac_engine_test.go +++ b/internal/xds/rbac/rbac_engine_test.go @@ -1805,6 +1805,7 @@ func (s) TestChainEngine(t *testing.T) { } type ServerTransportStreamWithMethod struct { + grpc.ServerTransportStream method string } @@ -1824,6 +1825,8 @@ func (sts *ServerTransportStreamWithMethod) SetTrailer(metadata.MD) error { return nil } +func (sts *ServerTransportStreamWithMethod) EnableCompression(bool) {} + // An audit logger that will log to the auditEvents slice. type TestAuditLoggerBuffer struct { auditEvents *[]*audit.Event diff --git a/server.go b/server.go index 1b5cefe81715..8afb76d64ab1 100644 --- a/server.go +++ b/server.go @@ -1470,7 +1470,11 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt if stream.SendCompress() != sendCompressorName { comp = encoding.GetCompressor(stream.SendCompress()) } - if err := s.sendResponse(ctx, stream, reply, cp, opts, comp); err != nil { + compV0, compV1 := cp, comp + if !stream.IsCompressionEnabled() { + compV0, compV1 = nil, nil + } + if err := s.sendResponse(ctx, stream, reply, compV0, opts, compV1); err != nil { if err == io.EOF { // The entire stream is done (for unary RPC only). return err @@ -1888,10 +1892,14 @@ func NewContextWithServerTransportStream(ctx context.Context, stream ServerTrans // Notice: This type is EXPERIMENTAL and may be changed or removed in a // later release. type ServerTransportStream interface { + internal.EnforceServerTransportStreamEmbedding Method() string SetHeader(md metadata.MD) error SendHeader(md metadata.MD) error SetTrailer(md metadata.MD) error + // EnableCompression controls whether per-message compression is enabled + // for subsequent messages sent on this stream. + EnableCompression(enable bool) } // ServerTransportStreamFromContext returns the ServerTransportStream saved in diff --git a/stream.go b/stream.go index eedb5f9b99c7..26623360770b 100644 --- a/stream.go +++ b/stream.go @@ -21,6 +21,7 @@ package grpc import ( "context" "errors" + "fmt" "io" "math" rand "math/rand/v2" @@ -51,6 +52,49 @@ import ( var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool)) +// clientStreamKey is the context key used to store a *clientStream pointer +// in the stream's context so that SetClientStreamMessageCompression can retrieve +// it from a generated stream wrapper without a direct type assertion. +type clientStreamKey struct{} + +// SetServerStreamMessageCompression enables or disables per-message compression +// on a server stream. The provided context must be the context passed to the +// server handler. Compression is enabled by default and is a no-op if no +// compressor is configured on the stream (e.g. via SetSendCompressor). +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func SetServerStreamMessageCompression(ctx context.Context, enable bool) error { + sts := ServerTransportStreamFromContext(ctx) + if sts == nil { + return fmt.Errorf("grpc: SetServerStreamMessageCompression called on a non-server-stream context") + } + sts.EnableCompression(enable) + return nil +} + +// SetClientStreamMessageCompression enables or disables per-message compression +// on a client stream. Compression is enabled by default and is a no-op if no +// compressor is configured on the stream (e.g. via UseCompressor). An error is +// returned if the provided stream is not a gRPC client stream. +// +// This method must not be called concurrently with SendMsg. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func SetClientStreamMessageCompression(cs ClientStream, enable bool) error { + s, ok := cs.Context().Value(clientStreamKey{}).(*clientStream) + if !ok { + return fmt.Errorf("grpc: SetClientStreamMessageCompression called on a non-client-stream context") + } + s.enableCompression = enable + return nil +} + // StreamHandler defines the handler called by gRPC server to complete the // execution of a streaming RPC. srv is the service implementation on which the // RPC was invoked. @@ -365,6 +409,10 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client firstAttempt: true, onCommit: onCommit, nameResolutionDelay: nameResolutionDelayed, + enableCompression: compressorV0 != nil || compressorV1 != nil, + } + if compressorV0 != nil || compressorV1 != nil { + cs.ctx = context.WithValue(cs.ctx, clientStreamKey{}, cs) } if !cc.dopts.disableRetry { cs.retryThrottler = cc.retryThrottler.Load().(*retryThrottler) @@ -623,6 +671,11 @@ type clientStream struct { // nameResolutionDelay indicates if there was a delay in the name resolution. // This field is only valid on client side, it's always false on server side. nameResolutionDelay bool + + // enableCompression controls whether per-message compression is enabled for + // this stream. It is accessed serially alongside SendMsg calls, so no mutex + // is needed. + enableCompression bool } type replayOp struct { @@ -954,7 +1007,11 @@ func (cs *clientStream) SendMsg(m any) (err error) { } // load hdr, payload, data - hdr, data, payload, pf, err := prepareMsg(m, cs.codec, cs.compressorV0, cs.compressorV1, cs.cc.dopts.copts.BufferPool) + compV0, compV1 := cs.compressorV0, cs.compressorV1 + if !cs.enableCompression { + compV0, compV1 = nil, nil + } + hdr, data, payload, pf, err := prepareMsg(m, cs.codec, compV0, compV1, cs.cc.dopts.copts.BufferPool) if err != nil { return err } @@ -1340,18 +1397,19 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin // Use a special addrConnStream to avoid retry. as := &addrConnStream{ - callHdr: callHdr, - ac: ac, - ctx: ctx, - cancel: cancel, - opts: opts, - callInfo: c, - desc: desc, - codec: c.codec, - sendCompressorV0: cp, - sendCompressorV1: comp, - decompressorV0: ac.cc.dopts.dc, - transport: t, + callHdr: callHdr, + ac: ac, + ctx: ctx, + cancel: cancel, + opts: opts, + callInfo: c, + desc: desc, + codec: c.codec, + sendCompressorV0: cp, + sendCompressorV1: comp, + enableCompression: cp != nil || comp != nil, + decompressorV0: ac.cc.dopts.dc, + transport: t, } // nil stats handler: internal streams like health and ORCA do not support telemetry. @@ -1387,24 +1445,25 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin } type addrConnStream struct { - transportStream *transport.ClientStream - ac *addrConn - callHdr *transport.CallHdr - cancel context.CancelFunc - opts []CallOption - callInfo *callInfo - transport transport.ClientTransport - ctx context.Context - sentLast bool - receivedFirstMsg bool - desc *StreamDesc - codec baseCodec - sendCompressorV0 Compressor - sendCompressorV1 encoding.Compressor - decompressorSet bool - decompressorV0 Decompressor - decompressorV1 encoding.Compressor - parser parser + transportStream *transport.ClientStream + ac *addrConn + callHdr *transport.CallHdr + cancel context.CancelFunc + opts []CallOption + callInfo *callInfo + transport transport.ClientTransport + ctx context.Context + sentLast bool + receivedFirstMsg bool + desc *StreamDesc + codec baseCodec + sendCompressorV0 Compressor + sendCompressorV1 encoding.Compressor + enableCompression bool + decompressorSet bool + decompressorV0 Decompressor + decompressorV1 encoding.Compressor + parser parser // mu guards finished and is held for the entire finish method. mu sync.Mutex @@ -1461,7 +1520,11 @@ func (as *addrConnStream) SendMsg(m any) (err error) { } // load hdr, payload, data - hdr, data, payload, pf, err := prepareMsg(m, as.codec, as.sendCompressorV0, as.sendCompressorV1, as.ac.dopts.copts.BufferPool) + compV0, compV1 := as.sendCompressorV0, as.sendCompressorV1 + if !as.enableCompression { + compV0, compV1 = nil, nil + } + hdr, data, payload, pf, err := prepareMsg(m, as.codec, compV0, compV1, as.ac.dopts.copts.BufferPool) if err != nil { return err } @@ -1741,7 +1804,11 @@ func (ss *serverStream) SendMsg(m any) (err error) { } // load hdr, payload, data - hdr, data, payload, pf, err := prepareMsg(m, ss.codec, ss.compressorV0, ss.compressorV1, ss.p.bufferPool) + compV0, compV1 := ss.compressorV0, ss.compressorV1 + if !ss.s.IsCompressionEnabled() { + compV0, compV1 = nil, nil + } + hdr, data, payload, pf, err := prepareMsg(m, ss.codec, compV0, compV1, ss.p.bufferPool) if err != nil { return err }