Skip to content

Commit

Permalink
remove fluid & fix bug in auto_layer_map (#96)
Browse files Browse the repository at this point in the history
* remove fluid

* update

* update
  • Loading branch information
feifei-111 authored Aug 14, 2023
1 parent 87ac80b commit e40e0e6
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
9 changes: 8 additions & 1 deletion padiff/abstracts/marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ def auto_layer_map(self, model_place):
registered = init_pool.registered_base_models if model_place == "base" else init_pool.registered_raw_models

log("Auto set layer_map start searching...")
for layer in self.traversal_for_assign_weight():
for layer in self.traversal_for_auto_layer_map():
if layer.fullname in registered:
print(f"++++ {model_place}_model found `{layer.fullname}` add to layer_map ++++")
_layer_map.append(layer)
self.unassigned_weights_list.add(layer.model)
self.unassigned_weights_list_recursively.add(layer.model)
print()
self.layer_map = _layer_map
return True

def update_black_list_with_class(self, layer_class, recursively=True):
Expand All @@ -107,6 +109,11 @@ def traversal_for_assign_weight(self):
continue
yield model

def traversal_for_auto_layer_map(self):
yield self.proxy_model
for model in traversal_for_assign_weight(self.proxy_model, self):
yield model


def traversal_prototype(fn0, fn1):
# if fn0 returns True, yield current model
Expand Down
2 changes: 1 addition & 1 deletion padiff/abstracts/proxy_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, param, param_type):
def create_from(param):
if isinstance(param, ProxyParam):
return param
elif isinstance(param, paddle.fluid.framework.EagerParamBase):
elif isinstance(param, paddle.framework.io.EagerParamBase):
return PaddleParam(param)
elif isinstance(param, torch.nn.parameter.Parameter):
return TorchParam(param)
Expand Down
5 changes: 1 addition & 4 deletions padiff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
import torch


try:
from paddle.fluid.layers.utils import flatten, pack_sequence_as, map_structure
except:
from paddle.utils import flatten, pack_sequence_as, map_structure
from paddle.utils import flatten, pack_sequence_as, map_structure


"""
Expand Down

0 comments on commit e40e0e6

Please sign in to comment.