-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Action Normalization #57
base: master
Are you sure you want to change the base?
Action Normalization #57
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mostly good! Left some comments, mainly about naming and doc strings.
Also, should we infer hdf5_normlize_action
from the dataset, rather than manually specifying it in the config?
@@ -421,7 +421,7 @@ class RolloutPolicy(object): | |||
""" | |||
Wraps @Algo object to make it easy to run policies in a rollout loop. | |||
""" | |||
def __init__(self, policy, obs_normalization_stats=None): | |||
def __init__(self, policy, obs_normalization_stats=None, action_normalization_stats=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add some comments in the function docstring for action_normalization_stats
? Similar to how it's already done for obs_normalization_stats
@@ -474,4 +475,7 @@ def __call__(self, ob, goal=None): | |||
if goal is not None: | |||
goal = self._prepare_observation(goal) | |||
ac = self.policy.get_action(obs_dict=ob, goal_dict=goal) | |||
return TensorUtils.to_numpy(ac[0]) | |||
ac = TensorUtils.to_numpy(ac) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason for changing ac[0]
to ac
? Can we keep things as ac[0]
?
robomimic/config/base_config.py
Outdated
@@ -156,6 +156,8 @@ class has a default implementation that usually doesn't need to be overriden. | |||
# of each observation in each dimension, computed across the training set. See SequenceDataset.normalize_obs | |||
# in utils/dataset.py for more information. | |||
self.train.hdf5_normalize_obs = False | |||
|
|||
self.train.hdf5_normalize_action = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment to describe the use case (similar to rest of file)
robomimic/utils/dataset.py
Outdated
@@ -30,6 +30,7 @@ def __init__( | |||
hdf5_cache_mode=None, | |||
hdf5_use_swmr=True, | |||
hdf5_normalize_obs=False, | |||
hdf5_normalize_action=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function docstring needs comment for this attribute
@@ -499,6 +499,17 @@ def normalize_obs(obs_dict, obs_normalization_stats): | |||
|
|||
return obs_dict | |||
|
|||
def normalize_actions(actions, action_normalization_stats): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small nitpick: our convention here is to use "
rather than '
. can you make the style change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, both normalize_actions
and unnormalize_actions
need docstring
@@ -366,6 +372,99 @@ def get_obs_normalization_stats(self): | |||
assert self.hdf5_normalize_obs, "not using observation normalization!" | |||
return deepcopy(self.obs_normalization_stats) | |||
|
|||
def normalize_actions(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming of this function may be confused for normalize_actions
in ObsUtils. How about renaming this to get_action_normalization_stats
?
robomimic/utils/dataset.py
Outdated
return obs_traj | ||
|
||
ep = self.dataset.demos[0] | ||
obs_traj = get_obs_traj(ep) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming of obs
here might be confused for observations that we pass into the policy (eg. images). Can we replace this term to be more general? And all other places where we name things with obs
in this function
As per discussion with @snasiriany, this is my current implementation of action normalization which is required for diffusion policy integration. These code are not fully tested and is meant to be a starting point for discussions.
DO NOT MERGE