diff --git a/benches/client.rs b/benches/client.rs index 23757872..9ac3a7c6 100644 --- a/benches/client.rs +++ b/benches/client.rs @@ -20,7 +20,7 @@ use pushpin::connmgr::client::TestClient; use pushpin::core::channel; use pushpin::core::executor::Executor; use pushpin::core::reactor::Reactor; -use pushpin::future::{AsyncReadExt, AsyncSender, AsyncTcpListener, AsyncTcpStream, AsyncWriteExt}; +use pushpin::future::{AsyncReadExt, AsyncTcpListener, AsyncTcpStream, AsyncWriteExt}; use std::net::SocketAddr; use std::rc::Rc; use std::str; @@ -46,7 +46,7 @@ where executor .spawn(async move { - let s = AsyncSender::new(s); + let s = channel::AsyncSender::new(s); let listener = AsyncTcpListener::new(listener); for _ in 0..REQS_PER_ITER { diff --git a/src/connmgr/client.rs b/src/connmgr/client.rs index e2a80d7a..cb97aebe 100644 --- a/src/connmgr/client.rs +++ b/src/connmgr/client.rs @@ -24,7 +24,7 @@ use crate::connmgr::zhttppacket; use crate::connmgr::zhttpsocket::{self, SessionKey, FROM_MAX, REQ_ID_MAX}; use crate::core::arena; use crate::core::buffer::TmpBuffer; -use crate::core::channel; +use crate::core::channel::{self, AsyncLocalReceiver, AsyncLocalSender, AsyncReceiver}; use crate::core::event; use crate::core::executor::{Executor, Spawner}; use crate::core::list; @@ -33,8 +33,7 @@ use crate::core::tnetstring; use crate::core::zmq::{MultipartHeader, SpecInfo}; use crate::future::{ event_wait, select_2, select_5, select_6, select_option, yield_to_local_events, - AsyncLocalReceiver, AsyncLocalSender, AsyncReceiver, CancellationSender, CancellationToken, - Select2, Select5, Select6, Timeout, + CancellationSender, CancellationToken, Select2, Select5, Select6, Timeout, }; use arrayvec::ArrayVec; use ipnet::IpNet; diff --git a/src/connmgr/connection.rs b/src/connmgr/connection.rs index a852a82b..b2729895 100644 --- a/src/connmgr/connection.rs +++ b/src/connmgr/connection.rs @@ -47,6 +47,7 @@ use crate::core::arena; use crate::core::buffer::{ Buffer, ContiguousBuffer, LimitBufsMut, TmpBuffer, VecRingBuffer, VECTORED_MAX, }; +use crate::core::channel::{AsyncLocalReceiver, AsyncLocalSender}; use crate::core::defer::Defer; use crate::core::http1::Error as CoreHttpError; use crate::core::http1::{self, client, server, RecvStatus, SendStatus}; @@ -56,10 +57,9 @@ use crate::core::shuffle::random; use crate::core::waker::RefWakerData; use crate::core::zmq::MultipartHeader; use crate::future::{ - io_split, poll_async, select_2, select_3, select_4, select_option, AsyncLocalReceiver, - AsyncLocalSender, AsyncRead, AsyncReadExt, AsyncResolver, AsyncTcpStream, AsyncTlsStream, - AsyncWrite, AsyncWriteExt, CancellationToken, ReadHalf, Select2, Select3, Select4, - StdWriteWrapper, Timeout, TlsWaker, WriteHalf, + io_split, poll_async, select_2, select_3, select_4, select_option, AsyncRead, AsyncReadExt, + AsyncResolver, AsyncTcpStream, AsyncTlsStream, AsyncWrite, AsyncWriteExt, CancellationToken, + ReadHalf, Select2, Select3, Select4, StdWriteWrapper, Timeout, TlsWaker, WriteHalf, }; use arrayvec::{ArrayString, ArrayVec}; use ipnet::IpNet; diff --git a/src/connmgr/listener.rs b/src/connmgr/listener.rs index 284f6a4b..a87c91dc 100644 --- a/src/connmgr/listener.rs +++ b/src/connmgr/listener.rs @@ -19,10 +19,7 @@ use crate::core::channel; use crate::core::executor::Executor; use crate::core::net::{NetListener, NetStream, SocketAddr}; use crate::core::reactor::Reactor; -use crate::future::{ - select_2, select_slice, AsyncNetListener, AsyncReceiver, AsyncSender, NetAcceptFuture, Select2, - WaitWritableFuture, -}; +use crate::future::{select_2, select_slice, AsyncNetListener, NetAcceptFuture, Select2}; use log::{debug, error}; use std::cmp; use std::sync::mpsc; @@ -67,18 +64,18 @@ impl Listener { listeners: Vec, senders: Vec>, ) { - let stop = AsyncReceiver::new(stop); + let stop = channel::AsyncReceiver::new(stop); let mut listeners: Vec = listeners.into_iter().map(AsyncNetListener::new).collect(); - let mut senders: Vec> = - senders.into_iter().map(AsyncSender::new).collect(); + let mut senders: Vec> = + senders.into_iter().map(channel::AsyncSender::new).collect(); let mut listeners_pos = 0; let mut senders_pos = 0; - let mut sender_tasks_mem: Vec> = + let mut sender_tasks_mem: Vec> = Vec::with_capacity(senders.len()); let mut listener_tasks_mem: Vec = Vec::with_capacity(listeners.len()); diff --git a/src/connmgr/server.rs b/src/connmgr/server.rs index db101340..a20a6ddd 100644 --- a/src/connmgr/server.rs +++ b/src/connmgr/server.rs @@ -26,7 +26,7 @@ use crate::connmgr::zhttpsocket; use crate::connmgr::{ListenConfig, ListenSpec}; use crate::core::arena; use crate::core::buffer::TmpBuffer; -use crate::core::channel; +use crate::core::channel::{self, AsyncLocalReceiver, AsyncLocalSender, AsyncReceiver}; use crate::core::event; use crate::core::executor::{Executor, Spawner}; use crate::core::fs::{set_group, set_user}; @@ -38,9 +38,8 @@ use crate::core::waker::RefWakerData; use crate::core::zmq::SpecInfo; use crate::future::{ event_wait, select_2, select_3, select_6, select_8, select_option, yield_to_local_events, - AsyncLocalReceiver, AsyncLocalSender, AsyncReceiver, AsyncTcpStream, AsyncTlsStream, - AsyncUnixStream, CancellationSender, CancellationToken, Select2, Select3, Select6, Select8, - Timeout, TlsWaker, + AsyncTcpStream, AsyncTlsStream, AsyncUnixStream, CancellationSender, CancellationToken, + Select2, Select3, Select6, Select8, Timeout, TlsWaker, }; use arrayvec::{ArrayString, ArrayVec}; use log::{debug, error, info, warn}; diff --git a/src/connmgr/track.rs b/src/connmgr/track.rs index 654e04a4..8552445f 100644 --- a/src/connmgr/track.rs +++ b/src/connmgr/track.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use crate::future::AsyncLocalReceiver; +use crate::core::channel::AsyncLocalReceiver; use std::cell::Cell; use std::future::Future; use std::ops::Deref; diff --git a/src/connmgr/zhttpsocket.rs b/src/connmgr/zhttpsocket.rs index 60f57af7..6425fe66 100644 --- a/src/connmgr/zhttpsocket.rs +++ b/src/connmgr/zhttpsocket.rs @@ -17,7 +17,7 @@ use crate::connmgr::zhttppacket::{parse_ids, Id, ParseScratch}; use crate::core::arena; use crate::core::buffer::trim_for_display; -use crate::core::channel; +use crate::core::channel::{self, AsyncReceiver, AsyncSender, RecvFuture, WaitWritableFuture}; use crate::core::event; use crate::core::executor::Executor; use crate::core::list; @@ -25,9 +25,8 @@ use crate::core::reactor::Reactor; use crate::core::tnetstring; use crate::core::zmq::{MultipartHeader, SpecInfo, ZmqSocket}; use crate::future::{ - select_10, select_9, select_option, select_slice, AsyncReceiver, AsyncSender, AsyncZmqSocket, - RecvFuture, Select10, Select9, WaitWritableFuture, ZmqSendFuture, ZmqSendToFuture, - REGISTRATIONS_PER_CHANNEL, REGISTRATIONS_PER_ZMQSOCKET, + select_10, select_9, select_option, select_slice, AsyncZmqSocket, Select10, Select9, + ZmqSendFuture, ZmqSendToFuture, REGISTRATIONS_PER_CHANNEL, REGISTRATIONS_PER_ZMQSOCKET, }; use arrayvec::{ArrayString, ArrayVec}; use log::{debug, error, log_enabled, trace, warn}; diff --git a/src/core/channel.rs b/src/core/channel.rs index 04a0b8d8..fa958685 100644 --- a/src/core/channel.rs +++ b/src/core/channel.rs @@ -17,14 +17,19 @@ use crate::core::arena; use crate::core::event; use crate::core::list; +use crate::core::reactor::CustomEvented; +use crate::future::get_reactor; use slab::Slab; use std::cell::RefCell; use std::collections::VecDeque; +use std::future::Future; use std::mem; +use std::pin::Pin; use std::rc::Rc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc; use std::sync::Arc; +use std::task::{Context, Poll}; pub struct Sender { sender: Option>, @@ -482,9 +487,441 @@ pub fn local_channel( (sender, receiver) } +pub struct AsyncSender { + evented: CustomEvented, + inner: Sender, +} + +impl AsyncSender { + pub fn new(s: Sender) -> Self { + let evented = CustomEvented::new( + s.get_write_registration(), + mio::Interest::WRITABLE, + &get_reactor(), + ) + .unwrap(); + + // assume we can write, unless can_send() returns false. note that + // if can_send() returns true, it doesn't mean we can actually write + evented.registration().set_ready(s.can_send()); + + Self { evented, inner: s } + } + + pub fn is_writable(&self) -> bool { + self.evented.registration().is_ready() + } + + pub fn wait_writable(&self) -> WaitWritableFuture<'_, T> { + WaitWritableFuture { s: self } + } + + pub fn try_send(&self, t: T) -> Result<(), mpsc::TrySendError> { + match self.inner.try_send(t) { + Ok(_) => { + // if can_send() returns false, then we know we can't write + if !self.inner.can_send() { + self.evented.registration().set_ready(false); + } + + Ok(()) + } + Err(mpsc::TrySendError::Full(t)) => { + self.evented.registration().set_ready(false); + + Err(mpsc::TrySendError::Full(t)) + } + Err(mpsc::TrySendError::Disconnected(t)) => Err(mpsc::TrySendError::Disconnected(t)), + } + } + + pub fn send(&self, t: T) -> SendFuture<'_, T> { + SendFuture { + s: self, + t: Some(t), + } + } +} + +pub struct AsyncReceiver { + evented: CustomEvented, + inner: Receiver, +} + +impl AsyncReceiver { + pub fn new(r: Receiver) -> Self { + let evented = CustomEvented::new( + r.get_read_registration(), + mio::Interest::READABLE, + &get_reactor(), + ) + .unwrap(); + + evented.registration().set_ready(true); + + Self { evented, inner: r } + } + + pub fn recv(&self) -> RecvFuture<'_, T> { + RecvFuture { r: self } + } +} + +pub struct AsyncLocalSender { + evented: CustomEvented, + inner: LocalSender, +} + +impl AsyncLocalSender { + pub fn new(s: LocalSender) -> Self { + let evented = CustomEvented::new_local( + s.get_write_registration(), + mio::Interest::WRITABLE, + &get_reactor(), + ) + .unwrap(); + + evented.registration().set_ready(true); + + Self { evented, inner: s } + } + + pub fn into_inner(self) -> LocalSender { + // normally, the poll registration would be deregistered when the + // sender drops, but here we are keeping the sender alive, so we need + // to explicitly deregister + self.evented + .registration() + .deregister_custom_local(self.inner.get_write_registration()) + .unwrap(); + + self.inner + } + + pub fn send(&self, t: T) -> LocalSendFuture<'_, T> { + LocalSendFuture { + s: self, + t: Some(t), + } + } + + // it's okay to run multiple instances of this future within the same + // task. see the comment on the CheckSendFuture struct + pub fn check_send(&self) -> CheckSendFuture<'_, T> { + CheckSendFuture { s: self } + } + + pub fn try_send(&self, t: T) -> Result<(), mpsc::TrySendError> { + self.inner.try_send(t) + } + + pub fn cancel(&self) { + self.inner.cancel(); + } +} + +pub struct AsyncLocalReceiver { + evented: CustomEvented, + inner: LocalReceiver, +} + +impl AsyncLocalReceiver { + pub fn new(r: LocalReceiver) -> Self { + let evented = CustomEvented::new_local( + r.get_read_registration(), + mio::Interest::READABLE, + &get_reactor(), + ) + .unwrap(); + + evented.registration().set_ready(true); + + Self { evented, inner: r } + } + + pub fn into_inner(self) -> LocalReceiver { + // normally, the poll registration would be deregistered when the + // receiver drops, but here we are keeping the receiver alive, so we + // need to explicitly deregister + self.evented + .registration() + .deregister_custom_local(self.inner.get_read_registration()) + .unwrap(); + + self.inner + } + + pub fn recv(&self) -> LocalRecvFuture<'_, T> { + LocalRecvFuture { r: self } + } +} + +pub struct WaitWritableFuture<'a, T> { + s: &'a AsyncSender, +} + +impl Future for WaitWritableFuture<'_, T> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let f = &*self; + + f.s.evented + .registration() + .set_waker(cx.waker(), mio::Interest::WRITABLE); + + // if can_send() returns false, then we know we can't write. this + // check prevents spurious wakups of a rendezvous channel from + // indicating writability when the channel is not actually writable + if !f.s.inner.can_send() { + f.s.evented.registration().set_ready(false); + } + + if !f.s.evented.registration().is_ready() { + return Poll::Pending; + } + + Poll::Ready(()) + } +} + +impl Drop for WaitWritableFuture<'_, T> { + fn drop(&mut self) { + self.s.evented.registration().clear_waker(); + } +} + +pub struct SendFuture<'a, T> { + s: &'a AsyncSender, + t: Option, +} + +impl Future for SendFuture<'_, T> +where + T: Unpin, +{ + type Output = Result<(), mpsc::SendError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let f = &mut *self; + + f.s.evented + .registration() + .set_waker(cx.waker(), mio::Interest::WRITABLE); + + if !f.s.evented.registration().is_ready() { + return Poll::Pending; + } + + if !f.s.evented.registration().pull_from_budget() { + return Poll::Pending; + } + + let t = f.t.take().unwrap(); + + // try_send will update the registration readiness, so we don't need + // to do that here + match f.s.try_send(t) { + Ok(()) => Poll::Ready(Ok(())), + Err(mpsc::TrySendError::Full(t)) => { + f.t = Some(t); + + Poll::Pending + } + Err(mpsc::TrySendError::Disconnected(t)) => Poll::Ready(Err(mpsc::SendError(t))), + } + } +} + +impl Drop for SendFuture<'_, T> { + fn drop(&mut self) { + self.s.evented.registration().clear_waker(); + } +} + +pub struct RecvFuture<'a, T> { + r: &'a AsyncReceiver, +} + +impl Future for RecvFuture<'_, T> { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let f = &*self; + + f.r.evented + .registration() + .set_waker(cx.waker(), mio::Interest::READABLE); + + if !f.r.evented.registration().is_ready() { + return Poll::Pending; + } + + if !f.r.evented.registration().pull_from_budget() { + return Poll::Pending; + } + + match f.r.inner.try_recv() { + Ok(v) => Poll::Ready(Ok(v)), + Err(mpsc::TryRecvError::Empty) => { + f.r.evented.registration().set_ready(false); + + Poll::Pending + } + Err(mpsc::TryRecvError::Disconnected) => Poll::Ready(Err(mpsc::RecvError)), + } + } +} + +impl Drop for RecvFuture<'_, T> { + fn drop(&mut self) { + self.r.evented.registration().clear_waker(); + } +} + +pub struct LocalSendFuture<'a, T> { + s: &'a AsyncLocalSender, + t: Option, +} + +impl Future for LocalSendFuture<'_, T> +where + T: Unpin, +{ + type Output = Result<(), mpsc::SendError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let f = &mut *self; + + f.s.evented + .registration() + .set_waker(cx.waker(), mio::Interest::WRITABLE); + + if !f.s.evented.registration().is_ready() { + return Poll::Pending; + } + + if !f.s.evented.registration().pull_from_budget() { + return Poll::Pending; + } + + let t = f.t.take().unwrap(); + + match f.s.inner.try_send(t) { + Ok(()) => Poll::Ready(Ok(())), + Err(mpsc::TrySendError::Full(t)) => { + f.s.evented.registration().set_ready(false); + f.t = Some(t); + + Poll::Pending + } + Err(mpsc::TrySendError::Disconnected(t)) => Poll::Ready(Err(mpsc::SendError(t))), + } + } +} + +impl Drop for LocalSendFuture<'_, T> { + fn drop(&mut self) { + self.s.inner.cancel(); + + self.s.evented.registration().clear_waker(); + } +} + +// it's okay to maintain multiple instances of this future at the same time +// within the same task. calling poll() won't negatively affect other +// instances. the drop() method clears the waker on the shared registration, +// which may look problematic. however, whenever any instance is (re-)polled, +// the waker will be reinstated. +// +// notably, these scenarios work: +// +// * creating two instances and awaiting them sequentially +// * creating two instances and selecting on them in a loop. both will +// eventually complete +// * creating one instance, polling it to pending, then creating a second +// instance and polling it to completion, then polling on the first +// instance again +pub struct CheckSendFuture<'a, T> { + s: &'a AsyncLocalSender, +} + +impl Future for CheckSendFuture<'_, T> +where + T: Unpin, +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let f = &mut *self; + + f.s.evented + .registration() + .set_waker(cx.waker(), mio::Interest::WRITABLE); + + if !f.s.inner.check_send() { + f.s.evented.registration().set_ready(false); + + return Poll::Pending; + } + + Poll::Ready(()) + } +} + +impl Drop for CheckSendFuture<'_, T> { + fn drop(&mut self) { + self.s.evented.registration().clear_waker(); + } +} + +pub struct LocalRecvFuture<'a, T> { + r: &'a AsyncLocalReceiver, +} + +impl Future for LocalRecvFuture<'_, T> { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let f = &*self; + + f.r.evented + .registration() + .set_waker(cx.waker(), mio::Interest::READABLE); + + if !f.r.evented.registration().is_ready() { + return Poll::Pending; + } + + if !f.r.evented.registration().pull_from_budget() { + return Poll::Pending; + } + + match f.r.inner.try_recv() { + Ok(v) => Poll::Ready(Ok(v)), + Err(mpsc::TryRecvError::Empty) => { + f.r.evented.registration().set_ready(false); + + Poll::Pending + } + Err(mpsc::TryRecvError::Disconnected) => Poll::Ready(Err(mpsc::RecvError)), + } + } +} + +impl Drop for LocalRecvFuture<'_, T> { + fn drop(&mut self) { + self.r.evented.registration().clear_waker(); + } +} + #[cfg(test)] mod tests { use super::*; + use crate::core::executor::Executor; + use crate::core::reactor::Reactor; + use crate::future::poll_async; + use std::cell::Cell; use std::time; #[test] @@ -859,4 +1296,356 @@ mod tests { assert_eq!(receiver.try_recv(), Ok(2)); } + + #[test] + fn test_async_send_bound0() { + let reactor = Reactor::new(2); + let executor = Executor::new(2); + + let (s, r) = channel::(0); + + let s = AsyncSender::new(s); + let r = AsyncReceiver::new(r); + + executor + .spawn(async move { + s.send(1).await.unwrap(); + + assert_eq!(s.is_writable(), false); + }) + .unwrap(); + + executor.run_until_stalled(); + + assert_eq!(executor.have_tasks(), true); + + executor + .spawn(async move { + assert_eq!(r.recv().await, Ok(1)); + assert_eq!(r.recv().await, Err(mpsc::RecvError)); + }) + .unwrap(); + + executor.run(|timeout| reactor.poll(timeout)).unwrap(); + } + + #[test] + fn test_async_send_bound1() { + let reactor = Reactor::new(2); + let executor = Executor::new(1); + + let (s, r) = channel::(1); + + let s = AsyncSender::new(s); + let r = AsyncReceiver::new(r); + + executor + .spawn(async move { + s.send(1).await.unwrap(); + + assert_eq!(s.is_writable(), true); + }) + .unwrap(); + + executor.run_until_stalled(); + + assert_eq!(executor.have_tasks(), false); + + executor + .spawn(async move { + assert_eq!(r.recv().await, Ok(1)); + assert_eq!(r.recv().await, Err(mpsc::RecvError)); + }) + .unwrap(); + + executor.run(|timeout| reactor.poll(timeout)).unwrap(); + } + + #[test] + fn test_async_recv() { + let reactor = Reactor::new(2); + let executor = Executor::new(2); + + let (s, r) = channel::(0); + + let s = AsyncSender::new(s); + let r = AsyncReceiver::new(r); + + executor + .spawn(async move { + assert_eq!(r.recv().await, Ok(1)); + assert_eq!(r.recv().await, Err(mpsc::RecvError)); + }) + .unwrap(); + + executor.run_until_stalled(); + + assert_eq!(executor.have_tasks(), true); + + executor + .spawn(async move { + s.send(1).await.unwrap(); + }) + .unwrap(); + + executor.run(|timeout| reactor.poll(timeout)).unwrap(); + } + + #[test] + fn test_async_writable() { + let reactor = Reactor::new(1); + let executor = Executor::new(1); + + let (s, r) = channel::(0); + + let s = AsyncSender::new(s); + + executor + .spawn(async move { + assert_eq!(s.is_writable(), false); + + s.wait_writable().await; + }) + .unwrap(); + + executor.run_until_stalled(); + + assert_eq!(executor.have_tasks(), true); + + // attempting to receive on a rendezvous channel will make the + // sender writable + assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); + + executor.run(|timeout| reactor.poll(timeout)).unwrap(); + } + + #[test] + fn test_async_local_channel() { + let reactor = Reactor::new(2); + let executor = Executor::new(2); + + let (s, r) = local_channel::(1, 1, &reactor.local_registration_memory()); + + let s = AsyncLocalSender::new(s); + let r = AsyncLocalReceiver::new(r); + + executor + .spawn(async move { + assert_eq!(r.recv().await, Ok(1)); + assert_eq!(r.recv().await, Err(mpsc::RecvError)); + }) + .unwrap(); + + executor.run_until_stalled(); + + assert_eq!(executor.have_tasks(), true); + + executor + .spawn(async move { + s.send(1).await.unwrap(); + }) + .unwrap(); + + executor.run(|timeout| reactor.poll(timeout)).unwrap(); + } + + #[test] + fn test_async_check_send_sequential() { + // create two instances and await them sequentially + + let reactor = Reactor::new(2); + let executor = Executor::new(2); + + let (s, r) = local_channel::(1, 1, &reactor.local_registration_memory()); + + let state = Rc::new(Cell::new(0)); + + { + let state = state.clone(); + + executor + .spawn(async move { + let s = AsyncLocalSender::new(s); + + // fill the queue + s.send(1).await.unwrap(); + state.set(1); + + // create two instances and await them sequentially + + let fut1 = s.check_send(); + let fut2 = s.check_send(); + + fut1.await; + + s.send(2).await.unwrap(); + state.set(2); + + fut2.await; + + state.set(3); + }) + .unwrap(); + } + + reactor.poll_nonblocking(reactor.now()).unwrap(); + executor.run_until_stalled(); + assert_eq!(state.get(), 1); + + assert_eq!(r.try_recv(), Ok(1)); + assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); + reactor.poll_nonblocking(reactor.now()).unwrap(); + executor.run_until_stalled(); + assert_eq!(state.get(), 2); + + assert_eq!(r.try_recv(), Ok(2)); + assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); + reactor.poll_nonblocking(reactor.now()).unwrap(); + executor.run_until_stalled(); + assert_eq!(state.get(), 3); + + executor.run(|timeout| reactor.poll(timeout)).unwrap(); + } + + #[test] + fn test_async_check_send_alternating() { + // create one instance, poll it to pending, then create a second + // instance and poll it to completion, then poll the first again + + let reactor = Reactor::new(2); + let executor = Executor::new(2); + + let (s, r) = local_channel::(1, 1, &reactor.local_registration_memory()); + + let state = Rc::new(Cell::new(0)); + + { + let state = state.clone(); + + executor + .spawn(async move { + let s = AsyncLocalSender::new(s); + + // fill the queue + s.send(1).await.unwrap(); + + // create one instance + let mut fut1 = s.check_send(); + + // poll it to pending + assert_eq!(poll_async(&mut fut1).await, Poll::Pending); + state.set(1); + + // create a second instance and poll it to completion + s.check_send().await; + + s.send(2).await.unwrap(); + state.set(2); + + // poll the first again + fut1.await; + + state.set(3); + }) + .unwrap(); + } + + reactor.poll_nonblocking(reactor.now()).unwrap(); + executor.run_until_stalled(); + assert_eq!(state.get(), 1); + + assert_eq!(r.try_recv(), Ok(1)); + assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); + reactor.poll_nonblocking(reactor.now()).unwrap(); + executor.run_until_stalled(); + assert_eq!(state.get(), 2); + + assert_eq!(r.try_recv(), Ok(2)); + assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); + reactor.poll_nonblocking(reactor.now()).unwrap(); + executor.run_until_stalled(); + assert_eq!(state.get(), 3); + + executor.run(|timeout| reactor.poll(timeout)).unwrap(); + } + + #[test] + fn test_budget_unlimited() { + let reactor = Reactor::new(1); + let executor = Executor::new(1); + + let (s, r) = channel::(3); + + s.send(1).unwrap(); + s.send(2).unwrap(); + s.send(3).unwrap(); + mem::drop(s); + + let r = AsyncReceiver::new(r); + + executor + .spawn(async move { + assert_eq!(r.recv().await, Ok(1)); + assert_eq!(r.recv().await, Ok(2)); + assert_eq!(r.recv().await, Ok(3)); + assert_eq!(r.recv().await, Err(mpsc::RecvError)); + }) + .unwrap(); + + let mut park_count = 0; + + executor + .run(|timeout| { + park_count += 1; + + reactor.poll(timeout) + }) + .unwrap(); + + assert_eq!(park_count, 0); + } + + #[test] + fn test_budget_1() { + let reactor = Reactor::new(1); + let executor = Executor::new(1); + + { + let reactor = reactor.clone(); + + executor.set_pre_poll(move || { + reactor.set_budget(Some(1)); + }); + } + + let (s, r) = channel::(3); + + s.send(1).unwrap(); + s.send(2).unwrap(); + s.send(3).unwrap(); + mem::drop(s); + + let r = AsyncReceiver::new(r); + + executor + .spawn(async move { + assert_eq!(r.recv().await, Ok(1)); + assert_eq!(r.recv().await, Ok(2)); + assert_eq!(r.recv().await, Ok(3)); + assert_eq!(r.recv().await, Err(mpsc::RecvError)); + }) + .unwrap(); + + let mut park_count = 0; + + executor + .run(|timeout| { + park_count += 1; + + reactor.poll(timeout) + }) + .unwrap(); + + assert_eq!(park_count, 3); + } } diff --git a/src/future.rs b/src/future.rs index 187af5f5..e1c0d6e1 100644 --- a/src/future.rs +++ b/src/future.rs @@ -17,7 +17,6 @@ use crate::connmgr::resolver; use crate::connmgr::tls::{TlsStream, TlsStreamError, VerifyMode}; use crate::core::arena; -use crate::core::channel; use crate::core::event::{self, ReadinessExt}; use crate::core::net::{NetListener, NetStream, SocketAddr}; use crate::core::reactor::{ @@ -37,7 +36,6 @@ use std::os::fd::{FromRawFd, IntoRawFd}; use std::path::Path; use std::pin::Pin; use std::rc::Rc; -use std::sync::mpsc; use std::task::{Context, Poll, Waker}; use std::time::{Duration, Instant}; @@ -458,179 +456,10 @@ pub fn io_split(handle: &RefCell) -> (ReadHalf, } #[track_caller] -fn get_reactor() -> Reactor { +pub fn get_reactor() -> Reactor { Reactor::current().expect("no reactor in thread") } -pub struct AsyncSender { - evented: CustomEvented, - inner: channel::Sender, -} - -impl AsyncSender { - pub fn new(s: channel::Sender) -> Self { - let evented = CustomEvented::new( - s.get_write_registration(), - mio::Interest::WRITABLE, - &get_reactor(), - ) - .unwrap(); - - // assume we can write, unless can_send() returns false. note that - // if can_send() returns true, it doesn't mean we can actually write - evented.registration().set_ready(s.can_send()); - - Self { evented, inner: s } - } - - pub fn is_writable(&self) -> bool { - self.evented.registration().is_ready() - } - - pub fn wait_writable(&self) -> WaitWritableFuture<'_, T> { - WaitWritableFuture { s: self } - } - - pub fn try_send(&self, t: T) -> Result<(), mpsc::TrySendError> { - match self.inner.try_send(t) { - Ok(_) => { - // if can_send() returns false, then we know we can't write - if !self.inner.can_send() { - self.evented.registration().set_ready(false); - } - - Ok(()) - } - Err(mpsc::TrySendError::Full(t)) => { - self.evented.registration().set_ready(false); - - Err(mpsc::TrySendError::Full(t)) - } - Err(mpsc::TrySendError::Disconnected(t)) => Err(mpsc::TrySendError::Disconnected(t)), - } - } - - pub fn send(&self, t: T) -> SendFuture<'_, T> { - SendFuture { - s: self, - t: Some(t), - } - } -} - -pub struct AsyncReceiver { - evented: CustomEvented, - inner: channel::Receiver, -} - -impl AsyncReceiver { - pub fn new(r: channel::Receiver) -> Self { - let evented = CustomEvented::new( - r.get_read_registration(), - mio::Interest::READABLE, - &get_reactor(), - ) - .unwrap(); - - evented.registration().set_ready(true); - - Self { evented, inner: r } - } - - pub fn recv(&self) -> RecvFuture<'_, T> { - RecvFuture { r: self } - } -} - -pub struct AsyncLocalSender { - evented: CustomEvented, - inner: channel::LocalSender, -} - -impl AsyncLocalSender { - pub fn new(s: channel::LocalSender) -> Self { - let evented = CustomEvented::new_local( - s.get_write_registration(), - mio::Interest::WRITABLE, - &get_reactor(), - ) - .unwrap(); - - evented.registration().set_ready(true); - - Self { evented, inner: s } - } - - pub fn into_inner(self) -> channel::LocalSender { - // normally, the poll registration would be deregistered when the - // sender drops, but here we are keeping the sender alive, so we need - // to explicitly deregister - self.evented - .registration() - .deregister_custom_local(self.inner.get_write_registration()) - .unwrap(); - - self.inner - } - - pub fn send(&self, t: T) -> LocalSendFuture<'_, T> { - LocalSendFuture { - s: self, - t: Some(t), - } - } - - // it's okay to run multiple instances of this future within the same - // task. see the comment on the CheckSendFuture struct - pub fn check_send(&self) -> CheckSendFuture<'_, T> { - CheckSendFuture { s: self } - } - - pub fn try_send(&self, t: T) -> Result<(), mpsc::TrySendError> { - self.inner.try_send(t) - } - - pub fn cancel(&self) { - self.inner.cancel(); - } -} - -pub struct AsyncLocalReceiver { - evented: CustomEvented, - inner: channel::LocalReceiver, -} - -impl AsyncLocalReceiver { - pub fn new(r: channel::LocalReceiver) -> Self { - let evented = CustomEvented::new_local( - r.get_read_registration(), - mio::Interest::READABLE, - &get_reactor(), - ) - .unwrap(); - - evented.registration().set_ready(true); - - Self { evented, inner: r } - } - - pub fn into_inner(self) -> channel::LocalReceiver { - // normally, the poll registration would be deregistered when the - // receiver drops, but here we are keeping the receiver alive, so we - // need to explicitly deregister - self.evented - .registration() - .deregister_custom_local(self.inner.get_read_registration()) - .unwrap(); - - self.inner - } - - pub fn recv(&self) -> LocalRecvFuture<'_, T> { - LocalRecvFuture { r: self } - } -} - pub struct AsyncResolver<'a> { resolver: &'a resolver::Resolver, } @@ -1250,265 +1079,6 @@ impl<'a> EventWaiter<'a> { } } -pub struct WaitWritableFuture<'a, T> { - s: &'a AsyncSender, -} - -impl Future for WaitWritableFuture<'_, T> { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let f = &*self; - - f.s.evented - .registration() - .set_waker(cx.waker(), mio::Interest::WRITABLE); - - // if can_send() returns false, then we know we can't write. this - // check prevents spurious wakups of a rendezvous channel from - // indicating writability when the channel is not actually writable - if !f.s.inner.can_send() { - f.s.evented.registration().set_ready(false); - } - - if !f.s.evented.registration().is_ready() { - return Poll::Pending; - } - - Poll::Ready(()) - } -} - -impl Drop for WaitWritableFuture<'_, T> { - fn drop(&mut self) { - self.s.evented.registration().clear_waker(); - } -} - -pub struct SendFuture<'a, T> { - s: &'a AsyncSender, - t: Option, -} - -impl Future for SendFuture<'_, T> -where - T: Unpin, -{ - type Output = Result<(), mpsc::SendError>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let f = &mut *self; - - f.s.evented - .registration() - .set_waker(cx.waker(), mio::Interest::WRITABLE); - - if !f.s.evented.registration().is_ready() { - return Poll::Pending; - } - - if !f.s.evented.registration().pull_from_budget() { - return Poll::Pending; - } - - let t = f.t.take().unwrap(); - - // try_send will update the registration readiness, so we don't need - // to do that here - match f.s.try_send(t) { - Ok(()) => Poll::Ready(Ok(())), - Err(mpsc::TrySendError::Full(t)) => { - f.t = Some(t); - - Poll::Pending - } - Err(mpsc::TrySendError::Disconnected(t)) => Poll::Ready(Err(mpsc::SendError(t))), - } - } -} - -impl Drop for SendFuture<'_, T> { - fn drop(&mut self) { - self.s.evented.registration().clear_waker(); - } -} - -pub struct RecvFuture<'a, T> { - r: &'a AsyncReceiver, -} - -impl Future for RecvFuture<'_, T> { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let f = &*self; - - f.r.evented - .registration() - .set_waker(cx.waker(), mio::Interest::READABLE); - - if !f.r.evented.registration().is_ready() { - return Poll::Pending; - } - - if !f.r.evented.registration().pull_from_budget() { - return Poll::Pending; - } - - match f.r.inner.try_recv() { - Ok(v) => Poll::Ready(Ok(v)), - Err(mpsc::TryRecvError::Empty) => { - f.r.evented.registration().set_ready(false); - - Poll::Pending - } - Err(mpsc::TryRecvError::Disconnected) => Poll::Ready(Err(mpsc::RecvError)), - } - } -} - -impl Drop for RecvFuture<'_, T> { - fn drop(&mut self) { - self.r.evented.registration().clear_waker(); - } -} - -pub struct LocalSendFuture<'a, T> { - s: &'a AsyncLocalSender, - t: Option, -} - -impl Future for LocalSendFuture<'_, T> -where - T: Unpin, -{ - type Output = Result<(), mpsc::SendError>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let f = &mut *self; - - f.s.evented - .registration() - .set_waker(cx.waker(), mio::Interest::WRITABLE); - - if !f.s.evented.registration().is_ready() { - return Poll::Pending; - } - - if !f.s.evented.registration().pull_from_budget() { - return Poll::Pending; - } - - let t = f.t.take().unwrap(); - - match f.s.inner.try_send(t) { - Ok(()) => Poll::Ready(Ok(())), - Err(mpsc::TrySendError::Full(t)) => { - f.s.evented.registration().set_ready(false); - f.t = Some(t); - - Poll::Pending - } - Err(mpsc::TrySendError::Disconnected(t)) => Poll::Ready(Err(mpsc::SendError(t))), - } - } -} - -impl Drop for LocalSendFuture<'_, T> { - fn drop(&mut self) { - self.s.inner.cancel(); - - self.s.evented.registration().clear_waker(); - } -} - -// it's okay to maintain multiple instances of this future at the same time -// within the same task. calling poll() won't negatively affect other -// instances. the drop() method clears the waker on the shared registration, -// which may look problematic. however, whenever any instance is (re-)polled, -// the waker will be reinstated. -// -// notably, these scenarios work: -// -// * creating two instances and awaiting them sequentially -// * creating two instances and selecting on them in a loop. both will -// eventually complete -// * creating one instance, polling it to pending, then creating a second -// instance and polling it to completion, then polling on the first -// instance again -pub struct CheckSendFuture<'a, T> { - s: &'a AsyncLocalSender, -} - -impl Future for CheckSendFuture<'_, T> -where - T: Unpin, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let f = &mut *self; - - f.s.evented - .registration() - .set_waker(cx.waker(), mio::Interest::WRITABLE); - - if !f.s.inner.check_send() { - f.s.evented.registration().set_ready(false); - - return Poll::Pending; - } - - Poll::Ready(()) - } -} - -impl Drop for CheckSendFuture<'_, T> { - fn drop(&mut self) { - self.s.evented.registration().clear_waker(); - } -} - -pub struct LocalRecvFuture<'a, T> { - r: &'a AsyncLocalReceiver, -} - -impl Future for LocalRecvFuture<'_, T> { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let f = &*self; - - f.r.evented - .registration() - .set_waker(cx.waker(), mio::Interest::READABLE); - - if !f.r.evented.registration().is_ready() { - return Poll::Pending; - } - - if !f.r.evented.registration().pull_from_budget() { - return Poll::Pending; - } - - match f.r.inner.try_recv() { - Ok(v) => Poll::Ready(Ok(v)), - Err(mpsc::TryRecvError::Empty) => { - f.r.evented.registration().set_ready(false); - - Poll::Pending - } - Err(mpsc::TryRecvError::Disconnected) => Poll::Ready(Err(mpsc::RecvError)), - } - } -} - -impl Drop for LocalRecvFuture<'_, T> { - fn drop(&mut self) { - self.r.evented.registration().clear_waker(); - } -} - pub struct QueryFuture { evented: Option, query: Option, @@ -2807,11 +2377,11 @@ pub fn yield_to_local_events() -> YieldToLocalEvents { mod tests { use super::*; use crate::connmgr::tls::TlsAcceptor; + use crate::core::channel; use crate::core::executor::Executor; use crate::core::zmq::SpecInfo; use std::cmp; use std::fs; - use std::mem; use std::net::IpAddr; use std::rc::Rc; use std::str; @@ -2870,278 +2440,6 @@ mod tests { fn cancel(&mut self) {} } - #[test] - fn test_channel_send_bound0() { - let reactor = Reactor::new(2); - let executor = Executor::new(2); - - let (s, r) = channel::channel::(0); - - let s = AsyncSender::new(s); - let r = AsyncReceiver::new(r); - - executor - .spawn(async move { - s.send(1).await.unwrap(); - - assert_eq!(s.is_writable(), false); - }) - .unwrap(); - - executor.run_until_stalled(); - - assert_eq!(executor.have_tasks(), true); - - executor - .spawn(async move { - assert_eq!(r.recv().await, Ok(1)); - assert_eq!(r.recv().await, Err(mpsc::RecvError)); - }) - .unwrap(); - - executor.run(|timeout| reactor.poll(timeout)).unwrap(); - } - - #[test] - fn test_channel_send_bound1() { - let reactor = Reactor::new(2); - let executor = Executor::new(1); - - let (s, r) = channel::channel::(1); - - let s = AsyncSender::new(s); - let r = AsyncReceiver::new(r); - - executor - .spawn(async move { - s.send(1).await.unwrap(); - - assert_eq!(s.is_writable(), true); - }) - .unwrap(); - - executor.run_until_stalled(); - - assert_eq!(executor.have_tasks(), false); - - executor - .spawn(async move { - assert_eq!(r.recv().await, Ok(1)); - assert_eq!(r.recv().await, Err(mpsc::RecvError)); - }) - .unwrap(); - - executor.run(|timeout| reactor.poll(timeout)).unwrap(); - } - - #[test] - fn test_channel_recv() { - let reactor = Reactor::new(2); - let executor = Executor::new(2); - - let (s, r) = channel::channel::(0); - - let s = AsyncSender::new(s); - let r = AsyncReceiver::new(r); - - executor - .spawn(async move { - assert_eq!(r.recv().await, Ok(1)); - assert_eq!(r.recv().await, Err(mpsc::RecvError)); - }) - .unwrap(); - - executor.run_until_stalled(); - - assert_eq!(executor.have_tasks(), true); - - executor - .spawn(async move { - s.send(1).await.unwrap(); - }) - .unwrap(); - - executor.run(|timeout| reactor.poll(timeout)).unwrap(); - } - - #[test] - fn test_channel_writable() { - let reactor = Reactor::new(1); - let executor = Executor::new(1); - - let (s, r) = channel::channel::(0); - - let s = AsyncSender::new(s); - - executor - .spawn(async move { - assert_eq!(s.is_writable(), false); - - s.wait_writable().await; - }) - .unwrap(); - - executor.run_until_stalled(); - - assert_eq!(executor.have_tasks(), true); - - // attempting to receive on a rendezvous channel will make the - // sender writable - assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); - - executor.run(|timeout| reactor.poll(timeout)).unwrap(); - } - - #[test] - fn test_local_channel() { - let reactor = Reactor::new(2); - let executor = Executor::new(2); - - let (s, r) = channel::local_channel::(1, 1, &reactor.local_registration_memory()); - - let s = AsyncLocalSender::new(s); - let r = AsyncLocalReceiver::new(r); - - executor - .spawn(async move { - assert_eq!(r.recv().await, Ok(1)); - assert_eq!(r.recv().await, Err(mpsc::RecvError)); - }) - .unwrap(); - - executor.run_until_stalled(); - - assert_eq!(executor.have_tasks(), true); - - executor - .spawn(async move { - s.send(1).await.unwrap(); - }) - .unwrap(); - - executor.run(|timeout| reactor.poll(timeout)).unwrap(); - } - - #[test] - fn test_check_send_sequential() { - // create two instances and await them sequentially - - let reactor = Reactor::new(2); - let executor = Executor::new(2); - - let (s, r) = channel::local_channel::(1, 1, &reactor.local_registration_memory()); - - let state = Rc::new(Cell::new(0)); - - { - let state = state.clone(); - - executor - .spawn(async move { - let s = AsyncLocalSender::new(s); - - // fill the queue - s.send(1).await.unwrap(); - state.set(1); - - // create two instances and await them sequentially - - let fut1 = s.check_send(); - let fut2 = s.check_send(); - - fut1.await; - - s.send(2).await.unwrap(); - state.set(2); - - fut2.await; - - state.set(3); - }) - .unwrap(); - } - - reactor.poll_nonblocking(reactor.now()).unwrap(); - executor.run_until_stalled(); - assert_eq!(state.get(), 1); - - assert_eq!(r.try_recv(), Ok(1)); - assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); - reactor.poll_nonblocking(reactor.now()).unwrap(); - executor.run_until_stalled(); - assert_eq!(state.get(), 2); - - assert_eq!(r.try_recv(), Ok(2)); - assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); - reactor.poll_nonblocking(reactor.now()).unwrap(); - executor.run_until_stalled(); - assert_eq!(state.get(), 3); - - executor.run(|timeout| reactor.poll(timeout)).unwrap(); - } - - #[test] - fn test_check_send_alternating() { - // create one instance, poll it to pending, then create a second - // instance and poll it to completion, then poll the first again - - let reactor = Reactor::new(2); - let executor = Executor::new(2); - - let (s, r) = channel::local_channel::(1, 1, &reactor.local_registration_memory()); - - let state = Rc::new(Cell::new(0)); - - { - let state = state.clone(); - - executor - .spawn(async move { - let s = AsyncLocalSender::new(s); - - // fill the queue - s.send(1).await.unwrap(); - - // create one instance - let mut fut1 = s.check_send(); - - // poll it to pending - assert_eq!(poll_async(&mut fut1).await, Poll::Pending); - state.set(1); - - // create a second instance and poll it to completion - s.check_send().await; - - s.send(2).await.unwrap(); - state.set(2); - - // poll the first again - fut1.await; - - state.set(3); - }) - .unwrap(); - } - - reactor.poll_nonblocking(reactor.now()).unwrap(); - executor.run_until_stalled(); - assert_eq!(state.get(), 1); - - assert_eq!(r.try_recv(), Ok(1)); - assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); - reactor.poll_nonblocking(reactor.now()).unwrap(); - executor.run_until_stalled(); - assert_eq!(state.get(), 2); - - assert_eq!(r.try_recv(), Ok(2)); - assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); - reactor.poll_nonblocking(reactor.now()).unwrap(); - executor.run_until_stalled(); - assert_eq!(state.get(), 3); - - executor.run(|timeout| reactor.poll(timeout)).unwrap(); - } - #[test] fn test_read_write() { let executor = Executor::new(1); @@ -3605,86 +2903,6 @@ mod tests { assert_eq!(msg, zmq::Message::from(&b"3"[..])); } - #[test] - fn test_budget_unlimited() { - let reactor = Reactor::new(1); - let executor = Executor::new(1); - - let (s, r) = channel::channel::(3); - - s.send(1).unwrap(); - s.send(2).unwrap(); - s.send(3).unwrap(); - mem::drop(s); - - let r = AsyncReceiver::new(r); - - executor - .spawn(async move { - assert_eq!(r.recv().await, Ok(1)); - assert_eq!(r.recv().await, Ok(2)); - assert_eq!(r.recv().await, Ok(3)); - assert_eq!(r.recv().await, Err(mpsc::RecvError)); - }) - .unwrap(); - - let mut park_count = 0; - - executor - .run(|timeout| { - park_count += 1; - - reactor.poll(timeout) - }) - .unwrap(); - - assert_eq!(park_count, 0); - } - - #[test] - fn test_budget_1() { - let reactor = Reactor::new(1); - let executor = Executor::new(1); - - { - let reactor = reactor.clone(); - - executor.set_pre_poll(move || { - reactor.set_budget(Some(1)); - }); - } - - let (s, r) = channel::channel::(3); - - s.send(1).unwrap(); - s.send(2).unwrap(); - s.send(3).unwrap(); - mem::drop(s); - - let r = AsyncReceiver::new(r); - - executor - .spawn(async move { - assert_eq!(r.recv().await, Ok(1)); - assert_eq!(r.recv().await, Ok(2)); - assert_eq!(r.recv().await, Ok(3)); - assert_eq!(r.recv().await, Err(mpsc::RecvError)); - }) - .unwrap(); - - let mut park_count = 0; - - executor - .run(|timeout| { - park_count += 1; - - reactor.poll(timeout) - }) - .unwrap(); - - assert_eq!(park_count, 3); - } - #[test] fn test_timeout() { let now = Instant::now(); @@ -3775,7 +2993,7 @@ mod tests { let (s, r) = channel::local_channel::(1, 1, &reactor.local_registration_memory()); - let s = AsyncLocalSender::new(s); + let s = channel::AsyncLocalSender::new(s); executor .spawn(async move { @@ -3864,7 +3082,7 @@ mod tests { executor .spawn(async move { - let r = AsyncLocalReceiver::new(r); + let r = channel::AsyncLocalReceiver::new(r); r.recv().await.unwrap(); state.set(1); })