Skip to content

Commit

Permalink
[quantizer] add quantize_op_action config to disable op quantization (#…
Browse files Browse the repository at this point in the history
…334)

* [quantizer] add quantize_op_action config to disable op quantization

* [quantizer] fix typo

* fix minor bug

---------

Co-authored-by: peterjc123 <[email protected]>
  • Loading branch information
zk1998 and peterjc123 authored Jun 25, 2024
1 parent ec5fa35 commit 3752f6d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
11 changes: 11 additions & 0 deletions docs/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ with model_tracer():
qat_model = quantizer.quantize()
```

Q: How to specify mixed quantization according to operator types?

A: Configure the quantize_op_action parameter in the config during Quantizer initialization. You need to specify the actions for non-quantized operators: 'disable' means completely non-quantized, and 'rewrite' means not quantized but retaining the quantization parameters of the operator's inputs and outputs.

```python
# For a model containing LSTM op, perform mixed quantization while retaining the quantization parameters of its inputs, facilitating subsequent quantization directly in the converter.
with model_tracer():
quantizer = QATQuantizer(model, dummy_input, work_dir='out', config={ 'quantize_op_action': {nn.LSTM: 'rewrite'} })
qat_model = quantizer.quantize()
```


#### How to handle the case of inconsistent training and inference computation graphs?

Expand Down
10 changes: 10 additions & 0 deletions docs/FAQ_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ with model_tracer():
qat_model = quantizer.quantize()
```

Q: 如何按照算子类型指定混合量化?

A:在Quantizer初始化时配置config中的quantize_op_action参数,需要指定不量化的行为,'disable'表示完全不量化,'rewrite'表示不量化但是保留OP输入输出的量化参数。
```python
# 需要对含LSTM OP的模型进行混合量化,保留其输入的量化参数,方便后续直接在converter中进行量化。
with model_tracer():
quantizer = QATQuantizer(model, dummy_input, work_dir='out', config={ 'quantize_op_action': {nn.LSTM: 'rewrite'} })
qat_model = quantizer.quantize()
```


#### 如何处理训练和推理计算图不一致的情况?

Expand Down
13 changes: 11 additions & 2 deletions tinynn/graph/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ def parse_config(self, config: typing.Optional[dict]):
'ignore_layerwise_config': False,
'inplace': False,
'override_qconfig_func': None,
'quantize_op_action': {},
}

if config is None:
Expand Down Expand Up @@ -2557,12 +2558,20 @@ def _is_batch_norm_1d(node, custom_data):
self.layerwise_config.yaml_add_eol_comment(f'type: {t}', n)

skip_types = set(k[0] for k in REWRITE_QUANTIZABLE_RULE_LIST if len(k) == 1)
for module_cls, action in self.quantize_op_action.items():
if action == 'rewrite':
skip_types.add(module_cls)
if self.set_quantizable_op_stats:
skip_types |= set(KNOWN_QSTATS.keys())
skip_types_prev = skip_types | set(k[-1] for k in REWRITE_QUANTIZABLE_RULE_LIST if len(k) > 1)
skip_types_next = skip_types | set(k[0] for k in REWRITE_QUANTIZABLE_RULE_LIST if len(k) > 1)

# Add quant/dequant nodes for non-quantizable OPs
disable_quantize_op_list = UNSUPPORTED_PYTORCH_QUANTIZATION_OP_LIST.copy()
for module_cls, action in self.quantize_op_action.items():
if action in ('disable', 'rewrite'):
disable_quantize_op_list[module_cls] = None

def _is_not_quantizable(node, custom_data):
cur_module = node.module
cur_class = type(cur_module)
Expand All @@ -2578,7 +2587,7 @@ def _is_not_quantizable(node, custom_data):
return False
if self.layerwise_config.get(node.unique_name, True) is False:
return True
supported_version = UNSUPPORTED_PYTORCH_QUANTIZATION_OP_LIST.get(cur_module.kind, torch.__version__)
supported_version = disable_quantize_op_list.get(cur_module.kind, torch.__version__)
return supported_version is None or LooseVersion(torch.__version__) < supported_version
else:
if isinstance(cur_module, (torch_q.QuantStub, torch_q.DeQuantStub)):
Expand All @@ -2587,7 +2596,7 @@ def _is_not_quantizable(node, custom_data):
return True
unsupported_types = tuple(
k
for k, v in UNSUPPORTED_PYTORCH_QUANTIZATION_OP_LIST.items()
for k, v in disable_quantize_op_list.items()
if type(k) != str
and k not in Q_MODULES_MAPPING
and (v is None or LooseVersion(torch.__version__) < v)
Expand Down

0 comments on commit 3752f6d

Please sign in to comment.