diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 5cbf8d8852237..a78259e663cf2 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -867,10 +867,11 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { !may_contain_variable(killed_in_this_node, load_ptr)) { // Only perform identical load elimination within a CFGNode. auto next_load_stmt = live_load_in_this_node[load_ptr]; - TI_ASSERT(irpass::analysis::same_statements(stmt, next_load_stmt)); - next_load_stmt->replace_usages_with(stmt); - erase(block->locate(next_load_stmt)); - modified = true; + if (irpass::analysis::same_statements(stmt, next_load_stmt)) { + next_load_stmt->replace_usages_with(stmt); + erase(block->locate(next_load_stmt)); + modified = true; + } } update_container_with_alias(tensor_to_matrix_ptrs_map, diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index b03e35b74b693..d1d35c8893b5b 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -643,10 +643,11 @@ class RegulateTensorTypedStatements : public BasicStmtVisitor { auto matrix_index = Stmt::make(index_values); matrix_index->ret_type = index_tensor_type; - + auto cmp_tensor_type = TypeFactory::get_instance().get_tensor_type( + tensor_shape, PrimitiveType::u1); auto matrix_eq = Stmt::make( BinaryOpType::cmp_eq, matrix_offset.get(), matrix_index.get()); - matrix_eq->ret_type = index_tensor_type; + matrix_eq->ret_type = cmp_tensor_type; auto orig_value = Stmt::make(orig_stmt); orig_value->ret_type = tensor_type; @@ -843,9 +844,11 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { auto matrix_index = Stmt::make(index_values); matrix_index->ret_type = index_tensor_type; + auto cmp_tensor_type = TypeFactory::get_instance().get_tensor_type( + tensor_shape, PrimitiveType::u1); auto matrix_eq = Stmt::make( BinaryOpType::cmp_eq, matrix_offset.get(), matrix_index.get()); - matrix_eq->ret_type = index_tensor_type; + matrix_eq->ret_type = cmp_tensor_type; auto matrix_alloca_value = Stmt::make(stack_top_stmt->stack); @@ -1817,11 +1820,12 @@ class MakeAdjoint : public ADTransform { auto offset_matrix_init_stmt = insert(offset_values); offset_matrix_init_stmt->ret_type = index_tensor_type; - + auto cmp_tensor_type = TypeFactory::get_instance().get_tensor_type( + tensor_shape, PrimitiveType::u1); auto bin_eq_stmt = insert(BinaryOpType::cmp_eq, offset_matrix_init_stmt, indices_matrix_init_stmt); - bin_eq_stmt->ret_type = index_tensor_type; + bin_eq_stmt->ret_type = cmp_tensor_type; auto select_stmt = insert( TernaryOpType::select, bin_eq_stmt, stmt_adj_matrix_init_stmt,