You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def _update_split_index_mapping(self, split_node: Node):
if split_node.type != ops.OPTYPE.SPLIT:
return
if hasattr(split_node.grad_fn, '_saved_dim'): # this only works for Pytorch>=1.12
# There a issue in some pytorch version, where the _saved_dim is an uninitialized value like 118745347895359
# So we need to check if the _saved_dim is a valid value (<len(_saved_self_sym_sizes) or a nominal value like 20)
if hasattr(split_node.grad_fn, '_saved_self_sym_sizes'):
if split_node.grad_fn._saved_dim<len(split_node.grad_fn._saved_self_sym_sizes) and split_node.grad_fn._saved_dim != 1:
return
else:
THRESHOLD = 20
if split_node.grad_fn._saved_dim<THRESHOLD and split_node.grad_fn._saved_dim>=0 and split_node.grad_fn._saved_dim != 1:
return
# Only Supports 2D/4D tensors
# TODO: Better support for reshape/view/flatten
if hasattr(reshape_node.grad_fn, '_saved_self_sizes'):
size = reshape_node.grad_fn._saved_self_sizes
if (len(size)!=1 and len(size)!=4):
return
else: # legacy version
if not self._2d_4d:
return
`
The text was updated successfully, but these errors were encountered:
请提出你的问题 Please ask your question
在使用torch_pruning工程中,想做一些修改适配paddle。发现在反向传播中,有一些变量在paddle中不知道是什么?恳请版忙解答。
比如在split操作的梯度grad_fn中,pytorch是有一个_saved_dim的变量,paddle是什么呢?
`
`
以及reshape操作的梯度grad_fn中,pytorch有一个_saved_self_sizes变量。
`
def _update_reshape_index_mapping(self, reshape_node: Node):
`
The text was updated successfully, but these errors were encountered: