diff --git a/rpc_util.go b/rpc_util.go index c3651a470860..52f4ea513df1 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -128,6 +128,16 @@ func NewGZIPDecompressor() Decompressor { } func (d *gzipDecompressor) Do(r io.Reader) ([]byte, error) { + return d.doWithMaxSize(r, math.MaxInt64) +} + +// doWithMaxSize behaves like Do but caps the size of the decompressed +// payload at maxMessageSize+1 bytes. The Decompressor interface does not +// allow extra parameters, so callers inside the package type-assert to +// *gzipDecompressor to invoke this method directly. The +1 byte makes it +// possible for the caller to detect that the limit was exceeded and +// return ResourceExhausted instead of materializing an unbounded payload. +func (d *gzipDecompressor) doWithMaxSize(r io.Reader, maxMessageSize int64) ([]byte, error) { var z *gzip.Reader switch maybeZ := d.pool.Get().(type) { case nil: @@ -148,7 +158,11 @@ func (d *gzipDecompressor) Do(r io.Reader) ([]byte, error) { z.Close() d.pool.Put(z) }() - return io.ReadAll(z) + var src io.Reader = z + if maxMessageSize < math.MaxInt64 { + src = io.LimitReader(z, maxMessageSize+1) + } + return io.ReadAll(src) } func (d *gzipDecompressor) Type() string { @@ -971,7 +985,20 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompressor, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, error) { if dc != nil { r := d.Reader() - uncompressed, err := dc.Do(r) + // For the built-in gzip decompressor, bound the decompressed output + // at maxReceiveMessageSize+1 so that a small but highly compressed + // payload (a "zip bomb") cannot expand to gigabytes in memory before + // the post-decompression size check below has a chance to fire. The + // Decompressor interface does not accept an extra size parameter, + // so we type-assert to invoke a size-aware helper. Third-party + // Decompressor implementations keep the original Do behavior. + var uncompressed []byte + var err error + if gd, ok := dc.(*gzipDecompressor); ok { + uncompressed, err = gd.doWithMaxSize(r, int64(maxReceiveMessageSize)) + } else { + uncompressed, err = dc.Do(r) + } if err != nil { r.Close() // ensure buffers are reused return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err) diff --git a/rpc_util_test.go b/rpc_util_test.go index 2abe7516fb4d..08a4a33feeaf 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -419,16 +419,17 @@ func BenchmarkGZIPCompressor1MiB(b *testing.B) { bmCompressor(b, 1024*1024, NewGZIPCompressor()) } -// compressWithDeterministicError compresses the input data and returns a BufferSlice. -func compressWithDeterministicError(t *testing.T, input []byte) mem.BufferSlice { +// mustCompress gzip-compresses input and returns it as a BufferSlice, +// failing the test if compression fails. +func mustCompress(t *testing.T, input []byte) mem.BufferSlice { t.Helper() var buf bytes.Buffer gz := gzip.NewWriter(&buf) if _, err := gz.Write(input); err != nil { - t.Fatalf("compressInput() failed to write data: %v", err) + t.Fatalf("mustCompress() failed to write data: %v", err) } if err := gz.Close(); err != nil { - t.Fatalf("compressInput() failed to close gzip writer: %v", err) + t.Fatalf("mustCompress() failed to close gzip writer: %v", err) } compressedData := buf.Bytes() return mem.BufferSlice{mem.NewBuffer(&compressedData, nil)} @@ -475,7 +476,7 @@ func (s) TestDecompress(t *testing.T) { }{ { name: "Decompresses successfully with sufficient buffer size", - input: compressWithDeterministicError(t, []byte("decompressed data")), + input: mustCompress(t, []byte("decompressed data")), dc: nil, maxReceiveMessageSize: 50, want: []byte("decompressed data"), @@ -483,7 +484,7 @@ func (s) TestDecompress(t *testing.T) { }, { name: "Fails due to exceeding maxReceiveMessageSize", - input: compressWithDeterministicError(t, []byte("message that is too large")), + input: mustCompress(t, []byte("message that is too large")), dc: nil, maxReceiveMessageSize: len("message that is too large") - 1, want: nil, @@ -491,7 +492,7 @@ func (s) TestDecompress(t *testing.T) { }, { name: "Decompresses to exactly maxReceiveMessageSize", - input: compressWithDeterministicError(t, []byte("exact size message")), + input: mustCompress(t, []byte("exact size message")), dc: nil, maxReceiveMessageSize: len("exact size message"), want: []byte("exact size message"), @@ -499,7 +500,7 @@ func (s) TestDecompress(t *testing.T) { }, { name: "Decompresses successfully with maxReceiveMessageSize MaxInt", - input: compressWithDeterministicError(t, []byte("large message")), + input: mustCompress(t, []byte("large message")), dc: nil, maxReceiveMessageSize: math.MaxInt, want: []byte("large message"), @@ -507,7 +508,7 @@ func (s) TestDecompress(t *testing.T) { }, { name: "Fails with decompression error due to invalid format", - input: compressWithDeterministicError(t, []byte("invalid compressed data")), + input: mustCompress(t, []byte("invalid compressed data")), dc: invalidFormatDecompressor, maxReceiveMessageSize: 50, want: nil, @@ -515,12 +516,32 @@ func (s) TestDecompress(t *testing.T) { }, { name: "Fails with resourceExhausted error when decompressed message exceeds maxReceiveMessageSize", - input: compressWithDeterministicError(t, []byte("large compressed data")), + input: mustCompress(t, []byte("large compressed data")), dc: validDecompressor, maxReceiveMessageSize: 20, want: nil, wantErr: status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", 25, 20), }, + { + // Bombs the legacy gzipDecompressor with 1 MiB of zeros (which + // gzips down to a few KiB). doWithMaxSize must cap the read at + // maxRecv+1 bytes; the error reports exactly that size so we + // know only maxRecv+1 bytes were materialised. + name: "Legacy gzipDecompressor bounds decompressed bomb to maxReceiveMessageSize+1", + input: mustCompress(t, make([]byte, 1<<20)), + dc: NewGZIPDecompressor(), + maxReceiveMessageSize: 1024, + want: nil, + wantErr: status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", 1025, 1024), + }, + { + name: "Legacy gzipDecompressor decompresses successfully when payload fits", + input: mustCompress(t, []byte("hello legacy decompressor")), + dc: NewGZIPDecompressor(), + maxReceiveMessageSize: len("hello legacy decompressor"), + want: []byte("hello legacy decompressor"), + wantErr: nil, + }, } for _, tc := range testCases { @@ -637,7 +658,7 @@ func (s) TestDecompress_ClosesReader(t *testing.T) { ch := make(chan struct{}) compressor := &fakeCloseCompressor{ch: ch} - in := compressWithDeterministicError(t, []byte("some data")) + in := mustCompress(t, []byte("some data")) out, err := decompress(compressor, in, nil, tc.maxReceiveMessageSize, mem.DefaultBufferPool()) if status.Code(err) != tc.wantCode { t.Fatalf("decompress() failed with error code %v, want %v", status.Code(err), tc.wantCode)