Skip to content

Commit

Permalink
fix(network): close the stream when finishing an inbound session
Browse files Browse the repository at this point in the history
  • Loading branch information
ShahakShama committed Sep 28, 2023
1 parent 38a3c52 commit b475703
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 31 deletions.
60 changes: 49 additions & 11 deletions crates/papyrus_network/src/streamed_data_protocol/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use libp2p::swarm::handler::{
use libp2p::swarm::{ConnectionHandler, ConnectionHandlerEvent, KeepAlive, SubstreamProtocol};
use tracing::debug;

use self::session::InboundSession;
use self::session::{FinishReason, InboundSession};
use super::protocol::{InboundProtocol, OutboundProtocol, PROTOCOL_NAME};
use super::{DataBound, GenericEvent, InboundSessionId, OutboundSessionId, QueryBound, SessionId};
use crate::messages::read_message;
Expand Down Expand Up @@ -104,6 +104,30 @@ impl<Query: QueryBound, Data: DataBound> Handler<Query, Data> {
// _ => true,
// })
// }

/// Poll an inbound session, inserting any events needed to pending_events, and return whether
/// the inbound session has finished.
fn poll_inbound_session(
inbound_session: &mut InboundSession<Data>,
inbound_session_id: InboundSessionId,
pending_events: &mut VecDeque<HandlerEvent<Self>>,
cx: &mut Context<'_>,
) -> bool {
let Poll::Ready(finish_reason) = inbound_session.poll_unpin(cx) else {
let is_session_alive = false;
return is_session_alive;
};
if let FinishReason::Error(io_error) = finish_reason {
pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour(
ToBehaviourEvent::SessionFailed {
session_id: SessionId::InboundSessionId(inbound_session_id),
error: SessionError::IOError(io_error),
},
));
}
let is_session_alive = true;
is_session_alive
}
}

impl<Query: QueryBound, Data: DataBound> ConnectionHandler for Handler<Query, Data> {
Expand Down Expand Up @@ -141,17 +165,31 @@ impl<Query: QueryBound, Data: DataBound> ConnectionHandler for Handler<Query, Da
> {
// Handle inbound sessions.
self.id_to_inbound_session.retain(|inbound_session_id, inbound_session| {
if let Poll::Ready(io_error) = inbound_session.poll_unpin(cx) {
self.pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour(
ToBehaviourEvent::SessionFailed {
session_id: SessionId::InboundSessionId(*inbound_session_id),
error: SessionError::IOError(io_error),
},
));
return false;
if Self::poll_inbound_session(
inbound_session,
*inbound_session_id,
&mut self.pending_events,
cx,
) {
let is_session_alive = false;
return is_session_alive;
}
if self.inbound_sessions_marked_to_end.contains(inbound_session_id)
&& inbound_session.is_waiting()
{
inbound_session.start_closing();
if Self::poll_inbound_session(
inbound_session,
*inbound_session_id,
&mut self.pending_events,
cx,
) {
let is_session_alive = false;
return is_session_alive;
}
}
!(self.inbound_sessions_marked_to_end.contains(inbound_session_id)
&& inbound_session.is_waiting())
let is_session_alive = true;
is_session_alive
});

// Handle outbound sessions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::pin::Pin;
use std::task::{Context, Poll};

use futures::future::BoxFuture;
use futures::FutureExt;
use futures::{AsyncWriteExt, FutureExt};
use libp2p::swarm::Stream;
use replace_with::replace_with_or_abort;

Expand All @@ -17,9 +17,15 @@ pub(super) struct InboundSession<Data: DataBound> {
current_task: WriteMessageTask,
}

pub(super) enum FinishReason {
Error(io::Error),
Closed,
}

enum WriteMessageTask {
Waiting(Stream),
Running(BoxFuture<'static, Result<Stream, io::Error>>),
Closing(BoxFuture<'static, Result<(), io::Error>>),
}

impl<Data: DataBound> InboundSession<Data> {
Expand All @@ -45,11 +51,20 @@ impl<Data: DataBound> InboundSession<Data> {
&& self.pending_messages.is_empty()
}

fn handle_waiting(&mut self, cx: &mut Context<'_>) -> Option<io::Error> {
pub fn start_closing(&mut self) {
replace_with_or_abort(&mut self.current_task, |current_task| {
let WriteMessageTask::Waiting(mut stream) = current_task else {
panic!("Called start_closing while not waiting.");
};
WriteMessageTask::Closing(async move { stream.close().await }.boxed())
})
}

fn handle_waiting(&mut self, cx: &mut Context<'_>) -> Option<FinishReason> {
if let Some(data) = self.pending_messages.pop_front() {
replace_with_or_abort(&mut self.current_task, |current_task| {
let WriteMessageTask::Waiting(mut stream) = current_task else {
panic!("Called handle_waiting while running.");
panic!("Called handle_waiting while not waiting.");
};
WriteMessageTask::Running(
async move {
Expand All @@ -64,32 +79,44 @@ impl<Data: DataBound> InboundSession<Data> {
None
}

fn handle_running(&mut self, cx: &mut Context<'_>) -> Option<io::Error> {
fn handle_running(&mut self, cx: &mut Context<'_>) -> Option<FinishReason> {
let WriteMessageTask::Running(fut) = &mut self.current_task else {
panic!("Called handle_running while waiting.");
panic!("Called handle_running while not running.");
};
match fut.poll_unpin(cx) {
Poll::Pending => None,
Poll::Ready(Ok(stream)) => {
self.current_task = WriteMessageTask::Waiting(stream);
self.handle_waiting(cx)
}
Poll::Ready(Err(io_error)) => Some(io_error),
Poll::Ready(Err(io_error)) => Some(FinishReason::Error(io_error)),
}
}

fn handle_closing(&mut self, cx: &mut Context<'_>) -> Option<FinishReason> {
let WriteMessageTask::Closing(fut) = &mut self.current_task else {
panic!("Called handle_closing while not closing.");
};
match fut.poll_unpin(cx) {
Poll::Pending => None,
Poll::Ready(Ok(())) => Some(FinishReason::Closed),
Poll::Ready(Err(io_error)) => Some(FinishReason::Error(io_error)),
}
}
}

impl<Data: DataBound> Future for InboundSession<Data> {
type Output = io::Error;
type Output = FinishReason;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let unpinned_self = Pin::into_inner(self);
let result = match &mut unpinned_self.current_task {
WriteMessageTask::Running(_) => unpinned_self.handle_running(cx),
WriteMessageTask::Waiting(_) => unpinned_self.handle_waiting(cx),
WriteMessageTask::Closing(_) => unpinned_self.handle_closing(cx),
};
match result {
Some(error) => Poll::Ready(error),
Some(finish_reason) => Poll::Ready(finish_reason),
None => Poll::Pending,
}
}
Expand Down
75 changes: 63 additions & 12 deletions crates/papyrus_network/src/streamed_data_protocol/handler_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use futures::{select, FutureExt, Stream as StreamTrait, StreamExt};
use libp2p::swarm::handler::{ConnectionEvent, FullyNegotiatedInbound, FullyNegotiatedOutbound};
use libp2p::swarm::{ConnectionHandler, ConnectionHandlerEvent, Stream};

use super::super::{DataBound, InboundSessionId, OutboundSessionId, QueryBound};
use super::{Handler, HandlerEvent, RequestFromBehaviourEvent, SessionId, ToBehaviourEvent};
use super::super::{DataBound, InboundSessionId, OutboundSessionId, QueryBound, SessionId};
use super::{Handler, HandlerEvent, RequestFromBehaviourEvent, ToBehaviourEvent};
use crate::messages::block::{GetBlocks, GetBlocksResponse};
use crate::messages::{read_message, write_message};
use crate::test_utils::{get_connected_streams, hardcoded_data};
Expand Down Expand Up @@ -127,12 +127,30 @@ async fn validate_request_to_swarm_new_outbound_session_to_swarm_event<
);
}

async fn read_messages(stream: &mut Stream, num_messages: usize) -> Vec<GetBlocksResponse> {
let mut result = Vec::new();
for _ in 0..num_messages {
result.push(read_message::<GetBlocksResponse, _>(&mut *stream).await.unwrap().unwrap());
async fn read_messages<Query: QueryBound, Data: DataBound>(
handler: Handler<Query, Data>,
stream: &mut Stream,
num_messages: usize,
) -> Vec<GetBlocksResponse> {
async fn read_messages_inner(
stream: &mut Stream,
num_messages: usize,
) -> Vec<GetBlocksResponse> {
let mut result = Vec::new();
for _ in 0..num_messages {
match read_message::<GetBlocksResponse, _>(&mut *stream).await.unwrap() {
Some(message) => result.push(message),
None => return result,
}
}
result
}

let mut fused_handler = handler.fuse();
select! {
data = read_messages_inner(stream, num_messages).fuse() => data,
_ = fused_handler.next() => panic!("There shouldn't be another event from the handler"),
}
result
}

#[tokio::test]
Expand Down Expand Up @@ -160,14 +178,47 @@ async fn process_inbound_session() {
simulate_request_to_send_data_from_swarm(&mut handler, data.clone(), inbound_session_id);
}

let mut fused_handler = handler.fuse();
let data_received = select! {
data = read_messages(&mut outbound_stream, hardcoded_data_vec.len()).fuse() => data,
_ = fused_handler.next() => panic!("There shouldn't be another event from the handler"),
};
let data_received =
read_messages(handler, &mut outbound_stream, hardcoded_data_vec.len()).await;
assert_eq!(hardcoded_data_vec, data_received);
}

#[tokio::test]
async fn finished_inbound_session_ignores_behaviour_request_to_send_data() {
let mut handler = Handler::<GetBlocks, GetBlocksResponse>::new(
SUBSTREAM_TIMEOUT,
Arc::new(Default::default()),
);

let (inbound_stream, mut outbound_stream, _) = get_connected_streams().await;
// TODO(shahak): Change to GetBlocks::default() when the bug that forbids sending default
// messages is fixed.
let query = GetBlocks { limit: 10, ..Default::default() };
let inbound_session_id = InboundSessionId { value: 1 };

simulate_negotiated_inbound_session_from_swarm(
&mut handler,
query.clone(),
inbound_stream,
inbound_session_id,
);

// consume the new inbound session event without reading it.
handler.next().await;

simulate_request_to_finish_session(
&mut handler,
SessionId::InboundSessionId(inbound_session_id),
);

let hardcoded_data_vec = hardcoded_data();
for data in &hardcoded_data_vec {
simulate_request_to_send_data_from_swarm(&mut handler, data.clone(), inbound_session_id);
}
let data_received = read_messages(handler, &mut outbound_stream, 1).await;
assert!(data_received.is_empty());
}

#[test]
fn listen_protocol_across_multiple_handlers() {
let next_inbound_session_id = Arc::new(AtomicUsize::default());
Expand Down

0 comments on commit b475703

Please sign in to comment.