Skip to content

Commit

Permalink
[quantizer] add lstm weight fake_quant to align tflite (#339)
Browse files Browse the repository at this point in the history
* [example] add weight fake_quant when doing lstm dynamic quantization

* [quantizer] add lstm weight_fake_quant to align tflite

* [quantizer] move lstm_fake_quant impl to quantizer
  • Loading branch information
zk1998 authored Jul 12, 2024
1 parent 25ced94 commit 250abdc
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
43 changes: 43 additions & 0 deletions tinynn/graph/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
)
from tinynn.util.train_util import get_logger, get_module_device
from tinynn.util.util import import_from_path
from tinynn.graph.quantization.utils import fake_quantize

from . import fused_modules as fm

Expand Down Expand Up @@ -376,6 +377,8 @@ def __init__(self, model, dummy_input, work_dir: typing.Optional[str] = None, co
if config is not None and 'layerwise_config' in config:
self.layerwise_config.update(config['layerwise_config'])

self.lstm_origin_weight_dict = {}

def parse_config(self, config: typing.Optional[dict]):
default_values = {
'rewrite_graph': True,
Expand Down Expand Up @@ -1142,6 +1145,9 @@ def new_no_observer_set():
idx = next_n.prev_nodes.index(n)
q.put((next_n, q_mod, state, idx))

if self.quantize_op_action.get(nn.LSTM, None) and self.backend == 'qnnpack':
self.fake_quantize_lstm_weights()

return graph.module

def extra_qat_fusion_postprocess(self, graph):
Expand Down Expand Up @@ -3202,6 +3208,8 @@ def convert(self, q_model: nn.Module, backend: str = 'tflite') -> nn.Module:
nn.Module: The QAT/PTQ-converted model. When the backend is set to `pytorch`, it is used for validation \
in PyTorch only.
"""
if self.quantize_op_action.get(nn.LSTM, None) and self.backend == 'qnnpack':
self.restore_lstm_weights(q_model)

for acp, post_acp, dq_name, q_name, activ_name, activ_type in self.extra_qparams_mappings:
if backend != 'pytorch' and activ_type in ('relu', 'relu6', torch.nn.ReLU, torch.nn.ReLU6):
Expand Down Expand Up @@ -3504,6 +3512,38 @@ def freeze_fake_quantize_hook(mod, inp):
if n.endswith('.weight_fake_quant'):
hooks.append(m.register_forward_pre_hook(freeze_fake_quantize_hook))

def fake_quantize_lstm_weights(self, asym=False, eps=1e-6):
def _lstm_weight_fake_quantize(weight, asym=False, eps=1e-6):
if weight.numel() == 0:
return weight
quant_min, quant_max = -127, 127
weight_parts = torch.chunk(weight, 4)
weight_quant_parts = []
for i in range(4):
weight_quant_parts.append(
fake_quantize(weight_parts[i], asym=asym, eps=eps, quant_min=quant_min, quant_max=quant_max)
)
weight_quant = torch.cat(weight_quant_parts)
return weight_quant

with torch.no_grad():
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.LSTM):
for weight_name in ['weight_ih_l0', 'weight_hh_l0']:
quantized_weight = _lstm_weight_fake_quantize(getattr(module, weight_name), asym=asym, eps=eps)
self.lstm_origin_weight_dict[f"{name}.{weight_name}"] = getattr(module, weight_name).clone()
getattr(module, weight_name).data.copy_(quantized_weight)

def restore_lstm_weights(self, model):
with torch.no_grad():
for name, module in model.named_modules():
if isinstance(module, torch.nn.LSTM):
for weight_name in ['weight_ih_l0', 'weight_hh_l0']:
full_weight_name = f"{name}.{weight_name}"
if full_weight_name in self.lstm_origin_weight_dict:
getattr(module, weight_name).data.copy_(self.lstm_origin_weight_dict[full_weight_name])
self.lstm_origin_weight_dict.clear()


class BF16Quantizer(QATQuantizer):
def __init__(self, model, dummy_input, work_dir: typing.Optional[str] = None, config: typing.Optional[dict] = None):
Expand Down Expand Up @@ -3821,6 +3861,9 @@ def new_no_observer_set():
if self.quantized_op_stats is not None:
self.prepare_quantized_ops_pass(graph)

if self.quantize_op_action.get(nn.LSTM, None) and self.backend == 'qnnpack':
self.fake_quantize_lstm_weights()

return graph.module


Expand Down
17 changes: 17 additions & 0 deletions tinynn/graph/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,20 @@ def clamp_with_fusion_(x: torch.Tensor, min_val: float, max_val: float) -> torch
if not x.is_quantized:
return torch.clamp_(x, min_val, max_val)
return x


def fake_quantize(tensor, asym, eps, quant_max, quant_min):
min_val, max_val = torch.aminmax(tensor)
device = tensor.device
zero_point = torch.zeros(min_val.size(), dtype=torch.int64, device=device)
if not asym:
max_val_pos = torch.max(-min_val, max_val)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
scale = torch.max(scale, torch.tensor(eps))
else:
scale = (max_val - min_val) / float(quant_max - quant_min)
scale = torch.max(scale, torch.tensor(eps))
zero_point = quant_min - torch.round(min_val / scale).to(torch.int)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
# do fake quantize
return torch.fake_quantize_per_tensor_affine(tensor, scale, zero_point, quant_min, quant_max)

0 comments on commit 250abdc

Please sign in to comment.