Skip to content

Commit

Permalink
refactor: gateway compiler handle declare tx
Browse files Browse the repository at this point in the history
  • Loading branch information
ArniStarkware committed Jul 18, 2024
1 parent ab3a562 commit 989780d
Show file tree
Hide file tree
Showing 10 changed files with 8,095 additions and 67 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 40 additions & 30 deletions crates/gateway/src/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractCl
use cairo_lang_starknet_classes::casm_contract_class::{
CasmContractClass, CasmContractEntryPoints,
};
use cairo_lang_starknet_classes::contract_class::ContractClass as CairoLangContractClass;
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::RPCDeclareTransaction;
use starknet_sierra_compile::compile::compile_sierra_to_casm;
Expand All @@ -29,44 +30,37 @@ impl GatewayCompiler {
/// Formats the contract class for compilation, compiles it, and returns the compiled contract
/// class wrapped in a [`ClassInfo`].
/// Assumes the contract class is of a Sierra program which is compiled to Casm.
pub fn compile_contract_class(
pub fn process_declare_tx(
&self,
declare_tx: &RPCDeclareTransaction,
) -> GatewayResult<ClassInfo> {
let RPCDeclareTransaction::V3(tx) = declare_tx;
let starknet_api_contract_class = &tx.contract_class;
let cairo_lang_contract_class =
into_contract_class_for_compilation(starknet_api_contract_class);
let rpc_contract_class = &tx.contract_class;
let cairo_lang_contract_class = into_contract_class_for_compilation(rpc_contract_class);

// Compile Sierra to Casm.
let catch_unwind_result =
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
let casm_contract_class = match catch_unwind_result {
Ok(compilation_result) => compilation_result?,
Err(_) => {
// TODO(Arni): Log the panic.
return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic));
}
};
let casm_contract_class = self.compile(cairo_lang_contract_class)?;

validate_compiled_class_hash(&casm_contract_class, &tx.compiled_class_hash)?;
self.validate_casm_class(&casm_contract_class)?;

let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash());
if hash_result != tx.compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: tx.compiled_class_hash,
hash_result,
});
}
Ok(ClassInfo::new(
&ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?),
rpc_contract_class.sierra_program.len(),
rpc_contract_class.abi.len(),
)?)
}

// Convert Casm contract class to Starknet contract class directly.
let blockifier_contract_class =
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
let class_info = ClassInfo::new(
&blockifier_contract_class,
starknet_api_contract_class.sierra_program.len(),
starknet_api_contract_class.abi.len(),
)?;
Ok(class_info)
/// TODO(Arni): Pass the compilation args from the config.
fn compile(
&self,
cairo_lang_contract_class: CairoLangContractClass,
) -> Result<CasmContractClass, GatewayError> {
let catch_unwind_result =
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
let casm_contract_class =
catch_unwind_result.map_err(|_| CompilationUtilError::CompilationPanic)??;

Ok(casm_contract_class)
}

// TODO(Arni): Add test.
Expand Down Expand Up @@ -101,3 +95,19 @@ fn supported_builtins() -> &'static Vec<String> {
SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::<Vec<String>>()
})
}

/// Validates that the compiled class hash of the compiled contract class matches the supplied
/// compiled class hash.
fn validate_compiled_class_hash(
casm_contract_class: &CasmContractClass,
supplied_compiled_class_hash: &CompiledClassHash,
) -> Result<(), GatewayError> {
let compiled_class_hash = CompiledClassHash(casm_contract_class.compiled_class_hash());
if compiled_class_hash != *supplied_compiled_class_hash {
return Err(GatewayError::CompiledClassHashMismatch {
supplied: *supplied_compiled_class_hash,
hash_result: compiled_class_hash,
});
}
Ok(())
}
53 changes: 27 additions & 26 deletions crates/gateway/src/compilation_test.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
use assert_matches::assert_matches;
use blockifier::execution::contract_class::ContractClass;
use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError;
use mempool_test_utils::starknet_api_test_utils::declare_tx;
use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass;
use mempool_test_utils::starknet_api_test_utils::{
casm_contract_class, compiled_class_hash, contract_class, declare_tx,
};
use rstest::{fixture, rstest};
use starknet_api::core::CompiledClassHash;
use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction};
use starknet_api::rpc_transaction::{
ContractClass as RpcContractClass, RPCDeclareTransaction, RPCTransaction,
};
use starknet_sierra_compile::errors::CompilationUtilError;
use starknet_sierra_compile::utils::into_contract_class_for_compilation;

use crate::compilation::GatewayCompiler;
use crate::compilation::{validate_compiled_class_hash, GatewayCompiler};
use crate::errors::GatewayError;

#[fixture]
Expand All @@ -16,36 +22,31 @@ fn gateway_compiler() -> GatewayCompiler {
}

#[rstest]
fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: GatewayCompiler) {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
);
let expected_hash_result = tx.compiled_class_hash;
let supplied_hash = CompiledClassHash::default();

tx.compiled_class_hash = supplied_hash;
let declare_tx = RPCDeclareTransaction::V3(tx);
fn test_compile_contract_class_compiled_class_hash_mismatch(
casm_contract_class: CasmContractClass,
compiled_class_hash: CompiledClassHash,
) {
let wrong_supplied_hash = CompiledClassHash::default();
let expected_hash = compiled_class_hash;

let result = gateway_compiler.compile_contract_class(&declare_tx);
let result = validate_compiled_class_hash(&casm_contract_class, &wrong_supplied_hash);
assert_matches!(
result.unwrap_err(),
GatewayError::CompiledClassHashMismatch { supplied, hash_result }
if supplied == supplied_hash && hash_result == expected_hash_result
if supplied == wrong_supplied_hash && hash_result == expected_hash
);
}

#[rstest]
fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
let mut tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
);
// Truncate the sierra program to trigger an error.
tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec();
let declare_tx = RPCDeclareTransaction::V3(tx);
fn test_compile_contract_class_bad_sierra(
gateway_compiler: GatewayCompiler,
mut contract_class: RpcContractClass,
) {
// Create a corrupted contract class.
contract_class.sierra_program = contract_class.sierra_program[..100].to_vec();

let result = gateway_compiler.compile_contract_class(&declare_tx);
let cairo_lang_contract_class = into_contract_class_for_compilation(&contract_class);
let result = gateway_compiler.compile(cairo_lang_contract_class);
assert_matches!(
result.unwrap_err(),
GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError(
Expand All @@ -55,15 +56,15 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
}

#[rstest]
fn test_compile_contract_class(gateway_compiler: GatewayCompiler) {
fn test_process_declare_tx(gateway_compiler: GatewayCompiler) {
let declare_tx = assert_matches!(
declare_tx(),
RPCTransaction::Declare(declare_tx) => declare_tx
);
let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx;
let contract_class = &declare_tx_v3.contract_class;

let class_info = gateway_compiler.compile_contract_class(&declare_tx).unwrap();
let class_info = gateway_compiler.process_declare_tx(&declare_tx).unwrap();
assert_matches!(class_info.contract_class(), ContractClass::V1(_));
assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len());
assert_eq!(class_info.abi_length(), contract_class.abi.len());
Expand Down
2 changes: 1 addition & 1 deletion crates/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn process_tx(
// Compile Sierra to Casm.
let optional_class_info = match &tx {
RPCTransaction::Declare(declare_tx) => {
Some(gateway_compiler.compile_contract_class(declare_tx)?)
Some(gateway_compiler.process_declare_tx(declare_tx)?)
}
_ => None,
};
Expand Down
2 changes: 1 addition & 1 deletion crates/gateway/src/stateful_transaction_validator_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ fn test_stateful_tx_validator(
let optional_class_info = match &external_tx {
RPCTransaction::Declare(declare_tx) => Some(
GatewayCompiler { config: GatewayCompilerConfig {} }
.compile_contract_class(declare_tx)
.process_declare_tx(declare_tx)
.unwrap(),
),
_ => None,
Expand Down
2 changes: 2 additions & 0 deletions crates/mempool_test_utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ license.workspace = true
[dependencies]
assert_matches.workspace = true
blockifier = { workspace = true, features = ["testing"] }
cairo-lang-starknet-classes.workspace = true
rstest.workspace = true
starknet-types-core.workspace = true
starknet_api.workspace = true
serde_json.workspace = true
Expand Down
1 change: 1 addition & 0 deletions crates/mempool_test_utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod starknet_api_test_utils;

pub const TEST_FILES_FOLDER: &str = "crates/mempool_test_utils/test_files";
pub const CONTRACT_CLASS_FILE: &str = "contract_class.json";
pub const CASM_CONTRACT_CLASS_FILE: &str = "casm_contract_class.json";
pub const COMPILED_CLASS_HASH_OF_CONTRACT_CLASS: &str =
"0x01e4f1248860f32c336f93f2595099aaa4959be515e40b75472709ef5243ae17";
pub const FAULTY_ACCOUNT_CLASS_FILE: &str = "faulty_account.sierra.json";
Expand Down
31 changes: 27 additions & 4 deletions crates/mempool_test_utils/src/starknet_api_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use std::rc::Rc;
use assert_matches::assert_matches;
use blockifier::test_utils::contracts::FeatureContract;
use blockifier::test_utils::{create_trivial_calldata, CairoVersion, NonceManager};
use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass;
use rstest::fixture;
use serde_json::to_string_pretty;
use starknet_api::core::{
calculate_contract_address, ClassHash, CompiledClassHash, ContractAddress, Nonce,
Expand All @@ -26,7 +28,8 @@ use starknet_types_core::felt::Felt;

use crate::{
declare_tx_args, deploy_account_tx_args, get_absolute_path, invoke_tx_args,
COMPILED_CLASS_HASH_OF_CONTRACT_CLASS, CONTRACT_CLASS_FILE, TEST_FILES_FOLDER,
CASM_CONTRACT_CLASS_FILE, COMPILED_CLASS_HASH_OF_CONTRACT_CLASS, CONTRACT_CLASS_FILE,
TEST_FILES_FOLDER,
};

pub const VALID_L1_GAS_MAX_AMOUNT: u64 = 203484;
Expand Down Expand Up @@ -90,11 +93,31 @@ pub fn executable_resource_bounds_mapping() -> ResourceBoundsMapping {
)
}

pub fn declare_tx() -> RPCTransaction {
/// Get the contract class used for testing.
#[fixture]
pub fn contract_class() -> ContractClass {
env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir.");
let json_file_path = Path::new(CONTRACT_CLASS_FILE);
let contract_class = serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap();
let compiled_class_hash = CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS));
serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap()
}

/// Get the casm contract class corresponding to the contract class used for testing.
#[fixture]
pub fn casm_contract_class() -> CasmContractClass {
env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir.");
let json_file_path = Path::new(CASM_CONTRACT_CLASS_FILE);
serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap()
}

/// Get the compiled class hash corresponding to the contract class used for testing.
#[fixture]
pub fn compiled_class_hash() -> CompiledClassHash {
CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS))
}

pub fn declare_tx() -> RPCTransaction {
let contract_class = contract_class();
let compiled_class_hash = compiled_class_hash();

let account_contract = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1);
let account_address = account_contract.get_instance_address(0);
Expand Down
Loading

0 comments on commit 989780d

Please sign in to comment.