Skip to content

Commit

Permalink
feat(network): add add_address and remove_address to behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
ShahakShama committed Oct 1, 2023
1 parent e3e9c3c commit a29430c
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 8 deletions.
48 changes: 42 additions & 6 deletions crates/papyrus_network/src/streamed_data_protocol/behaviour.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::task::{Context, Poll};
use std::time::Duration;

use defaultmap::DefaultHashMap;
use libp2p::core::Endpoint;
use libp2p::core::{ConnectedPoint, Endpoint};
use libp2p::swarm::behaviour::ConnectionEstablished;
use libp2p::swarm::dial_opts::{DialOpts, PeerCondition};
use libp2p::swarm::{
Expand Down Expand Up @@ -80,6 +80,10 @@ pub(crate) type Event<Query, Data> = GenericEvent<Query, Data, SessionError>;
#[error("The given session ID doesn't exist.")]
pub(crate) struct SessionIdNotFoundError;

#[derive(thiserror::Error, Debug)]
#[error("There are no known addresses for the given peer. Add one with `add_address`.")]
pub(crate) struct NoKnownAddressesError;

// TODO(shahak) remove allow dead code.
#[allow(dead_code)]
pub(crate) struct Behaviour<Query: QueryBound, Data: DataBound> {
Expand All @@ -88,6 +92,7 @@ pub(crate) struct Behaviour<Query: QueryBound, Data: DataBound> {
pending_queries: DefaultHashMap<PeerId, Vec<(Query, OutboundSessionId)>>,
connected_peers: HashSet<PeerId>,
session_id_to_peer_id: HashMap<SessionId, PeerId>,
peer_id_to_addresses: DefaultHashMap<PeerId, HashSet<Multiaddr>>,
next_outbound_session_id: OutboundSessionId,
next_inbound_session_id: Arc<AtomicUsize>,
}
Expand All @@ -102,28 +107,42 @@ impl<Query: QueryBound, Data: DataBound> Behaviour<Query, Data> {
pending_queries: Default::default(),
connected_peers: Default::default(),
session_id_to_peer_id: Default::default(),
peer_id_to_addresses: Default::default(),
next_outbound_session_id: Default::default(),
next_inbound_session_id: Arc::new(Default::default()),
}
}

/// Send query to the given peer and start a new outbound session with it. Return the id of the
/// new session.
pub fn send_query(&mut self, query: Query, peer_id: PeerId) -> OutboundSessionId {
pub fn send_query(
&mut self,
query: Query,
peer_id: PeerId,
) -> Result<OutboundSessionId, NoKnownAddressesError> {
let outbound_session_id = self.next_outbound_session_id;
self.next_outbound_session_id.value += 1;
self.session_id_to_peer_id
.insert(SessionId::OutboundSessionId(outbound_session_id), peer_id);

if self.connected_peers.contains(&peer_id) {
self.send_query_to_handler(peer_id, query, outbound_session_id);
return outbound_session_id;
return Ok(outbound_session_id);
}

let addresses_set = self.peer_id_to_addresses.get(peer_id);
if addresses_set.is_empty() {
return Err(NoKnownAddressesError);
}
let addresses = addresses_set.clone().into_iter().collect();
self.pending_events.push_back(ToSwarm::Dial {
opts: DialOpts::peer_id(peer_id).condition(PeerCondition::Disconnected).build(),
opts: DialOpts::peer_id(peer_id)
.addresses(addresses)
.condition(PeerCondition::Disconnected)
.build(),
});
self.pending_queries.get_mut(peer_id).push((query, outbound_session_id));
outbound_session_id
Ok(outbound_session_id)
}

/// Send a data message to an open inbound session.
Expand Down Expand Up @@ -154,6 +173,20 @@ impl<Query: QueryBound, Data: DataBound> Behaviour<Query, Data> {
Ok(())
}

/// Add an address as a known address that the given peer listens on for incoming connections.
/// This function must be called on a peer before calling send_query to that peer, unless the
/// given peer created an inbound session with us.
pub fn add_address(&mut self, peer_id: PeerId, address: Multiaddr) {
self.peer_id_to_addresses.get_mut(peer_id).insert(address);
}

/// Remove an address that was previously added with `add_address` or from an incoming inbound
/// session.
/// If the given address wasn't added, this function doesn't do anything.
pub fn remove_address(&mut self, peer_id: PeerId, address: Multiaddr) {
self.peer_id_to_addresses.get_mut(peer_id).remove(&address);
}

fn send_query_to_handler(
&mut self,
peer_id: PeerId,
Expand Down Expand Up @@ -195,12 +228,15 @@ impl<Query: QueryBound, Data: DataBound> NetworkBehaviour for Behaviour<Query, D
fn on_swarm_event(&mut self, event: FromSwarm<'_, Self::ConnectionHandler>) {
match event {
FromSwarm::ConnectionEstablished(connection_established) => {
let ConnectionEstablished { peer_id, .. } = connection_established;
let ConnectionEstablished { peer_id, endpoint, .. } = connection_established;
if let Some(queries) = self.pending_queries.remove(&peer_id) {
for (query, outbound_session_id) in queries.into_iter() {
self.send_query_to_handler(peer_id, query, outbound_session_id);
}
}
if let ConnectedPoint::Listener { send_back_addr, .. } = endpoint {
self.add_address(peer_id, send_back_addr.clone());
}
}
_ => {
unimplemented!();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ fn simulate_outbound_session_closed_by_peer<Query: QueryBound, Data: DataBound>(
);
}

// There's no way to extract addresses from DialOpts, so we can't test if the addresses are
// correct.
async fn validate_dial_event<Query: QueryBound, Data: DataBound>(
behaviour: &mut Behaviour<Query, Data>,
peer_id: &PeerId,
Expand Down Expand Up @@ -317,7 +319,8 @@ async fn create_and_process_outbound_session() {
// messages is fixed.
let query = GetBlocks { limit: 10, ..Default::default() };
let peer_id = PeerId::random();
let outbound_session_id = behaviour.send_query(query.clone(), peer_id);
behaviour.add_address(peer_id, Multiaddr::empty());
let outbound_session_id = behaviour.send_query(query.clone(), peer_id).unwrap();

validate_dial_event(&mut behaviour, &peer_id).await;
validate_no_events(&mut behaviour);
Expand Down Expand Up @@ -361,7 +364,8 @@ async fn outbound_session_closed_by_peer() {
// messages is fixed.
let query = GetBlocks { limit: 10, ..Default::default() };
let peer_id = PeerId::random();
let outbound_session_id = behaviour.send_query(query.clone(), peer_id);
behaviour.add_address(peer_id, Multiaddr::empty());
let outbound_session_id = behaviour.send_query(query.clone(), peer_id).unwrap();

// Consume the dial event.
behaviour.next().await.unwrap();
Expand Down Expand Up @@ -392,3 +396,28 @@ async fn send_data_non_existing_session_fails() {
behaviour.send_data(data, InboundSessionId::default()).unwrap_err();
}
}

#[tokio::test]
async fn send_query_no_known_address_fails() {
let mut behaviour = Behaviour::<GetBlocks, GetBlocksResponse>::new(SUBSTREAM_TIMEOUT);

// TODO(shahak): Change to GetBlocks::default() when the bug that forbids sending default
// messages is fixed.
let query = GetBlocks { limit: 10, ..Default::default() };
let peer_id = PeerId::random();
behaviour.send_query(query, peer_id).unwrap_err();
}

#[tokio::test]
async fn new_inbound_session_adds_address() {
let mut behaviour = Behaviour::<GetBlocks, GetBlocksResponse>::new(SUBSTREAM_TIMEOUT);

let peer_id = PeerId::random();

simulate_listener_connection_from_swarm(&mut behaviour, peer_id);

// TODO(shahak): Change to GetBlocks::default() when the bug that forbids sending default
// messages is fixed.
let query = GetBlocks { limit: 20, ..Default::default() };
behaviour.send_query(query, peer_id).unwrap();
}

0 comments on commit a29430c

Please sign in to comment.