Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Best way to store State #368

Closed
Jogima-cyber opened this issue Jun 2, 2023 · 2 comments
Closed

Best way to store State #368

Jogima-cyber opened this issue Jun 2, 2023 · 2 comments

Comments

@Jogima-cyber
Copy link

Jogima-cyber commented Jun 2, 2023

Hello, so my issue is related to this issue, I want to do some offline things where I need to reset the simulator to past states, using the generalized backend, and it was stated that I should store the entire State.
My question is how to store this since State is a complex Python structure, the most naïve way would be to store it in a Python list, but this would be awful to sample from this Python list. For me, the best thing to do would be to convert the State structure to a list of many tensors, what do you think about this approach?

@erikfrey
Copy link
Collaborator

erikfrey commented Jun 2, 2023

Hello! Yes, Jax comes with some handy tools to manipulate or even discard the tree structure, for example:

state_leaves = jax.tree_leaves(state)

Gives you just the leaves in a list. There's probably a way to reconstruct the tree from the list, too.

Or if what you want is the list for the batch dimension of State, you can do:

from jax import numpy as jp
states = [jax.tree_map(lambda a: jp.take(a, i, axis=0), x) for i in range(batch_size)]

That kind of stuff. I'm not sure if this is what you're looking for? Not entirely sure what you mean by "awful to sample from"?

@Jogima-cyber
Copy link
Author

Thank you! Yes, that's exactly the kind of thing I was looking for. I'm gonna use the list for the batch dimension of State.

What I meant by "awful to sample from" is the wall-clock time of the operation, I may be wrong, but I think wall-clock time is way better for sampling from tensors structures with libraries such as numpy rather than sampling from Python data structures that are agnostic to the data structure like deque or list. That's why in the end I'm gonna use the tree_leaves to discard the tree structure as you suggested, but reconstruction seems a little bit painful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants