Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 176 additions & 5 deletions encoding/compressor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"bytes"
"context"
"io"
"sync"
"sync/atomic"
"testing"

Expand All @@ -38,16 +39,21 @@ import (
_ "google.golang.org/grpc/encoding/gzip"
)

// wrapCompressor is a wrapper of encoding.Compressor which maintains count of
// Compressor method invokes.
// wrapCompressor is a wrapper of encoding.Compressor which records invocation
// count and the options passed to each Compress call.
type wrapCompressor struct {
encoding.Compressor
compressInvokes int32
mu sync.Mutex
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the mutex to protect compressInvokes too. Unifying the synchronization mechanisms will make the code much simpler.

Additionally, please use empty lines to group the mutex-guarded fields, similar to this example:

// mu protects the following fields and all fields within balancerCurrent
// and balancerPending. mu does not need to be held when calling into the
// child balancers, as all calls into these children happen only as a direct
// result of a call into the gracefulSwitchBalancer, which are also
// guaranteed to be synchronous. There is one exception: an UpdateState call
// from a child balancer when current and pending are populated can lead to
// calling Close() on the current. To prevent that racing with an
// UpdateSubConnState from the channel, we hold currentMu during Close and
// UpdateSubConnState calls.
mu sync.Mutex
balancerCurrent *balancerWrapper
balancerPending *balancerWrapper
closed bool // set to true when this balancer is closed

receivedOpts [][]any
}

func (wc *wrapCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
func (wc *wrapCompressor) Compress(w io.Writer, opts ...any) (io.WriteCloser, error) {
atomic.AddInt32(&wc.compressInvokes, 1)
return wc.Compressor.Compress(w)
wc.mu.Lock()
wc.receivedOpts = append(wc.receivedOpts, opts)
wc.mu.Unlock()
return wc.Compressor.Compress(w, opts...)
}

func setupGzipWrapCompressor(t *testing.T) *wrapCompressor {
Expand Down Expand Up @@ -186,7 +192,7 @@ type fakeCompressor struct {
decompressedMessageSize int
}

func (f *fakeCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
func (f *fakeCompressor) Compress(w io.Writer, _ ...any) (io.WriteCloser, error) {
return nopWriteCloser{w}, nil
}

Expand Down Expand Up @@ -237,3 +243,168 @@ func (s) TestDecompressionExceedsMaxMessageSize(t *testing.T) {
t.Errorf("Client.UnaryCall(%+v) returned status %v, want %v", req, got, want)
}
}

// TestSetSendCompressorOptionsPropagate verifies that options passed to
// SetSendCompressor are forwarded to the compressor's Compress method.
func (s) TestSetSendCompressorOptionsPropagate(t *testing.T) {
wantOpt := "dict-id-42"
for _, tc := range []struct {
name string
run func(*testing.T, *wrapCompressor)
}{
{"unary", testUnarySendCompressorOptionsPropagate},
{"stream", testStreamSendCompressorOptionsPropagate},
Comment on lines +255 to +256
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The helper seems to be abstracting most of the test logic here. I would recommend avoiding the table driven approach and have separate tests for unary and streaming RPCs, inlining the helpers. Same for the TestUseCompressorOptionsPropagate test.

See https://google.github.io/styleguide/go/best-practices#leave-testing-to-the-test-function

} {
t.Run(tc.name, func(t *testing.T) {
wc := setupGzipWrapCompressor(t)
tc.run(t, wc)
wc.mu.Lock()
defer wc.mu.Unlock()
if len(wc.receivedOpts) == 0 {
t.Fatal("Compress was not called")
}
if got := wc.receivedOpts[0]; len(got) == 0 || got[0] != wantOpt {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You can use cmp.Diff to compare the entire slice.

t.Fatalf("Compress received opts %v, want [%q]", got, wantOpt)
}
})
}
}

func testUnarySendCompressorOptionsPropagate(t *testing.T, _ *wrapCompressor) {
t.Helper()
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, _ *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
if err := grpc.SetSendCompressor(ctx, "gzip", "dict-id-42"); err != nil {
return nil, err
}
return &testpb.SimpleResponse{Payload: &testpb.Payload{Body: []byte("payload")}}, nil
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
t.Fatalf("Unexpected unary call error: %v", err)
}
}

func testStreamSendCompressorOptionsPropagate(t *testing.T, _ *wrapCompressor) {
t.Helper()
ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
if _, err := stream.Recv(); err != nil {
return err
}
if err := grpc.SetSendCompressor(stream.Context(), "gzip", "dict-id-42"); err != nil {
return err
}
return stream.Send(&testpb.StreamingOutputCallResponse{
Payload: &testpb.Payload{Body: []byte("payload")},
})
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

s, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("Unexpected full duplex call error: %v", err)
}
if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
t.Fatalf("Unexpected send error: %v", err)
}
if _, err := s.Recv(); err != nil {
t.Fatalf("Unexpected recv error: %v", err)
}
}

// TestUseCompressorOptionsPropagate verifies that options passed to
// UseCompressor are forwarded to the compressor's Compress method.
func (s) TestUseCompressorOptionsPropagate(t *testing.T) {
wantOpt := "dict-id-42"
for _, tc := range []struct {
name string
run func(*testing.T, *wrapCompressor)
}{
{"unary", testUnaryUseCompressorOptionsPropagate},
{"stream", testStreamUseCompressorOptionsPropagate},
} {
t.Run(tc.name, func(t *testing.T) {
wc := setupGzipWrapCompressor(t)
tc.run(t, wc)
wc.mu.Lock()
defer wc.mu.Unlock()
if len(wc.receivedOpts) == 0 {
t.Fatal("Compress was not called")
}
if got := wc.receivedOpts[0]; len(got) == 0 || got[0] != wantOpt {
t.Fatalf("Compress received opts %v, want [%q]", got, wantOpt)
}
})
}
}

func testUnaryUseCompressorOptionsPropagate(t *testing.T, _ *wrapCompressor) {
t.Helper()
ss := &stubserver.StubServer{
UnaryCallF: func(_ context.Context, _ *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{Payload: &testpb.Payload{Body: []byte("payload")}}, nil
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{Body: []byte("data")}}, grpc.UseCompressor("gzip", "dict-id-42")); err != nil {
t.Fatalf("Unexpected unary call error: %v", err)
}
}

func testStreamUseCompressorOptionsPropagate(t *testing.T, _ *wrapCompressor) {
t.Helper()
ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
req, err := stream.Recv()
if err != nil {
return err
}
return stream.Send(&testpb.StreamingOutputCallResponse{
Payload: req.GetPayload(),
})
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

s, err := ss.Client.FullDuplexCall(ctx, grpc.UseCompressor("gzip", "dict-id-42"))
if err != nil {
t.Fatalf("Unexpected full duplex call error: %v", err)
}
if err := s.Send(&testpb.StreamingOutputCallRequest{
Payload: &testpb.Payload{Body: []byte("payload")},
}); err != nil {
t.Fatalf("Unexpected send error: %v", err)
}
if _, err := s.Recv(); err != nil {
t.Fatalf("Unexpected recv error: %v", err)
}
}
6 changes: 4 additions & 2 deletions encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ func init() {
type Compressor interface {
// Compress writes the data written to wc to w after compressing it. If an
// error occurs while initializing the compressor, that error is returned
// instead.
Compress(w io.Writer) (io.WriteCloser, error)
// instead. opts passes caller-provided context to the compressor (e.g.
// dictionary IDs for trained compression formats). Unknown options must
// be silently ignored.
Compress(w io.Writer, opts ...any) (io.WriteCloser, error)
// Decompress reads data from r, decompresses it, and provides the
// uncompressed data via the returned io.Reader. If an error occurs while
// initializing the decompressor, that error is returned instead.
Expand Down
2 changes: 1 addition & 1 deletion encoding/gzip/gzip.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func SetLevel(level int) error {
return nil
}

func (c *compressor) Compress(w io.Writer) (io.WriteCloser, error) {
func (c *compressor) Compress(w io.Writer, _ ...any) (io.WriteCloser, error) {
z := c.poolCompressor.Get().(*writer)
z.Writer.Reset(w)
return z, nil
Expand Down
13 changes: 11 additions & 2 deletions internal/transport/server_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ type ServerStream struct {
headerSent atomic.Bool // atomically set when the headers are sent out.

headerWireLength int

sendCompressOptions []any
}

// Read reads an n byte message from the input stream.
Expand Down Expand Up @@ -108,13 +110,20 @@ func (s *ServerStream) ContentSubtype() string {
return s.contentSubtype
}

// SetSendCompress sets the compression algorithm to the stream.
func (s *ServerStream) SetSendCompress(name string) error {
// SendCompressOptions returns the compressor options set for the stream.
func (s *ServerStream) SendCompressOptions() []any {
return s.sendCompressOptions
}

// SetSendCompress sets the compression algorithm to the stream. opts are
// forwarded to the compressor's Compress method on each send.
func (s *ServerStream) SetSendCompress(name string, opts ...any) error {
if s.isHeaderSent() || s.getState() == streamDone {
return errors.New("transport: set send compressor called after headers sent or stream done")
}

s.sendCompress = name
s.sendCompressOptions = opts
return nil
}

Expand Down
17 changes: 11 additions & 6 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ type callInfo struct {
onFinish []func(err error)
authority string
acceptedResponseCompressors []string
compressorOptions []any
}

func acceptedCompressorAllows(allowed []string, name string) bool {
Expand Down Expand Up @@ -490,14 +491,16 @@ func (o PerRPCCredsCallOption) after(*callInfo, *csAttempt) {}

// UseCompressor returns a CallOption which sets the compressor used when
// sending the request. If WithCompressor is also set, UseCompressor has
// higher priority.
// higher priority. The optional compressorOptions are forwarded to the
// compressor's Compress method, allowing callers to pass additional context
// such as dictionary IDs for trained compression formats.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func UseCompressor(name string) CallOption {
return CompressorCallOption{CompressorType: name}
func UseCompressor(name string, compressorOptions ...any) CallOption {
Comment thread
Pranjali-2501 marked this conversation as resolved.
return CompressorCallOption{CompressorType: name, CompressorOptions: compressorOptions}
}

// CompressorCallOption is a CallOption that indicates the compressor to use.
Expand All @@ -507,11 +510,13 @@ func UseCompressor(name string) CallOption {
// Notice: This type is EXPERIMENTAL and may be changed or removed in a
// later release.
type CompressorCallOption struct {
CompressorType string
CompressorType string
CompressorOptions []any
}

func (o CompressorCallOption) before(c *callInfo) error {
c.compressorName = o.CompressorType
c.compressorOptions = o.CompressorOptions
return nil
}
func (o CompressorCallOption) after(*callInfo, *csAttempt) {}
Expand Down Expand Up @@ -817,7 +822,7 @@ func encode(c baseCodec, msg any) (mem.BufferSlice, error) {
// indicating no compression was done.
//
// TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor.
func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, pool mem.BufferPool) (mem.BufferSlice, payloadFormat, error) {
func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, pool mem.BufferPool, compressorOptions ...any) (mem.BufferSlice, payloadFormat, error) {
if (compressor == nil && cp == nil) || in.Len() == 0 {
return nil, compressionNone, nil
}
Expand All @@ -828,7 +833,7 @@ func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor,
return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
if compressor != nil {
z, err := compressor.Compress(w)
z, err := compressor.Compress(w, compressorOptions...)
if err != nil {
return nil, 0, wrapErr(err)
}
Expand Down
4 changes: 2 additions & 2 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type testCompressorForRegistry struct {
name string
}

func (c *testCompressorForRegistry) Compress(w io.Writer) (io.WriteCloser, error) {
func (c *testCompressorForRegistry) Compress(w io.Writer, _ ...any) (io.WriteCloser, error) {
return &testWriteCloser{w}, nil
}

Expand Down Expand Up @@ -541,7 +541,7 @@ type mockCompressor struct {
ch chan<- struct{}
}

func (m *mockCompressor) Compress(io.Writer) (io.WriteCloser, error) {
func (m *mockCompressor) Compress(io.Writer, ...any) (io.WriteCloser, error) {
panic("unimplemented")
}

Expand Down
14 changes: 9 additions & 5 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1172,14 +1172,14 @@ func (s *Server) incrCallsFailed() {
s.channelz.ServerMetrics.CallsFailed.Add(1)
}

func (s *Server) sendResponse(ctx context.Context, stream *transport.ServerStream, msg any, cp Compressor, opts *transport.WriteOptions, comp encoding.Compressor) error {
func (s *Server) sendResponse(ctx context.Context, stream *transport.ServerStream, msg any, cp Compressor, opts *transport.WriteOptions, comp encoding.Compressor, compressorOptions ...any) error {
data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
if err != nil {
channelz.Error(logger, s.channelz, "grpc: server failed to encode response: ", err)
return err
}

compData, pf, err := compress(data, cp, comp, s.opts.bufferPool)
compData, pf, err := compress(data, cp, comp, s.opts.bufferPool, compressorOptions...)
if err != nil {
data.Free()
channelz.Error(logger, s.channelz, "grpc: server failed to compress response: ", err)
Expand Down Expand Up @@ -1474,7 +1474,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt
if stream.SendCompress() != sendCompressorName {
comp = encoding.GetCompressor(stream.SendCompress())
}
if err := s.sendResponse(ctx, stream, reply, cp, opts, comp); err != nil {
if err := s.sendResponse(ctx, stream, reply, cp, opts, comp, stream.SendCompressOptions()...); err != nil {
if err == io.EOF {
// The entire stream is done (for unary RPC only).
return err
Expand Down Expand Up @@ -2146,11 +2146,15 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
// It is not safe to call SetSendCompressor concurrently with SendHeader and
// SendMsg.
//
// The optional compressorOptions are forwarded to the compressor's Compress
// method on each SendMsg call, allowing callers to pass additional context
// such as dictionary IDs for trained compression formats.
//
// # Experimental
//
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
// later release.
func SetSendCompressor(ctx context.Context, name string) error {
func SetSendCompressor(ctx context.Context, name string, compressorOptions ...any) error {
Comment thread
Pranjali-2501 marked this conversation as resolved.
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.ServerStream)
if !ok || stream == nil {
return fmt.Errorf("failed to fetch the stream from the given context")
Expand All @@ -2160,7 +2164,7 @@ func SetSendCompressor(ctx context.Context, name string) error {
return fmt.Errorf("unable to set send compressor: %w", err)
}

return stream.SetSendCompress(name)
return stream.SetSendCompress(name, compressorOptions...)
}

// ClientSupportedCompressors returns compressor names advertised by the client
Expand Down
Loading
Loading