From 9456a5549ba6c5132727bb57d6b10006cc792d6a Mon Sep 17 00:00:00 2001 From: YaelD <70628564+Yael-Starkware@users.noreply.github.com> Date: Mon, 1 Jul 2024 14:19:50 +0300 Subject: [PATCH] test: add deploy_account to integration test (#302) --- Cargo.lock | 1 + Cargo.toml | 1 + crates/gateway/Cargo.toml | 2 +- crates/gateway/src/starknet_api_test_utils.rs | 19 ++++- crates/gateway/src/state_reader_test_utils.rs | 21 +---- crates/tests-integration/Cargo.toml | 1 + crates/tests-integration/src/state_reader.rs | 84 +++++++++++++++---- .../tests/end_to_end_test.rs | 18 ++-- 8 files changed, 101 insertions(+), 46 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3220ba49..0491c040 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5553,6 +5553,7 @@ dependencies = [ "blockifier 0.7.0-dev.1 (git+https://github.com/starkware-libs/blockifier.git?branch=main-mempool)", "cairo-lang-starknet-classes", "indexmap 2.2.6", + "lazy_static", "papyrus_common", "papyrus_rpc", "papyrus_storage", diff --git a/Cargo.toml b/Cargo.toml index cca58215..63e2a21d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,6 +52,7 @@ futures = "0.3.30" hyper = { version = "0.14", features = ["client", "http1", "http2"] } indexmap = "2.1.0" itertools = "0.13.0" +lazy_static = "1.4.0" num-bigint = { version = "0.4.5", default-features = false } # TODO(YaelD, 28/5/2024): The special Papyrus version is needed in order to be aligned with the # starknet-api version. This should be removed once we have a mono-repo. diff --git a/crates/gateway/Cargo.toml b/crates/gateway/Cargo.toml index b66a6ec8..c9514273 100644 --- a/crates/gateway/Cargo.toml +++ b/crates/gateway/Cargo.toml @@ -13,6 +13,7 @@ testing = [] [dependencies] axum.workspace = true +assert_matches.workspace = true blockifier.workspace = true cairo-lang-starknet-classes.workspace = true cairo-vm.workspace = true @@ -30,7 +31,6 @@ tokio.workspace = true validator.workspace = true [dev-dependencies] -assert_matches.workspace = true pretty_assertions.workspace = true rstest.workspace = true starknet_mempool = { path = "../mempool", version = "0.0" } diff --git a/crates/gateway/src/starknet_api_test_utils.rs b/crates/gateway/src/starknet_api_test_utils.rs index 10076c9a..ab5547b9 100644 --- a/crates/gateway/src/starknet_api_test_utils.rs +++ b/crates/gateway/src/starknet_api_test_utils.rs @@ -2,10 +2,13 @@ use std::env; use std::fs::File; use std::path::Path; +use assert_matches::assert_matches; use blockifier::test_utils::contracts::FeatureContract; use blockifier::test_utils::{create_trivial_calldata, CairoVersion, NonceManager}; use serde_json::to_string_pretty; -use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; +use starknet_api::core::{ + calculate_contract_address, ClassHash, CompiledClassHash, ContractAddress, Nonce, +}; use starknet_api::data_availability::DataAvailabilityMode; use starknet_api::hash::StarkFelt; use starknet_api::rpc_transaction::{ @@ -380,3 +383,17 @@ pub fn deploy_account_tx() -> RPCTransaction { resource_bounds: executable_resource_bounds_mapping(), )) } + +pub fn deployed_account_contract_address(deploy_tx: &RPCTransaction) -> ContractAddress { + let tx = assert_matches!( + deploy_tx, + RPCTransaction::DeployAccount(RPCDeployAccountTransaction::V3(tx)) => tx + ); + calculate_contract_address( + tx.contract_address_salt, + tx.class_hash, + &tx.constructor_calldata, + ContractAddress::default(), + ) + .unwrap() +} diff --git a/crates/gateway/src/state_reader_test_utils.rs b/crates/gateway/src/state_reader_test_utils.rs index 4c12325a..38258cd9 100644 --- a/crates/gateway/src/state_reader_test_utils.rs +++ b/crates/gateway/src/state_reader_test_utils.rs @@ -1,4 +1,3 @@ -use assert_matches::assert_matches; use blockifier::blockifier::block::BlockInfo; use blockifier::context::{BlockContext, ChainInfo}; use blockifier::execution::contract_class::ContractClass; @@ -10,13 +9,12 @@ use blockifier::test_utils::initial_test_state::{fund_account, test_state_reader use blockifier::test_utils::{CairoVersion, BALANCE}; use blockifier::versioned_constants::VersionedConstants; use starknet_api::block::BlockNumber; -use starknet_api::core::{ - calculate_contract_address, ClassHash, CompiledClassHash, ContractAddress, Nonce, -}; +use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::hash::StarkFelt; -use starknet_api::rpc_transaction::{RPCDeployAccountTransaction, RPCTransaction}; +use starknet_api::rpc_transaction::RPCTransaction; use starknet_api::state::StorageKey; +use crate::starknet_api_test_utils::deployed_account_contract_address; use crate::state_reader::{MempoolStateReader, StateReaderFactory}; #[derive(Clone)] @@ -104,18 +102,7 @@ pub fn local_test_state_reader_factory_for_deploy_account( let mut state_reader_factory = local_test_state_reader_factory(CairoVersion::Cairo1, false); // Fund the deployed_account_address. - let tx = assert_matches!( - deploy_tx, - RPCTransaction::DeployAccount(RPCDeployAccountTransaction::V3(tx)) => tx - ); - - let deployed_account_address = calculate_contract_address( - tx.contract_address_salt, - tx.class_hash, - &tx.constructor_calldata, - ContractAddress::default(), - ) - .unwrap(); + let deployed_account_address = deployed_account_contract_address(deploy_tx); fund_account( BlockContext::create_for_testing().chain_info(), deployed_account_address, diff --git a/crates/tests-integration/Cargo.toml b/crates/tests-integration/Cargo.toml index bfd56dbf..b6e36b7d 100644 --- a/crates/tests-integration/Cargo.toml +++ b/crates/tests-integration/Cargo.toml @@ -13,6 +13,7 @@ axum.workspace = true blockifier.workspace = true cairo-lang-starknet-classes.workspace = true indexmap.workspace = true +lazy_static.workspace = true papyrus_common.workspace = true papyrus_rpc.workspace = true papyrus_storage.workspace = true diff --git a/crates/tests-integration/src/state_reader.rs b/crates/tests-integration/src/state_reader.rs index e886d2ae..91332630 100644 --- a/crates/tests-integration/src/state_reader.rs +++ b/crates/tests-integration/src/state_reader.rs @@ -6,11 +6,12 @@ use blockifier::context::{BlockContext, ChainInfo}; use blockifier::test_utils::contracts::FeatureContract; use blockifier::test_utils::{ CairoVersion, BALANCE, CURRENT_BLOCK_TIMESTAMP, DEFAULT_ETH_L1_GAS_PRICE, - DEFAULT_STRK_L1_GAS_PRICE, + DEFAULT_STRK_L1_GAS_PRICE, TEST_SEQUENCER_ADDRESS, }; use blockifier::transaction::objects::FeeType; use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; use indexmap::{indexmap, IndexMap}; +use lazy_static::lazy_static; use papyrus_common::pending_classes::PendingClasses; use papyrus_common::BlockHashAndNumber; use papyrus_rpc::{run_server, RpcConfig}; @@ -23,14 +24,17 @@ use papyrus_storage::{open_storage, StorageConfig, StorageReader}; use starknet_api::block::{ BlockBody, BlockHeader, BlockNumber, BlockTimestamp, GasPrice, GasPricePerToken, }; -use starknet_api::core::{ClassHash, ContractAddress}; +use starknet_api::core::{ClassHash, ContractAddress, PatriciaKey, SequencerContractAddress}; use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; -use starknet_api::hash::StarkFelt; -use starknet_api::stark_felt; +use starknet_api::hash::{StarkFelt, StarkHash}; use starknet_api::state::{StorageKey, ThinStateDiff}; +use starknet_api::{contract_address, patricia_key, stark_felt}; use starknet_client::reader::PendingData; use starknet_gateway::config::RpcStateReaderConfig; use starknet_gateway::rpc_state_reader::RpcStateReaderFactory; +use starknet_gateway::starknet_api_test_utils::{ + deploy_account_tx, deployed_account_contract_address, +}; use strum::IntoEnumIterator; use tempfile::tempdir; use tokio::sync::RwLock; @@ -38,6 +42,13 @@ use tokio::sync::RwLock; type ContractClassesMap = (Vec<(ClassHash, DeprecatedContractClass)>, Vec<(ClassHash, CasmContractClass)>); +lazy_static! { + static ref DEPLOY_ACCCOUNT_TX_CONTRACT_ADDRESS: ContractAddress = { + let deploy_tx = deploy_account_tx(); + deployed_account_contract_address(&deploy_tx) + }; +} + /// StateReader for integration tests. /// /// Creates a papyrus storage reader and Spawns a papyrus rpc server for it. @@ -52,16 +63,20 @@ pub async fn rpc_test_state_reader_factory( let test_contract_cairo0 = FeatureContract::TestContract(CairoVersion::Cairo0); let account_contract_cairo1 = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1); let test_contract_cairo1 = FeatureContract::TestContract(CairoVersion::Cairo1); + let erc20 = FeatureContract::ERC20; + let fund_accounts = vec![*DEPLOY_ACCCOUNT_TX_CONTRACT_ADDRESS]; let storage_reader = initialize_papyrus_test_state( block_context.chain_info(), BALANCE, &[ + (erc20, 1), (account_contract_cairo0, 1), (test_contract_cairo0, 1), (account_contract_cairo1, n_initialized_account_contracts), (test_contract_cairo1, 1), ], + fund_accounts, ); let addr = run_papyrus_rpc_server(storage_reader).await; @@ -77,8 +92,14 @@ fn initialize_papyrus_test_state( chain_info: &ChainInfo, initial_balances: u128, contract_instances: &[(FeatureContract, u16)], + fund_additional_accounts: Vec, ) -> StorageReader { - let state_diff = prepare_state_diff(chain_info, contract_instances, initial_balances); + let state_diff = prepare_state_diff( + chain_info, + contract_instances, + initial_balances, + fund_additional_accounts, + ); let (cairo0_contract_classes, cairo1_contract_classes) = prepare_compiled_contract_classes(contract_instances); @@ -90,6 +111,7 @@ fn prepare_state_diff( chain_info: &ChainInfo, contract_instances: &[(FeatureContract, u16)], initial_balances: u128, + fund_accounts: Vec, ) -> ThinStateDiff { let erc20 = FeatureContract::ERC20; let erc20_class_hash = erc20.get_class_hash(); @@ -116,10 +138,20 @@ fn prepare_state_diff( } deployed_contracts .insert(contract.get_instance_address(instance), contract.get_class_hash()); - fund_account(&mut storage_diffs, contract, instance, initial_balances, chain_info); + fund_feature_account_contract( + &mut storage_diffs, + contract, + instance, + initial_balances, + chain_info, + ); } } + fund_accounts.iter().for_each(|address| { + fund_account(&mut storage_diffs, address, initial_balances, chain_info) + }); + ThinStateDiff { storage_diffs, deployed_contracts, @@ -132,8 +164,8 @@ fn prepare_state_diff( fn prepare_compiled_contract_classes( contract_instances: &[(FeatureContract, u16)], ) -> ContractClassesMap { - let mut cairo0_contract_classes: Vec<(ClassHash, DeprecatedContractClass)> = Vec::new(); - let mut cairo1_contract_classes: Vec<(ClassHash, CasmContractClass)> = Vec::new(); + let mut cairo0_contract_classes = Vec::new(); + let mut cairo1_contract_classes = Vec::new(); for (contract, _) in contract_instances.iter() { match cairo_version(contract) { CairoVersion::Cairo0 => { @@ -196,6 +228,7 @@ fn cairo_version(contract: &FeatureContract) -> CairoVersion { | FeatureContract::Empty(version) | FeatureContract::FaultyAccount(version) | FeatureContract::TestContract(version) => *version, + FeatureContract::ERC20 => CairoVersion::Cairo0, _ => panic!("{contract:?} contract has no configurable version."), } } @@ -203,6 +236,7 @@ fn cairo_version(contract: &FeatureContract) -> CairoVersion { fn test_block_header(block_number: BlockNumber) -> BlockHeader { BlockHeader { block_number, + sequencer: SequencerContractAddress(contract_address!(TEST_SEQUENCER_ADDRESS)), l1_gas_price: GasPricePerToken { price_in_wei: GasPrice(DEFAULT_ETH_L1_GAS_PRICE), price_in_fri: GasPrice(DEFAULT_STRK_L1_GAS_PRICE), @@ -216,7 +250,7 @@ fn test_block_header(block_number: BlockNumber) -> BlockHeader { } } -fn fund_account( +fn fund_feature_account_contract( storage_diffs: &mut IndexMap>, contract: &FeatureContract, instance: u16, @@ -227,20 +261,34 @@ fn fund_account( FeatureContract::AccountWithLongValidate(_) | FeatureContract::AccountWithoutValidations(_) | FeatureContract::FaultyAccount(_) => { - let key_value = indexmap! { - get_fee_token_var_address(contract.get_instance_address(instance)) => stark_felt!(initial_balances), - }; - for fee_type in FeeType::iter() { - storage_diffs - .entry(chain_info.fee_token_address(&fee_type)) - .or_default() - .extend(key_value.clone()); - } + fund_account( + storage_diffs, + &contract.get_instance_address(instance), + initial_balances, + chain_info, + ); } _ => (), } } +fn fund_account( + storage_diffs: &mut IndexMap>, + account_address: &ContractAddress, + initial_balances: u128, + chain_info: &ChainInfo, +) { + let key_value = indexmap! { + get_fee_token_var_address(*account_address) => stark_felt!(initial_balances), + }; + for fee_type in FeeType::iter() { + storage_diffs + .entry(chain_info.fee_token_address(&fee_type)) + .or_default() + .extend(key_value.clone()); + } +} + // TODO(Yael 5/6/2024): remove this function and use the one from papyrus test utils once we have // mono-repo. fn get_test_highest_block() -> Arc>> { diff --git a/crates/tests-integration/tests/end_to_end_test.rs b/crates/tests-integration/tests/end_to_end_test.rs index e68e0998..3364acbc 100644 --- a/crates/tests-integration/tests/end_to_end_test.rs +++ b/crates/tests-integration/tests/end_to_end_test.rs @@ -1,24 +1,24 @@ use blockifier::test_utils::CairoVersion; use starknet_api::transaction::TransactionHash; -use starknet_gateway::starknet_api_test_utils::invoke_tx; +use starknet_gateway::starknet_api_test_utils::{deploy_account_tx, invoke_tx}; use starknet_mempool_integration_tests::integration_test_setup::IntegrationTestSetup; #[tokio::test] async fn test_end_to_end() { let mut mock_running_system = IntegrationTestSetup::new(1).await; - let mut expected_tx_hashs = Vec::new(); - expected_tx_hashs + let mut expected_tx_hashes = Vec::new(); + expected_tx_hashes .push(mock_running_system.assert_add_tx_success(&invoke_tx(CairoVersion::Cairo0)).await); - expected_tx_hashs + expected_tx_hashes .push(mock_running_system.assert_add_tx_success(&invoke_tx(CairoVersion::Cairo1)).await); + expected_tx_hashes.push(mock_running_system.assert_add_tx_success(&deploy_account_tx()).await); - let mempool_txs = mock_running_system.get_txs(3).await; - - assert_eq!(mempool_txs.len(), 2); + let mempool_txs = mock_running_system.get_txs(4).await; + assert_eq!(mempool_txs.len(), 3); let mut actual_tx_hashes: Vec = mempool_txs.iter().map(|tx| tx.tx_hash).collect(); actual_tx_hashes.sort(); - expected_tx_hashs.sort(); - assert_eq!(expected_tx_hashs, actual_tx_hashes); + expected_tx_hashes.sort(); + assert_eq!(expected_tx_hashes, actual_tx_hashes); }