diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index b0e3ec27c..f64a526b4 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -1112,8 +1112,16 @@ impl TcpStream { /// This function will cause all pending and future I/O on the specified /// portions to return immediately with an appropriate value (see the /// documentation of `Shutdown`). + /// + /// Remark: this function transforms `Err(std::io::ErrorKind::NotConnected)` to `Ok(())`. + /// It does this to abstract away OS specific logic and to prevent a race condition between + /// this function call and the OS closing this socket because of external events (e.g. TCP reset). + /// See for more information. pub(super) fn shutdown_std(&self, how: Shutdown) -> io::Result<()> { - self.io.shutdown(how) + match self.io.shutdown(how) { + Err(err) if err.kind() == std::io::ErrorKind::NotConnected => Ok(()), + result => result, + } } /// Gets the value of the `TCP_NODELAY` option on this socket. diff --git a/tokio/tests/tcp_shutdown.rs b/tokio/tests/tcp_shutdown.rs index 2497c1a40..837e61230 100644 --- a/tokio/tests/tcp_shutdown.rs +++ b/tokio/tests/tcp_shutdown.rs @@ -2,8 +2,10 @@ #![cfg(all(feature = "full", not(target_os = "wasi"), not(miri)))] // Wasi doesn't support bind // No `socket` on miri. +use std::time::Duration; use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::oneshot::channel; use tokio_test::assert_ok; #[tokio::test] @@ -11,7 +13,7 @@ async fn shutdown() { let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); let addr = assert_ok!(srv.local_addr()); - tokio::spawn(async move { + let handle = tokio::spawn(async move { let mut stream = assert_ok!(TcpStream::connect(&addr).await); assert_ok!(AsyncWriteExt::shutdown(&mut stream).await); @@ -26,4 +28,55 @@ async fn shutdown() { let n = assert_ok!(io::copy(&mut rd, &mut wr).await); assert_eq!(n, 0); + assert_ok!(AsyncWriteExt::shutdown(&mut stream).await); + handle.await.unwrap() +} + +#[tokio::test] +async fn shutdown_after_tcp_reset() { + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let addr = assert_ok!(srv.local_addr()); + + let (connected_tx, connected_rx) = channel(); + let (dropped_tx, dropped_rx) = channel(); + + let handle = tokio::spawn(async move { + let mut stream = assert_ok!(TcpStream::connect(&addr).await); + connected_tx.send(()).unwrap(); + + dropped_rx.await.unwrap(); + assert_ok!(AsyncWriteExt::shutdown(&mut stream).await); + }); + + let (stream, _) = assert_ok!(srv.accept().await); + // By setting linger to 0 we will trigger a TCP reset + stream.set_linger(Some(Duration::new(0, 0))).unwrap(); + connected_rx.await.unwrap(); + + drop(stream); + dropped_tx.send(()).unwrap(); + + handle.await.unwrap(); +} + +#[tokio::test] +async fn shutdown_multiple_calls() { + let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let addr = assert_ok!(srv.local_addr()); + + let (connected_tx, connected_rx) = channel(); + + let handle = tokio::spawn(async move { + let mut stream = assert_ok!(TcpStream::connect(&addr).await); + connected_tx.send(()).unwrap(); + assert_ok!(AsyncWriteExt::shutdown(&mut stream).await); + assert_ok!(AsyncWriteExt::shutdown(&mut stream).await); + assert_ok!(AsyncWriteExt::shutdown(&mut stream).await); + }); + + let (mut stream, _) = assert_ok!(srv.accept().await); + connected_rx.await.unwrap(); + + assert_ok!(AsyncWriteExt::shutdown(&mut stream).await); + handle.await.unwrap(); }