diff --git a/deepmd/pd/model/descriptor/se_atten_v2.py b/deepmd/pd/model/descriptor/se_atten_v2.py index 032661a8e2..6c90846f9a 100644 --- a/deepmd/pd/model/descriptor/se_atten_v2.py +++ b/deepmd/pd/model/descriptor/se_atten_v2.py @@ -258,7 +258,7 @@ def deserialize(cls, data: dict) -> "DescrptSeAttenV2": obj = cls(**data) def t_cvt(xx): - return paddle.to_tensor(xx, dtype=obj.se_atten.prec, device=env.DEVICE) + return paddle.to_tensor(xx, dtype=obj.se_atten.prec, place=env.DEVICE) obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( type_embedding