Skip to content

Commit

Permalink
[pruner] fix transformer pruner error (#326)
Browse files Browse the repository at this point in the history
* [pruner] fix yaml new_version use

* [pruner] add conv1d,layernorm,constant support, fix OP_add dim_change_forward error
  • Loading branch information
zk1998 authored Jun 5, 2024
1 parent 8571617 commit 32359e9
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 14 deletions.
75 changes: 63 additions & 12 deletions tinynn/graph/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,15 @@ def dim_change_forward(self, center, tensor, dim_changes_i: typing.List, dim_tra
raise e

for tensor_o in self.next_tensors():
# Case [1, c0, c1] + [c0, c1](center_node) -> [1, c0, c1], to keep dim_change_o keep consistent.
if len(tensor_o.shape) > len(tensor.shape) and tensor_o.shape[0] == 1:
old_dim_change_i = dim_changes_i
omitted_dim_len = 1
dim_changes_i = [i + omitted_dim_len for i in dim_changes_i]
for dim_ in old_dim_change_i:
tensor_constraint[dim_ + omitted_dim_len] = tensor_constraint[dim_]
tensor_constraint.pop(dim_)

self.dim_changes_info.update_o(center, tensor_o, dim_changes_i)

for m in self.next_modifiers(tensor_o):
Expand Down Expand Up @@ -1146,9 +1155,9 @@ def modify_input(self, remove_idx):
bn.num_parameters = len(preserve_idx)


class BatchNormChannelModifier(Modifier):
class NormChannelModifier(Modifier):
def __init__(self, node: TraceNode):
super(BatchNormChannelModifier, self).__init__(node)
super(NormChannelModifier, self).__init__(node)
self.prunable = True

def register_mask(self, modifiers, importance, sparsity):
Expand Down Expand Up @@ -1229,13 +1238,23 @@ def modify_input(self, remove_idx):
conv.bias = torch.nn.Parameter(conv.bias + bn_bias)
break

log.info(f'[BN] {self.unique_name()}: channel {bn.num_features} -> {len(preserve_idx)}')
if isinstance(bn, nn.BatchNorm1d) or isinstance(bn, nn.BatchNorm2d):
log.info(f'[BN] {self.unique_name()}: channel {bn.num_features} -> {len(preserve_idx)}')
bn.register_buffer('running_mean', bn.running_mean[preserve_idx])
bn.register_buffer('running_var', bn.running_var[preserve_idx])
bn.num_batches_tracked = bn.num_batches_tracked.zero_()
bn.num_features = len(preserve_idx)
elif isinstance(bn, nn.LayerNorm):
if len(bn.normalized_shape) == 1:
log.info(f'[LN] {self.unique_name()}: channel {bn.normalized_shape} -> ({len(preserve_idx)},)')
bn.normalized_shape = (len(preserve_idx),)
else:
log.error("The Layer Normalization (LN) Modifier supports only one-dimensional normalized_shape.")
else:
log.error("Unsupported Norm Type")

bn.weight = torch.nn.Parameter(bn.weight[preserve_idx])
bn.bias = torch.nn.Parameter(bn.bias[preserve_idx])
bn.register_buffer('running_mean', bn.running_mean[preserve_idx])
bn.register_buffer('running_var', bn.running_var[preserve_idx])
bn.num_batches_tracked = bn.num_batches_tracked.zero_()
bn.num_features = len(preserve_idx)


class ReIndexModifier(Modifier):
Expand Down Expand Up @@ -2190,7 +2209,7 @@ def dim_change_forward(self, center, tensor, dim_changes_i, dim_transform, tenso
assert False


class Conv2dChannelModifier(Modifier):
class ConvChannelModifier(Modifier):
def __init__(self, node: TraceNode):
super().__init__(node)
self.dim_n = 0
Expand Down Expand Up @@ -2416,7 +2435,7 @@ def dim_change_forward(self, center, tensor, dim_changes_i, dim_transform, tenso
m.dim_change_forward(center, self.next_tensors()[0], dim_changes_o, transform, None)


class TransConvChannelModifier(Conv2dChannelModifier):
class TransConvChannelModifier(ConvChannelModifier):
def register_mask(self, modifiers, importance, sparsity):
if self.dim_changes_info.pruned_idx_i:
remove_idx = self.dim_changes_info.pruned_idx_i
Expand Down Expand Up @@ -2464,11 +2483,40 @@ def modify_output(self, remove_idx):
conv.bias = torch.nn.Parameter(conv.bias[preserve_idx])


class ConstantModifier(LinearChannelModifier):
def __init__(self, node: TraceNode):
Modifier.__init__(self, node)
self.output_tensor = self.next_tensors()[0]
self.input_tensor = self.output_tensor
# Pruning operation occurs along the second dimension
self.dim_c = 1
self.prunable = True

def change_dimension(self) -> bool:
dim_changes_o = [self.dim_c]

fill_tensor_by_dim_changes(self.output_tensor, dim_changes_o)

tensor_constraint = self.dim_changes_info.update_o(
self, self.next_tensors()[0], dim_changes_o, update_constraint=True
)

for m in self.next_modifiers():
m.dim_change_forward(self, self.next_tensors()[0], dim_changes_o, None, tensor_constraint)

return True

def register_mask(self, modifiers, importance, sparsity):
Modifier.reset_mask(self)


CHANNEL_MODIFIERS = {
nn.Conv2d: Conv2dChannelModifier,
nn.Conv1d: ConvChannelModifier,
nn.Conv2d: ConvChannelModifier,
nn.Linear: LinearChannelModifier,
nn.ConvTranspose2d: TransConvChannelModifier,
nn.ConvTranspose1d: TransConvChannelModifier,
nn.AvgPool1d: PoolingModifier,
nn.AvgPool2d: PoolingModifier,
nn.AdaptiveAvgPool2d: PoolingModifier,
nn.MaxPool2d: PoolingModifier,
Expand All @@ -2480,9 +2528,11 @@ def modify_output(self, remove_idx):
nn.UpsamplingNearest2d: PoolingModifier,
"interpolate": PoolingModifier,
nn.PReLU: PReLUChannelModifier,
nn.BatchNorm2d: BatchNormChannelModifier,
nn.BatchNorm1d: BatchNormChannelModifier,
nn.BatchNorm2d: NormChannelModifier,
nn.BatchNorm1d: NormChannelModifier,
nn.LayerNorm: NormChannelModifier,
'matmul': MatMulModifier,
'bmm': MatMulModifier,
'cat': CatModifier,
'view': ReshapeModifier,
"flatten": ReIndexModifier,
Expand All @@ -2500,6 +2550,7 @@ def modify_output(self, remove_idx):
nn.RNN: RNNChannelModifier,
nn.GRU: RNNChannelModifier,
nn.LSTM: RNNChannelModifier,
'weight': ConstantModifier,
}


Expand Down
15 changes: 13 additions & 2 deletions tinynn/prune/base_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
except ModuleNotFoundError:
import ruamel.yaml as yaml

NEW_YAML_FLAG = "error_deprecation" in getsource(yaml.load)

log = get_logger(__name__)


Expand Down Expand Up @@ -136,7 +138,11 @@ def load_config(cls, path: str) -> dict:
"""Loads the configuration file and returns it as a dictionary"""

with open(path, 'r') as f:
config = yaml.load(f, Loader=yaml.RoundTripLoader)
if NEW_YAML_FLAG:
yaml_ = yaml.YAML(typ='rt')
config = yaml_.load(f)
else:
config = yaml.load(f, Loader=yaml.RoundTripLoader)
return config

@conditional(lambda: not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0)
Expand All @@ -147,7 +153,12 @@ def generate_config(self, path: str, config: dict = None) -> None:
config = self.config

with open(path, 'w') as f:
yaml.dump(config, f, default_flow_style=False, Dumper=yaml.RoundTripDumper)
if NEW_YAML_FLAG:
yaml_ = yaml.YAML(typ='rt')
yaml_.default_flow_style = False
yaml_.dump(config, f)
else:
yaml.dump(config, f, default_flow_style=False, Dumper=yaml.RoundTripDumper)

def trace(self) -> TraceGraph:
with torch.no_grad():
Expand Down
3 changes: 3 additions & 0 deletions tinynn/prune/oneshot_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def __init__(self, model, dummy_input, config):

self.graph_modifier = modifier.GraphChannelModifier(self.graph, self.center_nodes, self.bn_compensation)

# TODO: 为了保持剪枝精度,目前暂时先将所有的常量参与的剪枝子图取消
self.exclude_ops.append('weight')

for sub_graph in self.graph_modifier.sub_graphs.values():
exclude = False
for m in sub_graph.modifiers:
Expand Down

0 comments on commit 32359e9

Please sign in to comment.