diff --git a/tokio-util/src/compat.rs b/tokio-util/src/compat.rs index 4e13ac52e..6a8802d96 100644 --- a/tokio-util/src/compat.rs +++ b/tokio-util/src/compat.rs @@ -13,6 +13,7 @@ pin_project! { pub struct Compat { #[pin] inner: T, + seek_pos: Option, } } @@ -80,7 +81,10 @@ impl TokioAsyncWriteCompatExt for T {} impl Compat { fn new(inner: T) -> Self { - Self { inner } + Self { + inner, + seek_pos: None, + } } /// Get a reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object @@ -216,6 +220,45 @@ where } } +impl futures_io::AsyncSeek for Compat { + fn poll_seek( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + pos: io::SeekFrom, + ) -> Poll> { + if self.seek_pos != Some(pos) { + self.as_mut().project().inner.start_seek(pos)?; + *self.as_mut().project().seek_pos = Some(pos); + } + let res = ready!(self.as_mut().project().inner.poll_complete(cx)); + *self.as_mut().project().seek_pos = None; + Poll::Ready(res.map(|p| p as u64)) + } +} + +impl tokio::io::AsyncSeek for Compat { + fn start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> { + *self.as_mut().project().seek_pos = Some(pos); + Ok(()) + } + + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let pos = match self.seek_pos { + None => { + // tokio 1.x AsyncSeek recommends calling poll_complete before start_seek. + // We don't have to guarantee that the value returned by + // poll_complete called without start_seek is correct, + // so we'll return 0. + return Poll::Ready(Ok(0)); + } + Some(pos) => pos, + }; + let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos)); + *self.as_mut().project().seek_pos = None; + Poll::Ready(res.map(|p| p as u64)) + } +} + #[cfg(unix)] impl std::os::unix::io::AsRawFd for Compat { fn as_raw_fd(&self) -> std::os::unix::io::RawFd {