From 0a33cb3ceb5fbcae9a8f3f7966f2fc06b38fa44e Mon Sep 17 00:00:00 2001 From: Julian Gonzalez Calderon Date: Thu, 6 Feb 2025 11:43:20 -0300 Subject: [PATCH] Fix circuit gas diff: Implement u96_limbs_less_than_guarantee_verify (#1084) * Fix circuit implementation * Add circuit compare test (#1076) * Add more tests * Add compiler assert message * Adapt CircuitOutputs so that it can contain the modulo Instead of an array, now its a struct, where the second array is the modulo * Save circuit output in struct representation, instead of integer It is ultimately returned in struct representation, so now it is converted sooner (useful for guarantee representation) * Save modulus to circuit output * Implement limbs_guarantee type with gate value and circuit modulus * Implement guarantee verify * Add function build_aggregate_slice * Use always struct for limbs * Fix comment * Improve documentation * Remove single_limb implementation and use generic noop --- src/compiler.rs | 6 +- src/libfuncs/circuit.rs | 200 +++++++++++++++++++++++++++++----------- src/types.rs | 2 +- src/types/circuit.rs | 59 ++++++++++-- tests/tests/circuit.rs | 186 +++++++++++++++++++++++++++++++++++++ tests/tests/mod.rs | 1 + 6 files changed, 390 insertions(+), 64 deletions(-) create mode 100644 tests/tests/circuit.rs diff --git a/src/compiler.rs b/src/compiler.rs index dbe929b69..93870174d 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -602,7 +602,11 @@ fn compile_func( &helper, metadata, )?; - assert!(block.terminator().is_some()); + assert!( + block.terminator().is_some(), + "libfunc {} had no terminator", + libfunc_name + ); if let Some(tailrec_meta) = metadata.remove::() { if let Some(return_block) = tailrec_meta.return_target() { diff --git a/src/libfuncs/circuit.rs b/src/libfuncs/circuit.rs index 93a6fd0fe..154d44058 100644 --- a/src/libfuncs/circuit.rs +++ b/src/libfuncs/circuit.rs @@ -7,7 +7,7 @@ use crate::{ error::{Result, SierraAssertError}, libfuncs::r#struct::build_struct_value, metadata::MetadataStorage, - types::TypeBuilder, + types::{circuit::build_u384_struct_type, TypeBuilder}, utils::{get_integer_layout, layout_repeat, BlockExt, ProgramRegistryExt}, }; use cairo_lang_sierra::{ @@ -28,8 +28,8 @@ use melior::{ cf, llvm, }, ir::{ - attribute::DenseI32ArrayAttribute, r#type::IntegerType, Block, BlockLike, Location, Value, - ValueLike, + attribute::DenseI32ArrayAttribute, r#type::IntegerType, Block, BlockLike, Location, Type, + Value, ValueLike, }, Context, }; @@ -70,6 +70,9 @@ pub fn build<'ctx, 'this>( signature, .. }) + | CircuitConcreteLibfunc::U96SingleLimbLessThanGuaranteeVerify( + SignatureOnlyConcreteLibfunc { signature, .. }, + ) | CircuitConcreteLibfunc::U96GuaranteeVerify(SignatureOnlyConcreteLibfunc { signature }) => { super::build_noop::<1, true>( context, @@ -86,11 +89,6 @@ pub fn build<'ctx, 'this>( context, registry, entry, location, helper, metadata, info, ) } - CircuitConcreteLibfunc::U96SingleLimbLessThanGuaranteeVerify(info) => { - build_u96_single_limb_less_than_guarantee_verify( - context, registry, entry, location, helper, metadata, info, - ) - } } } @@ -373,7 +371,6 @@ fn build_eval<'ctx, 'this>( // let zero = entry.argument(5)?; // let one = entry.argument(6)?; - // We multiply the amount of gates evaluated by 4 (the amount of u96s in each gate) let add_mod = increment_builtin_counter_by( context, entry, @@ -402,6 +399,21 @@ fn build_eval<'ctx, 'this>( circuit_info.mul_offsets.len() * MOD_BUILTIN_INSTANCE_SIZE, )?; + // convert circuit output from integer representation to struct representation + let gates = gates + .into_iter() + .map(|value| u384_integer_to_struct(context, ok_block, location, value)) + .collect::>>()?; + + let n_gates = circuit_info.values.len(); + let gates_array = ok_block.append_op_result(llvm::undef( + llvm::r#type::array(build_u384_struct_type(context), n_gates as u32), + location, + ))?; + let gates_array = ok_block.insert_values(context, location, gates_array, &gates)?; + + let modulus_struct = u384_integer_to_struct(context, ok_block, location, circuit_modulus)?; + // Build output struct let outputs_type_id = &info.branch_signatures()[0].vars[2].ty; let outputs = build_struct_value( @@ -412,7 +424,7 @@ fn build_eval<'ctx, 'this>( helper, metadata, outputs_type_id, - &gates, + &[gates_array, modulus_struct], )?; ok_block.append_operation(helper.br(0, &[add_mod, mul_mod, outputs], location)); @@ -719,7 +731,6 @@ fn build_failure_guarantee_verify<'ctx, 'this>( } /// Generate MLIR operations for the `u96_limbs_less_than_guarantee_verify` libfunc. -/// NOOP #[allow(clippy::too_many_arguments)] fn build_u96_limbs_less_than_guarantee_verify<'ctx, 'this>( context: &'ctx Context, @@ -730,45 +741,81 @@ fn build_u96_limbs_less_than_guarantee_verify<'ctx, 'this>( metadata: &mut MetadataStorage, info: &ConcreteU96LimbsLessThanGuaranteeVerifyLibfunc, ) -> Result<()> { - let guarantee_type_id = &info.branch_signatures()[0].vars[0].ty; - let guarantee_type = registry.build_type(context, helper, metadata, guarantee_type_id)?; + let guarantee = entry.arg(0)?; + let limb_count = info.limb_count; - let guarantee = entry.append_op_result(llvm::undef(guarantee_type, location))?; + let u96_type = IntegerType::new(context, 96).into(); + let limb_struct_type = llvm::r#type::r#struct(context, &vec![u96_type; limb_count], false); + + // extract gate and modulus from input value + let gate = entry.extract_value(context, location, guarantee, limb_struct_type, 0)?; + let modulus = entry.extract_value(context, location, guarantee, limb_struct_type, 1)?; - let u96_type_id = &info.branch_signatures()[1].vars[0].ty; - let u96_type = registry.build_type(context, helper, metadata, u96_type_id)?; + // extract last limb from gate and modulus + let gate_last_limb = entry.extract_value(context, location, gate, u96_type, limb_count - 1)?; + let modulus_last_limb = + entry.extract_value(context, location, modulus, u96_type, limb_count - 1)?; - let u96 = entry.append_op_result(llvm::undef(u96_type, location))?; + // calcualte diff between limbs + let diff = entry.append_op_result(arith::subi(modulus_last_limb, gate_last_limb, location))?; + let k0 = entry.const_int_from_type(context, location, 0, u96_type)?; + let has_diff = entry.cmpi(context, CmpiPredicate::Ne, diff, k0, location)?; - let kfalse = entry.const_int(context, location, 0, 64)?; - entry.append_operation(helper.cond_br( + let diff_block = helper.append_block(Block::new(&[])); + let next_block = helper.append_block(Block::new(&[])); + entry.append_operation(cf::cond_br( context, - kfalse, - [0, 1], - [&[guarantee], &[u96]], + has_diff, + diff_block, + next_block, + &[], + &[], location, )); - Ok(()) -} + { + // if there is diff, return it + diff_block.append_operation(helper.br(1, &[diff], location)); + } + { + // if there is no diff, build a new guarantee, skipping last limb + let new_limb_struct_type = + llvm::r#type::r#struct(context, &vec![u96_type; limb_count - 1], false); + let new_gate = build_array_slice( + context, + next_block, + location, + gate, + u96_type, + new_limb_struct_type, + 0, + limb_count - 1, + )?; + let new_modulus = build_array_slice( + context, + next_block, + location, + modulus, + u96_type, + new_limb_struct_type, + 0, + limb_count - 1, + )?; -/// Generate MLIR operations for the `u96_single_limb_less_than_guarantee_verify` libfunc. -/// NOOP -#[allow(clippy::too_many_arguments)] -fn build_u96_single_limb_less_than_guarantee_verify<'ctx, 'this>( - context: &'ctx Context, - registry: &ProgramRegistry, - entry: &'this Block<'ctx>, - location: Location<'ctx>, - helper: &LibfuncHelper<'ctx, 'this>, - metadata: &mut MetadataStorage, - info: &SignatureOnlyConcreteLibfunc, -) -> Result<()> { - let u96_type_id = &info.branch_signatures()[0].vars[0].ty; - let u96_type = registry.build_type(context, helper, metadata, u96_type_id)?; - let u96 = entry.append_op_result(llvm::undef(u96_type, location))?; + let guarantee_type_id = &info.branch_signatures()[0].vars[0].ty; + let new_guarantee = build_struct_value( + context, + registry, + next_block, + location, + helper, + metadata, + guarantee_type_id, + &[new_gate, new_modulus], + )?; - entry.append_operation(helper.br(0, &[u96], location)); + next_block.append_operation(helper.br(0, &[new_guarantee], location)); + } Ok(()) } @@ -798,18 +845,41 @@ fn build_get_output<'ctx, 'this>( let output_idx = output_offset_idx - circuit_info.n_inputs - 1; let outputs = entry.arg(0)?; - let output_integer = entry.extract_value( + + let n_gates = circuit_info.values.len(); + let output_gates = entry.extract_value( context, location, outputs, - IntegerType::new(context, 384).into(), + llvm::r#type::array(build_u384_struct_type(context), n_gates as u32), + 0, + )?; + let modulus_struct = entry.extract_value( + context, + location, + outputs, + build_u384_struct_type(context), + 1, + )?; + let output_struct = entry.extract_value( + context, + location, + output_gates, + build_u384_struct_type(context), output_idx, )?; - let output_struct = u384_integer_to_struct(context, entry, location, output_integer)?; let guarantee_type_id = &info.branch_signatures()[0].vars[1].ty; - let guarantee_type = registry.build_type(context, helper, metadata, guarantee_type_id)?; - let guarantee = entry.append_op_result(llvm::undef(guarantee_type, location))?; + let guarantee = build_struct_value( + context, + registry, + entry, + location, + helper, + metadata, + guarantee_type_id, + &[output_struct, modulus_struct], + )?; entry.append_operation(helper.br(0, &[output_struct, guarantee], location)); @@ -892,16 +962,7 @@ fn u384_integer_to_struct<'a>( block.trunci(limb, u96_type, location)? }; - let struct_type = llvm::r#type::r#struct( - context, - &[ - IntegerType::new(context, 96).into(), - IntegerType::new(context, 96).into(), - IntegerType::new(context, 96).into(), - IntegerType::new(context, 96).into(), - ], - false, - ); + let struct_type = build_u384_struct_type(context); let struct_value = block.append_op_result(llvm::undef(struct_type, location))?; block.insert_values( @@ -993,6 +1054,35 @@ fn build_euclidean_algorithm<'ctx, 'this>( Ok(end_block) } +/// Extracts values from indexes `from` - `to` (exclusive) and builds a new value of type `result_type` +/// +/// Can be used with arrays, or structs with multiple elements of a single type. +#[allow(clippy::too_many_arguments)] +fn build_array_slice<'ctx>( + context: &'ctx Context, + block: &'ctx Block<'ctx>, + location: Location<'ctx>, + aggregate: Value<'ctx, 'ctx>, + element_type: Type<'ctx>, + result_type: Type<'ctx>, + from: usize, + to: usize, +) -> Result> { + let mut values = Vec::with_capacity(to - from); + + for i in from..to { + let value = block.extract_value(context, location, aggregate, element_type, i)?; + values.push(value); + } + + block.insert_values( + context, + location, + block.append_op_result(llvm::undef(result_type, location))?, + &values, + ) +} + #[cfg(test)] mod test { diff --git a/src/types.rs b/src/types.rs index 8c0f0a0ca..213f90556 100644 --- a/src/types.rs +++ b/src/types.rs @@ -35,7 +35,7 @@ mod bounded_int; mod r#box; mod builtin_costs; mod bytes31; -mod circuit; +pub mod circuit; mod coupon; mod ec_op; mod ec_point; diff --git a/src/types/circuit.rs b/src/types/circuit.rs index b8b9c93a4..613113a40 100644 --- a/src/types/circuit.rs +++ b/src/types/circuit.rs @@ -10,7 +10,7 @@ use crate::{ }; use cairo_lang_sierra::{ extensions::{ - circuit::CircuitTypeConcrete, + circuit::{CircuitTypeConcrete, ConcreteU96LimbsLessThanGuarantee}, core::{CoreLibfunc, CoreType, CoreTypeConcrete}, types::InfoOnlyConcreteType, }, @@ -57,10 +57,19 @@ pub fn build<'ctx>( metadata, WithSelf::new(selector.self_ty(), info), ), + CircuitTypeConcrete::U96LimbsLessThanGuarantee(info) => { + build_u96_limbs_less_than_guarantee( + context, + module, + registry, + metadata, + WithSelf::new(selector.self_ty(), info), + ) + } // builtins - CircuitTypeConcrete::AddMod(_) - | CircuitTypeConcrete::U96LimbsLessThanGuarantee(_) - | CircuitTypeConcrete::MulMod(_) => Ok(IntegerType::new(context, 64).into()), + CircuitTypeConcrete::AddMod(_) | CircuitTypeConcrete::MulMod(_) => { + Ok(IntegerType::new(context, 64).into()) + } // noops CircuitTypeConcrete::CircuitDescriptor(_) | CircuitTypeConcrete::CircuitFailureGuarantee(_) @@ -147,9 +156,32 @@ pub fn build_circuit_outputs<'ctx>( let n_gates = circuit.circuit_info.values.len(); - Ok(llvm::r#type::array( - IntegerType::new(context, 384).into(), - n_gates as u32, + Ok(llvm::r#type::r#struct( + context, + &[ + llvm::r#type::array(build_u384_struct_type(context), n_gates as u32), + build_u384_struct_type(context), + ], + false, + )) +} + +pub fn build_u96_limbs_less_than_guarantee<'ctx>( + context: &'ctx Context, + _module: &Module<'ctx>, + _registry: &ProgramRegistry, + _metadata: &mut MetadataStorage, + info: WithSelf, +) -> Result> { + let limbs = info.inner.limb_count; + + let u96_type = IntegerType::new(context, 96).into(); + let limb_struct_type = llvm::r#type::r#struct(context, &vec![u96_type; limbs], false); + + Ok(llvm::r#type::r#struct( + context, + &[limb_struct_type, limb_struct_type], + false, )) } @@ -282,3 +314,16 @@ pub fn layout( CircuitTypeConcrete::CircuitPartialOutputs(_) => Ok(Layout::new::<()>()), } } + +pub fn build_u384_struct_type(context: &Context) -> Type<'_> { + llvm::r#type::r#struct( + context, + &[ + IntegerType::new(context, 96).into(), + IntegerType::new(context, 96).into(), + IntegerType::new(context, 96).into(), + IntegerType::new(context, 96).into(), + ], + false, + ) +} diff --git a/tests/tests/circuit.rs b/tests/tests/circuit.rs new file mode 100644 index 000000000..049f5a649 --- /dev/null +++ b/tests/tests/circuit.rs @@ -0,0 +1,186 @@ +use crate::common::{compare_outputs, DEFAULT_GAS}; +use crate::common::{load_cairo, run_native_program, run_vm_program}; +use cairo_lang_runner::SierraCasmRunner; +use cairo_lang_sierra::program::Program; +use cairo_native::starknet::DummySyscallHandler; +use cairo_native::Value; +use lazy_static::lazy_static; + +lazy_static! { + // Taken from: https://github.com/starkware-libs/sequencer/blob/7ee6f4c8a81def87402c626c9d72a33c74bc3243/crates/blockifier/feature_contracts/cairo1/test_contract.cairo#L656 + static ref TEST: (String, Program, SierraCasmRunner) = load_cairo! { + use core::circuit::{ + CircuitElement, CircuitInput, circuit_add, circuit_sub, circuit_mul, circuit_inverse, + EvalCircuitResult, EvalCircuitTrait, u384, CircuitOutputsTrait, CircuitModulus, + CircuitInputs, AddInputResultTrait + }; + + fn test_guarantee_first_limb() { + let in1 = CircuitElement::> {}; + let in2 = CircuitElement::> {}; + let add = circuit_add(in1, in2); + let inv = circuit_inverse(add); + let sub = circuit_sub(inv, in2); + let mul = circuit_mul(inv, sub); + + let modulus = TryInto::<_, CircuitModulus>::try_into([7, 0, 0, 0]).unwrap(); + let outputs = (mul,) + .new_inputs() + .next([3, 0, 0, 0]) + .next([6, 0, 0, 0]) + .done() + .eval(modulus) + .unwrap(); + + assert!(outputs.get_output(mul) == u384 { limb0: 6, limb1: 0, limb2: 0, limb3: 0 }); + } + + fn test_guarantee_last_limb() { + let in1 = CircuitElement::> {}; + let in2 = CircuitElement::> {}; + let add = circuit_add(in1, in2); + + let modulus = TryInto::<_, CircuitModulus>::try_into([7, 0, 0, 1]).unwrap(); + let outputs = (add,) + .new_inputs() + .next([5, 0, 0, 0]) + .next([9, 0, 0, 0]) + .done() + .eval(modulus) + .unwrap(); + + assert!(outputs.get_output(add) == u384 { limb0: 14, limb1: 0, limb2: 0, limb3: 0 }); + } + + fn test_guarantee_middle_limb() { + let in1 = CircuitElement::> {}; + let in2 = CircuitElement::> {}; + let add = circuit_add(in1, in2); + + let modulus = TryInto::<_, CircuitModulus>::try_into([7, 0, 1, 0]).unwrap(); + let outputs = (add,) + .new_inputs() + .next([5, 0, 0, 0]) + .next([9, 0, 0, 0]) + .done() + .eval(modulus) + .unwrap(); + + assert!(outputs.get_output(add) == u384 { limb0: 14, limb1: 0, limb2: 0, limb3: 0 }); + } + }; +} + +#[test] +fn circuit_guarantee_first_limb() { + let program = &TEST; + + let result_vm = run_vm_program( + program, + "test_guarantee_first_limb", + vec![], + Some(DEFAULT_GAS as usize), + ) + .unwrap(); + + let result_native = run_native_program( + program, + "test_guarantee_first_limb", + &[], + Some(DEFAULT_GAS), + Option::::None, + ); + + assert!(matches!( + result_native.return_value, + Value::Enum { tag: 0, .. } + )); + + compare_outputs( + &program.1, + &program + .2 + .find_function("test_guarantee_first_limb") + .unwrap() + .id, + &result_vm, + &result_native, + ) + .unwrap(); +} + +#[test] +fn circuit_guarantee_last_limb() { + let program = &TEST; + + let result_vm = run_vm_program( + program, + "test_guarantee_last_limb", + vec![], + Some(DEFAULT_GAS as usize), + ) + .unwrap(); + + let result_native = run_native_program( + program, + "test_guarantee_last_limb", + &[], + Some(DEFAULT_GAS), + Option::::None, + ); + + assert!(matches!( + result_native.return_value, + Value::Enum { tag: 0, .. } + )); + + compare_outputs( + &program.1, + &program + .2 + .find_function("test_guarantee_last_limb") + .unwrap() + .id, + &result_vm, + &result_native, + ) + .unwrap(); +} + +#[test] +fn circuit_guarantee_middle_limb() { + let program = &TEST; + + let result_vm = run_vm_program( + program, + "test_guarantee_middle_limb", + vec![], + Some(DEFAULT_GAS as usize), + ) + .unwrap(); + + let result_native = run_native_program( + program, + "test_guarantee_middle_limb", + &[], + Some(DEFAULT_GAS), + Option::::None, + ); + + assert!(matches!( + result_native.return_value, + Value::Enum { tag: 0, .. } + )); + + compare_outputs( + &program.1, + &program + .2 + .find_function("test_guarantee_middle_limb") + .unwrap() + .id, + &result_vm, + &result_native, + ) + .unwrap(); +} diff --git a/tests/tests/mod.rs b/tests/tests/mod.rs index 650d3d7fc..d864318e8 100644 --- a/tests/tests/mod.rs +++ b/tests/tests/mod.rs @@ -2,6 +2,7 @@ pub mod alexandria; pub mod arrays; pub mod boolean; pub mod cases; +pub mod circuit; pub mod compile_library; pub mod dict; pub mod ec;