diff --git a/corelib/src/starknet/storage/vec.cairo b/corelib/src/starknet/storage/vec.cairo index a8fa4ca085d..0fdbf204e5b 100644 --- a/corelib/src/starknet/storage/vec.cairo +++ b/corelib/src/starknet/storage/vec.cairo @@ -321,6 +321,55 @@ pub trait MutableVecTrait { /// } /// ``` fn append(self: T) -> StoragePath>; + + /// Pushes a new value onto the vector. + /// + /// This operation: + /// 1. Increments the vector's length. + /// 2. Writes the provided value to the new storage location at the end of the vector. + /// + /// # Examples + /// + /// ``` + /// use core::starknet::storage::{Vec, MutableVecTrait}; + /// + /// #[storage] + /// struct Storage { + /// numbers: Vec, + /// } + /// + /// fn push_number(ref self: ContractState, number: u256) { + /// self.numbers.push(number); + /// } + /// ``` + fn push<+Drop, +starknet::Store>( + self: T, value: Self::ElementType, + ); + + /// Pops the last value off the vector. + /// + /// This operation: + /// 1. Retrieves the value stored at the last position in the vector. + /// 2. Decrements the vector's length. + /// 3. Returns the retrieved value or `None` if the vector is empty. + /// + /// # Examples + /// + /// ``` + /// use core::starknet::storage::{Vec, MutableVecTrait}; + /// + /// #[storage] + /// struct Storage { + /// numbers: Vec, + /// } + /// + /// fn pop_number(ref self: ContractState) -> Option { + /// self.numbers.pop() + /// } + /// ``` + fn pop<+Drop, +starknet::Store>( + self: T, + ) -> Option; } /// Implement `MutableVecTrait` for `StoragePath>`. @@ -350,8 +399,34 @@ impl MutableVecImpl of MutableVecTrait>>> { self.as_ptr().write(vec_len + 1); self.update(vec_len) } -} + fn push<+Drop, +starknet::Store>( + self: StoragePath>>, value: Self::ElementType, + ) { + self.append().write(value); + } + + fn pop<+Drop, +starknet::Store>( + self: StoragePath>>, + ) -> Option { + let len_ptr = self.as_ptr(); + let vec_len: u64 = len_ptr.read(); + if vec_len == 0 { + return None; + } + let entry: StoragePath> = self.update(vec_len - 1); + let last_element = entry.read(); + // Remove the element's data from the storage. + let entry_ptr = entry.as_ptr(); + starknet::SyscallResultTrait::unwrap_syscall( + starknet::Store::< + Self::ElementType, + >::scrub(0, entry_ptr.__storage_pointer_address__, 0), + ); + len_ptr.write(vec_len - 1); + Some(last_element) + } +} /// Implement `MutableVecTrait` for any type that implements StorageAsPath into a storage /// path that implements MutableVecTrait. impl PathableMutableVecImpl< @@ -377,6 +452,18 @@ impl PathableMutableVecImpl< fn append(self: T) -> StoragePath> { self.as_path().append() } + + fn push<+Drop, +starknet::Store>( + self: T, value: Self::ElementType, + ) { + self.as_path().push(value) + } + + fn pop<+Drop, +starknet::Store>( + self: T, + ) -> Option { + self.as_path().pop() + } } pub impl VecIndexView< diff --git a/corelib/src/starknet/storage_access.cairo b/corelib/src/starknet/storage_access.cairo index 1ef80f83476..e009eecb48b 100644 --- a/corelib/src/starknet/storage_access.cairo +++ b/corelib/src/starknet/storage_access.cairo @@ -203,6 +203,37 @@ pub trait Store { /// This is bounded to 255, as the offset is a u8. As such, a single type can only take up to /// 255 slots in storage. fn size() -> u8; + + /// Clears the storage area by writing zeroes to it. + /// + /// # Arguments + /// + /// * `address_domain` - The storage domain + /// * `base` - The base storage address to start clearing + /// * `offset` - The offset from the base address where clearing should start + /// + /// The operation writes zeroes to storage starting from the specified base address and offset, + /// and continues for the size of the type as determined by the `size()` function. + #[inline] + fn scrub( + address_domain: u32, base: StorageBaseAddress, offset: u8, + ) -> SyscallResult< + (), + > { + let mut result = Result::Ok(()); + let mut offset = offset; + for _ in 0..Self::size() { + if let Result::Err(err) = + storage_write_syscall( + address_domain, storage_address_from_base_and_offset(base, offset), 0, + ) { + result = Result::Err(err); + break; + } + offset += 1; + }; + result + } } /// Trait for efficient packing of values into optimized storage representations. diff --git a/crates/bin/cairo-execute/src/main.rs b/crates/bin/cairo-execute/src/main.rs index 38e4922d9dc..3d7cf043154 100644 --- a/crates/bin/cairo-execute/src/main.rs +++ b/crates/bin/cairo-execute/src/main.rs @@ -39,7 +39,7 @@ struct Args { /// In `--build-only` this would be the executable artifact. /// In bootloader mode it will be the resulting cairo PIE file. /// In standalone mode this parameter is disallowed. - #[clap(long, required_if_eq("standalone", "false"))] + #[clap(long, required_unless_present("standalone"))] output_path: Option, /// Whether to only run a prebuilt executable. @@ -94,7 +94,7 @@ struct RunArgs { long, default_value_t = false, conflicts_with_all = ["build_only", "output_path"], - requires_all=["trace_file", "memory_file", "air_public_input", "air_private_input"], + requires_all=["air_public_input", "air_private_input"], )] standalone: bool, /// If set, the program will be run in secure mode. @@ -104,7 +104,7 @@ struct RunArgs { #[clap(long)] allow_missing_builtins: Option, #[clap(flatten)] - standalone_outputs: StandaloneOutputArgs, + proof_outputs: ProofOutputArgs, } #[derive(Parser, Debug)] @@ -123,19 +123,20 @@ struct SerializedArgs { as_file: Option, } +/// The arguments for output files required for creating a proof. #[derive(Parser, Debug)] -struct StandaloneOutputArgs { +struct ProofOutputArgs { /// The resulting trace file. - #[clap(long, conflicts_with = "build_only", requires = "standalone")] + #[clap(long, conflicts_with = "build_only")] trace_file: Option, /// The resulting memory file. - #[clap(long, conflicts_with = "build_only", requires = "standalone")] + #[clap(long, conflicts_with = "build_only")] memory_file: Option, /// The resulting AIR public input file. - #[clap(long, conflicts_with = "build_only", requires = "standalone")] + #[clap(long, conflicts_with = "build_only")] air_public_input: Option, /// The resulting AIR private input file. - #[clap(long, conflicts_with = "build_only", requires = "standalone")] + #[clap(long, conflicts_with = "build_only", requires_all=["trace_file", "memory_file"])] air_private_input: Option, } @@ -230,8 +231,8 @@ fn main() -> anyhow::Result<()> { }; let cairo_run_config = CairoRunConfig { - trace_enabled: args.run.standalone, - relocate_mem: args.run.standalone, + trace_enabled: args.run.proof_outputs.trace_file.is_some(), + relocate_mem: args.run.proof_outputs.memory_file.is_some(), layout: args.run.layout, proof_mode: args.run.standalone, secure_run: args.run.secure_run, @@ -257,7 +258,7 @@ fn main() -> anyhow::Result<()> { } } - if let Some(trace_path) = &args.run.standalone_outputs.trace_file { + if let Some(trace_path) = &args.run.proof_outputs.trace_file { let relocated_trace = runner.relocated_trace.as_ref().with_context(|| "Trace not relocated.")?; let mut writer = FileWriter::new(3 * 1024 * 1024, trace_path)?; @@ -265,21 +266,21 @@ fn main() -> anyhow::Result<()> { writer.flush()?; } - if let Some(memory_path) = &args.run.standalone_outputs.memory_file { + if let Some(memory_path) = &args.run.proof_outputs.memory_file { let mut writer = FileWriter::new(5 * 1024 * 1024, memory_path)?; cairo_run::write_encoded_memory(&runner.relocated_memory, &mut writer)?; writer.flush()?; } - if let Some(file_path) = args.run.standalone_outputs.air_public_input { + if let Some(file_path) = args.run.proof_outputs.air_public_input { let json = runner.get_air_public_input()?.serialize_json()?; std::fs::write(file_path, json)?; } if let (Some(file_path), Some(trace_file), Some(memory_file)) = ( - args.run.standalone_outputs.air_private_input, - args.run.standalone_outputs.trace_file, - args.run.standalone_outputs.memory_file, + args.run.proof_outputs.air_private_input, + args.run.proof_outputs.trace_file, + args.run.proof_outputs.memory_file, ) { let absolute = |path_buf: PathBuf| { path_buf.as_path().canonicalize().unwrap_or(path_buf).to_string_lossy().to_string() diff --git a/crates/cairo-lang-semantic/src/expr/inference.rs b/crates/cairo-lang-semantic/src/expr/inference.rs index 1ee0de9fc3b..7319222fb9f 100644 --- a/crates/cairo-lang-semantic/src/expr/inference.rs +++ b/crates/cairo-lang-semantic/src/expr/inference.rs @@ -1230,15 +1230,9 @@ impl SemanticRewriter for Inference<'_> { let impl_id = impl_type_id.impl_id(); let trait_ty = impl_type_id.ty(); return Ok(match impl_id.lookup_intern(self.db) { - ImplLongId::GenericParameter(_) | ImplLongId::SelfImpl(_) => { - impl_type_id_rewrite_result - } - ImplLongId::ImplImpl(impl_impl) => { - // The grand parent impl must be var free since we are rewriting the parent, - // and the parent is not var. - assert!(impl_impl.impl_id().is_var_free(self.db)); - impl_type_id_rewrite_result - } + ImplLongId::GenericParameter(_) + | ImplLongId::SelfImpl(_) + | ImplLongId::ImplImpl(_) => impl_type_id_rewrite_result, ImplLongId::Concrete(_) => { if let Ok(ty) = self.db.impl_type_concrete_implized(ImplTypeId::new( impl_id, trait_ty, self.db, diff --git a/crates/cairo-lang-semantic/src/expr/inference/infers.rs b/crates/cairo-lang-semantic/src/expr/inference/infers.rs index 382649d97bd..8350e97f22a 100644 --- a/crates/cairo-lang-semantic/src/expr/inference/infers.rs +++ b/crates/cairo-lang-semantic/src/expr/inference/infers.rs @@ -376,11 +376,7 @@ impl InferenceEmbeddings for Inference<'_> { .map_err(|diag_added| self.set_error(InferenceError::Reported(diag_added)))?; let impl_id = self.new_impl_var(concrete_trait_id, stable_ptr, lookup_context); for (trait_ty, ty1) in param.type_constraints.iter() { - let ty0 = self.reduce_impl_ty(ImplTypeId::new( - impl_id, - trait_ty.trait_type(self.db), - self.db, - ))?; + let ty0 = self.reduce_impl_ty(ImplTypeId::new(impl_id, *trait_ty, self.db))?; // Conforming the type will always work as the impl is a new inference variable. self.conform_ty(ty0, *ty1).ok(); } diff --git a/crates/cairo-lang-semantic/src/items/functions.rs b/crates/cairo-lang-semantic/src/items/functions.rs index 950a079cdaf..ef400f707ec 100644 --- a/crates/cairo-lang-semantic/src/items/functions.rs +++ b/crates/cairo-lang-semantic/src/items/functions.rs @@ -855,9 +855,22 @@ pub fn concrete_function_closure_params( let ConcreteFunction { generic_function, generic_args, .. } = function_id.lookup_intern(db).function; let generic_params = generic_function.generic_params(db)?; - let generic_closure_params = db.get_closure_params(generic_function)?; + let mut generic_closure_params = db.get_closure_params(generic_function)?; let substitution = GenericSubstitution::new(&generic_params, &generic_args); - SubstitutionRewriter { db, substitution: &substitution }.rewrite(generic_closure_params) + let mut rewriter = SubstitutionRewriter { db, substitution: &substitution }; + let mut changed_keys = vec![]; + for (key, value) in generic_closure_params.iter_mut() { + rewriter.internal_rewrite(value)?; + let updated_key = rewriter.rewrite(*key)?; + if updated_key != *key { + changed_keys.push((*key, updated_key)); + } + } + for (old_key, new_key) in changed_keys { + let v = generic_closure_params.swap_remove(&old_key).unwrap(); + generic_closure_params.insert(new_key, v); + } + Ok(generic_closure_params) } /// For a given list of AST parameters, returns the list of semantic parameters along with the diff --git a/crates/cairo-lang-semantic/src/items/generics.rs b/crates/cairo-lang-semantic/src/items/generics.rs index f842eaf5b46..a711f951822 100644 --- a/crates/cairo-lang-semantic/src/items/generics.rs +++ b/crates/cairo-lang-semantic/src/items/generics.rs @@ -5,7 +5,7 @@ use cairo_lang_debug::DebugWithDb; use cairo_lang_defs::db::DefsGroup; use cairo_lang_defs::ids::{ GenericItemId, GenericKind, GenericModuleItemId, GenericParamId, GenericParamLongId, - LanguageElementId, LookupItemId, ModuleFileId, TraitId, + LanguageElementId, LookupItemId, ModuleFileId, TraitId, TraitTypeId, }; use cairo_lang_diagnostics::{Diagnostics, Maybe}; use cairo_lang_proc_macros::{DebugWithDb, SemanticObject}; @@ -200,7 +200,7 @@ pub struct GenericParamConst { pub struct GenericParamImpl { pub id: GenericParamId, pub concrete_trait: Maybe, - pub type_constraints: OrderedHashMap, + pub type_constraints: OrderedHashMap, } impl Hash for GenericParamImpl { fn hash(&self, state: &mut H) { @@ -397,8 +397,7 @@ pub fn generic_params_type_constraints( let Ok(concrete_trait_id) = imp.concrete_trait else { continue; }; - for (concrete_trait_type_id, ty1) in imp.type_constraints { - let trait_ty = concrete_trait_type_id.trait_type(db); + for (trait_ty, ty1) in imp.type_constraints { let impl_type = TypeLongId::ImplType(ImplTypeId::new( ImplLongId::GenericParameter(*param).intern(db), trait_ty, @@ -617,7 +616,7 @@ fn impl_generic_param_semantic( let concrete_trait_type_id = ConcreteTraitTypeId::new(db, concrete_trait_id, trait_type_id); - match map.entry(concrete_trait_type_id) { + match map.entry(trait_type_id) { Entry::Vacant(entry) => { entry.insert(resolve_type( db, diff --git a/crates/cairo-lang-semantic/src/items/tests/trait_impl b/crates/cairo-lang-semantic/src/items/tests/trait_impl index c43922036eb..f5b133ceba3 100644 --- a/crates/cairo-lang-semantic/src/items/tests/trait_impl +++ b/crates/cairo-lang-semantic/src/items/tests/trait_impl @@ -1367,3 +1367,64 @@ error: Trait has no implementation in context: core::metaprogramming::TypeEqual: --> lib.cairo:31:11 1_u64.foo(1_u32); ^^^ + +//! > ========================================================================== + +//! > Trait default implementation with recursive types. + +//! > test_runner_name +test_function_diagnostics(expect_diagnostics: true) + +//! > function +fn foo() { + // TODO(orizi): Remove this diagnostic. + mix(1_u64.into_inner(), 1_u32.into_inner()); +} + +//! > function_name +foo + +//! > module_code +trait Outer { + type Inner; + impl Impl: Inner; + fn into_inner(self: T) -> Self::Inner; +} + +trait Inner { + type InnerMost; +} + +fn mix< + T, + U, + impl InnerT: Inner, + impl OuterU: Outer, + +core::metaprogramming::TypeEqual, +>( + lhs: T, rhs: U, +) {} + +impl OuterU32 of Outer { + type Inner = u64; + fn into_inner(self: u32) -> u64 { + self.into() + } +} + +impl OuterU64 of Outer { + type Inner = u64; + fn into_inner(self: u64) -> u64 { + self + } +} + +impl InnerImpl of Inner { + type InnerMost = u128; +} + +//! > expected_diagnostics +error: Impl mismatch: `Outer::Impl` and `test::InnerImpl`. + --> lib.cairo:40:5 + mix(1_u64.into_inner(), 1_u32.into_inner()); + ^^^ diff --git a/crates/cairo-lang-semantic/src/resolve/mod.rs b/crates/cairo-lang-semantic/src/resolve/mod.rs index 8da2c662c24..06a58d027d9 100644 --- a/crates/cairo-lang-semantic/src/resolve/mod.rs +++ b/crates/cairo-lang-semantic/src/resolve/mod.rs @@ -1677,7 +1677,7 @@ impl<'db> Resolver<'db> { for (trait_ty, ty1) in param.type_constraints.iter() { let ty0 = TypeLongId::ImplType(ImplTypeId::new( resolved_impl, - trait_ty.trait_type(self.db), + *trait_ty, self.db, )) .intern(self.db); diff --git a/crates/cairo-lang-semantic/src/substitution.rs b/crates/cairo-lang-semantic/src/substitution.rs index 068c0eeb447..1b7dcf441c1 100644 --- a/crates/cairo-lang-semantic/src/substitution.rs +++ b/crates/cairo-lang-semantic/src/substitution.rs @@ -4,8 +4,8 @@ use std::ops::{Deref, DerefMut}; use cairo_lang_defs::ids::{ EnumId, ExternFunctionId, ExternTypeId, FreeFunctionId, GenericParamId, ImplAliasId, ImplDefId, - ImplFunctionId, ImplImplDefId, LocalVarId, MemberId, ParamId, StructId, TraitConstantId, - TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VariantId, + ImplFunctionId, ImplImplDefId, LanguageElementId, LocalVarId, MemberId, ParamId, StructId, + TraitConstantId, TraitFunctionId, TraitId, TraitImplId, TraitTypeId, VariantId, }; use cairo_lang_diagnostics::{DiagnosticAdded, Maybe}; use cairo_lang_utils::ordered_hash_map::OrderedHashMap; @@ -202,21 +202,12 @@ impl> SemanticRewriter, E> for TR } } -impl + SemanticRewriter> - SemanticRewriter, E> for TRewriter +impl> + SemanticRewriter, E> for TRewriter { - fn internal_rewrite(&mut self, value: &mut OrderedHashMap) -> Result { + fn internal_rewrite(&mut self, value: &mut OrderedHashMap) -> Result { let mut result = RewriteResult::NoChange; - let mut changed_key = Vec::new(); - for (k, v) in value.iter_mut() { - let mut temp_key = k.clone(); - match self.internal_rewrite(&mut temp_key)? { - RewriteResult::Modified => { - changed_key.push((k.clone(), temp_key)); - result = RewriteResult::Modified; - } - RewriteResult::NoChange => {} - } + for (_, v) in value.iter_mut() { match self.internal_rewrite(v)? { RewriteResult::Modified => { result = RewriteResult::Modified; @@ -224,11 +215,6 @@ impl + Sema RewriteResult::NoChange => {} } } - for (old_key, new_key) in changed_key { - let v = value.swap_remove(&old_key).unwrap(); - value.insert(new_key, v); - } - Ok(result) } } diff --git a/crates/cairo-lang-starknet-classes/src/casm_contract_class.rs b/crates/cairo-lang-starknet-classes/src/casm_contract_class.rs index 12ec13cb552..d3d191c3b92 100644 --- a/crates/cairo-lang-starknet-classes/src/casm_contract_class.rs +++ b/crates/cairo-lang-starknet-classes/src/casm_contract_class.rs @@ -467,20 +467,26 @@ impl CasmContractClass { let statement_id = function.entry_point; // The expected return types are [builtins.., gas_builtin, system, PanicResult]. - require(function.signature.ret_types.len() >= 3) + let (panic_result, output_builtins) = function + .signature + .ret_types + .split_last() + .ok_or(StarknetSierraCompilationError::InvalidEntryPointSignature)?; + let (builtins, [gas_ty, system_ty]) = output_builtins + .split_last_chunk() + .ok_or(StarknetSierraCompilationError::InvalidEntryPointSignature)?; + let (input_span, input_builtins) = function + .signature + .param_types + .split_last() .ok_or(StarknetSierraCompilationError::InvalidEntryPointSignatureMissingArgs)?; - - let (input_span, input_builtins) = function.signature.param_types.split_last().unwrap(); + require(input_builtins == output_builtins) + .ok_or(StarknetSierraCompilationError::InvalidEntryPointSignature)?; let type_resolver = TypeResolver { type_decl: &program.type_declarations }; require(type_resolver.is_felt252_span(input_span)) .ok_or(StarknetSierraCompilationError::InvalidEntryPointSignature)?; - let (panic_result, output_builtins) = - function.signature.ret_types.split_last().unwrap(); - - require(input_builtins == output_builtins) - .ok_or(StarknetSierraCompilationError::InvalidEntryPointSignature)?; require(type_resolver.is_valid_entry_point_return_type(panic_result)) .ok_or(StarknetSierraCompilationError::InvalidEntryPointSignature)?; @@ -491,8 +497,6 @@ impl CasmContractClass { )); } } - let (system_ty, builtins) = input_builtins.split_last().unwrap(); - let (gas_ty, builtins) = builtins.split_last().unwrap(); // Check that the last builtins are gas and system. if *type_resolver.get_generic_id(system_ty) != SystemType::id() diff --git a/crates/cairo-lang-starknet-classes/src/felt252_serde.rs b/crates/cairo-lang-starknet-classes/src/felt252_serde.rs index 334d5acdd33..4316a13212e 100644 --- a/crates/cairo-lang-starknet-classes/src/felt252_serde.rs +++ b/crates/cairo-lang-starknet-classes/src/felt252_serde.rs @@ -176,10 +176,11 @@ impl Felt252Serde for BigInt { Ok(()) } fn deserialize(input: &[BigUintAsHex]) -> Result<(Self, &[BigUintAsHex]), Felt252SerdeError> { - let first = input.first().ok_or(Felt252SerdeError::InvalidInputForDeserialization)?; + let (first, rest) = + input.split_first().ok_or(Felt252SerdeError::InvalidInputForDeserialization)?; Ok(( - first.value.to_bigint().expect("Unsigned should always be convertible to signed."), - &input[1..], + first.value.to_bigint().ok_or(Felt252SerdeError::InvalidInputForDeserialization)?, + rest, )) } } @@ -466,8 +467,12 @@ impl Felt252Serde for ConcreteTypeInfo { fn deserialize(input: &[BigUintAsHex]) -> Result<(Self, &[BigUintAsHex]), Felt252SerdeError> { let (generic_id, input) = GenericTypeId::deserialize(input)?; let (len_and_decl_ti_value, mut input) = BigInt::deserialize(input)?; - let len = (len_and_decl_ti_value.clone() & BigInt::from(u128::MAX)).to_usize().unwrap(); - let decl_ti_value = (len_and_decl_ti_value.shr(128) as BigInt).to_u64().unwrap(); + let len = (len_and_decl_ti_value.clone() & BigInt::from(u128::MAX)) + .to_usize() + .ok_or(Felt252SerdeError::InvalidInputForDeserialization)?; + let decl_ti_value = (len_and_decl_ti_value.shr(128) as BigInt) + .to_u64() + .ok_or(Felt252SerdeError::InvalidInputForDeserialization)?; let mut generic_args = vec_with_bounded_capacity(len, input.len())?; for _ in 0..len { let (arg, next) = GenericArg::deserialize(input)?; diff --git a/crates/cairo-lang-starknet-classes/src/felt252_vec_compression.rs b/crates/cairo-lang-starknet-classes/src/felt252_vec_compression.rs index 630a513d94e..d42682c7754 100644 --- a/crates/cairo-lang-starknet-classes/src/felt252_vec_compression.rs +++ b/crates/cairo-lang-starknet-classes/src/felt252_vec_compression.rs @@ -3,46 +3,42 @@ use cairo_lang_utils::ordered_hash_map::OrderedHashMap; use cairo_lang_utils::require; use num_bigint::BigUint; use num_integer::Integer; -use num_traits::ToPrimitive; +use num_traits::{ToPrimitive, Zero}; use starknet_types_core::felt::Felt as Felt252; /// Compresses a vector of `BigUintAsHex` representing felts into `result`, by creating a code /// mapping, and then compressing several original code words into the given felts. -pub fn compress>(values: &[BigUintAsHex], result: &mut Result) { - let mut code = OrderedHashMap::<&BigUint, usize>::default(); +pub fn compress(values: &[BigUintAsHex], result: &mut Vec) { + let mut code = OrderedHashMap::<&BigUintAsHex, usize>::default(); for value in values { let idx = code.len(); - code.entry(&value.value).or_insert(idx); + code.entry(value).or_insert(idx); } // Limiting the number of possible encodings by working only on powers of 2, as well as only // starting at 256 (or 8 bits per code word). let padded_code_size = std::cmp::max(256, code.len()).next_power_of_two(); - result.extend([BigUintAsHex { value: BigUint::from(code.len()) }]); - result.extend([BigUintAsHex { value: BigUint::from(padded_code_size - code.len()) }]); - result.extend(code.keys().map(|value| BigUintAsHex { value: (*value).clone() })); - result.extend([BigUintAsHex { value: BigUint::from(values.len()) }]); + result.extend([code.len(), padded_code_size - code.len()].map(BigUintAsHex::from)); + result.extend(code.keys().copied().cloned()); + result.push(values.len().into()); let words_per_felt = words_per_felt(padded_code_size); for values in values.chunks(words_per_felt) { - let mut packed_value = BigUint::from(0u64); + let mut packed_value = BigUint::zero(); for value in values.iter().rev() { packed_value *= padded_code_size; - packed_value += code[&value.value]; + packed_value += code[&value]; } - result.extend([BigUintAsHex { value: packed_value }]); + result.push(packed_value.into()); } } /// Decompresses `packed_values` created using `compress` into `result`. -pub fn decompress>( - packed_values: &[BigUintAsHex], - result: &mut Result, -) -> Option<()> { +pub fn decompress(packed_values: &[BigUintAsHex], result: &mut Vec) -> Option<()> { let (packed_values, code_size) = pop_usize(packed_values)?; require(code_size < packed_values.len())?; let (packed_values, padding_size) = pop_usize(packed_values)?; let (code, packed_values) = packed_values.split_at(code_size); let (packed_values, mut remaining_unpacked_size) = pop_usize(packed_values)?; - let padded_code_size = code_size + padding_size; + let padded_code_size = code_size.checked_add(padding_size)?; let words_per_felt = words_per_felt(padded_code_size); let padded_code_size = BigUint::from(padded_code_size); for packed_value in packed_values { @@ -50,9 +46,7 @@ pub fn decompress>( let mut v = packed_value.value.clone(); for _ in 0..curr_words { let (remaining, code_word) = v.div_mod_floor(&padded_code_size); - result.extend([BigUintAsHex { - value: code.get(code_word.to_usize().unwrap())?.value.clone(), - }]); + result.push(code.get(code_word.to_usize()?)?.clone()); v = remaining; } remaining_unpacked_size -= curr_words; diff --git a/crates/cairo-lang-starknet-classes/src/keccak.rs b/crates/cairo-lang-starknet-classes/src/keccak.rs index e007b4c76cf..6bf59057022 100644 --- a/crates/cairo-lang-starknet-classes/src/keccak.rs +++ b/crates/cairo-lang-starknet-classes/src/keccak.rs @@ -12,6 +12,6 @@ pub fn starknet_keccak(data: &[u8]) -> BigUint { let mut result = hasher.finalize(); // Truncate result to 250 bits. - *result.first_mut().unwrap() &= 3; + result[0] &= 3; BigUint::from_bytes_be(&result) } diff --git a/crates/cairo-lang-starknet/cairo_level_tests/collections_test.cairo b/crates/cairo-lang-starknet/cairo_level_tests/collections_test.cairo index e5f492a7b49..eff1667e5e5 100644 --- a/crates/cairo-lang-starknet/cairo_level_tests/collections_test.cairo +++ b/crates/cairo-lang-starknet/cairo_level_tests/collections_test.cairo @@ -101,3 +101,28 @@ fn test_nested_member_write_to_vec() { vec_contract_state.nested.append().append().write(1); assert_eq!(map_contract_state.nested.entry(0).entry(0).read(), 1); } + +#[test] +fn test_simple_member_push_to_vec() { + let mut state = contract_with_vec::contract_state_for_testing(); + state.simple.push(10); + state.simple.push(20); + state.simple.push(30); + assert_eq!(state.simple.len(), 3); + assert_eq!(state.simple.at(0).read(), 10); + assert_eq!(state.simple.at(1).read(), 20); + assert_eq!(state.simple.at(2).read(), 30); +} + +#[test] +fn test_simple_member_pop_from_vec() { + let mut state = contract_with_vec::contract_state_for_testing(); + state.simple.append().write(10); + state.simple.append().write(20); + state.simple.append().write(30); + assert_eq!(state.simple.pop(), Some(30)); + assert_eq!(state.simple.pop(), Some(20)); + assert_eq!(state.simple.pop(), Some(10)); + assert_eq!(state.simple.len(), 0); + assert_eq!(state.simple.pop(), None); +} diff --git a/crates/cairo-lang-starknet/cairo_level_tests/storage_access.cairo b/crates/cairo-lang-starknet/cairo_level_tests/storage_access.cairo index 2d9800f45bb..59cdd27bba4 100644 --- a/crates/cairo-lang-starknet/cairo_level_tests/storage_access.cairo +++ b/crates/cairo-lang-starknet/cairo_level_tests/storage_access.cairo @@ -1,7 +1,7 @@ #[feature("deprecated-bounded-int-trait")] use core::integer::BoundedInt; use core::num::traits::Zero; -use starknet::storage::Vec; +use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess, Vec}; use starknet::{ClassHash, ContractAddress, EthAddress, StorageAddress}; use super::utils::{deserialized, serialized}; @@ -106,8 +106,8 @@ mod test_contract { }; #[storage] - struct Storage { - data: AbcEtc, + pub struct Storage { + pub data: AbcEtc, byte_arrays: ByteArrays, non_zeros: NonZeros, vecs: Vecs, @@ -402,3 +402,38 @@ fn test_enum_sub_pointers() { deserialized(test_contract::__external::get_queryable_enum_low(serialized(()))), 789, ); } + +#[test] +fn test_scrub_clears_memory() { + let base_address = starknet::storage_access::storage_base_address_from_felt252( + selector!("data"), + ); + for i in 0..=255_u8 { + starknet::Store::::write_at_offset(0, base_address, i, 1).unwrap(); + }; + starknet::Store::< + ( + felt252, + felt252, + felt252, + felt252, + felt252, + felt252, + felt252, + felt252, + felt252, + felt252, + felt252, + ), + >::scrub(0, base_address, 7) + .unwrap(); + for i in 0..7_u8 { + assert_eq!(starknet::Store::::read_at_offset(0, base_address, i).unwrap(), 1); + }; + for i in 7..18_u8 { + assert_eq!(starknet::Store::::read_at_offset(0, base_address, i).unwrap(), 0); + }; + for i in 18..=255_u8 { + assert_eq!(starknet::Store::::read_at_offset(0, base_address, i).unwrap(), 1); + }; +} diff --git a/crates/cairo-lang-utils/src/bigint.rs b/crates/cairo-lang-utils/src/bigint.rs index 559c5267613..9131f6ea29d 100644 --- a/crates/cairo-lang-utils/src/bigint.rs +++ b/crates/cairo-lang-utils/src/bigint.rs @@ -12,7 +12,7 @@ use num_bigint::{BigInt, BigUint}; use num_traits::{Num, Signed}; /// A wrapper for BigUint that serializes as hex. -#[derive(Clone, Default, Debug, PartialEq, Eq)] +#[derive(Clone, Default, Debug, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize), serde(transparent))] pub struct BigUintAsHex { /// A field element that encodes the signature of the called function. @@ -23,6 +23,12 @@ pub struct BigUintAsHex { pub value: BigUint, } +impl> From for BigUintAsHex { + fn from(x: T) -> Self { + Self { value: x.into() } + } +} + #[cfg(feature = "serde")] fn deserialize_from_str<'a, D>(s: &str) -> Result where @@ -55,7 +61,7 @@ where } // A wrapper for BigInt that serializes as hex. -#[derive(Default, Clone, Debug, PartialEq, Eq)] +#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize), serde(transparent))] #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] pub struct BigIntAsHex {