Skip to content

Commit

Permalink
chore: add the gateway compiler to appstate (#391)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArniStarkware authored Jul 14, 2024
1 parent ffb8b05 commit 47925f1
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 16 deletions.
1 change: 1 addition & 0 deletions crates/gateway/src/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::utils::is_subsequence;
#[path = "compilation_test.rs"]
mod compilation_test;

#[derive(Clone)]
pub struct GatewayCompiler {
#[allow(dead_code)]
pub config: GatewayCompilerConfig,
Expand Down
2 changes: 1 addition & 1 deletion crates/gateway/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ impl StatefulTransactionValidatorConfig {
}
}

#[derive(Clone, Debug, Default, Serialize, Deserialize, Validate, PartialEq)]
#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize, Validate, PartialEq)]
pub struct GatewayCompilerConfig {}

impl SerializeConfig for GatewayCompilerConfig {
Expand Down
9 changes: 7 additions & 2 deletions crates/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ pub struct AppState {
pub stateless_tx_validator: StatelessTransactionValidator,
pub stateful_tx_validator: Arc<StatefulTransactionValidator>,
pub state_reader_factory: Arc<dyn StateReaderFactory>,
pub gateway_compiler: GatewayCompiler,
pub mempool_client: SharedMempoolClient,
}

impl Gateway {
pub fn new(
config: GatewayConfig,
state_reader_factory: Arc<dyn StateReaderFactory>,
gateway_compiler: GatewayCompiler,
mempool_client: SharedMempoolClient,
) -> Self {
let app_state = AppState {
Expand All @@ -53,6 +55,7 @@ impl Gateway {
config: config.stateful_tx_validator_config.clone(),
}),
state_reader_factory,
gateway_compiler,
mempool_client,
};
Gateway { config, app_state }
Expand Down Expand Up @@ -93,6 +96,7 @@ async fn add_tx(
app_state.stateless_tx_validator,
app_state.stateful_tx_validator.as_ref(),
app_state.state_reader_factory.as_ref(),
app_state.gateway_compiler,
tx,
)
})
Expand All @@ -113,6 +117,7 @@ fn process_tx(
stateless_tx_validator: StatelessTransactionValidator,
stateful_tx_validator: &StatefulTransactionValidator,
state_reader_factory: &dyn StateReaderFactory,
gateway_compiler: GatewayCompiler,
tx: RPCTransaction,
) -> GatewayResult<MempoolInput> {
// TODO(Arni, 1/5/2024): Perform congestion control.
Expand All @@ -123,7 +128,6 @@ fn process_tx(
// Compile Sierra to Casm.
let optional_class_info = match &tx {
RPCTransaction::Declare(declare_tx) => {
let gateway_compiler = GatewayCompiler { config: Default::default() };
Some(gateway_compiler.compile_contract_class(declare_tx)?)
}
_ => None,
Expand All @@ -145,7 +149,8 @@ pub fn create_gateway(
client: SharedMempoolClient,
) -> Gateway {
let state_reader_factory = Arc::new(RpcStateReaderFactory { config: rpc_state_reader_config });
Gateway::new(config, state_reader_factory, client)
let gateway_compiler = GatewayCompiler { config: config.compiler_config };
Gateway::new(config, state_reader_factory, gateway_compiler, client)
}

#[async_trait]
Expand Down
20 changes: 12 additions & 8 deletions crates/gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ use starknet_mempool_types::communication::{MempoolClientImpl, MempoolRequestAnd
use tokio::sync::mpsc::channel;
use tokio::task;

use crate::config::{StatefulTransactionValidatorConfig, StatelessTransactionValidatorConfig};
use crate::config::{
GatewayCompilerConfig, StatefulTransactionValidatorConfig, StatelessTransactionValidatorConfig,
};
use crate::gateway::{add_tx, AppState, GatewayCompiler, SharedMempoolClient};
use crate::state_reader_test_utils::{
local_test_state_reader_factory, local_test_state_reader_factory_for_deploy_account,
Expand Down Expand Up @@ -52,6 +54,7 @@ pub fn app_state(
stateful_tx_validator: Arc::new(StatefulTransactionValidator {
config: StatefulTransactionValidatorConfig::create_for_testing(),
}),
gateway_compiler: GatewayCompiler { config: GatewayCompilerConfig {} },
state_reader_factory: Arc::new(state_reader_factory),
mempool_client,
}
Expand Down Expand Up @@ -94,7 +97,7 @@ async fn test_add_tx(

let app_state = app_state(mempool_client, state_reader_factory);

let tx_hash = calculate_hash(&tx);
let tx_hash = calculate_hash(&tx, &app_state.gateway_compiler);
let response = add_tx(State(app_state), tx.into()).await.into_response();

let status_code = response.status();
Expand All @@ -108,13 +111,14 @@ async fn to_bytes(res: Response) -> Bytes {
res.into_body().collect().await.unwrap().to_bytes()
}

fn calculate_hash(external_tx: &RPCTransaction) -> TransactionHash {
fn calculate_hash(
external_tx: &RPCTransaction,
gateway_compiler: &GatewayCompiler,
) -> TransactionHash {
let optional_class_info = match &external_tx {
RPCTransaction::Declare(declare_tx) => Some(
GatewayCompiler { config: Default::default() }
.compile_contract_class(declare_tx)
.unwrap(),
),
RPCTransaction::Declare(declare_tx) => {
Some(gateway_compiler.compile_contract_class(declare_tx).unwrap())
}
_ => None,
};

Expand Down
11 changes: 6 additions & 5 deletions crates/gateway/src/stateful_transaction_validator_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use starknet_api::transaction::TransactionHash;
use starknet_types_core::felt::Felt;

use crate::compilation::GatewayCompiler;
use crate::config::StatefulTransactionValidatorConfig;
use crate::config::{GatewayCompilerConfig, StatefulTransactionValidatorConfig};
use crate::errors::{StatefulTransactionValidatorError, StatefulTransactionValidatorResult};
use crate::state_reader_test_utils::{
local_test_state_reader_factory, local_test_state_reader_factory_for_deploy_account,
Expand Down Expand Up @@ -95,10 +95,11 @@ fn test_stateful_tx_validator(
stateful_validator: StatefulTransactionValidator,
) {
let optional_class_info = match &external_tx {
RPCTransaction::Declare(declare_tx) => {
let gateway_compiler = GatewayCompiler { config: Default::default() };
Some(gateway_compiler.compile_contract_class(declare_tx).unwrap())
}
RPCTransaction::Declare(declare_tx) => Some(
GatewayCompiler { config: GatewayCompilerConfig {} }
.compile_contract_class(declare_tx)
.unwrap(),
),
_ => None,
};

Expand Down

0 comments on commit 47925f1

Please sign in to comment.