From 6e60bbe9f6ba7d7c1599281efec622c3c08b221c Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Tue, 4 Feb 2025 02:22:14 -0800 Subject: [PATCH] positional and kwargs corner case fix in for _build_args_kwargs (#2714) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2714 ## Context: Setting None in positional args is colliding with the kwargs in situations when kwargs contains the argument name accepted by the method. Eg : ``` def input_dist(ctx, id_feature_list): ... // If _build_args_kwargs returns: args = [None] kwargs = {'id_feature_list': KJT} input_dist(ctx, *args, **kwargs) // extends to input_dist(ctx, None, id_feature_list=KJT) ``` which results in "TypeError: got multiple values for argument 'id_feature_list'" because id_feature_list is provided both positionally (None) and via kwargs. Reviewed By: sarckk Differential Revision: D68892351 fbshipit-source-id: e02ff8f744bc35b9ab507c057f2a09f6d5e2bb68 --- requirements.txt | 1 + .../tests/test_train_pipelines_utils.py | 87 +++++++++++++++++++ torchrec/distributed/train_pipeline/utils.py | 5 +- 3 files changed, 92 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6d63107dd..6b17aeac6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ torchmetrics==1.0.3 torchx tqdm usort +parameterized # for tests # https://github.com/pytorch/pytorch/blob/b96b1e8cff029bb0a73283e6e7f6cc240313f1dc/requirements.txt#L3 diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py index f23dc0fe0..53fae9001 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py @@ -10,9 +10,11 @@ import copy import enum import unittest +from typing import List from unittest.mock import MagicMock import torch +from parameterized import parameterized from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.test_utils.test_model import ModelInput, TestNegSamplingModule @@ -21,8 +23,10 @@ TrainPipelineSparseDistTestBase, ) from torchrec.distributed.train_pipeline.utils import ( + _build_args_kwargs, _get_node_args, _rewrite_model, + ArgInfo, PipelinedForward, PipelinedPostproc, TrainPipelineContext, @@ -253,6 +257,89 @@ def test_restore_from_snapshot(self) -> None: for source_model_type, recipient_model_type in variants: self._test_restore_from_snapshot(source_model_type, recipient_model_type) + @parameterized.expand( + [ + ( + [ + # Empty attrs to ignore any attr based logic. + ArgInfo( + input_attrs=[ + "", + ], + is_getitems=[False], + postproc_modules=[None], + constants=[None], + name="id_list_features", + ), + ArgInfo( + input_attrs=[], + is_getitems=[], + postproc_modules=[], + constants=[], + name="id_score_list_features", + ), + ], + 0, + ["id_list_features", "id_score_list_features"], + ), + ( + [ + # Empty attrs to ignore any attr based logic. + ArgInfo( + input_attrs=[ + "", + ], + is_getitems=[False], + postproc_modules=[None], + constants=[None], + name=None, + ), + ArgInfo( + input_attrs=[], + is_getitems=[], + postproc_modules=[], + constants=[], + name=None, + ), + ], + 2, + [], + ), + ( + [ + # Empty attrs to ignore any attr based logic. + ArgInfo( + input_attrs=[ + "", + ], + is_getitems=[False], + postproc_modules=[None], + constants=[None], + name=None, + ), + ArgInfo( + input_attrs=[], + is_getitems=[], + postproc_modules=[], + constants=[], + name="id_score_list_features", + ), + ], + 1, + ["id_score_list_features"], + ), + ] + ) + def test_build_args_kwargs( + self, + fwd_args: List[ArgInfo], + args_len: int, + kwarges_keys: List[str], + ) -> None: + args, kwargs = _build_args_kwargs("initial_input", fwd_args) + self.assertEqual(len(args), args_len) + self.assertEqual(list(kwargs.keys()), kwarges_keys) + class TestUtils(unittest.TestCase): def test_get_node_args_helper_call_module_kjt(self) -> None: diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 76cf87370..5b76c1e2d 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -230,7 +230,10 @@ def _build_args_kwargs( else: args.append(arg) else: - args.append(None) + if arg_info.name: + kwargs[arg_info.name] = None + else: + args.append(None) return args, kwargs