diff --git a/src/bin/echo.rs b/src/bin/echo.rs index a6fb1c324..6bc76b7c7 100644 --- a/src/bin/echo.rs +++ b/src/bin/echo.rs @@ -5,11 +5,14 @@ extern crate futures_io; extern crate futures_mio; use std::env; +use std::io::{self, Read, Write}; use std::net::SocketAddr; +use std::sync::Arc; use futures::Future; -use futures_io::{copy, TaskIo}; use futures::stream::Stream; +use futures_io::copy; +use futures_mio::TcpStream; fn main() { let addr = env::args().nth(1).unwrap_or("127.0.0.1:8080".to_string()); @@ -27,17 +30,14 @@ fn main() { // Pull out the stream of incoming connections and then for each new // one spin up a new task copying data. We put the `socket` into a - // `TaskIo` structure which then allows us to `split` it into the read - // and write halves of the socket. + // `Arc` structure which then allows us to share it across the + // read/write halves with a small shim. // // Finally we use the `io::copy` future to copy all data from the // reading half onto the writing half. socket.incoming().for_each(|(socket, addr)| { - let io = TaskIo::new(socket); - let pair = io.map(|io| io.split()); - let amt = pair.and_then(|(reader, writer)| { - copy(reader, writer) - }); + let socket = Arc::new(socket); + let amt = copy(SocketIo(socket.clone()), SocketIo(socket)); // Once all that is done we print out how much we wrote, and then // critically we *forget* this future which allows it to run @@ -51,3 +51,21 @@ fn main() { }); l.run(done).unwrap(); } + +struct SocketIo(Arc); + +impl Read for SocketIo { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + (&*self.0).read(buf) + } +} + +impl Write for SocketIo { + fn write(&mut self, buf: &[u8]) -> io::Result { + (&*self.0).write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + (&*self.0).flush() + } +} diff --git a/src/bin/sink.rs b/src/bin/sink.rs index 9f81b30c0..3d572fb4c 100644 --- a/src/bin/sink.rs +++ b/src/bin/sink.rs @@ -9,7 +9,6 @@ extern crate futures_io; extern crate futures_mio; use std::env; -use std::io::{self, Write}; use std::net::SocketAddr; use futures::Future; @@ -24,7 +23,7 @@ fn main() { let server = l.handle().tcp_listen(&addr).and_then(|socket| { socket.incoming().and_then(|(socket, addr)| { println!("got a socket: {}", addr); - write(socket) + write(socket).or_else(|_| Ok(())) }).for_each(|()| { println!("lost the socket"); Ok(()) @@ -34,20 +33,10 @@ fn main() { l.run(server).unwrap(); } +// TODO: this blows the stack... fn write(socket: futures_mio::TcpStream) -> IoFuture<()> { - static BUF: &'static [u8] = &[0; 64 * 1024]; - socket.into_future().map_err(|e| e.0).and_then(move |(ready, mut socket)| { - let ready = match ready { - Some(ready) => ready, - None => return futures::finished(()).boxed(), - }; - while ready.is_write() { - match socket.write(&BUF) { - Ok(_) => {} - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break, - Err(e) => return futures::failed(e).boxed(), - } - } + static BUF: &'static [u8] = &[0; 1 * 1024 * 1024]; + futures_io::write_all(socket, BUF).and_then(|(socket, _)| { write(socket) }).boxed() } diff --git a/src/event_loop.rs b/src/event_loop.rs index 8a68013c3..0b00ba91a 100644 --- a/src/event_loop.rs +++ b/src/event_loop.rs @@ -8,9 +8,9 @@ use std::sync::atomic::{AtomicUsize, ATOMIC_USIZE_INIT, Ordering}; use std::sync::mpsc; use std::time::{Instant, Duration}; -use futures::{Future, Task, TaskHandle, Poll}; +use futures::{Future, Poll}; +use futures::task::{self, TaskHandle}; use futures::executor::{ExecuteCallback, Executor}; -use futures_io::Ready; use mio; use slab::Slab; @@ -80,7 +80,8 @@ pub struct LoopPin { struct Scheduled { source: IoSource, - waiter: Option, + reader: Option, + writer: Option, } enum TimeoutState { @@ -89,11 +90,15 @@ enum TimeoutState { Waiting(TaskHandle), } +enum Direction { + Read, + Write, +} + enum Message { AddSource(IoSource, Arc>>), DropSource(usize), - Schedule(usize, TaskHandle), - Deschedule(usize), + Schedule(usize, TaskHandle, Direction), AddTimeout(Instant, Arc>>), UpdateTimeout(TimeoutToken, TaskHandle), CancelTimeout(TimeoutToken), @@ -258,13 +263,15 @@ impl Loop { // supposed to do. If there's a waiter we get ready to notify // it, and we also or-in atomically any events that have // happened (currently read/write events). - let mut waiter = None; + let mut reader = None; + let mut writer = None; if let Some(sched) = self.dispatch.borrow_mut().get_mut(token) { - waiter = sched.waiter.take(); if event.kind().is_readable() { + reader = sched.reader.take(); sched.source.readiness.fetch_or(1, Ordering::Relaxed); } if event.kind().is_writable() { + writer = sched.writer.take(); sched.source.readiness.fetch_or(2, Ordering::Relaxed); } } else { @@ -272,8 +279,13 @@ impl Loop { } // If we actually got a waiter, then notify! - if let Some(waiter) = waiter { - self.notify_handle(waiter); + // + // TODO: don't notify the same task twice + if let Some(reader) = reader { + self.notify_handle(reader); + } + if let Some(writer) = writer { + self.notify_handle(writer); } } @@ -299,16 +311,17 @@ impl Loop { /// Method used to notify a task handle. /// - /// Note that this should be used instead fo `handle.notify()` to ensure + /// Note that this should be used instead fo `handle.unpark()` to ensure /// that the `CURRENT_LOOP` variable is set appropriately. fn notify_handle(&self, handle: TaskHandle) { - CURRENT_LOOP.set(&self, || handle.notify()); + CURRENT_LOOP.set(&self, || handle.unpark()); } fn add_source(&self, source: IoSource) -> io::Result { let sched = Scheduled { source: source, - waiter: None, + reader: None, + writer: None, }; let mut dispatch = self.dispatch.borrow_mut(); if dispatch.vacant_entry().is_none() { @@ -325,15 +338,21 @@ impl Loop { deregister(&self.io, &sched); } - fn schedule(&self, token: usize, wake: TaskHandle) { + fn schedule(&self, token: usize, wake: TaskHandle, dir: Direction) { let to_call = { let mut dispatch = self.dispatch.borrow_mut(); let sched = dispatch.get_mut(token).unwrap(); - if sched.source.readiness.load(Ordering::Relaxed) != 0 { - sched.waiter = None; + let (slot, bit) = match dir { + Direction::Read => (&mut sched.reader, 1), + Direction::Write => (&mut sched.writer, 2), + }; + let ready = sched.source.readiness.load(Ordering::SeqCst); + if ready & bit != 0 { + *slot = None; + sched.source.readiness.store(ready & !bit, Ordering::SeqCst); Some(wake) } else { - sched.waiter = Some(wake); + *slot = Some(wake); None } }; @@ -343,12 +362,6 @@ impl Loop { } } - fn deschedule(&self, token: usize) { - let mut dispatch = self.dispatch.borrow_mut(); - let sched = dispatch.get_mut(token).unwrap(); - sched.waiter = None; - } - fn add_timeout(&self, at: Instant) -> io::Result { let mut timeouts = self.timeouts.borrow_mut(); if timeouts.vacant_entry().is_none() { @@ -390,8 +403,7 @@ impl Loop { .ok().expect("interference with try_produce"); } Message::DropSource(tok) => self.drop_source(tok), - Message::Schedule(tok, wake) => self.schedule(tok, wake), - Message::Deschedule(tok) => self.deschedule(tok), + Message::Schedule(tok, wake, dir) => self.schedule(tok, wake, dir), Message::Shutdown => self.active.set(false), Message::AddTimeout(at, slot) => { @@ -475,42 +487,48 @@ impl LoopHandle { } } - /// Begin listening for events on an event loop. + /// Begin listening for read events on an event loop. /// /// Once an I/O object has been registered with the event loop through the /// `add_source` method, this method can be used with the assigned token to - /// begin awaiting notifications. + /// begin awaiting read notifications. /// - /// The `dir` argument indicates how the I/O object is expected to be - /// awaited on (either readable or writable) and the `wake` callback will be - /// invoked. Note that one the `wake` callback is invoked once it will not - /// be invoked again, it must be re-`schedule`d to continue receiving - /// notifications. + /// Currently the current task will be notified with *edge* semantics. This + /// means that whenever the underlying I/O object changes state, e.g. it was + /// not readable and now it is, then a notification will be sent. /// /// # Panics /// /// This function will panic if the event loop this handle is associated /// with has gone away, or if there is an error communicating with the event /// loop. - pub fn schedule(&self, tok: usize, task: &mut Task) { - // TODO: plumb through `&mut Task` if we're on the event loop - self.send(Message::Schedule(tok, task.handle().clone())); + /// + /// This function will also panic if there is not a currently running future + /// task. + pub fn schedule_read(&self, tok: usize) { + self.send(Message::Schedule(tok, task::park(), Direction::Read)); } - /// Stop listening for events on an event loop. + /// Begin listening for write events on an event loop. /// - /// Once a callback has been scheduled with the `schedule` method, it can be - /// unregistered from the event loop with this method. This method does not - /// guarantee that the callback will not be invoked if it hasn't already, - /// but a best effort will be made to ensure it is not called. + /// Once an I/O object has been registered with the event loop through the + /// `add_source` method, this method can be used with the assigned token to + /// begin awaiting write notifications. + /// + /// Currently the current task will be notified with *edge* semantics. This + /// means that whenever the underlying I/O object changes state, e.g. it was + /// not writable and now it is, then a notification will be sent. /// /// # Panics /// /// This function will panic if the event loop this handle is associated /// with has gone away, or if there is an error communicating with the event /// loop. - pub fn deschedule(&self, tok: usize) { - self.send(Message::Deschedule(tok)); + /// + /// This function will also panic if there is not a currently running future + /// task. + pub fn schedule_write(&self, tok: usize) { + self.send(Message::Schedule(tok, task::park(), Direction::Write)); } /// Unregister all information associated with a token on an event loop, @@ -554,9 +572,9 @@ impl LoopHandle { /// /// This method will panic if the timeout specified was not created by this /// loop handle's `add_timeout` method. - pub fn update_timeout(&self, timeout: &TimeoutToken, task: &mut Task) { + pub fn update_timeout(&self, timeout: &TimeoutToken) { let timeout = TimeoutToken { token: timeout.token }; - self.send(Message::UpdateTimeout(timeout, task.handle().clone())) + self.send(Message::UpdateTimeout(timeout, task::park())) } /// Cancel a previously added timeout. @@ -652,8 +670,8 @@ impl Future for AddSource { type Item = usize; type Error = io::Error; - fn poll(&mut self, task: &mut Task) -> Poll { - self.inner.poll(task, Loop::add_source, Message::AddSource) + fn poll(&mut self) -> Poll { + self.inner.poll(Loop::add_source, Message::AddSource) } } @@ -672,8 +690,8 @@ impl Future for AddTimeout { type Item = TimeoutToken; type Error = io::Error; - fn poll(&mut self, task: &mut Task) -> Poll { - self.inner.poll(task, Loop::add_timeout, Message::AddTimeout) + fn poll(&mut self) -> Poll { + self.inner.poll(Loop::add_timeout, Message::AddTimeout) } } @@ -715,8 +733,8 @@ impl Future for AddLoopData type Item = LoopData; type Error = io::Error; - fn poll(&mut self, task: &mut Task) -> Poll, io::Error> { - let ret = self.inner.poll(task, |_lp, f| { + fn poll(&mut self) -> Poll, io::Error> { + let ret = self.inner.poll(|_lp, f| { Ok(DropBox::new(f())) }, |f, slot| { Message::Run(Box::new(move || { @@ -777,13 +795,13 @@ impl Future for LoopData { type Item = A::Item; type Error = A::Error; - fn poll(&mut self, task: &mut Task) -> Poll { + fn poll(&mut self) -> Poll { // If we're on the right thread, then we can proceed. Otherwise we need // to go and get polled on the right thread. if let Some(inner) = self.get_mut() { - return inner.poll(task) + return inner.poll() } - task.poll_on(self.executor()); + task::poll_on(self.executor()); Poll::NotReady } } @@ -954,7 +972,7 @@ struct LoopFuture { impl LoopFuture where T: 'static, { - fn poll(&mut self, task: &mut Task, f: F, g: G) -> Poll + fn poll(&mut self, f: F, g: G) -> Poll where F: FnOnce(&Loop, U) -> io::Result, G: FnOnce(U, Arc>>) -> Message, { @@ -965,9 +983,9 @@ impl LoopFuture Ok(t) => return t.into(), Err(_) => {} } - let handle = task.handle().clone(); + let task = task::park(); *token = result.on_full(move |_| { - handle.notify(); + task.unpark(); }); return Poll::NotReady } @@ -980,10 +998,10 @@ impl LoopFuture return ret.into() } - let handle = task.handle().clone(); + let task = task::park(); let result = Arc::new(Slot::new(None)); let token = result.on_full(move |_| { - handle.notify(); + task.unpark(); }); self.result = Some((result.clone(), token)); self.loop_handle.send(g(data.take().unwrap(), result)); @@ -1033,14 +1051,8 @@ impl Source { /// The event loop will fill in this information and then inform futures /// that they're ready to go with the `schedule` method, and then the `poll` /// method can use this to figure out what happened. - pub fn take_readiness(&self) -> Option { - match self.readiness.swap(0, Ordering::SeqCst) { - 0 => None, - 1 => Some(Ready::Read), - 2 => Some(Ready::Write), - 3 => Some(Ready::ReadWrite), - _ => panic!(), - } + pub fn take_readiness(&self) -> usize { + self.readiness.swap(0, Ordering::SeqCst) } /// Gets access to the underlying I/O object. diff --git a/src/readiness_stream.rs b/src/readiness_stream.rs index efa180421..93e1eb09c 100644 --- a/src/readiness_stream.rs +++ b/src/readiness_stream.rs @@ -1,8 +1,7 @@ use std::io; +use std::sync::atomic::{AtomicUsize, Ordering}; -use futures::stream::Stream; -use futures::{Future, Task, Poll}; -use futures_io::{Ready}; +use futures::{Future, Poll}; use event_loop::{IoSource, LoopHandle, AddSource}; @@ -23,6 +22,7 @@ pub struct ReadinessStream { io_token: usize, loop_handle: LoopHandle, source: IoSource, + readiness: AtomicUsize, } pub struct ReadinessStreamNew { @@ -45,43 +45,64 @@ impl ReadinessStream { handle: Some(loop_handle), } } + + /// Tests to see if this source is ready to be read from or not. + pub fn poll_read(&self) -> Poll<(), io::Error> { + if self.readiness.load(Ordering::SeqCst) & 1 != 0 { + return Poll::Ok(()) + } + self.readiness.fetch_or(self.source.take_readiness(), Ordering::SeqCst); + if self.readiness.load(Ordering::SeqCst) & 1 != 0 { + Poll::Ok(()) + } else { + self.loop_handle.schedule_read(self.io_token); + Poll::NotReady + } + } + + /// Tests to see if this source is ready to be written to or not. + pub fn poll_write(&self) -> Poll<(), io::Error> { + if self.readiness.load(Ordering::SeqCst) & 2 != 0 { + return Poll::Ok(()) + } + self.readiness.fetch_or(self.source.take_readiness(), Ordering::SeqCst); + if self.readiness.load(Ordering::SeqCst) & 2 != 0 { + Poll::Ok(()) + } else { + self.loop_handle.schedule_write(self.io_token); + Poll::NotReady + } + } + + /// Tests to see if this source is ready to be read from or not. + pub fn need_read(&self) { + self.readiness.fetch_and(!1, Ordering::SeqCst); + self.loop_handle.schedule_read(self.io_token); + } + + /// Tests to see if this source is ready to be written to or not. + pub fn need_write(&self) { + self.readiness.fetch_and(!2, Ordering::SeqCst); + self.loop_handle.schedule_write(self.io_token); + } } impl Future for ReadinessStreamNew { type Item = ReadinessStream; type Error = io::Error; - fn poll(&mut self, task: &mut Task) -> Poll { - self.inner.poll(task).map(|token| { + fn poll(&mut self) -> Poll { + self.inner.poll().map(|token| { ReadinessStream { io_token: token, source: self.source.take().unwrap(), loop_handle: self.handle.take().unwrap(), + readiness: AtomicUsize::new(0), } }) } } -impl Stream for ReadinessStream { - type Item = Ready; - type Error = io::Error; - - fn poll(&mut self, task: &mut Task) -> Poll, io::Error> { - match self.source.take_readiness() { - None => { - self.loop_handle.schedule(self.io_token, task); - Poll::NotReady - } - Some(r) => { - if !r.is_read() || !r.is_write() { - self.loop_handle.schedule(self.io_token, task); - } - Poll::Ok(Some(r)) - } - } - } -} - impl Drop for ReadinessStream { fn drop(&mut self) { self.loop_handle.drop_source(self.io_token) diff --git a/src/tcp.rs b/src/tcp.rs index 90c768465..f004c2698 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -4,9 +4,9 @@ use std::mem; use std::net::{self, SocketAddr, Shutdown}; use std::sync::Arc; -use futures::stream::{self, Stream}; -use futures::{Future, IntoFuture, failed, Task, Poll}; -use futures_io::{Ready, IoFuture, IoStream}; +use futures::stream::Stream; +use futures::{Future, IntoFuture, failed, Poll}; +use futures_io::{IoFuture, IoStream}; use mio; use {ReadinessStream, LoopHandle}; @@ -71,6 +71,11 @@ impl TcpListener { .boxed() } + /// Test whether this socket is ready to be read or not. + pub fn poll_read(&self) -> Poll<(), io::Error> { + self.ready.poll_read() + } + /// Returns the local address that this listener is bound to. /// /// This can be useful, for example, when binding to port 0 to figure out @@ -85,13 +90,28 @@ impl TcpListener { /// This method returns an implementation of the `Stream` trait which /// resolves to the sockets the are accepted on this listener. pub fn incoming(self) -> IoStream<(TcpStream, SocketAddr)> { - let TcpListener { loop_handle, listener, ready } = self; + struct Incoming { + inner: TcpListener, + } - ready - .map(move |_| { - stream::iter(NonblockingIter { source: listener.clone() }.fuse()) - }) - .flatten() + impl Stream for Incoming { + type Item = (mio::tcp::TcpStream, SocketAddr); + type Error = io::Error; + + fn poll(&mut self) -> Poll, io::Error> { + match self.inner.listener.io().accept() { + Ok(Some(pair)) => Poll::Ok(Some(pair)), + Ok(None) => { + self.inner.ready.need_read(); + Poll::NotReady + } + Err(e) => Poll::Err(e), + } + } + } + + let loop_handle = self.loop_handle.clone(); + Incoming { inner: self } .and_then(move |(tcp, addr)| { let tcp = Arc::new(Source::new(tcp)); ReadinessStream::new(loop_handle.clone(), @@ -106,43 +126,12 @@ impl TcpListener { } } -struct NonblockingIter { - source: Arc>, -} - -impl Iterator for NonblockingIter { - type Item = io::Result<(mio::tcp::TcpStream, SocketAddr)>; - - fn next(&mut self) -> Option> { - match self.source.io().accept() { - Ok(Some(e)) => { - debug!("accepted connection"); - Some(Ok(e)) - } - Ok(None) => { - debug!("no connection ready"); - None - } - Err(e) => Some(Err(e)), - } - } -} - impl fmt::Debug for TcpListener { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.listener.io().fmt(f) } } -impl Stream for TcpListener { - type Item = Ready; - type Error = io::Error; - - fn poll(&mut self, task: &mut Task) -> Poll, io::Error> { - self.ready.poll(task) - } -} - /// An I/O object representing a TCP stream connected to a remote endpoint. /// /// A TCP stream can either be created by connecting to an endpoint or by @@ -233,6 +222,16 @@ impl TcpStream { } } + /// Test whether this socket is ready to be read or not. + pub fn poll_read(&self) -> Poll<(), io::Error> { + self.ready.poll_read() + } + + /// Test whether this socket is writey to be written to or not. + pub fn poll_write(&self) -> Poll<(), io::Error> { + self.ready.poll_write() + } + /// Returns the local address that this stream is bound to. pub fn local_addr(&self) -> io::Result { self.source.io().local_addr() @@ -273,14 +272,13 @@ impl Future for TcpStreamNew { type Item = TcpStream; type Error = io::Error; - fn poll(&mut self, task: &mut Task) -> Poll { - let mut stream = match mem::replace(self, TcpStreamNew::Empty) { + fn poll(&mut self) -> Poll { + let stream = match mem::replace(self, TcpStreamNew::Empty) { TcpStreamNew::Waiting(s) => s, TcpStreamNew::Empty => panic!("can't poll TCP stream twice"), }; - match stream.ready.poll(task) { - Poll::Ok(None) => panic!(), - Poll::Ok(Some(_)) => { + match stream.ready.poll_write() { + Poll::Ok(()) => { match stream.source.io().take_socket_error() { Ok(()) => return Poll::Ok(stream), Err(ref e) if e.kind() == ErrorKind::WouldBlock => {} @@ -298,6 +296,9 @@ impl Future for TcpStreamNew { impl Read for TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { let r = self.source.io().read(buf); + if is_wouldblock(&r) { + self.ready.need_read(); + } trace!("read[{:p}] {:?} on {:?}", self, r, self.source.io()); return r } @@ -306,26 +307,53 @@ impl Read for TcpStream { impl Write for TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { let r = self.source.io().write(buf); + if is_wouldblock(&r) { + self.ready.need_write(); + } trace!("write[{:p}] {:?} on {:?}", self, r, self.source.io()); return r } fn flush(&mut self) -> io::Result<()> { - self.source.io().flush() + let r = self.source.io().flush(); + if is_wouldblock(&r) { + self.ready.need_write(); + } + return r } } impl<'a> Read for &'a TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.source.io().read(buf) + let r = self.source.io().read(buf); + if is_wouldblock(&r) { + self.ready.need_read(); + } + return r } } impl<'a> Write for &'a TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.source.io().write(buf) + let r = self.source.io().write(buf); + if is_wouldblock(&r) { + self.ready.need_write(); + } + return r } + fn flush(&mut self) -> io::Result<()> { - self.source.io().flush() + let r = self.source.io().flush(); + if is_wouldblock(&r) { + self.ready.need_write(); + } + return r + } +} + +fn is_wouldblock(r: &io::Result) -> bool { + match *r { + Ok(_) => false, + Err(ref e) => e.kind() == io::ErrorKind::WouldBlock, } } @@ -335,15 +363,6 @@ impl fmt::Debug for TcpStream { } } -impl Stream for TcpStream { - type Item = Ready; - type Error = io::Error; - - fn poll(&mut self, task: &mut Task) -> Poll, io::Error> { - self.ready.poll(task) - } -} - #[cfg(unix)] mod sys { use std::os::unix::prelude::*; diff --git a/src/timeout.rs b/src/timeout.rs index dccbd91e7..d71a51e05 100644 --- a/src/timeout.rs +++ b/src/timeout.rs @@ -1,7 +1,7 @@ use std::io; use std::time::{Duration, Instant}; -use futures::{Future, Task, Poll}; +use futures::{Future, Poll}; use futures_io::IoFuture; use LoopHandle; @@ -50,12 +50,12 @@ impl Future for Timeout { type Item = (); type Error = io::Error; - fn poll(&mut self, task: &mut Task) -> Poll<(), io::Error> { + fn poll(&mut self) -> Poll<(), io::Error> { // TODO: is this fast enough? if self.at <= Instant::now() { Poll::Ok(()) } else { - self.handle.update_timeout(&self.token, task); + self.handle.update_timeout(&self.token); Poll::NotReady } } diff --git a/src/udp.rs b/src/udp.rs index 754fc8b9c..bc71ec41c 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -3,9 +3,8 @@ use std::net::{self, SocketAddr, Ipv4Addr, Ipv6Addr}; use std::sync::Arc; use std::fmt; -use futures::stream::Stream; -use futures::{Future, failed, Task, Poll}; -use futures_io::{Ready, IoFuture}; +use futures::{Future, failed, Poll}; +use futures_io::IoFuture; use mio; use {ReadinessStream, LoopHandle}; @@ -66,6 +65,16 @@ impl UdpSocket { self.source.io().local_addr() } + /// Test whether this socket is ready to be read or not. + pub fn poll_read(&self) -> Poll<(), io::Error> { + self.ready.poll_read() + } + + /// Test whether this socket is writey to be written to or not. + pub fn poll_write(&self) -> Poll<(), io::Error> { + self.ready.poll_write() + } + /// Sends data on the socket to the given address. On success, returns the /// number of bytes written. /// @@ -74,7 +83,10 @@ impl UdpSocket { pub fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result { match self.source.io().send_to(buf, target) { Ok(Some(n)) => Ok(n), - Ok(None) => Err(io::Error::new(io::ErrorKind::WouldBlock, "would block")), + Ok(None) => { + self.ready.need_write(); + Err(io::Error::new(io::ErrorKind::WouldBlock, "would block")) + } Err(e) => Err(e), } } @@ -84,7 +96,10 @@ impl UdpSocket { pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { match self.source.io().recv_from(buf) { Ok(Some(n)) => Ok(n), - Ok(None) => Err(io::Error::new(io::ErrorKind::WouldBlock, "would block")), + Ok(None) => { + self.ready.need_read(); + Err(io::Error::new(io::ErrorKind::WouldBlock, "would block")) + } Err(e) => Err(e), } } @@ -236,15 +251,6 @@ impl fmt::Debug for UdpSocket { } } -impl Stream for UdpSocket { - type Item = Ready; - type Error = io::Error; - - fn poll(&mut self, task: &mut Task) -> Poll, io::Error> { - self.ready.poll(task) - } -} - #[cfg(unix)] mod sys { use std::os::unix::prelude::*; diff --git a/tests/buffered.rs b/tests/buffered.rs index 3c30847e0..a6bc89925 100644 --- a/tests/buffered.rs +++ b/tests/buffered.rs @@ -5,11 +5,11 @@ extern crate env_logger; use std::net::TcpStream; use std::thread; -use std::io::{Read, Write}; +use std::io::{Read, Write, BufReader, BufWriter}; use futures::Future; use futures::stream::Stream; -use futures_io::{BufReader, BufWriter, copy}; +use futures_io::copy; macro_rules! t { ($e:expr) => (match $e { diff --git a/tests/chain.rs b/tests/chain.rs index fb2c9b8fa..97eccb4d1 100644 --- a/tests/chain.rs +++ b/tests/chain.rs @@ -4,11 +4,11 @@ extern crate futures_mio; use std::net::TcpStream; use std::thread; -use std::io::Write; +use std::io::{Write, Read}; use futures::Future; use futures::stream::Stream; -use futures_io::{chain, read_to_end}; +use futures_io::read_to_end; macro_rules! t { ($e:expr) => (match $e { @@ -40,9 +40,7 @@ fn chain_clients() { let b = clients.next().unwrap(); let c = clients.next().unwrap(); - let d = chain(a, b); - let d = chain(d, c); - read_to_end(d, Vec::new()) + read_to_end(a.chain(b).chain(c), Vec::new()) }); let data = t!(l.run(copied)); diff --git a/tests/echo.rs b/tests/echo.rs index 6247a10a8..724886472 100644 --- a/tests/echo.rs +++ b/tests/echo.rs @@ -3,13 +3,14 @@ extern crate futures; extern crate futures_io; extern crate futures_mio; -use std::net::TcpStream; +use std::io::{self, Read, Write}; +use std::sync::Arc; use std::thread; -use std::io::{Read, Write}; use futures::Future; use futures::stream::Stream; -use futures_io::{copy, TaskIo}; +use futures_io::copy; +use futures_mio::TcpStream; macro_rules! t { ($e:expr) => (match $e { @@ -29,6 +30,8 @@ fn echo_server() { let msg = "foo bar baz"; let t = thread::spawn(move || { + use std::net::TcpStream; + let mut s = TcpStream::connect(&addr).unwrap(); for _i in 0..1024 { @@ -41,7 +44,10 @@ fn echo_server() { let clients = srv.incoming(); let client = clients.into_future().map(|e| e.0.unwrap()).map_err(|e| e.0); - let halves = client.and_then(|s| TaskIo::new(s.0)).map(|i| i.split()); + let halves = client.map(|s| { + let s = Arc::new(s.0); + (SocketIo(s.clone()), SocketIo(s)) + }); let copied = halves.and_then(|(a, b)| copy(a, b)); let amt = t!(l.run(copied)); @@ -49,3 +55,21 @@ fn echo_server() { assert_eq!(amt, msg.len() as u64 * 1024); } + +struct SocketIo(Arc); + +impl Read for SocketIo { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + (&*self.0).read(buf) + } +} + +impl Write for SocketIo { + fn write(&mut self, buf: &[u8]) -> io::Result { + (&*self.0).write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + (&*self.0).flush() + } +} diff --git a/tests/limit.rs b/tests/limit.rs index d7082f459..08c97eeda 100644 --- a/tests/limit.rs +++ b/tests/limit.rs @@ -4,11 +4,11 @@ extern crate futures_mio; use std::net::TcpStream; use std::thread; -use std::io::Write; +use std::io::{Write, Read}; use futures::Future; use futures::stream::Stream; -use futures_io::{read_to_end, take}; +use futures_io::read_to_end; macro_rules! t { ($e:expr) => (match $e { @@ -34,7 +34,7 @@ fn limit() { let mut clients = clients.into_iter(); let a = clients.next().unwrap(); - read_to_end(take(a, 4), Vec::new()) + read_to_end(a.take(4), Vec::new()) }); let data = t!(l.run(copied)); diff --git a/tests/stream-buffered.rs b/tests/stream-buffered.rs index 34af93b30..86d5989a9 100644 --- a/tests/stream-buffered.rs +++ b/tests/stream-buffered.rs @@ -3,13 +3,14 @@ extern crate futures_io; extern crate futures_mio; extern crate env_logger; -use std::net::TcpStream; +use std::sync::Arc; use std::thread; -use std::io::{Read, Write}; +use std::io::{self, Read, Write}; use futures::Future; use futures::stream::Stream; -use futures_io::{copy, TaskIo}; +use futures_io::copy; +use futures_mio::TcpStream; macro_rules! t { ($e:expr) => (match $e { @@ -28,6 +29,8 @@ fn echo_server() { let addr = t!(srv.local_addr()); let t = thread::spawn(move || { + use std::net::TcpStream; + let mut s1 = t!(TcpStream::connect(&addr)); let mut s2 = t!(TcpStream::connect(&addr)); @@ -42,9 +45,9 @@ fn echo_server() { }); let future = srv.incoming() - .and_then(|s| TaskIo::new(s.0)) - .map(|i| i.split()) - .map(|(a,b)| copy(a,b).map(|_| ())) + .map(|s| Arc::new(s.0)) + .map(|i| (SocketIo(i.clone()), SocketIo(i))) + .map(|(a, b)| copy(a, b).map(|_| ())) .buffered(10) .take(2) .collect(); @@ -53,3 +56,21 @@ fn echo_server() { t.join().unwrap(); } + +struct SocketIo(Arc); + +impl Read for SocketIo { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + (&*self.0).read(buf) + } +} + +impl Write for SocketIo { + fn write(&mut self, buf: &[u8]) -> io::Result { + (&*self.0).write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + (&*self.0).flush() + } +} diff --git a/tests/udp.rs b/tests/udp.rs index faf9f25df..8f1ba5a9b 100644 --- a/tests/udp.rs +++ b/tests/udp.rs @@ -1,8 +1,11 @@ extern crate futures; extern crate futures_mio; -use futures::Future; -use futures::stream::Stream; +use std::io; +use std::net::SocketAddr; + +use futures::{Future, Poll}; +use futures_mio::UdpSocket; macro_rules! t { ($e:expr) => (match $e { @@ -20,24 +23,50 @@ fn send_messages() { let a_addr = t!(a.local_addr()); let b_addr = t!(b.local_addr()); - let ((ar, a), (br, b)) = t!(l.run(a.into_future().join(b.into_future()))); - let ar = ar.unwrap(); - let br = br.unwrap(); - - assert!(ar.is_write()); - assert!(!ar.is_read()); - assert!(br.is_write()); - assert!(!br.is_read()); - - assert_eq!(t!(a.send_to(b"1234", &b_addr)), 4); - let (br, b) = t!(l.run(b.into_future())); - let br = br.unwrap(); - - assert!(br.is_read()); - - let mut buf = [0; 32]; - let (size, addr) = t!(b.recv_from(&mut buf)); - assert_eq!(size, 4); - assert_eq!(&buf[..4], b"1234"); - assert_eq!(addr, a_addr); + let send = SendMessage { socket: a, addr: b_addr }; + let recv = RecvMessage { socket: b, expected_addr: a_addr }; + t!(l.run(send.join(recv))); +} + +struct SendMessage { + socket: UdpSocket, + addr: SocketAddr, +} + +impl Future for SendMessage { + type Item = (); + type Error = io::Error; + + fn poll(&mut self) -> Poll<(), io::Error> { + match self.socket.send_to(b"1234", &self.addr) { + Ok(4) => Poll::Ok(()), + Ok(n) => panic!("didn't send 4 bytes: {}", n), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::NotReady, + Err(e) => Poll::Err(e), + } + } +} + +struct RecvMessage { + socket: UdpSocket, + expected_addr: SocketAddr, +} + +impl Future for RecvMessage { + type Item = (); + type Error = io::Error; + + fn poll(&mut self) -> Poll<(), io::Error> { + let mut buf = [0; 32]; + match self.socket.recv_from(&mut buf) { + Ok((4, addr)) => { + assert_eq!(&buf[..4], b"1234"); + assert_eq!(addr, self.expected_addr); + Poll::Ok(()) + } + Ok((n, _)) => panic!("didn't read 4 bytes: {}", n), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::NotReady, + Err(e) => Poll::Err(e), + } + } }