Skip to content

Commit

Permalink
LiftStructViews: Lift direct struct member access AccessNodes
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Jan 23, 2025
1 parent f0ca36b commit 20f468e
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions dace/transformation/passes/lift_struct_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,51 @@ def _lift_isedge(self, cfg: ControlFlowRegion, edge: Edge[InterstateEdge], resul
lifted_something = True
return lifted_something

def _lift_access_node(self, state: SDFGState, data_node: nd.AccessNode, result: Dict[str, Set[str]]) -> bool:
parts = data_node.data.split('.')
if not len(parts) >= 2:
return False
root_container = parts[0]
struct = state.sdfg.arrays[root_container]
if isinstance(struct, (dt.Structure, dt.StructureView)):
view_name = 'v_' + '.'.join(parts)
try:
view = state.sdfg.arrays[view_name]
except KeyError:
view = dt.View.view(struct.members[parts[1]])
view_name = state.sdfg.add_datadesc(view_name, view, find_new_name=True)

if state.in_degree(data_node) > 0:
view_node = state.add_access(view_name)
for ie in state.in_edges(data_node):
ie.data.data = view_name
state.add_edge(ie.src, ie.src_conn, view_node, None, Memlet.from_memlet(ie.data))
for path_edge in state.memlet_path(ie):
if path_edge is ie:
continue
if path_edge.data.data == data_node.data:
path_edge.data.data = view_name
state.remove_edge(ie)
state.add_edge(view_node, 'views', data_node, None,
Memlet.from_array(root_container + '.' + parts[1], struct.members[parts[1]]))
if state.out_degree(data_node) > 0:
view_node = state.add_access(view_name)
for oe in state.out_edges(data_node):
oe.data.data = view_name
state.add_edge(view_node, None, oe.dst, oe.dst_conn, Memlet.from_memlet(oe.data))
for path_edge in state.memlet_path(oe):
if path_edge is oe:
continue
if path_edge.data.data == data_node.data:
path_edge.data.data = view_name
state.remove_edge(oe)
state.add_edge(data_node, None, view_node, 'views',
Memlet.from_array(root_container + '.' + parts[1], struct.members[parts[1]]))
result[data_node.data] = view_name
data_node.data = root_container
return True
return False

def _lift_tasklet(self, state: SDFGState, data_node: nd.AccessNode, tasklet: nd.Tasklet,
edge: MultiConnectorEdge[Memlet], data: dt.Structure, connector: str,
direction: dirtype) -> Set[str]:
Expand Down Expand Up @@ -486,6 +531,8 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Dict[str, Set[str]]]:
'out')
result[node.data].update(res)
lifted_something_this_round = True
elif '.' in node.data:
lifted_something_this_round |= self._lift_access_node(block, node, result)
for edge in cfg.edges():
lifted_something_this_round |= self._lift_isedge(cfg, edge, result)

Expand Down

0 comments on commit 20f468e

Please sign in to comment.