From d675ebe4eaa0538487b35a8bff2e87f8574655f7 Mon Sep 17 00:00:00 2001 From: "Fabio M. Graetz, Ph.D." Date: Wed, 12 Feb 2025 17:21:51 +0100 Subject: [PATCH] Doc: Document how to enable distributed error aggregation according to RFC #5598 for pytorch distributed tasks (#1776) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz Co-authored-by: Fabio Grätz --- .../kfpytorch_plugin/pytorch_mnist.py | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/examples/kfpytorch_plugin/kfpytorch_plugin/pytorch_mnist.py b/examples/kfpytorch_plugin/kfpytorch_plugin/pytorch_mnist.py index 2640cb8fd..9fb43fe1f 100644 --- a/examples/kfpytorch_plugin/kfpytorch_plugin/pytorch_mnist.py +++ b/examples/kfpytorch_plugin/kfpytorch_plugin/pytorch_mnist.py @@ -350,6 +350,14 @@ def pytorch_training_wf( # To visualize the outcomes, you can point Tensorboard on your local machine to these storage locations. # ::: # +# :::{note} +# In the context of distributed training, it's important to acknowledge that return values from various workers could potentially vary. +# If you need to regulate which worker's return value gets passed on to subsequent tasks in the workflow, +# you have the option to raise an +# [IgnoreOutputs exception](https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.core.base_task.IgnoreOutputs.html) +# for all remaining ranks. +# ::: +# # ## Pytorch elastic training (torchrun) # # Flyte supports distributed training through [torch elastic](https://pytorch.org/docs/stable/elastic/run.html) using `torchrun`. @@ -388,10 +396,19 @@ def pytorch_training_wf( # # This configuration runs distributed training on two nodes, each with four worker processes. # -# :::{note} -# In the context of distributed training, it's important to acknowledge that return values from various workers could potentially vary. -# If you need to regulate which worker's return value gets passed on to subsequent tasks in the workflow, -# you have the option to raise an -# [IgnoreOutputs exception](https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.core.base_task.IgnoreOutputs.html) -# for all remaining ranks. -# ::: +# ## Error handling for distributed PyTorch tasks +# +# Exceptions occurring in Flyte task pods are propagated to the Flyte backend by writing so-called *error files* into +# a preconfigured location in blob storage. In the case of PyTorch distributed tasks, each failed worker pod tries to write such +# an error file. By default, only a single error file is expected and evaluated by the backend leading to a race condition +# as it is not deterministic which worker pod's error file is considered. Flyte can aggregate the error files of all worker pods +# and use the timestamp of the exceptions to try to determine the root cause error. To enable this behavior, add the following to your +# helm chart values: +# +# ```yaml +# configmap: +# k8s: +# plugins: +# k8s: +# enable-distributed-error-aggregation: true +# ```