Skip to content

Commit

Permalink
[quantizer] add linear support for CLE (#367)
Browse files Browse the repository at this point in the history
* [quantizer]add linear support for cle

* [example]fix typo for ptq example
  • Loading branch information
zk1998 authored Oct 9, 2024
1 parent 5cf1194 commit a2acdf4
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/quantization/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main_worker(args):
# For per-tensor quantization, if there are many outliers in the weight, CLE can significantly improve the
# quantization accuracy
if args.cle:
cross_layer_equalize(model, dummy_input, get_device())
model = cross_layer_equalize(model, dummy_input, get_device())

# TinyNeuralNetwork provides a PostQuantizer class that may rewrite the graph for and perform model fusion for
# quantization. The model returned by the `quantize` function is ready for quantization calibration.
Expand Down
25 changes: 25 additions & 0 deletions tests/cross_layer_equalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,28 @@ def forward(self, x):
cle_model = cross_layer_equalize(model, dummy_input, torch.device('cpu'), hba_flag=False)
cle_output = cle_model(dummy_input)
torch.testing.assert_allclose(origin_output, cle_output)

def test_cle_linear(self):
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.fc0 = nn.Linear(3, 8)
self.fc1 = nn.Linear(8, 16)
self.fc2 = nn.Linear(16, 32)

def forward(self, x):
fc0 = self.fc0(x)
fc1 = self.fc1(fc0)
fc2 = self.fc2(fc1)
return fc2

torch.manual_seed(10)

dummy_input = torch.randn(1, 3)
model = TestModel()
model.eval()

origin_output = model(dummy_input)
cle_model = cross_layer_equalize(model, dummy_input, torch.device('cpu'), hba_flag=False)
cle_output = cle_model(dummy_input)
torch.testing.assert_allclose(origin_output, cle_output)
16 changes: 12 additions & 4 deletions tinynn/graph/quantization/algorithm/cross_layer_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

log = get_logger(__name__)

cls_support_type = (torch.nn.Conv2d, torch.nn.Conv1d)
cls_support_type = (torch.nn.Conv2d, torch.nn.Conv1d, torch.nn.Linear)
cls_scalable_type = (torch.nn.ReLU, torch.nn.LeakyReLU, torch.nn.PReLU, torch.nn.Identity)


Expand Down Expand Up @@ -89,7 +89,10 @@ def equalize(weight_1, weight_2, group=1, threshold=0.5, s_min=1e-6, s_max=1e6):
)
+ shape_2[1:],
)
weight_2_re = weight_2_re.permute((2, 0, 1, 3, 4))
num_dims = weight_2_re.dim()
assert num_dims >= 3, f"weight_2_re shape dim={num_dims}, <3"
new_order = [2, 0, 1] + list(range(3, num_dims))
weight_2_re = weight_2_re.permute(new_order)
weight_2_re = torch.reshape(weight_2_re, (weight_2_re.shape[0] * weight_2_re.shape[1], -1))
r1 = weight_1_re.abs().max(1).values.double()
r2 = weight_2_re.abs().max(1).values.double()
Expand All @@ -107,7 +110,12 @@ def _weight_equal_helper(cls, threshold=0.5):
layer_pair = [m for n, m in cls]
if len(layer_pair) == 2:
conv_0, conv_1 = layer_pair
weight1, bias1, weight2, groups = conv_0.weight, conv_0.bias, conv_1.weight, conv_1.groups
weight1, bias1, weight2, groups = (
conv_0.weight,
conv_0.bias,
conv_1.weight,
conv_1.groups if hasattr(conv_1, 'groups') else 1,
)
s = equalize(weight1, weight2, group=groups, threshold=threshold)
weight_1 = weight1 / s.reshape([-1] + ([1] * (weight1.ndim - 1)))
weight_2 = torch.reshape(weight2, (groups, weight2.shape[0] // groups) + weight2.shape[1:])
Expand Down Expand Up @@ -150,7 +158,7 @@ def equalize_model(model: nn.Module, dummy_input, threshold=0.5, iters=2) -> Tup
stat_we = model.state_dict()
for k, v in stat_we.items():
p, mod = cur_graph.get_submodule_with_parent_from_name(k)
if isinstance(mod, torch.nn.Conv2d):
if isinstance(mod, cls_support_type):
if k.endswith('.weight'):
after_max = p.abs().max()
elif k.endswith('.bias'):
Expand Down

0 comments on commit a2acdf4

Please sign in to comment.