Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add instruction for exporting inlined constant #8707

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

qihqi
Copy link
Collaborator

@qihqi qihqi commented Feb 13, 2025

No description provided.

@@ -64,6 +64,31 @@ print(stablehlo.mlir_module())
The second to last line we used `jax.ShapedDtypeStruct` to specify the input shape.
You can also pass a numpy array here.

### Inline some weights in generated stablehlo

Suppose that you want to inline some (or all) of the model's weight
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest:
You can inline some or all of your model's weights into the StableHLO graph as constants by exporting a separate function that calls your model.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

into the generated StableHLO graph as constant. You can accomplish it by
exporting a different function that calls your model.

The convention used in `jax.jit` is that, all the input of the `jit`'d python
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest: The convention used in jax.jit is all inputs to jited Python functions are exported as parameters, everything else is inlined as constants.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@qihqi qihqi requested a review from mikegre-google February 13, 2025 22:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants