Skip to content

Commit

Permalink
[converter] correct non-compliant axis handling in split and splitV o…
Browse files Browse the repository at this point in the history
…ps (#346)
  • Loading branch information
LynnL4 authored Aug 12, 2024
1 parent bda6c1c commit f6703c8
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
16 changes: 8 additions & 8 deletions tinynn/converter/operators/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1958,12 +1958,12 @@ def elementwise_op_transpose_passthrough_pass(self, quantizable_ops_only: bool =
elif node['node_type'] == ExtendedOperator.SPLIT_V:
old_dim = op.inputs[2].tensor
new_dim = np.where(inv_perm_arr == old_dim)[0][0]
new_dim_tensor = self.create_attr_tensor(np.array([new_dim], dtype='int32'))
new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
actions.append((self.graph.replace_operator_input, (node, 2, new_dim_tensor, True)))
elif node['node_type'] == ExtendedOperator.SPLIT:
old_dim = op.inputs[0].tensor
new_dim = np.where(inv_perm_arr == old_dim)[0][0]
new_dim_tensor = self.create_attr_tensor(np.array([new_dim], dtype='int32'))
new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
actions.append((self.graph.replace_operator_input, (node, 0, new_dim_tensor, True)))
elif node['node_type'] in (
ExtendedOperator.PAD,
Expand Down Expand Up @@ -2318,11 +2318,11 @@ def elementwise_op_reshape_passthrough_pass(self) -> int:
op.axis = new_axis
elif node['node_type'] == ExtendedOperator.SPLIT_V:
new_dim = prev_shape.index(-1)
new_dim_tensor = self.create_attr_tensor(np.array([new_dim], dtype='int32'))
new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
actions.append((self.graph.replace_operator_input, (node, 2, new_dim_tensor, True)))
elif node['node_type'] == ExtendedOperator.SPLIT:
new_dim = prev_shape.index(-1)
new_dim_tensor = self.create_attr_tensor(np.array([new_dim], dtype='int32'))
new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
actions.append((self.graph.replace_operator_input, (node, 0, new_dim_tensor, True)))
elif node['node_type'] in (ExtendedOperator.PAD, ExtendedOperator.PADV2, ExtendedOperator.MIRROR_PAD):
old_pad = op.inputs[1].tensor
Expand Down Expand Up @@ -2716,7 +2716,7 @@ def group_conv_rewrite_pass(self):
else:
biases = [None] * num_chunks

dim_tensor = self.create_attr_tensor(np.array([3], dtype='int32'))
dim_tensor = self.create_attr_tensor(np.array(3, dtype='int32'))
ops.append(tfl.SplitOperator([dim_tensor, input_tensor], input_tensors, num_chunks))

for it, ot, w, b in zip(input_tensors, output_tensors, weights, biases):
Expand Down Expand Up @@ -2815,7 +2815,7 @@ def group_deconv_rewrite_pass(self):
new_os = output_shape_tensor.tensor.copy()
new_os[3] = num_weight_channel
new_ost = self.create_attr_tensor(new_os)
dim_tensor = self.create_attr_tensor(np.array([3], dtype='int32'))
dim_tensor = self.create_attr_tensor(np.array(3, dtype='int32'))
ops.append(tfl.SplitOperator([dim_tensor, input_tensor], input_tensors, num_chunks))

for it, ot, w, b in zip(input_tensors, output_tensors, weights, biases):
Expand Down Expand Up @@ -4193,9 +4193,9 @@ def op_input_dims(op: tfl.BaseOperator):
if isinstance(op, (tfl.ConcatenationOperator, tfl.GatherOperator, tfl.PackOperator, tfl.UnpackOperator)):
dim_indices = op.axis
elif isinstance(op, tfl.SplitOperator):
dim_indices = op.inputs[0].tensor[0]
dim_indices = op.inputs[0].tensor.item()
elif isinstance(op, tfl.SplitVOperator):
dim_indices = op.inputs[2].tensor[0]
dim_indices = op.inputs[2].tensor.item()
elif isinstance(op, (tfl.PadOperator, tfl.Padv2Operator, tfl.MirrorPadOperator)):
pads = np.sum(op.inputs[1].tensor, axis=-1)
nonzero_idx = np.nonzero(pads)[0]
Expand Down
16 changes: 8 additions & 8 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def parse_common(

if not self.separated_rnn_gate_calc:
gate_outs = [self.create_transform_tensor(t) for t in np.split(add_out.tensor, 4, 1)]
split_dim_tensor = self.create_attr_tensor(np.array([1], dtype='int32'))
split_dim_tensor = self.create_attr_tensor(np.array(1, dtype='int32'))
ops.append(tfl.SplitOperator([split_dim_tensor, add_out], gate_outs, 4))

gate_i = self.create_transform_tensor(
Expand Down Expand Up @@ -709,7 +709,7 @@ def parse_common(
ops.append(tfl.FullyConnectedOperator([h, w_h, b_h], [hidden_mm]))

left_in = np.split(input_mm.tensor, 3, axis=1)
dim_tensor = self.create_attr_tensor(np.array([1], dtype='int32'))
dim_tensor = self.create_attr_tensor(np.array(1, dtype='int32'))
splited_left_in = [self.create_transform_tensor(t) for t in left_in]

ops.append(tfl.SplitOperator([dim_tensor, input_mm], splited_left_in, 3))
Expand Down Expand Up @@ -2459,7 +2459,7 @@ def parse(self, node, attrs, args, graph_converter):
else:
axis = tuple(range(1, dims))
axis_tensor = self.create_attr_tensor(np.array(axis, dtype='int32'))
split_dim_tensor = self.create_attr_tensor(np.array([1], dtype='int32'))
split_dim_tensor = self.create_attr_tensor(np.array(1, dtype='int32'))
inputs = [self.create_transform_tensor(t) for t in np.split(inp.tensor, n_groups, axis=1)]
ops.append(tfl.SplitOperator([split_dim_tensor, inp], inputs, n_groups))

Expand Down Expand Up @@ -2982,7 +2982,7 @@ def parse_common(self, node, attrs, args, graph_converter):
if dim < 0:
dim += len(self.input_tensors[0].shape)

dim_tensor = self.create_attr_tensor(np.array([dim], dtype='int32'))
dim_tensor = self.create_attr_tensor(np.array(dim, dtype='int32'))
size_splits = np.array([t.size(dim) for t in self.output_tensors[0]], dtype='int32')
chunks = len(size_splits)
split_tensor = self.create_attr_tensor(size_splits)
Expand Down Expand Up @@ -3024,7 +3024,7 @@ def parse(self, node, attrs, args, graph_converter):
chunks = dim_size

input_tensor = self.find_or_create_input(0, graph_converter)
dim_tensor = self.create_attr_tensor(np.array([dim], dtype='int32'))
dim_tensor = self.create_attr_tensor(np.array(dim, dtype='int32'))

output_names = [f'{self.output_names[0]}:{i}' for i in range(len(self.output_tensors[0]))]
graph_converter.add_iterable_pair(self.output_names, output_names, 'input')
Expand Down Expand Up @@ -3783,7 +3783,7 @@ def parse(self, node, attrs, args, graph_converter):
ops = []

mid_arrs = np.split(input_tensor.tensor, 2, axis=dim)
dim_tensor = self.create_attr_tensor(np.array([dim], dtype='int32'))
dim_tensor = self.create_attr_tensor(np.array(dim, dtype='int32'))
mid_tensors = [self.create_transform_tensor(t) for t in mid_arrs]
ops.append(tfl.SplitOperator([dim_tensor, input_tensor], mid_tensors, 2))

Expand Down Expand Up @@ -4119,7 +4119,7 @@ def parse(self, node, attrs, args, graph_converter):
actual_shift = shift % dim_size
if actual_shift != 0:
split_sizes = self.create_attr_tensor(np.array([dim_size - actual_shift, actual_shift], dtype='int32'))
dim_tensor = self.create_attr_tensor(np.array([dim], dtype='int32'))
dim_tensor = self.create_attr_tensor(np.array(dim, dtype='int32'))
chunks = 2

splitted = [
Expand Down Expand Up @@ -4491,4 +4491,4 @@ def parse(self, node, attrs, args, graph_converter):
ops.append(tfl.TileOperator([actual_input, repeat_tensor], [outp]))

for op in ops:
graph_converter.add_operator(op)
graph_converter.add_operator(op)
2 changes: 1 addition & 1 deletion tinynn/converter/operators/torch/prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def parse(self, node, attrs, args, graph_converter):
chunks = dim_size

input_tensor = self.find_or_create_input(0, graph_converter)
dim_tensor = self.create_attr_tensor(np.array([dim], dtype='int32'))
dim_tensor = self.create_attr_tensor(np.array(dim, dtype='int32'))

if dim_size % chunks != 0:
size_splits = np.array([t.size(dim) for t in self.output_tensors], dtype='int32')
Expand Down

0 comments on commit f6703c8

Please sign in to comment.