Skip to content

Commit

Permalink
[converter] int16 dynamic quantized lstm (#335)
Browse files Browse the repository at this point in the history
* [converter] int16 dynamic quantized lstm

* rebase to main

* refine
  • Loading branch information
peterjc123 authored Jun 25, 2024
1 parent 3752f6d commit f358aa6
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ Note: These state variables are all two-dimensional with the shape of `[batch_si
Usually, when the number of hidden layers is large enough (128+), the LSTM OP will be time-consuming in the TFLite backend. In this case, consider using dynamic range quantization to optimize its performance, see [dynamic.py](../examples/converter/dynamic.py).

You may also try out static quantization for LSTMs when you have PyTorch 1.13+. But it may take much more effort to minimize the quantization error, and you probably need to perform per-layer inspection carefully.
We also support int16 LSTM via the combination of static quantization and LSTM-only dynamic quantization. Please take a look at [ptq_with_dynamic_q_lstm.py](../examples/quantization/ptq_with_dynamic_q_lstm.py).

#### What if my model runs slower when dynamic quantization is enabled?
Please refer to [dynamic_with_selection.py](../examples/converter/dynamic_with_selection.py) for selective dynamic quantization.
Expand Down
1 change: 1 addition & 0 deletions docs/FAQ_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ Note: 这些状态变量都是二维的,维度为`[batch_size, hidden_size或
通常情况下,当隐层数量较大时(如128及以上)LSTM的模型在TFLite中会比较耗时。这种情况下,可以考虑使用动态范围量化来优化其性能,参见[dynamic.py](../examples/converter/dynamic.py)

对于使用PyTorch 1.13+版本的用户,也可以尝试对LSTM进行静态量化。但是全量化LSTM通常是较为困难的,可能需要比较细致的按层量化误差分析。
当然对于新版本TFLite中的Int16 LSTM,我们也进行了支持,可以参考[ptq_with_dynamic_q_lstm.py](../examples/quantization/ptq_with_dynamic_q_lstm.py)

#### 我的模型开了动态量化变得更慢了?
请参考 [dynamic_with_selection.py](../examples/converter/dynamic_with_selection.py) 选择性的开启动态量化。
Expand Down
86 changes: 86 additions & 0 deletions examples/quantization/ptq_with_dynamic_q_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import argparse
import os
import sys

CURRENT_PATH = os.path.abspath(os.path.dirname(__file__))

sys.path.insert(1, os.path.join(CURRENT_PATH, '../../'))

import torch
import torch.nn as nn

from tinynn.converter import TFLiteConverter
from tinynn.graph.quantization.quantizer import PostQuantizer


class SimpleLSTM(nn.Module):
def __init__(self, in_dim, out_dim, layers, num_classes):
super(SimpleLSTM, self).__init__()
self.lstm = torch.nn.LSTM(in_dim, out_dim, layers)
self.fc = torch.nn.Linear(out_dim, num_classes)
self.relu = torch.nn.ReLU()

def forward(self, inputs):
out, _ = self.lstm(inputs)
out = self.fc(out)
out = self.relu(out)
return out


def main_worker(args):
model = SimpleLSTM(args.input_size, args.hidden_size, args.num_layers, args.num_classes)

# Provide a viable input for the model
dummy_input = torch.rand((args.steps, args.batch_size, args.input_size))

# Please see 'ptq.py' for more details for using PostQuantizer.
quantizer = PostQuantizer(model, dummy_input, work_dir='out', config={'quantize_op_action': {nn.LSTM: 'rewrite'}})
ptq_model = quantizer.quantize()

print(ptq_model)

for _ in range(5):
ptq_model(torch.rand_like(dummy_input))

with torch.no_grad():
ptq_model.eval()
ptq_model.cpu()

# The step below converts the model to an actual quantized model, which uses the quantized kernels.
ptq_model = quantizer.convert(ptq_model)

print(ptq_model)

# When converting quantized models, please ensure the quantization backend is set.
torch.backends.quantized.engine = quantizer.backend

# The code section below is used to convert the model to the TFLite format
converter = TFLiteConverter(
ptq_model,
dummy_input,
tflite_path='out/ptq_with_dynamic_quant_lstm_model.tflite',
quantize_target_type='int8',
rewrite_quantizable=True,
# Enable hybrid quantization
hybrid_quantization_from_float=True,
# Enable hybrid per-channel quantization (lower q-loss, but slower)
hybrid_per_channel=False,
# Use asymmetric inputs for hybrid quantization (probably lower q-loss, but a bit slower)
hybrid_asymmetric_inputs=False,
# Enable int16 hybrid lstm quantization
hybrid_int16_lstm=True,
)
converter.convert()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--steps', type=int, default=20)
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--hidden-size', type=int, default=512)
parser.add_argument('--input-size', type=int, default=128)
parser.add_argument('--num-layers', type=int, default=1)
parser.add_argument('--num-classes', type=int, default=10)

args = parser.parse_args()
main_worker(args)
5 changes: 5 additions & 0 deletions tinynn/converter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
conv_transpose_with_bias: bool = True,
max_transpose_dims: int = -1,
hybrid_conv: bool = True,
hybrid_int16_lstm: bool = False,
unroll_rnn: bool = False,
separated_rnn_gate_calc: bool = False,
bypass_elementwise_passthrough_constraint: bool = False,
Expand Down Expand Up @@ -104,6 +105,7 @@ def __init__(
conv_transpose_with_bias (bool): ConvTranspose ops with bias. Defaults to True
max_transpose_dims (int): Max dimensions for the `Transpose` op. Defaults to -1, which means unlimited
hybrid_conv (bool): Enable hybrid quantization for Conv2d and DepthwiseConv2d. Defaults to True
hybrid_int16_lstm (bool): Enable hybrid int16 quantization for LSTM. Defaults to False
unroll_rnn (bool): Unrolling LSTM (translate LSTM to seperate ops). Defaults to False
separated_rnn_gate_calc (bool): Separated calculation for every gate in RNN. Effective only when \
`unroll_rnn=True`. Defaults to False
Expand Down Expand Up @@ -162,6 +164,7 @@ def __init__(
self.conv_transpose_with_bias = conv_transpose_with_bias
self.max_transpose_dims = max_transpose_dims
self.hybrid_conv = hybrid_conv
self.hybrid_int16_lstm = hybrid_int16_lstm
self.unroll_rnn = unroll_rnn
self.separated_rnn_gate_calc = separated_rnn_gate_calc
self.bypass_elementwise_passthrough_constraint = bypass_elementwise_passthrough_constraint
Expand Down Expand Up @@ -529,6 +532,7 @@ def convert(self):
self.bypass_elementwise_passthrough_constraint,
self.group_tensors,
self.conv_transpose_with_bias,
self.hybrid_int16_lstm,
)
optimizer.optimize()

Expand All @@ -541,6 +545,7 @@ def convert(self):
self.hybrid_q_type,
self.hybrid_per_channel,
self.hybrid_conv,
self.hybrid_int16_lstm,
self.hybrid_gen_single_op_models,
self.hybrid_config,
)
Expand Down
71 changes: 69 additions & 2 deletions tinynn/converter/operators/hybrid_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import torch

from tinynn.util.util import get_logger
from tinynn.util.util import class_conditional, get_logger

from . import tflite as tfl
from .base import ExtendedOperator
Expand All @@ -19,18 +19,33 @@
ExtendedOperator.BIDIRECTIONAL_SEQUENCE_LSTM: [1, 2, 3, 4, 5, 6, 7, 8, 18, 19, 20, 21, 22, 23, 24, 25],
}

BIAS_MAPPING = {
ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: {1: 12, 2: 13, 3: 14, 4: 15},
}

STATE_MAPPING = {
ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: [18],
}

CELL_STATE_MAPPING = {
ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM: [19],
}


class HybridQuantizer(object):
graph: CommonGraph

def __init__(self, graph, asymmetric, q_type, per_channel, enable_conv, gen_single_op_models, config) -> None:
def __init__(
self, graph, asymmetric, q_type, per_channel, enable_conv, enable_int16_lstm, gen_single_op_models, config
) -> None:
super().__init__()

self.graph = graph
self.asymmetric = asymmetric
self.q_type = q_type
self.per_channel = per_channel
self.enable_conv = enable_conv
self.enable_int16_lstm = enable_int16_lstm
self.gen_single_op_models = gen_single_op_models

if config is None:
Expand All @@ -40,6 +55,54 @@ def __init__(self, graph, asymmetric, q_type, per_channel, enable_conv, gen_sing

def quantize(self):
self.quantize_pass()
self.int16_lstm_pass()

@class_conditional(lambda self: self.enable_int16_lstm)
def int16_lstm_pass(self):
filtered_nodes = self.graph.graph.vs.select(functools.partial(is_int16_quantizable_lstm_node))

actions = []
replaced_tensors = {}
for node in filtered_nodes:
if self.config.get(node['outputs'][0], True) is False:
continue

if node['node_type'] == ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM:
lstm_input = node['op'].inputs[0]
if lstm_input.dtype == np.int8:
bias_indices = BIAS_MAPPING.get(node['node_type'])
for weight_idx, bias_idx in bias_indices.items():
bias_t = node['op'].inputs[bias_idx]
weight_t = node['op'].inputs[weight_idx]
name = bias_t.name
new_name = f'{name}_hybrid_q'
bias_a = np.frombuffer(bias_t.buffer.data, dtype='float32').reshape(bias_t.shape)
bias = torch.from_numpy(bias_a.copy())

bias_scale = weight_t.quantization.scale * lstm_input.quantization.scale
new_bias = torch.round(bias.detach() / bias_scale).to(dtype=torch.int32)
new_bias_t = tfl.Tensor(tfl.FakeQuantTensor(new_bias, bias_scale, 0), new_name)

replaced_tensors.setdefault(new_bias_t.name, new_bias_t)
new_bias_t = replaced_tensors[new_bias_t.name]
actions.append((self.graph.replace_operator_input, (node, bias_idx, new_bias_t)))

state_indices = STATE_MAPPING.get(node['node_type'])
for state_idx in state_indices:
node['op'].inputs[state_idx].quantization = copy.deepcopy(node['op'].outputs[0].quantization)
node['op'].inputs[state_idx].tensor = node['op'].inputs[state_idx].tensor.astype(np.int8)
node['op'].inputs[state_idx].dtype = node['op'].inputs[state_idx].tensor.dtype

cell_state_indices = CELL_STATE_MAPPING.get(node['node_type'])
for cell_state_idx in cell_state_indices:
node['op'].inputs[cell_state_idx].quantization = tfl.QuantizationParameters(0.00048828125, 0)
node['op'].inputs[cell_state_idx].tensor = (
node['op'].inputs[cell_state_idx].tensor.astype(np.int16)
)
node['op'].inputs[cell_state_idx].dtype = node['op'].inputs[cell_state_idx].tensor.dtype

for func, args in actions:
func(*args)

def quantize_pass(self):
filtered_nodes = self.graph.graph.vs.select(functools.partial(is_quantizable_node, with_conv=self.enable_conv))
Expand Down Expand Up @@ -116,6 +179,10 @@ def is_quantizable_node(vertex: ig.Vertex, with_conv: bool):
)


def is_int16_quantizable_lstm_node(vertex: ig.Vertex):
return vertex['node_type'] in (ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM,)


def quantize(name, tensor, dtype, qscheme, axis=None, q_type=np.uint8):
assert qscheme in (torch.per_tensor_symmetric, torch.per_channel_symmetric)

Expand Down
16 changes: 10 additions & 6 deletions tinynn/converter/operators/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
bypass_elementwise_passthrough_constraint: bool = False,
group_tensors: bool = False,
conv_transpose_with_bias: bool = True,
hybrid_int16_lstm: bool = False,
) -> None:
self.graph = graph
self.fuse_tensor_count = 0
Expand All @@ -67,6 +68,7 @@ def __init__(
self.bypass_elementwise_passthrough_constraint = bypass_elementwise_passthrough_constraint
self.group_tensors = group_tensors
self.conv_transpose_with_bias = conv_transpose_with_bias
self.hybrid_int16_lstm = hybrid_int16_lstm

def create_attr_tensor(
self, tensor: tfl.Tensor, name: str = None, quantization: typing.Optional[tfl.QuantizationParameters] = None
Expand Down Expand Up @@ -1499,7 +1501,9 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int:
@class_conditional(lambda self: self.rewrite_quantizable)
def elementwise_op_quantize_passthrough_pass(self):
edges = self.graph.graph.es.select(
functools.partial(is_quantize_elementwise_op_edge, graph_converter=self.graph.graph)
functools.partial(
is_quantize_elementwise_op_edge, graph_converter=self.graph.graph, with_lstm=self.hybrid_int16_lstm
)
)
pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.DEQUANTIZE else k[1] for k in pairs)
Expand Down Expand Up @@ -3707,17 +3711,17 @@ def is_multi_output_op_node(vertex: ig.Vertex, graph_converter: ig.Graph):
return vertex['node_type'] >= 0 and len(vertex['outputs']) > 1 and vertex.outdegree() > 0


def is_quantize_elementwise_op_edge(edge: ig.Edge, graph_converter: ig.Graph):
def is_quantize_elementwise_op_edge(edge: ig.Edge, graph_converter: ig.Graph, with_lstm: bool):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
(
source_vertex['node_type'] == ExtendedOperator.DEQUANTIZE
and is_quantizable_rewrite_op(target_vertex['node_type'], target_vertex['op'])
and is_quantizable_rewrite_op(target_vertex['node_type'], target_vertex['op'], with_lstm)
)
or (
target_vertex['node_type'] == ExtendedOperator.QUANTIZE
and is_quantizable_rewrite_op(source_vertex['node_type'], source_vertex['op'])
and is_quantizable_rewrite_op(source_vertex['node_type'], source_vertex['op'], with_lstm)
)
) and target_vertex['op'].inputs[0].name in source_vertex['outputs']

Expand Down Expand Up @@ -3867,7 +3871,7 @@ def is_elementwise_unary_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
) or is_elementwise_reduce_op(op_code, op)


def is_quantizable_rewrite_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
def is_quantizable_rewrite_op(op_code: ExtendedOperator, op: tfl.BaseOperator, with_lstm: bool):
return op_code in (
ExtendedOperator.BATCH_MATMUL,
ExtendedOperator.SOFTMAX,
Expand All @@ -3878,7 +3882,7 @@ def is_quantizable_rewrite_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
ExtendedOperator.RSQRT,
ExtendedOperator.MAXIMUM,
ExtendedOperator.MINIMUM,
)
) or (with_lstm and op_code == ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM)


def is_elementwise_binary_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
Expand Down

0 comments on commit f358aa6

Please sign in to comment.