diff --git a/crates/papyrus_network/src/streamed_data_protocol/handler.rs b/crates/papyrus_network/src/streamed_data_protocol/handler.rs index 24b42214b4..2955f52210 100644 --- a/crates/papyrus_network/src/streamed_data_protocol/handler.rs +++ b/crates/papyrus_network/src/streamed_data_protocol/handler.rs @@ -22,7 +22,7 @@ use libp2p::swarm::handler::{ use libp2p::swarm::{ConnectionHandler, ConnectionHandlerEvent, KeepAlive, SubstreamProtocol}; use tracing::debug; -use self::session::InboundSession; +use self::session::{FinishReason, InboundSession}; use super::protocol::{InboundProtocol, OutboundProtocol, PROTOCOL_NAME}; use super::{DataBound, GenericEvent, InboundSessionId, OutboundSessionId, QueryBound, SessionId}; use crate::messages::read_message; @@ -104,6 +104,29 @@ impl Handler { // _ => true, // }) // } + + /// Poll an inbound session, inserting any events needed to pending_events, and return whether + /// the inbound session has finished. + fn poll_inbound_session( + inbound_session: &mut InboundSession, + inbound_session_id: InboundSessionId, + pending_events: &mut VecDeque>, + cx: &mut Context<'_>, + ) -> bool { + let Poll::Ready(finish_reason) = inbound_session.poll_unpin(cx) else { + let is_session_alive = false; + return is_session_alive; + }; + if let FinishReason::Error(io_error) = finish_reason { + pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour( + ToBehaviourEvent::SessionFailed { + session_id: SessionId::InboundSessionId(inbound_session_id), + error: SessionError::IOError(io_error), + }, + )); + } + true + } } impl ConnectionHandler for Handler { @@ -141,17 +164,30 @@ impl ConnectionHandler for Handler { // Handle inbound sessions. self.id_to_inbound_session.retain(|inbound_session_id, inbound_session| { - if let Poll::Ready(io_error) = inbound_session.poll_unpin(cx) { - self.pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour( - ToBehaviourEvent::SessionFailed { - session_id: SessionId::InboundSessionId(*inbound_session_id), - error: SessionError::IOError(io_error), - }, - )); - return false; + if Self::poll_inbound_session( + inbound_session, + *inbound_session_id, + &mut self.pending_events, + cx, + ) { + let is_session_alive = false; + return is_session_alive; + } + if self.inbound_sessions_marked_to_end.contains(inbound_session_id) + && inbound_session.is_waiting() + { + inbound_session.start_closing(); + if Self::poll_inbound_session( + inbound_session, + *inbound_session_id, + &mut self.pending_events, + cx, + ) { + let is_session_alive = false; + return is_session_alive; + } } - !(self.inbound_sessions_marked_to_end.contains(inbound_session_id) - && inbound_session.is_waiting()) + true }); // Handle outbound sessions. diff --git a/crates/papyrus_network/src/streamed_data_protocol/handler/session.rs b/crates/papyrus_network/src/streamed_data_protocol/handler/session.rs index 14cef24442..5d03bf556a 100644 --- a/crates/papyrus_network/src/streamed_data_protocol/handler/session.rs +++ b/crates/papyrus_network/src/streamed_data_protocol/handler/session.rs @@ -5,7 +5,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; use futures::future::BoxFuture; -use futures::FutureExt; +use futures::{AsyncWriteExt, FutureExt}; use libp2p::swarm::Stream; use replace_with::replace_with_or_abort; @@ -17,9 +17,15 @@ pub(super) struct InboundSession { current_task: WriteMessageTask, } +pub(super) enum FinishReason { + Error(io::Error), + Closed, +} + enum WriteMessageTask { Waiting(Stream), Running(BoxFuture<'static, Result>), + Closing(BoxFuture<'static, Result<(), io::Error>>), } impl InboundSession { @@ -45,11 +51,20 @@ impl InboundSession { && self.pending_messages.is_empty() } - fn handle_waiting(&mut self, cx: &mut Context<'_>) -> Option { + pub fn start_closing(&mut self) { + replace_with_or_abort(&mut self.current_task, |current_task| { + let WriteMessageTask::Waiting(mut stream) = current_task else { + panic!("Called start_closing while not waiting."); + }; + WriteMessageTask::Closing(async move { stream.close().await }.boxed()) + }) + } + + fn handle_waiting(&mut self, cx: &mut Context<'_>) -> Option { if let Some(data) = self.pending_messages.pop_front() { replace_with_or_abort(&mut self.current_task, |current_task| { let WriteMessageTask::Waiting(mut stream) = current_task else { - panic!("Called handle_waiting while running."); + panic!("Called handle_waiting while not waiting."); }; WriteMessageTask::Running( async move { @@ -64,9 +79,9 @@ impl InboundSession { None } - fn handle_running(&mut self, cx: &mut Context<'_>) -> Option { + fn handle_running(&mut self, cx: &mut Context<'_>) -> Option { let WriteMessageTask::Running(fut) = &mut self.current_task else { - panic!("Called handle_running while waiting."); + panic!("Called handle_running while not running."); }; match fut.poll_unpin(cx) { Poll::Pending => None, @@ -74,22 +89,34 @@ impl InboundSession { self.current_task = WriteMessageTask::Waiting(stream); self.handle_waiting(cx) } - Poll::Ready(Err(io_error)) => Some(io_error), + Poll::Ready(Err(io_error)) => Some(FinishReason::Error(io_error)), + } + } + + fn handle_closing(&mut self, cx: &mut Context<'_>) -> Option { + let WriteMessageTask::Closing(fut) = &mut self.current_task else { + panic!("Called handle_closing while not closing."); + }; + match fut.poll_unpin(cx) { + Poll::Pending => None, + Poll::Ready(Ok(())) => Some(FinishReason::Closed), + Poll::Ready(Err(io_error)) => Some(FinishReason::Error(io_error)), } } } impl Future for InboundSession { - type Output = io::Error; + type Output = FinishReason; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let unpinned_self = Pin::into_inner(self); let result = match &mut unpinned_self.current_task { WriteMessageTask::Running(_) => unpinned_self.handle_running(cx), WriteMessageTask::Waiting(_) => unpinned_self.handle_waiting(cx), + WriteMessageTask::Closing(_) => unpinned_self.handle_closing(cx), }; match result { - Some(error) => Poll::Ready(error), + Some(finish_reason) => Poll::Ready(finish_reason), None => Poll::Pending, } } diff --git a/crates/papyrus_network/src/streamed_data_protocol/handler_test.rs b/crates/papyrus_network/src/streamed_data_protocol/handler_test.rs index 0a394b1e22..5d055030fe 100644 --- a/crates/papyrus_network/src/streamed_data_protocol/handler_test.rs +++ b/crates/papyrus_network/src/streamed_data_protocol/handler_test.rs @@ -10,8 +10,8 @@ use futures::{select, FutureExt, Stream as StreamTrait, StreamExt}; use libp2p::swarm::handler::{ConnectionEvent, FullyNegotiatedInbound, FullyNegotiatedOutbound}; use libp2p::swarm::{ConnectionHandler, ConnectionHandlerEvent, Stream}; -use super::super::{DataBound, InboundSessionId, OutboundSessionId, QueryBound}; -use super::{Handler, HandlerEvent, RequestFromBehaviourEvent, SessionId, ToBehaviourEvent}; +use super::super::{DataBound, InboundSessionId, OutboundSessionId, QueryBound, SessionId}; +use super::{Handler, HandlerEvent, RequestFromBehaviourEvent, ToBehaviourEvent}; use crate::messages::block::{GetBlocks, GetBlocksResponse}; use crate::messages::{read_message, write_message}; use crate::test_utils::{get_connected_streams, hardcoded_data}; @@ -127,12 +127,30 @@ async fn validate_request_to_swarm_new_outbound_session_to_swarm_event< ); } -async fn read_messages(stream: &mut Stream, num_messages: usize) -> Vec { - let mut result = Vec::new(); - for _ in 0..num_messages { - result.push(read_message::(&mut *stream).await.unwrap().unwrap()); +async fn read_messages( + handler: Handler, + stream: &mut Stream, + num_messages: usize, +) -> Vec { + async fn read_messages_inner( + stream: &mut Stream, + num_messages: usize, + ) -> Vec { + let mut result = Vec::new(); + for _ in 0..num_messages { + match read_message::(&mut *stream).await.unwrap() { + Some(message) => result.push(message), + None => return result, + } + } + result + } + + let mut fused_handler = handler.fuse(); + select! { + data = read_messages_inner(stream, num_messages).fuse() => data, + _ = fused_handler.next() => panic!("There shouldn't be another event from the handler"), } - result } #[tokio::test] @@ -160,14 +178,47 @@ async fn process_inbound_session() { simulate_request_to_send_data_from_swarm(&mut handler, data.clone(), inbound_session_id); } - let mut fused_handler = handler.fuse(); - let data_received = select! { - data = read_messages(&mut outbound_stream, hardcoded_data_vec.len()).fuse() => data, - _ = fused_handler.next() => panic!("There shouldn't be another event from the handler"), - }; + let data_received = + read_messages(handler, &mut outbound_stream, hardcoded_data_vec.len()).await; assert_eq!(hardcoded_data_vec, data_received); } +#[tokio::test] +async fn finished_inbound_session_ignores_behaviour_request_to_send_data() { + let mut handler = Handler::::new( + SUBSTREAM_TIMEOUT, + Arc::new(Default::default()), + ); + + let (inbound_stream, mut outbound_stream, _) = get_connected_streams().await; + // TODO(shahak): Change to GetBlocks::default() when the bug that forbids sending default + // messages is fixed. + let query = GetBlocks { limit: 10, ..Default::default() }; + let inbound_session_id = InboundSessionId { value: 1 }; + + simulate_negotiated_inbound_session_from_swarm( + &mut handler, + query.clone(), + inbound_stream, + inbound_session_id, + ); + + // consume the new inbound session event without reading it. + handler.next().await; + + simulate_request_to_finish_session( + &mut handler, + SessionId::InboundSessionId(inbound_session_id), + ); + + let hardcoded_data_vec = hardcoded_data(); + for data in &hardcoded_data_vec { + simulate_request_to_send_data_from_swarm(&mut handler, data.clone(), inbound_session_id); + } + let data_received = read_messages(handler, &mut outbound_stream, 1).await; + assert!(data_received.is_empty()); +} + #[test] fn listen_protocol_across_multiple_handlers() { let next_inbound_session_id = Arc::new(AtomicUsize::default());