mirror of
https://github.com/tokio-rs/tokio.git
synced 2025-09-28 12:10:37 +00:00
io: remove unsafe from ReadToString (#2384)
This commit is contained in:
parent
7e88b56be5
commit
2da15b5f24
@ -121,7 +121,7 @@ default-features = false
|
||||
optional = true
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { version = "0.2.0" }
|
||||
tokio-test = { version = "0.2.0", path = "../tokio-test" }
|
||||
futures = { version = "0.3.0", features = ["async-await"] }
|
||||
proptest = "0.9.4"
|
||||
tempfile = "3.1.0"
|
||||
|
@ -4,7 +4,7 @@ use crate::io::AsyncRead;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{io, mem, str};
|
||||
use std::{io, mem};
|
||||
|
||||
cfg_io_util! {
|
||||
/// Future for the [`read_to_string`](super::AsyncReadExt::read_to_string) method.
|
||||
@ -25,7 +25,7 @@ where
|
||||
let start_len = buf.len();
|
||||
ReadToString {
|
||||
reader,
|
||||
bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) },
|
||||
bytes: mem::replace(buf, String::new()).into_bytes(),
|
||||
buf,
|
||||
start_len,
|
||||
}
|
||||
@ -38,19 +38,20 @@ fn read_to_string_internal<R: AsyncRead + ?Sized>(
|
||||
bytes: &mut Vec<u8>,
|
||||
start_len: usize,
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len));
|
||||
if str::from_utf8(&bytes).is_err() {
|
||||
Poll::Ready(ret.and_then(|_| {
|
||||
Err(io::Error::new(
|
||||
let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len))?;
|
||||
match String::from_utf8(mem::replace(bytes, Vec::new())) {
|
||||
Ok(string) => {
|
||||
debug_assert!(buf.is_empty());
|
||||
*buf = string;
|
||||
Poll::Ready(Ok(ret))
|
||||
}
|
||||
Err(e) => {
|
||||
*bytes = e.into_bytes();
|
||||
Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"stream did not contain valid UTF-8",
|
||||
))
|
||||
}))
|
||||
} else {
|
||||
debug_assert!(buf.is_empty());
|
||||
// Safety: `bytes` is a valid UTF-8 because `str::from_utf8` returned `Ok`.
|
||||
mem::swap(unsafe { buf.as_mut_vec() }, bytes);
|
||||
Poll::Ready(ret)
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,7 +68,14 @@ where
|
||||
bytes,
|
||||
start_len,
|
||||
} = &mut *self;
|
||||
read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len)
|
||||
let ret = read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len);
|
||||
if let Poll::Ready(Err(_)) = ret {
|
||||
// Put back the original string.
|
||||
bytes.truncate(*start_len);
|
||||
**buf = String::from_utf8(mem::replace(bytes, Vec::new()))
|
||||
.expect("original string no longer utf-8");
|
||||
}
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
|
49
tokio/tests/read_to_string.rs
Normal file
49
tokio/tests/read_to_string.rs
Normal file
@ -0,0 +1,49 @@
|
||||
use std::io;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio_test::io::Builder;
|
||||
|
||||
#[tokio::test]
|
||||
async fn to_string_does_not_truncate_on_utf8_error() {
|
||||
let data = vec![0xff, 0xff, 0xff];
|
||||
|
||||
let mut s = "abc".to_string();
|
||||
|
||||
match AsyncReadExt::read_to_string(&mut data.as_slice(), &mut s).await {
|
||||
Ok(len) => panic!("Should fail: {} bytes.", len),
|
||||
Err(err) if err.to_string() == "stream did not contain valid UTF-8" => {}
|
||||
Err(err) => panic!("Fail: {}.", err),
|
||||
}
|
||||
|
||||
assert_eq!(s, "abc");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn to_string_does_not_truncate_on_io_error() {
|
||||
let mut mock = Builder::new()
|
||||
.read(b"def")
|
||||
.read_error(io::Error::new(io::ErrorKind::Other, "whoops"))
|
||||
.build();
|
||||
let mut s = "abc".to_string();
|
||||
|
||||
match AsyncReadExt::read_to_string(&mut mock, &mut s).await {
|
||||
Ok(len) => panic!("Should fail: {} bytes.", len),
|
||||
Err(err) if err.to_string() == "whoops" => {}
|
||||
Err(err) => panic!("Fail: {}.", err),
|
||||
}
|
||||
|
||||
assert_eq!(s, "abc");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn to_string_appends() {
|
||||
let data = b"def".to_vec();
|
||||
|
||||
let mut s = "abc".to_string();
|
||||
|
||||
let len = AsyncReadExt::read_to_string(&mut data.as_slice(), &mut s)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(len, 3);
|
||||
assert_eq!(s, "abcdef");
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user