diff --git a/Cargo.toml b/Cargo.toml index 23c5833a..206bbbf3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,55 +14,48 @@ gadget-common = { path = "./gadget-common" } zk-gadget = { path = "./zk-gadget" } mp-ecdsa-protocol = { path = "./protocols/mp-ecdsa" } -pallet-jobs-rpc-runtime-api = { git = "https://github.com/webb-tools/tangle" } -tangle-primitives = { git = "https://github.com/webb-tools/tangle" } -tangle-testnet-runtime = { git = "https://github.com/webb-tools/tangle" } -tangle-mainnet-runtime = { git = "https://github.com/webb-tools/tangle" } +pallet-jobs-rpc-runtime-api = { git = "https://github.com/webb-tools/tangle", default-features = false } +tangle-primitives = { git = "https://github.com/webb-tools/tangle", default-features = false } +tangle-testnet-runtime = { git = "https://github.com/webb-tools/tangle", default-features = false } +tangle-mainnet-runtime = { git = "https://github.com/webb-tools/tangle", default-features = false } multi-party-ecdsa = { git = "https://github.com/webb-tools/cggmp-threshold-ecdsa/" } round-based = { git = "https://github.com/webb-tools/round-based-protocol", features = [] } curv = { package = "curv-kzen", version = "0.10.0" } -sc-client-api = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sp-core = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sp-runtime = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sc-utils = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sp-api = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sp-application-crypto = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sp-consensus-aura = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sp-keyring = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sp-timestamp = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sp-blockchain = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sp-block-builder = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } +sc-client-api = { git = "https://github.com/paritytech/polkadot-sdk", default-features = false, branch = "release-polkadot-v1.1.0" } +sp-core = { git = "https://github.com/paritytech/polkadot-sdk", default-features = false, branch = "release-polkadot-v1.1.0" } +sp-runtime = { git = "https://github.com/paritytech/polkadot-sdk", default-features = false, branch = "release-polkadot-v1.1.0" } +sc-utils = { git = "https://github.com/paritytech/polkadot-sdk", default-features = false, branch = "release-polkadot-v1.1.0" } +sp-api = { git = "https://github.com/paritytech/polkadot-sdk", default-features = false, branch = "release-polkadot-v1.1.0" } +sp-application-crypto = { git = "https://github.com/paritytech/polkadot-sdk", default-features = false, branch = "release-polkadot-v1.1.0" } +sp-keyring = { git = "https://github.com/paritytech/polkadot-sdk", default-features = false, branch = "release-polkadot-v1.1.0" } -sc-offchain = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sc-basic-authorship = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sc-consensus-aura = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sc-consensus-grandpa = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sc-executor = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } sc-network = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } sc-network-common = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } sc-network-sync = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sc-service = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sc-telemetry = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sc-transaction-pool-api = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sc-cli = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sc-consensus = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sc-transaction-pool = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -sc-rpc-api = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } parity-scale-codec = "3.6.5" -substrate-build-script-utils = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -pallet-im-online = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -substrate-frame-rpc-system = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } -pallet-transaction-payment-rpc = { git = "https://github.com/paritytech/polkadot-sdk", branch = "release-polkadot-v1.1.0" } mpc-net = { git = "https://github.com/webb-tools/zk-SaaS/" } dist-primitives = { git = "https://github.com/webb-tools/zk-SaaS/" } secret-sharing = { git = "https://github.com/webb-tools/zk-SaaS/" } groth16 = { git = "https://github.com/webb-tools/zk-SaaS/" } -tokio-rustls = "0.24.1" +# ARK Libraries +ark-std = { version = "0.4.0", default-features = false, features = [ "print-trace", "std" ] } +ark-crypto-primitives = { version = "0.4.0", default-features = false } +ark-ff = { version = "0.4.2", default-features = false } +ark-poly = { version = "0.4.2", default-features = false } +ark-ec = { version = "0.4.2", default-features = false } +ark-relations = { version = "0.4.0", default-features = false } +ark-serialize = { version = "0.4.2", default-features = false, features = [ "derive" ] } +ark-groth16 = { version = "0.4.0", default-features = false } +ark-circom = { git = "https://github.com/webb-tools/ark-circom.git" } +# ARK curves +ark-bn254 = { version = "0.4.0", default-features = false, features = ["curve"] } + +tokio-rustls = "0.24" tokio = "1.32.0" bincode2 = "2" futures-util = "0.3.28" @@ -81,4 +74,4 @@ clap = "4.0.32" hex-literal = "0.4.1" rand = "0.8.5" jsonrpsee = "0.16.2" -linked-hash-map = "0.5.6" \ No newline at end of file +linked-hash-map = "0.5.6" diff --git a/gadget-common/Cargo.toml b/gadget-common/Cargo.toml index 9d1bc634..d867b32e 100644 --- a/gadget-common/Cargo.toml +++ b/gadget-common/Cargo.toml @@ -6,15 +6,8 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = [ - "std" -] - -std = [ - "sp-api/std", - "tangle-primitives/std", - "pallet-jobs-rpc-runtime-api/std" -] +default = ["std"] +std = [] [dependencies] gadget-core = { workspace = true, features = ["substrate"] } @@ -32,4 +25,4 @@ bincode2 = { workspace = true } sp-api = { workspace = true, default-features = false } tangle-primitives = { workspace = true, default-features = false } -pallet-jobs-rpc-runtime-api = { workspace = true, default-features = false } \ No newline at end of file +pallet-jobs-rpc-runtime-api = { workspace = true, default-features = false } diff --git a/gadget-common/src/client.rs b/gadget-common/src/client.rs index 0a768d09..36c97852 100644 --- a/gadget-common/src/client.rs +++ b/gadget-common/src/client.rs @@ -2,7 +2,7 @@ use crate::debug_logger::DebugLogger; use crate::keystore::{ECDSAKeyStore, KeystoreBackend}; use async_trait::async_trait; use gadget_core::gadget::substrate::Client; -use pallet_jobs_rpc_runtime_api::{JobsApi, RpcResponseJobsData}; +use pallet_jobs_rpc_runtime_api::JobsApi; use sc_client_api::{ Backend, BlockImportNotification, BlockchainEvents, FinalityNotification, HeaderBackend, }; @@ -10,7 +10,7 @@ use sp_api::BlockT as Block; use sp_api::ProvideRuntimeApi; use std::error::Error; use std::sync::Arc; -use tangle_primitives::jobs::{JobId, JobKey, JobResult}; +use tangle_primitives::jobs::{JobId, JobKey, JobResult, RpcResponseJobsData}; pub struct MpEcdsaClient { client: Arc, diff --git a/gadget-core/Cargo.toml b/gadget-core/Cargo.toml index 790dd5db..5bf2cdde 100644 --- a/gadget-core/Cargo.toml +++ b/gadget-core/Cargo.toml @@ -7,8 +7,8 @@ edition = "2021" [features] substrate = [ - "sp-runtime", - "sc-client-api", + "sp-runtime", + "sc-client-api", ] [dependencies] @@ -23,4 +23,4 @@ sp-runtime = { optional = true, workspace = true } sc-client-api = { optional = true, workspace = true } [dev-dependencies] -tokio = { workspace = true, features = ["macros"]} \ No newline at end of file +tokio = { workspace = true, features = ["macros"] } diff --git a/protocols/mp-ecdsa/Cargo.toml b/protocols/mp-ecdsa/Cargo.toml index 842531eb..4710c717 100644 --- a/protocols/mp-ecdsa/Cargo.toml +++ b/protocols/mp-ecdsa/Cargo.toml @@ -20,14 +20,13 @@ itertools = { workspace = true } bincode2 = { workspace = true } linked-hash-map = { workspace = true } -pallet-jobs-rpc-runtime-api = { workspace = true, features = ["std"] } -tangle-primitives = { workspace = true, features = ["std"] } - -sp-core = { workspace = true, features = ["std"] } -sp-api = { workspace = true, features = ["std"] } -sp-runtime = { workspace = true, features = ["std"] } -sp-application-crypto = { workspace = true, features = ["std"] } -sp-consensus-aura = { workspace = true, features = ["std"] } +pallet-jobs-rpc-runtime-api = { workspace = true, default-features = false, features = ["std"] } +tangle-primitives = { workspace = true, default-features = false } + +sp-core = { workspace = true } +sp-api = { workspace = true } +sp-runtime = { workspace = true } +sp-application-crypto = { workspace = true } sc-client-api = { workspace = true } sc-network = { workspace = true } diff --git a/zk-gadget/Cargo.toml b/zk-gadget/Cargo.toml index e7d62eb7..e3601118 100644 --- a/zk-gadget/Cargo.toml +++ b/zk-gadget/Cargo.toml @@ -8,6 +8,9 @@ edition = "2021" [dependencies] tokio-rustls = { workspace = true } mpc-net = { workspace = true } +secret-sharing = { workspace = true } +dist-primitives = { workspace = true } +groth16 = { workspace = true } gadget-common = { workspace = true } gadget-core = { workspace = true } bincode2 = { workspace = true } @@ -19,22 +22,33 @@ async-trait = { workspace = true } parking_lot = { workspace = true } log = { workspace = true } bytes = { workspace = true, features = ["serde"] } -sp-runtime = { workspace = true } fflonk = { git = "https://github.com/w3f/fflonk", features = ["std"] } +sp-runtime = { workspace = true } +sp-core = { workspace = true } +sp-api = { workspace = true } + pallet-jobs-rpc-runtime-api = { workspace = true } -tangle-primitives = { workspace = true } +tangle-primitives = { workspace = true, features = ["verifying"] } +ark-serialize = { workspace = true, default-features = false } +ark-groth16 = { workspace = true, default-features = false } +ark-poly = { workspace = true, default-features = false } +ark-relations = { workspace = true, default-features = false } +ark-ec = { workspace = true, default-features = false } +ark-crypto-primitives = { workspace = true, default-features = false } +ark-ff = { workspace = true, default-features = false } +ark-circom = { workspace = true } -sp-core = { workspace = true } -sp-api = { workspace = true } + +ark-bn254 = { workspace = true, default-features = false, features = ["curve"] } [dev-dependencies] sp-runtime = { workspace = true } sc-utils = { workspace = true } sc-client-api = { workspace = true } uuid = { workspace = true, features = ["v4"] } -rcgen = "0.11.3" +rcgen = "0.12" parity-scale-codec = { workspace = true, features = ["derive"] } tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } diff --git a/zk-gadget/src/lib.rs b/zk-gadget/src/lib.rs index 6dae4946..990267b8 100644 --- a/zk-gadget/src/lib.rs +++ b/zk-gadget/src/lib.rs @@ -1,6 +1,7 @@ use crate::client_ext::ClientWithApi; use crate::network::ZkNetworkService; -use crate::protocol::{AdditionalProtocolParams, ZkProtocol}; +use crate::protocol::ZkProtocol; +use client_ext::AccountId; use gadget_common::Error; use mpc_net::prod::RustlsCertificate; use pallet_jobs_rpc_runtime_api::JobsApi; @@ -8,7 +9,6 @@ use sp_api::ProvideRuntimeApi; use sp_core::ecdsa; use sp_runtime::traits::Block; use std::net::SocketAddr; -use std::sync::Arc; use tokio_rustls::rustls::{Certificate, PrivateKey, RootCertStore}; pub mod network; @@ -23,23 +23,12 @@ pub struct ZkGadgetConfig { pub public_identity_der: Vec, pub private_identity_der: Vec, pub client_only_king_public_identity_der: Option>, + pub account_id: AccountId, } -pub struct ZkSaaSConfig {} - -pub async fn run(config: ZkSaaSConfig, client: C) -> Result<(), Error> -where - B: Block, - C: ClientWithApi, - >::Api: JobsApi, -{ - let client = Arc::new(client); -} - -pub async fn run_zk_saas, B: Block, T: AdditionalProtocolParams>( +pub async fn run + 'static, B: Block>( config: ZkGadgetConfig, client: C, - extra_parameters: T, ) -> Result<(), Error> where >::Api: JobsApi, @@ -49,9 +38,9 @@ where log::info!("Created zk network for party {}", config.network_id); let zk_protocol = ZkProtocol { - additional_params: extra_parameters, client: client.clone(), _pd: std::marker::PhantomData, + account_id: config.account_id, network: network.clone(), }; @@ -64,6 +53,7 @@ pub async fn create_zk_network(config: &ZkGadgetConfig) -> Result>>, rx: Option>>, } -struct KingRegistryResult { +pub struct KingRegistryResult { tx: Option>, rx: Option>, } diff --git a/zk-gadget/src/protocol/mod.rs b/zk-gadget/src/protocol/mod.rs index bf0aa4ed..36618513 100644 --- a/zk-gadget/src/protocol/mod.rs +++ b/zk-gadget/src/protocol/mod.rs @@ -1,7 +1,17 @@ use crate::client_ext::{AccountId, ClientWithApi}; use crate::network::{RegistantId, ZkNetworkService, ZkSetupPacket}; use crate::protocol::proto_gen::ZkAsyncProtocolParameters; +use ark_circom::CircomReduction; +use ark_crypto_primitives::snark::SNARK; +use ark_ec::pairing::Pairing; +use ark_ec::CurveGroup; +use ark_ff::Zero; +use ark_groth16::{Groth16, ProvingKey}; +use ark_poly::{EvaluationDomain, Radix2EvaluationDomain}; +use ark_relations::r1cs::SynthesisError; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use async_trait::async_trait; +use futures_util::TryFutureExt; use gadget_common::gadget::message::GadgetProtocolMessage; use gadget_common::gadget::work_manager::WebbWorkManager; use gadget_common::gadget::{Job, WebbGadgetProtocol}; @@ -9,19 +19,27 @@ use gadget_common::protocol::AsyncProtocol; use gadget_common::{BlockImportNotification, Error, FinalityNotification}; use gadget_core::job::{BuiltExecutableJobWrapper, JobBuilder, JobError}; use gadget_core::job_manager::{ProtocolWorkManager, WorkManagerInterface}; +use groth16::proving_key::PackedProvingKeyShare; +use mpc_net::MultiplexedStreamID; use pallet_jobs_rpc_runtime_api::JobsApi; +use secret_sharing::pss::PackedSharingParams; use sp_api::ProvideRuntimeApi; use sp_runtime::traits::Block; use std::collections::HashMap; -use tangle_primitives::jobs::{JobId, JobKey, JobType}; +use tangle_primitives::jobs::{ + HyperData, JobId, JobKey, JobType, ZkSaaSPhaseTwoRequest, ZkSaaSSystem, +}; +use tangle_primitives::verifier::to_field_elements; use tokio::sync::mpsc::UnboundedReceiver; pub mod proto_gen; -pub struct ZkProtocol { +type F = ark_bn254::Fr; +type E = ark_bn254::Bn254; + +pub struct ZkProtocol { pub client: C, pub account_id: AccountId, - pub additional_params: V, pub network: ZkNetworkService, pub _pd: std::marker::PhantomData, } @@ -32,8 +50,7 @@ pub trait AdditionalProtocolParams: Send + Sync + Clone + 'static { } #[async_trait] -impl> WebbGadgetProtocol - for ZkProtocol +impl + 'static> WebbGadgetProtocol for ZkProtocol where >::Api: JobsApi, { @@ -74,6 +91,42 @@ where .latest_retry_id(&task_id) .map(|r| r + 1) .unwrap_or(0); + let phase_one_id = + job.job_type + .get_phase_one_id() + .ok_or_else(|| Error::JobError { + err: JobError { + reason: "Phase one id not found".to_string(), + }, + })?; + let phase_one_job = self + .client + .runtime_api() + .query_phase_one_by_id(notification.hash, JobKey::ZkSaaSCircuit, phase_one_id) + .map_err(|err| crate::Error::ClientError { + err: format!("Failed to query phase one by id: {err:?}"), + })? + .ok_or_else(|| crate::Error::JobError { + err: JobError { + reason: "Phase one job not found".to_string(), + }, + })?; + + let JobType::ZkSaaSPhaseOne(phase_one) = phase_one_job.job_type else { + return Err(Error::JobError { + err: JobError { + reason: "Phase one job type not ZkSaaS".to_string(), + }, + }); + }; + + let JobType::ZkSaaSPhaseTwo(phase_two) = job.job_type else { + return Err(Error::JobError { + err: JobError { + reason: "Phase two job type not ZkSaaS".to_string(), + }, + }); + }; let job_specific_params = ZkJobAdditionalParams { n_parties: participants.len(), @@ -83,17 +136,12 @@ where .expect("Should exist") as _, job_id: job.job_id, job_key: JobKey::ZkSaaSProve, - // TODO: add phase one job data here + system: phase_one.system, + request: phase_two.request, }; let job = self - .create( - session_id, - now, - retry_id, - task_id, - self.additional_params.clone(), - ) + .create(session_id, now, retry_id, task_id, job_specific_params) .await?; ret.push(job); @@ -126,6 +174,8 @@ pub struct ZkJobAdditionalParams { party_id: u32, job_id: JobId, job_key: JobKey, + system: ZkSaaSSystem, + request: ZkSaaSPhaseTwoRequest, } impl AdditionalProtocolParams for ZkJobAdditionalParams { @@ -138,8 +188,7 @@ impl AdditionalProtocolParams for ZkJobAdditionalParams { } #[async_trait] -impl, V: AdditionalProtocolParams> AsyncProtocol - for ZkProtocol +impl + 'static> AsyncProtocol for ZkProtocol where >::Api: pallet_jobs_rpc_runtime_api::JobsApi, @@ -152,58 +201,10 @@ where associated_retry_id: ::RetryID, associated_session_id: ::SessionID, associated_task_id: ::TaskID, - mut protocol_message_rx: UnboundedReceiver, + protocol_message_rx: UnboundedReceiver, additional_params: Self::AdditionalParams, ) -> Result { - let mut txs = HashMap::new(); - let mut rxs = HashMap::new(); - for peer_id in 0..additional_params.n_parties() { - // Create 3 multiplexed channels - let mut txs_for_this_peer = vec![]; - let mut rxs_for_this_peer = vec![]; - for _ in 0..3 { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - txs_for_this_peer.push(tx); - rxs_for_this_peer.push(tokio::sync::Mutex::new(rx)); - } - - txs.insert(peer_id as u32, txs_for_this_peer); - rxs.insert(peer_id as u32, rxs_for_this_peer); - } - - tokio::task::spawn(async move { - while let Some(message) = protocol_message_rx.recv().await { - let message: GadgetProtocolMessage = message; - match bincode2::deserialize::(&message.payload) { - Ok(deserialized) => { - let (source, sid) = (deserialized.source, deserialized.sid); - if let Some(txs) = txs.get(&source) { - if let Some(tx) = txs.get(sid as usize) { - if let Err(err) = tx.send(deserialized) { - log::warn!( - "Failed to forward message from {source} to stream {sid:?} because {err:?}", - ); - } - } else { - log::warn!( - "Failed to forward message from {source} to stream {sid:?} because the tx handle was not found", - ); - } - } else { - log::warn!( - "Failed to forward message from {source} to stream {sid:?} because the tx handle was not found", - ); - } - } - Err(err) => { - log::warn!("Failed to deserialize protocol message: {err:?}"); - } - } - } - - log::warn!("Async protocol message_rx died") - }); - + let rxs = zk_setup_rxs(additional_params.n_parties(), protocol_message_rx).await?; let other_network_ids = zk_setup_phase( additional_params.n_parties(), &associated_task_id, @@ -211,7 +212,7 @@ where ) .await?; - let params = ZkAsyncProtocolParameters::<_, _, _, B> { + let params = ZkAsyncProtocolParameters:: { associated_block_id, associated_retry_id, associated_session_id, @@ -223,19 +224,261 @@ where other_network_ids, network: self.network.clone(), client: self.client.clone(), - extra_parameters: self.additional_params.clone(), + extra_parameters: additional_params.clone(), _pd: Default::default(), }; Ok(JobBuilder::new() .protocol(async move { - // TODO: build the protocol, using the "params" object as a handle that has the MpcNet implementation + log::debug!( + "Running ZkSaaS for {:?} with JobId: {}", + params.extra_parameters.job_key, + params.extra_parameters.job_id + ); + let ZkSaaSSystem::Groth16(ref system) = params.extra_parameters.system; + let ZkSaaSPhaseTwoRequest::Groth16(ref job) = params.extra_parameters.request; + let HyperData::Raw(ref proving_key_bytes) = system.proving_key else { + return Err(JobError { + reason: "Only raw proving key is supported".to_string(), + }); + }; + + let pk = ProvingKey::::deserialize_compressed(&proving_key_bytes[..]).map_err( + |err| JobError { + reason: format!("Failed to deserialize proving key: {err:?}"), + }, + )?; + let l = params.n_parties / 4; + let pp = PackedSharingParams::new(l); + let crs_shares = + PackedProvingKeyShare::::pack_from_arkworks_proving_key(&pk, pp); + let our_qap_share = + job.qap_shares + .get(params.party_id as usize) + .ok_or_else(|| JobError { + reason: "Failed to get our qap share".to_string(), + })?; + let HyperData::Raw(ref qap_a) = our_qap_share.a else { + return Err(JobError { + reason: "Only raw qap_a is supported".to_string(), + }); + }; + let HyperData::Raw(ref qap_b) = our_qap_share.b else { + return Err(JobError { + reason: "Only raw qap_b is supported".to_string(), + }); + }; + let HyperData::Raw(ref qap_c) = our_qap_share.c else { + return Err(JobError { + reason: "Only raw qap_c is supported".to_string(), + }); + }; + let our_a_share = + job.a_shares + .get(params.party_id as usize) + .ok_or_else(|| JobError { + reason: "Failed to get our a share".to_string(), + })?; + let HyperData::Raw(a_share_bytes) = our_a_share else { + return Err(JobError { + reason: "Only raw a_share is supported".to_string(), + }); + }; + + let our_ax_share = + job.ax_shares + .get(params.party_id as usize) + .ok_or_else(|| JobError { + reason: "Failed to get our ax share".to_string(), + })?; + + let HyperData::Raw(ax_share_bytes) = our_ax_share else { + return Err(JobError { + reason: "Only raw ax_share is supported".to_string(), + }); + }; + let m = system.num_inputs + system.num_constraints; + let domain = Radix2EvaluationDomain::::new(m as usize) + .ok_or(SynthesisError::PolynomialDegreeTooLarge) + .map_err(|err| JobError { + reason: format!("Failed to create evaluation domain: {err:?}"), + })?; + let qap_share = groth16::qap::PackedQAPShare { + num_inputs: system.num_inputs as _, + num_constraints: system.num_constraints as _, + a: to_field_elements(qap_a).map_err(|err| JobError { + reason: format!("Failed to convert a to field elements: {err:?}"), + })?, + b: to_field_elements(qap_b).map_err(|err| JobError { + reason: format!("Failed to convert b to field elements: {err:?}"), + })?, + c: to_field_elements(qap_c).map_err(|err| JobError { + reason: format!("Failed to convert c to field elements: {err:?}"), + })?, + domain, + }; + let a_share = to_field_elements(a_share_bytes).map_err(|err| JobError { + reason: format!("Failed to convert a_share to field elements: {err:?}"), + })?; + let ax_share = to_field_elements(ax_share_bytes).map_err(|err| JobError { + reason: format!("Failed to convert ax_share to field elements: {err:?}"), + })?; + let crs_share = + crs_shares + .get(params.party_id as usize) + .ok_or_else(|| JobError { + reason: "Failed to get crs share".to_string(), + })?; + let h_share = groth16::ext_wit::circom_h(qap_share, &pp, ¶ms) + .map_err(|err| JobError { + reason: format!("Failed to compute circom_h: {err:?}"), + }) + .await?; + let pi_a_share = groth16::prove::A:: { + L: Default::default(), + N: Default::default(), + r: ::ScalarField::zero(), + pp: &pp, + S: &crs_share.s, + a: &a_share, + } + .compute(¶ms, MultiplexedStreamID::Zero) + .map_err(|err| JobError { + reason: format!("Failed to compute pi_a_share: {err:?}"), + }) + .await?; + let pi_b_share = groth16::prove::B:: { + Z: Default::default(), + K: Default::default(), + s: ::ScalarField::zero(), + pp: &pp, + V: &crs_share.v, + a: &a_share, + } + .compute(¶ms, MultiplexedStreamID::Zero) + .map_err(|err| JobError { + reason: format!("Failed to compute pi_b_share: {err:?}"), + }) + .await?; + let pi_c_share = groth16::prove::C:: { + W: &crs_share.w, + U: &crs_share.u, + A: pi_a_share, + M: Default::default(), + r: ::ScalarField::zero(), + s: ::ScalarField::zero(), + pp: &pp, + H: &crs_share.h, + a: &a_share, + ax: &ax_share, + h: &h_share, + } + .compute(¶ms) + .map_err(|err| JobError { + reason: format!("Failed to compute pi_c_share: {err:?}"), + }) + .await?; + if params.party_id == 0 { + let (mut a, mut b, c) = (pi_a_share, pi_b_share, pi_c_share); + // These elements are needed to construct the full proof, they are part of the proving key. + // however, we can just send these values to the client, not the full proving key. + a += pk.a_query[0] + pk.vk.alpha_g1; + b += pk.b_g2_query[0] + pk.vk.beta_g2; + + let proof = ark_groth16::Proof:: { + a: a.into_affine(), + b: b.into_affine(), + c: c.into_affine(), + }; + + // Verify the proof + // convert the public inputs from string to bigints + let public_inputs = + to_field_elements(&job.public_input).map_err(|err| JobError { + reason: format!( + "Failed to convert public inputs to field elements: {err:?}" + ), + })?; + let pvk = ark_groth16::prepare_verifying_key(&pk.vk); + let verified = Groth16::::verify_with_processed_vk( + &pvk, + public_inputs.as_slice(), + &proof, + ) + .unwrap(); + if verified { + log::info!("Proof verified"); + } else { + log::error!("Proof verification failed"); + } + let mut proof_bytes = Vec::new(); + proof.serialize_compressed(&mut proof_bytes).unwrap(); + // TODO save the proof to the chain + } Ok(()) }) .build()) } } +async fn zk_setup_rxs( + n_parties: usize, + mut protocol_message_rx: UnboundedReceiver, +) -> Result< + HashMap>>>, + JobError, +> { + let mut txs = HashMap::new(); + let mut rxs = HashMap::new(); + for peer_id in 0..n_parties { + // Create 3 multiplexed channels + let mut txs_for_this_peer = vec![]; + let mut rxs_for_this_peer = vec![]; + for _ in 0..3 { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + txs_for_this_peer.push(tx); + rxs_for_this_peer.push(tokio::sync::Mutex::new(rx)); + } + + txs.insert(peer_id as u32, txs_for_this_peer); + rxs.insert(peer_id as u32, rxs_for_this_peer); + } + + tokio::task::spawn(async move { + while let Some(message) = protocol_message_rx.recv().await { + let message: GadgetProtocolMessage = message; + match bincode2::deserialize::(&message.payload) { + Ok(deserialized) => { + let (source, sid) = (deserialized.source, deserialized.sid); + if let Some(txs) = txs.get(&source) { + if let Some(tx) = txs.get(sid as usize) { + if let Err(err) = tx.send(deserialized) { + log::warn!( + "Failed to forward message from {source} to stream {sid:?} because {err:?}", + ); + } + } else { + log::warn!( + "Failed to forward message from {source} to stream {sid:?} because the tx handle was not found", + ); + } + } else { + log::warn!( + "Failed to forward message from {source} to stream {sid:?} because the tx handle was not found", + ); + } + } + Err(err) => { + log::warn!("Failed to deserialize protocol message: {err:?}"); + } + } + } + + log::warn!("Async protocol message_rx died") + }); + Ok(rxs) +} + /// The goal of the ZK setup phase it to determine the mapping of party_id -> network_id /// This will allow proper routing of messages to the correct parties. ///