diff --git a/docs/quantization_support.md b/docs/quantization_support.md index ea6a364c..df991fe4 100644 --- a/docs/quantization_support.md +++ b/docs/quantization_support.md @@ -70,6 +70,7 @@ Quantized OPs that are natively not supported by PyTorch (and possibly TFLite). | `softmax` | For QATQuantizer/PostQuantizer, set `config={"set_quantizable_op_stats": True}`
For TFLiteConverter, set `rewrite_quantizable=True` | | `sum` | For TFLiteConverter, set `rewrite_quantizable=True` | | `torch.nn.GLU` | No action needed | +| `torch.nn.Hardsigmoid` | No action needed | | `torch.nn.LogSoftmax` | For QATQuantizer/PostQuantizer, set `config={"set_quantizable_op_stats": True}`
For TFLiteConverter, set `rewrite_quantizable=True` | | `torch.nn.PReLU` | No action needed | | `torch.nn.SiLU` | No action needed | diff --git a/tinynn/graph/quantization/modules.py b/tinynn/graph/quantization/modules.py index 4ab43117..9353bde5 100644 --- a/tinynn/graph/quantization/modules.py +++ b/tinynn/graph/quantization/modules.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.quantized as nnq +import torch.quantization as torch_q class QPReLU(nn.Module): @@ -70,3 +71,21 @@ def __init__(self, glu: nn.GLU) -> None: def forward(self, input: torch.Tensor) -> torch.Tensor: slices = torch.chunk(input, 2, self.dim) return self.f_mul.mul(slices[0], self.sigmoid(slices[1])) + + +class QHardsigmoid(nn.Module): + def __init__(self, hardsigmoid: nn.Hardsigmoid) -> None: + super().__init__() + + self.f_mul = nnq.FloatFunctional() + self.f_add = nnq.FloatFunctional() + self.q = torch_q.QuantStub() + self.dq = torch_q.DeQuantStub() + self.act_hs = nn.Hardsigmoid() + self.act_r = nn.ReLU6() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + x1 = self.f_add.add_scalar(input, 3.0) + x2 = self.act_r(x1) + x3 = self.q(self.dq(x2)) + return self.f_mul.mul_scalar(x3, 1 / 6) diff --git a/tinynn/graph/quantization/quantizer.py b/tinynn/graph/quantization/quantizer.py index 6204348d..ade794da 100644 --- a/tinynn/graph/quantization/quantizer.py +++ b/tinynn/graph/quantization/quantizer.py @@ -23,7 +23,7 @@ FakeQuantizeBFloat16, FakeQuantizeTFLite, ) -from tinynn.graph.quantization.modules import QGLU, QPReLU, QSiLU +from tinynn.graph.quantization.modules import QGLU, QHardsigmoid, QPReLU, QSiLU from tinynn.graph.quantization.observer import ( HistogramObserverKL, MinMaxObserver, @@ -223,6 +223,7 @@ Q_MODULES_MAPPING = { nn.PReLU: QPReLU, nn.GLU: QGLU, + nn.Hardsigmoid: QHardsigmoid, } FUNCTIONAL_MODULE_MAPPING = {