diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py index 0ef33cae97..27d962d0bd 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py @@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union import dace -from dace import data, dtypes, properties, subsets, symbolic, transformation +from dace import data, properties, subsets, symbolic, transformation from dace.sdfg import SDFG, SDFGState, graph, nodes from . import map_fusion_helper as mfh @@ -752,8 +752,10 @@ def handle_intermediate_set( Before the transformation the `state` does not have to be valid and after this function has run the state is (most likely) invalid. """ - - map_params = map_exit_1.map.params.copy() + first_map_exit = map_exit_1 + second_map_entry = map_entry_2 + second_map_exit = map_exit_2 + map_params = first_map_exit.map.params.copy() # Now we will iterate over all intermediate edges and process them. # If not stated otherwise the comments assume that we run in exclusive mode. @@ -763,36 +765,22 @@ def handle_intermediate_set( inter_node: nodes.AccessNode = out_edge.dst inter_name = inter_node.data inter_desc = inter_node.desc(sdfg) - inter_shape = inter_desc.shape # Now we will determine the shape of the new intermediate. This size of # this temporary is given by the Memlet that goes into the first map exit. pre_exit_edges = list( - state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:]) ) if len(pre_exit_edges) != 1: raise NotImplementedError() pre_exit_edge = pre_exit_edges[0] - new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) - - # Over approximation will leave us with some unneeded size one dimensions. - # If they are removed some dace transformations (especially auto optimization) - # will have problems. - if not self.strict_dataflow: - squeezed_dims: List[int] = [] # These are the dimensions we removed. - new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. - for dim, (proposed_dim_size, full_dim_size) in enumerate( - zip(new_inter_shape_raw, inter_shape) - ): - if full_dim_size == 1: # Must be kept! - new_inter_shape.append(proposed_dim_size) - elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. - squeezed_dims.append(dim) - else: - new_inter_shape.append(proposed_dim_size) - else: - squeezed_dims = [] - new_inter_shape = list(new_inter_shape_raw) + + (new_inter_shape_raw, new_inter_shape, squeezed_dims) = ( + self.compute_reduced_intermediate( + producer_subset=pre_exit_edge.data.dst_subset, + inter_desc=inter_desc, + ) + ) # This is the name of the new "intermediate" node that we will create. # It will only have the shape `new_inter_shape` which is basically its @@ -808,7 +796,6 @@ def handle_intermediate_set( new_inter_name, dtype=inter_desc.dtype, transient=True, - storage=dtypes.StorageType.Register, find_new_name=True, ) @@ -822,32 +809,30 @@ def handle_intermediate_set( shape=new_inter_shape, dtype=inter_desc.dtype, find_new_name=True, - storage=dtypes.StorageType.Register, ) new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) # Get the subset that defined into which part of the old intermediate # the old output edge wrote to. We need that to adjust the producer # Memlets, since they now write into the new (smaller) intermediate. - assert pre_exit_edge.data.data == inter_name - assert pre_exit_edge.data.dst_subset is not None producer_offset = self.compute_offset_subset( original_subset=pre_exit_edge.data.dst_subset, intermediate_desc=inter_desc, map_params=map_params, + producer_offset=None, ) - # Memlets have a lot of additional informations, such as dynamic. - # To ensure that we get all of them, we will now copy them and modify - # the one that was originally there. We also hope that propagate will - # set the rest for us correctly. + # Memlets have a lot of additional informations, to ensure that we get + # all of them, we have to do it this way. The main reason for this is + # to handle the case were the "Memlet reverse direction", i.e. `data` + # refers to the other end of the connection than before. + assert pre_exit_edge.data.dst_subset is not None + new_pre_exit_memlet_src_subset = copy.deepcopy(pre_exit_edge.data.src_subset) + new_pre_exit_memlet_dst_subset = subsets.Range.from_array(new_inter_desc) + new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) new_pre_exit_memlet.data = new_inter_name - new_pre_exit_memlet.dst_subset = subsets.Range.from_array(new_inter_desc) - # New we will reroute the output Memlet, thus it will no longer pass - # through the Map exit but through the newly created intermediate. - # NOTE: We will delete the previous edge later. new_pre_exit_edge = state.add_edge( pre_exit_edge.src, pre_exit_edge.src_conn, @@ -856,6 +841,11 @@ def handle_intermediate_set( new_pre_exit_memlet, ) + # We can update `{src, dst}_subset` only after we have inserted the + # edge, this is because the direction of the Memlet might change. + new_pre_exit_edge.data.src_subset = new_pre_exit_memlet_src_subset + new_pre_exit_edge.data.dst_subset = new_pre_exit_memlet_dst_subset + # We now handle the MemletTree defined by this edge. # The newly created edge, only handled the last collection step. for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children( @@ -863,11 +853,15 @@ def handle_intermediate_set( ): producer_edge = producer_tree.edge - # Associate the (already existing) Memlet with the new data. - # TODO(phimuell): Improve the code below to remove the check. - assert producer_edge.data.data == inter_name - producer_edge.data.data = new_inter_name + # In order to preserve the intrinsic direction of Memlets we only have to change + # the `.data` attribute of the producer Memlet if it refers to the old intermediate. + # If it refers to something different we keep it. Note that this case can only + # occur if the producer is an AccessNode. + if producer_edge.data.data == inter_name: + producer_edge.data.data = new_inter_name + # Regardless of the intrinsic direction of the Memlet, the subset we care about + # is always `dst_subset`. if is_scalar: producer_edge.data.dst_subset = "0" elif producer_edge.data.dst_subset is not None: @@ -885,7 +879,7 @@ def handle_intermediate_set( # NOTE: Assumes that map (if connected is the direct neighbour). conn_names: Set[str] = set() for inter_node_out_edge in state.out_edges(inter_node): - if inter_node_out_edge.dst == map_entry_2: + if inter_node_out_edge.dst == second_map_entry: assert inter_node_out_edge.dst_conn.startswith("IN_") conn_names.add(inter_node_out_edge.dst_conn) else: @@ -900,9 +894,7 @@ def handle_intermediate_set( for in_conn_name in conn_names: out_conn_name = "OUT_" + in_conn_name[3:] - for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): - assert inner_edge.data.data == inter_name # DIRECTION!! - + for inner_edge in state.out_edges_by_connector(second_map_entry, out_conn_name): # As for the producer side, we now read from a smaller array, # So we must offset them, we use the original edge for this. assert inner_edge.data.src_subset is not None @@ -913,11 +905,17 @@ def handle_intermediate_set( producer_offset=producer_offset, ) - # Now we create a new connection that instead reads from the new - # intermediate, instead of the old one. For this we use the - # old Memlet as template. However it is not fully initialized. + # Now create the memlet for the new consumer. To make sure that we get all attributes + # of the Memlet we make a deep copy of it. There is a tricky part here, we have to + # access `src_subset` however, this is only correctly set once it is put inside the + # SDFG. Furthermore, we have to make sure that the Memlet does not change its direction. + # i.e. that the association of `subset` and `other_subset` does not change. For this + # reason we only modify `.data` attribute of the Memlet if its name refers to the old + # intermediate. Furthermore, to play it safe, we only access the subset, `src_subset` + # after we have inserted it to the SDFG. new_inner_memlet = copy.deepcopy(inner_edge.data) - new_inner_memlet.data = new_inter_name + if inner_edge.data.data == inter_name: + new_inner_memlet.data = new_inter_name # Now we replace the edge from the SDFG. state.remove_edge(inner_edge) @@ -934,6 +932,7 @@ def handle_intermediate_set( if is_scalar: new_inner_memlet.subset = "0" elif new_inner_memlet.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. new_inner_memlet.src_subset.offset(consumer_offset, negative=True) new_inner_memlet.src_subset.pop(squeezed_dims) @@ -941,23 +940,30 @@ def handle_intermediate_set( for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children( include_self=False ): - assert consumer_tree.edge.data.data == inter_name - consumer_edge = consumer_tree.edge - consumer_edge.data.data = new_inter_name + + # We only modify the data if the Memlet refers to the old intermediate data. + # We can not do this unconditionally, because it might change the intrinsic + # direction of a Memlet and then `src_subset` would at the next `try_initialize` + # be wrong. Note that this case only occurs if the destination is an AccessNode. + if consumer_edge.data.data == inter_name: + consumer_edge.data.data = new_inter_name + + # Now we have to adapt the subsets. if is_scalar: consumer_edge.data.src_subset = "0" elif consumer_edge.data.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. consumer_edge.data.src_subset.offset(consumer_offset, negative=True) consumer_edge.data.src_subset.pop(squeezed_dims) # The edge that leaves the second map entry was already deleted. We now delete # the edges that connected the intermediate node with the second map entry. - for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)): + for edge in list(state.in_edges_by_connector(second_map_entry, in_conn_name)): assert edge.src == inter_node state.remove_edge(edge) - map_entry_2.remove_in_connector(in_conn_name) - map_entry_2.remove_out_connector(out_conn_name) + second_map_entry.remove_in_connector(in_conn_name) + second_map_entry.remove_out_connector(out_conn_name) if is_exclusive_set: # In exclusive mode the old intermediate node is no longer needed. @@ -967,41 +973,88 @@ def handle_intermediate_set( state.remove_node(inter_node) state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - map_exit_1.remove_out_connector(out_edge.src_conn) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) + first_map_exit.remove_out_connector(out_edge.src_conn) del sdfg.arrays[inter_name] else: + # TODO(phimuell): Lift this restriction + assert pre_exit_edge.data.data == inter_name + # This is the shared mode, so we have to recreate the intermediate # node, but this time it is at the exit of the second map. state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) # This is the Memlet that goes from the map internal intermediate # temporary node to the Map output. This will essentially restore # or preserve the output for the intermediate node. It is important # that we use the data that `preExitEdge` was used. final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) - assert pre_exit_edge.data.data == inter_name final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) - new_pre_exit_conn = map_exit_2.next_connector() + new_pre_exit_conn = second_map_exit.next_connector() state.add_edge( new_inter_node, None, - map_exit_2, + second_map_exit, "IN_" + new_pre_exit_conn, final_pre_exit_memlet, ) state.add_edge( - map_exit_2, + second_map_exit, "OUT_" + new_pre_exit_conn, inter_node, out_edge.dst_conn, copy.deepcopy(out_edge.data), ) - map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) - map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) + second_map_exit.add_in_connector("IN_" + new_pre_exit_conn) + second_map_exit.add_out_connector("OUT_" + new_pre_exit_conn) - map_exit_1.remove_out_connector(out_edge.src_conn) + first_map_exit.remove_out_connector(out_edge.src_conn) state.remove_edge(out_edge) + + def compute_reduced_intermediate( + self, + producer_subset: subsets.Range, + inter_desc: dace.data.Data, + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], List[int]]: + """Compute the size of the new (reduced) intermediate. + + `MapFusion` does not only fuses map, but, depending on the situation, also + eliminates intermediate arrays between the two maps. To transmit data between + the two maps a new, but much smaller intermediate is needed. + + :return: The function returns a tuple with three values with the following meaning: + * The raw shape of the reduced intermediate. + * The cleared shape of the reduced intermediate, essentially the raw shape + with all shape 1 dimensions removed. + * Which dimensions of the raw shape have been removed to get the cleared shape. + + :param producer_subset: The subset that was used to write into the intermediate. + :param inter_desc: The data descriptor for the intermediate. + """ + assert producer_subset is not None + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + new_inter_shape_raw = symbolic.overapproximate(producer_subset.size()) + inter_shape = inter_desc.shape + if not self.strict_dataflow: + squeezed_dims: List[int] = [] # These are the dimensions we removed. + new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_inter_shape_raw, inter_shape, strict=True) + ): + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + else: + squeezed_dims = [] + new_inter_shape = list(new_inter_shape_raw) + + return (tuple(new_inter_shape_raw), tuple(new_inter_shape), squeezed_dims)