From 6d653799accf0787efb440642baba1f918db4e8c Mon Sep 17 00:00:00 2001 From: dostonlv Date: Thu, 2 Apr 2026 23:38:54 +0500 Subject: [PATCH] encoding: add support for compressor options in Compress methods --- encoding/compressor_test.go | 181 +++++++++++++++++++++++++++- encoding/encoding.go | 6 +- encoding/gzip/gzip.go | 2 +- internal/transport/server_stream.go | 13 +- rpc_util.go | 17 ++- rpc_util_test.go | 4 +- server.go | 14 ++- stream.go | 17 ++- 8 files changed, 225 insertions(+), 29 deletions(-) diff --git a/encoding/compressor_test.go b/encoding/compressor_test.go index 18260ae37078..86e1657d2c81 100644 --- a/encoding/compressor_test.go +++ b/encoding/compressor_test.go @@ -22,6 +22,7 @@ import ( "bytes" "context" "io" + "sync" "sync/atomic" "testing" @@ -38,16 +39,21 @@ import ( _ "google.golang.org/grpc/encoding/gzip" ) -// wrapCompressor is a wrapper of encoding.Compressor which maintains count of -// Compressor method invokes. +// wrapCompressor is a wrapper of encoding.Compressor which records invocation +// count and the options passed to each Compress call. type wrapCompressor struct { encoding.Compressor compressInvokes int32 + mu sync.Mutex + receivedOpts [][]any } -func (wc *wrapCompressor) Compress(w io.Writer) (io.WriteCloser, error) { +func (wc *wrapCompressor) Compress(w io.Writer, opts ...any) (io.WriteCloser, error) { atomic.AddInt32(&wc.compressInvokes, 1) - return wc.Compressor.Compress(w) + wc.mu.Lock() + wc.receivedOpts = append(wc.receivedOpts, opts) + wc.mu.Unlock() + return wc.Compressor.Compress(w, opts...) } func setupGzipWrapCompressor(t *testing.T) *wrapCompressor { @@ -186,7 +192,7 @@ type fakeCompressor struct { decompressedMessageSize int } -func (f *fakeCompressor) Compress(w io.Writer) (io.WriteCloser, error) { +func (f *fakeCompressor) Compress(w io.Writer, _ ...any) (io.WriteCloser, error) { return nopWriteCloser{w}, nil } @@ -237,3 +243,168 @@ func (s) TestDecompressionExceedsMaxMessageSize(t *testing.T) { t.Errorf("Client.UnaryCall(%+v) returned status %v, want %v", req, got, want) } } + +// TestSetSendCompressorOptionsPropagate verifies that options passed to +// SetSendCompressor are forwarded to the compressor's Compress method. +func (s) TestSetSendCompressorOptionsPropagate(t *testing.T) { + wantOpt := "dict-id-42" + for _, tc := range []struct { + name string + run func(*testing.T, *wrapCompressor) + }{ + {"unary", testUnarySendCompressorOptionsPropagate}, + {"stream", testStreamSendCompressorOptionsPropagate}, + } { + t.Run(tc.name, func(t *testing.T) { + wc := setupGzipWrapCompressor(t) + tc.run(t, wc) + wc.mu.Lock() + defer wc.mu.Unlock() + if len(wc.receivedOpts) == 0 { + t.Fatal("Compress was not called") + } + if got := wc.receivedOpts[0]; len(got) == 0 || got[0] != wantOpt { + t.Fatalf("Compress received opts %v, want [%q]", got, wantOpt) + } + }) + } +} + +func testUnarySendCompressorOptionsPropagate(t *testing.T, _ *wrapCompressor) { + t.Helper() + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, _ *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + if err := grpc.SetSendCompressor(ctx, "gzip", "dict-id-42"); err != nil { + return nil, err + } + return &testpb.SimpleResponse{Payload: &testpb.Payload{Body: []byte("payload")}}, nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil { + t.Fatalf("Unexpected unary call error: %v", err) + } +} + +func testStreamSendCompressorOptionsPropagate(t *testing.T, _ *wrapCompressor) { + t.Helper() + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { + if _, err := stream.Recv(); err != nil { + return err + } + if err := grpc.SetSendCompressor(stream.Context(), "gzip", "dict-id-42"); err != nil { + return err + } + return stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: &testpb.Payload{Body: []byte("payload")}, + }) + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + s, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("Unexpected full duplex call error: %v", err) + } + if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil { + t.Fatalf("Unexpected send error: %v", err) + } + if _, err := s.Recv(); err != nil { + t.Fatalf("Unexpected recv error: %v", err) + } +} + +// TestUseCompressorOptionsPropagate verifies that options passed to +// UseCompressor are forwarded to the compressor's Compress method. +func (s) TestUseCompressorOptionsPropagate(t *testing.T) { + wantOpt := "dict-id-42" + for _, tc := range []struct { + name string + run func(*testing.T, *wrapCompressor) + }{ + {"unary", testUnaryUseCompressorOptionsPropagate}, + {"stream", testStreamUseCompressorOptionsPropagate}, + } { + t.Run(tc.name, func(t *testing.T) { + wc := setupGzipWrapCompressor(t) + tc.run(t, wc) + wc.mu.Lock() + defer wc.mu.Unlock() + if len(wc.receivedOpts) == 0 { + t.Fatal("Compress was not called") + } + if got := wc.receivedOpts[0]; len(got) == 0 || got[0] != wantOpt { + t.Fatalf("Compress received opts %v, want [%q]", got, wantOpt) + } + }) + } +} + +func testUnaryUseCompressorOptionsPropagate(t *testing.T, _ *wrapCompressor) { + t.Helper() + ss := &stubserver.StubServer{ + UnaryCallF: func(_ context.Context, _ *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{Payload: &testpb.Payload{Body: []byte("payload")}}, nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{Body: []byte("data")}}, grpc.UseCompressor("gzip", "dict-id-42")); err != nil { + t.Fatalf("Unexpected unary call error: %v", err) + } +} + +func testStreamUseCompressorOptionsPropagate(t *testing.T, _ *wrapCompressor) { + t.Helper() + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { + req, err := stream.Recv() + if err != nil { + return err + } + return stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: req.GetPayload(), + }) + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + s, err := ss.Client.FullDuplexCall(ctx, grpc.UseCompressor("gzip", "dict-id-42")) + if err != nil { + t.Fatalf("Unexpected full duplex call error: %v", err) + } + if err := s.Send(&testpb.StreamingOutputCallRequest{ + Payload: &testpb.Payload{Body: []byte("payload")}, + }); err != nil { + t.Fatalf("Unexpected send error: %v", err) + } + if _, err := s.Recv(); err != nil { + t.Fatalf("Unexpected recv error: %v", err) + } +} diff --git a/encoding/encoding.go b/encoding/encoding.go index 296f38c3a804..fd5a39f4223b 100644 --- a/encoding/encoding.go +++ b/encoding/encoding.go @@ -61,8 +61,10 @@ func init() { type Compressor interface { // Compress writes the data written to wc to w after compressing it. If an // error occurs while initializing the compressor, that error is returned - // instead. - Compress(w io.Writer) (io.WriteCloser, error) + // instead. opts passes caller-provided context to the compressor (e.g. + // dictionary IDs for trained compression formats). Unknown options must + // be silently ignored. + Compress(w io.Writer, opts ...any) (io.WriteCloser, error) // Decompress reads data from r, decompresses it, and provides the // uncompressed data via the returned io.Reader. If an error occurs while // initializing the decompressor, that error is returned instead. diff --git a/encoding/gzip/gzip.go b/encoding/gzip/gzip.go index 153e4dbfbf7a..cf99fe4076a3 100644 --- a/encoding/gzip/gzip.go +++ b/encoding/gzip/gzip.go @@ -70,7 +70,7 @@ func SetLevel(level int) error { return nil } -func (c *compressor) Compress(w io.Writer) (io.WriteCloser, error) { +func (c *compressor) Compress(w io.Writer, _ ...any) (io.WriteCloser, error) { z := c.poolCompressor.Get().(*writer) z.Writer.Reset(w) return z, nil diff --git a/internal/transport/server_stream.go b/internal/transport/server_stream.go index ed6a13b7501a..bb496f7cc651 100644 --- a/internal/transport/server_stream.go +++ b/internal/transport/server_stream.go @@ -50,6 +50,8 @@ type ServerStream struct { headerSent atomic.Bool // atomically set when the headers are sent out. headerWireLength int + + sendCompressOptions []any } // Read reads an n byte message from the input stream. @@ -108,13 +110,20 @@ func (s *ServerStream) ContentSubtype() string { return s.contentSubtype } -// SetSendCompress sets the compression algorithm to the stream. -func (s *ServerStream) SetSendCompress(name string) error { +// SendCompressOptions returns the compressor options set for the stream. +func (s *ServerStream) SendCompressOptions() []any { + return s.sendCompressOptions +} + +// SetSendCompress sets the compression algorithm to the stream. opts are +// forwarded to the compressor's Compress method on each send. +func (s *ServerStream) SetSendCompress(name string, opts ...any) error { if s.isHeaderSent() || s.getState() == streamDone { return errors.New("transport: set send compressor called after headers sent or stream done") } s.sendCompress = name + s.sendCompressOptions = opts return nil } diff --git a/rpc_util.go b/rpc_util.go index ee7f7dead1a3..bb2b5c79f007 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -168,6 +168,7 @@ type callInfo struct { onFinish []func(err error) authority string acceptedResponseCompressors []string + compressorOptions []any } func acceptedCompressorAllows(allowed []string, name string) bool { @@ -490,14 +491,16 @@ func (o PerRPCCredsCallOption) after(*callInfo, *csAttempt) {} // UseCompressor returns a CallOption which sets the compressor used when // sending the request. If WithCompressor is also set, UseCompressor has -// higher priority. +// higher priority. The optional compressorOptions are forwarded to the +// compressor's Compress method, allowing callers to pass additional context +// such as dictionary IDs for trained compression formats. // // # Experimental // // Notice: This API is EXPERIMENTAL and may be changed or removed in a // later release. -func UseCompressor(name string) CallOption { - return CompressorCallOption{CompressorType: name} +func UseCompressor(name string, compressorOptions ...any) CallOption { + return CompressorCallOption{CompressorType: name, CompressorOptions: compressorOptions} } // CompressorCallOption is a CallOption that indicates the compressor to use. @@ -507,11 +510,13 @@ func UseCompressor(name string) CallOption { // Notice: This type is EXPERIMENTAL and may be changed or removed in a // later release. type CompressorCallOption struct { - CompressorType string + CompressorType string + CompressorOptions []any } func (o CompressorCallOption) before(c *callInfo) error { c.compressorName = o.CompressorType + c.compressorOptions = o.CompressorOptions return nil } func (o CompressorCallOption) after(*callInfo, *csAttempt) {} @@ -817,7 +822,7 @@ func encode(c baseCodec, msg any) (mem.BufferSlice, error) { // indicating no compression was done. // // TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor. -func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, pool mem.BufferPool) (mem.BufferSlice, payloadFormat, error) { +func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, pool mem.BufferPool, compressorOptions ...any) (mem.BufferSlice, payloadFormat, error) { if (compressor == nil && cp == nil) || in.Len() == 0 { return nil, compressionNone, nil } @@ -828,7 +833,7 @@ func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) } if compressor != nil { - z, err := compressor.Compress(w) + z, err := compressor.Compress(w, compressorOptions...) if err != nil { return nil, 0, wrapErr(err) } diff --git a/rpc_util_test.go b/rpc_util_test.go index 79628d1be1d1..31dfdf66f1e6 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -52,7 +52,7 @@ type testCompressorForRegistry struct { name string } -func (c *testCompressorForRegistry) Compress(w io.Writer) (io.WriteCloser, error) { +func (c *testCompressorForRegistry) Compress(w io.Writer, _ ...any) (io.WriteCloser, error) { return &testWriteCloser{w}, nil } @@ -541,7 +541,7 @@ type mockCompressor struct { ch chan<- struct{} } -func (m *mockCompressor) Compress(io.Writer) (io.WriteCloser, error) { +func (m *mockCompressor) Compress(io.Writer, ...any) (io.WriteCloser, error) { panic("unimplemented") } diff --git a/server.go b/server.go index 5229adf71174..2b4ea18a03bb 100644 --- a/server.go +++ b/server.go @@ -1172,14 +1172,14 @@ func (s *Server) incrCallsFailed() { s.channelz.ServerMetrics.CallsFailed.Add(1) } -func (s *Server) sendResponse(ctx context.Context, stream *transport.ServerStream, msg any, cp Compressor, opts *transport.WriteOptions, comp encoding.Compressor) error { +func (s *Server) sendResponse(ctx context.Context, stream *transport.ServerStream, msg any, cp Compressor, opts *transport.WriteOptions, comp encoding.Compressor, compressorOptions ...any) error { data, err := encode(s.getCodec(stream.ContentSubtype()), msg) if err != nil { channelz.Error(logger, s.channelz, "grpc: server failed to encode response: ", err) return err } - compData, pf, err := compress(data, cp, comp, s.opts.bufferPool) + compData, pf, err := compress(data, cp, comp, s.opts.bufferPool, compressorOptions...) if err != nil { data.Free() channelz.Error(logger, s.channelz, "grpc: server failed to compress response: ", err) @@ -1474,7 +1474,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt if stream.SendCompress() != sendCompressorName { comp = encoding.GetCompressor(stream.SendCompress()) } - if err := s.sendResponse(ctx, stream, reply, cp, opts, comp); err != nil { + if err := s.sendResponse(ctx, stream, reply, cp, opts, comp, stream.SendCompressOptions()...); err != nil { if err == io.EOF { // The entire stream is done (for unary RPC only). return err @@ -2146,11 +2146,15 @@ func SendHeader(ctx context.Context, md metadata.MD) error { // It is not safe to call SetSendCompressor concurrently with SendHeader and // SendMsg. // +// The optional compressorOptions are forwarded to the compressor's Compress +// method on each SendMsg call, allowing callers to pass additional context +// such as dictionary IDs for trained compression formats. +// // # Experimental // // Notice: This function is EXPERIMENTAL and may be changed or removed in a // later release. -func SetSendCompressor(ctx context.Context, name string) error { +func SetSendCompressor(ctx context.Context, name string, compressorOptions ...any) error { stream, ok := ServerTransportStreamFromContext(ctx).(*transport.ServerStream) if !ok || stream == nil { return fmt.Errorf("failed to fetch the stream from the given context") @@ -2160,7 +2164,7 @@ func SetSendCompressor(ctx context.Context, name string) error { return fmt.Errorf("unable to set send compressor: %w", err) } - return stream.SetSendCompress(name) + return stream.SetSendCompress(name, compressorOptions...) } // ClientSupportedCompressors returns compressor names advertised by the client diff --git a/stream.go b/stream.go index 046f5493d26f..04b337aec82f 100644 --- a/stream.go +++ b/stream.go @@ -375,6 +375,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client firstAttempt: true, onCommit: onCommit, nameResolutionDelay: nameResolutionDelayed, + compressorOptions: callInfo.compressorOptions, } if !cc.dopts.disableRetry { cs.retryThrottler = cc.retryThrottler.Load().(*retryThrottler) @@ -633,6 +634,8 @@ type clientStream struct { // nameResolutionDelay indicates if there was a delay in the name resolution. // This field is only valid on client side, it's always false on server side. nameResolutionDelay bool + + compressorOptions []any } type replayOp struct { @@ -964,7 +967,7 @@ func (cs *clientStream) SendMsg(m any) (err error) { } // load hdr, payload, data - hdr, data, payload, pf, err := prepareMsg(m, cs.codec, cs.compressorV0, cs.compressorV1, cs.cc.dopts.copts.BufferPool) + hdr, data, payload, pf, err := prepareMsg(m, cs.codec, cs.compressorV0, cs.compressorV1, cs.cc.dopts.copts.BufferPool, cs.compressorOptions...) if err != nil { return err } @@ -1471,7 +1474,7 @@ func (as *addrConnStream) SendMsg(m any) (err error) { } // load hdr, payload, data - hdr, data, payload, pf, err := prepareMsg(m, as.codec, as.sendCompressorV0, as.sendCompressorV1, as.ac.dopts.copts.BufferPool) + hdr, data, payload, pf, err := prepareMsg(m, as.codec, as.sendCompressorV0, as.sendCompressorV1, as.ac.dopts.copts.BufferPool, as.callInfo.compressorOptions...) if err != nil { return err } @@ -1669,7 +1672,8 @@ type serverStream struct { // synchronized. serverHeaderBinlogged bool - mu sync.Mutex // protects trInfo.tr after the service handler runs. + mu sync.Mutex // protects trInfo.tr after the service handler runs. + sendCompressorOptions []any } func (ss *serverStream) Context() context.Context { @@ -1748,10 +1752,11 @@ func (ss *serverStream) SendMsg(m any) (err error) { if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName { ss.compressorV1 = encoding.GetCompressor(sendCompressorsName) ss.sendCompressorName = sendCompressorsName + ss.sendCompressorOptions = ss.s.SendCompressOptions() } // load hdr, payload, data - hdr, data, payload, pf, err := prepareMsg(m, ss.codec, ss.compressorV0, ss.compressorV1, ss.p.bufferPool) + hdr, data, payload, pf, err := prepareMsg(m, ss.codec, ss.compressorV0, ss.compressorV1, ss.p.bufferPool, ss.sendCompressorOptions...) if err != nil { return err } @@ -1893,7 +1898,7 @@ func MethodFromServerStream(stream ServerStream) (string, bool) { // compression was made and therefore whether the payload needs to be freed in // addition to the returned data. Freeing the payload if the returned boolean is // false can lead to undefined behavior. -func prepareMsg(m any, codec baseCodec, cp Compressor, comp encoding.Compressor, pool mem.BufferPool) (hdr []byte, data, payload mem.BufferSlice, pf payloadFormat, err error) { +func prepareMsg(m any, codec baseCodec, cp Compressor, comp encoding.Compressor, pool mem.BufferPool, compressorOptions ...any) (hdr []byte, data, payload mem.BufferSlice, pf payloadFormat, err error) { if preparedMsg, ok := m.(*PreparedMsg); ok { return preparedMsg.hdr, preparedMsg.encodedData, preparedMsg.payload, preparedMsg.pf, nil } @@ -1903,7 +1908,7 @@ func prepareMsg(m any, codec baseCodec, cp Compressor, comp encoding.Compressor, if err != nil { return nil, nil, nil, 0, err } - compData, pf, err := compress(data, cp, comp, pool) + compData, pf, err := compress(data, cp, comp, pool, compressorOptions...) if err != nil { data.Free() return nil, nil, nil, 0, err