Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamline and refactor #24

Merged
merged 2 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ members = [
gadget-core = { path = "./gadget-core" }
webb-gadget = { path = "./webb-gadget" }
zk-gadget = { path = "./zk-gadget" }
zk-playground = { path = "./playground" }

sc-client-api = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false }
sp-core = { git = "https://github.com/paritytech/substrate", branch = "polkadot-v1.0.0", default-features = false }
Expand Down
39 changes: 17 additions & 22 deletions gadget-core/src/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ impl Display for JobError {
impl Error for JobError {}

#[async_trait]
pub trait ExecutableJob: SendFuture<'static, Result<(), JobError>> + Unpin {
pub trait ExecutableJob: Send + 'static {
async fn pre_job_hook(&mut self) -> Result<ProceedWithExecution, JobError>;
async fn job(&mut self) -> Result<(), JobError>;
async fn post_job_hook(&mut self) -> Result<(), JobError>;

async fn execute(&mut self) -> Result<(), JobError> {
match self.pre_job_hook().await? {
ProceedWithExecution::True => {
let result = (&mut self).await;
let result = self.job().await;
let post_result = self.post_job_hook().await;
result.and(post_result)
}
Expand All @@ -48,34 +50,38 @@ pub trait ExecutableJob: SendFuture<'static, Result<(), JobError>> + Unpin {
}
}

pub struct ExecutableJobWrapper<Pre: ?Sized, Protocol, Post: ?Sized> {
pub struct ExecutableJobWrapper<Pre: ?Sized, Protocol: ?Sized, Post: ?Sized> {
pre: Pin<Box<Pre>>,
protocol: Pin<Box<Protocol>>,
post: Pin<Box<Post>>,
}

#[async_trait]
impl<Pre: ?Sized, Protocol, Post: ?Sized> ExecutableJob
impl<Pre: ?Sized, Protocol: ?Sized, Post: ?Sized> ExecutableJob
for ExecutableJobWrapper<Pre, Protocol, Post>
where
Pre: Future<Output = Result<ProceedWithExecution, JobError>> + Send + 'static,
Pre: SendFuture<'static, Result<ProceedWithExecution, JobError>>,
Protocol: SendFuture<'static, Result<(), JobError>>,
Post: Future<Output = Result<(), JobError>> + Send + 'static,
Post: SendFuture<'static, Result<(), JobError>>,
{
async fn pre_job_hook(&mut self) -> Result<ProceedWithExecution, JobError> {
self.pre.as_mut().await
}

async fn job(&mut self) -> Result<(), JobError> {
self.protocol.as_mut().await
}

async fn post_job_hook(&mut self) -> Result<(), JobError> {
self.post.as_mut().await
}
}

impl<Pre, Protocol, Post> ExecutableJobWrapper<Pre, Protocol, Post>
where
Pre: Future<Output = Result<ProceedWithExecution, JobError>>,
Pre: SendFuture<'static, Result<ProceedWithExecution, JobError>>,
Protocol: SendFuture<'static, Result<(), JobError>>,
Post: Future<Output = Result<(), JobError>>,
Post: SendFuture<'static, Result<(), JobError>>,
{
pub fn new(pre: Pre, protocol: Protocol, post: Post) -> Self {
Self {
Expand All @@ -86,17 +92,6 @@ where
}
}

impl<Pre: ?Sized, Protocol, Post: ?Sized> Future for ExecutableJobWrapper<Pre, Protocol, Post>
where
Protocol: SendFuture<'static, Result<(), JobError>>,
{
type Output = <Protocol as Future>::Output;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.protocol.as_mut().poll(cx)
}
}

#[derive(Default)]
pub struct JobBuilder {
pre: Option<Pin<Box<PreJobHook>>>,
Expand Down Expand Up @@ -124,9 +119,9 @@ impl Future for DefaultPostJobHook {
}
}

pub type BuiltExecutableJobWrapper<Protocol> = ExecutableJobWrapper<
pub type BuiltExecutableJobWrapper = ExecutableJobWrapper<
dyn SendFuture<'static, Result<ProceedWithExecution, JobError>>,
Protocol,
dyn SendFuture<'static, Result<(), JobError>>,
dyn SendFuture<'static, Result<(), JobError>>,
>;

Expand All @@ -151,7 +146,7 @@ impl JobBuilder {
self
}

pub fn build<Protocol>(self, protocol: Protocol) -> BuiltExecutableJobWrapper<Protocol>
pub fn build<Protocol>(self, protocol: Protocol) -> BuiltExecutableJobWrapper
where
Protocol: SendFuture<'static, Result<(), JobError>>,
{
Expand Down
4 changes: 2 additions & 2 deletions gadget-core/src/job_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ fn should_deliver<WM: WorkManagerInterface>(
#[cfg(test)]
mod tests {
use super::*;
use crate::job::{BuiltExecutableJobWrapper, JobBuilder, JobError};
use crate::job::{BuiltExecutableJobWrapper, JobBuilder};
use parking_lot::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
Expand Down Expand Up @@ -647,7 +647,7 @@ mod tests {
started_at: u64,
) -> (
Arc<TestProtocolRemote>,
BuiltExecutableJobWrapper<impl SendFuture<'static, Result<(), JobError>>>,
BuiltExecutableJobWrapper,
UnboundedReceiver<TestMessage>,
) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
Expand Down
4 changes: 2 additions & 2 deletions playground/examples/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use ark_std::{cfg_iter, start_timer};
use ark_std::{end_timer, Zero};
use clap::Parser;
use futures::{future, TryFutureExt};
use gadget_core::gadget::substrate::Client;
use gadget_core::job::{BuiltExecutableJobWrapper, JobBuilder, JobError};
use gadget_core::{gadget::substrate::Client, job_manager::SendFuture};
use groth16::proving_key::PackedProvingKeyShare;
use mpc_net::prod::CertToDer;
use mpc_net::MultiplexedStreamID;
Expand Down Expand Up @@ -186,7 +186,7 @@ fn async_protocol_generator(
BlockchainClient,
TestBlock,
>,
) -> BuiltExecutableJobWrapper<impl SendFuture<'static, Result<(), JobError>>> {
) -> BuiltExecutableJobWrapper {
let stop_tx = params.extra_parameters.stop_tx.clone();
let party_id = params.party_id;
JobBuilder::default()
Expand Down
10 changes: 3 additions & 7 deletions test-gadget/src/gadget.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ use crate::work_manager::{
};
use async_trait::async_trait;
use gadget_core::gadget::manager::AbstractGadget;
use gadget_core::job::{BuiltExecutableJobWrapper, JobBuilder, JobError};
use gadget_core::job_manager::{PollMethod, ProtocolWorkManager, SendFuture, WorkManagerInterface};
use gadget_core::job::{BuiltExecutableJobWrapper, JobBuilder};
use gadget_core::job_manager::{PollMethod, ProtocolWorkManager, WorkManagerInterface};
use parking_lot::RwLock;
use std::pin::Pin;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use tokio::sync::Mutex;
Expand Down Expand Up @@ -142,10 +141,7 @@ fn create_test_async_protocol<B: Send + Sync + 'static>(
task_id: <TestWorkManager as WorkManagerInterface>::TaskID,
test_bundle: B,
proto_gen: &dyn AsyncProtocolGenerator<B>,
) -> (
TestProtocolRemote,
BuiltExecutableJobWrapper<Pin<Box<dyn SendFuture<'static, Result<(), JobError>>>>>,
) {
) -> (TestProtocolRemote, BuiltExecutableJobWrapper) {
let is_done = Arc::new(AtomicBool::new(false));
let (to_async_protocol, protocol_message_rx) = tokio::sync::mpsc::unbounded_channel();
let (start_tx, start_rx) = tokio::sync::oneshot::channel();
Expand Down
67 changes: 67 additions & 0 deletions webb-gadget/src/helpers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use gadget_core::job::ExecutableJob;
use gadget_core::job::JobError;
use gadget_core::job::{BuiltExecutableJobWrapper, JobBuilder, ProceedWithExecution};
use gadget_core::job_manager::ShutdownReason;
use std::fmt::Display;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

/// Wraps an async protocol with logic that makes it compatible with the job manager
pub fn create_job_manager_compatible_job<T: Display + Send + Clone + 'static>(
task_name: T,
is_done: Arc<AtomicBool>,
start_rx: tokio::sync::oneshot::Receiver<()>,
shutdown_rx: tokio::sync::oneshot::Receiver<ShutdownReason>,
mut async_protocol: BuiltExecutableJobWrapper,
) -> BuiltExecutableJobWrapper {
let task_name_cloned = task_name.clone();
let pre_hook = async move {
match start_rx.await {
Ok(_) => Ok(ProceedWithExecution::True),
Err(err) => {
log::error!("Protocol {task_name_cloned} failed to receive start signal: {err:?}");
Ok(ProceedWithExecution::False)
}
}
};

let post_hook = async move {
// Mark the task as done
is_done.store(true, Ordering::SeqCst);
Ok(())
};

// This wrapped future enables proper functionality between the async protocol and the
// job manager
let wrapped_future = async move {
tokio::select! {
res0 = async_protocol.execute() => {
if let Err(err) = res0 {
log::error!("Protocol {task_name} failed: {err:?}");
Err(JobError::from(err.to_string()))
} else {
log::info!("Protocol {task_name} finished");
Ok(())
}
},

res1 = shutdown_rx => {
match res1 {
Ok(reason) => {
log::info!("Protocol {task_name} shutdown: {reason:?}");
Ok(())
},
Err(err) => {
log::error!("Protocol {task_name} shutdown failed: {err:?}");
Err(JobError::from(err.to_string()))
},
}
}
}
};

JobBuilder::default()
.pre(pre_hook)
.post(post_hook)
.build(wrapped_future)
}
2 changes: 2 additions & 0 deletions webb-gadget/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use std::fmt::{Debug, Display, Formatter};

pub mod gadget;

pub mod helpers;

#[derive(Debug)]
pub enum Error {
RegistryCreateError { err: String },
Expand Down
5 changes: 1 addition & 4 deletions zk-gadget/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ use crate::client_ext::ClientWithApi;
use crate::module::proto_gen::AsyncProtocolGenerator;
use crate::module::{AdditionalProtocolParams, ZkModule};
use crate::network::{RegistantId, ZkNetworkService};
use gadget_core::job::JobError;
use gadget_core::job_manager::SendFuture;
use mpc_net::prod::RustlsCertificate;
use sp_runtime::traits::Block;
use std::net::SocketAddr;
Expand All @@ -29,8 +27,7 @@ pub async fn run<
C: ClientWithApi<B>,
B: Block,
T: AdditionalProtocolParams,
F: SendFuture<'static, Result<(), JobError>>,
Gen: AsyncProtocolGenerator<T, Error, ZkNetworkService, C, B, F>,
Gen: AsyncProtocolGenerator<T, Error, ZkNetworkService, C, B>,
>(
config: ZkGadgetConfig,
client: C,
Expand Down
15 changes: 5 additions & 10 deletions zk-gadget/src/module/mod.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,30 @@
use crate::client_ext::ClientWithApi;
use crate::network::{RegistantId, ZkNetworkService};
use async_trait::async_trait;
use gadget_core::job::JobError;
use gadget_core::job_manager::{ProtocolWorkManager, SendFuture};
use gadget_core::job_manager::ProtocolWorkManager;
use sp_runtime::traits::Block;
use webb_gadget::gadget::work_manager::WebbWorkManager;
use webb_gadget::gadget::WebbGadgetModule;
use webb_gadget::{BlockImportNotification, Error, FinalityNotification};

pub mod proto_gen;

pub struct ZkModule<T, C, B, F: SendFuture<'static, Result<(), JobError>>> {
pub struct ZkModule<T, C, B> {
pub party_id: RegistantId,
pub n_parties: usize,
pub additional_protocol_params: T,
pub client: C,
pub network: ZkNetworkService,
pub async_protocol_generator:
Box<dyn proto_gen::AsyncProtocolGenerator<T, Error, ZkNetworkService, C, B, F>>,
Box<dyn proto_gen::AsyncProtocolGenerator<T, Error, ZkNetworkService, C, B>>,
}

pub trait AdditionalProtocolParams: Send + Sync + Clone + 'static {}
impl<T: Send + Sync + Clone + 'static> AdditionalProtocolParams for T {}

#[async_trait]
impl<
B: Block,
T: AdditionalProtocolParams,
C: ClientWithApi<B>,
F: SendFuture<'static, Result<(), JobError>>,
> WebbGadgetModule<B> for ZkModule<T, C, B, F>
impl<B: Block, T: AdditionalProtocolParams, C: ClientWithApi<B>> WebbGadgetModule<B>
for ZkModule<T, C, B>
{
async fn process_finality_notification(
&self,
Expand Down
Loading
Loading