Skip to content

Commit

Permalink
[converter] add intermediate tensors for Int16 LSTM (#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 authored Jun 26, 2024
1 parent f358aa6 commit 3307632
Show file tree
Hide file tree
Showing 4 changed files with 302 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tinynn/converter/operators/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,14 +322,17 @@ def replace_operator_input(
else:
self.graph.delete_edges(remove_edges)

def append_operator_input(self, node: ig.Vertex, new_tensor: tfl.Tensor):
def append_operator_input(self, node: ig.Vertex, new_tensor: tfl.Tensor, as_intermediate: bool = False):
"""Add a new input tensor to a op node
Args:
node (ig.Vertex): An op node
new_tensor (tfl.Tensor): The tensor to be added
"""
node['op'].inputs.append(new_tensor)
if as_intermediate:
node['op'].intermediates.append(new_tensor)
else:
node['op'].inputs.append(new_tensor)
new_node = self.add_nodes([new_tensor])[0]
edge = self.graph.add_edge(new_node, node, name=new_tensor.name, label=new_tensor.name)
log.debug(f'NEW EDGE: {new_node["label"]} -> {node["label"]} {self.tensor_map[edge["name"]]}')
Expand Down Expand Up @@ -566,6 +569,7 @@ def collect_operators(
op.op.index = idx
op.tfl_inputs_idx = [x.index for x in op.inputs]
op.tfl_outputs_idx = [x.index for x in op.outputs]
op.tfl_intermediates_idx = [x.index for x in op.intermediates]
result.append(op)
return result

Expand Down
19 changes: 19 additions & 0 deletions tinynn/converter/operators/hybrid_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,25 @@ def int16_lstm_pass(self):
)
node['op'].inputs[cell_state_idx].dtype = node['op'].inputs[cell_state_idx].tensor.dtype

# Add intermediates for int8x8_16 lstm
name = node['op'].outputs[0].name
input_to_input_intermediate = tfl.Tensor(np.zeros(0, dtype='float32'), f'{name}_intermediate_1')
input_to_forget_intermediate = tfl.Tensor(np.zeros(0, dtype='float32'), f'{name}_intermediate_2')
input_to_cell_intermediate = tfl.Tensor(np.zeros(0, dtype='float32'), f'{name}_intermediate_3')
input_to_output_intermediate = tfl.Tensor(np.zeros(0, dtype='float32'), f'{name}_intermediate_4')
effective_hidden_scale_intermediate = tfl.Tensor(
tfl.FakeQuantTensor(np.zeros(0, dtype='int8'), node['op'].outputs[0].quantization.scale, 0),
f'{name}_intermediate_5',
)

actions.append((self.graph.append_operator_input, (node, input_to_input_intermediate, True)))
actions.append((self.graph.append_operator_input, (node, input_to_forget_intermediate, True)))
actions.append((self.graph.append_operator_input, (node, input_to_cell_intermediate, True)))
actions.append((self.graph.append_operator_input, (node, input_to_output_intermediate, True)))
actions.append(
(self.graph.append_operator_input, (node, effective_hidden_scale_intermediate, True))
)

for func, args in actions:
func(*args)

Expand Down
5 changes: 5 additions & 0 deletions tinynn/converter/operators/tflite/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def build(self, builder: flatbuffers.Builder) -> Offset:
class BaseOperator(object):
inputs: typing.List['Tensor']
outputs: typing.List['Tensor']
intermediates: typing.List['Tensor']
op: OpCode
tfl_op: Offset
tfl_inputs_idx: typing.Iterable[int]
Expand All @@ -54,22 +55,26 @@ class BaseOperator(object):
def __init__(self, op: int, inputs: typing.List['Tensor'], outputs: typing.List['Tensor'], op_version: int = 1):
self.inputs = inputs
self.outputs = outputs
self.intermediates = []
self.op = OpCode(op, op_version)

self.tfl_op = 0
self.tfl_inputs_idx = []
self.tfl_outputs_idx = []
self.tfl_intermediates_idx = []

self.extra_hints = {}

def build(self, builder: flatbuffers.Builder) -> Offset:
tfl_inputs_idx = create_numpy_array(builder, tflite.Operator.Inputs, self.tfl_inputs_idx)
tfl_outputs_idx = create_numpy_array(builder, tflite.Operator.Outputs, self.tfl_outputs_idx)
tfl_intermediates_idx = create_numpy_array(builder, tflite.Operator.Intermediates, self.tfl_intermediates_idx)

tflite.OperatorStart(builder)
tflite.OperatorAddOpcodeIndex(builder, self.op.index)
tflite.OperatorAddInputs(builder, tfl_inputs_idx)
tflite.OperatorAddOutputs(builder, tfl_outputs_idx)
tflite.OperatorAddIntermediates(builder, tfl_intermediates_idx)
self.tfl_op = tflite.OperatorEnd(builder)

return self.tfl_op
Expand Down
Loading

0 comments on commit 3307632

Please sign in to comment.