diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 44a9348346c..7f7ba6d8e31 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -367,8 +367,8 @@ def reset( tensordict_reset = step_tensordict( tensordict_reset, exclude_done=False, - exclude_reward=True, - exclude_action=True, + exclude_reward=False, # some policies may need reward and action at reset time + exclude_action=False, ) if tensordict is not None: tensordict.update(tensordict_reset)