Skip to content

Commit

Permalink
Doc: Document how to enable distributed error aggregation according t…
Browse files Browse the repository at this point in the history
…o RFC #5598 for pytorch distributed tasks (#1776)

Signed-off-by: Fabio Grätz <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
  • Loading branch information
fg91 and Fabio Grätz authored Feb 12, 2025
1 parent 7dc2890 commit d675ebe
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions examples/kfpytorch_plugin/kfpytorch_plugin/pytorch_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,14 @@ def pytorch_training_wf(
# To visualize the outcomes, you can point Tensorboard on your local machine to these storage locations.
# :::
#
# :::{note}
# In the context of distributed training, it's important to acknowledge that return values from various workers could potentially vary.
# If you need to regulate which worker's return value gets passed on to subsequent tasks in the workflow,
# you have the option to raise an
# [IgnoreOutputs exception](https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.core.base_task.IgnoreOutputs.html)
# for all remaining ranks.
# :::
#
# ## Pytorch elastic training (torchrun)
#
# Flyte supports distributed training through [torch elastic](https://pytorch.org/docs/stable/elastic/run.html) using `torchrun`.
Expand Down Expand Up @@ -388,10 +396,19 @@ def pytorch_training_wf(
#
# This configuration runs distributed training on two nodes, each with four worker processes.
#
# :::{note}
# In the context of distributed training, it's important to acknowledge that return values from various workers could potentially vary.
# If you need to regulate which worker's return value gets passed on to subsequent tasks in the workflow,
# you have the option to raise an
# [IgnoreOutputs exception](https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.core.base_task.IgnoreOutputs.html)
# for all remaining ranks.
# :::
# ## Error handling for distributed PyTorch tasks
#
# Exceptions occurring in Flyte task pods are propagated to the Flyte backend by writing so-called *error files* into
# a preconfigured location in blob storage. In the case of PyTorch distributed tasks, each failed worker pod tries to write such
# an error file. By default, only a single error file is expected and evaluated by the backend leading to a race condition
# as it is not deterministic which worker pod's error file is considered. Flyte can aggregate the error files of all worker pods
# and use the timestamp of the exceptions to try to determine the root cause error. To enable this behavior, add the following to your
# helm chart values:
#
# ```yaml
# configmap:
# k8s:
# plugins:
# k8s:
# enable-distributed-error-aggregation: true
# ```

0 comments on commit d675ebe

Please sign in to comment.