Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 251 additions & 0 deletions encoding/compressor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ import (

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/internal"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"

testgrpc "google.golang.org/grpc/interop/grpc_testing"
Expand Down Expand Up @@ -237,3 +239,252 @@ func (s) TestDecompressionExceedsMaxMessageSize(t *testing.T) {
t.Errorf("Client.UnaryCall(%+v) returned status %v, want %v", req, got, want)
}
}

// statsHandler is a stats.Handler that counts the number of compressed
// outbound and inbound messages by comparing CompressedLength to Length.
type statsHandler struct {
stats.Handler
compress atomic.Int32
decompress atomic.Int32
}

func (h *statsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { return ctx }
func (h *statsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { return ctx }
func (h *statsHandler) HandleConn(context.Context, stats.ConnStats) {}
func (h *statsHandler) HandleRPC(_ context.Context, s stats.RPCStats) {
switch st := s.(type) {
case *stats.OutPayload:
if st.CompressedLength < st.Length {
h.compress.Add(1)
}
case *stats.InPayload:
if st.CompressedLength < st.Length {
h.decompress.Add(1)
}
}
}

// TestMessageCompression_StreamToggle tests that SetServerStreamMessageCompression
// and SetClientStreamMessageCompression correctly enable and disable per-message
// compression mid-stream on the server and client side respectively.
func (s) TestMessageCompression_StreamToggle(t *testing.T) {
sh := &statsHandler{}
ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
if _, err := stream.Recv(); err != nil {
return err
}
if err := stream.Send(&testpb.StreamingOutputCallResponse{
Payload: &testpb.Payload{Body: make([]byte, 1000)},
}); err != nil {
return err
}
if _, err := stream.Recv(); err != nil {
return err
}
if err := grpc.SetServerStreamMessageCompression(stream.Context(), false); err != nil {
return err
}
if err := stream.Send(&testpb.StreamingOutputCallResponse{
Payload: &testpb.Payload{Body: make([]byte, 1000)},
}); err != nil {
return err
}
if _, err := stream.Recv(); err != nil {
return err
}
if err := grpc.SetServerStreamMessageCompression(stream.Context(), true); err != nil {
return err
}
return stream.Send(&testpb.StreamingOutputCallResponse{
Payload: &testpb.Payload{Body: make([]byte, 1000)},
})
},
}

if err := ss.Start(nil, grpc.WithStatsHandler(sh)); err != nil {
t.Fatalf("Error starting server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

stream, err := ss.Client.FullDuplexCall(ctx, grpc.UseCompressor("gzip"))
if err != nil {
t.Fatalf("FullDuplexCall failed: %v", err)
}

// 1. Send first compressed message
stream.Send(&testpb.StreamingOutputCallRequest{Payload: &testpb.Payload{Body: make([]byte, 1000)}})
stream.Recv()
if sh.compress.Load() != 1 || sh.decompress.Load() != 1 {
t.Fatalf("After call 1 (compression enabled): got compress=%d decompress=%d, want compress=1 decompress=1", sh.compress.Load(), sh.decompress.Load())
}

// 2. Disable message compression and send second message
grpc.SetClientStreamMessageCompression(stream, false)
stream.Send(&testpb.StreamingOutputCallRequest{Payload: &testpb.Payload{Body: make([]byte, 1000)}})
stream.Recv()
if sh.compress.Load() != 1 || sh.decompress.Load() != 1 {
t.Fatalf("After call 2 (compression disabled): got compress=%d decompress=%d, want compress=1 decompress=1", sh.compress.Load(), sh.decompress.Load())
}

// 3. Enable message compression and send third message
grpc.SetClientStreamMessageCompression(stream, true)
stream.Send(&testpb.StreamingOutputCallRequest{Payload: &testpb.Payload{Body: make([]byte, 1000)}})
stream.Recv()
if sh.compress.Load() != 2 || sh.decompress.Load() != 2 {
t.Fatalf("After call 3 (compression re-enabled): got compress=%d decompress=%d, want compress=2 decompress=2", sh.compress.Load(), sh.decompress.Load())
}
}

// TestMessageCompression_AmbiguousContext verifies that
// SetServerStreamMessageCompression and SetClientStreamMessageCompression work
// independently when a server handler propagates its context to an outbound
// gRPC call for deadline propagation.
func (s) TestMessageCompression_AmbiguousContext(t *testing.T) {
backendSH := &statsHandler{}
backend := &stubserver.StubServer{
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
for {
if _, err := stream.Recv(); err == io.EOF {
return nil
} else if err != nil {
return err
}
if err := stream.Send(&testpb.StreamingOutputCallResponse{
Payload: &testpb.Payload{Body: make([]byte, 1000)},
}); err != nil {
return err
}
}
},
}
if err := backend.StartServer(grpc.StatsHandler(backendSH)); err != nil {
t.Fatalf("Error starting backend: %v", err)
}
defer backend.Stop()

// errCh carries any errors from the two SetXxx calls inside the proxy handler.
errCh := make(chan error, 2)
proxy := &stubserver.StubServer{
FullDuplexCallF: func(serverStream testgrpc.TestService_FullDuplexCallServer) error {
// Use the server handler's context as the parent for the outbound
// call. This is the standard deadline-propagation pattern.
backendConn, err := grpc.NewClient(backend.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return err
}
defer backendConn.Close()

clientStream, err := testgrpc.NewTestServiceClient(backendConn).FullDuplexCall(serverStream.Context(), grpc.UseCompressor("gzip"))
if err != nil {
return err
}

// Must only affect the server stream (proxy → original caller).
errCh <- grpc.SetServerStreamMessageCompression(clientStream.Context(), false)
// Must only affect the client stream (proxy → backend).
errCh <- grpc.SetClientStreamMessageCompression(clientStream, false)

// Forward one request/response pair through the proxy.
req, err := serverStream.Recv()
if err != nil {
return err
}
if err := clientStream.Send(req); err != nil {
return err
}
if err := clientStream.CloseSend(); err != nil {
return err
}
resp, err := clientStream.Recv()
if err != nil {
return err
}
return serverStream.Send(resp)
},
}
proxySH := &statsHandler{}
if err := proxy.Start([]grpc.ServerOption{grpc.StatsHandler(proxySH)}); err != nil {
t.Fatalf("Error starting proxy: %v", err)
}
defer proxy.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

stream, err := proxy.Client.FullDuplexCall(ctx, grpc.UseCompressor("gzip"))
if err != nil {
t.Fatalf("FullDuplexCall failed: %v", err)
}
if err := stream.Send(&testpb.StreamingOutputCallRequest{
Payload: &testpb.Payload{Body: make([]byte, 1000)},
}); err != nil {
t.Fatalf("Send failed: %v", err)
}
if _, err := stream.Recv(); err != nil {
t.Fatalf("Recv failed: %v", err)
}

// Collect results from both SetXxx calls.
for _, name := range []string{"SetServerStreamMessageCompression", "SetClientStreamMessageCompression"} {
select {
case err := <-errCh:
if err != nil {
t.Fatalf("%s on ambiguous context returned unexpected error: %v", name, err)
}
case <-ctx.Done():
t.Fatalf("timed out waiting for %s result", name)
}
}

// SetServerStreamMessageCompression disabled compression on the
// proxy → caller direction: the proxy's server stream must not have
// sent any compressed messages.
if got := proxySH.compress.Load(); got != 0 {
t.Errorf("proxy server outbound compress count = %d, want 0 (SetServerStreamMessageCompression disabled it)", got)
}
// SetClientStreamMessageCompression disabled compression on the
// proxy → backend direction: the backend must not have received any
// compressed messages.
if got := backendSH.decompress.Load(); got != 0 {
t.Errorf("backend inbound decompress count = %d, want 0 (SetClientStreamMessageCompression disabled it)", got)
}
}

func (s) TestMessageCompression_Unary(t *testing.T) {
sh := &statsHandler{}
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
grpc.SetSendCompressor(ctx, "gzip")
if in.ResponseSize == 0 {
grpc.SetServerStreamMessageCompression(ctx, false)
}
return &testpb.SimpleResponse{Payload: &testpb.Payload{Body: make([]byte, 10000)}}, nil
},
}

if err := ss.Start(nil, grpc.WithStatsHandler(sh)); err != nil {
t.Fatalf("Error starting server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

// Call 1: Compression ON
ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{ResponseSize: 1, Payload: &testpb.Payload{Body: make([]byte, 1000)}}, grpc.UseCompressor("gzip"))
if sh.compress.Load() != 1 || sh.decompress.Load() != 1 {
t.Fatalf("After call 1 (compression enabled): got compress=%d decompress=%d, want compress=1 decompress=1",
sh.compress.Load(), sh.decompress.Load())
}

// Call 2: Compression OFF (for response, but request is still compressed by UseCompressor)
ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{ResponseSize: 0, Payload: &testpb.Payload{Body: make([]byte, 1000)}}, grpc.UseCompressor("gzip"))
if sh.compress.Load() != 2 || sh.decompress.Load() != 1 {
t.Fatalf("After call 2 (server response compression disabled): got compress=%d decompress=%d, want compress=2 decompress=1",
sh.compress.Load(), sh.decompress.Load())
}
}
6 changes: 6 additions & 0 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,9 @@ type Timer interface {
type EnforceMetricsRecorderEmbedding interface {
enforceMetricsRecorderEmbedding()
}

// EnforceServerTransportStreamEmbedding is used to enforce proper
// ServerTransportStream implementation embedding.
type EnforceServerTransportStreamEmbedding interface {
enforceServerTransportStreamEmbedding()
}
1 change: 1 addition & 0 deletions internal/transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
st: ht,
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
}
s.enableCompression.Store(true) // compression is enabled by default
s.Stream.buf.init()
s.readRequester = s
s.trReader = transportReader{
Expand Down
1 change: 1 addition & 0 deletions internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
st: t,
headerWireLength: int(frame.Header().Length),
}
s.enableCompression.Store(true) // compression is enabled by default
s.Stream.buf.init()
var (
// if false, content-type was missing or invalid
Expand Down
14 changes: 14 additions & 0 deletions internal/transport/server_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ import (
"sync"
"sync/atomic"

"google.golang.org/grpc/internal"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

// ServerStream implements streaming functionality for a gRPC server.
type ServerStream struct {
internal.EnforceServerTransportStreamEmbedding
Comment thread
easwars marked this conversation as resolved.
Stream // Embed for common stream functionality.

st internalServerTransport
Expand All @@ -50,8 +52,20 @@ type ServerStream struct {
headerSent atomic.Bool // atomically set when the headers are sent out.

headerWireLength int

// enableCompression controls whether per-message compression is enabled for
// this stream.
enableCompression atomic.Bool
}

// EnableCompression sets whether per-message compression is enabled for
// subsequent messages sent on this stream.
func (s *ServerStream) EnableCompression(v bool) { s.enableCompression.Store(v) }

// IsCompressionEnabled reports whether per-message compression is enabled for
// this stream.
func (s *ServerStream) IsCompressionEnabled() bool { return s.enableCompression.Load() }

// Read reads an n byte message from the input stream.
func (s *ServerStream) Read(n int) (mem.BufferSlice, error) {
b, err := s.Stream.read(n)
Expand Down
3 changes: 3 additions & 0 deletions internal/xds/rbac/rbac_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1805,6 +1805,7 @@ func (s) TestChainEngine(t *testing.T) {
}

type ServerTransportStreamWithMethod struct {
grpc.ServerTransportStream
method string
}

Expand All @@ -1824,6 +1825,8 @@ func (sts *ServerTransportStreamWithMethod) SetTrailer(metadata.MD) error {
return nil
}

func (sts *ServerTransportStreamWithMethod) EnableCompression(bool) {}

// An audit logger that will log to the auditEvents slice.
type TestAuditLoggerBuffer struct {
auditEvents *[]*audit.Event
Expand Down
10 changes: 9 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,11 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt
if stream.SendCompress() != sendCompressorName {
comp = encoding.GetCompressor(stream.SendCompress())
}
if err := s.sendResponse(ctx, stream, reply, cp, opts, comp); err != nil {
compV0, compV1 := cp, comp
if !stream.IsCompressionEnabled() {
compV0, compV1 = nil, nil
}
Comment on lines +1473 to +1476
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is executed before the server handler is invoked. This would mean that stream.IsCompressionEnabled would always return true. Am I missing something here? Do we really need this check here?

And if we really need this check, why is it missing for the streaming case?

if err := s.sendResponse(ctx, stream, reply, compV0, opts, compV1); err != nil {
if err == io.EOF {
// The entire stream is done (for unary RPC only).
return err
Expand Down Expand Up @@ -1888,10 +1892,14 @@ func NewContextWithServerTransportStream(ctx context.Context, stream ServerTrans
// Notice: This type is EXPERIMENTAL and may be changed or removed in a
// later release.
type ServerTransportStream interface {
internal.EnforceServerTransportStreamEmbedding
Method() string
SetHeader(md metadata.MD) error
SendHeader(md metadata.MD) error
SetTrailer(md metadata.MD) error
// EnableCompression controls whether per-message compression is enabled
// for subsequent messages sent on this stream.
EnableCompression(enable bool)
}

// ServerTransportStreamFromContext returns the ServerTransportStream saved in
Expand Down
Loading
Loading