Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure we insert the match inputs after the variables that are use… #7146

Merged
merged 1 commit into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions crates/cairo-lang-lowering/src/optimizations/reorder_statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ pub fn reorder_statements(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {

for (block_id, block_changes) in changes_by_block.into_iter() {
let statements = &mut lowered.blocks[block_id].statements;
let block_len = statements.len();

// Apply block changes in reverse order to prevent a change from invalidating the
// indices of the other changes.
for (index, opt_statement) in
block_changes.into_iter().sorted_by_key(|(index, _)| Reverse(*index))
{
match opt_statement {
Some(stmt) => statements.insert(index, stmt),
Some(stmt) => {
// If index > block_len, we insert the statement at the end of the block.
statements.insert(std::cmp::min(index, block_len), stmt)
}
None => {
statements.remove(index);
}
Expand All @@ -70,6 +74,9 @@ pub struct ReorderStatementsInfo {
// A mapping from var_id to a candidate location that it can be moved to.
// If the variable is used in multiple match arms we define the next use to be
// the match.

// Note that StatementLocation.0 might >= block.len() and it means that
// the variable should be inserted at the end of the block.
next_use: UnorderedHashMap<VariableId, StatementLocation>,
}

Expand Down Expand Up @@ -170,7 +177,9 @@ impl Analyzer<'_> for ReorderStatementsContext<'_> {
}

for var_usage in match_info.inputs() {
info.next_use.insert(var_usage.var_id, statement_location);
// Make sure we insert the match inputs after the variables that are used in the arms.
info.next_use
.insert(var_usage.var_id, (statement_location.0, statement_location.1 + 1));
}

info
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,94 @@ Statements:
(v3: ()) <- struct_construct()
End:
Return(v3)

//! > ==========================================================================

//! > Test match inputs are moved next to the matchh.

//! > test_runner_name
test_reorder_statements

//! > function
fn foo() -> felt252 {
let a = true;
let v = 5;
if a {
v + 3
} else {
v + 3
}
}

//! > function_name
foo

//! > module_code

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > before
Parameters:
blk0 (root):
Statements:
(v0: ()) <- struct_construct()
(v1: core::bool) <- bool::True(v0)
(v2: core::felt252) <- 5
End:
Match(match_enum(v1) {
bool::False(v3) => blk1,
bool::True(v4) => blk2,
})

blk1:
Statements:
(v5: core::felt252) <- 3
(v6: core::felt252) <- core::felt252_add(v2, v5)
End:
Goto(blk3, {v6 -> v7})

blk2:
Statements:
(v8: core::felt252) <- 3
(v9: core::felt252) <- core::felt252_add(v2, v8)
End:
Goto(blk3, {v9 -> v7})

blk3:
Statements:
End:
Return(v7)

//! > after
Parameters:
blk0 (root):
Statements:
(v2: core::felt252) <- 5
(v0: ()) <- struct_construct()
(v1: core::bool) <- bool::True(v0)
End:
Match(match_enum(v1) {
bool::False(v3) => blk1,
bool::True(v4) => blk2,
})

blk1:
Statements:
(v5: core::felt252) <- 3
(v6: core::felt252) <- core::felt252_add(v2, v5)
End:
Goto(blk3, {v6 -> v7})

blk2:
Statements:
(v8: core::felt252) <- 3
(v9: core::felt252) <- core::felt252_add(v2, v8)
End:
Goto(blk3, {v9 -> v7})

blk3:
Statements:
End:
Return(v7)
73 changes: 31 additions & 42 deletions crates/cairo-lang-runner/src/profiling_test_data/profiling
Original file line number Diff line number Diff line change
Expand Up @@ -45,62 +45,51 @@ main

//! > expected_profiling_info
Weight by sierra statement:
statement 18: 4 (store_temp<test::MyEnum>([4]) -> ([4]))
statement 19: 4 (store_temp<test::MyEnum>([8]) -> ([8]))
statement 20: 4 (store_temp<test::MyEnum>([13]) -> ([13]))
statement 21: 4 (store_temp<test::MyEnum>([18]) -> ([18]))
statement 89: 4 (u8_overflowing_add([41], [63], [64]) { fallthrough([68], [69]) 135([70], [71]) })
statement 113: 4 (u8_overflowing_add([59], [81], [82]) { fallthrough([86], [87]) 123([88], [89]) })
statement 121: 3 (store_temp<core::panics::PanicResult::<((),)>>([93]) -> ([93]))
statement 46: 2 (enum_match<test::MyEnum>([8]) { fallthrough([37]) 50([38]) 54([39]) 60([40]) })
statement 70: 2 (enum_match<test::MyEnum>([13]) { fallthrough([55]) 74([56]) 78([57]) 84([58]) })
statement 94: 2 (enum_match<test::MyEnum>([18]) { fallthrough([73]) 98([74]) 102([75]) 108([76]) })
statement 22: 1 (enum_match<test::MyEnum>([4]) { fallthrough([19]) 26([20]) 30([21]) 36([22]) })
statement 23: 1 (branch_align() -> ())
statement 25: 1 (jump() { 28() })
statement 28: 1 (store_temp<RangeCheck>([0]) -> ([23]))
statement 29: 1 (jump() { 46() })
statement 50: 1 (branch_align() -> ())
statement 52: 1 (store_temp<RangeCheck>([23]) -> ([41]))
statement 53: 1 (jump() { 70() })
statement 81: 1 (store_temp<u8>([60]) -> ([63]))
statement 82: 1 (store_temp<u8>([62]) -> ([64]))
statement 83: 1 (jump() { 89() })
statement 93: 1 (store_temp<RangeCheck>([68]) -> ([59]))
statement 111: 1 (store_temp<u8>([83]) -> ([81]))
statement 112: 1 (store_temp<u8>([85]) -> ([82]))
statement 114: 1 (branch_align() -> ())
statement 117: 1 (store_temp<RangeCheck>([86]) -> ([77]))
statement 122: 1 (return([77], [93]))
statement 10: 4 (store_temp<test::MyEnum>([10]) -> ([10]))
statement 11: 4 (store_temp<test::MyEnum>([5]) -> ([5]))
statement 31: 4 (u8_overflowing_add([0], [19], [20]) { fallthrough([24], [25]) 77([26], [27]) })
statement 55: 4 (u8_overflowing_add([15], [37], [38]) { fallthrough([42], [43]) 65([44], [45]) })
statement 63: 3 (store_temp<core::panics::PanicResult::<((),)>>([49]) -> ([49]))
statement 12: 2 (enum_match<test::MyEnum>([10]) { fallthrough([11]) 16([12]) 20([13]) 26([14]) })
statement 36: 2 (enum_match<test::MyEnum>([5]) { fallthrough([29]) 40([30]) 44([31]) 50([32]) })
statement 23: 1 (store_temp<u8>([16]) -> ([19]))
statement 24: 1 (store_temp<u8>([18]) -> ([20]))
statement 25: 1 (jump() { 31() })
statement 35: 1 (store_temp<RangeCheck>([24]) -> ([15]))
statement 53: 1 (store_temp<u8>([39]) -> ([37]))
statement 54: 1 (store_temp<u8>([41]) -> ([38]))
statement 56: 1 (branch_align() -> ())
statement 59: 1 (store_temp<RangeCheck>([42]) -> ([33]))
statement 64: 1 (return([33], [49]))
Weight by concrete libfunc:
libfunc store_temp<test::MyEnum>: 16
libfunc store_temp<test::MyEnum>: 8
libfunc u8_overflowing_add: 8
libfunc enum_match<test::MyEnum>: 7
libfunc jump: 4
libfunc store_temp<RangeCheck>: 4
libfunc enum_match<test::MyEnum>: 4
libfunc store_temp<u8>: 4
libfunc branch_align: 3
libfunc store_temp<core::panics::PanicResult::<((),)>>: 3
libfunc store_temp<RangeCheck>: 2
libfunc branch_align: 1
libfunc jump: 1
return: 1
Weight by generic libfunc:
libfunc store_temp: 27
libfunc store_temp: 17
libfunc u8_overflowing_add: 8
libfunc enum_match: 7
libfunc jump: 4
libfunc branch_align: 3
libfunc enum_match: 4
libfunc branch_align: 1
libfunc jump: 1
return: 1
Weight by user function (inc. generated):
function test::main: 50
function test::main: 32
Weight by original user function (exc. generated):
function test::main: 50
function test::main: 32
Weight by Cairo function:
function lib.cairo::foo: 31
function lib.cairo::foo: 17
function core::integer::U8Add::add: 11
function lib.cairo::main: 8
function lib.cairo::main: 4
Weight by Sierra stack trace:
test::main: 50
test::main: 32
Weight by Cairo stack trace:
test::main: 50
test::main: 32

//! > ==========================================================================

Expand Down
Loading
Loading