You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm encountering an issue when trying to integrate custom kernels—specifically those built with pallas—and SPMD operations (like xs.mark_sharding) into the scan function. The current design of scan relies on AOTAutograd to capture the forward and backward passes. Because of that, only standard PyTorch ATen operations are recognized in the captured graph.
pallas Kernels Excluded
Because pallas kernels are neither standard ATen ops nor automatically recognized by AOTAutograd, they don’t appear in the traced graph. This makes them invisible to scan, even if the kernel code is correct and runs fine outside of a scan.
One workaround is to wrap each pallas kernel in a custom ATen op. But that approach adds friction, requiring additional boilerplate to register the op so that PyTorch’s dispatch and AOTAutograd can see it.
SPMD Ops Unrecognized
Similar to pallas kernels, calls such as xs.mark_sharding(...) do not appear in the captured graph. The AOTAutograd tracing step sees them as Python calls that do not translate into recognized ATen ops.
This prevents us from assigning SPMD partitioning attributes within a scan function, making it impossible to do sharding or other SPMD strategies inside the scanned layer.
Tracing annotations are skipped
xp.trace_me(...) also annotates the LazyTensor IR and doesn't have a corresponding aten representation. As a result, the scanned layer won't have tracing annotations.
AOTAutograd vs. LazyTensor C++ approaches
scan leverages AOTAutograd to partition the forward and backward passes, enabling advanced features like gradient checkpointing or user-defined partitioning strategies. However, this means the captured graph must be composed entirely of ATen ops recognized by AOTAutograd.
On the other hand, the LazyTensor C++ backend (in the XLA stack) can sometimes capture more exotic or low-level operations by intercepting them at the IR level, but that path isn’t used by AOTAutograd-based flows. There’s a trade-off: AOTAutograd provides a pure-PyTorch approach for graph capture and transformation, yet it effectively filters out non-ATen ops.
Desired Behavior
Ideally, we want scan to handle any operation recognized by the underlying XLA or LazyTensor stack, including pallas kernels and SPMD operations.
If there is a route for AOTAutograd to allow extension ops—like a stable mechanism for capturing pallas kernels or calls like mark_sharding— that would solve the problem. Otherwise, a different capture/trace mechanism might be needed to allow these ops in a scanned function.
Discussion Points
Can we extend AOTAutograd so it recognizes pallas kernels and SPMD ops (e.g., by whitelisting custom ops or hooking into the IR generation)?
Do we need a custom ATen registration for each pallas kernel? If so, can we streamline that process or document it?
Is there a recommended workaround to ensure xs.mark_sharding(...) is picked up in the graph capture?
Are there plans for broader extensibility in the AOTAutograd stack, allowing custom or otherwise “non-ATen” ops?
The text was updated successfully, but these errors were encountered:
I'm not sure whether we can extend AOTAutograd, or whether we would have to modify it inside PyTorch core. That said, if it's the latter, I think it shouldn't be too hard to do it. By the way, is it only AOTAutograd, or is it also dynamo?
I'm encountering an issue when trying to integrate custom kernels—specifically those built with pallas—and SPMD operations (like xs.mark_sharding) into the scan function. The current design of scan relies on AOTAutograd to capture the forward and backward passes. Because of that, only standard PyTorch ATen operations are recognized in the captured graph.
pallas Kernels Excluded
SPMD Ops Unrecognized
xs.mark_sharding(...)
do not appear in the captured graph. The AOTAutograd tracing step sees them as Python calls that do not translate into recognized ATen ops.Tracing annotations are skipped
xp.trace_me(...)
also annotates the LazyTensor IR and doesn't have a corresponding aten representation. As a result, the scanned layer won't have tracing annotations.AOTAutograd vs. LazyTensor C++ approaches
Desired Behavior
Discussion Points
The text was updated successfully, but these errors were encountered: