Skip to content

Commit 764686d

Browse files
committed
fix(http1): flush buffered data before shutdown
1 parent 90ede30 commit 764686d

5 files changed

Lines changed: 225 additions & 11 deletions

File tree

‎Cargo.toml‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,8 @@ required-features = ["full"]
362362
name = "server"
363363
path = "tests/server.rs"
364364
required-features = ["full"]
365+
366+
[[test]]
367+
name = "h1_shutdown_while_buffered"
368+
path = "tests/h1_shutdown_while_buffered.rs"
369+
required-features = ["full"]

‎src/proto/h1/conn.rs‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ where
833833
}
834834

835835
pub(crate) fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
836-
match ready!(Pin::new(self.io.io_mut()).poll_shutdown(cx)) {
836+
match ready!(self.io.poll_shutdown(cx)) {
837837
Ok(()) => {
838838
trace!("shut down IO complete");
839839
Poll::Ready(Ok(()))

‎src/proto/h1/io.rs‎

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,10 @@ where
242242
Poll::Ready(Ok(_)) => {
243243
let n = buf.filled().len();
244244
trace!("received {} bytes", n);
245+
// Safety: we just read that many bytes into the
246+
// uninitialized part of the buffer, so this is okay.
247+
// @tokio pls give me back `poll_read_buf` thanks
245248
unsafe {
246-
// Safety: we just read that many bytes into the
247-
// uninitialized part of the buffer, so this is okay.
248-
// @tokio pls give me back `poll_read_buf` thanks
249249
self.read_buf.advance_mut(n);
250250
}
251251
self.read_buf_strategy.record(n);
@@ -263,10 +263,6 @@ where
263263
(self.io, self.read_buf.freeze())
264264
}
265265

266-
pub(crate) fn io_mut(&mut self) -> &mut T {
267-
&mut self.io
268-
}
269-
270266
pub(crate) fn is_read_blocked(&self) -> bool {
271267
self.read_blocked
272268
}
@@ -330,6 +326,11 @@ where
330326
Pin::new(&mut self.io).poll_flush(cx)
331327
}
332328

329+
pub(crate) fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
330+
ready!(self.poll_flush(cx))?;
331+
Pin::new(&mut self.io).poll_shutdown(cx)
332+
}
333+
333334
#[cfg(test)]
334335
fn flush(&mut self) -> impl std::future::Future<Output = io::Result<()>> + '_ {
335336
futures_util::future::poll_fn(move |cx| self.poll_flush(cx))
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
// Test: Ensures poll_shutdown() is never called with buffered data
2+
//
3+
// Reproduces rare timing bug where HTTP/1.1 server calls shutdown() on a socket while response
4+
// data is still buffered (not flushed), leading to data loss.
5+
//
6+
// Scenario:
7+
// 1. Request fully received and read.
8+
// 2. Server computes a "large" response with Full::new()
9+
// 3. Socket accepts only a chunk of response and then pends
10+
// 3. Flush returns Pending (remaining data still buffered), result ignored
11+
// 4. self.conn.wants_read_again() is false and poll_loop returns Ready
12+
// 5. BUG: poll_shutdown called despite body remaining and buffered body is lost
13+
// 6. FIX: Ideally never try to shutdown without body being flushed completely.
14+
15+
use std::{
16+
pin::Pin,
17+
sync::{Arc, Mutex},
18+
task::Poll,
19+
time::Duration,
20+
};
21+
22+
use bytes::Bytes;
23+
use http::{Request, Response};
24+
use http_body_util::Full;
25+
use hyper::{body::Incoming, service::service_fn};
26+
use support::TokioIo;
27+
use tokio::{
28+
io::{AsyncRead, AsyncWrite},
29+
net::{TcpListener, TcpStream},
30+
time::{sleep, timeout},
31+
};
32+
mod support;
33+
34+
#[derive(Debug, Default)]
35+
struct PendingStreamStatistics {
36+
bytes_written: usize,
37+
total_attempted: usize,
38+
shutdown_called_with_buffered: bool,
39+
buffered_at_shutdown: usize,
40+
}
41+
42+
// Simple struct that simply does one write and then pends perpetually
43+
struct PendingStream {
44+
inner: TcpStream,
45+
// Keep track of how many times we entered poll_write so as to be able to write only the first
46+
// time out
47+
write_count: usize,
48+
// Only write this chunk size out of full buffer
49+
write_chunk_size: usize,
50+
stats: Arc<Mutex<PendingStreamStatistics>>,
51+
}
52+
53+
impl PendingStream {
54+
fn new(
55+
inner: TcpStream,
56+
write_chunk_size: usize,
57+
stats: Arc<Mutex<PendingStreamStatistics>>,
58+
) -> Self {
59+
Self {
60+
inner,
61+
stats,
62+
write_chunk_size,
63+
write_count: 0,
64+
}
65+
}
66+
}
67+
68+
impl AsyncRead for PendingStream {
69+
fn poll_read(
70+
mut self: Pin<&mut Self>,
71+
cx: &mut std::task::Context<'_>,
72+
buf: &mut tokio::io::ReadBuf<'_>,
73+
) -> Poll<std::io::Result<()>> {
74+
Pin::new(&mut self.inner).poll_read(cx, buf)
75+
}
76+
}
77+
78+
impl AsyncWrite for PendingStream {
79+
fn poll_write(
80+
mut self: Pin<&mut Self>,
81+
cx: &mut std::task::Context<'_>,
82+
buf: &[u8],
83+
) -> Poll<std::io::Result<usize>> {
84+
self.write_count += 1;
85+
86+
let mut stats = self.stats.lock().unwrap();
87+
stats.total_attempted += buf.len();
88+
89+
if self.write_count == 1 {
90+
// First write: partial only
91+
let partial = std::cmp::min(buf.len(), self.write_chunk_size);
92+
drop(stats);
93+
94+
let result = Pin::new(&mut self.inner).poll_write(cx, &buf[..partial]);
95+
if let Poll::Ready(Ok(n)) = result {
96+
self.stats.lock().unwrap().bytes_written += n;
97+
}
98+
return result;
99+
}
100+
101+
// Block all further writes to simulate pending buffer
102+
Poll::Pending
103+
}
104+
105+
fn poll_shutdown(
106+
mut self: Pin<&mut Self>,
107+
cx: &mut std::task::Context<'_>,
108+
) -> Poll<std::io::Result<()>> {
109+
let mut stats = self.stats.lock().unwrap();
110+
let buffered = stats.total_attempted - stats.bytes_written;
111+
112+
if buffered > 0 {
113+
eprintln!(
114+
"\n❌BUG: shutdown() called with {} bytes buffered",
115+
buffered
116+
);
117+
stats.shutdown_called_with_buffered = true;
118+
stats.buffered_at_shutdown = buffered;
119+
}
120+
drop(stats);
121+
Pin::new(&mut self.inner).poll_shutdown(cx)
122+
}
123+
124+
fn poll_flush(
125+
mut self: Pin<&mut Self>,
126+
cx: &mut std::task::Context<'_>,
127+
) -> Poll<std::io::Result<()>> {
128+
let stats = self.stats.lock().unwrap();
129+
let buffered = stats.total_attempted - stats.bytes_written;
130+
131+
if buffered > 0 {
132+
return Poll::Pending;
133+
}
134+
135+
drop(stats);
136+
Pin::new(&mut self.inner).poll_flush(cx)
137+
}
138+
}
139+
140+
// Test doesn't necessarily check that the connections ended successfully but mainly that shutdown
141+
// wasn't called with data still remaining within hyper's internal buffer
142+
#[tokio::test]
143+
async fn test_no_premature_shutdown_while_buffered() {
144+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
145+
let addr = listener.local_addr().unwrap();
146+
let stats = Arc::new(Mutex::new(PendingStreamStatistics::default()));
147+
148+
let stats_clone = stats.clone();
149+
let server = tokio::spawn(async move {
150+
let (stream, _) = listener.accept().await.unwrap();
151+
let pending_stream = PendingStream::new(stream, 212_992, stats_clone);
152+
let io = TokioIo::new(pending_stream);
153+
154+
let service = service_fn(|_req: Request<Incoming>| async move {
155+
// Larger Full response than write_chunk_size
156+
let body = Full::new(Bytes::from(vec![b'X'; 500_000]));
157+
Ok::<_, hyper::Error>(Response::new(body))
158+
});
159+
160+
hyper::server::conn::http1::Builder::new()
161+
.serve_connection(io, service)
162+
.await
163+
});
164+
165+
// Wait for server to be ready
166+
sleep(Duration::from_millis(50)).await;
167+
168+
// Client sends request
169+
tokio::spawn(async move {
170+
let mut stream = TcpStream::connect(addr).await.unwrap();
171+
172+
use tokio::io::AsyncWriteExt;
173+
174+
stream
175+
.write_all(
176+
b"POST / HTTP/1.1\r\n\
177+
Host: localhost\r\n\
178+
Transfer-Encoding: chunked\r\n\
179+
\r\n",
180+
)
181+
.await
182+
.unwrap();
183+
184+
stream.write_all(b"A\r\nHello World\r\n").await.unwrap();
185+
stream.write_all(b"0\r\n\r\n").await.unwrap();
186+
stream.flush().await.unwrap();
187+
188+
// keep connection open
189+
sleep(Duration::from_secs(2)).await;
190+
});
191+
192+
// Wait for completion
193+
let result = timeout(Duration::from_millis(900), server).await;
194+
195+
let stats = stats.lock().unwrap();
196+
197+
assert!(
198+
!stats.shutdown_called_with_buffered,
199+
"shutdown() called with {} bytes still buffered (wrote {} of {} bytes)",
200+
stats.buffered_at_shutdown, stats.bytes_written, stats.total_attempted
201+
);
202+
if let Ok(Ok(conn_result)) = result {
203+
conn_result.ok();
204+
}
205+
}

‎tests/ready_on_poll_stream.rs‎

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,18 @@ impl Write for ReadyOnPollStream {
109109

110110
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
111111
self.flush_count += 1;
112+
const TOTAL_CHUNKS: usize = 16;
113+
114+
if self.pending_write.is_none() {
115+
return Poll::Ready(Ok(()));
116+
}
117+
112118
// We require two flushes to complete each chunk, simulating a success at the end of the old
113119
// poll loop. After all chunks are written, we always succeed on flush to allow for finish.
114-
const TOTAL_CHUNKS: usize = 16;
115120
if self.flush_count % 2 != 0 && self.flush_count < TOTAL_CHUNKS * 2 {
116121
if let Some(sleep) = self.pending_write.as_mut() {
117122
let sleep = sleep.as_mut();
118123
ready!(Future::poll(sleep, cx));
119-
} else {
120-
return Poll::Pending;
121124
}
122125
}
123126
let mut this = self.as_mut().project();

0 commit comments

Comments
 (0)