diff --git a/Cargo.toml b/Cargo.toml index 2360310d..422734c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,28 +1,24 @@ -[package] -name = "gadget" -version = "0.0.1" -authors = ["Thomas P Braun"] -license = "GPL-3.0-or-later WITH Classpath-exception-2.0" -edition = "2021" - -[features] -substrate = [ - "sp-runtime", - "sc-client-api", - "sp-api", - "futures" +[workspace] +members = [ + "gadget-core", + "webb-gadget", + "zk-gadget" ] -[dependencies] -sync_wrapper = "0.1.2" -parking_lot = "0.12.1" -tokio = { version = "1.32.0", features = ["sync", "time", "macros", "rt"] } -hex = "0.4.3" -async-trait = "0.1.73" +[workspace.dependencies] +gadget-core = { path = "./gadget-core" } +webb-gadget = { path = "./webb-gadget" } +zk-gadget = { path = "./zk-gadget" } -sp-runtime = { optional = true, git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false } -sc-client-api = { optional = true, git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false } -sp-api = { optional = true, git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false } -futures = { optional = true, version = "0.3.28" } +sc-client-api = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false } +sp-core = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false } +sp-runtime = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false } +sp-api = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false } -[dev-dependencies] \ No newline at end of file +mpc-net = { git = "https://github.com/webb-tools/zk-SaaS/" } +tokio-rustls = "0.24.1" +tokio = "1.32.0" +bincode2 = "2" +futures-util = "0.3.28" +serde = "1.0.188" +async-trait = "0.1.73" \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 00000000..267fb5ba --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +# Gadget + +## Design + +The core library is `gadget-core`. The core library allows gadgets to hold standardization of use across different blockchains. The core library is the base of all gadgets, and expects to receive `FinalityNotifications` and `BlockImportNotifications`. + +Once such blockchain is a substrate blockchain. This is where `webb-gadget` comes into play. The `webb-gadget` is the core-gadget endowed with a connection to a substrate blockchain, a networking layer to communicate with other gadgets, and a *WebbModule* that has application-specific logic. + +Since `webb-gadget` allows varying connections to a substrate blockchain and differing network layers, we can thus design above it the `zk-gadget` and `tss-gadget`. These gadgets are endowed with the same functionalities as the `webb-gadget` but with a different connection and networking layer. diff --git a/gadget-core/Cargo.toml b/gadget-core/Cargo.toml new file mode 100644 index 00000000..c26b4b67 --- /dev/null +++ b/gadget-core/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "gadget-core" +version = "0.0.1" +authors = ["Thomas P Braun"] +license = "GPL-3.0-or-later WITH Classpath-exception-2.0" +edition = "2021" + +[features] +substrate = [ + "sp-runtime", + "sc-client-api", + "sp-api", + "futures" +] + +[dependencies] +sync_wrapper = "0.1.2" +parking_lot = "0.12.1" +tokio = { workspace = true, features = ["sync", "time", "macros", "rt"] } +hex = "0.4.3" +async-trait = "0.1.73" + +sp-runtime = { optional = true, workspace = true, default-features = false } +sc-client-api = { optional = true, workspace = true, default-features = false } +sp-api = { optional = true, workspace = true, default-features = false } +futures = { optional = true, version = "0.3.28" } + +[dev-dependencies] \ No newline at end of file diff --git a/src/gadget/manager.rs b/gadget-core/src/gadget/manager.rs similarity index 100% rename from src/gadget/manager.rs rename to gadget-core/src/gadget/manager.rs diff --git a/src/gadget/mod.rs b/gadget-core/src/gadget/mod.rs similarity index 100% rename from src/gadget/mod.rs rename to gadget-core/src/gadget/mod.rs diff --git a/src/gadget/substrate/mod.rs b/gadget-core/src/gadget/substrate/mod.rs similarity index 77% rename from src/gadget/substrate/mod.rs rename to gadget-core/src/gadget/substrate/mod.rs index 8c406cef..5a72de84 100644 --- a/src/gadget/substrate/mod.rs +++ b/gadget-core/src/gadget/substrate/mod.rs @@ -9,13 +9,14 @@ use sp_api::ProvideRuntimeApi; use sp_runtime::traits::Block; use std::error::Error; use std::fmt::{Debug, Display, Formatter}; +use std::sync::Arc; use tokio::sync::Mutex; -pub struct SubstrateGadget { +pub struct SubstrateGadget { module: Module, - finality_notification_stream: Mutex>, - block_import_notification_stream: Mutex>, - _pd: std::marker::PhantomData<(B, BE, API)>, + finality_notification_stream: Mutex>, + block_import_notification_stream: Mutex>, + client: Arc, } #[derive(Copy, Clone, Debug, Eq, PartialEq)] @@ -25,18 +26,19 @@ pub struct SubstrateGadgetError {} #[async_trait] pub trait SubstrateGadgetModule: Send + Sync { type Error: Error + Send; - type FinalityNotification: Send; - type BlockImportNotification: Send; type ProtocolMessage: Send; + type Block: Block; + type Backend: Backend; + type Client: Client; async fn get_next_protocol_message(&self) -> Option; async fn process_finality_notification( &self, - notification: Self::FinalityNotification, + notification: FinalityNotification, ) -> Result<(), Self::Error>; async fn process_block_import_notification( &self, - notification: Self::BlockImportNotification, + notification: BlockImportNotification, ) -> Result<(), Self::Error>; async fn process_protocol_message( &self, @@ -61,14 +63,11 @@ where { } -impl SubstrateGadget +impl SubstrateGadget where - B: Block, - BE: Backend, Module: SubstrateGadgetModule, - Api: Send + Sync, { - pub fn new>(client: &C, module: Module) -> Self { + pub fn new(client: Module::Client, module: Module) -> Self { let finality_notification_stream = client.finality_notification_stream(); let block_import_notification_stream = client.import_notification_stream(); @@ -76,24 +75,22 @@ where module, finality_notification_stream: Mutex::new(finality_notification_stream), block_import_notification_stream: Mutex::new(block_import_notification_stream), - _pd: std::marker::PhantomData, + client: Arc::new(client), } } + + pub fn client(&self) -> &Arc { + &self.client + } } #[async_trait] -impl AbstractGadget for SubstrateGadget +impl AbstractGadget for SubstrateGadget where - B: Block, - BE: Backend, - Module: SubstrateGadgetModule< - FinalityNotification = FinalityNotification, - BlockImportNotification = BlockImportNotification, - >, - Api: Send + Sync, + Module: SubstrateGadgetModule, { - type FinalityNotification = FinalityNotification; - type BlockImportNotification = BlockImportNotification; + type FinalityNotification = FinalityNotification; + type BlockImportNotification = BlockImportNotification; type ProtocolMessage = Module::ProtocolMessage; type Error = Module::Error; diff --git a/src/job_manager.rs b/gadget-core/src/job_manager.rs similarity index 95% rename from src/job_manager.rs rename to gadget-core/src/job_manager.rs index 886b2445..86ced27c 100644 --- a/src/job_manager.rs +++ b/gadget-core/src/job_manager.rs @@ -48,7 +48,7 @@ pub struct WorkManagerInner { pub active_tasks: HashSet>, pub enqueued_tasks: VecDeque>, // task hash => SSID => enqueued messages - pub enqueued_messages: EnqueuedMessage<[u8; 32], WM::SSID, WM::ProtocolMessage>, + pub enqueued_messages: EnqueuedMessage, } pub type EnqueuedMessage = HashMap>>; @@ -70,6 +70,8 @@ pub trait WorkManagerInterface: Send + Sync + 'static + Sized { type ProtocolMessage: ProtocolMessageMetadata + Send + Sync + 'static; type Error: Debug + Send + Sync + 'static; type SessionID: Copy + Hash + Eq + PartialEq + Display + Debug + Send + Sync + 'static; + type TaskID: Copy + Hash + Eq + PartialEq + Debug + Send + Sync + AsRef<[u8]> + 'static; + fn debug(&self, input: String); fn error(&self, input: String); fn warn(&self, input: String); @@ -99,6 +101,7 @@ pub trait ProtocolMessageMetadata { fn associated_block_id(&self) -> WM::Clock; fn associated_session_id(&self) -> WM::SessionID; fn associated_ssid(&self) -> WM::SSID; + fn associated_task(&self) -> WM::TaskID; } /// The [`ProtocolRemote`] is the interface between the [`ProtocolWorkManager`] and the async protocol. @@ -178,7 +181,7 @@ impl ProtocolWorkManager { /// Pushes the task, but does not necessarily start it pub fn push_task( &self, - task_hash: [u8; 32], + task_hash: WM::TaskID, force_start: bool, handle: Arc>, task: Pin>>, @@ -395,20 +398,21 @@ impl ProtocolWorkManager { Ok(()) } - pub fn job_exists(&self, job: &[u8; 32]) -> bool { + pub fn job_exists(&self, job: &WM::TaskID) -> bool { let lock = self.inner.read(); - lock.active_tasks.contains(job) || lock.enqueued_tasks.iter().any(|j| &j.task_hash == job) + lock.active_tasks.iter().any(|r| &r.task_hash == job) + || lock.enqueued_tasks.iter().any(|j| &j.task_hash == job) } pub fn deliver_message( &self, msg: WM::ProtocolMessage, - message_task_hash: [u8; 32], ) -> Result { self.utility.debug(format!( "Delivered message is intended for session_id = {}", msg.associated_session_id() )); + let message_task_hash = msg.associated_task(); let mut lock = self.inner.write(); // check the enqueued @@ -471,7 +475,7 @@ impl ProtocolWorkManager { } pub struct Job { - task_hash: [u8; 32], + task_hash: WM::TaskID, utility: Arc, handle: Arc>, task: Arc>>>, @@ -505,12 +509,6 @@ impl<'a, F: Send + Future + 'a, T> SendFuture<'a, T> for F {} pub type SyncFuture = SyncWrapper>>>; -impl std::borrow::Borrow<[u8; 32]> for Job { - fn borrow(&self) -> &[u8; 32] { - &self.task_hash - } -} - impl PartialEq for Job { fn eq(&self, other: &Self) -> bool { self.task_hash == other.task_hash @@ -538,7 +536,7 @@ impl Drop for Job { fn should_deliver( task: &Job, msg: &WM::ProtocolMessage, - message_task_hash: [u8; 32], + message_task_hash: WM::TaskID, ) -> bool { task.handle.session_id() == msg.associated_session_id() && task.task_hash == message_task_hash @@ -567,6 +565,7 @@ mod tests { associated_block_id: u64, associated_session_id: u32, associated_ssid: u32, + associated_task: [u8; 32], } impl ProtocolMessageMetadata for TestMessage { @@ -579,6 +578,9 @@ mod tests { fn associated_ssid(&self) -> u32 { self.associated_ssid } + fn associated_task(&self) -> [u8; 32] { + self.associated_task + } } impl WorkManagerInterface for TestWorkManager { @@ -587,6 +589,7 @@ mod tests { type ProtocolMessage = TestMessage; type Error = (); type SessionID = u32; + type TaskID = [u8; 32]; fn debug(&self, _input: String) {} fn error(&self, input: String) { @@ -709,7 +712,7 @@ mod tests { let (remote, task, mut rx) = generate_async_protocol(0, 0, 0); work_manager - .push_task([0; 32], true, remote.clone(), task) + .push_task(Default::default(), true, remote.clone(), task) .unwrap(); let message = TestMessage { @@ -717,10 +720,11 @@ mod tests { associated_block_id: 0, associated_session_id: 0, associated_ssid: 0, + associated_task: [0; 32], }; assert_ne!( DeliveryType::EnqueuedMessage, - work_manager.deliver_message(message, [0; 32]).unwrap() + work_manager.deliver_message(message).unwrap() ); let _ = rx.recv().await.unwrap(); } @@ -773,7 +777,7 @@ mod tests { // Add a queued task work_manager - .push_task([1; 32], false, remote.clone(), task) + .push_task([0; 32], false, remote.clone(), task) .unwrap(); // Deliver message, should succeed @@ -782,10 +786,11 @@ mod tests { associated_block_id: 0, associated_session_id: 1, associated_ssid: 1, + associated_task: [0; 32], }; assert_ne!( DeliveryType::EnqueuedMessage, - work_manager.deliver_message(msg.clone(), [1; 32]).unwrap() + work_manager.deliver_message(msg.clone()).unwrap() ); let next_message = rx.recv().await.unwrap(); assert_eq!(next_message, msg); @@ -1012,10 +1017,11 @@ mod tests { associated_block_id: 0, associated_session_id: 1, associated_ssid: 1, + associated_task: [0; 32], }; // Deliver a message to a non-existent job - let delivery_type = work_manager.deliver_message(message, [1; 32]).unwrap(); + let delivery_type = work_manager.deliver_message(message).unwrap(); // The message should be enqueued for future use assert_eq!(delivery_type, DeliveryType::EnqueuedMessage); @@ -1027,7 +1033,7 @@ mod tests { let (remote1, task1, _rx) = generate_async_protocol(1, 1, 0); work_manager - .push_task([1; 32], true, remote1, task1) + .push_task([0; 32], true, remote1, task1) .unwrap(); let message = TestMessage { @@ -1035,10 +1041,11 @@ mod tests { associated_block_id: 10, // Outdated block ID associated_session_id: 1, associated_ssid: 1, + associated_task: [0; 32], }; // Try to deliver a message with an outdated block ID - let delivery_type = work_manager.deliver_message(message, [1; 32]).unwrap(); + let delivery_type = work_manager.deliver_message(message).unwrap(); // The message should be enqueued for future use assert_eq!(delivery_type, DeliveryType::EnqueuedMessage); @@ -1062,7 +1069,7 @@ mod tests { let (remote, task, mut rx) = generate_async_protocol(1, 1, 0); work_manager - .push_task([1; 32], true, remote.clone(), task) + .push_task([0; 32], true, remote.clone(), task) .unwrap(); let message = TestMessage { @@ -1070,14 +1077,13 @@ mod tests { associated_block_id: 0, associated_session_id: 1, associated_ssid: 1, + associated_task: [0; 32], }; for _ in 0..10 { assert_ne!( DeliveryType::EnqueuedMessage, - work_manager - .deliver_message(message.clone(), [1; 32]) - .unwrap() + work_manager.deliver_message(message.clone()).unwrap() ); } @@ -1132,12 +1138,13 @@ mod tests { associated_block_id: 0, associated_session_id: 1, associated_ssid: 1, + associated_task: [0; 32], }; - work_manager.push_task([1; 32], true, remote, task).unwrap(); + work_manager.push_task([0; 32], true, remote, task).unwrap(); work_manager.poll(); // should identify and remove the stalled task - let delivery_type = work_manager.deliver_message(msg, [1; 32]).unwrap(); + let delivery_type = work_manager.deliver_message(msg).unwrap(); assert_eq!(delivery_type, DeliveryType::EnqueuedMessage); // message should be enqueued because task is stalled and removed } @@ -1150,11 +1157,12 @@ mod tests { associated_block_id: 0, associated_session_id: 1, associated_ssid: 1, + associated_task: [1; 32], // incorrect task hash, not 0;32 as below }; - work_manager.push_task([1; 32], true, remote, task).unwrap(); + work_manager.push_task([0; 32], true, remote, task).unwrap(); - let delivery_type = work_manager.deliver_message(msg, [2; 32]).unwrap(); // incorrect task hash + let delivery_type = work_manager.deliver_message(msg).unwrap(); // incorrect task hash assert_eq!(delivery_type, DeliveryType::EnqueuedMessage); // message should be enqueued because the task hash is incorrect } @@ -1165,14 +1173,15 @@ mod tests { fn associated_block_id(&self) -> u64 { 0u64 } - fn associated_session_id(&self) -> u64 { 0u64 } - fn associated_ssid(&self) -> u64 { 0u64 } + fn associated_task(&self) -> [u8; 32] { + [0; 32] + } } impl WorkManagerInterface for DummyRangeChecker { @@ -1181,6 +1190,7 @@ mod tests { type ProtocolMessage = DummyProtocolMessage; type Error = (); type SessionID = u64; + type TaskID = [u8; 32]; fn debug(&self, _input: String) { todo!() diff --git a/gadget-core/src/lib.rs b/gadget-core/src/lib.rs new file mode 100644 index 00000000..f0906143 --- /dev/null +++ b/gadget-core/src/lib.rs @@ -0,0 +1,5 @@ +pub use sc_client_api::Backend; +pub use sp_runtime::traits::Block; + +pub mod gadget; +pub mod job_manager; diff --git a/src/lib.rs b/src/lib.rs deleted file mode 100644 index 6feac97d..00000000 --- a/src/lib.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod gadget; -pub mod job_manager; diff --git a/webb-gadget/Cargo.toml b/webb-gadget/Cargo.toml new file mode 100644 index 00000000..4a57bd94 --- /dev/null +++ b/webb-gadget/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "webb-gadget" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] + +[dependencies] +mpc-net = { workspace = true } +tokio-rustls = { workspace = true } +gadget-core = { workspace = true, features = ["substrate"] } +tokio = { workspace = true } +serde = { workspace = true, features = ["derive"] } +tokio-util = { version = "0.7.9" } +async-trait = { workspace = true } +log = "0.4.20" +parking_lot = "0.12.1" +auto_impl = "1.1.0" +sc-client-api = { workspace = true } +sp-core = { workspace = true } +sp-runtime = { workspace = true } \ No newline at end of file diff --git a/webb-gadget/src/gadget/message.rs b/webb-gadget/src/gadget/message.rs new file mode 100644 index 00000000..d2ee3196 --- /dev/null +++ b/webb-gadget/src/gadget/message.rs @@ -0,0 +1,36 @@ +use crate::gadget::work_manager::WebbWorkManager; +use gadget_core::job_manager::{ProtocolMessageMetadata, WorkManagerInterface}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +pub struct GadgetProtocolMessage { + pub associated_block_id: ::Clock, + pub associated_session_id: ::SessionID, + pub associated_ssid: ::SSID, + pub from: UserID, + // If None, this is a broadcasted message + pub to: Option, + // A unique marker for the associated task this message belongs to + pub task_hash: ::TaskID, + pub payload: Vec, +} + +pub type UserID = u32; + +impl ProtocolMessageMetadata for GadgetProtocolMessage { + fn associated_block_id(&self) -> ::Clock { + self.associated_block_id + } + + fn associated_session_id(&self) -> ::SessionID { + self.associated_session_id + } + + fn associated_ssid(&self) -> ::SSID { + self.associated_ssid + } + + fn associated_task(&self) -> ::TaskID { + self.task_hash + } +} diff --git a/webb-gadget/src/gadget/mod.rs b/webb-gadget/src/gadget/mod.rs new file mode 100644 index 00000000..f823b47b --- /dev/null +++ b/webb-gadget/src/gadget/mod.rs @@ -0,0 +1,124 @@ +use crate::gadget::message::GadgetProtocolMessage; +use crate::gadget::network::Network; +use crate::gadget::work_manager::WebbWorkManager; +use crate::Error; +use async_trait::async_trait; +use gadget_core::gadget::substrate::{Client, SubstrateGadgetModule}; +use gadget_core::job_manager::{PollMethod, ProtocolWorkManager}; +use parking_lot::RwLock; +use sc_client_api::{Backend, BlockImportNotification, FinalityNotification}; +use sp_runtime::traits::{Block, Header}; +use sp_runtime::SaturatedConversion; +use std::marker::PhantomData; +use std::sync::Arc; +use tokio::sync::Mutex; + +pub mod message; +pub mod network; +pub mod work_manager; + +/// Used as a module to place inside the SubstrateGadget +pub struct WebbGadget { + #[allow(dead_code)] + network: N, + module: M, + job_manager: ProtocolWorkManager, + from_network: Mutex>, + clock: Arc>>, + _pd: PhantomData<(B, C, BE)>, +} + +const MAX_ACTIVE_TASKS: usize = 4; +const MAX_PENDING_TASKS: usize = 4; + +impl, B: Block, BE: Backend, N: Network, M: WebbGadgetModule> + WebbGadget +{ + pub fn new(mut network: N, mut module: M, now: Option) -> Self { + let clock = Arc::new(RwLock::new(now)); + let clock_clone = clock.clone(); + let from_registry = network.take_message_receiver().expect("Should exist"); + + let job_manager_zk = WebbWorkManager::new(move || *clock_clone.read()); + + let job_manager = ProtocolWorkManager::new( + job_manager_zk, + MAX_ACTIVE_TASKS, + MAX_PENDING_TASKS, + PollMethod::Interval { millis: 200 }, + ); + + module.on_job_manager_created(job_manager.clone()); + + WebbGadget { + module, + network, + job_manager, + clock, + from_network: Mutex::new(from_registry), + _pd: Default::default(), + } + } +} + +#[async_trait] +impl, B: Block, BE: Backend, N: Network, M: WebbGadgetModule> + SubstrateGadgetModule for WebbGadget +{ + type Error = Error; + type ProtocolMessage = GadgetProtocolMessage; + type Block = B; + type Backend = BE; + type Client = C; + + async fn get_next_protocol_message(&self) -> Option { + self.from_network.lock().await.recv().await + } + + async fn process_finality_notification( + &self, + notification: FinalityNotification, + ) -> Result<(), Self::Error> { + *self.clock.write() = Some((*notification.header.number()).saturated_into()); + self.module + .process_finality_notification(notification) + .await + } + + async fn process_block_import_notification( + &self, + notification: BlockImportNotification, + ) -> Result<(), Self::Error> { + self.module + .process_block_import_notification(notification) + .await + } + + async fn process_protocol_message( + &self, + message: Self::ProtocolMessage, + ) -> Result<(), Self::Error> { + self.job_manager + .deliver_message(message) + .map(|_| ()) + .map_err(|err| Error::WorkManagerError { err }) + } + + async fn process_error(&self, error: Self::Error) { + self.module.process_error(error).await + } +} + +#[async_trait] +pub trait WebbGadgetModule: Send + Sync { + fn on_job_manager_created(&mut self, job_manager: ProtocolWorkManager); + async fn process_finality_notification( + &self, + notification: FinalityNotification, + ) -> Result<(), Error>; + async fn process_block_import_notification( + &self, + notification: BlockImportNotification, + ) -> Result<(), Error>; + async fn process_error(&self, error: Error); +} diff --git a/webb-gadget/src/gadget/network.rs b/webb-gadget/src/gadget/network.rs new file mode 100644 index 00000000..7867377e --- /dev/null +++ b/webb-gadget/src/gadget/network.rs @@ -0,0 +1,6 @@ +use crate::gadget::message::GadgetProtocolMessage; +use tokio::sync::mpsc::UnboundedReceiver; + +pub trait Network: Send + Sync { + fn take_message_receiver(&mut self) -> Option>; +} diff --git a/webb-gadget/src/gadget/work_manager.rs b/webb-gadget/src/gadget/work_manager.rs new file mode 100644 index 00000000..a9fc188a --- /dev/null +++ b/webb-gadget/src/gadget/work_manager.rs @@ -0,0 +1,56 @@ +use crate::gadget::message::GadgetProtocolMessage; +use gadget_core::job_manager::WorkManagerInterface; +use std::sync::Arc; + +pub struct WebbWorkManager { + pub(crate) clock: Arc< + dyn Fn() -> Option<::Clock> + + Send + + Sync + + 'static, + >, +} + +impl WebbWorkManager { + pub fn new( + clock: impl Fn() -> Option<::Clock> + + Send + + Sync + + 'static, + ) -> Self { + Self { + clock: Arc::new(clock), + } + } +} + +const ACCEPTABLE_BLOCK_TOLERANCE: u64 = 5; + +impl WorkManagerInterface for WebbWorkManager { + type SSID = u16; + type Clock = u64; + type ProtocolMessage = GadgetProtocolMessage; + type Error = crate::Error; + type SessionID = u64; + type TaskID = [u8; 32]; + + fn debug(&self, input: String) { + log::debug!(target: "gadget", "{input}") + } + + fn error(&self, input: String) { + log::error!(target: "gadget", "{input}") + } + + fn warn(&self, input: String) { + log::warn!(target: "gadget", "{input}") + } + + fn clock(&self) -> Self::Clock { + (self.clock)().expect("No finality notification received") + } + + fn acceptable_block_tolerance() -> Self::Clock { + ACCEPTABLE_BLOCK_TOLERANCE + } +} diff --git a/webb-gadget/src/lib.rs b/webb-gadget/src/lib.rs new file mode 100644 index 00000000..0105e022 --- /dev/null +++ b/webb-gadget/src/lib.rs @@ -0,0 +1,46 @@ +use crate::gadget::network::Network; +use crate::gadget::{WebbGadget, WebbGadgetModule}; +use gadget_core::gadget::manager::{GadgetError, GadgetManager}; +use gadget_core::gadget::substrate::{Client, SubstrateGadget}; +use gadget_core::job_manager::WorkManagerError; +pub use sc_client_api::BlockImportNotification; +pub use sc_client_api::{Backend, FinalityNotification}; +use sp_runtime::traits::Block; +use std::fmt::{Debug, Display, Formatter}; + +pub mod gadget; + +#[derive(Debug)] +pub enum Error { + RegistryCreateError { err: String }, + RegistrySendError { err: String }, + RegistryRecvError { err: String }, + RegistryListenError { err: String }, + GadgetManagerError { err: GadgetError }, + InitError { err: String }, + WorkManagerError { err: WorkManagerError }, +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Debug::fmt(self, f) + } +} + +impl std::error::Error for Error {} + +pub async fn run, B: Block, BE: Backend, N: Network, M: WebbGadgetModule>( + network: N, + module: M, + client: C, +) -> Result<(), Error> { + let now = None; + let webb_gadget = WebbGadget::new(network, module, now); + // Plug the module into the substrate gadget to interface the WebbGadget with Substrate + let substrate_gadget = SubstrateGadget::new(client, webb_gadget); + + // Run the GadgetManager to execute the substrate gadget + GadgetManager::new(substrate_gadget) + .await + .map_err(|err| Error::GadgetManagerError { err }) +} diff --git a/zk-gadget/Cargo.toml b/zk-gadget/Cargo.toml new file mode 100644 index 00000000..077aa763 --- /dev/null +++ b/zk-gadget/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "zk-gadget" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tokio-rustls = { workspace = true } +mpc-net = { workspace = true } +webb-gadget = { workspace = true } +gadget-core = { workspace = true } +bincode2 = { workspace = true } +tokio = { workspace = true } +futures-util = { workspace = true } +serde = { workspace = true, features = ["derive"] } +async-trait = { workspace = true } \ No newline at end of file diff --git a/zk-gadget/src/lib.rs b/zk-gadget/src/lib.rs new file mode 100644 index 00000000..d1994609 --- /dev/null +++ b/zk-gadget/src/lib.rs @@ -0,0 +1,58 @@ +use crate::module::ZkModule; +use crate::network::RegistantId; +use gadget_core::gadget::substrate::Client; +use gadget_core::{Backend, Block}; +use mpc_net::prod::RustlsCertificate; +use std::net::SocketAddr; +use tokio_rustls::rustls::{Certificate, PrivateKey, RootCertStore}; +use webb_gadget::Error; + +pub mod module; +pub mod network; + +pub struct ZkGadgetConfig { + king_bind_addr: Option, + client_only_king_addr: Option, + id: RegistantId, + public_identity_der: Vec, + private_identity_der: Vec, + client_only_king_public_identity_der: Option>, +} + +pub async fn run, B: Block, BE: Backend>( + config: ZkGadgetConfig, + client: C, +) -> Result<(), Error> { + // Create the zk gadget module + let our_identity = RustlsCertificate { + cert: Certificate(config.public_identity_der), + private_key: PrivateKey(config.private_identity_der), + }; + + let network = if let Some(addr) = &config.king_bind_addr { + network::ZkNetworkService::new_king(*addr, our_identity).await? + } else { + let king_addr = config + .client_only_king_addr + .expect("King address must be specified if king bind address is not specified"); + + let mut king_certs = RootCertStore::empty(); + king_certs + .add(&Certificate( + config + .client_only_king_public_identity_der + .expect("The client must specify the identity of the king"), + )) + .map_err(|err| Error::InitError { + err: err.to_string(), + })?; + + network::ZkNetworkService::new_client(king_addr, config.id, our_identity, king_certs) + .await? + }; + + let zk_module = ZkModule { job_manager: None }; // TODO: proper implementation + + // Plug the module into the webb gadget + webb_gadget::run(network, zk_module, client).await +} diff --git a/zk-gadget/src/module/mod.rs b/zk-gadget/src/module/mod.rs new file mode 100644 index 00000000..167cb465 --- /dev/null +++ b/zk-gadget/src/module/mod.rs @@ -0,0 +1,36 @@ +use async_trait::async_trait; +use gadget_core::job_manager::ProtocolWorkManager; +use gadget_core::Block; +use webb_gadget::gadget::work_manager::WebbWorkManager; +use webb_gadget::gadget::WebbGadgetModule; +use webb_gadget::{BlockImportNotification, Error, FinalityNotification}; + +pub struct ZkModule { + #[allow(dead_code)] + pub job_manager: Option>, +} + +#[async_trait] +impl WebbGadgetModule for ZkModule { + fn on_job_manager_created(&mut self, job_manager: ProtocolWorkManager) { + self.job_manager = Some(job_manager); + } + + async fn process_finality_notification( + &self, + _notification: FinalityNotification, + ) -> Result<(), Error> { + todo!() + } + + async fn process_block_import_notification( + &self, + _notification: BlockImportNotification, + ) -> Result<(), Error> { + todo!() + } + + async fn process_error(&self, _error: Error) { + todo!() + } +} diff --git a/zk-gadget/src/network/mod.rs b/zk-gadget/src/network/mod.rs new file mode 100644 index 00000000..7292659e --- /dev/null +++ b/zk-gadget/src/network/mod.rs @@ -0,0 +1,367 @@ +use futures_util::sink::SinkExt; +use futures_util::StreamExt; +use mpc_net::multi::WrappedStream; +use mpc_net::prod::{CertToDer, RustlsCertificate}; +use mpc_net::MpcNetError; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; +use tokio::sync::mpsc::UnboundedReceiver; +use tokio::sync::Mutex; +use tokio_rustls::rustls::server::NoClientAuth; +use tokio_rustls::rustls::{RootCertStore, ServerConfig}; +use tokio_rustls::{rustls, TlsAcceptor, TlsStream}; + +/// Type should correspond to the on-chain identifier of the registrant +pub type RegistantId = u64; + +pub enum ZkNetworkService { + King { + listener: Option, + registrants: Arc>>, + to_gadget: tokio::sync::mpsc::UnboundedSender, + from_registry: Option>, + identity: RustlsCertificate, + }, + Client { + king_registry_addr: SocketAddr, + registrant_id: RegistantId, + connection: Option>, + cert_der: Vec, + to_gadget: tokio::sync::mpsc::UnboundedSender, + from_registry: Option>, + }, +} + +#[allow(dead_code)] +pub struct Registrant { + id: RegistantId, + cert_der: Vec, +} + +use crate::Error; +use webb_gadget::gadget::message::GadgetProtocolMessage; +use webb_gadget::gadget::network::Network; + +pub fn create_server_tls_acceptor( + server_certificate: T, +) -> Result { + let client_auth = NoClientAuth::boxed(); + let server_config = ServerConfig::builder() + .with_safe_defaults() + .with_client_cert_verifier(client_auth) + .with_single_cert( + vec![rustls::Certificate( + server_certificate.serialize_certificate_to_der()?, + )], + rustls::PrivateKey(server_certificate.serialize_private_key_to_der()?), + ) + .unwrap(); + Ok(TlsAcceptor::from(Arc::new(server_config))) +} + +impl ZkNetworkService { + pub async fn new_king( + bind_addr: T, + identity: RustlsCertificate, + ) -> Result { + let bind_addr: SocketAddr = bind_addr + .to_socket_addrs() + .map_err(|err| Error::RegistryCreateError { + err: err.to_string(), + })? + .next() + .ok_or(Error::RegistryCreateError { + err: "No address found".to_string(), + })?; + + let listener = tokio::net::TcpListener::bind(bind_addr) + .await + .map_err(|err| Error::RegistryCreateError { + err: err.to_string(), + })?; + let registrants = Arc::new(Mutex::new(HashMap::new())); + let (to_gadget, from_registry) = tokio::sync::mpsc::unbounded_channel(); + Ok(ZkNetworkService::King { + listener: Some(listener), + registrants, + to_gadget, + identity, + from_registry: Some(from_registry), + }) + } + + pub async fn new_client( + king_registry_addr: T, + registrant_id: RegistantId, + client_identity: RustlsCertificate, + king_certs: RootCertStore, + ) -> Result { + let king_registry_addr: SocketAddr = king_registry_addr + .to_socket_addrs() + .map_err(|err| Error::RegistryCreateError { + err: err.to_string(), + })? + .next() + .ok_or(Error::RegistryCreateError { + err: "No address found".to_string(), + })?; + + let cert_der = client_identity.cert.0.clone(); + + let connection = TcpStream::connect(king_registry_addr) + .await + .map_err(|err| Error::RegistryCreateError { + err: err.to_string(), + })?; + + // Upgrade to TLS + let tls = mpc_net::prod::create_client_mutual_tls_connector(king_certs, client_identity) + .map_err(|err| Error::RegistryCreateError { + err: format!("{err:?}"), + })?; + + let connection = tls + .connect( + tokio_rustls::rustls::ServerName::IpAddress(king_registry_addr.ip()), + connection, + ) + .await + .map_err(|err| Error::RegistryCreateError { + err: err.to_string(), + })?; + + let (to_gadget, from_registry) = tokio::sync::mpsc::unbounded_channel(); + + let mut this = ZkNetworkService::Client { + king_registry_addr, + registrant_id, + cert_der, + connection: Some(TlsStream::Client(connection)), + to_gadget, + from_registry: Some(from_registry), + }; + + this.client_register().await?; + + Ok(this) + } + + pub async fn run(self) -> Result<(), Error> { + match self { + Self::King { + listener, + registrants, + to_gadget, + identity, + .. + } => { + let listener = listener.expect("Should exist"); + let tls_acceptor = create_server_tls_acceptor(identity).map_err(|err| { + Error::RegistryCreateError { + err: format!("{err:?}"), + } + })?; + + while let Ok((stream, peer_addr)) = listener.accept().await { + println!("[Registry] Accepted connection from {peer_addr}, upgrading to TLS"); + let stream = tls_acceptor.accept(stream).await.map_err(|err| { + Error::RegistryCreateError { + err: format!("{err:?}"), + } + })?; + + handle_stream_as_king( + TlsStream::Server(stream), + peer_addr, + registrants.clone(), + to_gadget.clone(), + ); + } + + Err(Error::RegistryCreateError { + err: "Listener closed".to_string(), + }) + } + Self::Client { + connection, + to_gadget, + .. + } => { + let stream = connection.expect("Should exist"); + let mut wrapped_stream = mpc_net::multi::wrap_stream(stream); + while let Some(Ok(message)) = wrapped_stream.next().await { + match bincode2::deserialize::(&message) { + Ok(packet) => match packet { + RegistryPacket::SubstrateGadgetMessage { payload } => { + if let Err(err) = to_gadget.send(payload) { + eprintln!( + "[Registry] Failed to send message to gadget: {err:?}" + ); + } + } + _ => { + println!("[Registry] Received invalid packet"); + } + }, + Err(err) => { + println!("[Registry] Received invalid packet: {err}"); + } + } + } + + Err(Error::RegistryListenError { + err: "Connection closed".to_string(), + }) + } + } + } + + async fn client_register(&mut self) -> Result<(), Error> { + match self { + Self::King { .. } => Err(Error::RegistryCreateError { + err: "Cannot register as king".to_string(), + }), + Self::Client { + king_registry_addr: _, + registrant_id, + connection, + cert_der, + .. + } => { + let conn = connection.as_mut().expect("Should exist"); + let mut wrapped_stream = mpc_net::multi::wrap_stream(conn); + + send_stream( + &mut wrapped_stream, + RegistryPacket::Register { + id: *registrant_id, + cert_der: cert_der.clone(), + }, + ) + .await?; + + let response = recv_stream::(&mut wrapped_stream).await?; + + if !matches!( + &response, + &RegistryPacket::RegisterResponse { success: true, .. } + ) { + return Err(Error::RegistryCreateError { + err: "Unexpected response".to_string(), + }); + } + + Ok(()) + } + } + } +} + +#[derive(Serialize, Deserialize)] +enum RegistryPacket { + Register { id: RegistantId, cert_der: Vec }, + RegisterResponse { id: RegistantId, success: bool }, + // A message for the substrate gadget + SubstrateGadgetMessage { payload: GadgetProtocolMessage }, +} + +fn handle_stream_as_king( + stream: TlsStream, + peer_addr: SocketAddr, + registrants: Arc>>, + to_gadget: tokio::sync::mpsc::UnboundedSender, +) { + tokio::task::spawn(async move { + let mut wrapped_stream = mpc_net::multi::wrap_stream(stream); + let mut peer_id = None; + while let Some(Ok(message)) = wrapped_stream.next().await { + match bincode2::deserialize::(&message) { + Ok(packet) => match packet { + RegistryPacket::Register { id, cert_der } => { + println!("[Registry] Received registration for id {id}"); + peer_id = Some(id); + let mut registrants = registrants.lock().await; + registrants.insert(id, Registrant { id, cert_der }); + if let Err(err) = send_stream( + &mut wrapped_stream, + RegistryPacket::RegisterResponse { id, success: true }, + ) + .await + { + eprintln!("[Registry] Failed to send registration response: {err:?}"); + } + } + RegistryPacket::SubstrateGadgetMessage { payload } => { + if let Err(err) = to_gadget.send(payload) { + eprintln!("[Registry] Failed to send message to gadget: {err:?}"); + } + } + _ => { + println!("[Registry] Received invalid packet"); + } + }, + Err(err) => { + println!("[Registry] Received invalid packet: {err}"); + } + } + } + + // Deregister peer + if let Some(id) = peer_id { + let mut registrants = registrants.lock().await; + registrants.remove(&id); + } + + eprintln!("[Registry] Connection closed to peer {peer_addr}") + }); +} + +async fn send_stream( + stream: &mut WrappedStream, + payload: T, +) -> Result<(), Error> { + let serialized = bincode2::serialize(&payload).map_err(|err| Error::RegistrySendError { + err: err.to_string(), + })?; + + stream + .send(serialized.into()) + .await + .map_err(|err| Error::RegistrySendError { + err: err.to_string(), + }) +} + +async fn recv_stream( + stream: &mut WrappedStream, +) -> Result { + let message = stream + .next() + .await + .ok_or(Error::RegistryRecvError { + err: "Stream closed".to_string(), + })? + .map_err(|err| Error::RegistryRecvError { + err: err.to_string(), + })?; + + let deserialized = bincode2::deserialize(&message).map_err(|err| Error::RegistryRecvError { + err: err.to_string(), + })?; + + Ok(deserialized) +} + +impl Network for ZkNetworkService { + fn take_message_receiver(&mut self) -> Option> { + match self { + Self::King { from_registry, .. } => from_registry.take(), + Self::Client { from_registry, .. } => from_registry.take(), + } + } +}