Skip to content

Commit

Permalink
Fixed the map fusion.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Feb 6, 2025
1 parent d916ae5 commit 6095829
Showing 1 changed file with 118 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import dace
from dace import data, dtypes, properties, subsets, symbolic, transformation
from dace import data, properties, subsets, symbolic, transformation
from dace.sdfg import SDFG, SDFGState, graph, nodes

from . import map_fusion_helper as mfh
Expand Down Expand Up @@ -752,8 +752,10 @@ def handle_intermediate_set(
Before the transformation the `state` does not have to be valid and
after this function has run the state is (most likely) invalid.
"""

map_params = map_exit_1.map.params.copy()
first_map_exit = map_exit_1
second_map_entry = map_entry_2
second_map_exit = map_exit_2
map_params = first_map_exit.map.params.copy()

# Now we will iterate over all intermediate edges and process them.
# If not stated otherwise the comments assume that we run in exclusive mode.
Expand All @@ -763,36 +765,22 @@ def handle_intermediate_set(
inter_node: nodes.AccessNode = out_edge.dst
inter_name = inter_node.data
inter_desc = inter_node.desc(sdfg)
inter_shape = inter_desc.shape

# Now we will determine the shape of the new intermediate. This size of
# this temporary is given by the Memlet that goes into the first map exit.
pre_exit_edges = list(
state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])
state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:])
)
if len(pre_exit_edges) != 1:
raise NotImplementedError()
pre_exit_edge = pre_exit_edges[0]
new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size())

# Over approximation will leave us with some unneeded size one dimensions.
# If they are removed some dace transformations (especially auto optimization)
# will have problems.
if not self.strict_dataflow:
squeezed_dims: List[int] = [] # These are the dimensions we removed.
new_inter_shape: List[int] = [] # This is the final shape of the new intermediate.
for dim, (proposed_dim_size, full_dim_size) in enumerate(
zip(new_inter_shape_raw, inter_shape)
):
if full_dim_size == 1: # Must be kept!
new_inter_shape.append(proposed_dim_size)
elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it.
squeezed_dims.append(dim)
else:
new_inter_shape.append(proposed_dim_size)
else:
squeezed_dims = []
new_inter_shape = list(new_inter_shape_raw)

(new_inter_shape_raw, new_inter_shape, squeezed_dims) = (
self.compute_reduced_intermediate(
producer_subset=pre_exit_edge.data.dst_subset,
inter_desc=inter_desc,
)
)

# This is the name of the new "intermediate" node that we will create.
# It will only have the shape `new_inter_shape` which is basically its
Expand All @@ -808,7 +796,6 @@ def handle_intermediate_set(
new_inter_name,
dtype=inter_desc.dtype,
transient=True,
storage=dtypes.StorageType.Register,
find_new_name=True,
)

Expand All @@ -822,32 +809,30 @@ def handle_intermediate_set(
shape=new_inter_shape,
dtype=inter_desc.dtype,
find_new_name=True,
storage=dtypes.StorageType.Register,
)
new_inter_node: nodes.AccessNode = state.add_access(new_inter_name)

# Get the subset that defined into which part of the old intermediate
# the old output edge wrote to. We need that to adjust the producer
# Memlets, since they now write into the new (smaller) intermediate.
assert pre_exit_edge.data.data == inter_name
assert pre_exit_edge.data.dst_subset is not None
producer_offset = self.compute_offset_subset(
original_subset=pre_exit_edge.data.dst_subset,
intermediate_desc=inter_desc,
map_params=map_params,
producer_offset=None,
)

# Memlets have a lot of additional informations, such as dynamic.
# To ensure that we get all of them, we will now copy them and modify
# the one that was originally there. We also hope that propagate will
# set the rest for us correctly.
# Memlets have a lot of additional informations, to ensure that we get
# all of them, we have to do it this way. The main reason for this is
# to handle the case were the "Memlet reverse direction", i.e. `data`
# refers to the other end of the connection than before.
assert pre_exit_edge.data.dst_subset is not None
new_pre_exit_memlet_src_subset = copy.deepcopy(pre_exit_edge.data.src_subset)
new_pre_exit_memlet_dst_subset = subsets.Range.from_array(new_inter_desc)

new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data)
new_pre_exit_memlet.data = new_inter_name
new_pre_exit_memlet.dst_subset = subsets.Range.from_array(new_inter_desc)

# New we will reroute the output Memlet, thus it will no longer pass
# through the Map exit but through the newly created intermediate.
# NOTE: We will delete the previous edge later.
new_pre_exit_edge = state.add_edge(
pre_exit_edge.src,
pre_exit_edge.src_conn,
Expand All @@ -856,18 +841,27 @@ def handle_intermediate_set(
new_pre_exit_memlet,
)

# We can update `{src, dst}_subset` only after we have inserted the
# edge, this is because the direction of the Memlet might change.
new_pre_exit_edge.data.src_subset = new_pre_exit_memlet_src_subset
new_pre_exit_edge.data.dst_subset = new_pre_exit_memlet_dst_subset

# We now handle the MemletTree defined by this edge.
# The newly created edge, only handled the last collection step.
for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(
include_self=False
):
producer_edge = producer_tree.edge

# Associate the (already existing) Memlet with the new data.
# TODO(phimuell): Improve the code below to remove the check.
assert producer_edge.data.data == inter_name
producer_edge.data.data = new_inter_name
# In order to preserve the intrinsic direction of Memlets we only have to change
# the `.data` attribute of the producer Memlet if it refers to the old intermediate.
# If it refers to something different we keep it. Note that this case can only
# occur if the producer is an AccessNode.
if producer_edge.data.data == inter_name:
producer_edge.data.data = new_inter_name

# Regardless of the intrinsic direction of the Memlet, the subset we care about
# is always `dst_subset`.
if is_scalar:
producer_edge.data.dst_subset = "0"
elif producer_edge.data.dst_subset is not None:
Expand All @@ -885,7 +879,7 @@ def handle_intermediate_set(
# NOTE: Assumes that map (if connected is the direct neighbour).
conn_names: Set[str] = set()
for inter_node_out_edge in state.out_edges(inter_node):
if inter_node_out_edge.dst == map_entry_2:
if inter_node_out_edge.dst == second_map_entry:
assert inter_node_out_edge.dst_conn.startswith("IN_")
conn_names.add(inter_node_out_edge.dst_conn)
else:
Expand All @@ -900,9 +894,7 @@ def handle_intermediate_set(
for in_conn_name in conn_names:
out_conn_name = "OUT_" + in_conn_name[3:]

for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name):
assert inner_edge.data.data == inter_name # DIRECTION!!

for inner_edge in state.out_edges_by_connector(second_map_entry, out_conn_name):
# As for the producer side, we now read from a smaller array,
# So we must offset them, we use the original edge for this.
assert inner_edge.data.src_subset is not None
Expand All @@ -913,11 +905,17 @@ def handle_intermediate_set(
producer_offset=producer_offset,
)

# Now we create a new connection that instead reads from the new
# intermediate, instead of the old one. For this we use the
# old Memlet as template. However it is not fully initialized.
# Now create the memlet for the new consumer. To make sure that we get all attributes
# of the Memlet we make a deep copy of it. There is a tricky part here, we have to
# access `src_subset` however, this is only correctly set once it is put inside the
# SDFG. Furthermore, we have to make sure that the Memlet does not change its direction.
# i.e. that the association of `subset` and `other_subset` does not change. For this
# reason we only modify `.data` attribute of the Memlet if its name refers to the old
# intermediate. Furthermore, to play it safe, we only access the subset, `src_subset`
# after we have inserted it to the SDFG.
new_inner_memlet = copy.deepcopy(inner_edge.data)
new_inner_memlet.data = new_inter_name
if inner_edge.data.data == inter_name:
new_inner_memlet.data = new_inter_name

# Now we replace the edge from the SDFG.
state.remove_edge(inner_edge)
Expand All @@ -934,30 +932,38 @@ def handle_intermediate_set(
if is_scalar:
new_inner_memlet.subset = "0"
elif new_inner_memlet.src_subset is not None:
# TODO(phimuell): Figuring out if `src_subset` is None is an error.
new_inner_memlet.src_subset.offset(consumer_offset, negative=True)
new_inner_memlet.src_subset.pop(squeezed_dims)

# Now we have to make sure that all consumers are properly updated.
for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(
include_self=False
):
assert consumer_tree.edge.data.data == inter_name

consumer_edge = consumer_tree.edge
consumer_edge.data.data = new_inter_name

# We only modify the data if the Memlet refers to the old intermediate data.
# We can not do this unconditionally, because it might change the intrinsic
# direction of a Memlet and then `src_subset` would at the next `try_initialize`
# be wrong. Note that this case only occurs if the destination is an AccessNode.
if consumer_edge.data.data == inter_name:
consumer_edge.data.data = new_inter_name

# Now we have to adapt the subsets.
if is_scalar:
consumer_edge.data.src_subset = "0"
elif consumer_edge.data.src_subset is not None:
# TODO(phimuell): Figuring out if `src_subset` is None is an error.
consumer_edge.data.src_subset.offset(consumer_offset, negative=True)
consumer_edge.data.src_subset.pop(squeezed_dims)

# The edge that leaves the second map entry was already deleted. We now delete
# the edges that connected the intermediate node with the second map entry.
for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)):
for edge in list(state.in_edges_by_connector(second_map_entry, in_conn_name)):
assert edge.src == inter_node
state.remove_edge(edge)
map_entry_2.remove_in_connector(in_conn_name)
map_entry_2.remove_out_connector(out_conn_name)
second_map_entry.remove_in_connector(in_conn_name)
second_map_entry.remove_out_connector(out_conn_name)

if is_exclusive_set:
# In exclusive mode the old intermediate node is no longer needed.
Expand All @@ -967,41 +973,88 @@ def handle_intermediate_set(
state.remove_node(inter_node)

state.remove_edge(pre_exit_edge)
map_exit_1.remove_in_connector(pre_exit_edge.dst_conn)
map_exit_1.remove_out_connector(out_edge.src_conn)
first_map_exit.remove_in_connector(pre_exit_edge.dst_conn)
first_map_exit.remove_out_connector(out_edge.src_conn)
del sdfg.arrays[inter_name]

else:
# TODO(phimuell): Lift this restriction
assert pre_exit_edge.data.data == inter_name

# This is the shared mode, so we have to recreate the intermediate
# node, but this time it is at the exit of the second map.
state.remove_edge(pre_exit_edge)
map_exit_1.remove_in_connector(pre_exit_edge.dst_conn)
first_map_exit.remove_in_connector(pre_exit_edge.dst_conn)

# This is the Memlet that goes from the map internal intermediate
# temporary node to the Map output. This will essentially restore
# or preserve the output for the intermediate node. It is important
# that we use the data that `preExitEdge` was used.
final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data)
assert pre_exit_edge.data.data == inter_name
final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc)

new_pre_exit_conn = map_exit_2.next_connector()
new_pre_exit_conn = second_map_exit.next_connector()
state.add_edge(
new_inter_node,
None,
map_exit_2,
second_map_exit,
"IN_" + new_pre_exit_conn,
final_pre_exit_memlet,
)
state.add_edge(
map_exit_2,
second_map_exit,
"OUT_" + new_pre_exit_conn,
inter_node,
out_edge.dst_conn,
copy.deepcopy(out_edge.data),
)
map_exit_2.add_in_connector("IN_" + new_pre_exit_conn)
map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn)
second_map_exit.add_in_connector("IN_" + new_pre_exit_conn)
second_map_exit.add_out_connector("OUT_" + new_pre_exit_conn)

map_exit_1.remove_out_connector(out_edge.src_conn)
first_map_exit.remove_out_connector(out_edge.src_conn)
state.remove_edge(out_edge)

def compute_reduced_intermediate(
self,
producer_subset: subsets.Range,
inter_desc: dace.data.Data,
) -> Tuple[Tuple[int, ...], Tuple[int, ...], List[int]]:
"""Compute the size of the new (reduced) intermediate.
`MapFusion` does not only fuses map, but, depending on the situation, also
eliminates intermediate arrays between the two maps. To transmit data between
the two maps a new, but much smaller intermediate is needed.
:return: The function returns a tuple with three values with the following meaning:
* The raw shape of the reduced intermediate.
* The cleared shape of the reduced intermediate, essentially the raw shape
with all shape 1 dimensions removed.
* Which dimensions of the raw shape have been removed to get the cleared shape.
:param producer_subset: The subset that was used to write into the intermediate.
:param inter_desc: The data descriptor for the intermediate.
"""
assert producer_subset is not None

# Over approximation will leave us with some unneeded size one dimensions.
# If they are removed some dace transformations (especially auto optimization)
# will have problems.
new_inter_shape_raw = symbolic.overapproximate(producer_subset.size())
inter_shape = inter_desc.shape
if not self.strict_dataflow:
squeezed_dims: List[int] = [] # These are the dimensions we removed.
new_inter_shape: List[int] = [] # This is the final shape of the new intermediate.
for dim, (proposed_dim_size, full_dim_size) in enumerate(
zip(new_inter_shape_raw, inter_shape)
):
if full_dim_size == 1: # Must be kept!
new_inter_shape.append(proposed_dim_size)
elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it.
squeezed_dims.append(dim)
else:
new_inter_shape.append(proposed_dim_size)
else:
squeezed_dims = []
new_inter_shape = list(new_inter_shape_raw)

return (tuple(new_inter_shape_raw), tuple(new_inter_shape), squeezed_dims)

0 comments on commit 6095829

Please sign in to comment.