diff --git a/src/codec/framed_write.rs b/src/codec/framed_write.rs index 17e557623..bddd10ba2 100644 --- a/src/codec/framed_write.rs +++ b/src/codec/framed_write.rs @@ -4,12 +4,13 @@ use crate::frame::{self, Frame, FrameSize}; use crate::hpack; use bytes::{Buf, BufMut, BytesMut}; +use std::collections::VecDeque; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio_util::io::poll_write_buf; -use std::io::{self, Cursor}; +use std::io::{self, Cursor, IoSlice}; +use std::ops::ControlFlow; // A macro to get around a method needing to borrow &mut self macro_rules! limited_write_buf { @@ -38,11 +39,11 @@ struct Encoder { /// TODO: Should this be a ring buffer? buf: Cursor, - /// Next frame to encode - next: Option>, + /// Vector of buffer data and data frames to send next + next_vec: VecDeque>, - /// Last data frame - last_data_frame: Option>, + /// Next continuation frame to encode + next_continuation: Option, /// Max frame size, this is specified by the peer max_frame_size: FrameSize, @@ -55,9 +56,11 @@ struct Encoder { } #[derive(Debug)] -enum Next { - Data(frame::Data), - Continuation(frame::Continuation), +struct BufElement { + /// Number of bytes in the buffer that should be written before the next data frame payload. + buf_len: usize, + /// Data frame of the payload that should be written as part of this buffer element. + data_frame: frame::Data, } /// Initialize the connection with this amount of write buffer. @@ -76,6 +79,8 @@ const CHAIN_THRESHOLD: usize = 256; /// fragmented data being sent, and hereby improve the throughput. const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024; +const MAX_VECTORED_IO_COUNT: usize = 1024; + // TODO: Make generic impl FramedWrite where @@ -83,10 +88,10 @@ where B: Buf, { pub fn new(inner: T) -> FramedWrite { - let chain_threshold = if inner.is_write_vectored() { - CHAIN_THRESHOLD + let (chain_threshold, next_vec_capacity) = if inner.is_write_vectored() { + (CHAIN_THRESHOLD, MAX_VECTORED_IO_COUNT / 2) } else { - CHAIN_THRESHOLD_WITHOUT_VECTORED_IO + (CHAIN_THRESHOLD_WITHOUT_VECTORED_IO, 1) }; FramedWrite { inner, @@ -94,8 +99,8 @@ where encoder: Encoder { hpack: hpack::Encoder::default(), buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), - next: None, - last_data_frame: None, + next_vec: VecDeque::with_capacity(next_vec_capacity), + next_continuation: None, max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE, chain_threshold, min_buffer_capacity: chain_threshold + frame::HEADER_LEN, @@ -108,16 +113,23 @@ where /// Calling this function may result in the current contents of the buffer /// to be flushed to `T`. pub fn poll_ready(&mut self, cx: &mut Context) -> Poll> { - if !self.encoder.has_capacity() { + if !self.encoder.has_vec_capacity() { + self.poll_ready_inner(cx) + } else { + Poll::Ready(Ok(())) + } + } + + #[cold] + fn poll_ready_inner(&mut self, cx: &mut Context) -> Poll> { + loop { // Try flushing - ready!(self.flush(cx))?; + ready!(self.flush_inner(cx, /* flush_all: */ false))?; - if !self.encoder.has_capacity() { - return Poll::Pending; + if self.encoder.has_capacity() { + return Poll::Ready(Ok(())); } } - - Poll::Ready(Ok(())) } /// Buffer a frame. @@ -130,31 +142,35 @@ where /// Flush buffered data to the wire pub fn flush(&mut self, cx: &mut Context) -> Poll> { - let span = tracing::trace_span!("FramedWrite::flush"); + self.flush_inner(cx, /* flush_all: */ true) + } + + #[inline] + fn flush_inner(&mut self, cx: &mut Context, flush_all: bool) -> Poll> { + let span = tracing::trace_span!("FramedWrite::flush", %flush_all); let _e = span.enter(); loop { - while !self.encoder.is_empty() { - match self.encoder.next { - Some(Next::Data(ref mut frame)) => { - tracing::trace!(queued_data_frame = true); - let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut()); - ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))? - } - _ => { - tracing::trace!(queued_data_frame = false); - ready!(poll_write_buf( - Pin::new(&mut self.inner), - cx, - &mut self.encoder.buf - ))? - } - }; + while ready!(poll_write_buf( + Pin::new(&mut self.inner), + cx, + &mut self.encoder + ))? + .is_continue() + && (flush_all || !self.encoder.has_capacity()) + {} + + if flush_all { + assert_eq!(self.encoder.buf.position(), 0); + assert_eq!(self.encoder.buf.remaining(), 0); } + self.encoder.reclaim_buffer(); - match self.encoder.unset_frame() { - ControlFlow::Continue => (), - ControlFlow::Break => break, + if let Some(frame) = self.encoder.next_continuation.take() { + let mut buf = limited_write_buf!(self.encoder); + self.encoder.next_continuation = frame.encode(&mut buf); + } else { + break; } } @@ -175,38 +191,14 @@ where } } -#[must_use] -enum ControlFlow { - Continue, - Break, -} - impl Encoder where B: Buf, { - fn unset_frame(&mut self) -> ControlFlow { - // Clear internal buffer - self.buf.set_position(0); - self.buf.get_mut().clear(); - - // The data frame has been written, so unset it - match self.next.take() { - Some(Next::Data(frame)) => { - self.last_data_frame = Some(frame); - debug_assert!(self.is_empty()); - ControlFlow::Break - } - Some(Next::Continuation(frame)) => { - // Buffer the continuation frame, then try to write again - let mut buf = limited_write_buf!(self); - if let Some(continuation) = frame.encode(&mut buf) { - self.next = Some(Next::Continuation(continuation)); - } - ControlFlow::Continue - } - None => ControlFlow::Break, - } + #[inline] + fn reclaim_buffer(&mut self) { + let buf = self.buf.get_mut(); + let _ = buf.try_reclaim(buf.capacity() + buf.len() + 1); } fn buffer(&mut self, item: Frame) -> Result<(), UserError> { @@ -219,6 +211,8 @@ where match item { Frame::Data(mut v) => { + assert!(self.has_vec_capacity()); + // Ensure that the payload is not greater than the max frame. let len = v.payload().remaining(); @@ -226,6 +220,8 @@ where return Err(PayloadTooBig); } + let mut buf_len_to_push = 0; + if len >= self.chain_threshold { let head = v.head(); @@ -237,30 +233,29 @@ where self.buf.get_mut().put(v.payload_mut().take(extra_bytes)); } - // Save the data frame - self.next = Some(Next::Data(v)); + buf_len_to_push = self.buf.remaining(); + self.buf.advance(buf_len_to_push); } else { v.encode_chunk(self.buf.get_mut()); // The chunk has been fully encoded, so there is no need to // keep it around assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded"); - - // Save off the last frame... - self.last_data_frame = Some(v); } + + // Push the most recent data frame... + self.next_vec.push_back(BufElement { + buf_len: buf_len_to_push, + data_frame: v, + }); } Frame::Headers(v) => { let mut buf = limited_write_buf!(self); - if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { - self.next = Some(Next::Continuation(continuation)); - } + self.next_continuation = v.encode(&mut self.hpack, &mut buf); } Frame::PushPromise(v) => { let mut buf = limited_write_buf!(self); - if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { - self.next = Some(Next::Continuation(continuation)); - } + self.next_continuation = v.encode(&mut self.hpack, &mut buf); } Frame::Settings(v) => { v.encode(self.buf.get_mut()); @@ -296,16 +291,16 @@ where } fn has_capacity(&self) -> bool { - self.next.is_none() + self.next_continuation.is_none() && (self.buf.get_ref().capacity() - self.buf.get_ref().len() >= self.min_buffer_capacity) } - fn is_empty(&self) -> bool { - match self.next { - Some(Next::Data(ref frame)) => !frame.payload().has_remaining(), - _ => !self.buf.has_remaining(), - } + fn has_vec_capacity(&self) -> bool { + self.next_continuation.is_none() + && self.next_vec.len() < self.next_vec.capacity() + && (self.buf.get_ref().capacity() - self.buf.get_ref().len() + >= self.min_buffer_capacity) } } @@ -315,6 +310,110 @@ impl Encoder { } } +impl Buf for Encoder { + fn remaining(&self) -> usize { + let mut n = self.buf.get_ref().len(); + for next in self.next_vec.iter() { + n = n.saturating_add(next.data_frame.payload().remaining()); + } + n + } + + fn chunk(&self) -> &[u8] { + for next in self.next_vec.iter() { + if next.buf_len > 0 { + return &self.buf.get_ref()[..next.buf_len]; + } + let slice = next.data_frame.payload().chunk(); + if !slice.is_empty() { + return slice; + } + } + self.buf.get_ref() + } + + fn advance(&mut self, mut n: usize) { + for next in self.next_vec.iter_mut() { + if next.buf_len > 0 { + let i = n.min(next.buf_len); + self.buf.get_mut().advance(i); + self.buf + .set_position(self.buf.position().checked_sub(i as u64).unwrap()); + n -= i; + next.buf_len -= i; + if next.buf_len > 0 { + return; + } + } + let rem = next.data_frame.payload().remaining(); + if rem > 0 { + let i = n.min(rem); + n -= i; + next.data_frame.payload_mut().advance(i); + if i < rem { + return; + } + } + } + if n > 0 { + self.buf.get_mut().advance(n); + } + } + + fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize { + let mut n = 0; + let mut buf_index = 0; + let mut vec_iter = self.next_vec.iter(); + while n < dst.len() { + if let Some(next) = vec_iter.next() { + if next.buf_len > 0 { + let buf_end = buf_index + next.buf_len; + dst[n] = IoSlice::new(&self.buf.get_ref()[buf_index..buf_end]); + buf_index = buf_end; + n += 1; + if n == dst.len() { + break; + } + } + let mut rem = next.data_frame.payload().remaining(); + if rem > 0 { + let n0 = n; + n = n.wrapping_add(next.data_frame.payload().chunks_vectored(&mut dst[n..])); + assert!(n0 <= n && n <= dst.len()); + if rem < usize::MAX { + for s in &dst[n0..n] { + rem = rem.saturating_sub(s.len()); + } + } + if rem > 0 { + break; + } + } + assert!(n <= dst.len()); + } else { + if self.buf.get_ref().len() > buf_index { + dst[n] = IoSlice::new(&self.buf.get_ref()[buf_index..]); + n += 1; + } + break; + } + } + n + } + + fn has_remaining(&self) -> bool { + if !self.buf.get_ref().is_empty() { + return true; + } + for next in self.next_vec.iter() { + if next.data_frame.payload().has_remaining() { + return true; + } + } + false + } +} + impl FramedWrite { /// Returns the max frame size that can be sent pub fn max_frame_size(&self) -> usize { @@ -332,16 +431,21 @@ impl FramedWrite { self.encoder.hpack.update_max_size(val); } - /// Retrieve the last data frame that has been sent - pub fn take_last_data_frame(&mut self) -> Option> { - self.encoder.last_data_frame.take() - } - pub fn get_mut(&mut self) -> &mut T { &mut self.inner } } +impl FramedWrite { + /// Take back data frames that have been buffered and/or fully written. + pub fn take_used_data_frames(&mut self) -> impl Iterator> + '_ { + UsedDataFrameTaker { + vec: &mut self.encoder.next_vec, + index: 0, + } + } +} + impl AsyncRead for FramedWrite { fn poll_read( mut self: Pin<&mut Self>, @@ -365,3 +469,53 @@ mod unstable { } } } + +fn poll_write_buf( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut B, +) -> Poll>> { + let n = if io.is_write_vectored() { + let mut slices = [IoSlice::new(&[]); MAX_VECTORED_IO_COUNT]; + let cnt = buf.chunks_vectored(&mut slices); + if cnt == 0 { + return Poll::Ready(Ok(ControlFlow::Break(0))); + } + ready!(io.poll_write_vectored(cx, &slices[..cnt]))? + } else { + let slice = buf.chunk(); + if slice.is_empty() { + return Poll::Ready(Ok(ControlFlow::Break(0))); + } + ready!(io.poll_write(cx, slice))? + }; + + buf.advance(n); + + Poll::Ready(Ok(ControlFlow::Continue(n))) +} + +struct UsedDataFrameTaker<'a, B> { + vec: &'a mut VecDeque>, + index: usize, +} + +impl<'a, B: Buf> Iterator for UsedDataFrameTaker<'a, B> { + type Item = frame::Data; + + #[inline] + fn next(&mut self) -> Option { + while let Some(item) = self.vec.get(self.index) { + if item.buf_len == 0 && !item.data_frame.payload().has_remaining() { + return self.vec.remove(self.index).map(|x| x.data_frame); + } + self.index += 1; + } + None + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.vec.len())) + } +} diff --git a/src/codec/mod.rs b/src/codec/mod.rs index 6cbdc1e18..04ec0be69 100644 --- a/src/codec/mod.rs +++ b/src/codec/mod.rs @@ -116,16 +116,18 @@ impl Codec { self.inner.get_mut().get_mut() } - /// Takes the data payload value that was fully written to the socket - pub(crate) fn take_last_data_frame(&mut self) -> Option> { - self.framed_write().take_last_data_frame() - } - fn framed_write(&mut self) -> &mut FramedWrite { self.inner.get_mut() } } +impl Codec { + /// Take back data frames that have been buffered and/or fully written. + pub fn take_used_data_frames(&mut self) -> impl Iterator> + '_ { + self.framed_write().take_used_data_frames() + } +} + impl Codec where T: AsyncWrite + Unpin, diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index 75358d473..a8f489adf 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -9,7 +9,8 @@ use crate::codec::UserError::*; use bytes::buf::Take; use std::{ cmp::{self, Ordering}, - fmt, io, mem, + fmt, io, + ops::ControlFlow, task::{Context, Poll, Waker}, }; @@ -51,23 +52,10 @@ pub(super) struct Prioritize { /// Stream ID of the last stream opened. last_opened_id: StreamId, - /// What `DATA` frame is currently being sent in the codec. - in_flight_data_frame: InFlightData, - /// The maximum amount of bytes a stream should buffer. max_buffer_size: usize, } -#[derive(Debug, Eq, PartialEq)] -enum InFlightData { - /// There is no `DATA` frame in flight. - Nothing, - /// There is a `DATA` frame in flight belonging to the given stream. - DataFrame(store::Key), - /// There was a `DATA` frame, but the stream's queue was since cleared. - Drop, -} - pub(crate) struct Prioritized { // The buffer inner: Take, @@ -99,7 +87,6 @@ impl Prioritize { pending_open: store::Queue::new(), flow, last_opened_id: StreamId::ZERO, - in_flight_data_frame: InFlightData::Nothing, max_buffer_size: config.local_max_buffer_size, } } @@ -495,16 +482,13 @@ impl Prioritize { if stream.buffered_send_data > 0 && stream.is_send_ready() { // TODO: This assertion isn't *exactly* correct. There can still be // buffered send data while the stream's pending send queue is - // empty. This can happen when a large data frame is in the process - // of being **partially** sent. Once the window has been sent, the - // data frame will be returned to the prioritization layer to be - // re-scheduled. + // empty and the stream is send ready. This can happen when + // try_assign_capacity is called from send_data. // // That said, it would be nice to figure out how to make this // assertion correctly. // // debug_assert!(!stream.pending_send.is_empty()); - self.pending_send.push(stream); } } @@ -524,8 +508,8 @@ impl Prioritize { // Ensure codec is ready ready!(dst.poll_ready(cx))?; - // Reclaim any frame that has previously been written - self.reclaim_frame(buffer, store, dst); + // Reclaim any frames that have previously been written + self.reclaim_frames(buffer, store, dst); // The max frame length let max_frame_len = dst.max_send_frame_size(); @@ -542,24 +526,20 @@ impl Prioritize { Some(frame) => { tracing::trace!(?frame, "writing"); - debug_assert_eq!(self.in_flight_data_frame, InFlightData::Nothing); - if let Frame::Data(ref frame) = frame { - self.in_flight_data_frame = InFlightData::DataFrame(frame.payload().stream); - } dst.buffer(frame).expect("invalid frame"); // Ensure the codec is ready to try the loop again. ready!(dst.poll_ready(cx))?; // Because, always try to reclaim... - self.reclaim_frame(buffer, store, dst); + self.reclaim_frames(buffer, store, dst); } None => { // Try to flush the codec. ready!(dst.flush(cx))?; - // This might release a data frame... - if !self.reclaim_frame(buffer, store, dst) { + // This might release data frames... + if !self.reclaim_frames(buffer, store, dst) { return Poll::Ready(Ok(())); } @@ -577,7 +557,7 @@ impl Prioritize { /// When a data frame is written to the codec, it may not be written in its /// entirety (large chunks are split up into potentially many data frames). /// In this case, the stream needs to be reprioritized. - fn reclaim_frame( + fn reclaim_frames( &mut self, buffer: &mut Buffer>, store: &mut Store, @@ -589,12 +569,14 @@ impl Prioritize { let span = tracing::trace_span!("try_reclaim_frame"); let _e = span.enter(); + let mut ret = false; + // First check if there are any data chunks to take back - if let Some(frame) = dst.take_last_data_frame() { - self.reclaim_frame_inner(buffer, store, frame) - } else { - false + for frame in dst.take_used_data_frames() { + ret |= self.reclaim_frame_inner(buffer, store, frame); } + + ret } fn reclaim_frame_inner( @@ -612,36 +594,28 @@ impl Prioritize { "reclaimed" ); - let mut eos = false; let key = frame.payload().stream; - - match mem::replace(&mut self.in_flight_data_frame, InFlightData::Nothing) { - InFlightData::Nothing => panic!("wasn't expecting a frame to reclaim"), - InFlightData::Drop => { - tracing::trace!("not reclaiming frame for cancelled stream"); - return false; - } - InFlightData::DataFrame(k) => { - debug_assert_eq!(k, key); - } - } - + let eos = frame.payload().end_of_stream; let mut frame = frame.map(|prioritized| { // TODO: Ensure fully written - eos = prioritized.end_of_stream; prioritized.inner.into_inner() }); if frame.payload().has_remaining() { let mut stream = store.resolve(key); - - if eos { - frame.set_end_stream(true); + match stream.in_flight_partial_send.take() { + Some(ControlFlow::Continue(())) => { + if eos { + frame.set_end_stream(true); + } + self.push_back_frame(frame.into(), buffer, &mut stream); + return true; + } + Some(ControlFlow::Break(())) => { + tracing::trace!("not reclaiming frame for cancelled stream"); + } + None => panic!("wasn't expecting a frame to reclaim"), } - - self.push_back_frame(frame.into(), buffer, &mut stream); - - return true; } false @@ -676,11 +650,8 @@ impl Prioritize { stream.buffered_send_data = 0; stream.requested_send_capacity = 0; - if let InFlightData::DataFrame(key) = self.in_flight_data_frame { - if stream.key() == key { - // This stream could get cleaned up now - don't allow the buffered frame to get reclaimed. - self.in_flight_data_frame = InFlightData::Drop; - } + if stream.in_flight_partial_send == Some(ControlFlow::Continue(())) { + stream.in_flight_partial_send = Some(ControlFlow::Break(())); } } @@ -720,6 +691,14 @@ impl Prioritize { let span = tracing::trace_span!("popped", ?stream.id, ?stream.state); let _e = span.enter(); + if stream.in_flight_partial_send == Some(ControlFlow::Continue(())) { + tracing::trace!( + "stream has an in-flight partial send data frame \ + that needs to be reclaimed before proceeding" + ); + continue; + } + // It's possible that this stream, besides having data to send, // is also queued to send a reset, and thus is already in the queue // to wait for "some time" after a reset. @@ -813,6 +792,8 @@ impl Prioritize { if frame.payload().remaining() > len { frame.set_end_stream(false); + stream.in_flight_partial_send = + Some(ControlFlow::Continue(())); } (eos, len) }); diff --git a/src/proto/streams/stream.rs b/src/proto/streams/stream.rs index f522f5d5d..0f7811139 100644 --- a/src/proto/streams/stream.rs +++ b/src/proto/streams/stream.rs @@ -3,6 +3,7 @@ use crate::Reason; use super::*; use std::fmt; +use std::ops::ControlFlow; use std::task::{Context, Waker}; use std::time::Instant; @@ -73,6 +74,11 @@ pub(super) struct Stream { /// Set to true when a push is pending for this stream pub is_pending_push: bool, + /// Set to Some(_) when a data frame is in the process of being partially sent. + /// Some(ControlFlow::Continue) means that the rest of the data frame should still be sent. + /// Some(ControlFlow::Break) means that the rest of the data frame should NOT be sent. + pub in_flight_partial_send: Option>, + // ===== Fields related to receiving ===== /// Next node in the accept linked list pub next_pending_accept: Option, @@ -178,6 +184,7 @@ impl Stream { is_pending_open: false, next_open: None, is_pending_push: false, + in_flight_partial_send: None, // ===== Fields related to receiving ===== next_pending_accept: None, @@ -232,7 +239,12 @@ impl Stream { // This is different from the "open" check because reserved streams don't count // toward the concurrency limit. // See https://httpwg.org/specs/rfc7540.html#rfc.section.5.1.2 - !self.is_pending_open && !self.is_pending_push + // + // With in_flight_partial_send, we track whether a data frame is in the process of being partially sent. + // If so, we should wait until the last part of that data frame is encoded before sending any other frames for this stream. + !self.is_pending_open + && !self.is_pending_push + && self.in_flight_partial_send != Some(ControlFlow::Continue(())) } /// Returns true if the stream is closed @@ -418,6 +430,7 @@ impl fmt::Debug for Stream { .h2_field_some("next_open", &self.next_open) .h2_field_if("is_pending_open", &self.is_pending_open) .h2_field_if("is_pending_push", &self.is_pending_push) + .h2_field_some("in_flight_partial_send", &self.in_flight_partial_send) .h2_field_some("next_pending_accept", &self.next_pending_accept) .h2_field_if("is_pending_accept", &self.is_pending_accept) .field("recv_flow", &self.recv_flow) diff --git a/tests/h2-support/src/mock.rs b/tests/h2-support/src/mock.rs index 9ec5ba379..810e1c2c9 100644 --- a/tests/h2-support/src/mock.rs +++ b/tests/h2-support/src/mock.rs @@ -117,6 +117,10 @@ impl Handle { p }) .await?; + + // Take the frame back from the codec + self.codec.take_used_data_frames().next(); + Ok(()) }