Skip to content

Commit

Permalink
[converter] remove_tile_before_binary_elementwise_ops (#329)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 authored Jun 18, 2024
1 parent 7ea8e96 commit 0f9e3d5
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 0 deletions.
48 changes: 48 additions & 0 deletions tests/converter_optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,54 @@ def forward(self, x):
self.assertEqual(tfl_model.Subgraphs(0).OperatorsLength(), 1)
self.assertEqual(tfl_model.Subgraphs(0).Operators(0).OutputsLength(), 1)

def test_remove_elementwise_add_tile(self):
class TestModel(nn.Module):
def forward(self, x, y):
z = x + y.expand(1, -1, 2, 1)
return z

model = TestModel()
model.eval()

dummy_input = [torch.randn(1, 3, 2, 1), torch.randn(1, 3, 1, 1)]
model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

tfl_model = parse_model(model_path)
self.assertEqual(tfl_model.OperatorCodesLength(), 1)
self.assertEqual(tfl_model.OperatorCodes(0).DeprecatedBuiltinCode(), tflite.BuiltinOperator.ADD)
self.assertEqual(tfl_model.SubgraphsLength(), 1)
self.assertEqual(tfl_model.Subgraphs(0).InputsLength(), 2)
self.assertEqual(tfl_model.Subgraphs(0).OutputsLength(), 1)
self.assertEqual(tfl_model.Subgraphs(0).OperatorsLength(), 1)
self.assertEqual(tfl_model.Subgraphs(0).Operators(0).OutputsLength(), 1)

def test_remove_elementwise_mul_tile(self):
class TestModel(nn.Module):
def forward(self, x, y):
z = x.expand(1, -1, 2, 1) * y
return z

model = TestModel()
model.eval()

dummy_input = [torch.randn(1, 3, 1, 1), torch.randn(1, 3, 2, 1)]
model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

tfl_model = parse_model(model_path)
self.assertEqual(tfl_model.OperatorCodesLength(), 1)
self.assertEqual(tfl_model.OperatorCodes(0).DeprecatedBuiltinCode(), tflite.BuiltinOperator.MUL)
self.assertEqual(tfl_model.SubgraphsLength(), 1)
self.assertEqual(tfl_model.Subgraphs(0).InputsLength(), 2)
self.assertEqual(tfl_model.Subgraphs(0).OutputsLength(), 1)
self.assertEqual(tfl_model.Subgraphs(0).OperatorsLength(), 1)
self.assertEqual(tfl_model.Subgraphs(0).Operators(0).OutputsLength(), 1)

def test_binary_elementwise_transpose_as_unary(self):
class TestModel(nn.Module):
def __init__(self) -> None:
Expand Down
72 changes: 72 additions & 0 deletions tinynn/converter/operators/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,58 @@ def fuse_reciprocal_sqrt(self):
# Delete div nodes
self.graph.graph.delete_vertices(remove_ids)

@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def remove_tile_before_binary_elementwise_ops(self):
# Find fusable ops
edges = self.graph.graph.es.select(functools.partial(is_tile_binary_op_edge, graph_converter=self.graph.graph))
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)

remove_ids = []
actions = []
binary_op_ids = set()
for tile, op_node, tensor in filtered_pairs:
tile_op = tile['op']
binary_op = op_node['op']

input_idx = None
for i in range(2):
try:
_ = tile['outputs'].index(binary_op.inputs[i].name)
input_idx = i
break
except ValueError:
pass

if input_idx is None:
continue

alter_input_idx = 1 - input_idx
try:
out_shape = np.broadcast_shapes(binary_op.inputs[alter_input_idx].shape, tile_op.inputs[0].shape)
if out_shape != binary_op.outputs[0].shape:
continue
except ValueError:
continue

if op_node.index not in binary_op_ids:
binary_op_ids.add(op_node.index)
else:
continue

new_tensor = tile_op.inputs[0]

# Replace input tensors
actions.append((self.graph.replace_operator_input, (op_node, input_idx, new_tensor)))

# remove tile op
remove_ids.append(tile.index)

# Process actions
for func, args in actions:
func(*args)
# Delete tile nodes
self.graph.graph.delete_vertices(remove_ids)

@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_conv2d_gather(self):
# Find fusable ops
Expand Down Expand Up @@ -3449,6 +3501,9 @@ def optimize(self):
# Fuse reciprocal and sqrt
self.fuse_reciprocal_sqrt()

# Remove additional tile nodes before elementwise ops
self.remove_tile_before_binary_elementwise_ops()

# Fuse activation
self.fuse_activation()

Expand Down Expand Up @@ -4111,6 +4166,23 @@ def is_reciprocal_sqrt_edge(edge: ig.Edge, graph_converter: ig.Graph):
)


def is_tile_binary_op_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]

return (
source_vertex['node_type'] == ExtendedOperator.TILE
and target_vertex['node_type']
in (
ExtendedOperator.ADD,
ExtendedOperator.SUB,
ExtendedOperator.MUL,
ExtendedOperator.DIV,
)
and source_vertex.outdegree() == 1
)


def op_input_dims(op: tfl.BaseOperator):
dim_indices = None

Expand Down

0 comments on commit 0f9e3d5

Please sign in to comment.