Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions internal/transport/server_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,17 @@ type ServerStream struct {
headerSent atomic.Bool // atomically set when the headers are sent out.

headerWireLength int

doNotCompress bool
Comment thread
arjan-bal marked this conversation as resolved.
Outdated
Comment thread
arjan-bal marked this conversation as resolved.
Outdated
}

// SetDoNotCompress sets whether compression should be disabled for subsequent
// messages sent on this stream.
func (s *ServerStream) SetDoNotCompress(v bool) { s.doNotCompress = v }

// IsDoNotCompress reports whether compression is disabled for this stream.
func (s *ServerStream) IsDoNotCompress() bool { return s.doNotCompress }

// Read reads an n byte message from the input stream.
func (s *ServerStream) Read(n int) (mem.BufferSlice, error) {
b, err := s.Stream.read(n)
Expand Down
6 changes: 5 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,11 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt
if stream.SendCompress() != sendCompressorName {
comp = encoding.GetCompressor(stream.SendCompress())
}
if err := s.sendResponse(ctx, stream, reply, cp, opts, comp); err != nil {
compV0, compV1 := cp, comp
if stream.IsDoNotCompress() {
compV0, compV1 = nil, nil
}
Comment on lines +1473 to +1476
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is executed before the server handler is invoked. This would mean that stream.IsCompressionEnabled would always return true. Am I missing something here? Do we really need this check here?

And if we really need this check, why is it missing for the streaming case?

if err := s.sendResponse(ctx, stream, reply, compV0, opts, compV1); err != nil {
if err == io.EOF {
// The entire stream is done (for unary RPC only).
return err
Expand Down
63 changes: 60 additions & 3 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package grpc
import (
"context"
"errors"
"fmt"
"io"
"math"
rand "math/rand/v2"
Expand Down Expand Up @@ -51,6 +52,47 @@ import (

var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool))

type doNotCompressSetter interface {
SetDoNotCompress(bool)
}

type compressKey struct{}

// SetMessageCompression enables or disables per-message compression on a stream
// if a compressor is specified for the stream (e.g. using UseCompressor or
// SetSendCompressor) and if the encoding type is supported by the receiver.
// By default, message compression is enabled, but is a no-op if compression
// is not enabled on the stream.
//
// On the server side, the context provided must be the context passed to the
// server's handler. On the client side, the context provided must be the
// context associated with the stream, obtained via ClientStream.Context().
//
// This method must not be called concurrently with SendMsg.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func SetMessageCompression(ctx context.Context, enable bool) error {
Comment thread
arjan-bal marked this conversation as resolved.
Outdated
// Server side: transport.ServerStream is stored in context via streamKey.
// We use an interface upgrade to avoid importing transport directly here.
Comment thread
arjan-bal marked this conversation as resolved.
Outdated

if sts := ServerTransportStreamFromContext(ctx); sts != nil {
if s, ok := sts.(doNotCompressSetter); ok {
s.SetDoNotCompress(!enable)
return nil
}
}
// Client side: *bool pointing to clientStream.doNotCompress is stored in context.
flag, ok := ctx.Value(compressKey{}).(*bool)
if !ok || flag == nil {
return fmt.Errorf("grpc: SetMessageCompression called on an uninitialized or non-stream context")
Comment thread
arjan-bal marked this conversation as resolved.
Outdated
}
*flag = !enable
return nil
}

// StreamHandler defines the handler called by gRPC server to complete the
// execution of a streaming RPC. srv is the service implementation on which the
// RPC was invoked.
Expand Down Expand Up @@ -366,6 +408,9 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client
onCommit: onCommit,
nameResolutionDelay: nameResolutionDelayed,
}
if compressorV0 != nil || compressorV1 != nil {
cs.ctx = context.WithValue(cs.ctx, compressKey{}, new(bool))
}
Comment thread
arjan-bal marked this conversation as resolved.
Comment thread
arjan-bal marked this conversation as resolved.
if !cc.dopts.disableRetry {
cs.retryThrottler = cc.retryThrottler.Load().(*retryThrottler)
}
Expand Down Expand Up @@ -954,7 +999,11 @@ func (cs *clientStream) SendMsg(m any) (err error) {
}

// load hdr, payload, data
hdr, data, payload, pf, err := prepareMsg(m, cs.codec, cs.compressorV0, cs.compressorV1, cs.cc.dopts.copts.BufferPool)
compV0, compV1 := cs.compressorV0, cs.compressorV1
if flag, ok := cs.ctx.Value(compressKey{}).(*bool); ok && *flag {
compV0, compV1 = nil, nil
}
hdr, data, payload, pf, err := prepareMsg(m, cs.codec, compV0, compV1, cs.cc.dopts.copts.BufferPool)
if err != nil {
return err
}
Expand Down Expand Up @@ -1461,7 +1510,11 @@ func (as *addrConnStream) SendMsg(m any) (err error) {
}

// load hdr, payload, data
hdr, data, payload, pf, err := prepareMsg(m, as.codec, as.sendCompressorV0, as.sendCompressorV1, as.ac.dopts.copts.BufferPool)
compV0, compV1 := as.sendCompressorV0, as.sendCompressorV1
if flag, ok := as.ctx.Value(compressKey{}).(*bool); ok && *flag {
compV0, compV1 = nil, nil
}
hdr, data, payload, pf, err := prepareMsg(m, as.codec, compV0, compV1, as.ac.dopts.copts.BufferPool)
if err != nil {
return err
}
Expand Down Expand Up @@ -1741,7 +1794,11 @@ func (ss *serverStream) SendMsg(m any) (err error) {
}

// load hdr, payload, data
hdr, data, payload, pf, err := prepareMsg(m, ss.codec, ss.compressorV0, ss.compressorV1, ss.p.bufferPool)
compV0, compV1 := ss.compressorV0, ss.compressorV1
if ss.s.IsDoNotCompress() {
compV0, compV1 = nil, nil
}
hdr, data, payload, pf, err := prepareMsg(m, ss.codec, compV0, compV1, ss.p.bufferPool)
if err != nil {
return err
}
Expand Down
139 changes: 139 additions & 0 deletions test/compressor_test.go
Comment thread
arjan-bal marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"io"
"reflect"
"strings"
"sync/atomic"
"testing"

"google.golang.org/grpc"
Expand All @@ -34,6 +35,7 @@ import (
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"

testgrpc "google.golang.org/grpc/interop/grpc_testing"
Expand Down Expand Up @@ -692,3 +694,140 @@ func (s) TestGzipBadChecksum(t *testing.T) {
t.Errorf("ss.Client.UnaryCall(_) = _, %v\n\twant: _, status(codes.Internal, contains %q)", err, gzip.ErrChecksum)
}
}

// statsHandler is a stats.Handler that counts the number of compressed
// outbound and inbound messages by comparing CompressedLength to Length.
type statsHandler struct {
Comment thread
Dostonlv marked this conversation as resolved.
Outdated
stats.Handler
compress atomic.Int32
decompress atomic.Int32
}

func (h *statsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { return ctx }
func (h *statsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { return ctx }
func (h *statsHandler) HandleConn(context.Context, stats.ConnStats) {}
func (h *statsHandler) HandleRPC(_ context.Context, s stats.RPCStats) {
switch st := s.(type) {
case *stats.OutPayload:
if st.CompressedLength < st.Length {
h.compress.Add(1)
}
case *stats.InPayload:
if st.CompressedLength < st.Length {
h.decompress.Add(1)
}
}
}

// TestMessageCompression_StreamToggle tests that SetMessageCompression
// correctly enables and disables per-message compression mid-stream on both
// the client and server side.
func (s) TestMessageCompression_StreamToggle(t *testing.T) {
sh := &statsHandler{}
ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
if _, err := stream.Recv(); err != nil {
return err
}
if err := stream.Send(&testpb.StreamingOutputCallResponse{
Payload: &testpb.Payload{Body: make([]byte, 1000)},
}); err != nil {
return err
}
if _, err := stream.Recv(); err != nil {
return err
}
if err := grpc.SetMessageCompression(stream.Context(), false); err != nil {
return err
}
if err := stream.Send(&testpb.StreamingOutputCallResponse{
Payload: &testpb.Payload{Body: make([]byte, 1000)},
}); err != nil {
return err
}
if _, err := stream.Recv(); err != nil {
return err
}
if err := grpc.SetMessageCompression(stream.Context(), true); err != nil {
return err
}
return stream.Send(&testpb.StreamingOutputCallResponse{
Payload: &testpb.Payload{Body: make([]byte, 1000)},
})
},
}

if err := ss.Start(nil, grpc.WithStatsHandler(sh)); err != nil {
t.Fatalf("Error starting server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

stream, err := ss.Client.FullDuplexCall(ctx, grpc.UseCompressor("gzip"))
if err != nil {
t.Fatalf("FullDuplexCall failed: %v", err)
}

// 1. Send first compressed message
stream.Send(&testpb.StreamingOutputCallRequest{Payload: &testpb.Payload{Body: make([]byte, 1000)}})
stream.Recv()
if sh.compress.Load() != 1 || sh.decompress.Load() != 1 {
t.Fatalf("After call 1 (compression enabled): got compress=%d decompress=%d, want compress=1 decompress=1",
sh.compress.Load(), sh.decompress.Load())
}

// 2. Disable message compression and send second message
grpc.SetMessageCompression(stream.Context(), false)
stream.Send(&testpb.StreamingOutputCallRequest{Payload: &testpb.Payload{Body: make([]byte, 1000)}})
stream.Recv()
if sh.compress.Load() != 1 || sh.decompress.Load() != 1 {
t.Fatalf("After call 2 (compression disabled): got compress=%d decompress=%d, want compress=1 decompress=1",
sh.compress.Load(), sh.decompress.Load())
}

// 3. Enable message compression and send third message
grpc.SetMessageCompression(stream.Context(), true)
stream.Send(&testpb.StreamingOutputCallRequest{Payload: &testpb.Payload{Body: make([]byte, 1000)}})
stream.Recv()
if sh.compress.Load() != 2 || sh.decompress.Load() != 2 {
t.Fatalf("After call 3 (compression re-enabled): got compress=%d decompress=%d, want compress=2 decompress=2",
sh.compress.Load(), sh.decompress.Load())
}
}

func (s) TestMessageCompression_Unary(t *testing.T) {
sh := &statsHandler{}
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
grpc.SetSendCompressor(ctx, "gzip")
if in.ResponseSize == 0 {
grpc.SetMessageCompression(ctx, false)
}
return &testpb.SimpleResponse{Payload: &testpb.Payload{Body: make([]byte, 10000)}}, nil
},
}

if err := ss.Start(nil, grpc.WithStatsHandler(sh)); err != nil {
t.Fatalf("Error starting server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

// Call 1: Compression ON
ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{ResponseSize: 1, Payload: &testpb.Payload{Body: make([]byte, 1000)}}, grpc.UseCompressor("gzip"))
if sh.compress.Load() != 1 || sh.decompress.Load() != 1 {
t.Fatalf("After call 1 (compression enabled): got compress=%d decompress=%d, want compress=1 decompress=1",
sh.compress.Load(), sh.decompress.Load())
}

// Call 2: Compression OFF (for response, but request is still compressed by UseCompressor)
ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{ResponseSize: 0, Payload: &testpb.Payload{Body: make([]byte, 1000)}}, grpc.UseCompressor("gzip"))
if sh.compress.Load() != 2 || sh.decompress.Load() != 1 {
t.Fatalf("After call 2 (server response compression disabled): got compress=%d decompress=%d, want compress=2 decompress=1",
sh.compress.Load(), sh.decompress.Load())
}
}
Loading