From e23be56fac45ae84b83b19afb25451f1bfaa0ee6 Mon Sep 17 00:00:00 2001 From: eitanm-starkware <144585602+eitanm-starkware@users.noreply.github.com> Date: Thu, 18 Jul 2024 09:51:35 +0300 Subject: [PATCH] refactor(network): move responses sender into query sender (#2233) --- crates/papyrus_network/src/lib.rs | 2 - .../src/network_manager/mod.rs | 230 ++++++++---------- .../src/network_manager/test.rs | 26 +- crates/papyrus_node/src/main.rs | 15 +- .../src/client/header_test.rs | 38 +-- crates/papyrus_p2p_sync/src/client/mod.rs | 70 +++--- .../src/client/state_diff_test.rs | 47 ++-- .../src/client/stream_builder.rs | 26 +- .../papyrus_p2p_sync/src/client/test_utils.rs | 48 ++-- 9 files changed, 250 insertions(+), 252 deletions(-) diff --git a/crates/papyrus_network/src/lib.rs b/crates/papyrus_network/src/lib.rs index 0322ec3695..819588b3b4 100644 --- a/crates/papyrus_network/src/lib.rs +++ b/crates/papyrus_network/src/lib.rs @@ -30,8 +30,6 @@ use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam}; use serde::{Deserialize, Serialize}; use validator::Validate; -pub use crate::network_manager::SqmrSubscriberChannels; - // TODO: add peer manager config to the network config #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Validate)] pub struct NetworkConfig { diff --git a/crates/papyrus_network/src/network_manager/mod.rs b/crates/papyrus_network/src/network_manager/mod.rs index b058d9c339..4c05a60661 100644 --- a/crates/papyrus_network/src/network_manager/mod.rs +++ b/crates/papyrus_network/src/network_manager/mod.rs @@ -10,7 +10,7 @@ use futures::channel::oneshot; use futures::future::{ready, BoxFuture, Ready}; use futures::sink::With; use futures::stream::{self, BoxStream, FuturesUnordered, Map, Stream}; -use futures::{FutureExt, Sink, SinkExt, StreamExt}; +use futures::{pin_mut, FutureExt, Sink, SinkExt, StreamExt}; use libp2p::gossipsub::{SubscriptionError, TopicHash}; use libp2p::swarm::SwarmEvent; use libp2p::{PeerId, StreamProtocol, Swarm}; @@ -39,17 +39,15 @@ pub struct GenericNetworkManager { sqmr_inbound_response_receivers: StreamHashMap>>, sqmr_inbound_query_senders: HashMap)>>, - // Splitting the response receivers from the query senders in order to poll all - // receivers simultaneously. - // Each receiver has a matching sender and vice versa (i.e the maps have the same keys). - sqmr_outbound_query_receivers: StreamHashMap>, - sqmr_outbound_response_senders: HashMap>, + + sqmr_outbound_payload_receivers: StreamHashMap, + sqmr_outbound_response_senders: HashMap, + sqmr_outbound_report_receivers: HashMap, // Splitting the broadcast receivers from the broadcasted senders in order to poll all // receivers simultaneously. // Each receiver has a matching sender and vice versa (i.e the maps have the same keys). messages_to_broadcast_receivers: StreamHashMap>, broadcasted_messages_senders: HashMap>, - outbound_session_id_to_protocol: HashMap, reported_peer_receivers: FuturesUnordered>>, // Fields for metrics num_active_inbound_sessions: usize, @@ -62,8 +60,8 @@ impl GenericNetworkManager { tokio::select! { Some(event) = self.swarm.next() => self.handle_swarm_event(event), Some(res) = self.sqmr_inbound_response_receivers.next() => self.handle_response_for_inbound_query(res), - Some((protocol, (query, report_receiver))) = self.sqmr_outbound_query_receivers.next() => { - self.handle_local_sqmr_query(protocol, query, report_receiver) + Some((protocol, client_payload)) = self.sqmr_outbound_payload_receivers.next() => { + self.handle_local_sqmr_payload(protocol, client_payload) } Some((topic_hash, message)) = self.messages_to_broadcast_receivers.next() => { self.broadcast_message(message, topic_hash); @@ -82,11 +80,11 @@ impl GenericNetworkManager { inbound_protocol_to_buffer_size: HashMap::new(), sqmr_inbound_response_receivers: StreamHashMap::new(HashMap::new()), sqmr_inbound_query_senders: HashMap::new(), - sqmr_outbound_query_receivers: StreamHashMap::new(HashMap::new()), + sqmr_outbound_payload_receivers: StreamHashMap::new(HashMap::new()), sqmr_outbound_response_senders: HashMap::new(), + sqmr_outbound_report_receivers: HashMap::new(), messages_to_broadcast_receivers: StreamHashMap::new(HashMap::new()), broadcasted_messages_senders: HashMap::new(), - outbound_session_id_to_protocol: HashMap::new(), reported_peer_receivers, num_active_inbound_sessions: 0, num_active_outbound_sessions: 0, @@ -133,36 +131,33 @@ impl GenericNetworkManager { &mut self, protocol: String, buffer_size: usize, - ) -> SqmrSubscriberChannels + ) -> SqmrClientSender where Bytes: From, - Response: TryFrom, + Response: TryFrom + 'static + Send, + >::Error: 'static + Send, + Query: 'static, { let protocol = StreamProtocol::try_from_owned(protocol) .expect("Could not parse protocol into StreamProtocol."); self.swarm.add_new_supported_inbound_protocol(protocol.clone()); - let (query_sender, query_receiver) = futures::channel::mpsc::channel(buffer_size); - let (response_sender, response_receiver) = futures::channel::mpsc::channel(buffer_size); + let (payload_sender, payload_receiver) = futures::channel::mpsc::channel(buffer_size); - let insert_result = - self.sqmr_outbound_query_receivers.insert(protocol.clone(), query_receiver); - if insert_result.is_some() { - panic!("Protocol '{}' has already been registered as a client.", protocol); - } - let insert_result = - self.sqmr_outbound_response_senders.insert(protocol.clone(), response_sender); + let insert_result = self + .sqmr_outbound_payload_receivers + .insert(protocol.clone(), Box::new(payload_receiver)); if insert_result.is_some() { panic!("Protocol '{}' has already been registered as a client.", protocol); } - let query_fn: SendQueryConverterFn = - |(query, report_receiver)| ready(Ok((Bytes::from(query), report_receiver))); - let query_sender = query_sender.with(query_fn); + let payload_fn = |payload: SqmrClientPayload| { + ready(Ok(SqmrClientPayloadForNetwork::from(payload))) + }; + let payload_sender = payload_sender.with(payload_fn); - let response_fn: ReceivedMessagesConverterFn = |x| Response::try_from(x); - let response_receiver = response_receiver.map(response_fn); + // let response_fn: ReceivedMessagesConverterFn = |x| Response::try_from(x); - SqmrSubscriberChannels { query_sender, response_receiver } + Box::new(payload_sender) } /// Register a new subscriber for broadcasting and receiving broadcasts for a given topic. @@ -336,8 +331,9 @@ impl GenericNetworkManager { "A protocol is registered in NetworkManager but it has no buffer size.", ), ); + // TODO(shahak): Close the inbound session if the buffer is full. - send_now( + server_send_now( query_sender, (query, response_sender), format!( @@ -359,14 +355,12 @@ impl GenericNetworkManager { "Received response from peer for session id: {outbound_session_id:?}. sending \ to sync subscriber." ); - let protocol = self - .outbound_session_id_to_protocol - .get(&outbound_session_id) - .expect("Received response from an unknown session id"); - if let Some(response_sender) = self.sqmr_outbound_response_senders.get_mut(protocol) + + if let Some(response_sender) = + self.sqmr_outbound_response_senders.get_mut(&outbound_session_id) { // TODO(shahak): Close the channel if the buffer is full. - send_now( + network_send_now( response_sender, response, format!( @@ -381,14 +375,20 @@ impl GenericNetworkManager { self.report_session_removed_to_metrics(session_id); // TODO: Handle reputation and retry. if let SessionId::OutboundSessionId(outbound_session_id) = session_id { - self.outbound_session_id_to_protocol.remove(&outbound_session_id); + self.sqmr_outbound_response_senders.remove(&outbound_session_id); + // TODO: check if the report receiver was already removed when session was + // assigned + self.sqmr_outbound_report_receivers.remove(&outbound_session_id); } } sqmr::behaviour::ExternalEvent::SessionFinishedSuccessfully { session_id } => { debug!("Session completed successfully. session_id: {session_id:?}"); self.report_session_removed_to_metrics(session_id); if let SessionId::OutboundSessionId(outbound_session_id) = session_id { - self.outbound_session_id_to_protocol.remove(&outbound_session_id); + self.sqmr_outbound_response_senders.remove(&outbound_session_id); + // TODO: check if the report receiver was already removed when session was + // assigned + self.sqmr_outbound_report_receivers.remove(&outbound_session_id); } } } @@ -398,14 +398,6 @@ impl GenericNetworkManager { match event { gossipsub_impl::ExternalEvent::Received { originated_peer_id, message, topic_hash } => { let (report_sender, report_receiver) = oneshot::channel::<()>(); - let peer_id = originated_peer_id; - let report_sender = Box::new(move || { - error!("Report sender was called for message from {peer_id:?}"); - let res = report_sender.send(()); - if let Err(e) = res { - error!("Failed to send report. Error: {e:?}"); - } - }); self.handle_new_report_receiver(originated_peer_id, report_receiver); let Some(sender) = self.broadcasted_messages_senders.get_mut(&topic_hash) else { error!( @@ -451,12 +443,13 @@ impl GenericNetworkManager { }; } - fn handle_local_sqmr_query( + fn handle_local_sqmr_payload( &mut self, protocol: StreamProtocol, - query: Bytes, - report_receiver: ReportReceiver, + client_payload: SqmrClientPayloadForNetwork, ) { + let SqmrClientPayloadForNetwork { query, report_receiver, responses_sender } = + client_payload; match self.swarm.send_query(query, PeerId::random(), protocol.clone()) { Ok(outbound_session_id) => { debug!("Sent query to peer. outbound_session_id: {outbound_session_id:?}"); @@ -465,24 +458,10 @@ impl GenericNetworkManager { papyrus_metrics::PAPYRUS_NUM_ACTIVE_OUTBOUND_SESSIONS, self.num_active_outbound_sessions as f64 ); - self.outbound_session_id_to_protocol.insert(outbound_session_id, protocol); - // TODO(eitan): this match always results in error as the session isnt assigned yet. - // save map between outbound_session_id and report_receiver. once session is - // assigned call handle_new_report_receiver - match self - .swarm - .get_peer_id_from_session_id(SessionId::OutboundSessionId(outbound_session_id)) - { - Ok(peer_id) => { - self.handle_new_report_receiver(peer_id, report_receiver); - } - Err(e) => { - error!( - "Got a report before any message was received. Ignoring report. \ - Error: {e:?}" - ); - } - } + self.sqmr_outbound_response_senders.insert(outbound_session_id, responses_sender); + // TODO(eitan): once session is assigned call handle_new_report_receiver using map + // below + self.sqmr_outbound_report_receivers.insert(outbound_session_id, report_receiver); } Err(e) => { info!( @@ -527,6 +506,34 @@ impl GenericNetworkManager { } } +fn network_send_now( + sender: &mut GenericSender, + item: Item, + buffer_full_message: String, +) { + pin_mut!(sender); + match sender.as_mut().send(item).now_or_never() { + Some(Ok(())) => {} + Some(Err(error)) => { + error!("Received error while sending message: {:?}", error); + } + None => { + error!(buffer_full_message); + } + } +} + +fn server_send_now(sender: &mut Sender, item: Item, buffer_full_message: String) { + if let Err(error) = sender.try_send(item) { + if error.is_disconnected() { + panic!("Receiver was dropped. This should never happen.") + } else if error.is_full() { + // TODO(shahak): Consider doing something else rather than dropping the message. + error!(buffer_full_message); + } + } +} + pub type NetworkManager = GenericNetworkManager>; impl NetworkManager { @@ -610,53 +617,60 @@ where #[cfg(feature = "testing")] pub fn dummy_report_sender() -> ReportSender { - Box::new(|| {}) + oneshot::channel::<()>().0 } -pub type GenericSender = Box + Unpin>; +type GenericSender = Box + Unpin + Send>; // Box implements Stream only if S: Stream + Unpin -pub type GenericReceiver = Box + Unpin>; +type GenericReceiver = Box + Unpin + Send>; + +type ResponsesSenderForNetwork = GenericSender; +type ResponsesSender = + GenericSender>::Error>>; + +type ReportSender = oneshot::Sender<()>; +type ReportReceiver = oneshot::Receiver<()>; pub struct SqmrClientPayload> { pub query: Query, - pub report_receiver: oneshot::Receiver<()>, - pub responses_sender: GenericSender>::Error>>, + pub report_receiver: ReportReceiver, + pub responses_sender: ResponsesSender, } + +pub type SqmrClientSender = GenericSender>; + pub struct SqmrServerPayload> { pub query: Query, - pub report_sender: oneshot::Sender<()>, - pub responses_sender: GenericSender>::Error>>, + pub report_sender: ReportSender, + pub responses_sender: ResponsesSender, } -#[allow(dead_code)] +// TODO(shahak): Return this type in register_sqmr_protocol_server +pub type SqmrServerReceiver = GenericReceiver>; + struct SqmrClientPayloadForNetwork { - pub query: Bytes, - pub report_receiver: oneshot::Receiver<()>, - pub responses_sender: GenericSender, + query: Bytes, + report_receiver: ReportReceiver, + responses_sender: ResponsesSenderForNetwork, } +type SqmrClientReceiver = GenericReceiver; + #[allow(dead_code)] struct SqmrServerPayloadForNetwork { - pub query: Bytes, - pub report_sender: oneshot::Sender<()>, - pub responses_sender: GenericSender, + query: Bytes, + report_sender: ReportSender, + responses_sender: ResponsesSenderForNetwork, } -// TODO(shahak): Return this type in register_sqmr_protocol_client -pub type SqmrClientSender = GenericSender>; -#[allow(dead_code)] -type SqmrClientReceiver = GenericReceiver; - -// TODO(shahak): Return this type in register_sqmr_protocol_server -pub type SqmrServerReceiver = GenericReceiver>; #[allow(dead_code)] type SqmrServerSender = GenericSender; impl From> for SqmrClientPayloadForNetwork where Bytes: From, - Response: TryFrom + 'static, - >::Error: 'static, + Response: TryFrom + 'static + Send, + >::Error: 'static + Send, { fn from(payload: SqmrClientPayload) -> Self { let SqmrClientPayload { query, report_receiver, responses_sender } = payload; @@ -670,17 +684,11 @@ where impl> From for SqmrServerPayload { - fn from(_query: SqmrServerPayloadForNetwork) -> Self { + fn from(_payload: SqmrServerPayloadForNetwork) -> Self { unimplemented!() } } -// TODO(shahak): Create a custom struct if Box dyn becomes an overhead. -// TODO(eitan): Change type to oneshot::Sender<()> -pub type ReportSender = Box; -pub type ReportReceiver = oneshot::Receiver<()>; - -// TODO(shahak): Add report sender. pub type SqmrQueryReceiver = Map)>, ReceivedQueryConverterFn>; @@ -699,34 +707,15 @@ pub type BroadcastSubscriberSender = With< fn(T) -> Ready>, >; -pub type SqmrSubscriberSender = With< - Sender<(Bytes, ReportReceiver)>, - (Bytes, ReportReceiver), - (T, ReportReceiver), - Ready>, - fn((T, ReportReceiver)) -> Ready>, ->; - pub type SendQueryConverterFn = fn((Query, ReportReceiver)) -> Ready>; -// TODO(shahak): rename to ConvertFromBytesReceiver and add an alias called BroadcastReceiver -pub type SqmrSubscriberReceiver = Map, ReceivedMessagesConverterFn>; - -type ReceivedMessagesConverterFn = fn(Bytes) -> Result>::Error>; - pub type BroadcastSubscriberReceiver = Map, BroadcastReceivedMessagesConverterFn>; type BroadcastReceivedMessagesConverterFn = fn((Bytes, ReportSender)) -> (Result>::Error>, ReportSender); -// TODO(shahak): Unite channels to a Sender of Query and Receiver of Responses. -pub struct SqmrSubscriberChannels, Response: TryFrom> { - pub query_sender: SqmrSubscriberSender, - pub response_receiver: SqmrSubscriberReceiver, -} - pub struct BroadcastSubscriberChannels> { pub messages_to_broadcast_sender: BroadcastSubscriberSender, pub broadcasted_messages_receiver: BroadcastSubscriberReceiver, @@ -755,14 +744,3 @@ pub struct TestSubscriberChannels> { pub subscriber_channels: BroadcastSubscriberChannels, pub mock_network: BroadcastNetworkMock, } - -fn send_now(sender: &mut Sender, item: Item, buffer_full_message: String) { - if let Err(error) = sender.try_send(item) { - if error.is_disconnected() { - panic!("Receiver was dropped. This should never happen.") - } else if error.is_full() { - // TODO(shahak): Consider doing something else rather than dropping the message. - error!(buffer_full_message); - } - } -} diff --git a/crates/papyrus_network/src/network_manager/test.rs b/crates/papyrus_network/src/network_manager/test.rs index 55213b6eb6..4d076e20e9 100644 --- a/crates/papyrus_network/src/network_manager/test.rs +++ b/crates/papyrus_network/src/network_manager/test.rs @@ -1,4 +1,5 @@ use std::collections::{HashMap, HashSet}; +use std::convert::Infallible; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -21,9 +22,10 @@ use tokio::sync::Mutex; use tokio::time::sleep; use super::swarm_trait::{Event, SwarmTrait}; -use super::{GenericNetworkManager, SqmrSubscriberChannels}; +use super::GenericNetworkManager; use crate::gossipsub_impl::{self, Topic}; use crate::mixed_behaviour; +use crate::network_manager::SqmrClientPayload; use crate::sqmr::behaviour::{PeerNotConnected, SessionIdNotFoundError}; use crate::sqmr::{Bytes, GenericEvent, InboundSessionId, OutboundSessionId}; @@ -215,18 +217,20 @@ async fn register_sqmr_protocol_client_and_use_channels() { let (event_notifier, mut event_listner) = oneshot::channel(); mock_swarm.first_polled_event_notifier = Some(event_notifier); - // network manager to register subscriber and send query + // network manager to register subscriber let mut network_manager = GenericNetworkManager::generic_new(mock_swarm); - // register subscriber and send query - let SqmrSubscriberChannels { mut query_sender, response_receiver } = network_manager - .register_sqmr_protocol_client::, Vec>( - SIGNED_BLOCK_HEADER_PROTOCOL.to_string(), - BUFFER_SIZE, - ); + // register subscriber and send payload + let mut payload_sender = network_manager.register_sqmr_protocol_client::, Vec>( + SIGNED_BLOCK_HEADER_PROTOCOL.to_string(), + BUFFER_SIZE, + ); let response_receiver_length = Arc::new(Mutex::new(0)); let cloned_response_receiver_length = Arc::clone(&response_receiver_length); + let (responses_sender, response_receiver) = + futures::channel::mpsc::channel::, Infallible>>(BUFFER_SIZE); + let responses_sender = Box::new(responses_sender); let response_receiver_collector = response_receiver .enumerate() .take(VEC1.len()) @@ -237,11 +241,11 @@ async fn register_sqmr_protocol_client_and_use_channels() { result }) .collect::>(); - let (_report_callback, report_receiver) = oneshot::channel::<()>(); + let (_report_sender, report_receiver) = oneshot::channel::<()>(); tokio::select! { _ = network_manager.run() => panic!("network manager ended"), _ = poll_fn(|cx| event_listner.poll_unpin(cx)).then(|_| async move { - query_sender.send((VEC1.clone(), report_receiver)).await.unwrap()}) + payload_sender.send(SqmrClientPayload{query : VEC1.clone(), report_receiver, responses_sender}).await.unwrap()}) .then(|_| async move { *cloned_response_receiver_length.lock().await = response_receiver_collector.await.len(); }) => {}, @@ -364,7 +368,7 @@ async fn receive_broadcasted_message_and_report_it() { .then(|result| { let (message_result, report_callback) = result.unwrap().unwrap(); assert_eq!(message, message_result.unwrap()); - report_callback(); + report_callback.send(()).unwrap(); tokio::time::timeout(TIMEOUT, reported_peer_receiver.next()) }) => { assert_eq!(originated_peer_id, reported_peer_result.unwrap().unwrap()); diff --git a/crates/papyrus_node/src/main.rs b/crates/papyrus_node/src/main.rs index 1100a15cf8..70fff41bb7 100644 --- a/crates/papyrus_node/src/main.rs +++ b/crates/papyrus_node/src/main.rs @@ -355,11 +355,11 @@ fn run_network( }; let mut network_manager = network_manager::NetworkManager::new(network_config.clone()); let local_peer_id = network_manager.get_local_peer_id(); - let header_client_channels = network_manager + let header_client_sender = network_manager .register_sqmr_protocol_client(Protocol::SignedBlockHeader.into(), BUFFER_SIZE); - let state_diff_client_channels = + let state_diff_client_sender = network_manager.register_sqmr_protocol_client(Protocol::StateDiff.into(), BUFFER_SIZE); - let transaction_client_channels = + let transaction_client_sender = network_manager.register_sqmr_protocol_client(Protocol::Transaction.into(), BUFFER_SIZE); let header_server_channel = network_manager @@ -381,12 +381,9 @@ fn run_network( None => None, }; let p2p_sync_channels = P2PSyncClientChannels { - header_query_sender: Box::new(header_client_channels.query_sender), - header_response_receiver: Box::new(header_client_channels.response_receiver), - state_diff_query_sender: Box::new(state_diff_client_channels.query_sender), - state_diff_response_receiver: Box::new(state_diff_client_channels.response_receiver), - transaction_query_sender: Box::new(transaction_client_channels.query_sender), - transaction_response_receiver: Box::new(transaction_client_channels.response_receiver), + header_payload_sender: header_client_sender, + state_diff_payload_sender: state_diff_client_sender, + transaction_payload_sender: transaction_client_sender, }; Ok(( diff --git a/crates/papyrus_p2p_sync/src/client/header_test.rs b/crates/papyrus_p2p_sync/src/client/header_test.rs index 38caa8cc2c..912c48f106 100644 --- a/crates/papyrus_p2p_sync/src/client/header_test.rs +++ b/crates/papyrus_p2p_sync/src/client/header_test.rs @@ -1,4 +1,5 @@ use futures::{SinkExt, StreamExt}; +use papyrus_network::network_manager::SqmrClientPayload; use papyrus_protobuf::sync::{ BlockHashOrNumber, DataOrFin, @@ -27,11 +28,9 @@ async fn signed_headers_basic_flow() { let TestArgs { p2p_sync, storage_reader, - mut header_query_receiver, - mut headers_sender, + mut header_payload_receiver, // The test will fail if we drop these - state_diff_query_receiver: _state_diff_query_receiver, - state_diffs_sender: _state_diffs_sender, + state_diff_payload_receiver: _state_diff_query_receiver, .. } = setup(); let block_hashes_and_signatures = @@ -44,7 +43,11 @@ async fn signed_headers_basic_flow() { let end_block_number = (query_index + 1) * HEADER_QUERY_LENGTH; // Receive query and validate it. - let (query, _report_receiver) = header_query_receiver.next().await.unwrap(); + let SqmrClientPayload { + query, + report_receiver: _report_receiver, + responses_sender: mut headers_sender, + } = header_payload_receiver.next().await.unwrap(); assert_eq!( query, HeaderQuery(Query { @@ -110,18 +113,20 @@ async fn sync_sends_new_header_query_if_it_got_partial_responses() { let TestArgs { p2p_sync, - mut header_query_receiver, - mut headers_sender, + mut header_payload_receiver, // The test will fail if we drop these - state_diff_query_receiver: _state_diff_query_receiver, - state_diffs_sender: _state_diffs_sender, + state_diff_payload_receiver: _state_diff_query_receiver, .. } = setup(); let block_hashes_and_signatures = create_block_hashes_and_signatures(NUM_ACTUAL_RESPONSES); // Create a future that will receive a query, send partial responses and receive the next query. let parse_queries_future = async move { - let _query = header_query_receiver.next().await.unwrap(); + let SqmrClientPayload { + query: _query, + report_receiver: _report_receiver, + responses_sender: mut headers_sender, + } = header_payload_receiver.next().await.unwrap(); for (i, (block_hash, signature)) in block_hashes_and_signatures.into_iter().enumerate() { headers_sender @@ -140,11 +145,14 @@ async fn sync_sends_new_header_query_if_it_got_partial_responses() { headers_sender.send(Ok(DataOrFin(None))).await.unwrap(); // First unwrap is for the timeout. Second unwrap is for the Option returned from Stream. - let (query, _report_receiver) = - timeout(TIMEOUT_FOR_NEW_QUERY_AFTER_PARTIAL_RESPONSE, header_query_receiver.next()) - .await - .unwrap() - .unwrap(); + let SqmrClientPayload { + query, + report_receiver: _report_receiver, + responses_sender: _responses_sender, + } = timeout(TIMEOUT_FOR_NEW_QUERY_AFTER_PARTIAL_RESPONSE, header_payload_receiver.next()) + .await + .unwrap() + .unwrap(); assert_eq!( query, diff --git a/crates/papyrus_p2p_sync/src/client/mod.rs b/crates/papyrus_p2p_sync/src/client/mod.rs index c62cecff13..cf6112db37 100644 --- a/crates/papyrus_p2p_sync/src/client/mod.rs +++ b/crates/papyrus_p2p_sync/src/client/mod.rs @@ -14,12 +14,12 @@ use std::time::Duration; use futures::channel::mpsc::SendError; use futures::future::{ready, Ready}; use futures::sink::With; -use futures::{Sink, SinkExt, Stream}; +use futures::{SinkExt, Stream}; use header::HeaderStreamBuilder; use papyrus_config::converters::deserialize_seconds_to_duration; use papyrus_config::dumping::{ser_optional_param, ser_param, SerializeConfig}; use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam}; -use papyrus_network::network_manager::ReportReceiver; +use papyrus_network::network_manager::{SqmrClientPayload, SqmrClientSender}; use papyrus_protobuf::converters::ProtobufConversionError; use papyrus_protobuf::sync::{ DataOrFin, @@ -159,32 +159,29 @@ pub enum P2PSyncError { SendError(#[from] SendError), } -type Response = Result, ProtobufConversionError>; // TODO(Eitan): Use SqmrSubscriberChannels once there is a utility function for testing -type QuerySender = - Box + Unpin + Send + 'static>; -type WithQuerySender = With< - QuerySender, - (T, ReportReceiver), - (Query, ReportReceiver), - Ready>, - fn((Query, ReportReceiver)) -> Ready>, + +type WithPayloadSender = With< + SqmrClientSender, + SqmrClientPayload, + SqmrClientPayload, + Ready, SendError>>, + fn( + SqmrClientPayload, + ) -> Ready, SendError>>, >; -type ResponseReceiver = Box> + Unpin + Send + 'static>; -type HeaderQuerySender = QuerySender; -type HeaderResponseReceiver = ResponseReceiver; -type StateDiffQuerySender = QuerySender; -type StateDiffResponseReceiver = ResponseReceiver; -type TransactionQuerySender = QuerySender; -type TransactionResponseReceiver = ResponseReceiver<(Transaction, TransactionOutput)>; +type SyncResponse = Result, ProtobufConversionError>; +type ResponseReceiver = Box> + Unpin + Send>; + +type HeaderPayloadSender = SqmrClientSender>; +type StateDiffPayloadSender = SqmrClientSender>; +type TransactionPayloadSender = + SqmrClientSender>; pub struct P2PSyncClientChannels { - pub header_query_sender: HeaderQuerySender, - pub header_response_receiver: HeaderResponseReceiver, - pub state_diff_query_sender: StateDiffQuerySender, - pub state_diff_response_receiver: StateDiffResponseReceiver, - pub transaction_query_sender: TransactionQuerySender, - pub transaction_response_receiver: TransactionResponseReceiver, + pub header_payload_sender: HeaderPayloadSender, + pub state_diff_payload_sender: StateDiffPayloadSender, + pub transaction_payload_sender: TransactionPayloadSender, } impl P2PSyncClientChannels { @@ -194,9 +191,15 @@ impl P2PSyncClientChannels { config: P2PSyncClientConfig, ) -> impl Stream + Send + 'static { let header_stream = HeaderStreamBuilder::create_stream( - self.header_query_sender - .with(|(query, report_receiver)| ready(Ok((HeaderQuery(query), report_receiver)))), - self.header_response_receiver, + self.header_payload_sender.with( + |SqmrClientPayload { query, report_receiver, responses_sender }| { + ready(Ok(SqmrClientPayload { + query: HeaderQuery(query), + report_receiver, + responses_sender, + })) + }, + ), storage_reader.clone(), config.wait_period_for_new_data, config.num_headers_per_query, @@ -204,10 +207,15 @@ impl P2PSyncClientChannels { ); let state_diff_stream = StateDiffStreamBuilder::create_stream( - self.state_diff_query_sender.with(|(query, report_receiver)| { - ready(Ok((StateDiffQuery(query), report_receiver))) - }), - self.state_diff_response_receiver, + self.state_diff_payload_sender.with( + |SqmrClientPayload { query, report_receiver, responses_sender }| { + ready(Ok(SqmrClientPayload { + query: StateDiffQuery(query), + report_receiver, + responses_sender, + })) + }, + ), storage_reader.clone(), config.wait_period_for_new_data, config.num_block_state_diffs_per_query, diff --git a/crates/papyrus_p2p_sync/src/client/state_diff_test.rs b/crates/papyrus_p2p_sync/src/client/state_diff_test.rs index 78e4b754a2..1f720fbfd4 100644 --- a/crates/papyrus_p2p_sync/src/client/state_diff_test.rs +++ b/crates/papyrus_p2p_sync/src/client/state_diff_test.rs @@ -3,6 +3,7 @@ use std::time::Duration; use assert_matches::assert_matches; use futures::{FutureExt, SinkExt, StreamExt}; use indexmap::indexmap; +use papyrus_network::network_manager::SqmrClientPayload; use papyrus_protobuf::sync::{ BlockHashOrNumber, ContractDiff, @@ -46,13 +47,8 @@ async fn state_diff_basic_flow() { let TestArgs { p2p_sync, storage_reader, - mut state_diff_query_receiver, - mut headers_sender, - mut state_diffs_sender, - // The test will fail if we drop this. - // We don't need to read the header query in order to know which headers to send, and we - // already validate the header query in a different test. - header_query_receiver: _header_query_receiver, + mut state_diff_payload_receiver, + mut header_payload_receiver, .. } = setup(); @@ -71,7 +67,12 @@ async fn state_diff_basic_flow() { tokio::time::sleep(SLEEP_DURATION_TO_LET_SYNC_ADVANCE).await; // Check that before we send headers there is no state diff query. - assert!(state_diff_query_receiver.next().now_or_never().is_none()); + assert!(state_diff_payload_receiver.next().now_or_never().is_none()); + let SqmrClientPayload { + query: _query, + report_receiver: _report_receiver, + responses_sender: mut headers_sender, + } = header_payload_receiver.next().await.unwrap(); // Send headers for entire query. for (i, ((block_hash, block_signature), state_diff)) in @@ -96,7 +97,11 @@ async fn state_diff_basic_flow() { (STATE_DIFF_QUERY_LENGTH, HEADER_QUERY_LENGTH - STATE_DIFF_QUERY_LENGTH), ] { // Get a state diff query and validate it - let (query, _report_receiver) = state_diff_query_receiver.next().await.unwrap(); + let SqmrClientPayload { + query, + report_receiver: _report_receiver, + responses_sender: mut state_diff_sender, + } = state_diff_payload_receiver.next().await.unwrap(); assert_eq!( query, StateDiffQuery(Query { @@ -116,7 +121,7 @@ async fn state_diff_basic_flow() { let txn = storage_reader.begin_ro_txn().unwrap(); assert_eq!(block_number, txn.get_state_marker().unwrap()); - state_diffs_sender + state_diff_sender .send(Ok(DataOrFin(Some(state_diff_chunk.clone())))) .await .unwrap(); @@ -164,7 +169,7 @@ async fn state_diff_basic_flow() { }; assert_eq!(state_diff, expected_state_diff); } - state_diffs_sender.send(Ok(DataOrFin(None))).await.unwrap(); + state_diff_sender.send(Ok(DataOrFin(None))).await.unwrap(); } }; @@ -307,13 +312,8 @@ async fn validate_state_diff_fails( let TestArgs { p2p_sync, storage_reader, - mut state_diff_query_receiver, - mut headers_sender, - mut state_diffs_sender, - // The test will fail if we drop this. - // We don't need to read the header query in order to know which headers to send, and we - // already validate the header query in a different test. - header_query_receiver: _header_query_receiver, + mut state_diff_payload_receiver, + mut header_payload_receiver, .. } = setup(); @@ -322,6 +322,11 @@ async fn validate_state_diff_fails( // Create a future that will receive queries, send responses and validate the results. let parse_queries_future = async move { // Send a single header. There's no need to fill the entire query. + let SqmrClientPayload { + query: _query, + report_receiver: _report_receiver, + responses_sender: mut headers_sender, + } = header_payload_receiver.next().await.unwrap(); headers_sender .send(Ok(DataOrFin(Some(SignedBlockHeader { block_header: BlockHeader { @@ -336,7 +341,11 @@ async fn validate_state_diff_fails( .unwrap(); // Get a state diff query and validate it - let (query, _report_reciever) = state_diff_query_receiver.next().await.unwrap(); + let SqmrClientPayload { + query, + report_receiver: _report_reciever, + responses_sender: mut state_diffs_sender, + } = state_diff_payload_receiver.next().await.unwrap(); assert_eq!( query, StateDiffQuery(Query { diff --git a/crates/papyrus_p2p_sync/src/client/stream_builder.rs b/crates/papyrus_p2p_sync/src/client/stream_builder.rs index b31182ad59..13b88011a4 100644 --- a/crates/papyrus_p2p_sync/src/client/stream_builder.rs +++ b/crates/papyrus_p2p_sync/src/client/stream_builder.rs @@ -6,13 +6,17 @@ use futures::channel::oneshot; use futures::future::BoxFuture; use futures::stream::BoxStream; use futures::{SinkExt, StreamExt}; +use papyrus_network::network_manager::SqmrClientPayload; +use papyrus_protobuf::converters::ProtobufConversionError; use papyrus_protobuf::sync::{BlockHashOrNumber, DataOrFin, Direction, Query}; use papyrus_storage::header::HeaderStorageReader; use papyrus_storage::{StorageError, StorageReader, StorageWriter}; use starknet_api::block::BlockNumber; use tracing::{debug, info}; -use super::{P2PSyncError, ResponseReceiver, WithQuerySender, STEP}; +use super::{P2PSyncError, ResponseReceiver, WithPayloadSender, STEP}; +use crate::client::SyncResponse; +use crate::BUFFER_SIZE; pub type DataStreamResult = Result, P2PSyncError>; @@ -33,8 +37,7 @@ pub(crate) enum BlockNumberLimit { pub(crate) trait DataStreamBuilder where InputFromNetwork: Send + 'static, - DataOrFin: TryFrom>, - as TryFrom>>::Error: Send, + DataOrFin: TryFrom, Error = ProtobufConversionError>, { type Output: BlockData + 'static; @@ -51,8 +54,7 @@ where fn get_start_block_number(storage_reader: &StorageReader) -> Result; fn create_stream( - mut query_sender: WithQuerySender, - mut data_receiver: ResponseReceiver, + mut payload_sender: WithPayloadSender>, storage_reader: StorageReader, wait_period_for_new_data: Duration, num_blocks_per_query: u64, @@ -87,20 +89,24 @@ where // TODO(shahak): Use the report callback. //TODO(Eitan): abstract report functionality to the channel struct let (_report_sender, report_receiver) = oneshot::channel::<()>(); - query_sender - .send(( + let (responses_sender, responses_receiver) = futures::channel::mpsc::channel::>(BUFFER_SIZE); + let responses_sender = Box::new(responses_sender); + let mut responses_receiver: ResponseReceiver = Box::new(responses_receiver); + payload_sender + .send(SqmrClientPayload { query: Query { start_block: BlockHashOrNumber::Number(current_block_number), direction: Direction::Forward, limit, step: STEP, - }, report_receiver,) + }, report_receiver, responses_sender + } ) .await?; while current_block_number.0 < end_block_number { match Self::parse_data_for_block( - &mut data_receiver, current_block_number, &storage_reader + &mut responses_receiver, current_block_number, &storage_reader ).await? { Some(output) => yield Ok(Box::::from(Box::new(output))), None => { @@ -125,7 +131,7 @@ where } // Consume the None message signaling the end of the query. - match data_receiver.next().await { + match responses_receiver.next().await { Some(Ok(DataOrFin(None))) => { debug!("Query sent to network for {:?} finished", Self::TYPE_DESCRIPTION); }, diff --git a/crates/papyrus_p2p_sync/src/client/test_utils.rs b/crates/papyrus_p2p_sync/src/client/test_utils.rs index 6dd6e01fbb..aa35e05a40 100644 --- a/crates/papyrus_p2p_sync/src/client/test_utils.rs +++ b/crates/papyrus_p2p_sync/src/client/test_utils.rs @@ -1,9 +1,10 @@ use std::time::Duration; -use futures::channel::mpsc::{Receiver, Sender}; +use futures::channel::mpsc::Receiver; use lazy_static::lazy_static; -use papyrus_network::network_manager::ReportReceiver; +use papyrus_network::network_manager::SqmrClientPayload; use papyrus_protobuf::sync::{ + DataOrFin, HeaderQuery, SignedBlockHeader, StateDiffChunk, @@ -18,7 +19,7 @@ use starknet_api::hash::StarkHash; use starknet_api::transaction::{Transaction, TransactionOutput}; use starknet_types_core::felt::Felt; -use super::{P2PSyncClient, P2PSyncClientChannels, P2PSyncClientConfig, Response}; +use super::{P2PSyncClient, P2PSyncClientChannels, P2PSyncClientConfig}; pub const BUFFER_SIZE: usize = 1000; pub const HEADER_QUERY_LENGTH: u64 = 5; @@ -43,37 +44,29 @@ pub struct TestArgs { #[allow(clippy::type_complexity)] pub p2p_sync: P2PSyncClient, pub storage_reader: StorageReader, - pub header_query_receiver: Receiver<(HeaderQuery, ReportReceiver)>, - pub state_diff_query_receiver: Receiver<(StateDiffQuery, ReportReceiver)>, + pub header_payload_receiver: + Receiver>>, + pub state_diff_payload_receiver: + Receiver>>, #[allow(dead_code)] - pub transaction_query_receiver: Receiver<(TransactionQuery, ReportReceiver)>, - pub headers_sender: Sender>, - pub state_diffs_sender: Sender>, - #[allow(dead_code)] - pub transaction_sender: Sender>, + pub transaction_payload_receiver: + Receiver>>, } pub fn setup() -> TestArgs { let p2p_sync_config = *TEST_CONFIG; let buffer_size = p2p_sync_config.buffer_size; let ((storage_reader, storage_writer), _temp_dir) = get_test_storage(); - let (header_query_sender, header_query_receiver) = futures::channel::mpsc::channel(buffer_size); - let (state_diff_query_sender, state_diff_query_receiver) = - futures::channel::mpsc::channel(buffer_size); - let (transaction_query_sender, transaction_query_receiver) = + let (header_payload_sender, header_payload_receiver) = futures::channel::mpsc::channel(buffer_size); - let (headers_sender, header_response_receiver) = futures::channel::mpsc::channel(buffer_size); - let (state_diffs_sender, state_diff_response_receiver) = + let (state_diff_payload_sender, state_diff_payload_receiver) = futures::channel::mpsc::channel(buffer_size); - let (transaction_sender, transaction_response_receiver) = + let (transaction_payload_sender, transaction_payload_receiver) = futures::channel::mpsc::channel(buffer_size); let p2p_sync_channels = P2PSyncClientChannels { - header_query_sender: Box::new(header_query_sender), - state_diff_query_sender: Box::new(state_diff_query_sender), - header_response_receiver: Box::new(header_response_receiver), - state_diff_response_receiver: Box::new(state_diff_response_receiver), - transaction_query_sender: Box::new(transaction_query_sender), - transaction_response_receiver: Box::new(transaction_response_receiver), + header_payload_sender: Box::new(header_payload_sender), + state_diff_payload_sender: Box::new(state_diff_payload_sender), + transaction_payload_sender: Box::new(transaction_payload_sender), }; let p2p_sync = P2PSyncClient::new( p2p_sync_config, @@ -84,12 +77,9 @@ pub fn setup() -> TestArgs { TestArgs { p2p_sync, storage_reader, - header_query_receiver, - state_diff_query_receiver, - transaction_query_receiver, - headers_sender, - state_diffs_sender, - transaction_sender, + header_payload_receiver, + state_diff_payload_receiver, + transaction_payload_receiver, } }