From 34b574a3d3e448fb2ccd3a131c106a1ac26b16c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 7 Feb 2025 11:51:31 +0100 Subject: [PATCH] fix[dace]: Updating MapFusion (#1850) The [MapFusion PR](https://github.com/spcl/dace/pull/1629) in DaCe is still under review. However, the MapFusion in that PR has evolved, i.e. some bugs were fixed and now GT4Py is also these bugs. This PR essentially back ports some of the fixes to GT4Py. Note that this is a temporary solution and as soon as the MapFusion PR has been merged (and parallel map fusion has been introduced) the GT4Py version will go away. --- .../dace/transformations/map_fusion_serial.py | 183 +++++++++++------- 1 file changed, 118 insertions(+), 65 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py index 0ef33cae97..27d962d0bd 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py @@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union import dace -from dace import data, dtypes, properties, subsets, symbolic, transformation +from dace import data, properties, subsets, symbolic, transformation from dace.sdfg import SDFG, SDFGState, graph, nodes from . import map_fusion_helper as mfh @@ -752,8 +752,10 @@ def handle_intermediate_set( Before the transformation the `state` does not have to be valid and after this function has run the state is (most likely) invalid. """ - - map_params = map_exit_1.map.params.copy() + first_map_exit = map_exit_1 + second_map_entry = map_entry_2 + second_map_exit = map_exit_2 + map_params = first_map_exit.map.params.copy() # Now we will iterate over all intermediate edges and process them. # If not stated otherwise the comments assume that we run in exclusive mode. @@ -763,36 +765,22 @@ def handle_intermediate_set( inter_node: nodes.AccessNode = out_edge.dst inter_name = inter_node.data inter_desc = inter_node.desc(sdfg) - inter_shape = inter_desc.shape # Now we will determine the shape of the new intermediate. This size of # this temporary is given by the Memlet that goes into the first map exit. pre_exit_edges = list( - state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:]) ) if len(pre_exit_edges) != 1: raise NotImplementedError() pre_exit_edge = pre_exit_edges[0] - new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) - - # Over approximation will leave us with some unneeded size one dimensions. - # If they are removed some dace transformations (especially auto optimization) - # will have problems. - if not self.strict_dataflow: - squeezed_dims: List[int] = [] # These are the dimensions we removed. - new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. - for dim, (proposed_dim_size, full_dim_size) in enumerate( - zip(new_inter_shape_raw, inter_shape) - ): - if full_dim_size == 1: # Must be kept! - new_inter_shape.append(proposed_dim_size) - elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. - squeezed_dims.append(dim) - else: - new_inter_shape.append(proposed_dim_size) - else: - squeezed_dims = [] - new_inter_shape = list(new_inter_shape_raw) + + (new_inter_shape_raw, new_inter_shape, squeezed_dims) = ( + self.compute_reduced_intermediate( + producer_subset=pre_exit_edge.data.dst_subset, + inter_desc=inter_desc, + ) + ) # This is the name of the new "intermediate" node that we will create. # It will only have the shape `new_inter_shape` which is basically its @@ -808,7 +796,6 @@ def handle_intermediate_set( new_inter_name, dtype=inter_desc.dtype, transient=True, - storage=dtypes.StorageType.Register, find_new_name=True, ) @@ -822,32 +809,30 @@ def handle_intermediate_set( shape=new_inter_shape, dtype=inter_desc.dtype, find_new_name=True, - storage=dtypes.StorageType.Register, ) new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) # Get the subset that defined into which part of the old intermediate # the old output edge wrote to. We need that to adjust the producer # Memlets, since they now write into the new (smaller) intermediate. - assert pre_exit_edge.data.data == inter_name - assert pre_exit_edge.data.dst_subset is not None producer_offset = self.compute_offset_subset( original_subset=pre_exit_edge.data.dst_subset, intermediate_desc=inter_desc, map_params=map_params, + producer_offset=None, ) - # Memlets have a lot of additional informations, such as dynamic. - # To ensure that we get all of them, we will now copy them and modify - # the one that was originally there. We also hope that propagate will - # set the rest for us correctly. + # Memlets have a lot of additional informations, to ensure that we get + # all of them, we have to do it this way. The main reason for this is + # to handle the case were the "Memlet reverse direction", i.e. `data` + # refers to the other end of the connection than before. + assert pre_exit_edge.data.dst_subset is not None + new_pre_exit_memlet_src_subset = copy.deepcopy(pre_exit_edge.data.src_subset) + new_pre_exit_memlet_dst_subset = subsets.Range.from_array(new_inter_desc) + new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) new_pre_exit_memlet.data = new_inter_name - new_pre_exit_memlet.dst_subset = subsets.Range.from_array(new_inter_desc) - # New we will reroute the output Memlet, thus it will no longer pass - # through the Map exit but through the newly created intermediate. - # NOTE: We will delete the previous edge later. new_pre_exit_edge = state.add_edge( pre_exit_edge.src, pre_exit_edge.src_conn, @@ -856,6 +841,11 @@ def handle_intermediate_set( new_pre_exit_memlet, ) + # We can update `{src, dst}_subset` only after we have inserted the + # edge, this is because the direction of the Memlet might change. + new_pre_exit_edge.data.src_subset = new_pre_exit_memlet_src_subset + new_pre_exit_edge.data.dst_subset = new_pre_exit_memlet_dst_subset + # We now handle the MemletTree defined by this edge. # The newly created edge, only handled the last collection step. for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children( @@ -863,11 +853,15 @@ def handle_intermediate_set( ): producer_edge = producer_tree.edge - # Associate the (already existing) Memlet with the new data. - # TODO(phimuell): Improve the code below to remove the check. - assert producer_edge.data.data == inter_name - producer_edge.data.data = new_inter_name + # In order to preserve the intrinsic direction of Memlets we only have to change + # the `.data` attribute of the producer Memlet if it refers to the old intermediate. + # If it refers to something different we keep it. Note that this case can only + # occur if the producer is an AccessNode. + if producer_edge.data.data == inter_name: + producer_edge.data.data = new_inter_name + # Regardless of the intrinsic direction of the Memlet, the subset we care about + # is always `dst_subset`. if is_scalar: producer_edge.data.dst_subset = "0" elif producer_edge.data.dst_subset is not None: @@ -885,7 +879,7 @@ def handle_intermediate_set( # NOTE: Assumes that map (if connected is the direct neighbour). conn_names: Set[str] = set() for inter_node_out_edge in state.out_edges(inter_node): - if inter_node_out_edge.dst == map_entry_2: + if inter_node_out_edge.dst == second_map_entry: assert inter_node_out_edge.dst_conn.startswith("IN_") conn_names.add(inter_node_out_edge.dst_conn) else: @@ -900,9 +894,7 @@ def handle_intermediate_set( for in_conn_name in conn_names: out_conn_name = "OUT_" + in_conn_name[3:] - for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): - assert inner_edge.data.data == inter_name # DIRECTION!! - + for inner_edge in state.out_edges_by_connector(second_map_entry, out_conn_name): # As for the producer side, we now read from a smaller array, # So we must offset them, we use the original edge for this. assert inner_edge.data.src_subset is not None @@ -913,11 +905,17 @@ def handle_intermediate_set( producer_offset=producer_offset, ) - # Now we create a new connection that instead reads from the new - # intermediate, instead of the old one. For this we use the - # old Memlet as template. However it is not fully initialized. + # Now create the memlet for the new consumer. To make sure that we get all attributes + # of the Memlet we make a deep copy of it. There is a tricky part here, we have to + # access `src_subset` however, this is only correctly set once it is put inside the + # SDFG. Furthermore, we have to make sure that the Memlet does not change its direction. + # i.e. that the association of `subset` and `other_subset` does not change. For this + # reason we only modify `.data` attribute of the Memlet if its name refers to the old + # intermediate. Furthermore, to play it safe, we only access the subset, `src_subset` + # after we have inserted it to the SDFG. new_inner_memlet = copy.deepcopy(inner_edge.data) - new_inner_memlet.data = new_inter_name + if inner_edge.data.data == inter_name: + new_inner_memlet.data = new_inter_name # Now we replace the edge from the SDFG. state.remove_edge(inner_edge) @@ -934,6 +932,7 @@ def handle_intermediate_set( if is_scalar: new_inner_memlet.subset = "0" elif new_inner_memlet.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. new_inner_memlet.src_subset.offset(consumer_offset, negative=True) new_inner_memlet.src_subset.pop(squeezed_dims) @@ -941,23 +940,30 @@ def handle_intermediate_set( for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children( include_self=False ): - assert consumer_tree.edge.data.data == inter_name - consumer_edge = consumer_tree.edge - consumer_edge.data.data = new_inter_name + + # We only modify the data if the Memlet refers to the old intermediate data. + # We can not do this unconditionally, because it might change the intrinsic + # direction of a Memlet and then `src_subset` would at the next `try_initialize` + # be wrong. Note that this case only occurs if the destination is an AccessNode. + if consumer_edge.data.data == inter_name: + consumer_edge.data.data = new_inter_name + + # Now we have to adapt the subsets. if is_scalar: consumer_edge.data.src_subset = "0" elif consumer_edge.data.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. consumer_edge.data.src_subset.offset(consumer_offset, negative=True) consumer_edge.data.src_subset.pop(squeezed_dims) # The edge that leaves the second map entry was already deleted. We now delete # the edges that connected the intermediate node with the second map entry. - for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)): + for edge in list(state.in_edges_by_connector(second_map_entry, in_conn_name)): assert edge.src == inter_node state.remove_edge(edge) - map_entry_2.remove_in_connector(in_conn_name) - map_entry_2.remove_out_connector(out_conn_name) + second_map_entry.remove_in_connector(in_conn_name) + second_map_entry.remove_out_connector(out_conn_name) if is_exclusive_set: # In exclusive mode the old intermediate node is no longer needed. @@ -967,41 +973,88 @@ def handle_intermediate_set( state.remove_node(inter_node) state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - map_exit_1.remove_out_connector(out_edge.src_conn) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) + first_map_exit.remove_out_connector(out_edge.src_conn) del sdfg.arrays[inter_name] else: + # TODO(phimuell): Lift this restriction + assert pre_exit_edge.data.data == inter_name + # This is the shared mode, so we have to recreate the intermediate # node, but this time it is at the exit of the second map. state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) # This is the Memlet that goes from the map internal intermediate # temporary node to the Map output. This will essentially restore # or preserve the output for the intermediate node. It is important # that we use the data that `preExitEdge` was used. final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) - assert pre_exit_edge.data.data == inter_name final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) - new_pre_exit_conn = map_exit_2.next_connector() + new_pre_exit_conn = second_map_exit.next_connector() state.add_edge( new_inter_node, None, - map_exit_2, + second_map_exit, "IN_" + new_pre_exit_conn, final_pre_exit_memlet, ) state.add_edge( - map_exit_2, + second_map_exit, "OUT_" + new_pre_exit_conn, inter_node, out_edge.dst_conn, copy.deepcopy(out_edge.data), ) - map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) - map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) + second_map_exit.add_in_connector("IN_" + new_pre_exit_conn) + second_map_exit.add_out_connector("OUT_" + new_pre_exit_conn) - map_exit_1.remove_out_connector(out_edge.src_conn) + first_map_exit.remove_out_connector(out_edge.src_conn) state.remove_edge(out_edge) + + def compute_reduced_intermediate( + self, + producer_subset: subsets.Range, + inter_desc: dace.data.Data, + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], List[int]]: + """Compute the size of the new (reduced) intermediate. + + `MapFusion` does not only fuses map, but, depending on the situation, also + eliminates intermediate arrays between the two maps. To transmit data between + the two maps a new, but much smaller intermediate is needed. + + :return: The function returns a tuple with three values with the following meaning: + * The raw shape of the reduced intermediate. + * The cleared shape of the reduced intermediate, essentially the raw shape + with all shape 1 dimensions removed. + * Which dimensions of the raw shape have been removed to get the cleared shape. + + :param producer_subset: The subset that was used to write into the intermediate. + :param inter_desc: The data descriptor for the intermediate. + """ + assert producer_subset is not None + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + new_inter_shape_raw = symbolic.overapproximate(producer_subset.size()) + inter_shape = inter_desc.shape + if not self.strict_dataflow: + squeezed_dims: List[int] = [] # These are the dimensions we removed. + new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_inter_shape_raw, inter_shape, strict=True) + ): + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + else: + squeezed_dims = [] + new_inter_shape = list(new_inter_shape_raw) + + return (tuple(new_inter_shape_raw), tuple(new_inter_shape), squeezed_dims)