Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(network): close the stream when finishing an inbound session #1214

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 47 additions & 11 deletions crates/papyrus_network/src/streamed_data_protocol/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use libp2p::swarm::handler::{
use libp2p::swarm::{ConnectionHandler, ConnectionHandlerEvent, KeepAlive, SubstreamProtocol};
use tracing::debug;

use self::session::InboundSession;
use self::session::{FinishReason, InboundSession};
use super::protocol::{InboundProtocol, OutboundProtocol, PROTOCOL_NAME};
use super::{DataBound, GenericEvent, InboundSessionId, OutboundSessionId, QueryBound, SessionId};
use crate::messages::read_message;
Expand Down Expand Up @@ -104,6 +104,29 @@ impl<Query: QueryBound, Data: DataBound> Handler<Query, Data> {
// _ => true,
// })
// }

/// Poll an inbound session, inserting any events needed to pending_events, and return whether
/// the inbound session has finished.
fn poll_inbound_session(
inbound_session: &mut InboundSession<Data>,
inbound_session_id: InboundSessionId,
pending_events: &mut VecDeque<HandlerEvent<Self>>,
cx: &mut Context<'_>,
) -> bool {
let Poll::Ready(finish_reason) = inbound_session.poll_unpin(cx) else {
let is_session_alive = false;
return is_session_alive;
};
if let FinishReason::Error(io_error) = finish_reason {
pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour(
ToBehaviourEvent::SessionFailed {
session_id: SessionId::InboundSessionId(inbound_session_id),
error: SessionError::IOError(io_error),
},
));
}
true
}
}

impl<Query: QueryBound, Data: DataBound> ConnectionHandler for Handler<Query, Data> {
Expand Down Expand Up @@ -141,17 +164,30 @@ impl<Query: QueryBound, Data: DataBound> ConnectionHandler for Handler<Query, Da
> {
// Handle inbound sessions.
self.id_to_inbound_session.retain(|inbound_session_id, inbound_session| {
if let Poll::Ready(io_error) = inbound_session.poll_unpin(cx) {
self.pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour(
ToBehaviourEvent::SessionFailed {
session_id: SessionId::InboundSessionId(*inbound_session_id),
error: SessionError::IOError(io_error),
},
));
return false;
if Self::poll_inbound_session(
inbound_session,
*inbound_session_id,
&mut self.pending_events,
cx,
) {
let is_session_alive = false;
return is_session_alive;
}
if self.inbound_sessions_marked_to_end.contains(inbound_session_id)
&& inbound_session.is_waiting()
{
inbound_session.start_closing();
if Self::poll_inbound_session(
inbound_session,
*inbound_session_id,
&mut self.pending_events,
cx,
) {
let is_session_alive = false;
return is_session_alive;
}
}
!(self.inbound_sessions_marked_to_end.contains(inbound_session_id)
&& inbound_session.is_waiting())
true
});

// Handle outbound sessions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::pin::Pin;
use std::task::{Context, Poll};

use futures::future::BoxFuture;
use futures::FutureExt;
use futures::{AsyncWriteExt, FutureExt};
use libp2p::swarm::Stream;
use replace_with::replace_with_or_abort;

Expand All @@ -17,9 +17,15 @@ pub(super) struct InboundSession<Data: DataBound> {
current_task: WriteMessageTask,
}

pub(super) enum FinishReason {
Error(io::Error),
Closed,
}

enum WriteMessageTask {
Waiting(Stream),
Running(BoxFuture<'static, Result<Stream, io::Error>>),
Closing(BoxFuture<'static, Result<(), io::Error>>),
}

impl<Data: DataBound> InboundSession<Data> {
Expand All @@ -45,11 +51,20 @@ impl<Data: DataBound> InboundSession<Data> {
&& self.pending_messages.is_empty()
}

fn handle_waiting(&mut self, cx: &mut Context<'_>) -> Option<io::Error> {
pub fn start_closing(&mut self) {
replace_with_or_abort(&mut self.current_task, |current_task| {
let WriteMessageTask::Waiting(mut stream) = current_task else {
panic!("Called start_closing while not waiting.");
};
WriteMessageTask::Closing(async move { stream.close().await }.boxed())
})
}

fn handle_waiting(&mut self, cx: &mut Context<'_>) -> Option<FinishReason> {
if let Some(data) = self.pending_messages.pop_front() {
replace_with_or_abort(&mut self.current_task, |current_task| {
let WriteMessageTask::Waiting(mut stream) = current_task else {
panic!("Called handle_waiting while running.");
panic!("Called handle_waiting while not waiting.");
};
WriteMessageTask::Running(
async move {
Expand All @@ -64,32 +79,44 @@ impl<Data: DataBound> InboundSession<Data> {
None
}

fn handle_running(&mut self, cx: &mut Context<'_>) -> Option<io::Error> {
fn handle_running(&mut self, cx: &mut Context<'_>) -> Option<FinishReason> {
let WriteMessageTask::Running(fut) = &mut self.current_task else {
panic!("Called handle_running while waiting.");
panic!("Called handle_running while not running.");
};
match fut.poll_unpin(cx) {
Poll::Pending => None,
Poll::Ready(Ok(stream)) => {
self.current_task = WriteMessageTask::Waiting(stream);
self.handle_waiting(cx)
}
Poll::Ready(Err(io_error)) => Some(io_error),
Poll::Ready(Err(io_error)) => Some(FinishReason::Error(io_error)),
}
}

fn handle_closing(&mut self, cx: &mut Context<'_>) -> Option<FinishReason> {
let WriteMessageTask::Closing(fut) = &mut self.current_task else {
panic!("Called handle_closing while not closing.");
};
match fut.poll_unpin(cx) {
Poll::Pending => None,
Poll::Ready(Ok(())) => Some(FinishReason::Closed),
Poll::Ready(Err(io_error)) => Some(FinishReason::Error(io_error)),
}
}
}

impl<Data: DataBound> Future for InboundSession<Data> {
type Output = io::Error;
type Output = FinishReason;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let unpinned_self = Pin::into_inner(self);
let result = match &mut unpinned_self.current_task {
WriteMessageTask::Running(_) => unpinned_self.handle_running(cx),
WriteMessageTask::Waiting(_) => unpinned_self.handle_waiting(cx),
WriteMessageTask::Closing(_) => unpinned_self.handle_closing(cx),
};
match result {
Some(error) => Poll::Ready(error),
Some(finish_reason) => Poll::Ready(finish_reason),
None => Poll::Pending,
}
}
Expand Down
75 changes: 63 additions & 12 deletions crates/papyrus_network/src/streamed_data_protocol/handler_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use futures::{select, FutureExt, Stream as StreamTrait, StreamExt};
use libp2p::swarm::handler::{ConnectionEvent, FullyNegotiatedInbound, FullyNegotiatedOutbound};
use libp2p::swarm::{ConnectionHandler, ConnectionHandlerEvent, Stream};

use super::super::{DataBound, InboundSessionId, OutboundSessionId, QueryBound};
use super::{Handler, HandlerEvent, RequestFromBehaviourEvent, SessionId, ToBehaviourEvent};
use super::super::{DataBound, InboundSessionId, OutboundSessionId, QueryBound, SessionId};
use super::{Handler, HandlerEvent, RequestFromBehaviourEvent, ToBehaviourEvent};
use crate::messages::block::{GetBlocks, GetBlocksResponse};
use crate::messages::{read_message, write_message};
use crate::test_utils::{get_connected_streams, hardcoded_data};
Expand Down Expand Up @@ -127,12 +127,30 @@ async fn validate_request_to_swarm_new_outbound_session_to_swarm_event<
);
}

async fn read_messages(stream: &mut Stream, num_messages: usize) -> Vec<GetBlocksResponse> {
let mut result = Vec::new();
for _ in 0..num_messages {
result.push(read_message::<GetBlocksResponse, _>(&mut *stream).await.unwrap().unwrap());
async fn read_messages<Query: QueryBound, Data: DataBound>(
handler: Handler<Query, Data>,
stream: &mut Stream,
num_messages: usize,
) -> Vec<GetBlocksResponse> {
async fn read_messages_inner(
stream: &mut Stream,
num_messages: usize,
) -> Vec<GetBlocksResponse> {
let mut result = Vec::new();
for _ in 0..num_messages {
match read_message::<GetBlocksResponse, _>(&mut *stream).await.unwrap() {
Some(message) => result.push(message),
None => return result,
}
}
result
}

let mut fused_handler = handler.fuse();
select! {
data = read_messages_inner(stream, num_messages).fuse() => data,
_ = fused_handler.next() => panic!("There shouldn't be another event from the handler"),
}
result
}

#[tokio::test]
Expand Down Expand Up @@ -160,14 +178,47 @@ async fn process_inbound_session() {
simulate_request_to_send_data_from_swarm(&mut handler, data.clone(), inbound_session_id);
}

let mut fused_handler = handler.fuse();
let data_received = select! {
data = read_messages(&mut outbound_stream, hardcoded_data_vec.len()).fuse() => data,
_ = fused_handler.next() => panic!("There shouldn't be another event from the handler"),
};
let data_received =
read_messages(handler, &mut outbound_stream, hardcoded_data_vec.len()).await;
assert_eq!(hardcoded_data_vec, data_received);
}

#[tokio::test]
async fn finished_inbound_session_ignores_behaviour_request_to_send_data() {
let mut handler = Handler::<GetBlocks, GetBlocksResponse>::new(
SUBSTREAM_TIMEOUT,
Arc::new(Default::default()),
);

let (inbound_stream, mut outbound_stream, _) = get_connected_streams().await;
// TODO(shahak): Change to GetBlocks::default() when the bug that forbids sending default
// messages is fixed.
let query = GetBlocks { limit: 10, ..Default::default() };
let inbound_session_id = InboundSessionId { value: 1 };

simulate_negotiated_inbound_session_from_swarm(
&mut handler,
query.clone(),
inbound_stream,
inbound_session_id,
);

// consume the new inbound session event without reading it.
handler.next().await;

simulate_request_to_finish_session(
&mut handler,
SessionId::InboundSessionId(inbound_session_id),
);

let hardcoded_data_vec = hardcoded_data();
for data in &hardcoded_data_vec {
simulate_request_to_send_data_from_swarm(&mut handler, data.clone(), inbound_session_id);
}
let data_received = read_messages(handler, &mut outbound_stream, 1).await;
assert!(data_received.is_empty());
}

#[test]
fn listen_protocol_across_multiple_handlers() {
let next_inbound_session_id = Arc::new(AtomicUsize::default());
Expand Down
Loading