net: ignore NotConnected in TcpStream::shutdown (#7290)

This commit is contained in:
soundofspace 2025-05-06 10:27:40 +02:00 committed by GitHub
parent 00754c8f9c
commit f0fdef80c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 2 deletions

View File

@ -1112,8 +1112,16 @@ impl TcpStream {
/// This function will cause all pending and future I/O on the specified
/// portions to return immediately with an appropriate value (see the
/// documentation of `Shutdown`).
///
/// Remark: this function transforms `Err(std::io::ErrorKind::NotConnected)` to `Ok(())`.
/// It does this to abstract away OS specific logic and to prevent a race condition between
/// this function call and the OS closing this socket because of external events (e.g. TCP reset).
/// See <https://github.com/tokio-rs/tokio/issues/4665> for more information.
pub(super) fn shutdown_std(&self, how: Shutdown) -> io::Result<()> {
self.io.shutdown(how)
match self.io.shutdown(how) {
Err(err) if err.kind() == std::io::ErrorKind::NotConnected => Ok(()),
result => result,
}
}
/// Gets the value of the `TCP_NODELAY` option on this socket.

View File

@ -2,8 +2,10 @@
#![cfg(all(feature = "full", not(target_os = "wasi"), not(miri)))] // Wasi doesn't support bind
// No `socket` on miri.
use std::time::Duration;
use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot::channel;
use tokio_test::assert_ok;
#[tokio::test]
@ -11,7 +13,7 @@ async fn shutdown() {
let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(srv.local_addr());
tokio::spawn(async move {
let handle = tokio::spawn(async move {
let mut stream = assert_ok!(TcpStream::connect(&addr).await);
assert_ok!(AsyncWriteExt::shutdown(&mut stream).await);
@ -26,4 +28,55 @@ async fn shutdown() {
let n = assert_ok!(io::copy(&mut rd, &mut wr).await);
assert_eq!(n, 0);
assert_ok!(AsyncWriteExt::shutdown(&mut stream).await);
handle.await.unwrap()
}
#[tokio::test]
async fn shutdown_after_tcp_reset() {
let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(srv.local_addr());
let (connected_tx, connected_rx) = channel();
let (dropped_tx, dropped_rx) = channel();
let handle = tokio::spawn(async move {
let mut stream = assert_ok!(TcpStream::connect(&addr).await);
connected_tx.send(()).unwrap();
dropped_rx.await.unwrap();
assert_ok!(AsyncWriteExt::shutdown(&mut stream).await);
});
let (stream, _) = assert_ok!(srv.accept().await);
// By setting linger to 0 we will trigger a TCP reset
stream.set_linger(Some(Duration::new(0, 0))).unwrap();
connected_rx.await.unwrap();
drop(stream);
dropped_tx.send(()).unwrap();
handle.await.unwrap();
}
#[tokio::test]
async fn shutdown_multiple_calls() {
let srv = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(srv.local_addr());
let (connected_tx, connected_rx) = channel();
let handle = tokio::spawn(async move {
let mut stream = assert_ok!(TcpStream::connect(&addr).await);
connected_tx.send(()).unwrap();
assert_ok!(AsyncWriteExt::shutdown(&mut stream).await);
assert_ok!(AsyncWriteExt::shutdown(&mut stream).await);
assert_ok!(AsyncWriteExt::shutdown(&mut stream).await);
});
let (mut stream, _) = assert_ok!(srv.accept().await);
connected_rx.await.unwrap();
assert_ok!(AsyncWriteExt::shutdown(&mut stream).await);
handle.await.unwrap();
}