Skip to content

Commit

Permalink
[tracer&quantizer&converter] int16 cell state qparams (#340)
Browse files Browse the repository at this point in the history
* [tracer&quantizer&converter] int16 cell state qparams

* minor fixes
  • Loading branch information
peterjc123 authored Jul 9, 2024
1 parent 6d0bfb5 commit 25ced94
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 4 deletions.
8 changes: 5 additions & 3 deletions tinynn/converter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from .operators.op_version import OPVersioner
from .operators.tflite import Tensor
from .operators.torch import OPERATOR_CONVERTER_DICT
from .operators.torch.base import NoTrackOperator, TrackQParamsOperator
from .operators.torch.aten import ATenDequantizeOperator
from .operators.torch.base import NoTrackOperator, TrackRevQParamsOperator, TrackQParamsOperator
from .operators.torch.aten import ATenDequantizeOperator, ATenQuantizePerTensorOperator
from ..util.converter_util import generate_converter_config
from ..util.util import get_logger

Expand Down Expand Up @@ -436,6 +436,8 @@ def init_operations(self):
if no_track_flag:
if converter_type == ATenDequantizeOperator:
converter_type = TrackQParamsOperator
elif converter_type == ATenQuantizePerTensorOperator:
converter_type = TrackRevQParamsOperator
else:
converter_type = NoTrackOperator
converter = converter_type(
Expand All @@ -456,7 +458,7 @@ def init_operations(self):
if k != 'prim::Constant':
log.debug(f'{k} {converter.input_names} -> {converter.output_names} {converter_type.__name__}')
# Don't fetch attrs and schemas for non-tracking nodes
if converter_type not in (NoTrackOperator, TrackQParamsOperator):
if converter_type not in (NoTrackOperator, TrackRevQParamsOperator, TrackQParamsOperator):
try:
attrs = converter.fetch_all_attrs(node)
except StopIteration:
Expand Down
1 change: 1 addition & 0 deletions tinynn/converter/operators/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self) -> None:
self.output_transpose = None
self.node_op_counter = 0
self.q_mapping = {}
self.rev_q_mapping = {}
self.transform_store = {}
self.constant_mapping = {}

Expand Down
5 changes: 4 additions & 1 deletion tinynn/converter/operators/hybrid_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def int16_lstm_pass(self):

cell_state_indices = CELL_STATE_MAPPING.get(node['node_type'])
for cell_state_idx in cell_state_indices:
node['op'].inputs[cell_state_idx].quantization = tfl.QuantizationParameters(1 / 32768, 0)
q_cell_output = self.graph.rev_q_mapping[node['op'].extra_hints['cell_output']].quantization
q_cell_max = q_cell_output.scale * (127 - q_cell_output.zero_point)
cell_pot = np.pow(2, np.maximum(np.ceil(np.log2(q_cell_max)), 0)).item()
node['op'].inputs[cell_state_idx].quantization = tfl.QuantizationParameters(cell_pot / 32768, 0)
node['op'].inputs[cell_state_idx].tensor = (
node['op'].inputs[cell_state_idx].tensor.astype(np.int16)
)
Expand Down
1 change: 1 addition & 0 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def parse_common(
pack_op.extra_hints['warn_on_unused'] = False
ops.append(pack_op)
else:
ops[-1].extra_hints['cell_output'] = self.output_names[-1]
common_names = set(self.output_names[1:]) & set(graph_converter.outputs)
assert len(common_names) == 0, (
f"Please remove the LSTM state outputs ({common_names}) from the model. Alternatively, you can try"
Expand Down
10 changes: 10 additions & 0 deletions tinynn/converter/operators/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,16 @@ def parse(self, node, attrs, args, graph_converter):
graph_converter.q_mapping[self.output_names[0]] = t


class TrackRevQParamsOperator(OperatorConverter):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)

t = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]
graph_converter.rev_q_mapping[self.input_names[0]] = t


class TrackConstantOperator(OperatorConverter):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
Expand Down
53 changes: 53 additions & 0 deletions tinynn/graph/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2572,6 +2572,59 @@ def _is_batch_norm_1d(node, custom_data):
if action in ('disable', 'rewrite'):
disable_quantize_op_list[module_cls] = None

def _is_rewritable_lstm_node(node, custom_data):
cur_module = node.module
cur_class = type(cur_module)
return cur_class == nn.LSTM

if self.quantize_op_action.get(nn.LSTM, 'enable') == 'rewrite':
rewritable_lstm_nodes = graph.filter_forward_nodes(_is_rewritable_lstm_node)
fake_dequant_cls = torch_q.DeQuantStub
for idx, node in enumerate(rewritable_lstm_nodes):
cell_state = node.next_tensors[1][1]

fake_dequant = fake_dequant_cls()

fake_dequant_name = f'fake_dequant_rewrite_{idx}'

graph.module_unique_name_dict[id(fake_dequant)] = fake_dequant_name
graph.module_original_name_dict[id(fake_dequant)] = fake_dequant_name

module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()'

new_node = graph.insert_new_after(
node, fake_dequant, [cell_state], [[1, 1]], before_node=node.next_nodes[0]
)

with override_current_trace_graph(graph):
size_func = TraceFunction(
'torch.Tensor.size', is_class=True, prefix='fake_dequant_rewrite_'
).parse_args(new_node.next_tensors[0], -1)

size_node = graph.insert_new_after(
new_node,
size_func,
[new_node.next_tensors[0]],
[None],
next_tensors=[torch.tensor(new_node.next_tensors[0].size(-1))],
before_node=node.next_nodes[0],
)
size_len = len(node.next_tensors[0].shape)

with override_current_trace_graph(graph):
expand_func = TraceFunction(
'torch.Tensor.expand', is_class=True, prefix='fake_dequant_rewrite_'
).parse_args(node.next_tensors[0], *((-1,) * (size_len - 1)), size_node.next_tensors[0])

graph.insert_between(
node, node.next_nodes[0], expand_func, tensor_ptrs=[id(node.next_tensors[0])], move_idx=True
)
expand_node = graph.nodes_map[expand_func.unique_name]
size_node.next_nodes.append(expand_node)
expand_node.prev_nodes.append(node)
expand_node.prev_tensors.append(size_node.next_tensors[0])
expand_node.prev_indices.append(None)

def _is_not_quantizable(node, custom_data):
cur_module = node.module
cur_class = type(cur_module)
Expand Down
37 changes: 37 additions & 0 deletions tinynn/graph/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2664,6 +2664,43 @@ def insert_after(self, node: TraceNode, module, next_tensors: typing.Optional[ty
next_node.module.replace_tensor_name(old_unique_name, new_unique_name)
next_node.module.update_args_string()

def insert_new_after(
self,
node,
module_or_func,
prev_tensors: typing.List[torch.Tensor],
prev_indices: typing.List[torch.Tensor],
next_tensors: typing.Optional[typing.List[torch.Tensor]] = None,
before_node: typing.Optional[TraceNode] = None,
):
assert type(module_or_func) != TraceNode
new_node = TraceNode(module_or_func, cur_graph=self)

if next_tensors is None:
next_tensors = [t.clone() for t in prev_tensors]

for new_t, new_i in zip(next_tensors, prev_indices):
self.tensor_pre_node_dict[id(new_t)] = new_node.unique_name
if new_i is not None:
self.tensor_pre_index_dict[id(new_t)] = new_i

new_node.prev_tensors.extend(prev_tensors)
new_node.next_tensors.extend(next_tensors)

new_node.prev_indices.extend(prev_indices)
new_node.prev_nodes.append(node)

node.next_nodes.append(new_node)

if before_node is not None:
idx = self.forward_nodes.index(before_node)
self.forward_nodes.insert(idx, new_node)
else:
self.forward_nodes.append(new_node)
self.nodes_map[new_node.unique_name] = new_node

return new_node

def insert_between(
self,
prev_node: TraceNode,
Expand Down

0 comments on commit 25ced94

Please sign in to comment.