Skip to content

Commit

Permalink
feat: remove Control and ControlledConnection (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaseizinger authored Jul 3, 2023
1 parent 34dcb97 commit 69404a1
Show file tree
Hide file tree
Showing 11 changed files with 451 additions and 957 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# 0.12.0 - unreleased

- Remove `Control` and `ControlledConnection`.
Users have to move to the `poll_` functions of `Connection`.
See [PR #164](https://github.com/libp2p/rust-yamux/pull/164).

# 0.11.1

- Avoid race condition between pending frames and closing stream.
Expand Down
10 changes: 10 additions & 0 deletions test-harness/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,15 @@ anyhow = "1"
log = "0.4.17"

[dev-dependencies]
criterion = "0.5"
env_logger = "0.10"
futures = "0.3.4"
quickcheck = "1.0"
tokio = { version = "1.0", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.7", features = ["compat"] }
constrained-connection = "0.1"
futures_ringbuf = "0.4.0"

[[bench]]
name = "concurrent"
harness = false
73 changes: 16 additions & 57 deletions yamux/benches/concurrent.rs → test-harness/benches/concurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

use constrained_connection::{new_unconstrained_connection, samples, Endpoint};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use futures::{channel::mpsc, future, io::AsyncReadExt, prelude::*};
use std::iter;
use std::sync::Arc;
use test_harness::{dev_null_server, MessageSender, MessageSenderStrategy, Msg};
use tokio::{runtime::Runtime, task};
use yamux::{Config, Connection, Control, Mode};
use yamux::{Config, Connection, Mode};

criterion_group!(benches, concurrent);
criterion_main!(benches);
Expand Down Expand Up @@ -86,62 +87,20 @@ async fn oneway(
server: Endpoint,
client: Endpoint,
) {
let msg_len = data.0.len();
let (tx, rx) = mpsc::unbounded();
let server = Connection::new(server, config(), Mode::Server);
let client = Connection::new(client, config(), Mode::Client);

let server = async move {
let mut connection = Connection::new(server, config(), Mode::Server);
task::spawn(dev_null_server(server));

while let Some(Ok(mut stream)) = stream::poll_fn(|cx| connection.poll_next_inbound(cx))
.next()
.await
{
let tx = tx.clone();

task::spawn(async move {
let mut n = 0;
let mut b = vec![0; msg_len];

// Receive `nmessages` messages.
for _ in 0..nmessages {
stream.read_exact(&mut b[..]).await.unwrap();
n += b.len();
}

tx.unbounded_send(n).expect("unbounded_send");
stream.close().await.unwrap();
});
}
};
task::spawn(server);

let conn = Connection::new(client, config(), Mode::Client);
let (mut ctrl, conn) = Control::new(conn);

task::spawn(conn.for_each(|r| {
r.unwrap();
future::ready(())
}));

for _ in 0..nstreams {
let data = data.clone();
let mut ctrl = ctrl.clone();
task::spawn(async move {
let mut stream = ctrl.open_stream().await.unwrap();

// Send `nmessages` messages.
for _ in 0..nmessages {
stream.write_all(data.as_ref()).await.unwrap();
}

stream.close().await.unwrap();
});
}

let n = rx
let messages = iter::repeat(data)
.map(|b| Msg(b.0.to_vec()))
.take(nstreams)
.fold(0, |acc, n| future::ready(acc + n))
.await;
assert_eq!(n, nstreams * nmessages * msg_len);
ctrl.close().await.expect("close");
.collect(); // `MessageSender` will use 1 stream per message.
let num_streams_used = MessageSender::new(client, messages, true)
.with_message_multiplier(nmessages as u64)
.with_strategy(MessageSenderStrategy::Send)
.await
.unwrap();

assert_eq!(num_streams_used, nstreams);
}
222 changes: 210 additions & 12 deletions test-harness/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io, mem};
use tokio::net::{TcpListener, TcpStream};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::task;
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
use yamux::ConnectionError;
use yamux::{Config, WindowUpdateMode};
Expand All @@ -19,8 +20,9 @@ use yamux::{Connection, Mode};
pub async fn connected_peers(
server_config: Config,
client_config: Config,
buffer_sizes: Option<TcpBufferSizes>,
) -> io::Result<(Connection<Compat<TcpStream>>, Connection<Compat<TcpStream>>)> {
let (listener, addr) = bind().await?;
let (listener, addr) = bind(buffer_sizes).await?;

let server = async {
let (stream, _) = listener.accept().await?;
Expand All @@ -31,7 +33,7 @@ pub async fn connected_peers(
))
};
let client = async {
let stream = TcpStream::connect(addr).await?;
let stream = new_socket(buffer_sizes)?.connect(addr).await?;
Ok(Connection::new(
stream.compat(),
client_config,
Expand All @@ -42,12 +44,27 @@ pub async fn connected_peers(
futures::future::try_join(server, client).await
}

pub async fn bind() -> io::Result<(TcpListener, SocketAddr)> {
let i = Ipv4Addr::new(127, 0, 0, 1);
let s = SocketAddr::V4(SocketAddrV4::new(i, 0));
let l = TcpListener::bind(&s).await?;
let a = l.local_addr()?;
Ok((l, a))
pub async fn bind(buffer_sizes: Option<TcpBufferSizes>) -> io::Result<(TcpListener, SocketAddr)> {
let socket = new_socket(buffer_sizes)?;
socket.bind(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(127, 0, 0, 1),
0,
)))?;

let listener = socket.listen(1024)?;
let address = listener.local_addr()?;

Ok((listener, address))
}

fn new_socket(buffer_sizes: Option<TcpBufferSizes>) -> io::Result<TcpSocket> {
let socket = TcpSocket::new_v4()?;
if let Some(size) = buffer_sizes {
socket.set_send_buffer_size(size.send)?;
socket.set_recv_buffer_size(size.recv)?;
}

Ok(socket)
}

/// For each incoming stream of `c` echo back to the sender.
Expand All @@ -67,6 +84,145 @@ where
.await
}

/// For each incoming stream of `c`, read to end but don't write back.
pub async fn dev_null_server<T>(mut c: Connection<T>) -> Result<(), ConnectionError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
stream::poll_fn(|cx| c.poll_next_inbound(cx))
.try_for_each_concurrent(None, |mut stream| async move {
let mut buf = [0u8; 1024];

while let Ok(n) = stream.read(&mut buf).await {
if n == 0 {
break;
}
}

stream.close().await?;
Ok(())
})
.await
}

pub struct MessageSender<T> {
connection: Connection<T>,
pending_messages: Vec<Msg>,
worker_streams: FuturesUnordered<BoxFuture<'static, ()>>,
streams_processed: usize,
/// Whether to spawn a new task for each stream.
spawn_tasks: bool,
/// How many times to send each message on the stream
message_multiplier: u64,
strategy: MessageSenderStrategy,
}

#[derive(Copy, Clone)]
pub enum MessageSenderStrategy {
SendRecv,
Send,
}

impl<T> MessageSender<T> {
pub fn new(connection: Connection<T>, messages: Vec<Msg>, spawn_tasks: bool) -> Self {
Self {
connection,
pending_messages: messages,
worker_streams: FuturesUnordered::default(),
streams_processed: 0,
spawn_tasks,
message_multiplier: 1,
strategy: MessageSenderStrategy::SendRecv,
}
}

pub fn with_message_multiplier(mut self, multiplier: u64) -> Self {
self.message_multiplier = multiplier;
self
}

pub fn with_strategy(mut self, strategy: MessageSenderStrategy) -> Self {
self.strategy = strategy;
self
}
}

impl<T> Future for MessageSender<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
type Output = yamux::Result<usize>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();

loop {
if this.pending_messages.is_empty() && this.worker_streams.is_empty() {
futures::ready!(this.connection.poll_close(cx)?);

return Poll::Ready(Ok(this.streams_processed));
}

if let Some(message) = this.pending_messages.pop() {
match this.connection.poll_new_outbound(cx)? {
Poll::Ready(mut stream) => {
let multiplier = this.message_multiplier;
let strategy = this.strategy;

let future = async move {
for _ in 0..multiplier {
match strategy {
MessageSenderStrategy::SendRecv => {
send_recv_message(&mut stream, &message).await.unwrap()
}
MessageSenderStrategy::Send => {
stream.write_all(&message.0).await.unwrap()
}
};
}

stream.close().await.unwrap();
};

let worker_stream_future = if this.spawn_tasks {
async { task::spawn(future).await.unwrap() }.boxed()
} else {
future.boxed()
};

this.worker_streams.push(worker_stream_future);
continue;
}
Poll::Pending => {
this.pending_messages.push(message);
}
}
}

match this.worker_streams.poll_next_unpin(cx) {
Poll::Ready(Some(())) => {
this.streams_processed += 1;
continue;
}
Poll::Ready(None) | Poll::Pending => {}
}

match this.connection.poll_next_inbound(cx)? {
Poll::Ready(Some(stream)) => {
drop(stream);
panic!("Did not expect remote to open a stream");
}
Poll::Ready(None) => {
panic!("Did not expect remote to close the connection");
}
Poll::Pending => {}
}

return Poll::Pending;
}
}
}

/// For each incoming stream, do nothing.
pub async fn noop_server(c: impl Stream<Item = Result<yamux::Stream, yamux::ConnectionError>>) {
c.for_each(|maybe_stream| {
Expand All @@ -76,13 +232,39 @@ pub async fn noop_server(c: impl Stream<Item = Result<yamux::Stream, yamux::Conn
.await;
}

pub async fn send_recv_message(stream: &mut yamux::Stream, Msg(msg): Msg) -> io::Result<()> {
/// Send and receive buffer size for a TCP socket.
#[derive(Clone, Debug, Copy)]
pub struct TcpBufferSizes {
send: u32,
recv: u32,
}

impl Arbitrary for TcpBufferSizes {
fn arbitrary(g: &mut Gen) -> Self {
let send = if bool::arbitrary(g) {
16 * 1024
} else {
32 * 1024
};

// Have receive buffer size be some multiple of send buffer size.
let recv = if bool::arbitrary(g) {
send * 2
} else {
send * 4
};

TcpBufferSizes { send, recv }
}
}

pub async fn send_recv_message(stream: &mut yamux::Stream, Msg(msg): &Msg) -> io::Result<()> {
let id = stream.id();
let (mut reader, mut writer) = AsyncReadExt::split(stream);

let len = msg.len();
let write_fut = async {
writer.write_all(&msg).await.unwrap();
writer.write_all(msg).await.unwrap();
log::debug!("C: {}: sent {} bytes", id, len);
};
let mut data = vec![0; msg.len()];
Expand All @@ -91,7 +273,23 @@ pub async fn send_recv_message(stream: &mut yamux::Stream, Msg(msg): Msg) -> io:
log::debug!("C: {}: received {} bytes", id, data.len());
};
futures::future::join(write_fut, read_fut).await;
assert_eq!(data, msg);
assert_eq!(&data, msg);

Ok(())
}

/// Send all messages, using only a single stream.
pub async fn send_on_single_stream(
mut stream: yamux::Stream,
iter: impl IntoIterator<Item = Msg>,
) -> Result<(), ConnectionError> {
log::debug!("C: new stream: {}", stream);

for msg in iter {
send_recv_message(&mut stream, &msg).await?;
}

stream.close().await?;

Ok(())
}
Expand Down
Loading

0 comments on commit 69404a1

Please sign in to comment.