diff --git a/docs/quantization_support.md b/docs/quantization_support.md
index ea6a364c..df991fe4 100644
--- a/docs/quantization_support.md
+++ b/docs/quantization_support.md
@@ -70,6 +70,7 @@ Quantized OPs that are natively not supported by PyTorch (and possibly TFLite).
| `softmax` | For QATQuantizer/PostQuantizer, set `config={"set_quantizable_op_stats": True}`
For TFLiteConverter, set `rewrite_quantizable=True` |
| `sum` | For TFLiteConverter, set `rewrite_quantizable=True` |
| `torch.nn.GLU` | No action needed |
+| `torch.nn.Hardsigmoid` | No action needed |
| `torch.nn.LogSoftmax` | For QATQuantizer/PostQuantizer, set `config={"set_quantizable_op_stats": True}`
For TFLiteConverter, set `rewrite_quantizable=True` |
| `torch.nn.PReLU` | No action needed |
| `torch.nn.SiLU` | No action needed |
diff --git a/tinynn/graph/quantization/modules.py b/tinynn/graph/quantization/modules.py
index 4ab43117..9353bde5 100644
--- a/tinynn/graph/quantization/modules.py
+++ b/tinynn/graph/quantization/modules.py
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.quantized as nnq
+import torch.quantization as torch_q
class QPReLU(nn.Module):
@@ -70,3 +71,21 @@ def __init__(self, glu: nn.GLU) -> None:
def forward(self, input: torch.Tensor) -> torch.Tensor:
slices = torch.chunk(input, 2, self.dim)
return self.f_mul.mul(slices[0], self.sigmoid(slices[1]))
+
+
+class QHardsigmoid(nn.Module):
+ def __init__(self, hardsigmoid: nn.Hardsigmoid) -> None:
+ super().__init__()
+
+ self.f_mul = nnq.FloatFunctional()
+ self.f_add = nnq.FloatFunctional()
+ self.q = torch_q.QuantStub()
+ self.dq = torch_q.DeQuantStub()
+ self.act_hs = nn.Hardsigmoid()
+ self.act_r = nn.ReLU6()
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ x1 = self.f_add.add_scalar(input, 3.0)
+ x2 = self.act_r(x1)
+ x3 = self.q(self.dq(x2))
+ return self.f_mul.mul_scalar(x3, 1 / 6)
diff --git a/tinynn/graph/quantization/quantizer.py b/tinynn/graph/quantization/quantizer.py
index 6204348d..ade794da 100644
--- a/tinynn/graph/quantization/quantizer.py
+++ b/tinynn/graph/quantization/quantizer.py
@@ -23,7 +23,7 @@
FakeQuantizeBFloat16,
FakeQuantizeTFLite,
)
-from tinynn.graph.quantization.modules import QGLU, QPReLU, QSiLU
+from tinynn.graph.quantization.modules import QGLU, QHardsigmoid, QPReLU, QSiLU
from tinynn.graph.quantization.observer import (
HistogramObserverKL,
MinMaxObserver,
@@ -223,6 +223,7 @@
Q_MODULES_MAPPING = {
nn.PReLU: QPReLU,
nn.GLU: QGLU,
+ nn.Hardsigmoid: QHardsigmoid,
}
FUNCTIONAL_MODULE_MAPPING = {