diff --git a/tokio-reactor/src/lib.rs b/tokio-reactor/src/lib.rs index 0b7cf77d2..1a1b8b3f1 100644 --- a/tokio-reactor/src/lib.rs +++ b/tokio-reactor/src/lib.rs @@ -96,6 +96,12 @@ pub struct Reactor { /// and will instead use the default reactor for the execution context. #[derive(Clone)] pub struct Handle { + inner: Option, +} + +/// Like `Handle`, but never `None`. +#[derive(Clone)] +struct HandlePriv { inner: Weak, } @@ -116,6 +122,12 @@ pub struct SetFallbackError(()); #[doc(hidden)] pub type SetDefaultError = SetFallbackError; +#[test] +fn test_handle_size() { + use std::mem; + assert_eq!(mem::size_of::(), mem::size_of::()); +} + struct Inner { /// The underlying system event queue. io: mio::Poll, @@ -147,7 +159,7 @@ pub(crate) enum Direction { static HANDLE_FALLBACK: AtomicUsize = ATOMIC_USIZE_INIT; /// Tracks the reactor for the current execution context. -thread_local!(static CURRENT_REACTOR: RefCell> = RefCell::new(None)); +thread_local!(static CURRENT_REACTOR: RefCell> = RefCell::new(None)); const TOKEN_SHIFT: usize = 22; @@ -199,8 +211,17 @@ where F: FnOnce(&mut Enter) -> R CURRENT_REACTOR.with(|current| { { let mut current = current.borrow_mut(); + assert!(current.is_none(), "default Tokio reactor already set \ for execution context"); + + let handle = match handle.as_priv() { + Some(handle) => handle, + None => { + panic!("`handle` does not reference a reactor"); + } + }; + *current = Some(handle.clone()); } @@ -240,7 +261,9 @@ impl Reactor { /// to bind them to this event loop. pub fn handle(&self) -> Handle { Handle { - inner: Arc::downgrade(&self.inner), + inner: Some(HandlePriv { + inner: Arc::downgrade(&self.inner), + }), } } @@ -268,7 +291,7 @@ impl Reactor { /// then this function will also return an error. (aka if `Handle::default` /// has been called previously in this program). pub fn set_fallback(&self) -> Result<(), SetFallbackError> { - set_fallback(self.handle()) + set_fallback(self.handle().into_priv().unwrap()) } /// Performs one iteration of the event loop, blocking on waiting for events @@ -416,24 +439,83 @@ impl fmt::Debug for Reactor { impl Handle { /// Returns a handle to the current reactor. pub fn current() -> Handle { - Handle::try_current() - .unwrap_or(Handle { inner: Weak::new() }) + // TODO: Should this panic on error? + HandlePriv::try_current() + .map(|handle| Handle { + inner: Some(handle), + }) + .unwrap_or(Handle { + inner: Some(HandlePriv { + inner: Weak::new(), + }) + }) } + fn as_priv(&self) -> Option<&HandlePriv> { + self.inner.as_ref() + } + + fn into_priv(self) -> Option { + self.inner + } + + fn wakeup(&self) { + if let Some(handle) = self.as_priv() { + handle.wakeup(); + } + } +} + +impl Unpark for Handle { + fn unpark(&self) { + if let Some(ref h) = self.inner { + h.wakeup(); + } + } +} + +impl Default for Handle { + fn default() -> Handle { + Handle { inner: None } + } +} + +impl fmt::Debug for Handle { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Handle") + } +} + +fn set_fallback(handle: HandlePriv) -> Result<(), SetFallbackError> { + unsafe { + let val = handle.into_usize(); + match HANDLE_FALLBACK.compare_exchange(0, val, SeqCst, SeqCst) { + Ok(_) => Ok(()), + Err(_) => { + drop(HandlePriv::from_usize(val)); + Err(SetFallbackError(())) + } + } + } +} + +// ===== impl HandlePriv ===== + +impl HandlePriv { /// Try to get a handle to the current reactor. /// /// Returns `Err` if no handle is found. - pub(crate) fn try_current() -> io::Result { + pub(crate) fn try_current() -> io::Result { CURRENT_REACTOR.with(|current| { match *current.borrow() { Some(ref handle) => Ok(handle.clone()), - None => Handle::fallback(), + None => HandlePriv::fallback(), } }) } /// Returns a handle to the fallback reactor. - fn fallback() -> io::Result { + fn fallback() -> io::Result { let mut fallback = HANDLE_FALLBACK.load(SeqCst); // If the fallback hasn't been previously initialized then let's spin @@ -454,8 +536,8 @@ impl Handle { // that someone was racing with this call to `Handle::default`. // They ended up winning so we'll destroy our helper thread (which // shuts down the thread) and reload the fallback. - if set_fallback(reactor.handle().clone()).is_ok() { - let ret = reactor.handle().clone(); + if set_fallback(reactor.handle().into_priv().unwrap()).is_ok() { + let ret = reactor.handle().into_priv().unwrap(); match reactor.background() { Ok(bg) => bg.forget(), @@ -476,9 +558,13 @@ impl Handle { assert!(fallback != 0); let ret = unsafe { - let handle = Handle::from_usize(fallback); + let handle = HandlePriv::from_usize(fallback); let ret = handle.clone(); + + // This prevents `handle` from being dropped and having the ref + // count decremented. drop(handle.into_usize()); + ret }; @@ -506,9 +592,9 @@ impl Handle { } } - unsafe fn from_usize(val: usize) -> Handle { + unsafe fn from_usize(val: usize) -> HandlePriv { let inner = mem::transmute::>(val);; - Handle { inner } + HandlePriv { inner } } fn inner(&self) -> Option> { @@ -516,34 +602,9 @@ impl Handle { } } -impl Unpark for Handle { - fn unpark(&self) { - self.wakeup(); - } -} - -impl Default for Handle { - fn default() -> Handle { - Handle::current() - } -} - -impl fmt::Debug for Handle { +impl fmt::Debug for HandlePriv { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Handle") - } -} - -fn set_fallback(handle: Handle) -> Result<(), SetFallbackError> { - unsafe { - let val = handle.into_usize(); - match HANDLE_FALLBACK.compare_exchange(0, val, SeqCst, SeqCst) { - Ok(_) => Ok(()), - Err(_) => { - drop(Handle::from_usize(val)); - Err(SetFallbackError(())) - } - } + write!(f, "HandlePriv") } } diff --git a/tokio-reactor/src/poll_evented.rs b/tokio-reactor/src/poll_evented.rs index 6082df949..23b250bf6 100644 --- a/tokio-reactor/src/poll_evented.rs +++ b/tokio-reactor/src/poll_evented.rs @@ -160,7 +160,12 @@ where E: Evented /// Creates a new `PollEvented` associated with the specified reactor. pub fn new_with_handle(io: E, handle: &Handle) -> io::Result { let ret = PollEvented::new(io); - ret.inner.registration.register_with(ret.io.as_ref().unwrap(), handle)?; + + if let Some(handle) = handle.as_priv() { + ret.inner.registration + .register_with_priv(ret.io.as_ref().unwrap(), handle)?; + } + Ok(ret) } diff --git a/tokio-reactor/src/registration.rs b/tokio-reactor/src/registration.rs index 73298f9a4..278b57680 100644 --- a/tokio-reactor/src/registration.rs +++ b/tokio-reactor/src/registration.rs @@ -1,4 +1,4 @@ -use {Handle, Direction, Task}; +use {Handle, HandlePriv, Direction, Task}; use futures::{Async, Poll, task}; use mio::{self, Evented}; @@ -59,7 +59,7 @@ pub struct Registration { #[derive(Debug)] struct Inner { - handle: Handle, + handle: HandlePriv, token: usize, } @@ -117,7 +117,7 @@ impl Registration { pub fn register(&self, io: &T) -> io::Result where T: Evented, { - self.register2(io, || Handle::try_current()) + self.register2(io, || HandlePriv::try_current()) } /// Deregister the I/O resource from the reactor it is associated with. @@ -163,13 +163,24 @@ impl Registration { /// If an error is encountered during registration, `Err` is returned. pub fn register_with(&self, io: &T, handle: &Handle) -> io::Result where T: Evented, + { + self.register2(io, || { + match handle.as_priv() { + Some(handle) => Ok(handle.clone()), + None => HandlePriv::try_current(), + } + }) + } + + pub(crate) fn register_with_priv(&self, io: &T, handle: &HandlePriv) -> io::Result + where T: Evented, { self.register2(io, || Ok(handle.clone())) } fn register2(&self, io: &T, f: F) -> io::Result where T: Evented, - F: Fn() -> io::Result, + F: Fn() -> io::Result, { let mut state = self.state.load(SeqCst); @@ -434,7 +445,7 @@ unsafe impl Sync for Registration {} // ===== impl Inner ===== impl Inner { - fn new(io: &T, handle: Handle) -> (Self, io::Result<()>) + fn new(io: &T, handle: HandlePriv) -> (Self, io::Result<()>) where T: Evented, { let mut res = Ok(());