From fc1b05a4342439658455b5093aa80de882ee9762 Mon Sep 17 00:00:00 2001 From: Lin Jiang Date: Tue, 26 Dec 2023 19:35:08 +0800 Subject: [PATCH] [autodiff] Fix the type of cmp statements in autodiff (#8452) Issue: fixes #8444 The return type of cmp statements of tensors should be tensors of u1 instead of tensors of i32. Sometimes the CFG detects that an AdStackLoadTopStmt and an AdStackLoadTopAdjStmt loading the same address. I don't know if this should happen, but it stops raising error if I don't eliminate the latter statement. --- taichi/ir/control_flow_graph.cpp | 9 +++++---- taichi/transforms/auto_diff.cpp | 14 +++++++++----- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 5cbf8d8852237..a78259e663cf2 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -867,10 +867,11 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { !may_contain_variable(killed_in_this_node, load_ptr)) { // Only perform identical load elimination within a CFGNode. auto next_load_stmt = live_load_in_this_node[load_ptr]; - TI_ASSERT(irpass::analysis::same_statements(stmt, next_load_stmt)); - next_load_stmt->replace_usages_with(stmt); - erase(block->locate(next_load_stmt)); - modified = true; + if (irpass::analysis::same_statements(stmt, next_load_stmt)) { + next_load_stmt->replace_usages_with(stmt); + erase(block->locate(next_load_stmt)); + modified = true; + } } update_container_with_alias(tensor_to_matrix_ptrs_map, diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index b03e35b74b693..d1d35c8893b5b 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -643,10 +643,11 @@ class RegulateTensorTypedStatements : public BasicStmtVisitor { auto matrix_index = Stmt::make(index_values); matrix_index->ret_type = index_tensor_type; - + auto cmp_tensor_type = TypeFactory::get_instance().get_tensor_type( + tensor_shape, PrimitiveType::u1); auto matrix_eq = Stmt::make( BinaryOpType::cmp_eq, matrix_offset.get(), matrix_index.get()); - matrix_eq->ret_type = index_tensor_type; + matrix_eq->ret_type = cmp_tensor_type; auto orig_value = Stmt::make(orig_stmt); orig_value->ret_type = tensor_type; @@ -843,9 +844,11 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { auto matrix_index = Stmt::make(index_values); matrix_index->ret_type = index_tensor_type; + auto cmp_tensor_type = TypeFactory::get_instance().get_tensor_type( + tensor_shape, PrimitiveType::u1); auto matrix_eq = Stmt::make( BinaryOpType::cmp_eq, matrix_offset.get(), matrix_index.get()); - matrix_eq->ret_type = index_tensor_type; + matrix_eq->ret_type = cmp_tensor_type; auto matrix_alloca_value = Stmt::make(stack_top_stmt->stack); @@ -1817,11 +1820,12 @@ class MakeAdjoint : public ADTransform { auto offset_matrix_init_stmt = insert(offset_values); offset_matrix_init_stmt->ret_type = index_tensor_type; - + auto cmp_tensor_type = TypeFactory::get_instance().get_tensor_type( + tensor_shape, PrimitiveType::u1); auto bin_eq_stmt = insert(BinaryOpType::cmp_eq, offset_matrix_init_stmt, indices_matrix_init_stmt); - bin_eq_stmt->ret_type = index_tensor_type; + bin_eq_stmt->ret_type = cmp_tensor_type; auto select_stmt = insert( TernaryOpType::select, bin_eq_stmt, stmt_adj_matrix_init_stmt,