io: fix panic in read_line (#2541)

Fixes: #2532
This commit is contained in:
Alice Ryhl 2020-05-24 23:26:33 +02:00 committed by GitHub
parent d562e58871
commit 954f2b7304
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 195 additions and 33 deletions

View File

@ -91,6 +91,7 @@ where
let me = self.project();
let n = ready!(read_line_internal(me.reader, cx, me.buf, me.bytes, me.read))?;
debug_assert_eq!(*me.read, 0);
if n == 0 && me.buf.is_empty() {
return Poll::Ready(Ok(None));

View File

@ -5,7 +5,6 @@ use std::future::Future;
use std::io;
use std::mem;
use std::pin::Pin;
use std::str;
use std::task::{Context, Poll};
cfg_io_util! {
@ -14,45 +13,72 @@ cfg_io_util! {
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadLine<'a, R: ?Sized> {
reader: &'a mut R,
buf: &'a mut String,
bytes: Vec<u8>,
/// This is the buffer we were provided. It will be replaced with an empty string
/// while reading to postpone utf-8 handling until after reading.
output: &'a mut String,
/// The actual allocation of the string is moved into a vector instead.
buf: Vec<u8>,
/// The number of bytes appended to buf. This can be less than buf.len() if
/// the buffer was not empty when the operation was started.
read: usize,
}
}
pub(crate) fn read_line<'a, R>(reader: &'a mut R, buf: &'a mut String) -> ReadLine<'a, R>
pub(crate) fn read_line<'a, R>(reader: &'a mut R, string: &'a mut String) -> ReadLine<'a, R>
where
R: AsyncBufRead + ?Sized + Unpin,
{
ReadLine {
reader,
bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) },
buf,
buf: mem::replace(string, String::new()).into_bytes(),
output: string,
read: 0,
}
}
fn put_back_original_data(output: &mut String, mut vector: Vec<u8>, num_bytes_read: usize) {
let original_len = vector.len() - num_bytes_read;
vector.truncate(original_len);
*output = String::from_utf8(vector).expect("The original data must be valid utf-8.");
}
pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>(
reader: Pin<&mut R>,
cx: &mut Context<'_>,
buf: &mut String,
bytes: &mut Vec<u8>,
output: &mut String,
buf: &mut Vec<u8>,
read: &mut usize,
) -> Poll<io::Result<usize>> {
let ret = ready!(read_until_internal(reader, cx, b'\n', bytes, read));
if str::from_utf8(&bytes).is_err() {
Poll::Ready(ret.and_then(|_| {
Err(io::Error::new(
let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read));
let utf8_res = String::from_utf8(mem::replace(buf, Vec::new()));
// At this point both buf and output are empty. The allocation is in utf8_res.
debug_assert!(buf.is_empty());
match (io_res, utf8_res) {
(Ok(num_bytes), Ok(string)) => {
debug_assert_eq!(*read, 0);
*output = string;
Poll::Ready(Ok(num_bytes))
}
(Err(io_err), Ok(string)) => {
*output = string;
Poll::Ready(Err(io_err))
}
(Ok(num_bytes), Err(utf8_err)) => {
debug_assert_eq!(*read, 0);
put_back_original_data(output, utf8_err.into_bytes(), num_bytes);
Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
"stream did not contain valid UTF-8",
))
}))
} else {
debug_assert!(buf.is_empty());
debug_assert_eq!(*read, 0);
// Safety: `bytes` is a valid UTF-8 because `str::from_utf8` returned `Ok`.
mem::swap(unsafe { buf.as_mut_vec() }, bytes);
Poll::Ready(ret)
)))
}
(Err(io_err), Err(utf8_err)) => {
put_back_original_data(output, utf8_err.into_bytes(), *read);
Poll::Ready(Err(io_err))
}
}
}
@ -62,11 +88,12 @@ impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadLine<'_, R> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self {
reader,
output,
buf,
bytes,
read,
} = &mut *self;
read_line_internal(Pin::new(reader), cx, buf, bytes, read)
read_line_internal(Pin::new(reader), cx, output, buf, read)
}
}

View File

@ -8,19 +8,22 @@ use std::task::{Context, Poll};
cfg_io_util! {
/// Future for the [`read_until`](crate::io::AsyncBufReadExt::read_until) method.
/// The delimeter is included in the resulting vector.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadUntil<'a, R: ?Sized> {
reader: &'a mut R,
byte: u8,
delimeter: u8,
buf: &'a mut Vec<u8>,
/// The number of bytes appended to buf. This can be less than buf.len() if
/// the buffer was not empty when the operation was started.
read: usize,
}
}
pub(crate) fn read_until<'a, R>(
reader: &'a mut R,
byte: u8,
delimeter: u8,
buf: &'a mut Vec<u8>,
) -> ReadUntil<'a, R>
where
@ -28,7 +31,7 @@ where
{
ReadUntil {
reader,
byte,
delimeter,
buf,
read: 0,
}
@ -37,14 +40,14 @@ where
pub(super) fn read_until_internal<R: AsyncBufRead + ?Sized>(
mut reader: Pin<&mut R>,
cx: &mut Context<'_>,
byte: u8,
delimeter: u8,
buf: &mut Vec<u8>,
read: &mut usize,
) -> Poll<io::Result<usize>> {
loop {
let (done, used) = {
let available = ready!(reader.as_mut().poll_fill_buf(cx))?;
if let Some(i) = memchr::memchr(byte, available) {
if let Some(i) = memchr::memchr(delimeter, available) {
buf.extend_from_slice(&available[..=i]);
(true, i + 1)
} else {
@ -66,11 +69,11 @@ impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntil<'_, R> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self {
reader,
byte,
delimeter,
buf,
read,
} = &mut *self;
read_until_internal(Pin::new(reader), cx, *byte, buf, read)
read_until_internal(Pin::new(reader), cx, *delimeter, buf, read)
}
}

View File

@ -75,6 +75,8 @@ where
let n = ready!(read_until_internal(
me.reader, cx, *me.delim, me.buf, me.read,
))?;
// read_until_internal resets me.read to zero once it finds the delimeter
debug_assert_eq!(*me.read, 0);
if n == 0 && me.buf.is_empty() {
return Poll::Ready(Ok(None));

View File

@ -1,8 +1,9 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]
use tokio::io::AsyncBufReadExt;
use tokio_test::assert_ok;
use std::io::ErrorKind;
use tokio::io::{AsyncBufReadExt, BufReader, Error};
use tokio_test::{assert_ok, io::Builder};
use std::io::Cursor;
@ -27,3 +28,80 @@ async fn read_line() {
assert_eq!(n, 0);
assert_eq!(buf, "");
}
#[tokio::test]
async fn read_line_not_all_ready() {
let mock = Builder::new()
.read(b"Hello Wor")
.read(b"ld\nFizzBuz")
.read(b"z\n1\n2")
.build();
let mut read = BufReader::new(mock);
let mut line = "We say ".to_string();
let bytes = read.read_line(&mut line).await.unwrap();
assert_eq!(bytes, "Hello World\n".len());
assert_eq!(line.as_str(), "We say Hello World\n");
line = "I solve ".to_string();
let bytes = read.read_line(&mut line).await.unwrap();
assert_eq!(bytes, "FizzBuzz\n".len());
assert_eq!(line.as_str(), "I solve FizzBuzz\n");
line.clear();
let bytes = read.read_line(&mut line).await.unwrap();
assert_eq!(bytes, 2);
assert_eq!(line.as_str(), "1\n");
line.clear();
let bytes = read.read_line(&mut line).await.unwrap();
assert_eq!(bytes, 1);
assert_eq!(line.as_str(), "2");
}
#[tokio::test]
async fn read_line_invalid_utf8() {
let mock = Builder::new().read(b"Hello Wor\xffld.\n").build();
let mut read = BufReader::new(mock);
let mut line = "Foo".to_string();
let err = read.read_line(&mut line).await.expect_err("Should fail");
assert_eq!(err.kind(), ErrorKind::InvalidData);
assert_eq!(err.to_string(), "stream did not contain valid UTF-8");
assert_eq!(line.as_str(), "Foo");
}
#[tokio::test]
async fn read_line_fail() {
let mock = Builder::new()
.read(b"Hello Wor")
.read_error(Error::new(ErrorKind::Other, "The world has no end"))
.build();
let mut read = BufReader::new(mock);
let mut line = "Foo".to_string();
let err = read.read_line(&mut line).await.expect_err("Should fail");
assert_eq!(err.kind(), ErrorKind::Other);
assert_eq!(err.to_string(), "The world has no end");
assert_eq!(line.as_str(), "FooHello Wor");
}
#[tokio::test]
async fn read_line_fail_and_utf8_fail() {
let mock = Builder::new()
.read(b"Hello Wor")
.read(b"\xff\xff\xff")
.read_error(Error::new(ErrorKind::Other, "The world has no end"))
.build();
let mut read = BufReader::new(mock);
let mut line = "Foo".to_string();
let err = read.read_line(&mut line).await.expect_err("Should fail");
assert_eq!(err.kind(), ErrorKind::Other);
assert_eq!(err.to_string(), "The world has no end");
assert_eq!(line.as_str(), "Foo");
}

View File

@ -1,8 +1,9 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]
use tokio::io::AsyncBufReadExt;
use tokio_test::assert_ok;
use std::io::ErrorKind;
use tokio::io::{AsyncBufReadExt, BufReader, Error};
use tokio_test::{assert_ok, io::Builder};
#[tokio::test]
async fn read_until() {
@ -21,3 +22,53 @@ async fn read_until() {
assert_eq!(n, 0);
assert_eq!(buf, []);
}
#[tokio::test]
async fn read_until_not_all_ready() {
let mock = Builder::new()
.read(b"Hello Wor")
.read(b"ld#Fizz\xffBuz")
.read(b"z#1#2")
.build();
let mut read = BufReader::new(mock);
let mut chunk = b"We say ".to_vec();
let bytes = read.read_until(b'#', &mut chunk).await.unwrap();
assert_eq!(bytes, b"Hello World#".len());
assert_eq!(chunk, b"We say Hello World#");
chunk = b"I solve ".to_vec();
let bytes = read.read_until(b'#', &mut chunk).await.unwrap();
assert_eq!(bytes, b"Fizz\xffBuzz\n".len());
assert_eq!(chunk, b"I solve Fizz\xffBuzz#");
chunk.clear();
let bytes = read.read_until(b'#', &mut chunk).await.unwrap();
assert_eq!(bytes, 2);
assert_eq!(chunk, b"1#");
chunk.clear();
let bytes = read.read_until(b'#', &mut chunk).await.unwrap();
assert_eq!(bytes, 1);
assert_eq!(chunk, b"2");
}
#[tokio::test]
async fn read_until_fail() {
let mock = Builder::new()
.read(b"Hello \xffWor")
.read_error(Error::new(ErrorKind::Other, "The world has no end"))
.build();
let mut read = BufReader::new(mock);
let mut chunk = b"Foo".to_vec();
let err = read
.read_until(b'#', &mut chunk)
.await
.expect_err("Should fail");
assert_eq!(err.kind(), ErrorKind::Other);
assert_eq!(err.to_string(), "The world has no end");
assert_eq!(chunk, b"FooHello \xffWor");
}