Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch_xla] scan only captures aten operations #8691

Open
tengyifei opened this issue Feb 9, 2025 · 2 comments
Open

[torch_xla] scan only captures aten operations #8691

tengyifei opened this issue Feb 9, 2025 · 2 comments
Labels

Comments

@tengyifei
Copy link
Collaborator

I'm encountering an issue when trying to integrate custom kernels—specifically those built with pallas—and SPMD operations (like xs.mark_sharding) into the scan function. The current design of scan relies on AOTAutograd to capture the forward and backward passes. Because of that, only standard PyTorch ATen operations are recognized in the captured graph.

pallas Kernels Excluded

  • Because pallas kernels are neither standard ATen ops nor automatically recognized by AOTAutograd, they don’t appear in the traced graph. This makes them invisible to scan, even if the kernel code is correct and runs fine outside of a scan.
  • One workaround is to wrap each pallas kernel in a custom ATen op. But that approach adds friction, requiring additional boilerplate to register the op so that PyTorch’s dispatch and AOTAutograd can see it.

SPMD Ops Unrecognized

  • Similar to pallas kernels, calls such as xs.mark_sharding(...) do not appear in the captured graph. The AOTAutograd tracing step sees them as Python calls that do not translate into recognized ATen ops.
  • This prevents us from assigning SPMD partitioning attributes within a scan function, making it impossible to do sharding or other SPMD strategies inside the scanned layer.

Tracing annotations are skipped

  • xp.trace_me(...) also annotates the LazyTensor IR and doesn't have a corresponding aten representation. As a result, the scanned layer won't have tracing annotations.

AOTAutograd vs. LazyTensor C++ approaches

  • scan leverages AOTAutograd to partition the forward and backward passes, enabling advanced features like gradient checkpointing or user-defined partitioning strategies. However, this means the captured graph must be composed entirely of ATen ops recognized by AOTAutograd.
  • On the other hand, the LazyTensor C++ backend (in the XLA stack) can sometimes capture more exotic or low-level operations by intercepting them at the IR level, but that path isn’t used by AOTAutograd-based flows. There’s a trade-off: AOTAutograd provides a pure-PyTorch approach for graph capture and transformation, yet it effectively filters out non-ATen ops.

Desired Behavior

  • Ideally, we want scan to handle any operation recognized by the underlying XLA or LazyTensor stack, including pallas kernels and SPMD operations.
  • If there is a route for AOTAutograd to allow extension ops—like a stable mechanism for capturing pallas kernels or calls like mark_sharding— that would solve the problem. Otherwise, a different capture/trace mechanism might be needed to allow these ops in a scanned function.

Discussion Points

  • Can we extend AOTAutograd so it recognizes pallas kernels and SPMD ops (e.g., by whitelisting custom ops or hooking into the IR generation)?
  • Do we need a custom ATen registration for each pallas kernel? If so, can we streamline that process or document it?
  • Is there a recommended workaround to ensure xs.mark_sharding(...) is picked up in the graph capture?
  • Are there plans for broader extensibility in the AOTAutograd stack, allowing custom or otherwise “non-ATen” ops?
@ysiraichi ysiraichi added enhancement New feature or request SPMD / Distributed labels Feb 10, 2025
@ysiraichi
Copy link
Collaborator

I'm not sure whether we can extend AOTAutograd, or whether we would have to modify it inside PyTorch core. That said, if it's the latter, I think it shouldn't be too hard to do it. By the way, is it only AOTAutograd, or is it also dynamo?

@tengyifei
Copy link
Collaborator Author

@ysiraichi AFAIK this will impact both AOTAutograd and Dynamo. Dynamo transforms the captured graph with AOTAutograd internally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants