diff --git a/Cargo.toml b/Cargo.toml index 413a30236..7d93dcaf9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = [ "tokio", # "tokio-buf", - # "tokio-codec", + "tokio-codec", "tokio-current-thread", "tokio-executor", # "tokio-fs", diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 2c397e608..4ed9e175b 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -45,7 +45,7 @@ jobs: rust: $(nightly) crates: # - tokio-buf - # - tokio-codec + - tokio-codec - tokio-current-thread - tokio-executor - tokio-io diff --git a/tokio-codec/Cargo.toml b/tokio-codec/Cargo.toml index bbc29c334..e34e7b2e0 100644 --- a/tokio-codec/Cargo.toml +++ b/tokio-codec/Cargo.toml @@ -24,4 +24,10 @@ publish = false [dependencies] tokio-io = { version = "0.2.0", path = "../tokio-io" } bytes = "0.4.7" -futures = "0.1.18" +tokio-futures = { version = "0.2.0", path = "../tokio-futures" } +log = "0.4" + +[dev-dependencies] +futures-preview = "0.3.0-alpha.16" +tokio-current-thread = { version = "0.2.0", path = "../tokio-current-thread" } +tokio-test = { version = "0.2.0", path = "../tokio-test" } \ No newline at end of file diff --git a/tokio-codec/src/bytes_codec.rs b/tokio-codec/src/bytes_codec.rs index 3d6e979d9..b4a0fa31a 100644 --- a/tokio-codec/src/bytes_codec.rs +++ b/tokio-codec/src/bytes_codec.rs @@ -1,6 +1,7 @@ +use crate::decoder::Decoder; +use crate::encoder::Encoder; use bytes::{BufMut, Bytes, BytesMut}; use std::io; -use tokio_io::_tokio_codec::{Decoder, Encoder}; /// A simple `Codec` implementation that just ships bytes around. #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] diff --git a/tokio-codec/src/decoder.rs b/tokio-codec/src/decoder.rs new file mode 100644 index 000000000..f492b0f19 --- /dev/null +++ b/tokio-codec/src/decoder.rs @@ -0,0 +1,117 @@ +use bytes::BytesMut; +use std::io; +use tokio_io::{AsyncRead, AsyncWrite}; + +use super::encoder::Encoder; + +use super::Framed; + +/// Decoding of frames via buffers. +/// +/// This trait is used when constructing an instance of `Framed` or +/// `FramedRead`. An implementation of `Decoder` takes a byte stream that has +/// already been buffered in `src` and decodes the data into a stream of +/// `Self::Item` frames. +/// +/// Implementations are able to track state on `self`, which enables +/// implementing stateful streaming parsers. In many cases, though, this type +/// will simply be a unit struct (e.g. `struct HttpDecoder`). + +// Note: We can't deprecate this trait, because the deprecation carries through to tokio-codec, and +// there doesn't seem to be a way to un-deprecate the re-export. +pub trait Decoder { + /// The type of decoded frames. + type Item; + + /// The type of unrecoverable frame decoding errors. + /// + /// If an individual message is ill-formed but can be ignored without + /// interfering with the processing of future messages, it may be more + /// useful to report the failure as an `Item`. + /// + /// `From` is required in the interest of making `Error` suitable + /// for returning directly from a `FramedRead`, and to enable the default + /// implementation of `decode_eof` to yield an `io::Error` when the decoder + /// fails to consume all available data. + /// + /// Note that implementors of this trait can simply indicate `type Error = + /// io::Error` to use I/O errors as this type. + type Error: From; + + /// Attempts to decode a frame from the provided buffer of bytes. + /// + /// This method is called by `FramedRead` whenever bytes are ready to be + /// parsed. The provided buffer of bytes is what's been read so far, and + /// this instance of `Decode` can determine whether an entire frame is in + /// the buffer and is ready to be returned. + /// + /// If an entire frame is available, then this instance will remove those + /// bytes from the buffer provided and return them as a decoded + /// frame. Note that removing bytes from the provided buffer doesn't always + /// necessarily copy the bytes, so this should be an efficient operation in + /// most circumstances. + /// + /// If the bytes look valid, but a frame isn't fully available yet, then + /// `Ok(None)` is returned. This indicates to the `Framed` instance that + /// it needs to read some more bytes before calling this method again. + /// + /// Note that the bytes provided may be empty. If a previous call to + /// `decode` consumed all the bytes in the buffer then `decode` will be + /// called again until it returns `Ok(None)`, indicating that more bytes need to + /// be read. + /// + /// Finally, if the bytes in the buffer are malformed then an error is + /// returned indicating why. This informs `Framed` that the stream is now + /// corrupt and should be terminated. + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error>; + + /// A default method available to be called when there are no more bytes + /// available to be read from the underlying I/O. + /// + /// This method defaults to calling `decode` and returns an error if + /// `Ok(None)` is returned while there is unconsumed data in `buf`. + /// Typically this doesn't need to be implemented unless the framing + /// protocol differs near the end of the stream. + /// + /// Note that the `buf` argument may be empty. If a previous call to + /// `decode_eof` consumed all the bytes in the buffer, `decode_eof` will be + /// called again until it returns `None`, indicating that there are no more + /// frames to yield. This behavior enables returning finalization frames + /// that may not be based on inbound data. + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { + match self.decode(buf)? { + Some(frame) => Ok(Some(frame)), + None => { + if buf.is_empty() { + Ok(None) + } else { + Err(io::Error::new(io::ErrorKind::Other, "bytes remaining on stream").into()) + } + } + } + } + + /// Provides a `Stream` and `Sink` interface for reading and writing to this + /// `Io` object, using `Decode` and `Encode` to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the `Codec` + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both `Stream` and + /// `Sink`; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling `split` on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + fn framed(self, io: T) -> Framed + where + Self: Encoder + Sized, + { + Framed::new(io, self) + } +} diff --git a/tokio-codec/src/encoder.rs b/tokio-codec/src/encoder.rs new file mode 100644 index 000000000..506508032 --- /dev/null +++ b/tokio-codec/src/encoder.rs @@ -0,0 +1,25 @@ +use bytes::BytesMut; +use std::io; + +/// Trait of helper objects to write out messages as bytes, for use with +/// `FramedWrite`. + +// Note: We can't deprecate this trait, because the deprecation carries through to tokio-codec, and +// there doesn't seem to be a way to un-deprecate the re-export. +pub trait Encoder { + /// The type of items consumed by the `Encoder` + type Item; + + /// The type of encoding errors. + /// + /// `FramedWrite` requires `Encoder`s errors to implement `From` + /// in the interest letting it return `Error`s directly. + type Error: From; + + /// Encodes a frame into the buffer provided. + /// + /// This method will encode `item` into the byte buffer provided by `dst`. + /// The `dst` provided is an internal buffer of the `Framed` instance and + /// will be written out when possible. + fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error>; +} diff --git a/tokio-codec/src/framed.rs b/tokio-codec/src/framed.rs new file mode 100644 index 000000000..1929b3eb0 --- /dev/null +++ b/tokio-codec/src/framed.rs @@ -0,0 +1,308 @@ +#![allow(deprecated)] + +use std::fmt; +use std::io::{self, Read, Write}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::decoder::Decoder; +use crate::encoder::Encoder; +use crate::framed_read::{framed_read2, framed_read2_with_buffer, FramedRead2}; +use crate::framed_write::{framed_write2, framed_write2_with_buffer, FramedWrite2}; +use tokio_futures::{Sink, Stream}; +use tokio_io::{AsyncRead, AsyncWrite}; + +use bytes::BytesMut; + +/// A unified `Stream` and `Sink` interface to an underlying I/O object, using +/// the `Encoder` and `Decoder` traits to encode and decode frames. +/// +/// You can create a `Framed` instance by using the `AsyncRead::framed` adapter. +pub struct Framed { + inner: FramedRead2>>, +} + +pub struct Fuse(pub T, pub U); + +impl Framed +where + T: AsyncRead + AsyncWrite, + U: Decoder + Encoder, +{ + /// Provides a `Stream` and `Sink` interface for reading and writing to this + /// `Io` object, using `Decode` and `Encode` to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the `Codec` + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both `Stream` and + /// `Sink`; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling `split` on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + pub fn new(inner: T, codec: U) -> Framed { + Framed { + inner: framed_read2(framed_write2(Fuse(inner, codec))), + } + } +} + +impl Framed { + /// Provides a `Stream` and `Sink` interface for reading and writing to this + /// `Io` object, using `Decode` and `Encode` to read and write the raw data. + /// + /// Raw I/O objects work with byte sequences, but higher-level code usually + /// wants to batch these into meaningful chunks, called "frames". This + /// method layers framing on top of an I/O object, by using the `Codec` + /// traits to handle encoding and decoding of messages frames. Note that + /// the incoming and outgoing frame types may be distinct. + /// + /// This function returns a *single* object that is both `Stream` and + /// `Sink`; grouping this into a single object is often useful for layering + /// things like gzip or TLS, which require both read and write access to the + /// underlying object. + /// + /// This objects takes a stream and a readbuffer and a writebuffer. These field + /// can be obtained from an existing `Framed` with the `into_parts` method. + /// + /// If you want to work more directly with the streams and sink, consider + /// calling `split` on the `Framed` returned by this method, which will + /// break them into separate objects, allowing them to interact more easily. + pub fn from_parts(parts: FramedParts) -> Framed { + Framed { + inner: framed_read2_with_buffer( + framed_write2_with_buffer(Fuse(parts.io, parts.codec), parts.write_buf), + parts.read_buf, + ), + } + } + + /// Returns a reference to the underlying I/O stream wrapped by + /// `Frame`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.get_ref().get_ref().0 + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `Frame`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.get_mut().get_mut().0 + } + + /// Returns a reference to the underlying codec wrapped by + /// `Frame`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec(&self) -> &U { + &self.inner.get_ref().get_ref().1 + } + + /// Returns a mutable reference to the underlying codec wrapped by + /// `Frame`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn codec_mut(&mut self) -> &mut U { + &mut self.inner.get_mut().get_mut().1 + } + + /// Consumes the `Frame`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.into_inner().into_inner().0 + } + + /// Consumes the `Frame`, returning its underlying I/O stream, the buffer + /// with unprocessed data, and the codec. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_parts(self) -> FramedParts { + let (inner, read_buf) = self.inner.into_parts(); + let (inner, write_buf) = inner.into_parts(); + + FramedParts { + io: inner.0, + codec: inner.1, + read_buf: read_buf, + write_buf: write_buf, + _priv: (), + } + } +} + +impl Stream for Framed +where + T: AsyncRead + Unpin, + U: Decoder + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(self.get_mut().inner).poll_next(cx) + } +} + +impl Sink for Framed +where + T: AsyncWrite + Unpin, + U: Encoder + Unpin, + U::Error: From, +{ + type Error = U::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(Pin::get_mut(self).inner.get_mut()).poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + Pin::new(Pin::get_mut(self).inner.get_mut()).start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(Pin::get_mut(self).inner.get_mut()).poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(Pin::get_mut(self).inner.get_mut()).poll_close(cx) + } +} + +impl fmt::Debug for Framed +where + T: fmt::Debug, + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Framed") + .field("io", &self.inner.get_ref().get_ref().0) + .field("codec", &self.inner.get_ref().get_ref().1) + .finish() + } +} + +// ===== impl Fuse ===== + +impl Read for Fuse { + fn read(&mut self, dst: &mut [u8]) -> io::Result { + self.0.read(dst) + } +} + +impl AsyncRead for Fuse { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.0.prepare_uninitialized_buffer(buf) + } + + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + pin!(self.get_mut().0).poll_read(cx, buf) + } +} + +impl Write for Fuse { + fn write(&mut self, src: &[u8]) -> io::Result { + self.0.write(src) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } +} + +impl AsyncWrite for Fuse { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + pin!(self.get_mut().0).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(self.get_mut().0).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(self.get_mut().0).poll_shutdown(cx) + } +} + +impl Decoder for Fuse { + type Item = U::Item; + type Error = U::Error; + + fn decode(&mut self, buffer: &mut BytesMut) -> Result, Self::Error> { + self.1.decode(buffer) + } + + fn decode_eof(&mut self, buffer: &mut BytesMut) -> Result, Self::Error> { + self.1.decode_eof(buffer) + } +} + +impl Encoder for Fuse { + type Item = U::Item; + type Error = U::Error; + + fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { + self.1.encode(item, dst) + } +} + +/// `FramedParts` contains an export of the data of a Framed transport. +/// It can be used to construct a new `Framed` with a different codec. +/// It contains all current buffers and the inner transport. +#[derive(Debug)] +pub struct FramedParts { + /// The inner transport used to read bytes to and write bytes to + pub io: T, + + /// The codec + pub codec: U, + + /// The buffer with read but unprocessed data. + pub read_buf: BytesMut, + + /// A buffer with unprocessed data which are not written yet. + pub write_buf: BytesMut, + + /// This private field allows us to add additional fields in the future in a + /// backwards compatible way. + _priv: (), +} + +impl FramedParts { + /// Create a new, default, `FramedParts` + pub fn new(io: T, codec: U) -> FramedParts { + FramedParts { + io, + codec, + read_buf: BytesMut::new(), + write_buf: BytesMut::new(), + _priv: (), + } + } +} diff --git a/tokio-codec/src/framed_read.rs b/tokio-codec/src/framed_read.rs new file mode 100644 index 000000000..13c475414 --- /dev/null +++ b/tokio-codec/src/framed_read.rs @@ -0,0 +1,225 @@ +use std::fmt; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use super::framed::Fuse; +use super::Decoder; +use tokio_futures::{Sink, Stream}; +use tokio_io::AsyncRead; + +use bytes::BytesMut; +use log::trace; + +/// A `Stream` of messages decoded from an `AsyncRead`. +pub struct FramedRead { + inner: FramedRead2>, +} + +pub struct FramedRead2 { + inner: T, + eof: bool, + is_readable: bool, + buffer: BytesMut, +} + +const INITIAL_CAPACITY: usize = 8 * 1024; + +// ===== impl FramedRead ===== + +impl FramedRead +where + T: AsyncRead, + D: Decoder, +{ + /// Creates a new `FramedRead` with the given `decoder`. + pub fn new(inner: T, decoder: D) -> FramedRead { + FramedRead { + inner: framed_read2(Fuse(inner, decoder)), + } + } +} + +impl FramedRead { + /// Returns a reference to the underlying I/O stream wrapped by + /// `FramedRead`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.inner.0 + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `FramedRead`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.inner.0 + } + + /// Consumes the `FramedRead`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.inner.0 + } + + /// Returns a reference to the underlying decoder. + pub fn decoder(&self) -> &D { + &self.inner.inner.1 + } + + /// Returns a mutable reference to the underlying decoder. + pub fn decoder_mut(&mut self) -> &mut D { + &mut self.inner.inner.1 + } +} + +impl Stream for FramedRead +where + T: AsyncRead + Unpin, + D: Decoder + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(self.get_mut().inner).poll_next(cx) + } +} + +// This impl just defers to the underlying T: Sink +impl Sink for FramedRead +where + T: Sink + Unpin, + D: Unpin, +{ + type Error = T::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(Pin::get_mut(self).inner.inner.0).poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + pin!(Pin::get_mut(self).inner.inner.0).start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(Pin::get_mut(self).inner.inner.0).poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(Pin::get_mut(self).inner.inner.0).poll_close(cx) + } +} + +impl fmt::Debug for FramedRead +where + T: fmt::Debug, + D: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FramedRead") + .field("inner", &self.inner.inner.0) + .field("decoder", &self.inner.inner.1) + .field("eof", &self.inner.eof) + .field("is_readable", &self.inner.is_readable) + .field("buffer", &self.inner.buffer) + .finish() + } +} + +// ===== impl FramedRead2 ===== + +pub fn framed_read2(inner: T) -> FramedRead2 { + FramedRead2 { + inner: inner, + eof: false, + is_readable: false, + buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + } +} + +pub fn framed_read2_with_buffer(inner: T, mut buf: BytesMut) -> FramedRead2 { + if buf.capacity() < INITIAL_CAPACITY { + let bytes_to_reserve = INITIAL_CAPACITY - buf.capacity(); + buf.reserve(bytes_to_reserve); + } + FramedRead2 { + inner: inner, + eof: false, + is_readable: buf.len() > 0, + buffer: buf, + } +} + +impl FramedRead2 { + pub fn get_ref(&self) -> &T { + &self.inner + } + + pub fn into_inner(self) -> T { + self.inner + } + + pub fn into_parts(self) -> (T, BytesMut) { + (self.inner, self.buffer) + } + + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner + } +} + +impl Stream for FramedRead2 +where + T: AsyncRead + Decoder + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let pinned = Pin::get_mut(self); + loop { + // Repeatedly call `decode` or `decode_eof` as long as it is + // "readable". Readable is defined as not having returned `None`. If + // the upstream has returned EOF, and the decoder is no longer + // readable, it can be assumed that the decoder will never become + // readable again, at which point the stream is terminated. + if pinned.is_readable { + if pinned.eof { + let frame = pinned.inner.decode_eof(&mut pinned.buffer)?; + return Poll::Ready(frame.map(Ok)); + } + + trace!("attempting to decode a frame"); + + if let Some(frame) = pinned.inner.decode(&mut pinned.buffer)? { + trace!("frame decoded from buffer"); + return Poll::Ready(Some(Ok(frame))); + } + + pinned.is_readable = false; + } + + assert!(!pinned.eof); + + // Otherwise, try to read more data and try again. Make sure we've + // got room for at least one byte to read to ensure that we don't + // get a spurious 0 that looks like EOF + pinned.buffer.reserve(1); + let bytect = match pin!(pinned.inner).poll_read_buf(cx, &mut pinned.buffer)? { + Poll::Ready(ct) => ct, + Poll::Pending => return Poll::Pending, + }; + if bytect == 0 { + pinned.eof = true; + } + + pinned.is_readable = true; + } + } +} diff --git a/tokio-codec/src/framed_write.rs b/tokio-codec/src/framed_write.rs new file mode 100644 index 000000000..153f58815 --- /dev/null +++ b/tokio-codec/src/framed_write.rs @@ -0,0 +1,271 @@ +#![allow(deprecated)] + +use log::trace; +use std::fmt; +use std::io::{self, Read}; + +use super::framed::Fuse; +use crate::decoder::Decoder; +use crate::encoder::Encoder; +use tokio_futures::{Sink, Stream}; +use tokio_io::{AsyncRead, AsyncWrite}; + +use bytes::BytesMut; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// A `Sink` of frames encoded to an `AsyncWrite`. +pub struct FramedWrite { + inner: FramedWrite2>, +} + +pub struct FramedWrite2 { + inner: T, + buffer: BytesMut, +} + +const INITIAL_CAPACITY: usize = 8 * 1024; +const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY; + +impl FramedWrite +where + T: AsyncWrite, + E: Encoder, +{ + /// Creates a new `FramedWrite` with the given `encoder`. + pub fn new(inner: T, encoder: E) -> FramedWrite { + FramedWrite { + inner: framed_write2(Fuse(inner, encoder)), + } + } +} + +impl FramedWrite { + /// Returns a reference to the underlying I/O stream wrapped by + /// `FramedWrite`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_ref(&self) -> &T { + &self.inner.inner.0 + } + + /// Returns a mutable reference to the underlying I/O stream wrapped by + /// `FramedWrite`. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner.inner.0 + } + + /// Consumes the `FramedWrite`, returning its underlying I/O stream. + /// + /// Note that care should be taken to not tamper with the underlying stream + /// of data coming in as it may corrupt the stream of frames otherwise + /// being worked with. + pub fn into_inner(self) -> T { + self.inner.inner.0 + } + + /// Returns a reference to the underlying decoder. + pub fn encoder(&self) -> &E { + &self.inner.inner.1 + } + + /// Returns a mutable reference to the underlying decoder. + pub fn encoder_mut(&mut self) -> &mut E { + &mut self.inner.inner.1 + } +} + +// This impl just defers to the underlying FramedWrite2 +impl Sink for FramedWrite +where + T: AsyncWrite + Unpin, + E: Encoder + Unpin, + E::Error: From, +{ + type Error = E::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(Pin::get_mut(self).inner).poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + pin!(Pin::get_mut(self).inner).start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(Pin::get_mut(self).inner).poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(Pin::get_mut(self).inner).poll_close(cx) + } +} + +impl Stream for FramedWrite +where + T: Stream + Unpin, + D: Unpin, +{ + type Item = T::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(Pin::get_mut(self).get_mut()).poll_next(cx) + } +} + +impl fmt::Debug for FramedWrite +where + T: fmt::Debug, + U: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FramedWrite") + .field("inner", &self.inner.get_ref().0) + .field("encoder", &self.inner.get_ref().1) + .field("buffer", &self.inner.buffer) + .finish() + } +} + +// ===== impl FramedWrite2 ===== + +pub fn framed_write2(inner: T) -> FramedWrite2 { + FramedWrite2 { + inner: inner, + buffer: BytesMut::with_capacity(INITIAL_CAPACITY), + } +} + +pub fn framed_write2_with_buffer(inner: T, mut buf: BytesMut) -> FramedWrite2 { + if buf.capacity() < INITIAL_CAPACITY { + let bytes_to_reserve = INITIAL_CAPACITY - buf.capacity(); + buf.reserve(bytes_to_reserve); + } + FramedWrite2 { + inner: inner, + buffer: buf, + } +} + +impl FramedWrite2 { + pub fn get_ref(&self) -> &T { + &self.inner + } + + pub fn into_inner(self) -> T { + self.inner + } + + pub fn into_parts(self) -> (T, BytesMut) { + (self.inner, self.buffer) + } + + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner + } +} + +impl Sink for FramedWrite2 +where + T: AsyncWrite + Encoder + Unpin, +{ + type Error = T::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // If the buffer is already over 8KiB, then attempt to flush it. If after flushing it's + // *still* over 8KiB, then apply backpressure (reject the send). + if self.buffer.len() >= BACKPRESSURE_BOUNDARY { + match self.as_mut().poll_flush(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Ready(Ok(())) => (), + }; + + if self.buffer.len() >= BACKPRESSURE_BOUNDARY { + return Poll::Pending; + } + } + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + let pinned = Pin::get_mut(self); + pinned.inner.encode(item, &mut pinned.buffer)?; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + trace!("flushing framed transport"); + let pinned = Pin::get_mut(self); + + while !pinned.buffer.is_empty() { + trace!("writing; remaining={}", pinned.buffer.len()); + + let buf = &pinned.buffer; + let n = try_ready!(pin!(pinned.inner).poll_write(cx, &buf)); + + if n == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to \ + write frame to transport", + ) + .into())); + } + + // TODO: Add a way to `bytes` to do this w/o returning the drained data. + let _ = pinned.buffer.split_to(n); + } + + // Try flushing the underlying IO + try_ready!(pin!(pinned.inner).poll_flush(cx)); + + trace!("framed transport flushed"); + Poll::Ready(Ok(())) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let () = try_ready!(pin!(self).poll_flush(cx)); + let () = try_ready!(pin!(self.inner).poll_shutdown(cx)); + Poll::Ready(Ok(())) + } +} + +impl Decoder for FramedWrite2 { + type Item = T::Item; + type Error = T::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, T::Error> { + self.inner.decode(src) + } + + fn decode_eof(&mut self, src: &mut BytesMut) -> Result, T::Error> { + self.inner.decode_eof(src) + } +} + +impl Read for FramedWrite2 { + fn read(&mut self, dst: &mut [u8]) -> io::Result { + self.inner.read(dst) + } +} + +impl AsyncRead for FramedWrite2 { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.inner.prepare_uninitialized_buffer(buf) + } + + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + pin!(Pin::get_mut(self).inner).poll_read(cx, buf) + } +} diff --git a/tokio-codec/src/lib.rs b/tokio-codec/src/lib.rs index 1e22feb48..d8abd5cee 100644 --- a/tokio-codec/src/lib.rs +++ b/tokio-codec/src/lib.rs @@ -15,9 +15,21 @@ //! [`Stream`]: # //! [transports]: # +#[macro_use] +mod macros; + mod bytes_codec; +mod decoder; +mod encoder; +mod framed; +mod framed_read; +mod framed_write; mod lines_codec; pub use crate::bytes_codec::BytesCodec; +pub use crate::decoder::Decoder; +pub use crate::encoder::Encoder; +pub use crate::framed::{Framed, FramedParts}; +pub use crate::framed_read::FramedRead; +pub use crate::framed_write::FramedWrite; pub use crate::lines_codec::LinesCodec; -pub use tokio_io::_tokio_codec::{Decoder, Encoder, Framed, FramedParts, FramedRead, FramedWrite}; diff --git a/tokio-codec/src/lines_codec.rs b/tokio-codec/src/lines_codec.rs index 9422d312f..5ebfec004 100644 --- a/tokio-codec/src/lines_codec.rs +++ b/tokio-codec/src/lines_codec.rs @@ -1,6 +1,7 @@ +use crate::decoder::Decoder; +use crate::encoder::Encoder; use bytes::{BufMut, BytesMut}; -use std::{cmp, io, str, usize}; -use tokio_io::_tokio_codec::{Decoder, Encoder}; +use std::{cmp, fmt, io, str, usize}; /// A simple `Codec` implementation that splits up data into lines. #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] @@ -117,11 +118,9 @@ fn without_carriage_return(s: &[u8]) -> &[u8] { impl Decoder for LinesCodec { type Item = String; - // TODO: in the next breaking change, this should be changed to a custom - // error type that indicates the "max length exceeded" condition better. - type Error = io::Error; + type Error = LinesCodecError; - fn decode(&mut self, buf: &mut BytesMut) -> Result, io::Error> { + fn decode(&mut self, buf: &mut BytesMut) -> Result, LinesCodecError> { loop { // Determine how far into the buffer we'll search for a newline. If // there's no max_length set, we'll read to the end of the buffer. @@ -149,10 +148,7 @@ impl Decoder for LinesCodec { // newline, return an error and start discarding on the // next call. self.is_discarding = true; - Err(io::Error::new( - io::ErrorKind::Other, - "line length limit exceeded", - )) + Err(LinesCodecError::MaxLineLengthExceeded) } else { // We didn't find a line or reach the length limit, so the next // call will resume searching at the current offset. @@ -163,7 +159,7 @@ impl Decoder for LinesCodec { } } - fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, io::Error> { + fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, LinesCodecError> { Ok(match self.decode(buf)? { Some(frame) => Some(frame), None => { @@ -184,12 +180,35 @@ impl Decoder for LinesCodec { impl Encoder for LinesCodec { type Item = String; - type Error = io::Error; + type Error = LinesCodecError; - fn encode(&mut self, line: String, buf: &mut BytesMut) -> Result<(), io::Error> { + fn encode(&mut self, line: String, buf: &mut BytesMut) -> Result<(), LinesCodecError> { buf.reserve(line.len() + 1); buf.put(line); buf.put_u8(b'\n'); Ok(()) } } + +#[derive(Debug)] +pub enum LinesCodecError { + MaxLineLengthExceeded, + Io(io::Error), +} + +impl fmt::Display for LinesCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + LinesCodecError::MaxLineLengthExceeded => write!(f, "max line length exceeded"), + LinesCodecError::Io(e) => write!(f, "{}", e), + } + } +} + +impl From for LinesCodecError { + fn from(e: io::Error) -> LinesCodecError { + LinesCodecError::Io(e) + } +} + +impl std::error::Error for LinesCodecError {} diff --git a/tokio-codec/src/macros.rs b/tokio-codec/src/macros.rs new file mode 100644 index 000000000..871f81822 --- /dev/null +++ b/tokio-codec/src/macros.rs @@ -0,0 +1,22 @@ +// TODO this macro should probably be somewhere in tokio-futures +/// A macro for extracting the successful type of a `Poll`. +/// +/// This macro bakes propagation of both errors and `Pending` signals by +/// returning early. +macro_rules! try_ready { + ($e:expr) => { + match $e { + std::task::Poll::Pending => return std::task::Poll::Pending, + std::task::Poll::Ready(Err(e)) => return std::task::Poll::Ready(Err(From::from(e))), + std::task::Poll::Ready(Ok(t)) => t, + } + }; +} + +/// A macro to reduce some of the boilerplate for projecting from +/// `Pin<&mut T>` to `Pin<&mut T.field>` +macro_rules! pin { + ($e:expr) => { + std::pin::Pin::new(&mut $e) + }; +} diff --git a/tokio-codec/tests/codecs.rs b/tokio-codec/tests/codecs.rs index 4256bbabf..6a5117158 100644 --- a/tokio-codec/tests/codecs.rs +++ b/tokio-codec/tests/codecs.rs @@ -118,8 +118,8 @@ fn lines_decoder_max_length() { // Line that's one character too long. This could cause an out of bounds // error if we peek at the next characters using slice indexing. - // buf.put("aaabbbc"); - // assert!(codec.decode(buf).is_err()); + buf.put("aaabbbc"); + assert!(codec.decode(buf).is_err()); } #[test] diff --git a/tokio-codec/tests/framed.rs b/tokio-codec/tests/framed.rs index b17472ffb..97f01228c 100644 --- a/tokio-codec/tests/framed.rs +++ b/tokio-codec/tests/framed.rs @@ -1,11 +1,16 @@ #![deny(warnings, rust_2018_idioms)] -use bytes::{Buf, BufMut, BytesMut, IntoBuf}; -use futures::{Future, Stream}; use std::io::{self, Read}; +use std::pin::Pin; +use std::task::{Context, Poll}; + use tokio_codec::{Decoder, Encoder, Framed, FramedParts}; +use tokio_current_thread::block_on_all; use tokio_io::AsyncRead; +use bytes::{Buf, BufMut, BytesMut, IntoBuf}; +use futures::prelude::{FutureExt, StreamExt}; + const INITIAL_CAPACITY: usize = 8 * 1024; /// Encode and decode u32 values. @@ -49,7 +54,15 @@ impl Read for DontReadIntoThis { } } -impl AsyncRead for DontReadIntoThis {} +impl AsyncRead for DontReadIntoThis { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut [u8], + ) -> Poll> { + unreachable!() + } +} #[test] fn can_read_from_existing_buf() { @@ -58,12 +71,12 @@ fn can_read_from_existing_buf() { let framed = Framed::from_parts(parts); - let num = framed - .into_future() - .map(|(first_num, _)| first_num.unwrap()) - .wait() - .map_err(|e| e.0) - .unwrap(); + let num = block_on_all( + framed + .into_future() + .map(|(first_num, _)| first_num.unwrap()), + ) + .unwrap(); assert_eq!(num, 42); } diff --git a/tokio-codec/tests/framed_read.rs b/tokio-codec/tests/framed_read.rs index ee7cc4566..64d574c08 100644 --- a/tokio-codec/tests/framed_read.rs +++ b/tokio-codec/tests/framed_read.rs @@ -1,12 +1,18 @@ #![deny(warnings, rust_2018_idioms)] -use bytes::{Buf, BytesMut, IntoBuf}; -use futures::Async::{NotReady, Ready}; -use futures::Stream; use std::collections::VecDeque; use std::io::{self, Read}; +use std::pin::Pin; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; + +use bytes::{Buf, BytesMut, IntoBuf}; +use futures::Stream; + use tokio_codec::{Decoder, FramedRead}; use tokio_io::AsyncRead; +use tokio_test::assert_ready; +use tokio_test::task::MockTask; macro_rules! mock { ($($x:expr,)*) => {{ @@ -16,6 +22,19 @@ macro_rules! mock { }}; } +macro_rules! assert_read { + ($e:expr, $n:expr) => {{ + let val = assert_ready!($e); + assert_eq!(val.unwrap().unwrap(), $n); + }}; +} + +macro_rules! pin { + ($id:ident) => { + Pin::new(&mut $id) + }; +} + struct U32Decoder; impl Decoder for U32Decoder { @@ -34,104 +53,146 @@ impl Decoder for U32Decoder { #[test] fn read_multi_frame_in_packet() { + let mut task = MockTask::new(); let mock = mock! { Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), }; - let mut framed = FramedRead::new(mock, U32Decoder); - assert_eq!(Ready(Some(0)), framed.poll().unwrap()); - assert_eq!(Ready(Some(1)), framed.poll().unwrap()); - assert_eq!(Ready(Some(2)), framed.poll().unwrap()); - assert_eq!(Ready(None), framed.poll().unwrap()); + + task.enter(|cx| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); } #[test] fn read_multi_frame_across_packets() { + let mut task = MockTask::new(); let mock = mock! { Ok(b"\x00\x00\x00\x00".to_vec()), Ok(b"\x00\x00\x00\x01".to_vec()), Ok(b"\x00\x00\x00\x02".to_vec()), }; - let mut framed = FramedRead::new(mock, U32Decoder); - assert_eq!(Ready(Some(0)), framed.poll().unwrap()); - assert_eq!(Ready(Some(1)), framed.poll().unwrap()); - assert_eq!(Ready(Some(2)), framed.poll().unwrap()); - assert_eq!(Ready(None), framed.poll().unwrap()); + + task.enter(|cx| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); } #[test] fn read_not_ready() { + let mut task = MockTask::new(); let mock = mock! { Err(io::Error::new(io::ErrorKind::WouldBlock, "")), Ok(b"\x00\x00\x00\x00".to_vec()), Ok(b"\x00\x00\x00\x01".to_vec()), }; - let mut framed = FramedRead::new(mock, U32Decoder); - assert_eq!(NotReady, framed.poll().unwrap()); - assert_eq!(Ready(Some(0)), framed.poll().unwrap()); - assert_eq!(Ready(Some(1)), framed.poll().unwrap()); - assert_eq!(Ready(None), framed.poll().unwrap()); + + task.enter(|cx| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); } #[test] fn read_partial_then_not_ready() { + let mut task = MockTask::new(); let mock = mock! { Ok(b"\x00\x00".to_vec()), Err(io::Error::new(io::ErrorKind::WouldBlock, "")), Ok(b"\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), }; - let mut framed = FramedRead::new(mock, U32Decoder); - assert_eq!(NotReady, framed.poll().unwrap()); - assert_eq!(Ready(Some(0)), framed.poll().unwrap()); - assert_eq!(Ready(Some(1)), framed.poll().unwrap()); - assert_eq!(Ready(Some(2)), framed.poll().unwrap()); - assert_eq!(Ready(None), framed.poll().unwrap()); + + task.enter(|cx| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); } #[test] fn read_err() { + let mut task = MockTask::new(); let mock = mock! { Err(io::Error::new(io::ErrorKind::Other, "")), }; - let mut framed = FramedRead::new(mock, U32Decoder); - assert_eq!(io::ErrorKind::Other, framed.poll().unwrap_err().kind()); + + task.enter(|cx| { + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); } #[test] fn read_partial_then_err() { + let mut task = MockTask::new(); let mock = mock! { Ok(b"\x00\x00".to_vec()), Err(io::Error::new(io::ErrorKind::Other, "")), }; - let mut framed = FramedRead::new(mock, U32Decoder); - assert_eq!(io::ErrorKind::Other, framed.poll().unwrap_err().kind()); + + task.enter(|cx| { + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); } #[test] fn read_partial_would_block_then_err() { + let mut task = MockTask::new(); let mock = mock! { Ok(b"\x00\x00".to_vec()), Err(io::Error::new(io::ErrorKind::WouldBlock, "")), Err(io::Error::new(io::ErrorKind::Other, "")), }; - let mut framed = FramedRead::new(mock, U32Decoder); - assert_eq!(NotReady, framed.poll().unwrap()); - assert_eq!(io::ErrorKind::Other, framed.poll().unwrap_err().kind()); + + task.enter(|cx| { + assert!(pin!(framed).poll_next(cx).is_pending()); + assert_eq!( + io::ErrorKind::Other, + assert_ready!(pin!(framed).poll_next(cx)) + .unwrap() + .unwrap_err() + .kind() + ) + }); } #[test] fn huge_size() { + let mut task = MockTask::new(); let data = [0; 32 * 1024]; + let mut framed = FramedRead::new(Slice(&data[..]), BigDecoder); - let mut framed = FramedRead::new(&data[..], BigDecoder); - assert_eq!(Ready(Some(0)), framed.poll().unwrap()); - assert_eq!(Ready(None), framed.poll().unwrap()); + task.enter(|cx| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); struct BigDecoder; @@ -151,15 +212,19 @@ fn huge_size() { #[test] fn data_remaining_is_error() { - let data = [0; 5]; + let mut task = MockTask::new(); + let slice = Slice(&[0; 5]); + let mut framed = FramedRead::new(slice, U32Decoder); - let mut framed = FramedRead::new(&data[..], U32Decoder); - assert_eq!(Ready(Some(0)), framed.poll().unwrap()); - assert!(framed.poll().is_err()); + task.enter(|cx| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert!(assert_ready!(pin!(framed).poll_next(cx)).unwrap().is_err()); + }); } #[test] fn multi_frames_on_eof() { + let mut task = MockTask::new(); struct MyDecoder(Vec); impl Decoder for MyDecoder { @@ -180,11 +245,14 @@ fn multi_frames_on_eof() { } let mut framed = FramedRead::new(mock!(), MyDecoder(vec![0, 1, 2, 3])); - assert_eq!(Ready(Some(0)), framed.poll().unwrap()); - assert_eq!(Ready(Some(1)), framed.poll().unwrap()); - assert_eq!(Ready(Some(2)), framed.poll().unwrap()); - assert_eq!(Ready(Some(3)), framed.poll().unwrap()); - assert_eq!(Ready(None), framed.poll().unwrap()); + + task.enter(|cx| { + assert_read!(pin!(framed).poll_next(cx), 0); + assert_read!(pin!(framed).poll_next(cx), 1); + assert_read!(pin!(framed).poll_next(cx), 2); + assert_read!(pin!(framed).poll_next(cx), 3); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); } // ===== Mock ====== @@ -207,4 +275,28 @@ impl Read for Mock { } } -impl AsyncRead for Mock {} +impl AsyncRead for Mock { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + match Pin::get_mut(self).read(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, + other => Ready(other), + } + } +} + +// TODO this newtype is necessary because `&[u8]` does not currently implement `AsyncRead` +struct Slice<'a>(&'a [u8]); + +impl<'a> AsyncRead for Slice<'a> { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Ready(Pin::get_mut(self).0.read(buf)) + } +} diff --git a/tokio-codec/tests/framed_write.rs b/tokio-codec/tests/framed_write.rs index 8f7a1c12f..f29ce2900 100644 --- a/tokio-codec/tests/framed_write.rs +++ b/tokio-codec/tests/framed_write.rs @@ -1,11 +1,17 @@ #![deny(warnings, rust_2018_idioms)] use bytes::{BufMut, BytesMut}; -use futures::{Poll, Sink}; use std::collections::VecDeque; -use std::io::{self, Write}; use tokio_codec::{Encoder, FramedWrite}; +use tokio_futures::Sink; use tokio_io::AsyncWrite; +use tokio_test::assert_ready; +use tokio_test::task::MockTask; + +use std::io::{self, Write}; +use std::pin::Pin; +use std::task::Poll::{Pending, Ready}; +use std::task::{Context, Poll}; macro_rules! mock { ($($x:expr,)*) => {{ @@ -15,6 +21,12 @@ macro_rules! mock { }}; } +macro_rules! pin { + ($id:ident) => { + Pin::new(&mut $id) + }; +} + struct U32Encoder; impl Encoder for U32Encoder { @@ -31,22 +43,28 @@ impl Encoder for U32Encoder { #[test] fn write_multi_frame_in_packet() { + let mut task = MockTask::new(); let mock = mock! { Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()), }; - let mut framed = FramedWrite::new(mock, U32Encoder); - assert!(framed.start_send(0).unwrap().is_ready()); - assert!(framed.start_send(1).unwrap().is_ready()); - assert!(framed.start_send(2).unwrap().is_ready()); - // Nothing written yet - assert_eq!(1, framed.get_ref().calls.len()); + task.enter(|cx| { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(0).is_ok()); + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(1).is_ok()); + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(2).is_ok()); - // Flush the writes - assert!(framed.poll_complete().unwrap().is_ready()); + // Nothing written yet + assert_eq!(1, framed.get_ref().calls.len()); - assert_eq!(0, framed.get_ref().calls.len()); + // Flush the writes + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); + + assert_eq!(0, framed.get_ref().calls.len()); + }); } #[test] @@ -59,7 +77,7 @@ fn write_hits_backpressure() { Ok(b"".to_vec()), }; - for i in 0..(ITER + 1) { + for i in 0..=ITER { let mut b = BytesMut::with_capacity(4); b.put_u32_be(i as u32); @@ -70,7 +88,7 @@ fn write_hits_backpressure() { if data.len() < ITER { data.extend_from_slice(&b[..]); continue; - } + } // else fall through and create a new buffer } _ => unreachable!(), } @@ -78,27 +96,38 @@ fn write_hits_backpressure() { // Push a new new chunk mock.calls.push_back(Ok(b[..].to_vec())); } + // 1 'wouldblock', 4 * 2KB buffers, 1 b-byte buffer + assert_eq!(mock.calls.len(), 6); + let mut task = MockTask::new(); let mut framed = FramedWrite::new(mock, U32Encoder); + task.enter(|cx| { + // Send 8KB. This fills up FramedWrite2 buffer + for i in 0..ITER { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(i as u32).is_ok()); + } - for i in 0..ITER { - assert!(framed.start_send(i as u32).unwrap().is_ready()); - } + // Now we poll_ready which forces a flush. The mock pops the front message + // and decides to block. + assert!(pin!(framed).poll_ready(cx).is_pending()); - // This should reject - assert!(!framed.start_send(ITER as u32).unwrap().is_ready()); + // We poll again, forcing another flush, which this time succeeds + // The whole 8KB buffer is flushed + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); - // This should succeed and start flushing the buffer. - assert!(framed.start_send(ITER as u32).unwrap().is_ready()); + // Send more data. This matches the final message expected by the mock + assert!(pin!(framed).start_send(ITER as u32).is_ok()); - // Flush the rest of the buffer - assert!(framed.poll_complete().unwrap().is_ready()); + // Flush the rest of the buffer + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); - // Ensure the mock is empty - assert_eq!(0, framed.get_ref().calls.len()); + // Ensure the mock is empty + assert_eq!(0, framed.get_ref().calls.len()); + }) } -// ===== Mock ====== +// // ===== Mock ====== struct Mock { calls: VecDeque>>, @@ -123,7 +152,23 @@ impl Write for Mock { } impl AsyncWrite for Mock { - fn shutdown(&mut self) -> Poll<(), io::Error> { - Ok(().into()) + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match Pin::get_mut(self).write(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, + other => Ready(other), + } + } + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + match Pin::get_mut(self).flush() { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending, + other => Ready(other), + } + } + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + unimplemented!() } }