From dfdbee865240d0885d5c0ffc23685a04141f5e63 Mon Sep 17 00:00:00 2001 From: facebook-github-bot Date: Thu, 18 Nov 2021 09:42:53 -0800 Subject: [PATCH] Initial commit fbshipit-source-id: 8f9686235729bb0aa9e03e3dbf73f74e75932b3f --- .github/workflows/pre-commit.yaml | 20 + .pre-commit-config.yaml | 16 + pyproject.toml | 3 + setup.py | 9 + test_installation.py | 36 + test_installation_main.py | 120 ++ torchrec/__init__.py | 23 + torchrec/datasets/__init__.py | 6 + torchrec/datasets/criteo.py | 126 ++ torchrec/datasets/movielens.py | 133 ++ torchrec/datasets/random.py | 147 ++ torchrec/datasets/tests/__init__.py | 0 torchrec/datasets/tests/test_criteo.py | 173 +++ torchrec/datasets/tests/test_movielens.py | 129 ++ torchrec/datasets/tests/test_utils.py | 133 ++ torchrec/datasets/utils.py | 340 +++++ torchrec/distributed/__init__.py | 22 + torchrec/distributed/collective_utils.py | 66 + torchrec/distributed/comm.py | 136 ++ torchrec/distributed/comm_ops.py | 892 ++++++++++++ torchrec/distributed/cw_sharding.py | 74 + torchrec/distributed/dist_data.py | 543 +++++++ torchrec/distributed/dp_sharding.py | 194 +++ torchrec/distributed/embedding.py | 102 ++ torchrec/distributed/embedding_lookup.py | 1251 +++++++++++++++++ torchrec/distributed/embedding_sharding.py | 368 +++++ torchrec/distributed/embedding_types.py | 267 ++++ torchrec/distributed/embeddingbag.py | 800 +++++++++++ .../distributed/grouped_position_weighted.py | 98 ++ torchrec/distributed/model_parallel.py | 390 +++++ torchrec/distributed/planner/__init__.py | 6 + .../distributed/planner/cost_functions.py | 54 + .../distributed/planner/embedding_planner.py | 515 +++++++ .../distributed/planner/new/calculators.py | 333 +++++ torchrec/distributed/planner/new/constants.py | 47 + .../distributed/planner/new/enumerators.py | 581 ++++++++ .../distributed/planner/new/partitioners.py | 262 ++++ torchrec/distributed/planner/new/placers.py | 200 +++ torchrec/distributed/planner/new/planners.py | 119 ++ torchrec/distributed/planner/new/rankers.py | 108 ++ torchrec/distributed/planner/new/stats.py | 196 +++ .../planner/new/tests/test_calculators.py | 147 ++ .../planner/new/tests/test_enumerators.py | 586 ++++++++ .../planner/new/tests/test_partitioners.py | 259 ++++ .../planner/new/tests/test_placers.py | 105 ++ .../planner/new/tests/test_rankers.py | 129 ++ torchrec/distributed/planner/new/types.py | 373 +++++ .../distributed/planner/parameter_sharding.py | 282 ++++ .../planner/tests/test_embedding_planner.py | 930 ++++++++++++ torchrec/distributed/planner/types.py | 154 ++ torchrec/distributed/planner/utils.py | 292 ++++ torchrec/distributed/rw_sharding.py | 349 +++++ .../tests/collective_utils_test.py | 119 ++ torchrec/distributed/tests/test_comm.py | 151 ++ torchrec/distributed/tests/test_dist_data.py | 411 ++++++ .../distributed/tests/test_fused_optim.py | 297 ++++ .../distributed/tests/test_lazy_awaitable.py | 248 ++++ torchrec/distributed/tests/test_model.py | 473 +++++++ .../distributed/tests/test_model_parallel.py | 645 +++++++++ .../tests/test_model_parallel_base.py | 275 ++++ .../tests/test_quant_model_parallel.py | 171 +++ .../distributed/tests/test_train_pipeline.py | 265 ++++ torchrec/distributed/tests/test_utils.py | 316 +++++ torchrec/distributed/train_pipeline.py | 509 +++++++ torchrec/distributed/tw_sharding.py | 314 +++++ torchrec/distributed/twrw_sharding.py | 452 ++++++ torchrec/distributed/types.py | 542 +++++++ torchrec/distributed/utils.py | 123 ++ torchrec/examples/__init__.py | 0 torchrec/examples/dlrm/.torchxconfig | 20 + torchrec/examples/dlrm/README.MD | 11 + torchrec/examples/dlrm/__init__.py | 0 torchrec/examples/dlrm/dlrm_main.py | 194 +++ torchrec/examples/dlrm/modules/__init__.py | 0 torchrec/examples/dlrm/modules/dlrm_train.py | 71 + .../examples/dlrm/tests/test_dlrm_main.py | 47 + .../examples/notebooks/criteo_tutorial.ipynb | 869 ++++++++++++ .../notebooks/movielens_tutorial.ipynb | 639 +++++++++ torchrec/fx/__init__.py | 3 + torchrec/fx/tests/test_tracer.py | 286 ++++ torchrec/fx/tracer.py | 44 + torchrec/linter/module_linter.py | 276 ++++ torchrec/linter/tests/test_module_linter.py | 283 ++++ torchrec/models/__init__.py | 0 torchrec/models/deepfm.py | 325 +++++ torchrec/models/dlrm.py | 360 +++++ torchrec/models/tests/test_deepfm.py | 194 +++ torchrec/models/tests/test_dlrm.py | 538 +++++++ torchrec/modules/__init__.py | 1 + torchrec/modules/activation.py | 46 + torchrec/modules/crossnet.py | 401 ++++++ torchrec/modules/deepfm.py | 189 +++ torchrec/modules/embedding_configs.py | 80 ++ torchrec/modules/embedding_modules.py | 310 ++++ torchrec/modules/feature_processor.py | 52 + torchrec/modules/lazy_extension.py | 248 ++++ torchrec/modules/mlp.py | 161 +++ torchrec/modules/score_learning.py | 89 ++ torchrec/modules/tests/__init__.py | 0 torchrec/modules/tests/test_activation.py | 27 + torchrec/modules/tests/test_code_quality.py | 30 + torchrec/modules/tests/test_crossnet.py | 170 +++ torchrec/modules/tests/test_deepfm.py | 166 +++ .../modules/tests/test_embedding_modules.py | 264 ++++ torchrec/modules/tests/test_lazy_extension.py | 297 ++++ torchrec/modules/tests/test_mlp.py | 108 ++ torchrec/modules/tests/test_score_learning.py | 51 + torchrec/modules/utils.py | 112 ++ torchrec/optim/__init__.py | 11 + torchrec/optim/clipping.py | 48 + torchrec/optim/fused.py | 37 + torchrec/optim/keyed.py | 313 +++++ torchrec/optim/tests/test_clipping.py | 191 +++ torchrec/optim/tests/test_keyed.py | 212 +++ torchrec/optim/tests/test_utils.py | 14 + torchrec/optim/tests/test_warmup.py | 69 + torchrec/optim/warmup.py | 142 ++ torchrec/quant/__init__.py | 3 + torchrec/quant/embedding_modules.py | 219 +++ .../quant/tests/test_embedding_modules.py | 67 + torchrec/sparse/__init__.py | 0 torchrec/sparse/jagged_tensor.py | 1018 ++++++++++++++ torchrec/sparse/tests/__init__.py | 0 torchrec/sparse/tests/test_jagged_tensor.py | 846 +++++++++++ torchrec/sparse/tests/tests_utils.py | 40 + torchrec/tests/__init__.py | 0 torchrec/tests/utils.py | 96 ++ torchrec/types.py | 41 + 128 files changed, 28284 insertions(+) create mode 100644 .github/workflows/pre-commit.yaml create mode 100644 .pre-commit-config.yaml create mode 100644 pyproject.toml create mode 100644 setup.py create mode 100644 test_installation.py create mode 100644 test_installation_main.py create mode 100644 torchrec/__init__.py create mode 100644 torchrec/datasets/__init__.py create mode 100644 torchrec/datasets/criteo.py create mode 100644 torchrec/datasets/movielens.py create mode 100644 torchrec/datasets/random.py create mode 100644 torchrec/datasets/tests/__init__.py create mode 100644 torchrec/datasets/tests/test_criteo.py create mode 100644 torchrec/datasets/tests/test_movielens.py create mode 100644 torchrec/datasets/tests/test_utils.py create mode 100644 torchrec/datasets/utils.py create mode 100644 torchrec/distributed/__init__.py create mode 100644 torchrec/distributed/collective_utils.py create mode 100644 torchrec/distributed/comm.py create mode 100644 torchrec/distributed/comm_ops.py create mode 100644 torchrec/distributed/cw_sharding.py create mode 100644 torchrec/distributed/dist_data.py create mode 100644 torchrec/distributed/dp_sharding.py create mode 100644 torchrec/distributed/embedding.py create mode 100644 torchrec/distributed/embedding_lookup.py create mode 100644 torchrec/distributed/embedding_sharding.py create mode 100644 torchrec/distributed/embedding_types.py create mode 100644 torchrec/distributed/embeddingbag.py create mode 100644 torchrec/distributed/grouped_position_weighted.py create mode 100644 torchrec/distributed/model_parallel.py create mode 100644 torchrec/distributed/planner/__init__.py create mode 100644 torchrec/distributed/planner/cost_functions.py create mode 100644 torchrec/distributed/planner/embedding_planner.py create mode 100644 torchrec/distributed/planner/new/calculators.py create mode 100644 torchrec/distributed/planner/new/constants.py create mode 100644 torchrec/distributed/planner/new/enumerators.py create mode 100644 torchrec/distributed/planner/new/partitioners.py create mode 100644 torchrec/distributed/planner/new/placers.py create mode 100644 torchrec/distributed/planner/new/planners.py create mode 100644 torchrec/distributed/planner/new/rankers.py create mode 100644 torchrec/distributed/planner/new/stats.py create mode 100644 torchrec/distributed/planner/new/tests/test_calculators.py create mode 100644 torchrec/distributed/planner/new/tests/test_enumerators.py create mode 100644 torchrec/distributed/planner/new/tests/test_partitioners.py create mode 100644 torchrec/distributed/planner/new/tests/test_placers.py create mode 100644 torchrec/distributed/planner/new/tests/test_rankers.py create mode 100644 torchrec/distributed/planner/new/types.py create mode 100644 torchrec/distributed/planner/parameter_sharding.py create mode 100644 torchrec/distributed/planner/tests/test_embedding_planner.py create mode 100644 torchrec/distributed/planner/types.py create mode 100644 torchrec/distributed/planner/utils.py create mode 100644 torchrec/distributed/rw_sharding.py create mode 100644 torchrec/distributed/tests/collective_utils_test.py create mode 100644 torchrec/distributed/tests/test_comm.py create mode 100644 torchrec/distributed/tests/test_dist_data.py create mode 100644 torchrec/distributed/tests/test_fused_optim.py create mode 100644 torchrec/distributed/tests/test_lazy_awaitable.py create mode 100644 torchrec/distributed/tests/test_model.py create mode 100644 torchrec/distributed/tests/test_model_parallel.py create mode 100644 torchrec/distributed/tests/test_model_parallel_base.py create mode 100644 torchrec/distributed/tests/test_quant_model_parallel.py create mode 100644 torchrec/distributed/tests/test_train_pipeline.py create mode 100644 torchrec/distributed/tests/test_utils.py create mode 100644 torchrec/distributed/train_pipeline.py create mode 100644 torchrec/distributed/tw_sharding.py create mode 100644 torchrec/distributed/twrw_sharding.py create mode 100644 torchrec/distributed/types.py create mode 100644 torchrec/distributed/utils.py create mode 100644 torchrec/examples/__init__.py create mode 100644 torchrec/examples/dlrm/.torchxconfig create mode 100644 torchrec/examples/dlrm/README.MD create mode 100644 torchrec/examples/dlrm/__init__.py create mode 100644 torchrec/examples/dlrm/dlrm_main.py create mode 100644 torchrec/examples/dlrm/modules/__init__.py create mode 100644 torchrec/examples/dlrm/modules/dlrm_train.py create mode 100644 torchrec/examples/dlrm/tests/test_dlrm_main.py create mode 100644 torchrec/examples/notebooks/criteo_tutorial.ipynb create mode 100644 torchrec/examples/notebooks/movielens_tutorial.ipynb create mode 100644 torchrec/fx/__init__.py create mode 100644 torchrec/fx/tests/test_tracer.py create mode 100644 torchrec/fx/tracer.py create mode 100644 torchrec/linter/module_linter.py create mode 100644 torchrec/linter/tests/test_module_linter.py create mode 100644 torchrec/models/__init__.py create mode 100644 torchrec/models/deepfm.py create mode 100644 torchrec/models/dlrm.py create mode 100644 torchrec/models/tests/test_deepfm.py create mode 100644 torchrec/models/tests/test_dlrm.py create mode 100644 torchrec/modules/__init__.py create mode 100644 torchrec/modules/activation.py create mode 100644 torchrec/modules/crossnet.py create mode 100644 torchrec/modules/deepfm.py create mode 100644 torchrec/modules/embedding_configs.py create mode 100644 torchrec/modules/embedding_modules.py create mode 100644 torchrec/modules/feature_processor.py create mode 100644 torchrec/modules/lazy_extension.py create mode 100644 torchrec/modules/mlp.py create mode 100644 torchrec/modules/score_learning.py create mode 100644 torchrec/modules/tests/__init__.py create mode 100644 torchrec/modules/tests/test_activation.py create mode 100644 torchrec/modules/tests/test_code_quality.py create mode 100644 torchrec/modules/tests/test_crossnet.py create mode 100644 torchrec/modules/tests/test_deepfm.py create mode 100644 torchrec/modules/tests/test_embedding_modules.py create mode 100644 torchrec/modules/tests/test_lazy_extension.py create mode 100644 torchrec/modules/tests/test_mlp.py create mode 100644 torchrec/modules/tests/test_score_learning.py create mode 100644 torchrec/modules/utils.py create mode 100644 torchrec/optim/__init__.py create mode 100644 torchrec/optim/clipping.py create mode 100644 torchrec/optim/fused.py create mode 100644 torchrec/optim/keyed.py create mode 100644 torchrec/optim/tests/test_clipping.py create mode 100644 torchrec/optim/tests/test_keyed.py create mode 100644 torchrec/optim/tests/test_utils.py create mode 100644 torchrec/optim/tests/test_warmup.py create mode 100644 torchrec/optim/warmup.py create mode 100644 torchrec/quant/__init__.py create mode 100644 torchrec/quant/embedding_modules.py create mode 100644 torchrec/quant/tests/test_embedding_modules.py create mode 100644 torchrec/sparse/__init__.py create mode 100644 torchrec/sparse/jagged_tensor.py create mode 100644 torchrec/sparse/tests/__init__.py create mode 100644 torchrec/sparse/tests/test_jagged_tensor.py create mode 100644 torchrec/sparse/tests/tests_utils.py create mode 100644 torchrec/tests/__init__.py create mode 100644 torchrec/tests/utils.py create mode 100644 torchrec/types.py diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 000000000..a8fab7333 --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,20 @@ +name: pre-commit + +on: + push: + branches: [master] + pull_request: + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: 3.8 + architecture: x64 + - name: Checkout Torchrec + uses: actions/checkout@v2 + - name: Run pre-commit + uses: pre-commit/action@v2.0.3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..c7f310e60 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,16 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: check-toml + - id: check-yaml + exclude: packaging/.* + - id: end-of-file-fixer + + - repo: https://github.com/omnilib/ufmt + rev: v1.3.0 + hooks: + - id: ufmt + additional_dependencies: + - black == 21.9b0 + - usort == 0.6.4 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..9f90a09d2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[tool.usort] + +first_party_detection = false diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..63604d953 --- /dev/null +++ b/setup.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 + +from setuptools import setup, find_packages + +# Minimal setup configuration. +setup( + name="torchrec", + packages=find_packages(exclude=("*tests",)), +) diff --git a/test_installation.py b/test_installation.py new file mode 100644 index 000000000..71664dd59 --- /dev/null +++ b/test_installation.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 + +import os + +import torchx.specs as specs +from torchx.components.base import torch_dist_role +from torchx.specs.api import Resource + + +def test_installation() -> specs.AppDef: + cwd = os.getcwd() + entrypoint = os.path.join(cwd, "test_installation_main.py") + + user = os.environ.get("USER") + image = f"/data/home/{user}" + + return specs.AppDef( + name="test_installation", + roles=[ + torch_dist_role( + name="trainer", + image=image, + # AWS p4d instance (https://aws.amazon.com/ec2/instance-types/p4/). + resource=Resource( + cpu=96, + gpu=8, + memMB=-1, + ), + num_replicas=1, + entrypoint=entrypoint, + nproc_per_node="1", + rdzv_backend="c10d", + args=[], + ), + ], + ) diff --git a/test_installation_main.py b/test_installation_main.py new file mode 100644 index 000000000..7753ee5cf --- /dev/null +++ b/test_installation_main.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 + +import os +import sys +from typing import List, Iterator + +import torch +import torch.distributed as dist +from torchrec import EmbeddingBagCollection +from torchrec import KeyedJaggedTensor +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.models.dlrm import DLRM +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.optim.keyed import KeyedOptimizerWrapper + + +class RandomIterator(Iterator): + def __init__( + self, batch_size: int, num_dense: int, num_sparse: int, num_embeddings: int + ) -> None: + self.batch_size = batch_size + self.num_dense = num_dense + self.num_sparse = num_sparse + self.sparse_keys = [f"feature{id}" for id in range(self.num_sparse)] + self.num_embeddings = num_embeddings + self.num_ids_per_feature = 3 + self.num_ids_to_generate = ( + self.num_sparse * self.num_ids_per_feature * self.batch_size + ) + + def __next__(self) -> (torch.Tensor, KeyedJaggedTensor, torch.Tensor): + float_features = torch.randn( + self.batch_size, + self.num_dense, + ) + labels = torch.randint( + low=0, + high=2, + size=(self.batch_size,), + ) + sparse_ids = torch.randint( + high=self.num_sparse, + size=(self.num_ids_to_generate,), + ) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=self.sparse_keys, + values=sparse_ids, + offsets=torch.tensor( + list(range(0, self.num_ids_to_generate + 1, self.num_ids_per_feature)), + dtype=torch.int32, + ), + ) + return (float_features, sparse_features, labels) + + +def main(argv: List[str]) -> None: + batch_size = 1024 + num_dense = 1000 + num_sparse = 20 + num_embeddings = 1000000 + + configs = [ + EmbeddingBagConfig( + name=f"table{id}", + embedding_dim=64, + num_embeddings=num_embeddings, + feature_names=[f"feature{id}"], + ) + for id in range(num_sparse) + ] + + rank = int(os.environ["LOCAL_RANK"]) + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + backend = "nccl" + torch.cuda.set_device(device) + else: + raise Exception("Cuda not available") + + if not torch.distributed.is_initialized(): + dist.init_process_group(backend=backend) + + model = DLRM( + embedding_bag_collection=EmbeddingBagCollection( + tables=configs, device=torch.device("meta") + ), + dense_in_features=num_dense, + dense_arch_layer_sizes=[500, 64], + over_arch_layer_sizes=[32, 16, 1], + dense_device=device, + ) + model = DistributedModelParallel( + module=model, + device=device, + ) + optimizer = KeyedOptimizerWrapper( + dict(model.named_parameters()), + lambda params: torch.optim.SGD(params, lr=0.01), + ) + + random_iterator = RandomIterator(batch_size, num_dense, num_sparse, num_embeddings) + loss_fn = torch.nn.BCEWithLogitsLoss() + for _ in range(10): + (dense_features, sparse_features, labels) = next(random_iterator) + dense_features = dense_features.to(device) + sparse_features = sparse_features.to(device) + output = model(dense_features, sparse_features) + labels = labels.to(device) + loss = loss_fn(output.squeeze(), labels.float()) + torch.sum(loss, dim=0).backward() + optimizer.zero_grad() + optimizer.step() + + print( + "\033[92m" + "Successfully ran a few epochs for DLRM. Installation looks good!" + ) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/torchrec/__init__.py b/torchrec/__init__.py new file mode 100644 index 000000000..96da1a766 --- /dev/null +++ b/torchrec/__init__.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 + +import torchrec.distributed # noqa +import torchrec.quant # noqa +from torchrec.fx import tracer # noqa +from torchrec.modules.embedding_configs import ( # noqa + EmbeddingBagConfig, + EmbeddingConfig, + DataType, + PoolingType, +) +from torchrec.modules.embedding_modules import ( # noqa + EmbeddingBagCollection, + EmbeddingCollection, + EmbeddingBagCollectionInterface, +) # noqa +from torchrec.modules.score_learning import PositionWeightsAttacher # noqa +from torchrec.sparse.jagged_tensor import ( # noqa + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, +) +from torchrec.types import Pipelineable, Multistreamable # noqa diff --git a/torchrec/datasets/__init__.py b/torchrec/datasets/__init__.py new file mode 100644 index 000000000..0582b2e49 --- /dev/null +++ b/torchrec/datasets/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +import torchrec.datasets.criteo # noqa +import torchrec.datasets.movielens # noqa +import torchrec.datasets.random # noqa +import torchrec.datasets.utils # noqa diff --git a/torchrec/datasets/criteo.py b/torchrec/datasets/criteo.py new file mode 100644 index 000000000..ed716029b --- /dev/null +++ b/torchrec/datasets/criteo.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 + +from typing import ( + Iterator, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Union, +) + +import torch +import torch.utils.data.datapipes as dp +from torch.utils.data import IterDataPipe +from torchrec.datasets.utils import LoadFiles, ReadLinesFromCSV, safe_cast + + +INT_FEATURE_COUNT = 13 +CAT_FEATURE_COUNT = 26 +DEFAULT_LABEL_NAME = "label" +DEFAULT_INT_NAMES: List[str] = [f"int_{idx}" for idx in range(INT_FEATURE_COUNT)] +DEFAULT_CAT_NAMES: List[str] = [f"cat_{idx}" for idx in range(CAT_FEATURE_COUNT)] +DEFAULT_COLUMN_NAMES: List[str] = [ + DEFAULT_LABEL_NAME, + *DEFAULT_INT_NAMES, + *DEFAULT_CAT_NAMES, +] + +COLUMN_TYPE_CASTERS: List[Callable[[Union[int, str]], Union[int, str]]] = [ + lambda val: safe_cast(val, int, 0), + *(lambda val: safe_cast(val, int, 0) for _ in range(INT_FEATURE_COUNT)), + *(lambda val: safe_cast(val, str, "") for _ in range(CAT_FEATURE_COUNT)), +] + + +def _default_row_mapper(example: List[str]) -> Dict[str, Union[int, str]]: + column_names = reversed(DEFAULT_COLUMN_NAMES) + column_type_casters = reversed(COLUMN_TYPE_CASTERS) + return { + next(column_names): next(column_type_casters)(val) for val in reversed(example) + } + + +class CriteoIterDataPipe(IterDataPipe): + def __init__( + self, + paths: Iterable[str], + *, + # pyre-ignore[2] + row_mapper: Optional[Callable[[List[str]], Any]] = _default_row_mapper, + # pyre-ignore[2] + **open_kw, + ) -> None: + self.paths = paths + self.row_mapper = row_mapper + self.open_kw: Any = open_kw # pyre-ignore[4] + + # pyre-ignore[3] + def __iter__(self) -> Iterator[Any]: + worker_info = torch.utils.data.get_worker_info() + paths = self.paths + if worker_info is not None: + paths = ( + path + for (idx, path) in enumerate(paths) + if idx % worker_info.num_workers == worker_info.id + ) + datapipe = LoadFiles(paths, mode="r", **self.open_kw) + datapipe = ReadLinesFromCSV(datapipe, delimiter="\t") + if self.row_mapper: + datapipe = dp.iter.Mapper(datapipe, self.row_mapper) + yield from datapipe + + +def criteo_terabyte( + paths: Iterable[str], + *, + # pyre-ignore[2] + row_mapper: Optional[Callable[[List[str]], Any]] = _default_row_mapper, + # pyre-ignore[2] + **open_kw, +) -> IterDataPipe: + """`Criteo 1TB Click Logs `_ Dataset + Args: + paths (str): local paths to TSV files that constitute the Criteo 1TB dataset. + row_mapper (Optional[Callable[[List[str]], Any]]): function to apply to each split TSV line. + open_kw: options to pass to underlying invocation of iopath.common.file_io.PathManager.open. + + Example: + >>> datapipe = criteo_terabyte( + >>> ("/home/datasets/criteo/day_0.tsv", "/home/datasets/criteo/day_1.tsv") + >>> ) + >>> datapipe = dp.iter.Batcher(datapipe, 100) + >>> datapipe = dp.iter.Collator(datapipe) + >>> batch = next(iter(datapipe)) + """ + return CriteoIterDataPipe(paths, row_mapper=row_mapper, **open_kw) + + +def criteo_kaggle( + path: str, + *, + # pyre-ignore[2] + row_mapper: Optional[Callable[[List[str]], Any]] = _default_row_mapper, + # pyre-ignore[2] + **open_kw, +) -> IterDataPipe: + """`Kaggle/Criteo Display Advertising `_ Dataset + Args: + root (str): local path to train or test dataset file. + row_mapper (Optional[Callable[[List[str]], Any]]): function to apply to each split TSV line. + open_kw: options to pass to underlying invocation of iopath.common.file_io.PathManager.open. + + Example: + >>> train_datapipe = criteo_kaggle( + >>> "/home/datasets/criteo_kaggle/train.txt", + >>> ) + >>> example = next(iter(train_datapipe)) + >>> test_datapipe = criteo_kaggle( + >>> "/home/datasets/criteo_kaggle/test.txt", + >>> ) + >>> example = next(iter(test_datapipe)) + """ + return CriteoIterDataPipe((path,), row_mapper=row_mapper, **open_kw) diff --git a/torchrec/datasets/movielens.py b/torchrec/datasets/movielens.py new file mode 100644 index 000000000..c51f1d20b --- /dev/null +++ b/torchrec/datasets/movielens.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 + +import os +from typing import Any, Callable, Dict, List, Optional, Union + +from torch.utils.data import IterDataPipe +from torchrec.datasets.utils import LoadFiles, ReadLinesFromCSV, safe_cast + +RATINGS_FILENAME = "ratings.csv" +MOVIES_FILENAME = "movies.csv" + +DEFAULT_RATINGS_COLUMN_NAMES: List[str] = ["userId", "movieId", "rating", "timestamp"] +DEFAULT_MOVIES_COLUMN_NAMES: List[str] = ["movieId", "title", "genres"] +DEFAULT_COLUMN_NAMES: List[str] = ( + DEFAULT_RATINGS_COLUMN_NAMES + DEFAULT_MOVIES_COLUMN_NAMES[1:] +) + +COLUMN_TYPE_CASTERS: List[ + Callable[[Union[float, int, str]], Union[float, int, str]] +] = [ + lambda val: safe_cast(val, int, 0), + lambda val: safe_cast(val, int, 0), + lambda val: safe_cast(val, float, 0.0), + lambda val: safe_cast(val, int, 0), + lambda val: safe_cast(val, str, ""), + lambda val: safe_cast(val, str, ""), +] + + +def _default_row_mapper(example: List[str]) -> Dict[str, Union[float, int, str]]: + return { + DEFAULT_COLUMN_NAMES[idx]: COLUMN_TYPE_CASTERS[idx](val) + for idx, val in enumerate(example) + } + + +def _join_with_movies(datapipe: IterDataPipe, root: str) -> IterDataPipe: + movies_path = os.path.join(root, MOVIES_FILENAME) + movies_datapipe = LoadFiles((movies_path,), mode="r") + movies_datapipe = ReadLinesFromCSV( + movies_datapipe, + skip_first_line=True, + delimiter=",", + ) + movie_id_to_movie: Dict[str, List[str]] = { + row[0]: row[1:] for row in movies_datapipe + } + + def join_rating_movie(val: List[str]) -> List[str]: + return val + movie_id_to_movie[val[1]] + + return datapipe.map(join_rating_movie) + + +def _movielens( + root: str, + *, + include_movies_data: bool = False, + # pyre-ignore[2] + row_mapper: Optional[Callable[[List[str]], Any]] = _default_row_mapper, + # pyre-ignore[2] + **open_kw, +) -> IterDataPipe: + ratings_path = os.path.join(root, RATINGS_FILENAME) + datapipe = LoadFiles((ratings_path,), mode="r", **open_kw) + datapipe = ReadLinesFromCSV(datapipe, skip_first_line=True, delimiter=",") + + if include_movies_data: + datapipe = _join_with_movies(datapipe, root) + if row_mapper: + datapipe = datapipe.map(row_mapper) + + return datapipe + + +def movielens_20m( + root: str, + *, + include_movies_data: bool = False, + # pyre-ignore[2] + row_mapper: Optional[Callable[[List[str]], Any]] = _default_row_mapper, + # pyre-ignore[2] + **open_kw, +) -> IterDataPipe: + """`MovieLens 20M `_ Dataset + Args: + root (str): local path to root directory containing MovieLens 20M dataset files. + include_movies_data (bool): if True, adds movies data to each line. + row_mapper (Optional[Callable[[List[str]], Any]]): function to apply to each split line. + open_kw: options to pass to underlying invocation of iopath.common.file_io.PathManager.open. + + Examples: + >>> datapipe = movielens_20m("/home/datasets/ml-20") + >>> datapipe = dp.iter.Batch(datapipe, 100) + >>> datapipe = dp.iter.Collate(datapipe) + >>> batch = next(iter(datapipe)) + """ + return _movielens( + root, + include_movies_data=include_movies_data, + row_mapper=row_mapper, + **open_kw, + ) + + +def movielens_25m( + root: str, + *, + include_movies_data: bool = False, + # pyre-ignore[2] + row_mapper: Optional[Callable[[List[str]], Any]] = _default_row_mapper, + # pyre-ignore[2] + **open_kw, +) -> IterDataPipe: + """`MovieLens 25M `_ Dataset + Args: + root (str): local path to root directory containing MovieLens 25M dataset files. + include_movies_data (bool): if True, adds movies data to each line. + row_mapper (Optional[Callable[[List[str]], Any]]): function to apply to each split line. + open_kw: options to pass to underlying invocation of iopath.common.file_io.PathManager.open. + + Examples: + >>> datapipe = movielens_25m("/home/datasets/ml-25") + >>> datapipe = dp.iter.Batch(datapipe, 100) + >>> datapipe = dp.iter.Collate(datapipe) + >>> batch = next(iter(datapipe)) + """ + return _movielens( + root, + include_movies_data=include_movies_data, + row_mapper=row_mapper, + **open_kw, + ) diff --git a/torchrec/datasets/random.py b/torchrec/datasets/random.py new file mode 100644 index 000000000..a78ca3f92 --- /dev/null +++ b/torchrec/datasets/random.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 + +from typing import Iterator, List, Optional + +import torch +from pyre_extensions import none_throws +from torch.utils.data.dataset import IterableDataset +from torchrec.datasets.utils import Batch +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class _RandomRecBatch: + generator: Optional[torch.Generator] + + def __init__( + self, + keys: List[str], + batch_size: int, + hash_size: Optional[int], + hash_sizes: Optional[List[int]], + ids_per_feature: int, + num_dense: int, + manual_seed: Optional[int] = None, + ) -> None: + if (hash_size is None and hash_sizes is None) or ( + hash_size is not None and hash_sizes is not None + ): + raise ValueError( + "One - and only one - of hash_size or hash_sizes must be set." + ) + + self.keys = keys + self.keys_length: int = len(keys) + self.batch_size = batch_size + self.hash_size = hash_size + self.hash_sizes = hash_sizes + self.ids_per_feature = ids_per_feature + self.num_dense = num_dense + + if manual_seed is not None: + self.generator = torch.Generator() + # pyre-ignore[16] + self.generator.manual_seed(manual_seed) + else: + self.generator = None + + self.iter_num = 0 + self._num_ids_in_batch: int = ( + self.ids_per_feature * self.keys_length * self.batch_size + ) + self.max_values: Optional[torch.Tensor] = None + if hash_sizes is not None: + self.max_values: torch.Tensor = torch.tensor( + [ + hash_size + for hash_size in hash_sizes + for b in range(batch_size) + for i in range(ids_per_feature) + ] + ) + self._generated_batches: List[Batch] = [self._generate_batch()] * 10 + self.batch_index = 0 + + def __iter__(self) -> "_RandomRecBatch": + return self + + def __next__(self) -> Batch: + batch = self._generated_batches[self.batch_index % len(self._generated_batches)] + self.batch_index += 1 + return batch + + def _generate_batch(self) -> Batch: + if self.hash_sizes is None: + # pyre-ignore[28] + values = torch.randint( + high=self.hash_size, + size=(self._num_ids_in_batch,), + generator=self.generator, + ) + else: + values = ( + torch.rand( + self._num_ids_in_batch, + generator=self.generator, + ) + * none_throws(self.max_values) + ).type(torch.LongTensor) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=self.keys, + values=values, + offsets=torch.tensor( + list( + range( + 0, + self._num_ids_in_batch + 1, + self.ids_per_feature, + ) + ), + dtype=torch.int32, + ), + ) + + dense_features = torch.randn( + self.batch_size, + self.num_dense, + generator=self.generator, + ) + # pyre-ignore[28] + labels = torch.randint( + low=0, + high=2, + size=(self.batch_size,), + generator=self.generator, + ) + + batch = Batch( + dense_features=dense_features, + sparse_features=sparse_features, + labels=labels, + ) + return batch + + +class RandomRecDataset(IterableDataset[Batch]): + def __init__( + self, + keys: List[str], + batch_size: int, + hash_size: Optional[int] = 100, + hash_sizes: Optional[List[int]] = None, + ids_per_feature: int = 2, + num_dense: int = 50, + manual_seed: Optional[int] = None, + ) -> None: + super().__init__() + self.batch_generator = _RandomRecBatch( + keys=keys, + batch_size=batch_size, + hash_size=hash_size, + hash_sizes=hash_sizes, + ids_per_feature=ids_per_feature, + num_dense=num_dense, + manual_seed=manual_seed, + ) + + def __iter__(self) -> Iterator[Batch]: + return iter(self.batch_generator) diff --git a/torchrec/datasets/tests/__init__.py b/torchrec/datasets/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/datasets/tests/test_criteo.py b/torchrec/datasets/tests/test_criteo.py new file mode 100644 index 000000000..7b3d40070 --- /dev/null +++ b/torchrec/datasets/tests/test_criteo.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 + +import contextlib +import csv +import os +import random +import tempfile +import unittest +from typing import List, Any, Dict, Generator + +from torch.utils.data import DataLoader +from torchrec.datasets.criteo import criteo_kaggle, criteo_terabyte + + +class _CriteoTest(unittest.TestCase): + INT_FEATURE_COUNT = 13 + CAT_FEATURE_COUNT = 26 + + LABEL_VAL_RANGE = (0, 1) + INT_VAL_RANGE = (0, 100) + CAT_VAL_RANGE = (0, 1000) + + @classmethod + @contextlib.contextmanager + def _create_dataset_tsv( + cls, + num_rows: int = 10, + train: bool = True, + filename: str = "criteo", + ) -> Generator[str, None, None]: + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, filename) + with open(path, "w") as f: + rows = [] + for _ in range(num_rows): + row = [] + if train: + row.append(str(random.randint(*cls.LABEL_VAL_RANGE))) + row += [ + *( + str(random.randint(*cls.INT_VAL_RANGE)) + for _ in range(cls.INT_FEATURE_COUNT) + ), + *( + ( + "%x" + % abs(hash(str(random.randint(*cls.CAT_VAL_RANGE)))) + ).zfill(8)[:8] + for _ in range(cls.CAT_FEATURE_COUNT) + ), + ] + rows.append(row) + # pyre-ignore[6] + cf = csv.writer(f, delimiter="\t") + cf.writerows(rows) + yield path + + def _validate_sample(self, sample: Dict[str, Any], train: bool = True) -> None: + if train: + self.assertEqual( + len(sample), self.INT_FEATURE_COUNT + self.CAT_FEATURE_COUNT + 1 + ) + label_val = sample["label"] + self.assertTrue( + self.LABEL_VAL_RANGE[0] <= label_val <= self.LABEL_VAL_RANGE[1] + ) + else: + self.assertEqual( + len(sample), self.INT_FEATURE_COUNT + self.CAT_FEATURE_COUNT + ) + for idx in range(self.INT_FEATURE_COUNT): + int_val = sample[f"int_{idx}"] + self.assertTrue(self.INT_VAL_RANGE[0] <= int_val <= self.INT_VAL_RANGE[1]) + for idx in range(self.CAT_FEATURE_COUNT): + cat_val = int(sample[f"cat_{idx}"], 16) + self.assertTrue(0 <= cat_val <= 16 ** 8 - 1) + + +class CriteoTerabyteTest(_CriteoTest): + def test_single_file(self) -> None: + with self._create_dataset_tsv() as dataset_pathname: + dataset = criteo_terabyte((dataset_pathname,)) + for sample in dataset: + self._validate_sample(sample) + self.assertEqual(len(list(iter(dataset))), 10) + + def test_multiple_files(self) -> None: + with contextlib.ExitStack() as stack: + pathnames = [ + stack.enter_context(self._create_dataset_tsv()) for _ in range(3) + ] + dataset = criteo_terabyte(pathnames) + for sample in dataset: + self._validate_sample(sample) + self.assertEqual(len(list(iter(dataset))), 30) + + +class CriteoKaggleTest(_CriteoTest): + def test_train_file(self) -> None: + with self._create_dataset_tsv() as path: + dataset = criteo_kaggle(path) + for sample in dataset: + self._validate_sample(sample) + self.assertEqual(len(list(iter(dataset))), 10) + + def test_test_file(self) -> None: + with self._create_dataset_tsv(train=False) as path: + dataset = criteo_kaggle(path) + for sample in dataset: + self._validate_sample(sample, train=False) + self.assertEqual(len(list(iter(dataset))), 10) + + +class CriteoDataLoaderTest(_CriteoTest): + def _validate_dataloader_sample( + self, + sample: Dict[str, List[Any]], # pyre-ignore[2] + batch_size: int, + train: bool = True, + ) -> None: + unbatched_samples = [{} for _ in range(self._sample_len(sample))] + for k, batched_values in sample.items(): + for (idx, value) in enumerate(batched_values): + unbatched_samples[idx][k] = value + for sample in unbatched_samples: + self._validate_sample(sample, train=train) + + def _sample_len( + self, + sample: Dict[str, List[Any]], # pyre-ignore[2] + ) -> int: + return len(next(iter(sample.values()))) + + def _test_dataloader( + self, + num_workers: int = 0, + batch_size: int = 1, + num_tsvs: int = 1, + num_rows_per_tsv: int = 10, + train: bool = True, + ) -> None: + with contextlib.ExitStack() as stack: + pathnames = [ + stack.enter_context( + self._create_dataset_tsv(num_rows=num_rows_per_tsv, train=train) + ) + for _ in range(num_tsvs) + ] + dataset = criteo_terabyte(pathnames) + dataloader = DataLoader( + dataset, batch_size=batch_size, num_workers=num_workers + ) + total_len = 0 + for sample in dataloader: + sample_len = self._sample_len(sample) + total_len += sample_len + self._validate_dataloader_sample( + sample, batch_size=batch_size, train=train + ) + self.assertEqual(total_len, len(list(iter(dataset)))) + + def test_multiple_train_workers(self) -> None: + self._test_dataloader( + num_workers=4, batch_size=16, num_tsvs=5, num_rows_per_tsv=32 + ) + + def test_fewer_tsvs_than_workers(self) -> None: + self._test_dataloader( + num_workers=2, batch_size=16, num_tsvs=1, num_rows_per_tsv=16 + ) + + def test_single_worker(self) -> None: + self._test_dataloader(batch_size=16, num_tsvs=2, num_rows_per_tsv=16) diff --git a/torchrec/datasets/tests/test_movielens.py b/torchrec/datasets/tests/test_movielens.py new file mode 100644 index 000000000..a686d20f6 --- /dev/null +++ b/torchrec/datasets/tests/test_movielens.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 + +import contextlib +import csv +import os +import random +import tempfile +import unittest +from typing import Any, Callable, Dict, Generator, Iterable, List, Type, Union + +from torchrec.datasets.movielens import movielens_20m, movielens_25m + + +class MovieLensTest(unittest.TestCase): + RATINGS_FILENAME = "ratings.csv" + MOVIES_FILENAME = "movies.csv" + + DEFAULT_RATINGS_COLUMN_NAMES: List[str] = [ + "userId", + "movieId", + "rating", + "timestamp", + ] + DEFAULT_RATINGS_COLUMN_TYPES: List[Type[Union[float, int, str]]] = [ + int, + int, + float, + int, + ] + DEFAULT_MOVIES_COLUMN_NAMES: List[str] = ["movieId", "title", "genres"] + DEFAULT_MOVIES_COLUMN_TYPES: List[Type[Union[float, int, str]]] = [int, str, str] + + MOVIE_ID_RANGE = (0, 100) + + # pyre-ignore[2] + def _create_csv(self, filename: str, rows: Iterable[Any]) -> None: + with open(filename, "w") as f: + # pyre-ignore[6]: Expected `_csv._Writer` for 1st positional only parameter + cf = csv.writer(f, delimiter=",") + cf.writerows(rows) + + def _create_ratings_csv(self, filename: str, num_rows: int) -> None: + self._create_csv( + filename, + [self.DEFAULT_RATINGS_COLUMN_NAMES] + + [ + [ + str(random.randint(0, 100)), + str(random.randint(*self.MOVIE_ID_RANGE)), + str(random.randint(0, 10) / 2), + str(random.randint(0, 100000)), + ] + for _ in range(num_rows) + ], + ) + + def _create_movies_csv(self, filename: str, movie_ids: Iterable[str]) -> None: + self._create_csv( + filename, + [self.DEFAULT_MOVIES_COLUMN_NAMES] + + [ + [movie_id, "title", "action|adventure|comedy"] for movie_id in movie_ids + ], + ) + + @contextlib.contextmanager + def _create_root(self, ratings_row_count: int) -> Generator[str, None, None]: + with tempfile.TemporaryDirectory() as tmpdir: + self._create_ratings_csv( + os.path.join(tmpdir, self.RATINGS_FILENAME), ratings_row_count + ) + self._create_movies_csv( + os.path.join(tmpdir, self.MOVIES_FILENAME), + movie_ids=[ + str(movie_id) for movie_id in range(self.MOVIE_ID_RANGE[1] + 1) + ], + ) + yield tmpdir + + def _validate_sample( + self, + sample: Dict[str, Any], + expected_column_names: List[str], + expected_column_types: List[Type[Union[float, int, str]]], + ) -> None: + self.assertSetEqual(set(sample.keys()), set(expected_column_names)) + ordered_vals = [sample[column_name] for column_name in expected_column_names] + for val, expected_type in zip(ordered_vals, expected_column_types): + self.assertTrue(isinstance(val, expected_type)) + + # pyre-ignore[24] + def _test_ratings(self, dataset_fn: Callable) -> None: + ratings_row_count = 100 + with self._create_root(ratings_row_count) as tmpdir: + dataset = dataset_fn(tmpdir) + for sample in dataset: + self._validate_sample( + sample, + self.DEFAULT_RATINGS_COLUMN_NAMES, + self.DEFAULT_RATINGS_COLUMN_TYPES, + ) + self.assertEqual(len(list(dataset)), ratings_row_count) + + # pyre-ignore[24] + def _test_ratings_movies(self, dataset_fn: Callable) -> None: + ratings_row_count = 200 + with self._create_root(ratings_row_count) as tmpdir: + dataset = dataset_fn(tmpdir, include_movies_data=True) + for sample in dataset: + self._validate_sample( + sample, + self.DEFAULT_RATINGS_COLUMN_NAMES + + self.DEFAULT_MOVIES_COLUMN_NAMES[1:], + self.DEFAULT_RATINGS_COLUMN_TYPES + + self.DEFAULT_MOVIES_COLUMN_TYPES[1:], + ) + self.assertEqual(len(list(dataset)), ratings_row_count) + + def test_20m_ratings(self) -> None: + self._test_ratings(movielens_20m) + + def test_25m_ratings(self) -> None: + self._test_ratings(movielens_25m) + + def test_20m_ratings_movies(self) -> None: + self._test_ratings_movies(movielens_20m) + + def test_25m_ratings_movies(self) -> None: + self._test_ratings_movies(movielens_25m) diff --git a/torchrec/datasets/tests/test_utils.py b/torchrec/datasets/tests/test_utils.py new file mode 100644 index 000000000..dcbad1131 --- /dev/null +++ b/torchrec/datasets/tests/test_utils.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 + +import random +import unittest +from typing import Any, Iterator, List, Tuple +from unittest.mock import Mock, patch + +from torch.utils.data import IterDataPipe +from torchrec.datasets.utils import ( + idx_split_train_val, + rand_split_train_val, + ParallelReadConcat, +) + + +class _DummyDataReader(IterDataPipe): + def __init__(self, num_rows: int, val: str = "") -> None: + self.num_rows = num_rows + self.val = val + + def __iter__(self) -> Iterator[Tuple[int, str]]: + for idx in range(self.num_rows): + yield idx, self.val + + +class TestLimit(unittest.TestCase): + def test(self) -> None: + datapipe = _DummyDataReader(100).limit(10) + self.assertEqual(len(list(datapipe)), 10) + + +class TestIdxSplitTrainVal(unittest.TestCase): + def test_even_split(self) -> None: + datapipe = _DummyDataReader(int(1000)) + train_datapipe, val_datapipe = idx_split_train_val(datapipe, 0.5) + self.assertEqual(len(list(train_datapipe)), 500) + self.assertEqual(len(list(val_datapipe)), 500) + + def test_uneven_split(self) -> None: + datapipe = _DummyDataReader(int(100000)) + train_datapipe, val_datapipe = idx_split_train_val(datapipe, 0.6) + self.assertEqual(len(list(train_datapipe)), 100000 * 0.6) + self.assertEqual(len(list(val_datapipe)), 100000 * 0.4) + + def test_invalid_train_perc(self) -> None: + datapipe = _DummyDataReader(123) + with self.assertRaisesRegex(ValueError, "train_perc"): + train_datapipe, val_datapipe = idx_split_train_val(datapipe, 0.0) + with self.assertRaisesRegex(ValueError, "train_perc"): + train_datapipe, val_datapipe = idx_split_train_val(datapipe, 1.0) + with self.assertRaisesRegex(ValueError, "train_perc"): + train_datapipe, val_datapipe = idx_split_train_val(datapipe, 10.2) + with self.assertRaisesRegex(ValueError, "train_perc"): + train_datapipe, val_datapipe = idx_split_train_val(datapipe, -50.15) + + +class _FakeRandom(random.Random): + def __init__(self, num_vals: int) -> None: + super().__init__() + self.num_vals = num_vals + self.vals: List[float] = [val / num_vals for val in range(num_vals)] + self.current_idx = 0 + + def random(self) -> float: + val = self.vals[self.current_idx] + self.current_idx += 1 + return val + + # pyre-ignore[3] + def getstate(self) -> Tuple[Any, ...]: + return (self.vals, self.current_idx) + + # pyre-ignore[2] + def setstate(self, state: Tuple[Any, ...]) -> None: + self.vals, self.current_idx = state + + +class TestRandSplitTrainVal(unittest.TestCase): + def test_deterministic_split(self) -> None: + num_vals = 1000 + datapipe = _DummyDataReader(num_vals) + with patch("random.Random", new=lambda a: _FakeRandom(num_vals)): + train_datapipe, val_datapipe = rand_split_train_val(datapipe, 0.8) + self.assertEqual(len(list(train_datapipe)), num_vals * 0.8) + self.assertEqual(len(list(val_datapipe)), num_vals * 0.2) + self.assertEqual( + len(set(train_datapipe).intersection(set(val_datapipe))), 0 + ) + + def test_rand_split(self) -> None: + datapipe = _DummyDataReader(100000) + train_datapipe, val_datapipe = rand_split_train_val(datapipe, 0.7) + self.assertEqual(len(set(train_datapipe).intersection(set(val_datapipe))), 0) + + def test_invalid_train_perc(self) -> None: + datapipe = _DummyDataReader(123) + with self.assertRaisesRegex(ValueError, "train_perc"): + train_datapipe, val_datapipe = rand_split_train_val(datapipe, 0.0) + with self.assertRaisesRegex(ValueError, "train_perc"): + train_datapipe, val_datapipe = rand_split_train_val(datapipe, 1.0) + with self.assertRaisesRegex(ValueError, "train_perc"): + train_datapipe, val_datapipe = rand_split_train_val(datapipe, 10.2) + with self.assertRaisesRegex(ValueError, "train_perc"): + train_datapipe, val_datapipe = rand_split_train_val(datapipe, -50.15) + + +class TestParallelReadConcat(unittest.TestCase): + def test_worker_assignment(self) -> None: + datapipes = [_DummyDataReader(1000, str(idx)) for idx in range(10)] + all_res = [] + num_workers = 4 + for idx in range(num_workers): + with patch("torchrec.datasets.utils.get_worker_info") as get_worker_info: + get_worker_info.return_value = Mock(id=idx, num_workers=num_workers) + all_res += list(ParallelReadConcat(*datapipes)) + expected_res = [] + for dp in datapipes: + expected_res += list(dp) + self.assertEqual(all_res, expected_res) + + def test_no_workers(self) -> None: + datapipes = [_DummyDataReader(1000, str(idx)) for idx in range(10)] + with patch("torchrec.datasets.utils.get_worker_info") as get_worker_info: + get_worker_info.return_value = None + dp = ParallelReadConcat(*datapipes) + self.assertEqual(len(list(dp)), 10000) + + def test_more_workers_than_dps(self) -> None: + datapipes = [_DummyDataReader(1000, str(idx)) for idx in range(2)] + with patch("torchrec.datasets.utils.get_worker_info") as get_worker_info: + get_worker_info.return_value = Mock(id=2, num_workers=10) + with self.assertRaises(ValueError): + next(iter(ParallelReadConcat(*datapipes))) diff --git a/torchrec/datasets/utils.py b/torchrec/datasets/utils.py new file mode 100644 index 000000000..00a07723f --- /dev/null +++ b/torchrec/datasets/utils.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python3 + +import csv +import math +import random +from dataclasses import dataclass +from functools import partial +from io import IOBase +from typing import ( + Any, + Callable, + Iterable, + Iterator, + List, + Sequence, + Tuple, + TypeVar, +) + +import torch +from iopath.common.file_io import PathManager, PathManagerFactory +from torch.utils.data import IterDataPipe, functional_datapipe, get_worker_info +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.types import Pipelineable + +PATH_MANAGER_KEY = "torchrec" + + +@dataclass +class Batch(Pipelineable): + dense_features: torch.Tensor + sparse_features: KeyedJaggedTensor + labels: torch.Tensor + + def to(self, device: torch.device, non_blocking: bool = False) -> "Batch": + return Batch( + dense_features=self.dense_features.to( + device=device, non_blocking=non_blocking + ), + sparse_features=self.sparse_features.to( + device=device, non_blocking=non_blocking + ), + labels=self.labels.to(device=device, non_blocking=non_blocking), + ) + + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + self.dense_features.record_stream(stream) + self.sparse_features.record_stream(stream) + self.labels.record_stream(stream) + + def pin_memory(self) -> "Batch": + return Batch( + dense_features=self.dense_features.pin_memory(), + sparse_features=self.sparse_features.pin_memory(), + labels=self.labels.pin_memory(), + ) + + +class _IdxFilter(IterDataPipe): + def __init__( + self, datapipe: IterDataPipe, filter_fn: Callable[[int], bool] + ) -> None: + super().__init__() + self.datapipe = datapipe + self.filter_fn = filter_fn + + # pyre-ignore[3] + def __iter__(self) -> Iterator[Any]: + for idx, data in enumerate(self.datapipe): + if self.filter_fn(idx): + yield data + + +def _default_key_fn(idx: int) -> int: + return idx + + +def train_filter( + key_fn: Callable[[int], int], + train_perc: float, + decimal_places: int, + idx: int, +) -> bool: + return (key_fn(idx) % 10 ** decimal_places) < round( + train_perc * 10 ** decimal_places + ) + + +def val_filter( + key_fn: Callable[[int], int], + train_perc: float, + decimal_places: int, + idx: int, +) -> bool: + return not train_filter(key_fn, train_perc, decimal_places, idx) + + +def idx_split_train_val( + datapipe: IterDataPipe, + train_perc: float, + decimal_places: int = 2, + key_fn: Callable[[int], int] = _default_key_fn, +) -> Tuple[IterDataPipe, IterDataPipe]: + if not 0.0 < train_perc < 1.0: + raise ValueError("train_perc must be in range (0.0, 1.0)") + return ( + _IdxFilter(datapipe, partial(train_filter, key_fn, train_perc, decimal_places)), + _IdxFilter(datapipe, partial(val_filter, key_fn, train_perc, decimal_places)), + ) + + +class _RandFilter(IterDataPipe): + def __init__( + self, + datapipe: IterDataPipe, + filter_fn: Callable[[random.Random], bool], + rand_gen: random.Random, + ) -> None: + super().__init__() + self.datapipe = datapipe + self.filter_fn = filter_fn + self.rand_gen = rand_gen + # pyre-ignore[4] + self.rand_gen_init_state: Tuple[Any, ...] = rand_gen.getstate() + + # pyre-ignore[3] + def __iter__(self) -> Iterator[Any]: + self.rand_gen.setstate(self.rand_gen_init_state) + for data in self.datapipe: + if self.filter_fn(self.rand_gen): + yield data + + +def _rand_train_filter_fn( + train_perc: float, + rand_gen: random.Random, +) -> bool: + return rand_gen.random() < train_perc + + +def _rand_val_filter_fn(train_perc: float, rand_gen: random.Random) -> bool: + return not _rand_train_filter_fn(train_perc, rand_gen) + + +def rand_split_train_val( + datapipe: IterDataPipe, + train_perc: float, + random_seed: int = 0, +) -> Tuple[IterDataPipe, IterDataPipe]: + """Via uniform random sampling, generates two IterDataPipe instances representing + disjoint train and val splits of the given IterDataPipe. + Args: + datapipe (IterDataPipe): datapipe to split. + train_perc (float): value in range (0.0, 1.0) specifying target proportion of + datapipe samples to include in train split. Note that the actual proportion + is not guaranteed to match train_perc exactly. + random_seed (int): determines split membership for a given sample + and train_perc. Use the same value across calls to generate consistent splits. + Example: + >>> datapipe = criteo_terabyte( + >>> ("/home/datasets/criteo/day_0.tsv", "/home/datasets/criteo/day_1.tsv") + >>> ) + >>> train_datapipe, val_datapipe = rand_split_train_val(datapipe, 0.75) + >>> train_batch = next(iter(train_datapipe)) + >>> val_batch = next(iter(val_datapipe)) + """ + if not 0.0 < train_perc < 1.0: + raise ValueError("train_perc must be in range (0.0, 1.0)") + + return _RandFilter( + datapipe, partial(_rand_train_filter_fn, train_perc), random.Random(random_seed) + ), _RandFilter( + datapipe, partial(_rand_val_filter_fn, train_perc), random.Random(random_seed) + ) + + +T = TypeVar("T") + + +def safe_cast(val: T, dest_type: Callable[[T], T], default: T) -> T: + try: + return dest_type(val) + except ValueError: + return default + + +@functional_datapipe("limit") +class Limit(IterDataPipe): + def __init__(self, datapipe: IterDataPipe, limit: int) -> None: + super().__init__() + self.datapipe = datapipe + self.limit = limit + + # pyre-ignore[3] + def __iter__(self) -> Iterator[Any]: + for idx, data in enumerate(self.datapipe): + if idx >= self.limit: + break + yield data + + +class ReadLinesFromCSV(IterDataPipe): + def __init__( + self, + datapipe: IterDataPipe[Tuple[str, "IOBase"]], + skip_first_line: bool = False, + # pyre-ignore[2] + **kw, + ) -> None: + super().__init__() + self.datapipe = datapipe + self.skip_first_line = skip_first_line + # pyre-ignore[4] + self.kw = kw + + def __iter__(self) -> Iterator[List[str]]: + for _, data in self.datapipe: + reader = csv.reader(data, **self.kw) + if self.skip_first_line: + next(reader, None) + for line in reader: + yield line + + +class LoadFiles(IterDataPipe[Tuple[str, "IOBase"]]): + """ + Taken and adapted from torch.utils.data.datapipes.iter.LoadFilesFromDisk + + TODO: + Merge this back or replace this with something in core Datapipes lib + """ + + def __init__( + self, + datapipe: Iterable[str], + mode: str = "b", + length: int = -1, + path_manager_key: str = PATH_MANAGER_KEY, + # pyre-ignore[2] + **open_kw, + ) -> None: + super().__init__() + self.datapipe: Iterable[str] = datapipe + self.mode: str = mode + if self.mode not in ("b", "t", "rb", "rt", "r"): + raise ValueError("Invalid mode {}".format(mode)) + # TODO: enforce typing for each instance based on mode, otherwise + # `argument_validation` with this DataPipe may be potentially broken + self.length: int = length + # pyre-ignore[4] + self.open_kw = open_kw + self.path_manager: PathManager = PathManagerFactory().get(path_manager_key) + self.path_manager.set_strict_kwargs_checking(False) + + # Remove annotation due to 'IOBase' is a general type and true type + # is determined at runtime based on mode. Some `DataPipe` requiring + # a subtype would cause mypy error. + # pyre-ignore[3] + def __iter__(self): + if self.mode in ("b", "t"): + self.mode = "r" + self.mode + for pathname in self.datapipe: + if not isinstance(pathname, str): + raise TypeError( + "Expected string type for pathname, but got {}".format( + type(pathname) + ) + ) + yield ( + pathname, + self.path_manager.open(pathname, self.mode, **self.open_kw), + ) + + def __len__(self) -> int: + if self.length == -1: + raise NotImplementedError + return self.length + + +def _default_dp_selector( + datapipes: Sequence[IterDataPipe], +) -> Sequence[IterDataPipe]: + worker_info = get_worker_info() + if worker_info is None: + return datapipes + else: + if worker_info.num_workers > len(datapipes): + raise ValueError( + f"Number of workers {worker_info.num_workers} exceeds" + f"number of datapipes ({len(datapipes)})!" + ) + offsets = [0] + for num_workers in reversed(range(1, worker_info.num_workers + 1)): + remaining_dps = len(datapipes) - offsets[-1] + dps_to_assign = math.ceil(remaining_dps / num_workers) + offsets.append(offsets[-1] + dps_to_assign) + return datapipes[offsets[worker_info.id] : offsets[worker_info.id + 1]] + + +class ParallelReadConcat(IterDataPipe): + r""":class:`ParallelReadConcat`. + + Iterable DataPipe that concatenates multiple Iterable DataPipes. + When used with a DataLoader, assigns a subset of datapipes to each DataLoader worker + to allow for parallel reading. + Args: + datapipes: IterDataPipe instances to read from. + dp_selector: function that each DataLoader worker would use to determine the subset of datapipes + to read from. + Example: + >>> datapipes = [ + >>> criteo_terabyte( + >>> (f"/home/local/datasets/criteo/shard_{idx}.tsv",), + >>> ) + >>> .batch(100) + >>> .collate() + >>> for idx in range(4) + >>> ] + >>> dataloader = DataLoader( + >>> ParallelReadConcat(*datapipes), num_workers=4, batch_size=None + >>> ) + """ + + def __init__( + self, + *datapipes: IterDataPipe, + dp_selector: Callable[ + [Sequence[IterDataPipe]], Sequence[IterDataPipe] + ] = _default_dp_selector, + ) -> None: + super().__init__() + self.datapipes: Tuple[IterDataPipe, ...] = datapipes + self.dp_selector = dp_selector + + # pyre-ignore[3] + def __iter__(self) -> Iterator[Any]: + selected_dps = self.dp_selector(self.datapipes) + for dp in selected_dps: + for data in dp: + yield data diff --git a/torchrec/distributed/__init__.py b/torchrec/distributed/__init__.py new file mode 100644 index 000000000..19c62cb55 --- /dev/null +++ b/torchrec/distributed/__init__.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +from torchrec.distributed.model_parallel import DistributedModelParallel # noqa +from torchrec.distributed.train_pipeline import ( # noqa + TrainPipeline, + TrainPipelineBase, + TrainPipelineSparseDist, +) +from torchrec.distributed.types import ( # noqa + Awaitable, + NoWait, + ParameterSharding, + ModuleSharder, + ShardingPlanner, + ShardedModule, + ShardedTensor, + ShardingEnv, +) +from torchrec.distributed.utils import ( # noqa + get_unsharded_module_names, + sharded_model_copy, +) diff --git a/torchrec/distributed/collective_utils.py b/torchrec/distributed/collective_utils.py new file mode 100644 index 000000000..a37999b3d --- /dev/null +++ b/torchrec/distributed/collective_utils.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +""" +This file contains utilities for constructing collective based control flows. +""" + +from functools import wraps +from typing import Optional, Callable, TypeVar, cast, Any + +import torch.distributed as dist + + +def is_leader(pg: Optional[dist.ProcessGroup], leader_rank: int = 0) -> bool: + """ + Check if the current processs is the leader. + + Args: + - pg: the process's rank within the pg is used to determine if + the process is the leader. pg being None implies that the process + is the only member in the group (e.g. a single process program). + + - leader_rank: the definition of leader (defaults to 0). The caller can + override it with a context-specific definition. + """ + if pg is None: + return leader_rank == 0 + return pg.rank() == leader_rank + + +T = TypeVar("T") + + +def invoke_on_rank_and_broadcast_result( + pg: dist.ProcessGroup, + rank: int, + func: Callable[..., T], + *args: Any, + **kwargs: Any, +) -> T: + """ + Invokes a function on the designated rank and broadcasts the result to all + members within the group. + + Example usage: + + >>> id = invoke_on_rank_and_broadcast_result(pg, 0, allocate_id) + """ + if pg.rank() == rank: + res = func(*args, **kwargs) + object_list = [res] + else: + object_list = [None] + if pg.size() > 1: + dist.broadcast_object_list(object_list, rank, group=pg) + return cast(T, object_list[0]) + + +# pyre-ignore Missing return annotation [3] +def run_on_leader(pg: dist.ProcessGroup, rank: int): + def callable(func: Callable[..., T]) -> T: + @wraps(func) + def wrapped(*args: Any, **kwargs: Any) -> T: + return invoke_on_rank_and_broadcast_result(pg, rank, func, *args, **kwargs) + + return wrapped + + return callable diff --git a/torchrec/distributed/comm.py b/torchrec/distributed/comm.py new file mode 100644 index 000000000..fd60c53a7 --- /dev/null +++ b/torchrec/distributed/comm.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 + +import logging +import os +from typing import List, Tuple, Optional + +import torch +import torch.distributed as dist + +logger: logging.Logger = logging.getLogger(__name__) + +# Global, only should be accessed via intra_and_cross_node_pg() +_INTRA_PG: Optional[dist.ProcessGroup] = None +_CROSS_PG: Optional[dist.ProcessGroup] = None + + +def _env2int(env_list: List[str], default: int = -1) -> int: + for e in env_list: + val = int(os.environ.get(e, -1)) + if val >= 0: + return val + return default + + +def get_local_size(world_size: int) -> int: + """ + Get the local world size (see https://pytorch.org/docs/stable/elastic/run.html) + This is usually the size of workers on each node, or nproc_per_node + """ + local_size = _env2int( + [ + "LOCAL_WORLD_SIZE", + "MPI_LOCALNRANKS", + "OMPI_COMM_WORLD_LOCAL_SIZE", + "MV2_COMM_WORLD_LOCAL_SIZE", + ], + 8, + ) + + if local_size == -1 or world_size % local_size != 0: + logging.warning( + "Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE." + ) + local_size = world_size + return local_size + + +def get_local_rank(world_size: int, rank: int) -> int: + """ + Get the local rank of the local processes (see https://pytorch.org/docs/stable/elastic/run.html) + This is usually the rank of the worker on its node + """ + my_local_rank = _env2int( + [ + "LOCAL_RANK", + "MPI_LOCALRANKID", + "OMPI_COMM_WORLD_LOCAL_RANK", + "MV2_COMM_WORLD_LOCAL_RANK", + ], + -1, + ) + local_size = get_local_size(world_size) + + if my_local_rank == -1 or my_local_rank >= local_size: + logging.warning( + "Could not determine LOCAL_RANK from environment, falling back to GLOBAL_RANK % LOCAL_SIZE." + ) + my_local_rank = rank % local_size + return my_local_rank + + +def get_group_rank(world_size: int, rank: int) -> int: + """ + Get the group rank of the worker group. Also available with GROUP_RANK environment varible + A number between 0 and get_num_groups() (See https://pytorch.org/docs/stable/elastic/run.html) + """ + return rank // get_local_size(world_size) + + +def get_num_groups(world_size: int) -> int: + """ + Get the number of worker groups. + Usually equivalent to max_nnodes (See https://pytorch.org/docs/stable/elastic/run.html) + """ + return world_size // get_local_size(world_size) + + +def intra_and_cross_node_pg( + device: Optional[torch.device] = None, + backend: str = "nccl", +) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]: + """ + This function creates sub process groups (intra and cross node) + """ + if device is not None and device.type == "meta": + return None, None + + global _INTRA_PG # intra node process group + global _CROSS_PG # cross node process group + + my_size = dist.get_world_size() + my_rank = dist.get_rank() + my_local_rank = get_local_rank(my_size, my_rank) + local_size = get_local_size(my_size) + my_group_rank = get_group_rank(my_size, my_rank) + group_count = get_num_groups(my_size) + + logger.info( + f"[{my_rank}] my_local_rank = {my_local_rank}, local_size = {local_size}," + f"my_group_rank = {my_group_rank}, group_count = {group_count}" + ) + if _INTRA_PG is None: + for group_rank in range(group_count): + peers = [group_rank * local_size + r for r in range(local_size)] + curr_intra_group_pg = dist.new_group(backend=backend, ranks=peers) + if my_group_rank == group_rank: + logger.warning( + "[Connection] intra_group: [%d] -> [%s]" % (my_rank, peers) + ) + _INTRA_PG = curr_intra_group_pg + + dist.barrier() + + if _CROSS_PG is None: + for l_rank in range(local_size): + peers = [l_rank + g * local_size for g in range(group_count)] + curr_cross_group_pg = dist.new_group(backend=backend, ranks=peers) + if l_rank == my_local_rank: + logger.warning( + "[Connection] cross_group: [%d] -> [%s]" % (my_rank, peers) + ) + _CROSS_PG = curr_cross_group_pg + + dist.barrier() + + return _INTRA_PG, _CROSS_PG diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py new file mode 100644 index 000000000..2e967d46b --- /dev/null +++ b/torchrec/distributed/comm_ops.py @@ -0,0 +1,892 @@ +#!/usr/bin/env python3 + +from dataclasses import dataclass, field +from typing import List, Optional, Tuple, TypeVar, Any + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.autograd import Function +from torch.autograd.profiler import record_function +from torchrec.distributed.types import Awaitable, NoWait + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +W = TypeVar("W") + +# TODO: T96382816, NE Parity Backward compatibility +GRADIENT_DIVISION: bool = True + + +def set_gradient_division(val: bool) -> None: + global GRADIENT_DIVISION + GRADIENT_DIVISION = val + + +# Some commonly used notations for comm ops: +# B - batch size +# T - number of embedding tables +# D - embedding dimension + + +class Request(Awaitable[W]): + def __init__(self, pg: dist.ProcessGroup) -> None: + super().__init__() + self.pg: dist.ProcessGroup = pg + # pyre-fixme[11]: Annotation dist.Work is not defined as a type. + self.req: Optional[dist.Work] = None + self.tensor: Optional[W] = None + self.a2ai = None # type: ignore + self.rsi = None # type: ignore + self.wait_function = None # type: ignore + + def wait(self) -> W: + ret = self.wait_function.apply(self.pg, self, self.tensor) + self.req = None + self.tensor = None + return ret + + +@dataclass +class All2AllPooledInfo(object): + """The data class that collects the attributes when calling the + alltoall_pooled operation. + + Attributes: + dim_sum_per_rank (list[Int]): number of features (sum of dimensions) + of the embedding in each rank + dim_sum_per_rank_tensor: (tensor, optional): the tensor version of + `dim_sum_per_rank`, this is only used by the fast kernel of + `_recat_pooled_embedding_grad_out` + cumsum_dim_sum_per_rank_tensor (tensor, optional): cumulative sum of + dim_sum_per_rank, this is only used by the fast kernel of + `_recat_pooled_embedding_grad_out` + mixed_dim: (bool): the flag whether the input is mixed + dimensioned or not. + D : (int, optional): embedding dimension of the embedding table + B_local: (int, optional): local batch size before scattering + """ + + dim_sum_per_rank: List[int] + dim_sum_per_rank_tensor: Optional[Tensor] + cumsum_dim_sum_per_rank_tensor: Optional[Tensor] + mixed_dim: bool + D: int = -1 # -1 means doesn't use + B_local: int = -1 + + +@dataclass +class All2AllSequenceInfo(object): + """The data class that collects the attributes when calling the + alltoall_sequence operation. + + Attributes: + embedding_dim (int): embedding dimension + lengths_after_sparse_data_all2all (tensor): lengths of sparse features after all2all + forward_recat_tensor (tensor): recat tensor for forward + backward_recat_tensor (tensor): recat tensor for backward + input_splits (tensor): input splits + output_splits (tensor): output splits + lengths_sparse_before_features_all2all (tensor): lengths of sparse features before all2all + """ + + embedding_dim: int + lengths_after_sparse_data_all2all: Tensor + forward_recat_tensor: Tensor + backward_recat_tensor: Tensor + input_splits: List[int] + output_splits: List[int] + permuted_lengths_after_sparse_data_all2all: Optional[Tensor] = None + + +@dataclass +class All2AllVInfo(object): + """The data class that collects the attributes when calling the + alltoallv operation. + + Attributes: + dim_sum_per_rank (list[Int]): number of features (sum of dimensions) + of the embedding in each rank + B_global: (int): The global batch size for each rank + B_local: (int, optional): local batch size before scattering + B_local_list: (List[int]): local batch sizes for each embedding table + locally (in my current rank) + D_local_list : (List[Int]): embedding dimension of each embedding table + locally (in my current rank) + input_split_sizes (list[Int]): The input split sizes for each rank, this + remembers how to split the input when doing the all_to_all_single operation + output_split_sizes (list[Int]): The output split sizes for each rank, this + remembers how to fill the output when doing the all_to_all_single operation + """ + + dims_sum_per_rank: List[int] + B_global: int + B_local: int + B_local_list: List[int] + D_local_list: List[int] + input_split_sizes: List[int] = field(default_factory=list) + output_split_sizes: List[int] = field(default_factory=list) + + +@dataclass +class ReduceScatterInfo(object): + """The data class that collects the attributes when calling the + reduce_scatter_pooled operation. + + Attributes: + input_sizes (List[int]) : the sizes of the input tensors, this remembers + the sizes of the input tensors when running the backward pass and + producing the gradient + """ + + input_sizes: List[int] + + +def _get_split_lengths_by_len( + world_size: int, my_rank: int, n: int +) -> Tuple[int, List[int]]: + k, m = divmod(n, world_size) + if m == 0: + splits = [k] * world_size + my_len = k + else: + splits = [(k + 1) if i < m else k for i in range(world_size)] + my_len = splits[my_rank] + return (my_len, splits) + + +def alltoall_pooled( + a2a_pooled_embs_tensor: Tensor, + dim_sum_per_rank: List[int], + mixed_dim: bool = False, + dim_sum_per_rank_tensor: Optional[Tensor] = None, + cumsum_dim_sum_per_rank_tensor: Optional[Tensor] = None, + group: Optional[dist.ProcessGroup] = None, +) -> Awaitable[Tensor]: + """ + Perform alltoall operation for a single pooled embedding tensor. Each process splits + the input pooled embeddings tensor based on the world size, and then scatters the + split list to all processes in the group. Then concatenate the received tensors from + all the processes in the group and return single output tensor. + + Args: + a2a_pooled_embs_tensor (tensor): input pooled embeddings. Must be pooled + together before passing into this function. Usually with the shape of + B x T x D, where B - batch size, T - number of embedding tables, + D - embedding dimension. When `mixed_dim=True`, the input shape should + be B x D_local_sum, where D_local_sum is the dimension sum of all the + local embedding tables. + dim_sum_per_rank (list[Int]): number of features (sum of dimensions) + of the embedding in each rank + mixed_dim: (bool, optional): the flag whether the input is mixed + dimensioned or not. + dim_sum_per_rank_tensor: (tensor, optional): the tensor version of + `dim_sum_per_rank`, this is only used by the fast kernel of + `_recat_pooled_embedding_grad_out` + cumsum_dim_sum_per_rank_tensor (tensor, optional): cumulative sum of + dim_sum_per_rank, this is only used by the fast kernel of + `_recat_pooled_embedding_grad_out` + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + Async work handle (Awaitable), which can be `wait()` later to get the result + tensor. + + .. warning:: + `alltoall_pooled` is experimental and subject to change. + + """ + if group is None: + group = dist.distributed_c10d._get_default_group() + + if dist.get_world_size(group) <= 1: + return NoWait(a2a_pooled_embs_tensor) + + myreq = Request(group) + a2ai = All2AllPooledInfo( + dim_sum_per_rank=dim_sum_per_rank, + dim_sum_per_rank_tensor=dim_sum_per_rank_tensor, + cumsum_dim_sum_per_rank_tensor=cumsum_dim_sum_per_rank_tensor, + mixed_dim=mixed_dim, + ) + # pyre-fixme[16]: `All2All_Pooled_Req` has no attribute `apply`. + All2All_Pooled_Req.apply(group, myreq, a2ai, a2a_pooled_embs_tensor) + return myreq + + +# a2a operator for (T * B * L_i, D) tensors +# not support mixed dimensions +def alltoall_sequence( + # (T, B, L_i * D) flattened + a2a_sequence_embs_tensor: Tensor, + forward_recat_tensor: Tensor, + backward_recat_tensor: Tensor, + lengths_after_sparse_data_all2all: Tensor, + input_splits: List[int], + output_splits: List[int], + group: Optional[dist.ProcessGroup] = None, +) -> Awaitable[Tensor]: + """ + Perform alltoall operation for sequence embeddings. Each process splits input + tensor base on the world size, and then scatters the split list to all processes + in a group. Then concatenate the received tensors from all the processes in the + group and return single output tensor. + + Args: + a2a_sequence_embs_tensor (tensor): input embeddings. Usually with the shape + of (T * B * L_i, D), where B - batch size, T - number of embedding tables, + D - embedding dimension. + embedding_dim: embedding dimension + lengths_after_sparse_data_all2all (tensor): lengths of sparse features after all2all + forward_recat_tensor (tensor): recat tensor for forward + backward_recat_tensor (tensor): recat tensor for backward + input_splits (tensor): input splits + output_splits (tensor): output splits + + Returns: + Async work handle (Awaitable), which can be `wait()` later to get the result + tensor. + + .. warning:: + `alltoall_sequence` is experimental and subject to change. + + """ + if group is None: + group = dist.distributed_c10d._get_default_group() + + if dist.get_world_size(group) <= 1: + return NoWait(a2a_sequence_embs_tensor) + + myreq = Request(group) + a2ai = All2AllSequenceInfo( + embedding_dim=a2a_sequence_embs_tensor.shape[1], + lengths_after_sparse_data_all2all=lengths_after_sparse_data_all2all, + forward_recat_tensor=forward_recat_tensor, + backward_recat_tensor=backward_recat_tensor, + input_splits=input_splits, + output_splits=output_splits, + ) + # sequence of embeddings, bags are definitely non-uniform + + # pyre-fixme[16]: `All2All_Seq_Req` has no attribute `apply`. + All2All_Seq_Req.apply(group, myreq, a2ai, a2a_sequence_embs_tensor) + return myreq + + +def alltoallv( + inputs: List[Tensor], + out_split: Optional[List[int]] = None, + per_rank_split_lengths: Optional[List[int]] = None, + group: Optional[dist.ProcessGroup] = None, +) -> Awaitable[List[Tensor]]: + """ + Perform alltoallv operation for a list of input embeddings. Each process scatters + the list to all processes in a group. + + Args: + input (list[tensor]): List of tensors to scatter one per rank. The tensors + in the list usually have different lengths. If output_split + out_split (list[Int], optional): output split sizes (or dims sum per rank), + if not specified, we will use `per_rank_split_lengths` to construct a output + split with the assumption that all the embs have the same dimension. + per_rank_split_lengths: (list[Int], optional): split lengths per rank. + If not specified, the `output_split` must be specified. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + Returns: + Async work handle (Awaitable), which can be `wait()` later to get the result + list of tensor. + + .. warning:: + `alltoallv` is experimental and subject to change. + + """ + + if group is None: + group = dist.distributed_c10d._get_default_group() + + world_size = dist.get_world_size(group) + my_rank = dist.get_rank(group) + + myreq = Request(group) + B_global, _ = inputs[0].size() + D_local_list = [e.size()[1] for e in inputs] + B_local, B_local_list = _get_split_lengths_by_len(world_size, my_rank, B_global) + + if out_split is not None: + dims_sum_per_rank = out_split + elif per_rank_split_lengths is not None: + # all the embs have the same dimension + dims_sum_per_rank = [s * D_local_list[0] for s in per_rank_split_lengths] + else: + raise RuntimeError("Need to specify either out_split or per_rank_split_lengths") + + a2ai = All2AllVInfo( + dims_sum_per_rank=dims_sum_per_rank, + B_local=B_local, + B_local_list=B_local_list, + D_local_list=D_local_list, + B_global=B_global, + ) + + # pyre-fixme[16]: `All2Allv_Req` has no attribute `apply`. + All2Allv_Req.apply(group, myreq, a2ai, inputs) + + return myreq + + +def reduce_scatter_pooled( + inputs: List[Tensor], + group: Optional[dist.ProcessGroup] = None, +) -> Awaitable[Tensor]: + if group is None: + group = dist.distributed_c10d._get_default_group() + + if dist.get_world_size(group) <= 1: + return NoWait(inputs[dist.get_rank(group)]) + + myreq = Request(group) + rsi = ReduceScatterInfo(input_sizes=[tensor.size() for tensor in inputs]) + # pyre-fixme[16]: `ReduceScatter_Req` has no attribute `apply`. + ReduceScatter_Req.apply(group, myreq, rsi, *inputs) + return myreq + + +# TODO: improve performance of _recat_pooled_embedding_grad_out and +# recat_pooled_embedding_mixed_dim_grad_out, see T87591139 +def _recat_pooled_embedding_grad_out( + grad_output: Tensor, num_features_per_rank: List[int] +) -> Tensor: + + """ + TODO: improve performance of _recat_pooled_embedding_grad_out in an + efficient fashion (the .contiguous() calls are extremely expensive). + see T87591139 + """ + grad_outputs_by_rank = grad_output.split(num_features_per_rank, dim=1) + return torch.cat( + [ + grad_output_by_rank.contiguous().view(-1) + for grad_output_by_rank in grad_outputs_by_rank + ], + dim=0, + ) + + +def _recat_seq_embedding( + input_embeddings: Tensor, + split_sizes: List[int], + T_local: int, + my_size: int, + forward: bool, +) -> Tensor: + seq_embeddings_by_rank = input_embeddings.split(split_sizes) + if forward: + return torch.cat( + [ + seq_embeddings_by_rank[t * my_size + i] + # .contiguous().view(-1) + for i in range(my_size) + for t in range(T_local) + ], + dim=0, + ) + else: + return torch.cat( + [ + seq_embeddings_by_rank[i * T_local + t] + # .contiguous() + # .view(-1) + for t in range(T_local) + for i in range(my_size) + ], + dim=0, + ) + + +class All2All_Pooled_Req(Function): + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + # pyre-fixme[2]: Parameter must be annotated. + ctx, + pg: dist.ProcessGroup, + myreq: Request[Tensor], + a2ai: All2AllPooledInfo, + input_embeddings: Tensor, + ) -> Tensor: + world_size = dist.get_world_size(pg) + if a2ai.mixed_dim: + (B_global, D_local_sum) = input_embeddings.shape + else: + (B_global, T_local, D) = input_embeddings.shape + D_local_sum = T_local * D + a2ai.D = D + dim_sum_per_rank = a2ai.dim_sum_per_rank + B_local = B_global // world_size + a2ai.B_local = B_local + assert ( + B_global % world_size == 0 + ), f"num of ranks {world_size} doesn't divide global batch size {B_global}" + + sharded_input_embeddings = input_embeddings.view( + world_size, B_local, D_local_sum + ) + D_global_sum = sum(dim_sum_per_rank) + sharded_output_embeddings = torch.empty( + B_local * D_global_sum, + dtype=input_embeddings.dtype, + device=input_embeddings.device, + ) + with record_function("## alltoall_fwd_single ##"): + req = dist.all_to_all_single( + output=sharded_output_embeddings, + input=sharded_input_embeddings, + output_split_sizes=[ + B_local * D_rank_sum for D_rank_sum in dim_sum_per_rank + ], + input_split_sizes=None, + group=pg, + async_op=True, + ) + assert ( + sum(B_local * D_rank_sum for D_rank_sum in dim_sum_per_rank) + == B_local * D_global_sum + ) + + myreq.req = req + myreq.tensor = sharded_output_embeddings + myreq.a2ai = a2ai + myreq.wait_function = All2All_Pooled_Wait + ctx.myreq = myreq + ctx.pg = pg + ctx.mixed_dim = a2ai.mixed_dim + return sharded_output_embeddings + + @staticmethod + # pyre-fixme[2]: Parameter must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]: + myreq = ctx.myreq + myreq.req.wait() + myreq.req = None + grad_output = myreq.tensor + if ctx.mixed_dim: + (W, B_local, D_local_sum) = grad_output.shape + grad_input = grad_output.view(W * B_local, D_local_sum) + else: + (W, B_local, T_local, D) = grad_output.shape + grad_input = grad_output.view(W * B_local, T_local, D) + if GRADIENT_DIVISION: + grad_input.div_(dist.get_world_size(ctx.pg)) + myreq.tensor = None + return (None, None, None, grad_input) + + +class All2All_Pooled_Wait(Function): + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + # pyre-fixme[2]: Parameter must be annotated. + ctx, + pg: dist.ProcessGroup, + myreq: Request[Tensor], + sharded_output_embeddings: Tensor, + ) -> Tensor: + a2ai = myreq.a2ai + ctx.a2ai = a2ai + myreq.req.wait() + myreq.req = None + myreq.tensor = None + ctx.pg = pg + ctx.myreq = myreq + dim_sum_per_rank = a2ai.dim_sum_per_rank + B_local = a2ai.B_local + mixed_dim = a2ai.mixed_dim + outputs_by_rank = sharded_output_embeddings.split( + [B_local * D_rank_sum for D_rank_sum in dim_sum_per_rank] + ) + if mixed_dim: + result = torch.cat( + [output.view(B_local, -1) for output in outputs_by_rank], dim=1 + ) + else: + D = a2ai.D + result = torch.cat( + [output.view(B_local, -1, D) for output in outputs_by_rank], dim=1 + ) + return result + + @staticmethod + # pyre-fixme[14]: `backward` overrides method defined in `Function` inconsistently. + # pyre-fixme[2]: Parameter must be annotated. + def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]: + myreq = ctx.myreq + a2ai = ctx.a2ai + pg = ctx.pg + world_size = dist.get_world_size(pg) + my_rank = dist.get_rank(pg) + dim_sum_per_rank = a2ai.dim_sum_per_rank + + D_local_sum = dim_sum_per_rank[my_rank] + if a2ai.mixed_dim: + (B_local, D_global_sum) = grad_output.shape + sharded_grad_input_sizes = (world_size, B_local, D_local_sum) + else: + (B_local, T_global, D) = grad_output.shape + D_global_sum = T_global * D + grad_output = grad_output.view(B_local, -1) + T_local = D_local_sum // D + sharded_grad_input_sizes = (world_size, B_local, T_local, D) + assert sum(dim_sum_per_rank) == D_global_sum + + sharded_grad_output = _recat_pooled_embedding_grad_out( + grad_output.contiguous(), + dim_sum_per_rank, + ) + + sharded_grad_input = torch.empty( + sharded_grad_input_sizes, device=grad_output.device, dtype=grad_output.dtype + ) + with record_function("## alltoall_bwd_single ##"): + req = dist.all_to_all_single( + output=sharded_grad_input, + input=sharded_grad_output, + output_split_sizes=None, + input_split_sizes=[ + B_local * D_rank_sum for D_rank_sum in dim_sum_per_rank + ], + group=pg, + async_op=True, + ) + myreq.req = req + myreq.tensor = sharded_grad_input + # Note - this mismatch is by design! We return sharded_grad_output to allow PyTorch shape matching to proceed correctly. + return (None, None, sharded_grad_output) + + +class All2All_Seq_Req(Function): + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + # pyre-fixme[2]: Parameter must be annotated. + ctx, + pg: dist.ProcessGroup, + myreq: Request[Tensor], + a2ai: All2AllSequenceInfo, + sharded_input_embeddings: Tensor, + ) -> torch.Tensor: + world_size = dist.get_world_size(pg) + my_rank = dist.get_rank(pg) + D = a2ai.embedding_dim + forward_recat_tensor = a2ai.forward_recat_tensor + lengths_after_sparse_data_all2all = a2ai.lengths_after_sparse_data_all2all * D + input_splits = [i * D for i in a2ai.output_splits] + output_splits = [i * D for i in a2ai.input_splits] + local_T = lengths_after_sparse_data_all2all.shape[0] + if local_T > 0: + with record_function("## alltoall_seq_embedding_fwd_permute ##"): + ( + permuted_lengths_after_sparse_data_all2all, + sharded_input_embeddings, + _, + ) = torch.ops.fbgemm.permute_sparse_data( + forward_recat_tensor, + lengths_after_sparse_data_all2all.view(local_T * world_size, -1), + sharded_input_embeddings.view(-1), + None, + sharded_input_embeddings.numel(), + ) + else: + permuted_lengths_after_sparse_data_all2all = None + sharded_output_embeddings = torch.empty( + sum(output_splits), + dtype=sharded_input_embeddings.dtype, + device=sharded_input_embeddings.device, + ) + + with record_function("## alltoall_seq_embedding_fwd_single ##"): + req = dist.all_to_all_single( + output=sharded_output_embeddings, + input=sharded_input_embeddings, + output_split_sizes=output_splits, + input_split_sizes=input_splits, + group=pg, + async_op=True, + ) + a2ai.permuted_lengths_after_sparse_data_all2all = ( + permuted_lengths_after_sparse_data_all2all + ) + a2ai.input_splits = input_splits + a2ai.output_splits = output_splits + myreq.req = req + myreq.tensor = sharded_output_embeddings + myreq.a2ai = a2ai + myreq.wait_function = All2All_Seq_Req_Wait + ctx.myreq = myreq + ctx.pg = pg + ctx.my_rank = my_rank + ctx.world_size = world_size + return sharded_output_embeddings + + @staticmethod + # pyre-fixme[2]: Parameter must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]: + myreq = ctx.myreq + a2ai = myreq.a2ai + D = a2ai.embedding_dim + backward_recat_tensor = a2ai.backward_recat_tensor + permuted_lengths_after_sparse_data_all2all = ( + a2ai.permuted_lengths_after_sparse_data_all2all + ) + myreq.req.wait() + sharded_grad_input = myreq.tensor + myreq.req = None + myreq.tensor = None + + if permuted_lengths_after_sparse_data_all2all is not None: + with record_function("## alltoall_seq_embedding_bwd_permute ##"): + _, sharded_grad_input, _ = torch.ops.fbgemm.permute_sparse_data( + backward_recat_tensor, + permuted_lengths_after_sparse_data_all2all, + sharded_grad_input, + None, + sharded_grad_input.numel(), + ) + return (None, None, None, sharded_grad_input.view(-1, D)) + + +class All2All_Seq_Req_Wait(Function): + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + # pyre-fixme[2]: Parameter must be annotated. + ctx, + pg: dist.ProcessGroup, + myreq: Request[Tensor], + sharded_output_embeddings: Tensor, + ) -> Tensor: + a2ai = myreq.a2ai + D = a2ai.embedding_dim + ctx.a2ai = a2ai + myreq.req.wait() + myreq.req = None + myreq.tensor = None + ctx.pg = pg + ctx.myreq = myreq + return sharded_output_embeddings.view(-1, D) + + @staticmethod + # pyre-fixme[14]: `backward` overrides method defined in `Function` inconsistently. + # pyre-fixme[2]: Parameter must be annotated. + def backward(ctx, sharded_grad_output: Tensor) -> Tuple[None, None, Tensor]: + myreq = ctx.myreq + a2ai = ctx.a2ai + pg = ctx.pg + input_splits = a2ai.output_splits + output_splits = a2ai.input_splits + sharded_grad_input = torch.empty( + sum(output_splits), + device=sharded_grad_output.device, + dtype=sharded_grad_output.dtype, + ) + with record_function("## alltoall_seq_embedding_bwd_single ##"): + req = dist.all_to_all_single( + output=sharded_grad_input, + input=sharded_grad_output.view(-1), + output_split_sizes=output_splits, + input_split_sizes=input_splits, + group=pg, + async_op=True, + ) + myreq.req = req + myreq.tensor = sharded_grad_input + + # Note - this mismatch is by design! We return sharded_grad_output + # to allow PyTorch shape matching to proceed correctly. + return (None, None, sharded_grad_output.view(-1)) + + +class All2Allv_Req(Function): + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + # pyre-fixme[2]: Parameter must be annotated. + ctx, + pg: dist.ProcessGroup, + myreq: Request[Tensor], + a2ai: All2AllVInfo, + inputs: List[Tensor], + ) -> Tensor: + input_split_sizes = [m * sum(a2ai.D_local_list) for m in a2ai.B_local_list] + output_split_sizes = [a2ai.B_local * e for e in a2ai.dims_sum_per_rank] + input = torch.cat(inputs, dim=1).view([-1]) + output = input.new_empty(sum(output_split_sizes)) + with record_function("## alltoallv_bwd_single ##"): + req = dist.all_to_all_single( + output, + input, + output_split_sizes, + input_split_sizes, + group=pg, + async_op=True, + ) + + myreq.req = req + myreq.tensor = output + myreq.wait_function = All2Allv_Wait + a2ai.input_split_sizes = input_split_sizes + a2ai.output_split_sizes = output_split_sizes + myreq.a2ai = a2ai + ctx.a2ai = a2ai + ctx.myreq = myreq + return output + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def backward(ctx, *grad_output): + a2ai = ctx.a2ai + myreq = ctx.myreq + myreq.req.wait() + myreq.req = None + grad_input = myreq.tensor + grad_inputs = grad_input.view([a2ai.B_global, -1]).split( + a2ai.D_local_list, dim=1 + ) + grad_inputs = [gin.contiguous() for gin in grad_inputs] + myreq.tensor = None + return (None, None, None, *grad_inputs) + + +class All2Allv_Wait(Function): + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def forward(ctx, pg: dist.ProcessGroup, myreq, output): + a2ai = myreq.a2ai + ctx.a2ai = a2ai + myreq.req.wait() + myreq.req = None + myreq.tensor = None + ctx.pg = pg + ctx.myreq = myreq + outputs = tuple( + [ + out.view([a2ai.B_local, -1]) + for out in output.split(a2ai.output_split_sizes) + ] + ) + return outputs + + @staticmethod + # pyre-fixme[2]: Parameter must be annotated. + def backward(ctx, *grad_outputs) -> Tuple[None, None, Tensor]: + pg = ctx.pg + myreq = ctx.myreq + a2ai = ctx.a2ai + grad_outputs = [gout.contiguous().view([-1]) for gout in grad_outputs] + grad_output = torch.cat(grad_outputs) + grad_input = grad_output.new_empty([a2ai.B_global * sum(a2ai.D_local_list)]) + with record_function("## alltoall_bwd_single ##"): + req = dist.all_to_all_single( + grad_input, + grad_output, + a2ai.input_split_sizes, + a2ai.output_split_sizes, + group=pg, + async_op=True, + ) + myreq.req = req + myreq.tensor = grad_input + return (None, None, grad_output) + + +class ReduceScatter_Req(Function): + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + # pyre-fixme[2]: Parameter must be annotated. + ctx, + pg: dist.ProcessGroup, + myreq: Request[Tensor], + rsi: ReduceScatterInfo, + *inputs: Any, + ) -> Tensor: + my_rank = dist.get_rank(pg) + output = inputs[my_rank].new_empty( + inputs[my_rank].size(), + dtype=inputs[my_rank].dtype, + device=inputs[my_rank].device, + ) + with record_function("## reduce_scatter ##"): + req = dist.reduce_scatter(output, list(inputs), group=pg, async_op=True) + myreq.req = req + myreq.tensor = output + myreq.wait_function = ReduceScatter_Wait + myreq.rsi = rsi + ctx.myreq = myreq + ctx.pg = pg + return output + + @staticmethod + # pyre-fixme[2]: Parameter must be annotated. + def backward(ctx, *unused: Tensor) -> Tuple[Optional[Tensor], ...]: + myreq = ctx.myreq + myreq.req.wait() + myreq.req = None + grad_inputs = list(myreq.tensor) + # Make it equivalent to running on a single rank. + if GRADIENT_DIVISION: + for grad_input in grad_inputs: + grad_input.div_(dist.get_world_size(ctx.pg)) + myreq.tensor = None + return (None, None, None, *grad_inputs) + + +class ReduceScatter_Wait(Function): + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def forward(ctx, pg: dist.ProcessGroup, myreq: Request[Tensor], output: Tensor): + myreq.req.wait() + myreq.req = None + myreq.tensor = None + ctx.myreq = myreq + ctx.pg = pg + return output + + @staticmethod + # pyre-fixme[14]: `backward` overrides method defined in `Function` inconsistently. + # pyre-fixme[2]: Parameter must be annotated. + def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]: + myreq = ctx.myreq + rsi = myreq.rsi + grad_inputs = [ + grad_output.new_empty( + in_size, + dtype=grad_output.dtype, + device=grad_output.device, + ) + for in_size in rsi.input_sizes + ] + with record_function("## reduce_scatter_bw (all_gather) ##"): + req = dist.all_gather( + grad_inputs, + grad_output.contiguous(), + group=ctx.pg, + async_op=True, + ) + myreq.req = req + myreq.tensor = grad_inputs + return (None, None, grad_output) diff --git a/torchrec/distributed/cw_sharding.py b/torchrec/distributed/cw_sharding.py new file mode 100644 index 000000000..f23c4d33a --- /dev/null +++ b/torchrec/distributed/cw_sharding.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 + +from typing import List, Optional, Tuple, cast + +import torch +import torch.distributed as dist +from torchrec.distributed.embedding_types import ( + ShardedEmbeddingTable, + EmbeddingComputeKernel, +) +from torchrec.distributed.tw_sharding import TwEmbeddingSharding +from torchrec.distributed.types import ( + ShardMetadata, + ParameterSharding, +) +from torchrec.modules.embedding_configs import EmbeddingTableConfig + + +class CwEmbeddingSharding(TwEmbeddingSharding): + """ + Shards embedding bags table-wise, i.e.. a given embedding table is entirely placed on a selected rank. + """ + + def __init__( + self, + embedding_configs: List[ + Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] + ], + pg: dist.ProcessGroup, + device: Optional[torch.device] = None, + ) -> None: + super().__init__(embedding_configs, pg, device) + + def _shard( + self, + embedding_configs: List[ + Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] + ], + ) -> List[List[ShardedEmbeddingTable]]: + world_size = self._pg.size() + tables_per_rank: List[List[ShardedEmbeddingTable]] = [ + [] for i in range(world_size) + ] + for config in embedding_configs: + # pyre-fixme [16] + shards: List[ShardMetadata] = config[1].sharding_spec.shards + placed_ranks = cast(List[int], config[1].ranks) + + for rank in range(world_size): + table = ShardedEmbeddingTable( + num_embeddings=config[0].num_embeddings, + embedding_dim=config[0].embedding_dim, + name=config[0].name, + embedding_names=[], + data_type=config[0].data_type, + feature_names=[], + pooling=config[0].pooling, + is_weighted=config[0].is_weighted, + has_feature_processor=config[0].has_feature_processor, + block_size=config[1].block_size, + compute_kernel=EmbeddingComputeKernel(config[1].compute_kernel), + ) + + if rank in placed_ranks: + shard_idx = placed_ranks.index(rank) + table.embedding_names = config[0].embedding_names + table.feature_names = config[0].feature_names + table.local_rows = config[0].num_embeddings + table.local_cols = config[0].embedding_dim + table.local_metadata = shards[shard_idx] + + tables_per_rank[rank].append(table) + + return tables_per_rank diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py new file mode 100644 index 000000000..ea9e394fa --- /dev/null +++ b/torchrec/distributed/dist_data.py @@ -0,0 +1,543 @@ +#!/usr/bin/env python3 + +import itertools +from typing import List, Optional, Callable + +import torch +import torch.distributed as dist +from torch import nn +from torch.autograd.profiler import record_function +from torchrec.distributed.comm_ops import ( + alltoall_pooled, + alltoall_sequence, + reduce_scatter_pooled, +) +from torchrec.distributed.types import Awaitable, NoWait +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +def _recat(local_split: int, num_splits: int, stagger: int = 1) -> List[int]: + """ + Calculates relevant recat indices required to reorder All-to-All Collective + + Call Args: + local_split: how many features in local split + num_splits: how many splits (typically WORLD_SIZE) + stagger: secondary reordering, (typically 1, but WORLD_SIZE/LOCAL_WORLD_SIZE for TWRW) + + Returns: + List[int] + + Example: + >>> _recat(2, 4, 1) + [0, 2, 4, 6, 1, 3, 5, 7] + >>> _recat(2, 4, 2) + [0, 4, 2, 6, 1, 5, 3, 7] + + """ + recat: List[int] = [] + + feature_order: List[int] = [ + x + num_splits // stagger * y + for x in range(num_splits // stagger) + for y in range(stagger) + ] + + for i in range(local_split): + for j in feature_order: # range(num_splits): + recat.append(i + j * local_split) + return recat + + +def _split_lengths( + splits: List[int], keys: List[str], offset_per_key: List[int] +) -> List[int]: + # Calculates lengths [x1, x2, x3, ..., y1, y2], splits [3, ..., 2] + # -> [x1+x2+x3, ..., y1+y2] + length_per_split: List[int] = [] + i = 0 + offset = 0 + for split in splits: + new_offset = offset_per_key[i + split] + length_per_split.append(new_offset - offset) + i += split + offset = new_offset + return length_per_split + + +class KJTAllToAllAwaitable(Awaitable[KeyedJaggedTensor]): + """ + Awaitable for KJT all2all + + Constructor Args: + pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. + input (KeyedJaggedTensor): Input KJT tensor + splits (List[int]): List of len(pg.size()) which indicates how many features to send to + each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. + Same for all ranks. + keys (List[str]): KJT keys after all2all + recat (torch.Tensor): recat tensor for reordering tensor order after all2all + + Call Args: + None + + Returns: + Synced KJT after all2all + """ + + def __init__( + self, + pg: dist.ProcessGroup, + input: KeyedJaggedTensor, + splits: List[int], + keys: List[str], + recat: torch.Tensor, + ) -> None: + super().__init__() + self._workers: int = pg.size() + self._input = input + self._callback: Optional[ + Callable[[KeyedJaggedTensor], KeyedJaggedTensor] + ] = None + self._in_lengths_per_worker: List[int] = [] + self._out_lengths_per_worker: List[int] = [] + if self._workers == 1: + return + self._recat = recat + self._splits = splits + self._pg: dist.ProcessGroup = pg + self._device: torch.device = input.values().device + self._keys = keys + + dim_0 = splits[pg.rank()] + dim_1 = input.stride() + in_lengths = input.lengths().view(-1) + out_lengths = torch.empty( + dim_0 * dim_1 * self._workers, + device=self._device, + dtype=in_lengths.dtype, + ) + + with record_function("## all2all_data:lengths ##"): + dist.all_to_all_single( + output=out_lengths, + input=in_lengths, + output_split_sizes=[dim_0 * dim_1] * self._workers, + input_split_sizes=[split * dim_1 for split in self._splits], + group=self._pg, + async_op=False, + ) + + self._in_lengths_per_worker = _split_lengths( + splits, input.keys(), input.offset_per_key() + ) + self._out_lengths_per_worker = ( + out_lengths.view(self._workers, -1).sum(dim=1).cpu().tolist() + ) + in_values = input.values().view(-1) + out_values = torch.empty( + sum(self._out_lengths_per_worker), + device=self._device, + dtype=in_values.dtype, + ) + # Pyre-fixme [11] + self._values_awaitable: dist.Work = dist.all_to_all_single( + output=out_values, + input=in_values, + output_split_sizes=self._out_lengths_per_worker, + input_split_sizes=self._in_lengths_per_worker, + group=self._pg, + async_op=True, + ) + + self._values: torch.Tensor = out_values + self._lengths: torch.Tensor = out_lengths + + self._weights_awaitable: Optional[dist.Work] = None + self._weights: Optional[torch.Tensor] = None + + if input.weights_or_none() is not None: + in_weights = input.weights().view(-1) + out_weights = torch.empty( + sum(self._out_lengths_per_worker), + device=self._device, + dtype=in_weights.dtype, + ) + self._weights_awaitable: dist.Work = dist.all_to_all_single( + output=out_weights, + input=in_weights, + output_split_sizes=self._out_lengths_per_worker, + input_split_sizes=self._in_lengths_per_worker, + group=self._pg, + async_op=True, + ) + self._weights: torch.Tensor = out_weights + + def wait(self) -> KeyedJaggedTensor: + if self._workers == 1: + # TODO: add callback logic to awaitable type directly + self._input.sync() + return ( + self._callback(self._input) + if self._callback is not None + else self._input + ) + + with record_function("## all2all_data:values ##"): + self._values_awaitable.wait() + + if self._weights_awaitable: + with record_function("## all2all_data:weights ##"): + self._weights_awaitable.wait() + + keys = self._keys + lengths = self._lengths + values = self._values + weights = self._weights + + with record_function("## all2all_data:recat_values ##"): + if self._recat.numel(): + lengths, values, weights = torch.ops.fbgemm.permute_sparse_data( + self._recat, + lengths.view(self._workers * self._splits[self._pg.rank()], -1), + values, + weights, + values.numel(), + ) + lengths = lengths.view(-1) + + ret = KeyedJaggedTensor.from_lengths_sync( + keys=keys, + values=values, + weights=weights, + lengths=lengths, + stride=self._workers * self._input.stride(), + ) + + # TODO: add callback logic to awaitable type directly + return self._callback(ret) if self._callback is not None else ret + + +class KJTAllToAll(nn.Module): + """ + Redistributes KeyedJaggedTensor to a ProcessGroup according to splits + + Implementation utilizes alltoall collective as part of torch.distributed. + Requires two collective calls, one to transmit final tensor lengths (to allocate + correct space), and one to transmit actual sparse values. + + Example: + + >>> keys=['A','B','C'] + >>> splits=[2,1] + >>> sdd = SparseDataDist(pg, splits, device) + >>> awaitable = sdd(rank0_input) + + where: + rank0_input is KeyedJaggedTensor holding + + 0 1 2 + 'A' [A.V0] None [A.V1, A.V2] + 'B' None [B.V0] [B.V1] + 'C' [C.V0] [C.V1] None + + rank1_input is KeyedJaggedTensor holding + + 0 1 2 + 'A' [A.V3] [A.V4] None + 'B' None [B.V2] [B.V3, B.V4] + 'C' [C.V2] [C.V3] None + + >>> rank0_output = awaitable.wait() + + rank0_output is KeyedJaggedTensor holding + + 0 1 2 3 4 5 + 'A' [A.V0] None [A.V1, A.V2] [A.V3] [A.V4] None + 'B' None [B.V0] [B.V1] None [B.V2] [B.V3, B.V4] + + rank1_output is KeyedJaggedTensor holding + 0 1 2 3 4 5 + 'C' [C.V0] [C.V1] None [C.V2] [C.V3] None + + Constructor Args: + pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. + splits (List[int]): List of len(pg.size()) which indicates how many features to send to + each pg.rank(). It is assumed the KeyedJaggedTensor is ordered by destination rank. + Same for all ranks. + device (Optional[torch.device]): device on which buffers will be allocated + stagger (int): stagger value to apply to recat tensor, see _recat function for more detail + + Call Args: + input (KeyedJaggedTensor): Input KJT tensor + + Returns: + None + """ + + def __init__( + self, + pg: dist.ProcessGroup, + splits: List[int], + device: Optional[torch.device] = None, + stagger: int = 1, + ) -> None: + super().__init__() + assert len(splits) == pg.size() + self._pg: dist.ProcessGroup = pg + self._splits = splits + self._no_dist: bool = all(s == 0 for s in splits) + self._splits_cumsum: List[int] = [0] + list(itertools.accumulate(splits)) + self.register_buffer( + "_recat", + torch.tensor( + _recat( + local_split=splits[pg.rank()], + num_splits=len(splits), + stagger=stagger, + ), + device=device, + dtype=torch.int, + ), + ) + + def forward(self, input: KeyedJaggedTensor) -> Awaitable[KeyedJaggedTensor]: + """ + Sends input to relevant ProcessGroup ranks + + Call Args: + input (KeyedJaggedTensor): A Jagged tensor of values to distribute + + Returns: + awaitable of a KeyedJaggedTensor + """ + with torch.no_grad(): + if self._no_dist: + assert len(input.keys()) == 0 + return NoWait(input) + else: + assert len(input.keys()) == sum(self._splits) + rank = dist.get_rank(self._pg) + local_keys = input.keys()[ + self._splits_cumsum[rank] : self._splits_cumsum[rank + 1] + ] + + return KJTAllToAllAwaitable( + pg=self._pg, + input=input, + splits=self._splits, + keys=local_keys, + recat=self._recat, + ) + + +class PooledEmbeddingsAwaitable(Awaitable[torch.Tensor]): + def __init__( + self, + tensor_awaitable: Awaitable[torch.Tensor], + ) -> None: + super().__init__() + self._tensor_awaitable = tensor_awaitable + self._callback: Optional[Callable[[torch.Tensor], torch.Tensor]] = None + + def wait(self) -> torch.Tensor: + ret = self._tensor_awaitable.wait() + # TODO: add callback logic to awaitable type directly + if self._callback is not None: + ret = self._callback(ret) + + return ret + + +class PooledEmbeddingsAllToAll(nn.Module): + # TODO: potentially refactor to take KT instead of torch.Tensor: D29174501 + """ + Shards batchs and collects keys of Tensor with a ProcessGroup according to dim_sum_per_rank + + Implementation utilizes alltoall_pooled operation. + + Example: + >>> dim_sum_per_rank = [2, 1] + >>> a2a = PooledEmbeddingsAllToAll(pg, dim_sum_per_rank, device) + + >>> t0 = torch.rand((6, 2)) + >>> t1 = torch.rand((6, 1)) + >>> rank0_output = a2a(t0).wait() + >>> rank1_output = a2a(t1).wait() + >>> print(rank0_output.size()) + torch.Size([3, 3]) + >>> print(rank1_output.size()) + torch.Size([3, 3]) + + Constructor Args: + pg: dist.ProcessGroup, + dim_sum_per_rank: List[int], + device: Optional[torch.device] = None, + + Call Args: + local_embs: torch.Tensor + + Returns: + PooledEmbeddingsAwaitable + """ + + def __init__( + self, + pg: dist.ProcessGroup, + dim_sum_per_rank: List[int], + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self._pg = pg + + self._dim_sum_per_rank = dim_sum_per_rank + self.register_buffer( + "_dim_sum_per_rank_tensor", + torch.tensor(dim_sum_per_rank, device=device, dtype=torch.int), + ) + cumsum_dim_sum_per_rank = list(itertools.accumulate(dim_sum_per_rank)) + self.register_buffer( + "_cumsum_dim_sum_per_rank_tensor", + torch.tensor(cumsum_dim_sum_per_rank, device=device, dtype=torch.int), + ) + + def forward(self, local_embs: torch.Tensor) -> PooledEmbeddingsAwaitable: + if local_embs.numel() == 0: + local_embs.view(local_embs.size(0) * self._pg.size(), 0) + tensor_awaitable = alltoall_pooled( + a2a_pooled_embs_tensor=local_embs, + dim_sum_per_rank=self._dim_sum_per_rank, + mixed_dim=True, + dim_sum_per_rank_tensor=self._dim_sum_per_rank_tensor, + cumsum_dim_sum_per_rank_tensor=self._cumsum_dim_sum_per_rank_tensor, + group=self._pg, + ) + return PooledEmbeddingsAwaitable( + tensor_awaitable=tensor_awaitable, + ) + + +class PooledEmbeddingsReduceScatter(nn.Module): + def __init__( + self, + pg: dist.ProcessGroup, + ) -> None: + """The module class that wraps reduce-scatter communication primitive + for pooled embedding communication in row-wise and twrw sharding. + + For pooled embeddings, we have a local model-parallel output tensor with + a layout of [num_buckets x batch_size, dimension]. We need to sum over num_buckets dimension across batches. + We split tensor along the first dimension into equal chunks(tensor slices of different buckets) and + reduce them into the output tensor and scatter the results for corresponding ranks. + The class returns the async Awaitable handle for pooled embeddings tensor. + The reduce-scatter is only available for nccl backend. + + Constructor Args:: + pg (dist.ProcessGroup): The process group that the reduce-scatter communication happens within. + + Call Args: + input (torch.Tensor): tensor of shape [num_buckets x batch_size, dimension]. + + Returns: + output (torch.Tensor): PooledEmbeddingsAwaitable of tensor of shape [batch_size, dimension]. + + Example: + >>> init_distributed(rank=rank, size=2, backend="nccl") + >>> pg = dist.new_group(backend="nccl") + >>> input = torch.randn(2 * 2, 2) + >>> m = PooledEmbeddingsReduceScatter(pg) + >>> output = m(input) + >>> tensor = output.wait() + """ + super().__init__() + self._pg = pg + + def forward(self, local_embs: torch.Tensor) -> PooledEmbeddingsAwaitable: + tensor_awaitable = reduce_scatter_pooled( + list(torch.chunk(local_embs, self._pg.size(), dim=0)), self._pg + ) + return PooledEmbeddingsAwaitable(tensor_awaitable=tensor_awaitable) + + +class SequenceEmbeddingsAwaitable(Awaitable[torch.Tensor]): + def __init__( + self, + tensor_awaitable: Awaitable[torch.Tensor], + unbucketize_permute_tensor: Optional[torch.Tensor], + embedding_dim: int, + ) -> None: + super().__init__() + self._tensor_awaitable = tensor_awaitable + self._unbucketize_permute_tensor = unbucketize_permute_tensor + self._callback: Optional[Callable[[torch.Tensor], torch.Tensor]] = None + self._embedding_dim = embedding_dim + + def wait(self) -> torch.Tensor: + ret = self._tensor_awaitable.wait() + # TODO: add callback logic to awaitable type directly + if self._callback is not None: + ret = self._callback(ret) + if self._unbucketize_permute_tensor is not None: + ret = torch.index_select( + ret.view(-1, self._embedding_dim), + 0, + self._unbucketize_permute_tensor, + ) + return ret + + +class SequenceEmbeddingAllToAll(nn.Module): + def __init__( + self, + pg: dist.ProcessGroup, + features_per_rank: List[int], + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self._pg = pg + + forward_recat = [] + for j in range(self._pg.size()): + for i in range(features_per_rank[self._pg.rank()]): + forward_recat.append(j + i * self._pg.size()) + self.register_buffer( + "_forward_recat_tensor", + torch.tensor(forward_recat, device=device, dtype=torch.int), + ) + backward_recat = [] + for i in range(features_per_rank[self._pg.rank()]): + for j in range(self._pg.size()): + backward_recat.append(i + j * features_per_rank[self._pg.rank()]) + self.register_buffer( + "_backward_recat_tensor", + torch.tensor(backward_recat, device=device, dtype=torch.int), + ) + + def forward( + self, + local_embs: torch.Tensor, + lengths: torch.Tensor, + input_splits: List[int], + output_splits: List[int], + unbucketize_permute_tensor: Optional[torch.Tensor] = None, + ) -> SequenceEmbeddingsAwaitable: + tensor_awaitable = alltoall_sequence( + a2a_sequence_embs_tensor=local_embs, + forward_recat_tensor=self._forward_recat_tensor, + backward_recat_tensor=self._backward_recat_tensor, + lengths_after_sparse_data_all2all=lengths, + input_splits=input_splits, + output_splits=output_splits, + group=self._pg, + ) + return SequenceEmbeddingsAwaitable( + tensor_awaitable=tensor_awaitable, + unbucketize_permute_tensor=unbucketize_permute_tensor, + embedding_dim=local_embs.shape[1], + ) diff --git a/torchrec/distributed/dp_sharding.py b/torchrec/distributed/dp_sharding.py new file mode 100644 index 000000000..ab5117e94 --- /dev/null +++ b/torchrec/distributed/dp_sharding.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 + +from typing import List, Optional, Dict, Any, Tuple + +import torch +from torch.distributed._sharding_spec import ShardMetadata +from torchrec.distributed.embedding_lookup import ( + GroupedPooledEmbeddingsLookup, + GroupedEmbeddingsLookup, +) +from torchrec.distributed.embedding_sharding import ( + EmbeddingSharding, + group_tables, + BasePooledEmbeddingDist, + BaseSequenceEmbeddingDist, + BaseSparseFeaturesDist, + SequenceShardingContext, + BaseEmbeddingLookup, +) +from torchrec.distributed.embedding_types import ( + GroupedEmbeddingConfig, + SparseFeatures, + ShardedEmbeddingTable, + EmbeddingComputeKernel, + BaseGroupedFeatureProcessor, +) +from torchrec.distributed.types import Awaitable, NoWait, ParameterSharding, ShardingEnv +from torchrec.modules.embedding_configs import EmbeddingTableConfig + + +class DpSparseFeaturesDist(BaseSparseFeaturesDist): + def __init__(self) -> None: + super().__init__() + + def forward( + self, + sparse_features: SparseFeatures, + ) -> Awaitable[SparseFeatures]: + return NoWait(sparse_features) + + +class DpPooledEmbeddingDist(BasePooledEmbeddingDist): + def __init__(self) -> None: + super().__init__() + + def forward(self, local_embs: torch.Tensor) -> Awaitable[torch.Tensor]: + return NoWait(local_embs) + + +class DpSequenceEmbeddingDist(BaseSequenceEmbeddingDist): + def __init__(self) -> None: + super().__init__() + + def forward( + self, sharding_ctx: SequenceShardingContext, local_embs: torch.Tensor + ) -> Awaitable[torch.Tensor]: + return NoWait(local_embs) + + +class DpEmbeddingSharding(EmbeddingSharding): + """ + Use data-parallel, no table sharding + """ + + def __init__( + self, + embedding_configs: List[ + Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] + ], + env: ShardingEnv, + device: Optional[torch.device] = None, + is_sequence: bool = False, + ) -> None: + super().__init__() + self._env = env + self._device = device + self._is_sequence = is_sequence + sharded_tables_per_rank = self._shard(embedding_configs) + self._grouped_embedding_configs_per_rank: List[ + List[GroupedEmbeddingConfig] + ] = [] + self._score_grouped_embedding_configs_per_rank: List[ + List[GroupedEmbeddingConfig] + ] = [] + ( + self._grouped_embedding_configs_per_rank, + self._score_grouped_embedding_configs_per_rank, + ) = group_tables(sharded_tables_per_rank) + self._grouped_embedding_configs: List[ + GroupedEmbeddingConfig + ] = self._grouped_embedding_configs_per_rank[env.rank] + self._score_grouped_embedding_configs: List[ + GroupedEmbeddingConfig + ] = self._score_grouped_embedding_configs_per_rank[env.rank] + + def _shard( + self, + embedding_configs: List[ + Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] + ], + ) -> List[List[ShardedEmbeddingTable]]: + world_size = self._env.world_size + tables_per_rank: List[List[ShardedEmbeddingTable]] = [ + [] for i in range(world_size) + ] + for config in embedding_configs: + for rank in range(world_size): + tables_per_rank[rank].append( + ShardedEmbeddingTable( + num_embeddings=config[0].num_embeddings, + embedding_dim=config[0].embedding_dim, + name=config[0].name, + embedding_names=config[0].embedding_names, + data_type=config[0].data_type, + feature_names=config[0].feature_names, + pooling=config[0].pooling, + is_weighted=config[0].is_weighted, + has_feature_processor=config[0].has_feature_processor, + local_rows=config[2].size(0), + local_cols=config[2].size(1), + compute_kernel=EmbeddingComputeKernel(config[1].compute_kernel), + local_metadata=None, + weight_init_max=config[0].weight_init_max, + weight_init_min=config[0].weight_init_min, + ) + ) + return tables_per_rank + + def create_input_dist(self) -> DpSparseFeaturesDist: + return DpSparseFeaturesDist() + + def create_lookup( + self, + fused_params: Optional[Dict[str, Any]], + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup: + if self._is_sequence: + return GroupedEmbeddingsLookup( + grouped_configs=self._grouped_embedding_configs, + fused_params=fused_params, + pg=self._env.process_group, + device=self._device, + ) + else: + return GroupedPooledEmbeddingsLookup( + grouped_configs=self._grouped_embedding_configs, + grouped_score_configs=self._score_grouped_embedding_configs, + fused_params=fused_params, + pg=self._env.process_group, + device=self._device, + feature_processor=feature_processor, + ) + + def create_pooled_output_dist(self) -> DpPooledEmbeddingDist: + return DpPooledEmbeddingDist() + + def create_sequence_output_dist(self) -> DpSequenceEmbeddingDist: + return DpSequenceEmbeddingDist() + + def embedding_dims(self) -> List[int]: + embedding_dims = [] + for grouped_config in self._grouped_embedding_configs: + embedding_dims.extend(grouped_config.embedding_dims()) + for grouped_config in self._score_grouped_embedding_configs: + embedding_dims.extend(grouped_config.embedding_dims()) + return embedding_dims + + def embedding_names(self) -> List[str]: + embedding_names = [] + for grouped_config in self._grouped_embedding_configs: + embedding_names.extend(grouped_config.embedding_names()) + for grouped_config in self._score_grouped_embedding_configs: + embedding_names.extend(grouped_config.embedding_names()) + return embedding_names + + def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: + embedding_shard_metadata = [] + for grouped_config in self._grouped_embedding_configs: + embedding_shard_metadata.extend(grouped_config.embedding_shard_metadata()) + for grouped_config in self._score_grouped_embedding_configs: + embedding_shard_metadata.extend(grouped_config.embedding_shard_metadata()) + return embedding_shard_metadata + + def id_list_feature_names(self) -> List[str]: + id_list_feature_names = [] + for grouped_config in self._grouped_embedding_configs: + id_list_feature_names.extend(grouped_config.feature_names()) + return id_list_feature_names + + def id_score_list_feature_names(self) -> List[str]: + id_score_list_feature_names = [] + for grouped_config in self._score_grouped_embedding_configs: + id_score_list_feature_names.extend(grouped_config.feature_names()) + return id_score_list_feature_names diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py new file mode 100644 index 000000000..af9e176fa --- /dev/null +++ b/torchrec/distributed/embedding.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 + +from typing import ( + List, + Dict, + Optional, + Type, + Any, + TypeVar, +) + +import torch +from torch import nn +from torchrec.distributed.embedding_types import ( + BaseEmbeddingSharder, + SparseFeaturesList, +) +from torchrec.distributed.types import ( + Awaitable, + LazyAwaitable, + ParameterSharding, + ShardedModule, + ShardedModuleContext, + ShardingEnv, +) +from torchrec.modules.embedding_modules import ( + EmbeddingCollection, +) +from torchrec.optim.fused import FusedOptimizerModule +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +class ShardedEmbeddingCollection( + ShardedModule[ + SparseFeaturesList, + List[torch.Tensor], + KeyedTensor, + ], + FusedOptimizerModule, +): + """ + Sharded implementation of EmbeddingCollection. + This is part of public API to allow for manual data dist pipelining. + """ + + def __init__( + self, + module: EmbeddingCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + env: ShardingEnv, + fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + + # pyre-ignore [14] + def input_dist( + self, ctx: ShardedModuleContext, features: KeyedJaggedTensor + ) -> Awaitable[SparseFeaturesList]: + # pyre-ignore [7] + pass + + def compute( + self, ctx: ShardedModuleContext, dist_input: SparseFeaturesList + ) -> List[torch.Tensor]: + # pyre-ignore [7] + pass + + def output_dist( + self, ctx: ShardedModuleContext, output: List[torch.Tensor] + ) -> LazyAwaitable[KeyedTensor]: + # pyre-ignore [7] + pass + + +M = TypeVar("M", bound=nn.Module) + + +class EmbeddingCollectionSharder(BaseEmbeddingSharder[M]): + """ + This implementation uses non-fused EmbeddingCollection + """ + + def shard( + self, + module: EmbeddingCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + ) -> ShardedEmbeddingCollection: + return ShardedEmbeddingCollection( + module, params, env, self.fused_params, device + ) + + def shardable_parameters( + self, module: EmbeddingCollection + ) -> Dict[str, nn.Parameter]: + return {} + + @property + def module_type(self) -> Type[EmbeddingCollection]: + return EmbeddingCollection diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py new file mode 100644 index 000000000..612b51172 --- /dev/null +++ b/torchrec/distributed/embedding_lookup.py @@ -0,0 +1,1251 @@ +#!/usr/bin/env python3 + +import abc +import copy +import itertools +from collections import OrderedDict +from typing import List, Optional, Dict, Any, Union, Tuple, cast, Iterator + +import torch +import torch.distributed as dist +import torch.distributed._sharded_tensor as sharded_tensor +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType +from fbgemm_gpu.split_table_batched_embeddings_ops import ( + EmbeddingLocation, + ComputeDevice, + PoolingMode, + DenseTableBatchedEmbeddingBagsCodegen, + SplitTableBatchedEmbeddingBagsCodegen, + IntNBitTableBatchedEmbeddingBagsCodegen, + rounded_row_size_in_bytes, +) +from torch import nn +from torch.nn.modules.module import _IncompatibleKeys +from torchrec.distributed.embedding_types import ( + GroupedEmbeddingConfig, + BaseEmbeddingLookup, + SparseFeatures, + EmbeddingComputeKernel, + ShardedEmbeddingTable, + BaseGroupedFeatureProcessor, +) +from torchrec.distributed.grouped_position_weighted import ( + GroupedPositionWeightedModule, +) +from torchrec.distributed.types import ( + Shard, + ShardMetadata, + ShardedTensor, +) +from torchrec.distributed.utils import append_prefix +from torchrec.modules.embedding_configs import ( + PoolingType, + DataType, + DATA_TYPE_NUM_BITS, +) +from torchrec.optim.fused import FusedOptimizerModule, FusedOptimizer +from torchrec.sparse.jagged_tensor import ( + KeyedJaggedTensor, + KeyedTensor, +) + + +def _load_state_dict( + emb_modules: "nn.ModuleList[nn.Module]", + state_dict: "OrderedDict[str, torch.Tensor]", +) -> Tuple[List[str], List[str]]: + missing_keys = [] + unexpected_keys = list(state_dict.keys()) + for emb_module in emb_modules: + for key, param in emb_module.state_dict().items(): + if key in state_dict: + if isinstance(param, ShardedTensor): + assert len(param.local_shards()) == 1 + dst_tensor = param.local_shards()[0].tensor + else: + dst_tensor = param + if isinstance(state_dict[key], ShardedTensor): + # pyre-fixme[16] + assert len(state_dict[key].local_shards()) == 1 + src_tensor = state_dict[key].local_shards()[0].tensor + else: + src_tensor = state_dict[key] + dst_tensor.detach().copy_(src_tensor) + unexpected_keys.remove(key) + else: + missing_keys.append(cast(str, key)) + return missing_keys, unexpected_keys + + +class BaseEmbedding(abc.ABC, nn.Module): + """ + abstract base class for grouped nn.Embedding + """ + + @abc.abstractmethod + def forward( + self, + features: KeyedJaggedTensor, + ) -> torch.Tensor: + pass + + """ + return sparse gradient parameter names + """ + + def sparse_grad_parameter_names( + self, destination: Optional[List[str]] = None, prefix: str = "" + ) -> List[str]: + destination = [] if destination is None else destination + return destination + + +class GroupedEmbedding(BaseEmbedding): + def __init__( + self, + config: GroupedEmbeddingConfig, + sparse: bool, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}") + self._config = config + self._pg = pg + self._emb_modules: nn.ModuleList[nn.Module] = nn.ModuleList() + self._sparse = sparse + for embedding_config in self._config.local_embedding_tables: + self._emb_modules.append( + nn.Embedding( + num_embeddings=embedding_config.local_rows, + embedding_dim=embedding_config.local_cols, + device=device, + sparse=self._sparse, + _weight=torch.empty( + embedding_config.local_rows, + embedding_config.local_cols, + device=device, + ).uniform_( + embedding_config.get_weight_init_min(), + embedding_config.get_weight_init_max(), + ), + ) + ) + + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + indices_dict: Dict[str, torch.Tensor] = {} + indices_list = torch.split(features.values(), features.length_per_key()) + for key, indices in zip(features.keys(), indices_list): + indices_dict[key] = indices + unpooled_embeddings: List[torch.Tensor] = [] + for embedding_config, emb_module in zip( + self._config.local_embedding_tables, self._emb_modules + ): + for feature_name in embedding_config.feature_names: + unpooled_embeddings.append(emb_module(input=indices_dict[feature_name])) + return torch.cat(unpooled_embeddings, dim=0) + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + + for config in self._config.global_embedding_tables: + key = prefix + f"{config.name}.weight" + if config in self._config.local_embedding_tables: + config_idx = self._config.local_embedding_tables.index(config) + emb_module = self._emb_modules[config_idx] + param = emb_module.weight if keep_vars else emb_module.weight.data + assert config.local_rows == param.size(0) + assert config.local_cols == param.size(1) + if config.local_metadata is not None: + destination[key] = sharded_tensor.init_from_local_shards( + # pyre-ignore [6] + [Shard(param, config.local_metadata)], + [config.num_embeddings, config.embedding_dim], + process_group=self._pg, + ) + else: + destination[key] = param + else: + # just an handler for tw-related sharding on the rank that + # those tables aren't exist, this is to comply with SPMD + sharded_tensor.init_from_local_shards( + [], + [config.num_embeddings, config.embedding_dim], + process_group=self._pg, + ) + return destination + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + for config, emb_module in zip( + self._config.local_embedding_tables, + self._emb_modules, + ): + param = emb_module.weight + assert config.local_rows == param.size(0) + assert config.local_cols == param.size(1) + yield append_prefix(prefix, f"{config.name}.weight"), param + + def sparse_grad_parameter_names( + self, destination: Optional[List[str]] = None, prefix: str = "" + ) -> List[str]: + destination = [] if destination is None else destination + if self._sparse: + for config in self._config.local_embedding_tables: + destination.append(append_prefix(prefix, f"{config.name}.weight")) + return destination + + def config(self) -> GroupedEmbeddingConfig: + return self._config + + +class GroupedEmbeddingsLookup(BaseEmbeddingLookup): + def __init__( + self, + grouped_configs: List[GroupedEmbeddingConfig], + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + fused_params: Optional[Dict[str, Any]] = None, + ) -> None: + def _create_lookup( + config: GroupedEmbeddingConfig, + ) -> BaseEmbedding: + if config.compute_kernel == EmbeddingComputeKernel.DENSE: + return GroupedEmbedding( + config=config, + sparse=False, + pg=pg, + device=device, + ) + elif config.compute_kernel == EmbeddingComputeKernel.SPARSE: + return GroupedEmbedding( + config=config, + sparse=True, + pg=pg, + device=device, + ) + else: + raise ValueError( + f"Compute kernel not supported {config.compute_kernel}" + ) + + super().__init__() + self._emb_modules: nn.ModuleList[BaseEmbedding] = nn.ModuleList() + for config in grouped_configs: + self._emb_modules.append(_create_lookup(config)) + + self._id_list_feature_splits: List[int] = [] + for config in grouped_configs: + self._id_list_feature_splits.append(config.num_features()) + + # return a dummy empty tensor when grouped_configs is empty + self.register_buffer( + "_dummy_embs_tensor", + torch.empty( + [0], + dtype=torch.float32, + device=device, + requires_grad=True, + ), + ) + + self.grouped_configs = grouped_configs + + def forward( + self, + sparse_features: SparseFeatures, + ) -> torch.Tensor: + assert sparse_features.id_list_features is not None + embeddings: List[torch.Tensor] = [] + id_list_features_by_group = sparse_features.id_list_features.split( + self._id_list_feature_splits, + ) + for emb_op, features in zip(self._emb_modules, id_list_features_by_group): + embeddings.append(emb_op(features).view(-1)) + + if len(embeddings) == 0: + # a hack for empty ranks + return self._dummy_embs_tensor + elif len(embeddings) == 1: + return embeddings[0] + else: + return torch.cat(embeddings) + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + + for emb_module in self._emb_modules: + emb_module.state_dict(destination, prefix, keep_vars) + + return destination + + def load_state_dict( + self, + state_dict: "OrderedDict[str, torch.Tensor]", + strict: bool = True, + ) -> _IncompatibleKeys: + # pyre-ignore [6] + m, u = _load_state_dict(self._emb_modules, state_dict) + return _IncompatibleKeys(missing_keys=m, unexpected_keys=u) + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + for emb_module in self._emb_modules: + yield from emb_module.named_parameters(prefix, recurse) + + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for emb_module in self._emb_modules: + yield from emb_module.named_buffers(prefix, recurse) + + def sparse_grad_parameter_names( + self, destination: Optional[List[str]] = None, prefix: str = "" + ) -> List[str]: + destination = [] if destination is None else destination + for emb_module in self._emb_modules: + emb_module.sparse_grad_parameter_names(destination, prefix) + return destination + + +class BaseEmbeddingBag(nn.Module): + """ + abstract base class for grouped nn.EmbeddingBag + """ + + """ + return sparse gradient parameter names + """ + + def sparse_grad_parameter_names( + self, destination: Optional[List[str]] = None, prefix: str = "" + ) -> List[str]: + destination = [] if destination is None else destination + return destination + + @property + @abc.abstractmethod + def config(self) -> GroupedEmbeddingConfig: + pass + + +class GroupedEmbeddingBag(BaseEmbeddingBag): + def __init__( + self, + config: GroupedEmbeddingConfig, + sparse: bool, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + ) -> None: + def _to_mode(pooling: PoolingType) -> str: + if pooling == PoolingType.SUM: + return "sum" + elif pooling == PoolingType.MEAN: + return "mean" + else: + raise ValueError(f"Unsupported pooling {pooling}") + + super().__init__() + torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}") + self._config = config + self._pg = pg + self._emb_modules: nn.ModuleList[nn.Module] = nn.ModuleList() + self._sparse = sparse + self._emb_names: List[str] = [] + self._lengths_per_emb: List[int] = [] + + shared_feature: Dict[str, bool] = {} + for embedding_config in self._config.local_embedding_tables: + self._emb_modules.append( + nn.EmbeddingBag( + num_embeddings=embedding_config.local_rows, + embedding_dim=embedding_config.local_cols, + mode=_to_mode(embedding_config.pooling), + device=device, + include_last_offset=True, + sparse=self._sparse, + _weight=torch.empty( + embedding_config.local_rows, + embedding_config.local_cols, + device=device, + ).uniform_( + embedding_config.get_weight_init_min(), + embedding_config.get_weight_init_max(), + ), + ) + ) + for feature_name in embedding_config.feature_names: + if feature_name not in shared_feature: + shared_feature[feature_name] = False + else: + shared_feature[feature_name] = True + self._lengths_per_emb.append(embedding_config.embedding_dim) + + for embedding_config in self._config.local_embedding_tables: + for feature_name in embedding_config.feature_names: + if shared_feature[feature_name]: + self._emb_names.append(feature_name + "@" + embedding_config.name) + else: + self._emb_names.append(feature_name) + + def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + pooled_embeddings: List[torch.Tensor] = [] + for embedding_config, emb_module in zip( + self._config.local_embedding_tables, self._emb_modules + ): + for feature_name in embedding_config.feature_names: + values = features[feature_name].values() + offsets = features[feature_name].offsets() + weights = features[feature_name].weights_or_none() + if weights is not None and not torch.is_floating_point(weights): + weights = None + pooled_embeddings.append( + emb_module( + input=values, + offsets=offsets, + per_sample_weights=weights, + ) + ) + return KeyedTensor( + keys=self._emb_names, + values=torch.cat(pooled_embeddings, dim=1), + length_per_key=self._lengths_per_emb, + ) + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + + for config in self._config.global_embedding_tables: + key = prefix + f"{config.name}.weight" + if config in self._config.local_embedding_tables: + config_idx = self._config.local_embedding_tables.index(config) + emb_module = self._emb_modules[config_idx] + param = emb_module.weight if keep_vars else emb_module.weight.data + assert config.local_rows == param.size(0) + assert config.local_cols == param.size(1) + if config.local_metadata is not None: + destination[key] = sharded_tensor.init_from_local_shards( + # pyre-ignore [6] + [Shard(param, config.local_metadata)], + [config.num_embeddings, config.embedding_dim], + process_group=self._pg, + ) + else: + destination[key] = param + else: + # just an handler for tw-related sharding on the rank that + # those tables aren't exist, this is to comply with SPMD + sharded_tensor.init_from_local_shards( + [], + [config.num_embeddings, config.embedding_dim], + process_group=self._pg, + ) + return destination + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + for config, emb_module in zip( + self._config.local_embedding_tables, + self._emb_modules, + ): + param = emb_module.weight + assert config.local_rows == param.size(0) + assert config.local_cols == param.size(1) + yield append_prefix(prefix, f"{config.name}.weight"), param + + def sparse_grad_parameter_names( + self, destination: Optional[List[str]] = None, prefix: str = "" + ) -> List[str]: + destination = [] if destination is None else destination + if self._sparse: + for config in self._config.local_embedding_tables: + destination.append(append_prefix(prefix, f"{config.name}.weight")) + return destination + + def config(self) -> GroupedEmbeddingConfig: + return self._config + + +class BaseBatchedEmbeddingBag(BaseEmbeddingBag): + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}") + self._config = config + self._pg = pg + + def to_pooling_mode(pooling_type: PoolingType) -> PoolingMode: + if pooling_type == PoolingType.SUM: + return PoolingMode.SUM + else: + assert pooling_type == PoolingType.MEAN + return PoolingMode.MEAN + + self._pooling: PoolingMode = to_pooling_mode(config.pooling) + + self._local_rows: List[int] = [] + self._weight_init_mins: List[float] = [] + self._weight_init_maxs: List[float] = [] + self._num_embeddings: List[int] = [] + self._local_cols: List[int] = [] + self._feature_table_map: List[int] = [] + self._emb_names: List[str] = [] + self._lengths_per_emb: List[int] = [] + + shared_feature: Dict[str, bool] = {} + for idx, config in enumerate(self._config.local_embedding_tables): + self._local_rows.append(config.local_rows) + self._weight_init_mins.append(config.get_weight_init_min()) + self._weight_init_maxs.append(config.get_weight_init_max()) + self._num_embeddings.append(config.num_embeddings) + self._local_cols.append(config.local_cols) + self._feature_table_map.extend([idx] * config.num_features()) + for feature_name in config.feature_names: + if feature_name not in shared_feature: + shared_feature[feature_name] = False + else: + shared_feature[feature_name] = True + self._lengths_per_emb.append(config.embedding_dim) + + for embedding_config in self._config.local_embedding_tables: + for feature_name in embedding_config.feature_names: + if shared_feature[feature_name]: + self._emb_names.append(feature_name + "@" + embedding_config.name) + else: + self._emb_names.append(feature_name) + + def init_parameters(self) -> None: + # initialize embedding weights + assert len(self._num_embeddings) == len(self.split_embedding_weights()) + for (rows, emb_dim, weight_init_min, weight_init_max, param) in zip( + self._local_rows, + self._local_cols, + self._weight_init_mins, + self._weight_init_maxs, + self.split_embedding_weights(), + ): + assert param.shape == (rows, emb_dim) + param.data.uniform_( + weight_init_min, + weight_init_max, + ) + + def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + weights = features.weights_or_none() + if weights is not None and not torch.is_floating_point(weights): + weights = None + values = self.emb_module( + indices=features.values().long(), + offsets=features.offsets().long(), + per_sample_weights=weights, + ) + return KeyedTensor( + keys=self._emb_names, + values=values, + length_per_key=self._lengths_per_emb, + ) + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + + for config in self._config.global_embedding_tables: + key = prefix + f"{config.name}.weight" + if config in self._config.local_embedding_tables: + config_idx = self._config.local_embedding_tables.index(config) + param = self.split_embedding_weights()[config_idx] + assert config.local_rows == param.size(0) + assert config.local_cols == param.size(1) + if config.local_metadata is not None: + destination[key] = sharded_tensor.init_from_local_shards( + [Shard(param, config.local_metadata)], + [config.num_embeddings, config.embedding_dim], + process_group=self._pg, + ) + else: + destination[key] = param + else: + # just an handler for tw-related sharding on the rank that + # those tables aren't exist, this is to comply with SPMD + sharded_tensor.init_from_local_shards( + [], + [config.num_embeddings, config.embedding_dim], + process_group=self._pg, + ) + return destination + + def split_embedding_weights(self) -> List[torch.Tensor]: + return self.emb_module.split_embedding_weights() + + @property + @abc.abstractmethod + def emb_module( + self, + ) -> Union[ + DenseTableBatchedEmbeddingBagsCodegen, + SplitTableBatchedEmbeddingBagsCodegen, + IntNBitTableBatchedEmbeddingBagsCodegen, + ]: + ... + + def config(self) -> GroupedEmbeddingConfig: + return self._config + + +class EmbeddingBagFusedOptimizer(FusedOptimizer): + def __init__( + self, + config: GroupedEmbeddingConfig, + emb_module: SplitTableBatchedEmbeddingBagsCodegen, + pg: Optional[dist.ProcessGroup] = None, + ) -> None: + self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = emb_module + self._pg = pg + self._fused_optim_rowwise: bool = False + if self._emb_module.optimizer in [ + OptimType.ROWWISE_ADAGRAD, + OptimType.EXACT_ROWWISE_ADAGRAD, + OptimType.PARTIAL_ROWWISE_ADAM, + OptimType.PARTIAL_ROWWISE_LAMB, + ]: + self._fused_optim_rowwise = True + + def to_rowwise_sharded_metadata( + metadata: ShardMetadata, + sharding_dim: int, + table_config: ShardedEmbeddingTable, + ) -> ShardMetadata: + offset = metadata.shard_offsets[0] + if sharding_dim == 1: + # for column-wise sharding, we still create row-wise sharded metadata for optimizer + # manually create a row-wise offset + offset = ( + metadata.shard_offsets[1] // table_config.block_size + ) * metadata.shard_lengths[0] + + rw_shard = ShardMetadata( + shard_lengths=[metadata.shard_lengths[0]], + shard_offsets=[offset], + placement=metadata.placement, + ) + + return rw_shard + + def to_rowwise_num_shards( + sharding_dim: int, table_config: ShardedEmbeddingTable + ) -> int: + if sharding_dim == 1: + # for column-wise sharding, we create len_shards base on + # the block size of the table + len_shards, res = divmod( + table_config.embedding_dim, table_config.block_size + ) + if res > 0: + len_shards += 1 + else: + len_shards = 1 + return len_shards + + # pyre-ignore [33] + state: Dict[Any, Any] = {} + param_group: Dict[str, Any] = { + "params": [], + "lr": emb_module.optimizer_args.learning_rate, + } + params: Dict[str, torch.Tensor] = {} + + # Fused optimizers use buffers (they don't use autograd) and we want to make sure + # that state_dict look identical to non-fused version. + split_embedding_weights = emb_module.split_embedding_weights() + split_optimizer_states = emb_module.split_optimizer_states() + for table_config, weight in zip( + config.local_embedding_tables, + split_embedding_weights, + ): + param_group["params"].append(weight) + param_key = table_config.name + ".weight" + params[param_key] = weight + + # set up states if there's momentums + if len(split_optimizer_states) > 0: + for global_config in config.global_embedding_tables: + has_local_shards = global_config in config.local_embedding_tables + has_momentum2 = self._emb_module.optimizer in ( + OptimType.ADAM, + OptimType.PARTIAL_ROWWISE_ADAM, + OptimType.LAMB, + OptimType.PARTIAL_ROWWISE_LAMB, + ) + + sharding_dim = ( + 1 + if global_config.local_cols != global_config.embedding_dim + and global_config.local_cols != 0 + else 0 + ) + if self._fused_optim_rowwise: + len_rw_shards = to_rowwise_num_shards(sharding_dim, global_config) + momentum_size = [global_config.num_embeddings * len_rw_shards] + else: + momentum_size = [ + global_config.num_embeddings, + global_config.embedding_dim, + ] + + if has_local_shards: + config_idx = config.local_embedding_tables.index(global_config) + weight = split_embedding_weights[config_idx] + optimizer_states = split_optimizer_states[config_idx] + state[weight] = {} + # momentum1 + assert global_config.local_rows == optimizer_states[0].size(0) + assert global_config.local_metadata is not None + momentum1_key = f"{global_config.name}.momentum1" + + local_metadata = ( + to_rowwise_sharded_metadata( + global_config.local_metadata, sharding_dim, global_config + ) + if self._fused_optim_rowwise + else global_config.local_metadata + ) + + momentum1 = sharded_tensor.init_from_local_shards( + [Shard(optimizer_states[0], local_metadata)], + momentum_size, + process_group=self._pg, + ) + state[weight][momentum1_key] = momentum1 + + # momentum2 + if has_momentum2: + assert table_config.local_rows == optimizer_states[1].size(0) + momentum2_key = f"{table_config.name}.momentum2" + + momentum2 = sharded_tensor.init_from_local_shards( + [Shard(optimizer_states[1], local_metadata)], + momentum_size, + process_group=self._pg, + ) + state[weight][momentum2_key] = momentum2 + else: + # just an handler for tw-related sharding on the rank that + # those tables aren't exist, this is to comply with SPMD + # momentum1 + sharded_tensor.init_from_local_shards( + [], + momentum_size, + process_group=self._pg, + ) + + if has_momentum2: + sharded_tensor.init_from_local_shards( + [], + momentum_size, + process_group=self._pg, + ) + + super().__init__(params, state, [param_group]) + + def zero_grad(self, set_to_none: bool = False) -> None: + # pyre-ignore [16] + self._emb_module.set_learning_rate(self.param_groups[0]["lr"]) + + # pyre-ignore [2] + def step(self, closure: Any = None) -> None: + # pyre-ignore [16] + self._emb_module.set_learning_rate(self.param_groups[0]["lr"]) + + +class BatchedFusedEmbeddingBag(BaseBatchedEmbeddingBag, FusedOptimizerModule): + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + fused_params: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(config, pg, device) + + def to_embedding_location( + compute_kernel: EmbeddingComputeKernel, + ) -> EmbeddingLocation: + if compute_kernel == EmbeddingComputeKernel.BATCHED_FUSED: + return EmbeddingLocation.DEVICE + elif compute_kernel == EmbeddingComputeKernel.BATCHED_FUSED_UVM: + return EmbeddingLocation.MANAGED + elif compute_kernel == EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING: + return EmbeddingLocation.MANAGED_CACHING + else: + raise ValueError(f"Invalid EmbeddingComputeKernel {compute_kernel}") + + managed: List[EmbeddingLocation] = [] + compute_devices: List[ComputeDevice] = [] + for table in config.local_embedding_tables: + if device is not None and device.type == "cuda": + compute_devices.append(ComputeDevice.CUDA) + managed.append(to_embedding_location(table.compute_kernel)) + else: + compute_devices.append(ComputeDevice.CPU) + managed.append(EmbeddingLocation.HOST) + if fused_params is None: + fused_params = {} + self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = ( + SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=list( + zip(self._local_rows, self._local_cols, managed, compute_devices) + ), + feature_table_map=self._feature_table_map, + pooling_mode=self._pooling, + weights_precision=BatchedFusedEmbeddingBag.to_sparse_type( + config.data_type + ), + device=device, + **fused_params, + ) + ) + self._optim: EmbeddingBagFusedOptimizer = EmbeddingBagFusedOptimizer( + config, + self._emb_module, + pg, + ) + + self.init_parameters() + + @staticmethod + def to_sparse_type(data_type: DataType) -> SparseType: + if data_type == DataType.FP32: + return SparseType.FP32 + elif data_type == DataType.FP16: + return SparseType.FP16 + elif data_type == DataType.INT8: + return SparseType.INT8 + else: + raise ValueError(f"Invalid DataType {data_type}") + + @property + def emb_module( + self, + ) -> SplitTableBatchedEmbeddingBagsCodegen: + return self._emb_module + + @property + def fused_optimizer(self) -> FusedOptimizer: + return self._optim + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + yield from () + + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for config, param in zip( + self._config.local_embedding_tables, + self.emb_module.split_embedding_weights(), + ): + key = append_prefix(prefix, f"{config.name}.weight") + yield key, param + + +class BatchedDenseEmbeddingBag(BaseBatchedEmbeddingBag): + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__(config, pg, device) + + self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = ( + DenseTableBatchedEmbeddingBagsCodegen( + list(zip(self._local_rows, self._local_cols)), + feature_table_map=self._feature_table_map, + pooling_mode=self._pooling, + use_cpu=device is None or device.type == "cpu", + ) + ) + + self.init_parameters() + + @property + def emb_module( + self, + ) -> DenseTableBatchedEmbeddingBagsCodegen: + return self._emb_module + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + combined_key = "/".join( + [config.name for config in self._config.local_embedding_tables] + ) + yield append_prefix(prefix, f"{combined_key}.weight"), cast( + nn.Parameter, self._emb_module.weights + ) + + +class QuantBatchedEmbeddingBag(BaseBatchedEmbeddingBag): + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__(config, pg, device) + + self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = ( + IntNBitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + "", + local_rows, + table.embedding_dim, + QuantBatchedEmbeddingBag.to_sparse_type(config.data_type), + EmbeddingLocation.DEVICE + if (device is not None and device.type == "cuda") + else EmbeddingLocation.HOST, + ) + for local_rows, table in zip( + self._local_rows, config.local_embedding_tables + ) + ], + pooling_mode=self._pooling, + ) + ) + if device is not None and device.type != "meta": + self._emb_module.initialize_weights() + + @staticmethod + def to_sparse_type(data_type: DataType) -> SparseType: + if data_type == DataType.FP16: + return SparseType.FP16 + elif data_type == DataType.INT8: + return SparseType.INT8 + elif data_type == DataType.INT4: + return SparseType.INT4 + elif data_type == DataType.INT2: + return SparseType.INT2 + else: + raise ValueError(f"Invalid DataType {data_type}") + + def init_parameters(self) -> None: + pass + + @property + def emb_module( + self, + ) -> IntNBitTableBatchedEmbeddingBagsCodegen: + return self._emb_module + + def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + values = self.emb_module( + indices=features.values().int(), + offsets=features.offsets().int(), + per_sample_weights=features.weights_or_none(), + ) + return KeyedTensor( + keys=self._emb_names, + values=values, + length_per_key=self._lengths_per_emb, + ) + + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for config, weight in zip( + self._config.local_embedding_tables, + self.emb_module.split_embedding_weights(), + ): + yield append_prefix(prefix, f"{config.name}.weight"), weight[0] + + def split_embedding_weights(self) -> List[torch.Tensor]: + return [ + weight + for weight, _ in self.emb_module.split_embedding_weights( + split_scale_shifts=False + ) + ] + + @classmethod + def from_float(cls, module: BaseEmbeddingBag) -> "QuantBatchedEmbeddingBag": + assert hasattr( + module, "qconfig" + ), "EmbeddingBagCollectionInterface input float module must have qconfig defined" + + def _to_data_type(dtype: torch.dtype) -> DataType: + if dtype == torch.quint8 or dtype == torch.qint8: + return DataType.INT8 + elif dtype == torch.quint4 or dtype == torch.qint4: + return DataType.INT4 + elif dtype == torch.quint2 or dtype == torch.qint2: + return DataType.INT2 + else: + raise Exception(f"Invalid data type {dtype}") + + # pyre-ignore [16] + data_type = _to_data_type(module.qconfig.weight().dtype) + sparse_type = QuantBatchedEmbeddingBag.to_sparse_type(data_type) + + state_dict = dict( + itertools.chain(module.named_buffers(), module.named_parameters()) + ) + device = next(iter(state_dict.values())).device + + # Adjust config to quantized version. + # This obviously doesn't work for column-wise sharding. + # pyre-ignore [29] + config = copy.deepcopy(module.config()) + config.data_type = data_type + for table in config.local_embedding_tables: + table.local_cols = rounded_row_size_in_bytes(table.local_cols, sparse_type) + if table.local_metadata is not None: + table.local_metadata.shard_lengths = [ + table.local_rows, + table.local_cols, + ] + + ret = QuantBatchedEmbeddingBag(config=config, device=device) + + # Quantize weights. + quant_weight_list = [] + for _, weight in state_dict.items(): + quantized_weights = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf( + weight, DATA_TYPE_NUM_BITS[data_type] + ) + # weight and 4 byte scale shift (2xfp16) + quant_weight = quantized_weights[:, :-4] + scale_shift = quantized_weights[:, -4:] + + quant_weight_list.append((quant_weight, scale_shift)) + ret.emb_module.assign_embedding_weights(quant_weight_list) + + return ret + + +class GroupedPooledEmbeddingsLookup(BaseEmbeddingLookup): + def __init__( + self, + grouped_configs: List[GroupedEmbeddingConfig], + grouped_score_configs: List[GroupedEmbeddingConfig], + device: Optional[torch.device] = None, + fused_params: Optional[Dict[str, Any]] = None, + pg: Optional[dist.ProcessGroup] = None, + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> None: + def _create_lookup( + config: GroupedEmbeddingConfig, + ) -> BaseEmbeddingBag: + if config.compute_kernel == EmbeddingComputeKernel.BATCHED_DENSE: + return BatchedDenseEmbeddingBag( + config=config, + pg=pg, + device=device, + ) + elif config.compute_kernel == EmbeddingComputeKernel.BATCHED_FUSED: + return BatchedFusedEmbeddingBag( + config=config, + pg=pg, + device=device, + fused_params=fused_params, + ) + elif config.compute_kernel == EmbeddingComputeKernel.DENSE: + return GroupedEmbeddingBag( + config=config, + sparse=False, + device=device, + ) + elif config.compute_kernel == EmbeddingComputeKernel.SPARSE: + return GroupedEmbeddingBag( + config=config, + sparse=True, + device=device, + ) + elif config.compute_kernel == EmbeddingComputeKernel.BATCHED_QUANT: + return QuantBatchedEmbeddingBag( + config=config, + pg=pg, + device=device, + ) + else: + raise ValueError( + f"Compute kernel not supported {config.compute_kernel}" + ) + + super().__init__() + self._emb_modules: nn.ModuleList[BaseEmbeddingBag] = nn.ModuleList() + for config in grouped_configs: + self._emb_modules.append(_create_lookup(config)) + + self._score_emb_modules: nn.ModuleList[BaseEmbeddingBag] = nn.ModuleList() + for config in grouped_score_configs: + self._score_emb_modules.append(_create_lookup(config)) + + self._id_list_feature_splits: List[int] = [] + for config in grouped_configs: + self._id_list_feature_splits.append(config.num_features()) + self._id_score_list_feature_splits: List[int] = [] + for config in grouped_score_configs: + self._id_score_list_feature_splits.append(config.num_features()) + + # return a dummy empty tensor + # when grouped_configs and grouped_score_configs are empty + self.register_buffer( + "_dummy_embs_tensor", + torch.empty( + [0], + dtype=torch.float32, + device=device, + requires_grad=True, + ), + ) + + self.grouped_configs = grouped_configs + self.grouped_score_configs = grouped_score_configs + self._feature_processor = feature_processor + + def forward( + self, + sparse_features: SparseFeatures, + ) -> torch.Tensor: + assert ( + sparse_features.id_list_features is not None + or sparse_features.id_score_list_features is not None + ) + embeddings: List[torch.Tensor] = [] + if len(self._emb_modules) > 0: + assert sparse_features.id_list_features is not None + id_list_features_by_group = sparse_features.id_list_features.split( + self._id_list_feature_splits, + ) + for config, emb_op, features in zip( + self.grouped_configs, self._emb_modules, id_list_features_by_group + ): + if ( + config.has_feature_processor + and self._feature_processor is not None + and isinstance( + self._feature_processor, GroupedPositionWeightedModule + ) + ): + features = self._feature_processor(features) + embeddings.append(emb_op(features).values()) + if len(self._score_emb_modules) > 0: + assert sparse_features.id_score_list_features is not None + id_score_list_features_by_group = ( + sparse_features.id_score_list_features.split( + self._id_score_list_feature_splits, + ) + ) + for emb_op, features in zip( + self._score_emb_modules, id_score_list_features_by_group + ): + embeddings.append(emb_op(features).values()) + + if len(embeddings) == 0: + # a hack for empty ranks + batch_size: int = ( + sparse_features.id_list_features.stride() + if sparse_features.id_list_features is not None + # pyre-fixme[16]: `Optional` has no attribute `stride`. + else sparse_features.id_score_list_features.stride() + ) + return self._dummy_embs_tensor.view(batch_size, 0) + elif len(embeddings) == 1: + return embeddings[0] + else: + return torch.cat(embeddings, dim=1) + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + + for emb_module in self._emb_modules: + emb_module.state_dict(destination, prefix, keep_vars) + for emb_module in self._score_emb_modules: + emb_module.state_dict(destination, prefix, keep_vars) + + return destination + + def load_state_dict( + self, + state_dict: "OrderedDict[str, torch.Tensor]", + strict: bool = True, + ) -> _IncompatibleKeys: + # pyre-ignore [6] + m1, u1 = _load_state_dict(self._emb_modules, state_dict) + # pyre-ignore [6] + m2, u2 = _load_state_dict(self._score_emb_modules, state_dict) + return _IncompatibleKeys(missing_keys=m1 + m2, unexpected_keys=u1 + u2) + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + for emb_module in self._emb_modules: + yield from emb_module.named_parameters(prefix, recurse) + for emb_module in self._score_emb_modules: + yield from emb_module.named_parameters(prefix, recurse) + + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for emb_module in self._emb_modules: + yield from emb_module.named_buffers(prefix, recurse) + for emb_module in self._score_emb_modules: + yield from emb_module.named_buffers(prefix, recurse) + + def sparse_grad_parameter_names( + self, destination: Optional[List[str]] = None, prefix: str = "" + ) -> List[str]: + destination = [] if destination is None else destination + for emb_module in self._emb_modules: + emb_module.sparse_grad_parameter_names(destination, prefix) + for emb_module in self._score_emb_modules: + emb_module.sparse_grad_parameter_names(destination, prefix) + return destination diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py new file mode 100644 index 000000000..53f8f2b10 --- /dev/null +++ b/torchrec/distributed/embedding_sharding.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 + +import abc +from dataclasses import dataclass, field +from typing import List, Tuple, Optional, Dict, Any + +import torch +import torch.distributed as dist +from torch import nn +from torch.distributed._sharding_spec import ShardMetadata +from torchrec.distributed.dist_data import KJTAllToAll +from torchrec.distributed.embedding_types import ( + GroupedEmbeddingConfig, + BaseEmbeddingLookup, + SparseFeatures, + EmbeddingComputeKernel, + ShardedEmbeddingTable, + BaseGroupedFeatureProcessor, + SparseFeaturesList, +) +from torchrec.distributed.types import Awaitable +from torchrec.modules.embedding_configs import ( + PoolingType, + DataType, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.types import Multistreamable + + +@dataclass +class SequenceShardingContext(Multistreamable): + """ + SequenceEmbeddingAll2all has the same comm pattern as KJTAll2all. + Stores KJTAll2all context and reuse it in SequenceEmbeddingAll2all. + + features_before_input_dist: stores the original KJT before input dist + input_splits: stores the input splits of KJT ALl2all + input_splits: stores the output splits of KJT ALl2all + unbucketize_permute_tensor: stores the permute order of + KJT bucketize (forrow-wise sharding only) + lengths_after_input_dist: stores the KJT length after input dist + """ + + features_before_input_dist: Optional[KeyedJaggedTensor] = None + input_splits: List[int] = field(default_factory=list) + output_splits: List[int] = field(default_factory=list) + unbucketize_permute_tensor: Optional[torch.Tensor] = None + lengths_after_input_dist: Optional[torch.Tensor] = None + + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + if self.features_before_input_dist is not None: + self.features_before_input_dist.record_stream(stream) + if self.unbucketize_permute_tensor is not None: + self.unbucketize_permute_tensor.record_stream(stream) + if self.lengths_after_input_dist is not None: + self.lengths_after_input_dist.record_stream(stream) + + +class SparseFeaturesAllToAllAwaitable(Awaitable[SparseFeatures]): + def __init__( + self, + id_list_features_awaitable: Optional[Awaitable[KeyedJaggedTensor]], + id_score_list_features_awaitable: Optional[Awaitable[KeyedJaggedTensor]], + ) -> None: + super().__init__() + self._id_list_features_awaitable = id_list_features_awaitable + self._id_score_list_features_awaitable = id_score_list_features_awaitable + + def wait(self) -> SparseFeatures: + return SparseFeatures( + id_list_features=self._id_list_features_awaitable.wait() + if self._id_list_features_awaitable is not None + else None, + id_score_list_features=self._id_score_list_features_awaitable.wait() + if self._id_score_list_features_awaitable is not None + else None, + ) + + +def bucketize_kjt_before_all2all( + kjt: KeyedJaggedTensor, + num_buckets: int, + block_sizes: torch.Tensor, + output_permute: bool = False, + bucketize_pos: bool = False, +) -> Tuple[KeyedJaggedTensor, Optional[torch.Tensor]]: + """ + Bucketize the `values` in KeyedJaggedTensor into `num_buckets` buckets, + `lengths` are readjusted based on the bucketization results. + + Note: This function should be used only for row-wise sharding before calling SparseFeaturesAllToAll + + Args: + num_buckets (int): The number of buckets to bucketize the values into. + block_sizes: (torch.Tensor): The bucket sizes for the keyed dimension. + output_permute (bool): Output the memory location mapping from the unbucketized values to bucketized values or not. + bucketize_pos (bool): Output the changed position of the bucketized values or not. + Returns: + The bucketized `KeyedJaggedTensor` and the optional permute mapping from the unbucketized values to bucketized values. + """ + num_features = len(kjt.keys()) + assert ( + block_sizes.numel() == num_features + ), f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received." + + # kernel expects them to be same type, cast to avoid type mismatch + block_sizes_new_type = block_sizes.type(kjt.values().type()) + ( + bucketized_lengths, + bucketized_indices, + bucketized_weights, + pos, + unbucketize_permute, + ) = torch.ops.fbgemm.block_bucketize_sparse_features( + kjt.lengths().view(-1), + kjt.values(), + bucketize_pos=bucketize_pos, + sequence=output_permute, + block_sizes=block_sizes_new_type, + my_size=num_buckets, + weights=kjt.weights_or_none(), + ) + + return ( + KeyedJaggedTensor( + # duplicate keys will be resolved by AllToAll + keys=kjt.keys() * num_buckets, + values=bucketized_indices, + weights=pos if bucketize_pos else bucketized_weights, + lengths=bucketized_lengths.view(-1), + offsets=None, + stride=kjt.stride(), + length_per_key=None, + offset_per_key=None, + index_per_key=None, + ), + unbucketize_permute, + ) + + +class SparseFeaturesAllToAll(nn.Module): + def __init__( + self, + pg: dist.ProcessGroup, + id_list_features_per_rank: List[int], + id_score_list_features_per_rank: List[int], + device: Optional[torch.device] = None, + stagger: int = 1, + ) -> None: + super().__init__() + self._id_list_features_all2all = KJTAllToAll( + pg, id_list_features_per_rank, device, stagger + ) + self._id_score_list_features_all2all = KJTAllToAll( + pg, id_score_list_features_per_rank, device, stagger + ) + + def forward( + self, + sparse_features: SparseFeatures, + ) -> Awaitable[SparseFeatures]: + return SparseFeaturesAllToAllAwaitable( + id_list_features_awaitable=self._id_list_features_all2all.forward( + sparse_features.id_list_features + ) + if sparse_features.id_list_features is not None + else None, + id_score_list_features_awaitable=self._id_score_list_features_all2all.forward( + sparse_features.id_score_list_features + ) + if sparse_features.id_score_list_features is not None + else None, + ) + + +# group tables by DataType, PoolingType, Weighted, and EmbeddingComputeKernel. +def group_tables( + tables_per_rank: List[List[ShardedEmbeddingTable]], +) -> Tuple[List[List[GroupedEmbeddingConfig]], List[List[GroupedEmbeddingConfig]]]: + def _group_tables_per_rank( + embedding_tables: List[ShardedEmbeddingTable], + ) -> Tuple[List[GroupedEmbeddingConfig], List[GroupedEmbeddingConfig]]: + grouped_embedding_configs: List[GroupedEmbeddingConfig] = [] + score_grouped_embedding_configs: List[GroupedEmbeddingConfig] = [] + for data_type in DataType: + for pooling in PoolingType: + for is_weighted in [True, False]: + for has_feature_processor in [True, False]: + for compute_kernel in [ + EmbeddingComputeKernel.DENSE, + EmbeddingComputeKernel.SPARSE, + EmbeddingComputeKernel.BATCHED_DENSE, + EmbeddingComputeKernel.BATCHED_FUSED, + EmbeddingComputeKernel.BATCHED_QUANT, + ]: + global_grouped_tables: List[ShardedEmbeddingTable] = [] + global_grouped_score_tables: List[ + ShardedEmbeddingTable + ] = [] + local_grouped_tables: List[ShardedEmbeddingTable] = [] + local_grouped_score_tables: List[ShardedEmbeddingTable] = [] + for table in embedding_tables: + if table.compute_kernel in [ + EmbeddingComputeKernel.BATCHED_FUSED_UVM, + EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING, + ]: + compute_kernel_type = ( + EmbeddingComputeKernel.BATCHED_FUSED + ) + else: + compute_kernel_type = table.compute_kernel + if ( + table.data_type == data_type + and table.pooling == pooling + and table.is_weighted == is_weighted + and table.has_feature_processor + == has_feature_processor + and compute_kernel_type == compute_kernel + ): + + # if not empty on the rank, add to local configs + table_not_empty = ( + table.local_rows != 0 and table.local_cols != 0 + ) + if table.is_weighted: + global_grouped_score_tables.append(table) + if table_not_empty: + local_grouped_score_tables.append(table) + else: + global_grouped_tables.append(table) + if table_not_empty: + local_grouped_tables.append(table) + if local_grouped_tables: + grouped_embedding_configs.append( + GroupedEmbeddingConfig( + data_type=data_type, + pooling=pooling, + is_weighted=is_weighted, + has_feature_processor=has_feature_processor, + compute_kernel=compute_kernel, + global_embedding_tables=global_grouped_tables, + local_embedding_tables=local_grouped_tables, + ) + ) + if local_grouped_score_tables: + score_grouped_embedding_configs.append( + GroupedEmbeddingConfig( + data_type=data_type, + pooling=pooling, + is_weighted=is_weighted, + has_feature_processor=has_feature_processor, + compute_kernel=compute_kernel, + global_embedding_tables=global_grouped_score_tables, + local_embedding_tables=local_grouped_score_tables, + ) + ) + return grouped_embedding_configs, score_grouped_embedding_configs + + grouped_embedding_configs_by_rank: List[List[GroupedEmbeddingConfig]] = [] + score_grouped_embedding_configs_by_rank: List[List[GroupedEmbeddingConfig]] = [] + for tables in tables_per_rank: + ( + grouped_embedding_configs, + score_grouped_embedding_configs, + ) = _group_tables_per_rank(tables) + grouped_embedding_configs_by_rank.append(grouped_embedding_configs) + score_grouped_embedding_configs_by_rank.append(score_grouped_embedding_configs) + return ( + grouped_embedding_configs_by_rank, + score_grouped_embedding_configs_by_rank, + ) + + +class SparseFeaturesListAwaitable(Awaitable[SparseFeaturesList]): + def __init__( + self, + awaitables: List[Awaitable[SparseFeatures]], + ) -> None: + super().__init__() + self.awaitables = awaitables + + def wait(self) -> SparseFeaturesList: + return SparseFeaturesList([w.wait() for w in self.awaitables]) + + +class BaseSparseFeaturesDist(abc.ABC, nn.Module): + """ + Converts input from data-parallel to model-parallel. + """ + + @abc.abstractmethod + def forward( + self, + sparse_features: SparseFeatures, + ) -> Awaitable[SparseFeatures]: + pass + + +class BasePooledEmbeddingDist(abc.ABC, nn.Module): + """ + Converts output of pooled EmbeddingLookup + from model-parallel to data-parallel. + """ + + @abc.abstractmethod + def forward(self, local_embs: torch.Tensor) -> Awaitable[torch.Tensor]: + pass + + +class BaseSequenceEmbeddingDist(abc.ABC, nn.Module): + """ + Converts output of sequence EmbeddingLookup + from model-parallel to data-parallel. + """ + + pass + + @abc.abstractmethod + def forward( + self, sharding_ctx: SequenceShardingContext, local_embs: torch.Tensor + ) -> Awaitable[torch.Tensor]: + pass + + +class EmbeddingSharding(abc.ABC): + """ + Used to implement different sharding type for EmbeddingBagCollection, e.g. table_wise. + """ + + @abc.abstractmethod + def create_input_dist(self) -> BaseSparseFeaturesDist: + pass + + @abc.abstractmethod + def create_pooled_output_dist(self) -> BasePooledEmbeddingDist: + pass + + @abc.abstractmethod + def create_sequence_output_dist(self) -> BaseSequenceEmbeddingDist: + pass + + @abc.abstractmethod + def create_lookup( + self, + fused_params: Optional[Dict[str, Any]], + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup: + pass + + @abc.abstractmethod + def embedding_dims(self) -> List[int]: + pass + + @abc.abstractmethod + def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: + pass + + @abc.abstractmethod + def embedding_names(self) -> List[str]: + pass + + @abc.abstractmethod + def id_list_feature_names(self) -> List[str]: + pass + + @abc.abstractmethod + def id_score_list_feature_names(self) -> List[str]: + pass diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py new file mode 100644 index 000000000..a106487de --- /dev/null +++ b/torchrec/distributed/embedding_types.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 + +import abc +from dataclasses import dataclass +from enum import Enum, unique +from typing import List, Optional, Dict, Any, TypeVar, Iterator + +import torch +from torch import nn +from torch.distributed._sharded_tensor import ShardMetadata +from torchrec.distributed.types import ( + ModuleSharder, + ShardingType, + ParameterStorage, +) +from torchrec.modules.embedding_configs import ( + PoolingType, + DataType, + EmbeddingTableConfig, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.types import Multistreamable + + +@unique +class OptimType(Enum): + SGD = "SGD" + LARS_SGD = "LARS_SGD" + LAMB = "LAMB" + PARTIAL_ROWWISE_LAMB = "PARTIAL_ROWWISE_LAMB" + ADAM = "ADAM" + PARTIAL_ROWWISE_ADAM = "PARTIAL_ROWWISE_ADAM" + ADAGRAD = "ADAGRAD" + ROWWISE_ADAGRAD = "ROWWISE_ADAGRAD" + + +@unique +class EmbeddingComputeKernel(Enum): + DENSE = "dense" + SPARSE = "sparse" + BATCHED_DENSE = "batched_dense" + BATCHED_FUSED = "batched_fused" + BATCHED_FUSED_UVM = "batched_fused_uvm" + BATCHED_FUSED_UVM_CACHING = "batched_fused_uvm_caching" + BATCHED_QUANT = "batched_quant" + + +@dataclass +class SparseFeatures(Multistreamable): + id_list_features: Optional[KeyedJaggedTensor] = None + id_score_list_features: Optional[KeyedJaggedTensor] = None + + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + if self.id_list_features is not None: + self.id_list_features.record_stream(stream) + if self.id_score_list_features is not None: + self.id_score_list_features.record_stream(stream) + + +class SparseFeaturesList(Multistreamable): + def __init__(self, features: List[SparseFeatures]) -> None: + self.features = features + + def __len__(self) -> int: + return len(self.features) + + def __setitem__(self, key: int, item: SparseFeatures) -> None: + self.features[key] = item + + def __getitem__(self, key: int) -> SparseFeatures: + return self.features[key] + + def __iter__(self) -> Iterator[SparseFeatures]: + return iter(self.features) + + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + for feature in self.features: + feature.record_stream(stream) + + +@dataclass +class ShardedConfig: + local_rows: int = 0 + local_cols: int = 0 + # The block size of sharding dim on each shard. + # mainly used in cw, not applicable in tw/dp + block_size: int = 0 + + +@dataclass +class ShardedMetaConfig(ShardedConfig): + local_metadata: Optional[ShardMetadata] = None + + +@dataclass +class EmbeddingAttributes: + compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.DENSE + + +@dataclass +class ShardedEmbeddingTable( + ShardedMetaConfig, + EmbeddingAttributes, + EmbeddingTableConfig, +): + pass + + +@dataclass +class GroupedEmbeddingConfig: + data_type: DataType + pooling: PoolingType + is_weighted: bool + has_feature_processor: bool + compute_kernel: EmbeddingComputeKernel + # a global logical view of the grouped embedding tables (including the tables that + # does not exist locally, i.e. tw, tw-rw sharding) + global_embedding_tables: List[ShardedEmbeddingTable] + # a local grouped embedding tables that are being created on this rank + local_embedding_tables: List[ShardedEmbeddingTable] + + def feature_hash_sizes(self) -> List[int]: + feature_hash_sizes = [] + for table in self.local_embedding_tables: + feature_hash_sizes.extend(table.num_features() * [table.num_embeddings]) + return feature_hash_sizes + + def num_features(self) -> int: + num_features = 0 + for table in self.local_embedding_tables: + num_features += table.num_features() + return num_features + + def dim_sum(self) -> int: + dim_sum = 0 + for table in self.local_embedding_tables: + dim_sum += table.num_features() * table.local_cols + return dim_sum + + def feature_names(self) -> List[str]: + feature_names = [] + for table in self.local_embedding_tables: + feature_names.extend(table.feature_names) + return feature_names + + def embedding_dims(self) -> List[int]: + embedding_dims = [] + for table in self.local_embedding_tables: + embedding_dims.extend([table.local_cols] * table.num_features()) + return embedding_dims + + def embedding_names(self) -> List[str]: + embedding_names = [] + for table in self.local_embedding_tables: + embedding_names.extend(table.embedding_names) + return embedding_names + + def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: + embedding_shard_metadata: List[Optional[ShardMetadata]] = [] + for table in self.local_embedding_tables: + for _ in table.feature_names: + embedding_shard_metadata.append(table.local_metadata) + return embedding_shard_metadata + + +class BaseEmbeddingLookup(abc.ABC, nn.Module): + """ + Interface implemented by different embedding implementations: + e.g. one, which relies on nn.EmbeddingBag or table-batched one, etc. + """ + + @abc.abstractmethod + def forward( + self, + sparse_features: SparseFeatures, + ) -> torch.Tensor: + pass + + def sparse_grad_parameter_names( + self, destination: Optional[List[str]] = None, prefix: str = "" + ) -> List[str]: + destination = [] if destination is None else destination + return destination + + +M = TypeVar("M", bound=nn.Module) + + +class BaseEmbeddingSharder(ModuleSharder[M]): + def __init__(self, fused_params: Optional[Dict[str, Any]] = None) -> None: + self._fused_params = fused_params + + def sharding_types(self, compute_device_type: str) -> List[str]: + types = [ + ShardingType.DATA_PARALLEL.value, + ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ] + if compute_device_type in {"cuda"}: + # TWRW supported for CUDA only + types.append(ShardingType.TABLE_ROW_WISE.value) + + return types + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + ret = [ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.BATCHED_DENSE.value, + ] + if sharding_type != ShardingType.DATA_PARALLEL.value: + ret += [ + EmbeddingComputeKernel.BATCHED_FUSED.value, + EmbeddingComputeKernel.SPARSE.value, + ] + if compute_device_type in {"cuda"}: + ret += [ + EmbeddingComputeKernel.BATCHED_FUSED_UVM.value, + EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING.value, + ] + return ret + + @property + def fused_params(self) -> Optional[Dict[str, Any]]: + return self._fused_params + + def storage_usage( + self, tensor: torch.Tensor, compute_device_type: str, compute_kernel: str + ) -> Dict[str, int]: + """ + List of system resources and corresponding usage given a compute device and + compute kernel + """ + tensor_bytes = tensor.element_size() * tensor.nelement() + if compute_kernel in { + EmbeddingComputeKernel.BATCHED_FUSED_UVM.value, + EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING.value, + }: + assert compute_device_type in {"cuda"} + return {ParameterStorage.DDR.value: tensor_bytes} + else: + assert compute_device_type in {"cuda", "cpu"} + storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR} + return { + storage_map[compute_device_type].value: tensor.element_size() + * tensor.nelement() + } + + +class BaseGroupedFeatureProcessor(nn.Module): + """ + abstract base class for grouped feature processor + """ + + @abc.abstractmethod + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedJaggedTensor: + pass + + def sparse_grad_parameter_names( + self, destination: Optional[List[str]] = None, prefix: str = "" + ) -> List[str]: + destination = [] if destination is None else destination + return destination diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py new file mode 100644 index 000000000..940ae6cd4 --- /dev/null +++ b/torchrec/distributed/embeddingbag.py @@ -0,0 +1,800 @@ +#!/usr/bin/env python3 + +import copy +from collections import OrderedDict +from typing import ( + List, + Dict, + Optional, + Type, + Any, + TypeVar, + Mapping, + Union, + Tuple, + Iterator, + Set, +) + +import torch +from torch import Tensor +from torch import nn +from torch.distributed._sharding_spec import ( + EnumerableShardingSpec, +) +from torch.nn.modules.module import _IncompatibleKeys +from torchrec.distributed.cw_sharding import CwEmbeddingSharding +from torchrec.distributed.dp_sharding import DpEmbeddingSharding +from torchrec.distributed.embedding_sharding import ( + EmbeddingSharding, + SparseFeaturesListAwaitable, +) +from torchrec.distributed.embedding_types import ( + SparseFeatures, + BaseEmbeddingSharder, + EmbeddingComputeKernel, + BaseEmbeddingLookup, + SparseFeaturesList, +) +from torchrec.distributed.rw_sharding import RwEmbeddingSharding +from torchrec.distributed.tw_sharding import TwEmbeddingSharding +from torchrec.distributed.twrw_sharding import TwRwEmbeddingSharding +from torchrec.distributed.types import ( + Awaitable, + LazyAwaitable, + ParameterSharding, + ParameterStorage, + ShardedModule, + ShardingType, + ShardedModuleContext, + ShardedTensor, + ModuleSharder, + ShardingEnv, +) +from torchrec.distributed.utils import append_prefix +from torchrec.modules.embedding_configs import EmbeddingTableConfig, PoolingType +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingBagCollectionInterface, +) +from torchrec.optim.fused import FusedOptimizerModule +from torchrec.optim.keyed import KeyedOptimizer, CombinedOptimizer +from torchrec.quant.embedding_modules import ( + EmbeddingBagCollection as QuantEmbeddingBagCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +def create_embedding_sharding( + sharding_type: str, + embedding_configs: List[ + Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] + ], + env: ShardingEnv, + device: Optional[torch.device] = None, +) -> EmbeddingSharding: + pg = env.process_group + if device is not None and device.type == "meta": + replace_placement_with_meta_device(embedding_configs) + if pg is not None: + if sharding_type == ShardingType.TABLE_WISE.value: + return TwEmbeddingSharding(embedding_configs, pg, device) + elif sharding_type == ShardingType.ROW_WISE.value: + return RwEmbeddingSharding(embedding_configs, pg, device) + elif sharding_type == ShardingType.DATA_PARALLEL.value: + return DpEmbeddingSharding(embedding_configs, env, device) + elif sharding_type == ShardingType.TABLE_ROW_WISE.value: + return TwRwEmbeddingSharding(embedding_configs, pg, device) + elif sharding_type == ShardingType.COLUMN_WISE.value: + return CwEmbeddingSharding(embedding_configs, pg, device) + else: + raise ValueError(f"Sharding not supported {sharding_type}") + else: + if sharding_type == ShardingType.DATA_PARALLEL.value: + return DpEmbeddingSharding(embedding_configs, env, device) + else: + raise ValueError(f"Sharding not supported {sharding_type}") + + +def replace_placement_with_meta_device( + embedding_configs: List[ + Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] + ] +) -> None: + """Placement device and tensor device could be unmatched in some + scenarios, e.g. passing meta device to DMP and passing cuda + to EmbeddingShardingPlanner. We need to make device consistent + after getting sharding planner. + """ + for config in embedding_configs: + sharding_spec = config[1].sharding_spec + if sharding_spec is None: + continue + if isinstance(sharding_spec, EnumerableShardingSpec): + for shard_metadata in sharding_spec.shards: + placement = shard_metadata.placement + if isinstance(placement, str): + placement = torch.distributed._remote_device(placement) + assert isinstance(placement, torch.distributed._remote_device) + placement._device = torch.device("meta") + shard_metadata.placement = placement + else: + # We only support EnumerableShardingSpec at present. + raise RuntimeError( + f"Unsupported ShardingSpec {type(sharding_spec)} with meta device" + ) + + +def filter_state_dict( + state_dict: "OrderedDict[str, torch.Tensor]", name: str +) -> "OrderedDict[str, torch.Tensor]": + rtn_dict = OrderedDict() + for key, value in state_dict.items(): + if key.startswith(name): + # + 1 to length is to remove the '.' after the key + rtn_dict[key[len(name) + 1 :]] = value + return rtn_dict + + +def _create_embedding_configs_by_sharding( + module: EmbeddingBagCollectionInterface, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + prefix: str, +) -> Dict[str, List[Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor]]]: + shared_feature: Dict[str, bool] = {} + for embedding_config in module.embedding_bag_configs: + if not embedding_config.feature_names: + embedding_config.feature_names = [embedding_config.name] + for feature_name in embedding_config.feature_names: + if feature_name not in shared_feature: + shared_feature[feature_name] = False + else: + shared_feature[feature_name] = True + + sharding_type_to_embedding_configs: Dict[ + str, List[Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor]] + ] = {} + state_dict = module.state_dict() + for config in module.embedding_bag_configs: + table_name = config.name + assert table_name in table_name_to_parameter_sharding + parameter_sharding = table_name_to_parameter_sharding[table_name] + if parameter_sharding.compute_kernel not in [ + kernel.value for kernel in EmbeddingComputeKernel + ]: + raise ValueError( + f"Compute kernel not supported {parameter_sharding.compute_kernel}" + ) + embedding_names: List[str] = [] + for feature_name in config.feature_names: + if shared_feature[feature_name]: + embedding_names.append(feature_name + "@" + config.name) + else: + embedding_names.append(feature_name) + + param_name = prefix + table_name + ".weight" + assert param_name in state_dict + param = state_dict[param_name] + + if parameter_sharding.sharding_type not in sharding_type_to_embedding_configs: + sharding_type_to_embedding_configs[parameter_sharding.sharding_type] = [] + sharding_type_to_embedding_configs[parameter_sharding.sharding_type].append( + ( + EmbeddingTableConfig( + num_embeddings=config.num_embeddings, + embedding_dim=config.embedding_dim, + name=config.name, + data_type=config.data_type, + feature_names=copy.deepcopy(config.feature_names), + pooling=config.pooling, + is_weighted=module.is_weighted, + has_feature_processor=False, + embedding_names=embedding_names, + weight_init_max=config.weight_init_max, + weight_init_min=config.weight_init_min, + ), + parameter_sharding, + param, + ) + ) + return sharding_type_to_embedding_configs + + +class EmbeddingCollectionAwaitable(LazyAwaitable[KeyedTensor]): + def __init__( + self, + awaitables: List[Awaitable[torch.Tensor]], + embedding_dims: List[int], + embedding_names: List[str], + ) -> None: + super().__init__() + self._awaitables = awaitables + self._embedding_dims = embedding_dims + self._embedding_names = embedding_names + + def wait(self) -> KeyedTensor: + embeddings = [w.wait() for w in self._awaitables] + if len(embeddings) == 1: + embeddings = embeddings[0] + else: + embeddings = torch.cat(embeddings, dim=1) + return KeyedTensor( + keys=self._embedding_names, + length_per_key=self._embedding_dims, + values=embeddings, + key_dim=1, + ) + + +class ShardedEmbeddingBagCollection( + ShardedModule[ + SparseFeaturesList, + List[torch.Tensor], + KeyedTensor, + ], + FusedOptimizerModule, +): + """ + Sharded implementation of EmbeddingBagCollection. + This is part of public API to allow for manual data dist pipelining. + """ + + def __init__( + self, + module: EmbeddingBagCollectionInterface, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + env: ShardingEnv, + fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + sharding_type_to_embedding_configs = _create_embedding_configs_by_sharding( + module, table_name_to_parameter_sharding, "embedding_bags." + ) + self._sharding_type_to_sharding: Dict[str, EmbeddingSharding] = { + sharding_type: create_embedding_sharding( + sharding_type, embedding_confings, env, device + ) + for sharding_type, embedding_confings in sharding_type_to_embedding_configs.items() + } + + self._is_weighted: bool = module.is_weighted + self._device = device + self._create_lookups(fused_params) + self._output_dists: nn.ModuleList[nn.Module] = nn.ModuleList() + self._embedding_names: List[str] = [] + self._embedding_dims: List[int] = [] + self._input_dists: nn.ModuleList[nn.Module] = nn.ModuleList() + self._feature_splits: List[int] = [] + self._features_order: List[int] = [] + + # forward pass flow control + self._has_uninitialized_input_dist: bool = True + self._has_uninitialized_output_dist: bool = True + self._has_features_permute: bool = True + + # Get all fused optimizers and combine them. + optims = [] + for lookup in self._lookups: + for _, module in lookup.named_modules(): + if isinstance(module, FusedOptimizerModule): + # modify param keys to match EmbeddingBagCollection + params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {} + for param_key, weight in module.fused_optimizer.params.items(): + params["embedding_bags." + param_key] = weight + module.fused_optimizer.params = params + optims.append(("", module.fused_optimizer)) + self._optim: CombinedOptimizer = CombinedOptimizer(optims) + + def _create_input_dist( + self, + input_feature_names: List[str], + ) -> None: + + feature_names: List[str] = [] + for sharding in self._sharding_type_to_sharding.values(): + self._input_dists.append(sharding.create_input_dist()) + feature_names.extend( + sharding.id_score_list_feature_names() + if self._is_weighted + else sharding.id_list_feature_names() + ) + self._feature_splits.append( + len( + sharding.id_score_list_feature_names() + if self._is_weighted + else sharding.id_list_feature_names() + ) + ) + + if feature_names == input_feature_names: + self._has_features_permute = False + else: + for f in feature_names: + self._features_order.append(input_feature_names.index(f)) + self.register_buffer( + "_features_order_tensor", + torch.tensor( + self._features_order, device=self._device, dtype=torch.int32 + ), + ) + + def _create_lookups( + self, + fused_params: Optional[Dict[str, Any]], + ) -> None: + self._lookups: nn.ModuleList[BaseEmbeddingLookup] = nn.ModuleList() + for sharding in self._sharding_type_to_sharding.values(): + self._lookups.append(sharding.create_lookup(fused_params)) + + def _create_output_dist(self) -> None: + for sharding in self._sharding_type_to_sharding.values(): + self._output_dists.append(sharding.create_pooled_output_dist()) + self._embedding_names.extend(sharding.embedding_names()) + self._embedding_dims.extend(sharding.embedding_dims()) + + # pyre-ignore [14] + def input_dist( + self, ctx: ShardedModuleContext, features: KeyedJaggedTensor + ) -> Awaitable[SparseFeaturesList]: + if self._has_uninitialized_input_dist: + self._create_input_dist(features.keys()) + self._has_uninitialized_input_dist = False + with torch.no_grad(): + if self._has_features_permute: + features = features.permute( + self._features_order, + # pyre-ignore [6] + self._features_order_tensor, + ) + features_by_shards = features.split( + self._feature_splits, + ) + awaitables = [ + module( + SparseFeatures( + id_list_features=None + if self._is_weighted + else features_by_shard, + id_score_list_features=features_by_shard + if self._is_weighted + else None, + ) + ) + for module, features_by_shard in zip( + self._input_dists, features_by_shards + ) + ] + return SparseFeaturesListAwaitable(awaitables) + + def compute( + self, ctx: ShardedModuleContext, dist_input: SparseFeaturesList + ) -> List[torch.Tensor]: + return [lookup(features) for lookup, features in zip(self._lookups, dist_input)] + + def output_dist( + self, ctx: ShardedModuleContext, output: List[torch.Tensor] + ) -> LazyAwaitable[KeyedTensor]: + if self._has_uninitialized_output_dist: + self._create_output_dist() + self._has_uninitialized_output_dist = False + return EmbeddingCollectionAwaitable( + awaitables=[ + dist(embeddings) for dist, embeddings in zip(self._output_dists, output) + ], + embedding_dims=self._embedding_dims, + embedding_names=self._embedding_names, + ) + + def compute_and_output_dist( + self, ctx: ShardedModuleContext, input: SparseFeaturesList + ) -> LazyAwaitable[KeyedTensor]: + if self._has_uninitialized_output_dist: + self._create_output_dist() + self._has_uninitialized_output_dist = False + return EmbeddingCollectionAwaitable( + awaitables=[ + dist(lookup(features)) + for lookup, dist, features in zip( + self._lookups, self._output_dists, input + ) + ], + embedding_dims=self._embedding_dims, + embedding_names=self._embedding_names, + ) + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + for lookup in self._lookups: + lookup.state_dict(destination, prefix + "embedding_bags.", keep_vars) + return destination + + def named_modules( + self, + memo: Optional[Set[nn.Module]] = None, + prefix: str = "", + remove_duplicate: bool = True, + ) -> Iterator[Tuple[str, nn.Module]]: + yield from [(prefix, self)] + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + for lookup in self._lookups: + yield from lookup.named_parameters( + append_prefix(prefix, "embedding_bags"), recurse + ) + + def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: + for lookup, sharding_type in zip( + self._lookups, self._sharding_type_to_sharding.keys() + ): + if sharding_type == ShardingType.DATA_PARALLEL.value: + continue + for name, _ in lookup.named_parameters( + append_prefix(prefix, "embedding_bags") + ): + yield name + + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for lookup in self._lookups: + yield from lookup.named_buffers( + append_prefix(prefix, "embedding_bags"), recurse + ) + + def load_state_dict( + self, + state_dict: "OrderedDict[str, torch.Tensor]", + strict: bool = True, + ) -> _IncompatibleKeys: + missing_keys = [] + unexpected_keys = [] + for lookup in self._lookups: + missing, unexpected = lookup.load_state_dict( + filter_state_dict(state_dict, "embedding_bags"), + strict, + ) + missing_keys.extend(missing) + unexpected_keys.extend(unexpected) + return _IncompatibleKeys( + missing_keys=missing_keys, unexpected_keys=unexpected_keys + ) + + def sparse_grad_parameter_names( + self, + destination: Optional[List[str]] = None, + prefix: str = "", + ) -> List[str]: + destination = [] if destination is None else destination + for lookup in self._lookups: + lookup.sparse_grad_parameter_names( + destination, append_prefix(prefix, "embedding_bags") + ) + return destination + + @property + def fused_optimizer(self) -> KeyedOptimizer: + return self._optim + + +M = TypeVar("M", bound=nn.Module) + + +class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[M]): + """ + This implementation uses non-fused EmbeddingBagCollection + """ + + def shard( + self, + module: EmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + ) -> ShardedEmbeddingBagCollection: + return ShardedEmbeddingBagCollection( + module, params, env, self.fused_params, device + ) + + def shardable_parameters( + self, module: EmbeddingBagCollection + ) -> Dict[str, nn.Parameter]: + return { + name.split(".")[0]: param + for name, param in module.embedding_bags.named_parameters() + } + + @property + def module_type(self) -> Type[EmbeddingBagCollection]: + return EmbeddingBagCollection + + +class QuantEmbeddingBagCollectionSharder(ModuleSharder[QuantEmbeddingBagCollection]): + def shard( + self, + module: QuantEmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + ) -> ShardedEmbeddingBagCollection: + return ShardedEmbeddingBagCollection(module, params, env, None, device) + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.DATA_PARALLEL.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [ + EmbeddingComputeKernel.BATCHED_QUANT.value, + ] + + def storage_usage( + self, tensor: torch.Tensor, compute_device_type: str, compute_kernel: str + ) -> Dict[str, int]: + tensor_bytes = tensor.numel() * tensor.element_size() + tensor.shape[0] * 4 + assert compute_device_type in {"cuda", "cpu"} + storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR} + return {storage_map[compute_device_type].value: tensor_bytes} + + def shardable_parameters( + self, module: QuantEmbeddingBagCollection + ) -> Dict[str, nn.Parameter]: + return { + name.split(".")[-2]: param + for name, param in module.state_dict().items() + if name.endswith(".weight") + } + + @property + def module_type(self) -> Type[QuantEmbeddingBagCollection]: + return QuantEmbeddingBagCollection + + +class EmbeddingAwaitable(LazyAwaitable[torch.Tensor]): + def __init__( + self, + awaitable: Awaitable[torch.Tensor], + ) -> None: + super().__init__() + self._awaitable = awaitable + + def wait(self) -> torch.Tensor: + embedding = self._awaitable.wait() + return embedding + + +class ShardedEmbeddingBag( + ShardedModule[ + SparseFeatures, + torch.Tensor, + torch.Tensor, + ], + FusedOptimizerModule, +): + """ + Sharded implementation of nn.EmbeddingBag. + This is part of public API to allow for manual data dist pipelining. + """ + + def __init__( + self, + module: nn.EmbeddingBag, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + env: ShardingEnv, + fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + + assert ( + len(table_name_to_parameter_sharding) == 1 + ), "expect 1 table, but got len(table_name_to_parameter_sharding)" + assert module.mode == "sum", "ShardedEmbeddingBag only supports sum pooling" + + self._dummy_embedding_table_name = "dummy_embedding_table_name" + self._dummy_feature_name = "dummy_feature_name" + self.parameter_sharding: ParameterSharding = next( + iter(table_name_to_parameter_sharding.values()) + ) + embedding_table_config = EmbeddingTableConfig( + num_embeddings=module.num_embeddings, + embedding_dim=module.embedding_dim, + name=self._dummy_embedding_table_name, + feature_names=[self._dummy_feature_name], + pooling=PoolingType.SUM, + # We set is_weighted to True for now, + # if per_sample_weights is None in forward(), + # we could assign a all-one vector to per_sample_weights + is_weighted=True, + embedding_names=[self._dummy_feature_name], + ) + + if self.parameter_sharding.sharding_type == ShardingType.TABLE_WISE.value: + # TODO: enable it with correct semantics, see T104397332 + raise RuntimeError( + "table-wise sharding on a single EmbeddingBag is not supported yet" + ) + + self._embedding_sharding: EmbeddingSharding = create_embedding_sharding( + sharding_type=self.parameter_sharding.sharding_type, + embedding_configs=[ + ( + embedding_table_config, + self.parameter_sharding, + next(iter(module.parameters())), + ) + ], + env=env, + device=device, + ) + self._input_dist: nn.Module = self._embedding_sharding.create_input_dist() + self._lookup: nn.Module = self._embedding_sharding.create_lookup(fused_params) + self._output_dist: nn.Module = ( + self._embedding_sharding.create_pooled_output_dist() + ) + + # Get all fused optimizers and combine them. + optims = [] + for _, module in self._lookup.named_modules(): + if isinstance(module, FusedOptimizerModule): + # modify param keys to match EmbeddingBag + params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {} + for param_key, weight in module.fused_optimizer.params.items(): + params[param_key.split(".")[-1]] = weight + module.fused_optimizer.params = params + optims.append(("", module.fused_optimizer)) + self._optim: CombinedOptimizer = CombinedOptimizer(optims) + + # pyre-ignore [14] + def input_dist( + self, + ctx: ShardedModuleContext, + input: Tensor, + offsets: Optional[Tensor] = None, + per_sample_weights: Optional[Tensor] = None, + ) -> Awaitable[SparseFeatures]: + if per_sample_weights is None: + per_sample_weights = torch.ones_like(input, dtype=torch.float) + features = KeyedJaggedTensor( + keys=[self._dummy_feature_name], + values=input, + offsets=offsets, + weights=per_sample_weights, + ) + return self._input_dist( + SparseFeatures( + id_list_features=None, + id_score_list_features=features, + ) + ) + + def compute( + self, ctx: ShardedModuleContext, dist_input: SparseFeatures + ) -> torch.Tensor: + return self._lookup(dist_input) + + def output_dist( + self, ctx: ShardedModuleContext, output: torch.Tensor + ) -> LazyAwaitable[torch.Tensor]: + return EmbeddingAwaitable( + awaitable=self._output_dist(output), + ) + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + lookup_state_dict = self._lookup.state_dict(None, "", keep_vars) + # update key to match embeddingBag state_dict key + for key, item in lookup_state_dict.items(): + new_key = prefix + key.split(".")[-1] + destination[new_key] = item + return destination + + def named_modules( + self, + memo: Optional[Set[nn.Module]] = None, + prefix: str = "", + remove_duplicate: bool = True, + ) -> Iterator[Tuple[str, nn.Module]]: + yield from [(prefix, self)] + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + for name, parameter in self._lookup.named_parameters("", recurse): + # update name to match embeddingBag parameter name + yield append_prefix(prefix, name.split(".")[-1]), parameter + + def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: + if self.parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: + yield from [] + else: + for name, _ in self._lookup.named_parameters(""): + yield append_prefix(prefix, name.split(".")[-1]) + + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for name, buffer in self._lookup.named_buffers("", recurse): + yield append_prefix(prefix, name.split(".")[-1]), buffer + + def load_state_dict( + self, + state_dict: "OrderedDict[str, torch.Tensor]", + strict: bool = True, + ) -> _IncompatibleKeys: + missing_keys = [] + unexpected_keys = [] + # update key to match embeddingBag state_dict key + for key, value in state_dict.items(): + new_key = ".".join([self._dummy_embedding_table_name, key]) + state_dict[new_key] = value + state_dict.pop(key) + missing, unexpected = self._lookup.load_state_dict( + state_dict, + strict, + ) + missing_keys.extend(missing) + unexpected_keys.extend(unexpected) + + return _IncompatibleKeys( + missing_keys=missing_keys, unexpected_keys=unexpected_keys + ) + + def sparse_grad_parameter_names( + self, + destination: Optional[List[str]] = None, + prefix: str = "", + ) -> List[str]: + destination = [] if destination is None else destination + # pyre-ignore [29] + lookup_sparse_grad_parameter_names = self._lookup.sparse_grad_parameter_names( + None, "" + ) + for name in lookup_sparse_grad_parameter_names: + destination.append(name.split(".")[-1]) + return destination + + @property + def fused_optimizer(self) -> KeyedOptimizer: + return self._optim + + +class EmbeddingBagSharder(BaseEmbeddingSharder[M]): + """ + This implementation uses non-fused nn.EmbeddingBag + """ + + def shard( + self, + module: nn.EmbeddingBag, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + ) -> ShardedEmbeddingBag: + return ShardedEmbeddingBag(module, params, env, self.fused_params, device) + + def shardable_parameters(self, module: nn.EmbeddingBag) -> Dict[str, nn.Parameter]: + return {name: param for name, param in module.named_parameters()} + + @property + def module_type(self) -> Type[nn.EmbeddingBag]: + return nn.EmbeddingBag diff --git a/torchrec/distributed/grouped_position_weighted.py b/torchrec/distributed/grouped_position_weighted.py new file mode 100644 index 000000000..08d3b103e --- /dev/null +++ b/torchrec/distributed/grouped_position_weighted.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 + +from collections import OrderedDict +from typing import Dict, Optional, Iterator, Tuple, Any, List + +import torch +import torch.nn as nn +from torchrec.distributed.embedding_types import BaseGroupedFeatureProcessor +from torchrec.distributed.utils import append_prefix +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class GroupedPositionWeightedModule(BaseGroupedFeatureProcessor): + def __init__( + self, max_feature_lengths: Dict[str, int], device: Optional[torch.device] = None + ) -> None: + super().__init__() + self.max_feature_lengths = max_feature_lengths + for length in self.max_feature_lengths.values(): + if length <= 0: + raise + self.position_weights: nn.ParameterDict = nn.ParameterDict() + for key, length in max_feature_lengths.items(): + # pyre-ignore [29] + self.position_weights[key] = nn.Parameter( + torch.empty([length], device=device).fill_(1.0) + ) + self.register_buffer( + "_dummy_weights", + torch.tensor( + max(self.max_feature_lengths.values()), + device=device, + ).fill_(1.0), + ) + + def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: + if features.weights_or_none() is None: + cat_seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) + else: + # for row-wise sharding + cat_seq = features.weights().long() + seqs = torch.split(cat_seq, features.length_per_key()) + weights_list = [] + for key, seq in zip(features.keys(), seqs): + if key in self.max_feature_lengths: + weights_list.append( + torch.gather(self.position_weights[key], dim=0, index=seq) + ) + else: + weights_list.append( + self._dummy_weights[: self.max_feature_lengths[key]] + ) + weights = torch.cat(weights_list) + + return KeyedJaggedTensor( + keys=features.keys(), + values=features.values(), + weights=weights, + lengths=features.lengths(), + offsets=features.offsets(), + stride=features.stride(), + length_per_key=features.length_per_key(), + ) + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + # pyre-ignore [29] + for name, param in self.position_weights.items(): + yield append_prefix(prefix, f"position_weights.{name}"), param + + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + yield from () + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + # pyre-ignore [29] + for name, param in self.position_weights.items(): + destination[prefix + f"position_weights.{name}"] = param + return destination + + def sparse_grad_parameter_names( + self, destination: Optional[List[str]] = None, prefix: str = "" + ) -> List[str]: + destination = [] if destination is None else destination + return destination diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py new file mode 100644 index 000000000..10f70870c --- /dev/null +++ b/torchrec/distributed/model_parallel.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 + +from collections import OrderedDict +from typing import Dict, Any, Optional, cast, List, Tuple, Iterator + +import torch +import torch.distributed as dist +from torch import nn +from torch.nn.modules.module import _IncompatibleKeys +from torch.nn.parallel import DistributedDataParallel +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollectionSharder, + QuantEmbeddingBagCollectionSharder, +) +from torchrec.distributed.planner import EmbeddingShardingPlanner, sharder_name +from torchrec.distributed.types import ( + ShardingPlan, + ModuleSharder, + ShardedModule, + ShardingEnv, +) +from torchrec.distributed.utils import append_prefix +from torchrec.distributed.utils import filter_state_dict +from torchrec.optim.fused import FusedOptimizerModule +from torchrec.optim.keyed import KeyedOptimizer, CombinedOptimizer + +# pyre-ignore [9] +default_sharders: List[ModuleSharder[nn.Module]] = [ + EmbeddingBagCollectionSharder(), + QuantEmbeddingBagCollectionSharder(), +] + + +class DistributedModelParallel(nn.Module, FusedOptimizerModule): + """ + Entry point to model parallelism. + Example: + >>> @torch.no_grad() + def init_weights(m): + if isinstance(m, nn.Linear) + m.weight.fill_(1.0) + elif isinstance(m, EmbeddingBagCollection) + for param in m.parameters(): + init.kaiming_normal_(param) + + m = MyModel(device='meta') + m = DistributedModelParallel(m) + m.apply(init_weights) + + Constructor Args: + module: module to wrap, + pg: this processes' process group, defaults to dist.GroupMember.WORLD, + device: this device, defaults to cpu, + plan: plan to use when sharding, defaults to EmbeddingShardingPlanner.collective_plan(), + sharders: ModuleSharders available to shard with, defaults to EmbeddingBagCollectionSharder(), + init_data_parallel: data-parallel modules can be lazy, i.e. they delay parameter initialization until + the first forward pass. Pass True if that's a case to delay initialization of data parallel modules. + Do first forward pass and then call DistributedModelParallel.init_data_parallel(). + init_parameters: initialize parameters for modules still on meta device. + + Call Args: + + Returns: + None + """ + + def __init__( + self, + module: nn.Module, + env: Optional[ShardingEnv] = None, + device: Optional[torch.device] = None, + plan: Optional[ShardingPlan] = None, + sharders: List[ModuleSharder[nn.Module]] = default_sharders, + init_data_parallel: bool = True, + init_parameters: bool = True, + ) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}") + + self.module = module + self.init_parameters = init_parameters + + if env is None: + pg = dist.GroupMember.WORLD + assert pg is not None, "Process group is not initialized" + env = ShardingEnv.from_process_group(pg) + self._env: ShardingEnv = env + + if device is None: + device = torch.device("cpu") + + self.device: torch.device = device + + self._sharder_map: Dict[str, ModuleSharder[nn.Module]] = { + sharder_name(sharder.module_type): sharder for sharder in sharders + } + + # 2. Call ShardingPlanner.collective_plan passing all found modules and corresponding sharders. + if plan is None: + planner = EmbeddingShardingPlanner(self._env.world_size, self.device.type) + pg = self._env.process_group + if pg is not None: + plan = planner.collective_plan(module, sharders, pg) + else: + plan = planner.plan(module, sharders) + self._plan: ShardingPlan = plan + + # 3. Replace modules w/ sharded versions, + # and then wrap w/ DistributedDataParallel. + fused_optims = [] + self._init_dmp( + fused_optims=fused_optims, + ) + if init_data_parallel: + self._init_ddp() + self._optim = CombinedOptimizer(fused_optims) + + @property + def dmp_module(self) -> nn.Module: + """ + Property to directly access sharded module, which + may or may not yet be wrapped in DDP + """ + # pyre-ignore [7] + return ( + self.module.module + if isinstance(self.module, DistributedDataParallel) + else self.module + ) + + # pyre-ignore [2, 3] + def forward(self, *args, **kwargs) -> Any: + return self.module(*args, **kwargs) + + def init_data_parallel(self) -> None: + """ + See init_data_parallel c-tor argument for usage. + It's safe to call this method multiple times. + """ + if not isinstance(self.module, DistributedDataParallel): + self._init_ddp() + + def _init_dmp( + self, + fused_optims: List[Tuple[str, KeyedOptimizer]], + ) -> None: + self._shard_modules_impl( + self.module, + "", + fused_optims, + ) + + def _shard_modules_impl( + self, + module: nn.Module, + path: str, + fused_optims: List[Tuple[str, KeyedOptimizer]], + ) -> None: + sharded_children = set() + for name, child in module.named_children(): + curr_path = path + name + sharded_params = self._plan.get_plan_for_module(curr_path) + if sharded_params: + # Shard module + sharder_key = sharder_name(type(child)) + sharded_child = self._sharder_map[sharder_key].shard( + child, + sharded_params, + self._env, + self.device, + ) + setattr(module, name, sharded_child) + if isinstance(sharded_child, FusedOptimizerModule): + fused_optims.append((curr_path, sharded_child.fused_optimizer)) + sharded_children.add(name) + else: + self._shard_modules_impl( + child, + curr_path + ".", + fused_optims, + ) + + def _init_ddp(self) -> None: + pg = self._env.process_group + if pg is None: + raise RuntimeError("Can only init DDP for ProcessGroup-based ShardingEnv") + sharded_parameter_names = set(self._sharded_parameter_names(self.module)) + DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( + module=self.module, + params_and_buffers_to_ignore=[ + key + for key, _ in self.named_parameters() + if key in sharded_parameter_names + ], + ) + # Allocate any 'meta' tensors + if self.init_parameters: + self._init_parameters(self.module) + # initailize DDP + self.module = cast( + nn.Module, + DistributedDataParallel( + module=self.module.to(self.device), + device_ids=None if self.device.type == "cpu" else [self.device], + process_group=pg, + gradient_as_bucket_view=True, + broadcast_buffers=False, + ), + ) + + # Enable static graph for better DPP performance + # pyre-ignore + self.module._set_static_graph() + + def _init_parameters(self, module: nn.Module) -> None: + @torch.no_grad() + def init_parameters(module: nn.Module) -> None: + # Allocate parameters and buffers if over 'meta' device. + has_meta_param = False + # pyre-ignore [16] + for name, param in module._parameters.items(): + if isinstance(param, torch.Tensor) and param.device.type == "meta": + # pyre-ignore [29] + module._parameters[name] = nn.Parameter( + torch.empty_like(param, device=self.device), + requires_grad=param.requires_grad, + ) + has_meta_param = True + for name, buffer in module._buffers.items(): + if isinstance(buffer, torch.Tensor) and buffer.device.type == "meta": + # pyre-ignore [29] + module._buffers[name] = torch.empty_like(buffer, device=self.device) + + # Init parameters if at least one parameter is over 'meta' device. + if has_meta_param and hasattr(module, "reset_parameters"): + # pyre-ignore [29] + module.reset_parameters() + + module.apply(init_parameters) + + def sparse_grad_parameter_names( + self, destination: Optional[List[str]] = None, prefix: str = "" + ) -> List[str]: + destination = [] if destination is None else destination + return self._sparse_grad_parameter_names(self.dmp_module, destination, prefix) + + def _sparse_grad_parameter_names( + self, module: nn.Module, destination: List[str], prefix: str = "" + ) -> List[str]: + if isinstance(module, ShardedModule): + module.sparse_grad_parameter_names(destination, prefix) + elif isinstance(module, nn.Embedding): + if module.sparse: + destination.append(append_prefix(prefix, "weight")) + elif isinstance(module, nn.EmbeddingBag): + if module.sparse: + destination.append(append_prefix(prefix, "weight")) + else: + for name, child in module.named_children(): + self._sparse_grad_parameter_names( + child, destination, append_prefix(prefix, name) + ) + return destination + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + + return self._state_dict(self.dmp_module, destination, prefix, keep_vars) + + def _state_dict( + self, + module: nn.Module, + destination: Dict[str, Any], + prefix: str, + keep_vars: bool, + ) -> Dict[str, Any]: + if isinstance(module, ShardedModule): + module.state_dict(destination, prefix, keep_vars) + else: + # pyre-ignore [29] + module._save_to_state_dict(destination, prefix, keep_vars) + for name, child in module.named_children(): + self._state_dict(child, destination, prefix + name + ".", keep_vars) + return destination + + def load_state_dict( + self, + state_dict: "OrderedDict[str, torch.Tensor]", + prefix: str = "", + strict: bool = True, + ) -> _IncompatibleKeys: + return self._load_state_dict(self.dmp_module, state_dict, prefix, strict) + + def _load_state_dict( + self, + module: nn.Module, + state_dict: "OrderedDict[str, torch.Tensor]", + prefix: str = "", + strict: bool = True, + ) -> _IncompatibleKeys: + module.load_state_dict(state_dict, strict=strict) + missing_keys = [] + unexpected_keys = [] + if isinstance(module, ShardedModule): + return module.load_state_dict(state_dict, strict=strict) + else: + # pyre-ignore [29] + module._load_from_state_dict( + state_dict, prefix, {}, strict, missing_keys, unexpected_keys, [] + ) + for name, child in module.named_children(): + m_keys, u_keys = self._load_state_dict( + child, + filter_state_dict(state_dict, prefix + name), + "", + strict, + ) + missing_keys.extend(m_keys) + unexpected_keys.extend(u_keys) + return _IncompatibleKeys( + missing_keys=missing_keys, unexpected_keys=unexpected_keys + ) + + def _named_parameters( + self, module: nn.Module, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.nn.Parameter]]: + if isinstance(module, ShardedModule): + yield from module.named_parameters(prefix, recurse) + else: + yield from module.named_parameters(prefix, recurse=False) + for name, child in module.named_children(): + yield from self._named_parameters( + child, append_prefix(prefix, name), recurse + ) + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.nn.Parameter]]: + yield from self._named_parameters(self.dmp_module, prefix, recurse) + + def _sharded_parameter_names( + self, module: nn.Module, prefix: str = "" + ) -> Iterator[str]: + if isinstance(module, ShardedModule): + yield from module.sharded_parameter_names(prefix) + else: + for name, child in module.named_children(): + yield from self._sharded_parameter_names( + child, append_prefix(prefix, name) + ) + + def _named_buffers( + self, module: nn.Module, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + if isinstance(module, ShardedModule): + yield from module.named_buffers(prefix, recurse) + else: + yield from module.named_buffers(prefix, recurse=False) + for name, child in module.named_children(): + yield from self._named_buffers( + child, append_prefix(prefix, name), recurse + ) + + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + yield from self._named_buffers(self.dmp_module, prefix, recurse) + + @property + def fused_optimizer(self) -> KeyedOptimizer: + return self._optim + + @property + def plan(self) -> ShardingPlan: + return self._plan + + @staticmethod + def _reset_parameters(module: nn.Module) -> None: + for _, m in module.named_modules(): + if hasattr(m, "reset_parameters"): + # pyre-ignore [29] + m.reset_parameters() diff --git a/torchrec/distributed/planner/__init__.py b/torchrec/distributed/planner/__init__.py new file mode 100644 index 000000000..4ee56adf3 --- /dev/null +++ b/torchrec/distributed/planner/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from torchrec.distributed.planner.embedding_planner import ( + EmbeddingShardingPlanner, +) # noqa +from torchrec.distributed.planner.utils import sharder_name # noqa diff --git a/torchrec/distributed/planner/cost_functions.py b/torchrec/distributed/planner/cost_functions.py new file mode 100644 index 000000000..d521ceaad --- /dev/null +++ b/torchrec/distributed/planner/cost_functions.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 + +import math +from typing import Dict + +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.planner.types import CostInput +from torchrec.distributed.types import ShardingType + + +# Constants +COMMS_MULTIPLER: Dict[str, int] = { + ShardingType.TABLE_WISE.value: 2, + ShardingType.COLUMN_WISE.value: 2, + ShardingType.ROW_WISE.value: 5, + ShardingType.TABLE_ROW_WISE.value: 3, + ShardingType.DATA_PARALLEL.value: 1, +} +KERNEL_MULTIPLER: Dict[str, int] = { + EmbeddingComputeKernel.DENSE.value: 25, + EmbeddingComputeKernel.SPARSE.value: 5, + EmbeddingComputeKernel.BATCHED_DENSE.value: 20, + EmbeddingComputeKernel.BATCHED_FUSED.value: 1, + EmbeddingComputeKernel.BATCHED_FUSED_UVM.value: 15, + EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING.value: 10, + EmbeddingComputeKernel.BATCHED_QUANT.value: 1, +} + + +def cost_func_compute_based(cost_input: CostInput) -> int: + sharding_type = cost_input.sharding_type + compute_kernel = cost_input.compute_kernel + param = cost_input.param + input_stats = cost_input.input_stats + hash_size = param.shape[0] + emb_dim = param.shape[1] + pooling_factor = ( + input_stats.mean + if input_stats is not None + and input_stats.mean is not None + and None not in input_stats.mean + else [1.0] + ) + cost = math.log(hash_size, 10) * emb_dim * sum(pooling_factor) + + if sharding_type not in COMMS_MULTIPLER: + raise ValueError(f"cost function does not support {sharding_type}") + + if compute_kernel not in KERNEL_MULTIPLER: + raise ValueError(f"cost function does not support {compute_kernel}") + + return round( + cost * COMMS_MULTIPLER[sharding_type] * KERNEL_MULTIPLER[compute_kernel] + ) diff --git a/torchrec/distributed/planner/embedding_planner.py b/torchrec/distributed/planner/embedding_planner.py new file mode 100644 index 000000000..4f76e3f16 --- /dev/null +++ b/torchrec/distributed/planner/embedding_planner.py @@ -0,0 +1,515 @@ +#!/usr/bin/env python3 + +import heapq +import logging +from collections import deque +from typing import Dict, Optional, List, Callable, Tuple, Any + +import torch +import torch.distributed as dist +from torch import nn +from torchrec.distributed.collective_utils import ( + invoke_on_rank_and_broadcast_result, +) +from torchrec.distributed.comm import get_local_size +from torchrec.distributed.planner.cost_functions import ( + cost_func_compute_based, +) +from torchrec.distributed.planner.types import ( + ShardingOption, + ParameterInfo, + ParameterHints, + ParameterInputStats, + CostInput, + Topology, + ParamSortKey, +) +from torchrec.distributed.planner.utils import ( + sharder_name, + get_topology, + is_enough_storage, + allocate_param, + deallocate_param, + param_sort_key, + to_plan, + bytes_to_gb, +) +from torchrec.distributed.types import ( + ShardingPlan, + ShardingPlanner, + ModuleSharder, + ShardingType, + ParameterSharding, +) + + +logger: logging.Logger = logging.getLogger(__name__) + + +class EmbeddingShardingPlanner(ShardingPlanner): + def __init__( + self, + world_size: int, + compute_device_type: str = "cuda", + hints: Optional[Dict[str, ParameterHints]] = None, + input_stats: Optional[Dict[str, ParameterInputStats]] = None, + storage: Optional[Dict[str, int]] = None, + cost_functions: Optional[List[Callable[[CostInput], int]]] = None, + ) -> None: + self._world_size: int = world_size + self._local_size: int = get_local_size(world_size) + self._hints: Dict[str, ParameterHints] = hints if hints else {} + self._input_stats: Dict[str, ParameterInputStats] = ( + input_stats if input_stats else {} + ) + self._compute_device_type = compute_device_type + + if cost_functions is None: + self._cost_functions: List[Callable[[CostInput], int]] = [ + cost_func_compute_based + ] + else: + self._cost_functions = cost_functions + + self._topology: Topology = get_topology( + world_size, compute_device_type, storage + ) + self._counter: int = 1 + + def collective_plan( + self, + module: nn.Module, + sharders: List[ModuleSharder[nn.Module]], + pg: dist.ProcessGroup, + ) -> ShardingPlan: + """ + Call self.plan(...) on rank 0 and broadcast + """ + return invoke_on_rank_and_broadcast_result( + pg, + 0, + self.plan, + module, + sharders, + ) + + def plan( + self, + module: nn.Module, + sharders: List[ModuleSharder[nn.Module]], + ) -> ShardingPlan: + """ + Algorithm + + Each parameter has a set of sharding options, ordered in terms of compute cost (lowest to highest) + + Using the first sharding option for each parameter, the planner attempts to place the parameters in a + greedy fashion by placing the highest compute cost parameter remaining on the lowest total cost device. + + In event that planner hits a global storage constraint, the planner with remove the sharding option of + the parameter with the highest storage cost; and retry same greedy approach. Typically removing a + sharding option with a high storage cost will reduce storage cost but increase compute cost for a given + parameter. + + If no solution found using this approach, planner will fail. This search is not exhaustive, + so it does not mean a solution is not possible. + + """ + param_infos = self._get_param_infos( + module=module, + sharders=sharders, + ) + unplaced_param_infos: List[Tuple[ParamSortKey, ParameterInfo]] = [ + (param_sort_key(param_info, self._world_size), param_info) + for param_info in param_infos + ] + placed_param_infos: List[Tuple[ParamSortKey, ParameterInfo]] = [] + + heapq.heapify(unplaced_param_infos) + while unplaced_param_infos: + if not self._place(unplaced_param_infos, placed_param_infos): + self._counter += 1 + self._backtrack(unplaced_param_infos, placed_param_infos) + + sharding_plan = to_plan( + param_infos, + self._compute_device_type, + self._world_size, + self._local_size, + ) + self._log_stats( + sharding_plan=sharding_plan, + param_infos=param_infos, + ) + return sharding_plan + + def _log_stats( + self, sharding_plan: ShardingPlan, param_infos: List[ParameterInfo] + ) -> None: + """ + Builds Stats, then logs out results + """ + + # Data structures to read and store statistics + shard_by_fqn = { + module_name + "." + param_name: value + for module_name, param_dict in sharding_plan.plan.items() + for param_name, value in param_dict.items() + } + stats: Dict[int, Dict[str, Any]] = { + rank: {"type": {}, "pooling_factor": 0.0, "emb_dims": 0} + for rank in range(self._world_size) + } + + # Populate stats table + for param_info in param_infos: + name = param_info.name + fqn = param_info.fqn + shard: ParameterSharding = shard_by_fqn[fqn] + input_stats = self._input_stats.get(name, None) + ranks = list(range(self._world_size)) + pooling_factor = [ + sum(input_stats.mean) + if input_stats + and isinstance(input_stats.mean, list) + and None not in input_stats.mean + else 0.0 + ] + emb_dims = [param_info.param.shape[1]] + if shard.sharding_type == ShardingType.ROW_WISE.value: + pooling_factor = [pooling_factor[0] / self._world_size] * len(ranks) + emb_dims = emb_dims * len(ranks) + elif shard.sharding_type == ShardingType.TABLE_ROW_WISE.value: + # pyre-ignore [16] + host_id = shard.ranks[0] // self._local_size + ranks = list( + range(host_id * self._local_size, (host_id + 1) * self._local_size) + ) + pooling_factor = [pooling_factor[0] / self._local_size] * len(ranks) + emb_dims = emb_dims * len(ranks) + elif shard.sharding_type == ShardingType.COLUMN_WISE.value: + ranks = shard.ranks + emb_dims = [ + shard.shard_lengths[1] + # pyre-ignore [16] + for shard in shard.sharding_spec.shards + ] + # pyre-ignore [6] + pooling_factor = pooling_factor * len(ranks) + elif shard.sharding_type == ShardingType.TABLE_WISE.value: + ranks = shard.ranks + else: # DATA PARALLEL + emb_dims = emb_dims * len(ranks) + pooling_factor = pooling_factor * len(ranks) + + # pyre-ignore [6] + for i, rank in enumerate(ranks): + count = stats[rank]["type"].get(shard.sharding_type, 0) + stats[rank]["type"][shard.sharding_type] = count + 1 + stats[rank]["pooling_factor"] += pooling_factor[i] + stats[rank]["emb_dims"] += emb_dims[i] + + # Log out results + logger.info(f"------ {self.__class__.__name__} Statistics ------") + for rank in range(self._world_size): + host = self._topology.get_host(rank) + device = self._topology.get_device(rank) + logger.info( + f" Rank {rank} -- " + f"HBM/DDR: {bytes_to_gb(device.hbm.capacity - device.hbm.free):.1f}/" + f"{bytes_to_gb(host.ddr.capacity - host.ddr.free):.1f}, " + f"Cost: {device.total_cost}, " + f"Mean Pooling: {int(stats[rank]['pooling_factor'])}, " + f"Emb Dims: {stats[rank]['emb_dims']}, " + f"Shards: {stats[rank]['type']}" + ) + logger.info( + f"------ Executed {self._counter} iteration(s) to find a solution ------" + ) + + def _place( + self, + unplaced_param_infos: List[Tuple[ParamSortKey, ParameterInfo]], + placed_param_infos: List[Tuple[ParamSortKey, ParameterInfo]], + ) -> bool: + """ + Places parameters until all parameters are placed, or a storage contraint is hit + """ + candidate_devices = [ + self._topology.get_device(rank) for rank in range(self._world_size) + ] + heapq.heapify(candidate_devices) + sort_key, param_info = heapq.heappop(unplaced_param_infos) + sharding_option = param_info.sharding_options[0] + + is_placed = False + if sharding_option.sharding_type == ShardingType.TABLE_WISE.value: + constrained_devices = [] + ranks = [] + while candidate_devices: + candidate_device = heapq.heappop(candidate_devices) + if is_enough_storage(sharding_option, self._topology, candidate_device): + ranks.append(candidate_device.rank) + sharding_option.ranks = ranks + allocate_param(sharding_option, self._topology) + heapq.heappush(candidate_devices, candidate_device) + heapq.heappush( + placed_param_infos, + ( + param_sort_key(param_info, self._world_size, "storage"), + param_info, + ), + ) + is_placed = True + break + constrained_devices.append(candidate_device) + + for constrained_device in constrained_devices: + heapq.heappush(candidate_devices, constrained_device) + elif sharding_option.sharding_type == ShardingType.COLUMN_WISE.value: + constrained_devices = [] + ranks = [] + while candidate_devices: + candidate_device = heapq.heappop(candidate_devices) + if is_enough_storage(sharding_option, self._topology, candidate_device): + ranks.append(candidate_device.rank) + sharding_option.ranks = ranks + allocate_param(sharding_option, self._topology) + heapq.heappush(candidate_devices, candidate_device) + if len(ranks) == sharding_option._num_col_wise_shards: + heapq.heappush( + placed_param_infos, + ( + param_sort_key(param_info, self._world_size, "storage"), + param_info, + ), + ) + is_placed = True + break + constrained_devices.append(candidate_device) + + for constrained_device in constrained_devices: + heapq.heappush(candidate_devices, constrained_device) + elif sharding_option.sharding_type == ShardingType.TABLE_ROW_WISE.value: + num_hosts = len(self._topology.hosts) + devices_per_host = len(self._topology.hosts[0].devices) + candidate_hosts = [0] * num_hosts + constrained_devices = [] + ranks = [] + while candidate_devices: + candidate_device = heapq.heappop(candidate_devices) + host_idx, _ = self._topology.host_and_device_by_rank[ + candidate_device.rank + ] + candidate_hosts[host_idx] += 1 + if candidate_hosts[host_idx] == devices_per_host and is_enough_storage( + sharding_option, self._topology, candidate_device + ): + ranks.append(candidate_device.rank) + sharding_option.ranks = ranks + allocate_param(sharding_option, self._topology) + heapq.heappush( + placed_param_infos, + ( + param_sort_key(param_info, self._world_size, "storage"), + param_info, + ), + ) + heapq.heappush(candidate_devices, candidate_device) + is_placed = True + break + constrained_devices.append(candidate_device) + + for constrained_device in constrained_devices: + heapq.heappush(candidate_devices, constrained_device) + + elif sharding_option.sharding_type in [ + ShardingType.DATA_PARALLEL.value, + ShardingType.ROW_WISE.value, + ]: + if is_enough_storage(sharding_option, self._topology): + sharding_option.ranks = None + allocate_param(sharding_option, self._topology) + heapq.heappush( + placed_param_infos, + ( + param_sort_key(param_info, self._world_size, "storage"), + param_info, + ), + ) + is_placed = True + else: + raise ValueError( + f"{self.__class__.__name__} does not support {sharding_option.sharding_type}" + ) + + if not is_placed: + heapq.heappush(unplaced_param_infos, (sort_key, param_info)) + + return is_placed + + def _backtrack( + self, + unplaced_param_infos: List[Tuple[ParamSortKey, ParameterInfo]], + placed_param_infos: List[Tuple[ParamSortKey, ParameterInfo]], + ) -> None: + """ + Called when the planner hits a storage constraint. A single sharding option is discarded, + and then reset state such that _place method can recalled. If no there are no available + sharding options to discard an error will be raised. + """ + + is_option_discarded = False + _, param_info = heapq.heappop(unplaced_param_infos) + + # Temporarily place param_info into solution set + heapq.heappush( + placed_param_infos, + ( + param_sort_key(param_info, self._world_size, "storage"), + param_info, + ), + ) + while placed_param_infos: + (_, placed_param_info) = heapq.heappop(placed_param_infos) + + # Deallocate in necessary + if placed_param_info is not param_info: + deallocate_param(placed_param_info.sharding_options[0], self._topology) + + # Discard sharding option from first parameter with more than one sharding option + if len(placed_param_info.sharding_options) > 1 and not is_option_discarded: + placed_param_info.sharding_options.popleft() + is_option_discarded = True + heapq.heappush( + unplaced_param_infos, + ( + param_sort_key(placed_param_info, self._world_size), + placed_param_info, + ), + ) + if not is_option_discarded: + raise RuntimeError( + f"------ {self.__class__.__name__} is unable to find a plan for model. ------\n" + "Try: \n" + " 1) Increasing the number of devices\n" + " 2) Reducing the model size\n" + " 3) Removing sharding hints that may reduce search space\n" + f"------ attempted {self._counter} iteration(s)) ------" + ) + + def _get_param_infos( + self, + module: nn.Module, + sharders: List[ModuleSharder[nn.Module]], + ) -> List[ParameterInfo]: + sharder_map: Dict[str, ModuleSharder[nn.Module]] = { + sharder_name(sharder.module_type): sharder for sharder in sharders + } + param_infos: List[ParameterInfo] = [] + + for child_path, child_module in module.named_modules(): + sharder_key = sharder_name(type(child_module)) + sharder = sharder_map.get(sharder_key, None) + if not sharder: + continue + + for name, param in sharder.shardable_parameters(child_module).items(): + sharding_options = [] + for sharding_type in self._filter_sharding_types( + name, sharder.sharding_types(self._compute_device_type) + ): + num_col_wise_shards, shard_size = self._get_num_col_wise_shards( + name, param, sharding_type + ) + for compute_kernel in self._filter_compute_kernels( + name, + sharder.compute_kernels( + sharding_type, self._compute_device_type + ), + ): + cost_input = CostInput( + param=param, + compute_device_type=self._compute_device_type, + compute_kernel=compute_kernel, + sharding_type=sharding_type, + input_stats=self._input_stats.get(name, None), + ) + cost = sum( + [ + cost_function(cost_input) + for cost_function in self._cost_functions + ] + ) + sharding_options.append( + ShardingOption( + cost=cost, + sharding_type=sharding_type, + compute_kernel=compute_kernel, + storage_usage=sharder.storage_usage( + param, self._compute_device_type, compute_kernel + ), + _num_col_wise_shards=num_col_wise_shards, + col_wise_shard_dim=shard_size, + ) + ) + param_infos.append( + ParameterInfo( + param=param, + name=name, + prefix=child_path, + sharding_options=deque(sorted(sharding_options)), + ) + ) + return param_infos + + def _filter_sharding_types(self, name: str, sharding_types: List[str]) -> List[str]: + hint = self._hints.get(name, None) + if not hint or not hint.sharding_types: + return sharding_types + sharding_types = list( + set(hint.sharding_types).intersection(set(sharding_types)) + ) + if not sharding_types: + raise RuntimeError( + f"No available sharding types after applying hints for {name}" + ) + return sharding_types + + def _filter_compute_kernels( + self, name: str, compute_kernels: List[str] + ) -> List[str]: + hint = self._hints.get(name, None) + if not hint or not hint.compute_kernels: + return compute_kernels + compute_kernels = list( + set(hint.compute_kernels).intersection(set(compute_kernels)) + ) + if not compute_kernels: + raise RuntimeError( + f"No available compute kernels after applying hints for {name}" + ) + return compute_kernels + + def _get_num_col_wise_shards( + self, name: str, param: torch.Tensor, sharding_type: str + ) -> Tuple[Optional[int], Optional[int]]: + num_col_wise_shards = None + col_wise_shard_dim = None + if sharding_type == ShardingType.COLUMN_WISE.value: + _hint = self._hints.get(name, None) + col_wise_shard_dim_hint = ( + None if _hint is None else _hint.col_wise_shard_dim + ) + col_wise_shard_dim = ( + col_wise_shard_dim_hint + if col_wise_shard_dim_hint is not None + else param.shape[1] + ) + # column-wise shard the weights + num_col_wise_shards, residual = divmod(param.shape[1], col_wise_shard_dim) + if residual > 0: + num_col_wise_shards += 1 + elif sharding_type == ShardingType.TABLE_WISE.value: + num_col_wise_shards = 1 + return num_col_wise_shards, col_wise_shard_dim diff --git a/torchrec/distributed/planner/new/calculators.py b/torchrec/distributed/planner/new/calculators.py new file mode 100644 index 000000000..7eb59a718 --- /dev/null +++ b/torchrec/distributed/planner/new/calculators.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 + +from typing import Dict, Optional, Tuple, List + +from torchrec.distributed.planner.new.constants import ( + BIGINT_DTYPE, + INTRA_NODE_BANDWIDTH, + CROSS_NODE_BANDWIDTH, + kernel_bw_lookup, +) +from torchrec.distributed.planner.new.types import ( + PlannerConstraints, + Calculator, + Topology, + ShardingOption, +) +from torchrec.distributed.types import ShardingType + + +def cost_func_emb_wall_time( + shard_lengths: List[List[int]], + compute_kernel: str, + compute_device: str, + sharding_type: str, + batch_size: int, + world_size: int, + local_world_size: int, + input_lengths: List[float], + input_data_type_size: float, + output_data_type_size: float, + bw_intra_host: int, + bw_inter_host: int, + has_input_dist: bool = True, + has_output_dist: bool = True, + caching_ratio: Optional[float] = None, +) -> List[float]: + """ + Attempts to model costs as a function of relative wall times. + Only models forward costs (ignores backward costs) + The computation cost estimation is based on EmbeddingBagCollectionSharder + (pooledEmbedding) + + shard_lengths: the list of (local_rows, local_cols) pf each shard + input_lengths: the list of the average number of lookups of each input query feature + bw_intra_host: the bandwidth within the single host like multiple threads + bw_inter_host: the bandwidth between two hosts like multiple machines + input_dist: + tw_sharding: https://fburl.com/code/uxueh8wh + rw_sharding: https://fburl.com/code/zemh4rzw + cw_sharding: same as tw, consider as multiple tables (cw_emb_dim * num_sharding = tw_emb_dim) + twrw_sharding: https://fburl.com/code/vrweq0ri + output_dist: + tw_sharding: https://fburl.com/code/ete7schi + rw_sharding: https://fburl.com/code/gl9186u1 + cw_sharding: same as tw, consider as multiple tables (cw_emb_dim * num_sharding = tw_emb_dim) + twrw_sharding: https://fburl.com/code/z9nyjflj + + Note: the computation of the output cost will count len(input_length) due to pooling + + """ + shard_costs = [] + B = 1.0 * world_size * batch_size # global batch size + device_bw = kernel_bw_lookup(compute_device, compute_kernel, caching_ratio) + + for hash_size, emb_dim in shard_lengths: + + if sharding_type is ShardingType.TABLE_WISE.value: + input_cost, compute_cost, output_cost = _get_tw_sharding_cost( + B, + world_size, + input_lengths, + emb_dim, + input_data_type_size, + output_data_type_size, + device_bw, + bw_inter_host, + ) + elif sharding_type is ShardingType.COLUMN_WISE.value: + input_cost, compute_cost, output_cost = _get_cw_sharding_cost( + B, + world_size, + input_lengths, + emb_dim, + input_data_type_size, + output_data_type_size, + device_bw, + bw_inter_host, + ) + elif sharding_type is ShardingType.ROW_WISE.value: + input_cost, compute_cost, output_cost = _get_rw_sharding_cost( + B, + world_size, + input_lengths, + emb_dim, + input_data_type_size, + output_data_type_size, + device_bw, + bw_inter_host, + ) + elif sharding_type is ShardingType.TABLE_ROW_WISE.value: + input_cost, compute_cost, output_cost = _get_twrw_sharding_cost( + B, + world_size, + local_world_size, + input_lengths, + emb_dim, + input_data_type_size, + output_data_type_size, + device_bw, + bw_inter_host, + bw_intra_host, + ) + elif sharding_type is ShardingType.DATA_PARALLEL.value: + input_cost, compute_cost, output_cost = _get_dp_sharding_cost( + batch_size, + input_lengths, + hash_size * emb_dim, + bw_inter_host, + emb_dim, + output_data_type_size, + device_bw, + ) + else: + raise RuntimeError(f"Unexpected sharding type: {sharding_type}") + + shard_cost = 0 + shard_cost += input_cost if has_input_dist else 0 + shard_cost += compute_cost + shard_cost += output_cost if has_output_dist else 0 + shard_costs.append(shard_cost) + + return shard_costs + + +def _get_tw_sharding_cost( + global_batch_size: float, + world_size: int, + input_lengths: List[float], + emb_dim: int, + input_data_type_size: float, + output_data_type_size: float, + device_bw: float, + bw_inter_host: int, +) -> Tuple[float, float, float]: + input_cost = ( + global_batch_size * sum(input_lengths) * input_data_type_size / bw_inter_host + ) + compute_cost = ( + global_batch_size + * sum(input_lengths) + * emb_dim + * output_data_type_size + / device_bw + ) + output_cost = ( + global_batch_size + * emb_dim + * len(input_lengths) + * output_data_type_size + / bw_inter_host + ) + return (input_cost, compute_cost, output_cost) + + +def _get_cw_sharding_cost( + global_batch_size: float, + world_size: int, + input_lengths: List[float], + emb_dim: int, + input_data_type_size: float, + output_data_type_size: float, + device_bw: float, + bw_inter_host: int, +) -> Tuple[float, float, float]: + input_cost = ( + global_batch_size * sum(input_lengths) * input_data_type_size / bw_inter_host + ) + compute_cost = ( + global_batch_size + * sum(input_lengths) + * emb_dim + * output_data_type_size + / device_bw + ) + output_cost = ( + global_batch_size + * emb_dim + * len(input_lengths) + * output_data_type_size + / bw_inter_host + ) + return (input_cost, compute_cost, output_cost) + + +def _get_rw_sharding_cost( + global_batch_size: float, + world_size: int, + input_lengths: List[float], + emb_dim: int, + input_data_type_size: float, + output_data_type_size: float, + device_bw: float, + bw_inter_host: int, +) -> Tuple[float, float, float]: + input_cost = ( + global_batch_size + * sum(input_lengths) + / world_size + * input_data_type_size + / bw_inter_host + ) + compute_cost = ( + global_batch_size + * sum(input_lengths) + / world_size + * emb_dim + * output_data_type_size + / device_bw + ) + output_cost = ( + global_batch_size + * emb_dim + * len(input_lengths) + * output_data_type_size + / bw_inter_host + ) + return (input_cost, compute_cost, output_cost) + + +def _get_twrw_sharding_cost( + global_batch_size: float, + world_size: int, + local_world_size: int, + input_lengths: List[float], + emb_dim: int, + input_data_type_size: float, + output_data_type_size: float, + device_bw: float, + bw_inter_host: int, + bw_intra_host: int, +) -> Tuple[float, float, float]: + input_cost = ( + global_batch_size + * sum(input_lengths) + / local_world_size + * input_data_type_size + / bw_inter_host + ) + compute_cost = ( + global_batch_size + * sum(input_lengths) + / local_world_size + * emb_dim + * output_data_type_size + / device_bw + ) + output_cost = ( + global_batch_size + * emb_dim + * len(input_lengths) + * output_data_type_size + / bw_intra_host + + global_batch_size + * emb_dim + * len(input_lengths) + * output_data_type_size + * (local_world_size / world_size) + / bw_inter_host + ) + return (input_cost, compute_cost, output_cost) + + +def _get_dp_sharding_cost( + batch_size: float, + input_lengths: List[float], + grad_num_elem: int, + bw_inter_host: int, + emb_dim: int, + output_data_type_size: float, + device_bw: float, +) -> Tuple[float, float, float]: + input_cost = 0 + compute_cost = ( + batch_size * sum(input_lengths) * emb_dim * output_data_type_size / device_bw + ) + # TODO: this is allreduce cost, better separated out as backward cost + output_cost = grad_num_elem * output_data_type_size / bw_inter_host + return (input_cost, compute_cost, output_cost) + + +class EmbeddingWTCostCalculator(Calculator): + """ + Embedding Wall Time Cost Calculator + """ + + def __init__( + self, + topology: Topology, + constraints: Optional[Dict[str, PlannerConstraints]] = None, + ) -> None: + self._topology = topology + self._constraints = constraints + + def run(self, sharding_options: List[ShardingOption]) -> None: + for sharding_option in sharding_options: + caching_ratio = ( + self._constraints[sharding_option.name].caching_ratio + if self._constraints and self._constraints.get(sharding_option.name) + else None + ) + shard_costs = cost_func_emb_wall_time( + shard_lengths=[shard.length for shard in sharding_option.shards], + compute_kernel=sharding_option.compute_kernel, + compute_device=self._topology.compute_device, + sharding_type=sharding_option.sharding_type, + batch_size=sharding_option.batch_size, + world_size=self._topology.world_size, + local_world_size=self._topology.local_world_size, + input_lengths=sharding_option.input_lengths, + input_data_type_size=BIGINT_DTYPE, + output_data_type_size=sharding_option.tensor.element_size(), + bw_intra_host=getattr( + self._topology, "intra_host_bw", INTRA_NODE_BANDWIDTH + ), + bw_inter_host=getattr( + self._topology, "inter_host_bw", CROSS_NODE_BANDWIDTH + ), + has_input_dist=True if sharding_option.upstream_modules else False, + has_output_dist=False if sharding_option.downstream_modules else True, + caching_ratio=caching_ratio, + ) + for shard, cost in zip(sharding_option.shards, shard_costs): + shard.cost = cost diff --git a/torchrec/distributed/planner/new/constants.py b/torchrec/distributed/planner/new/constants.py new file mode 100644 index 000000000..6d8e6f439 --- /dev/null +++ b/torchrec/distributed/planner/new/constants.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 + +from typing import Optional + +from torchrec.distributed.embedding_types import EmbeddingComputeKernel + +MAX_SIZE: int = (1 << 63) - 1 + +INTRA_NODE_BANDWIDTH: int = 600 +CROSS_NODE_BANDWIDTH: int = 12 + +MIN_CW_DIM: int = 32 +POOLING_FACTOR: float = 1.0 + +BIGINT_DTYPE: int = 8 + +HBM_CAP: int = 32 * 1024 * 1024 * 1024 # 32 GB +DDR_CAP: int = 128 * 1024 * 1024 * 1024 # 128 GB +DDR_MEM_BW: int = 51 +HBM_MEM_BW: int = 897 +CACHING_RATIO: float = 0.2 +BATCH_SIZE: int = 512 + + +def kernel_bw_lookup( + compute_device: str, + compute_kernel: str, + caching_ratio: Optional[float] = None, +) -> float: + caching_ratio = caching_ratio if caching_ratio else CACHING_RATIO + return { + # CPU + ("cpu", EmbeddingComputeKernel.DENSE.value): 0.35 * DDR_MEM_BW, + ("cpu", EmbeddingComputeKernel.SPARSE.value): 0.35 * DDR_MEM_BW, + ("cpu", EmbeddingComputeKernel.BATCHED_DENSE.value): 0.5 * DDR_MEM_BW, + ("cpu", EmbeddingComputeKernel.BATCHED_FUSED.value): 1 * DDR_MEM_BW, + # CUDA + ("cuda", EmbeddingComputeKernel.DENSE.value): 0.35 * HBM_MEM_BW, + ("cuda", EmbeddingComputeKernel.SPARSE.value): 0.35 * HBM_MEM_BW, + ("cuda", EmbeddingComputeKernel.BATCHED_DENSE.value): 0.5 * HBM_MEM_BW, + ("cuda", EmbeddingComputeKernel.BATCHED_FUSED.value): 1 * HBM_MEM_BW, + ("cuda", EmbeddingComputeKernel.BATCHED_FUSED_UVM.value): DDR_MEM_BW / 100, + ("cuda", EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING.value): ( + caching_ratio * HBM_MEM_BW + (1 - caching_ratio) * DDR_MEM_BW + ) + / 100, + }[(compute_device, compute_kernel)] diff --git a/torchrec/distributed/planner/new/enumerators.py b/torchrec/distributed/planner/new/enumerators.py new file mode 100644 index 000000000..12d647495 --- /dev/null +++ b/torchrec/distributed/planner/new/enumerators.py @@ -0,0 +1,581 @@ +#!/usr/bin/env python3 + +import math +from typing import Tuple, Optional, Dict, List + +import torch +from torch import nn +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.planner.new.constants import ( + MIN_CW_DIM, + POOLING_FACTOR, + BIGINT_DTYPE, + CACHING_RATIO, +) +from torchrec.distributed.planner.new.types import ( + PlannerConstraints, + InputStats, + Enumerator, + ShardingOption, + Shard, + Storage, + Topology, + PartitionByType, +) +from torchrec.distributed.planner.utils import sharder_name +from torchrec.distributed.types import ModuleSharder, ShardingType + + +class EmbeddingEnumerator(Enumerator): + def __init__( + self, + topology: Topology, + constraints: Optional[Dict[str, PlannerConstraints]] = None, + input_stats: Optional[Dict[str, InputStats]] = None, + ) -> None: + self._compute_device: str = topology.compute_device + self._world_size: int = topology.world_size + self._local_world_size: int = topology.local_world_size + self._constraints = constraints + self._input_stats = input_stats + self._batch_size: int = topology.batch_size + + def run( + self, module: nn.Module, sharders: List[ModuleSharder[nn.Module]] + ) -> List[ShardingOption]: + sharder_map: Dict[str, ModuleSharder[nn.Module]] = { + sharder_name(sharder.module_type): sharder for sharder in sharders + } + sharding_options: List[ShardingOption] = [] + + for child_path, child_module in module.named_modules(): + sharder_key = sharder_name(type(child_module)) + sharder = sharder_map.get(sharder_key, None) + if not sharder: + continue + + for name, param in sharder.shardable_parameters(child_module).items(): + for sharding_type in self._filter_sharding_types( + name, sharder.sharding_types(self._compute_device) + ): + for compute_kernel in self._filter_compute_kernels( + name, + sharder.compute_kernels(sharding_type, self._compute_device), + ): + col_wise_shard_dim = ( + self._constraints[name].min_partition + if self._constraints and self._constraints.get(name) + else None + ) + shard_lengths, shard_offsets = get_shard_lengths_and_offsets( + tensor=param, + world_size=self._world_size, + local_world_size=self._local_world_size, + sharding_type=sharding_type, + col_wise_shard_dim=col_wise_shard_dim, + ) + input_lengths = self._get_input_lengths(name) + caching_ratio = ( + self._constraints[name].caching_ratio + if self._constraints and self._constraints.get(name) + else None + ) + shard_storages = get_shard_storages( + sharder=sharder, + sharding_type=sharding_type, + tensor=param, + compute_device=self._compute_device, + compute_kernel=compute_kernel, + shard_lengths=shard_lengths, + batch_size=self._batch_size, + world_size=self._world_size, + local_world_size=self._local_world_size, + input_lengths=input_lengths, + caching_ratio=caching_ratio + if caching_ratio + else CACHING_RATIO, + ) + sharding_options.append( + ShardingOption( + name=name, + tensor=param, + module=(child_path, child_module), + upstream_modules=[], + downstream_modules=[], + input_lengths=input_lengths, + batch_size=self._batch_size, + compute_kernel=compute_kernel, + sharding_type=sharding_type, + partition_by=get_partition_by_type(sharding_type), + shards=[ + Shard(length=length, offset=offset, storage=storage) + for length, offset, storage in zip( + shard_lengths, shard_offsets, shard_storages + ) + ], + ) + ) + + return sharding_options + + def _filter_sharding_types(self, name: str, sharding_types: List[str]) -> List[str]: + if not self._constraints or not self._constraints.get(name): + return sharding_types + constraints: PlannerConstraints = self._constraints[name] + if not constraints.sharding_types: + return sharding_types + constrained_sharding_types: List[str] = constraints.sharding_types + + sharding_types = list(set(constrained_sharding_types) & set(sharding_types)) + + if not sharding_types: + raise RuntimeError( + f"No available sharding types after applying user provided constraints for {name}" + ) + return sharding_types + + def _filter_compute_kernels( + self, name: str, compute_kernels: List[str] + ) -> List[str]: + if not self._constraints or not self._constraints.get(name): + return compute_kernels + constraints: PlannerConstraints = self._constraints[name] + if not constraints.compute_kernels: + return compute_kernels + constrained_compute_kernels: List[str] = constraints.compute_kernels + + compute_kernels = list(set(constrained_compute_kernels) & set(compute_kernels)) + + if not compute_kernels: + raise RuntimeError( + f"No available compute kernels after applying user provided constraints for {name}" + ) + return compute_kernels + + def _get_input_lengths(self, name: str) -> List[float]: + return ( + self._input_stats[name].pooling_factors + if self._input_stats and self._input_stats.get(name) + else [POOLING_FACTOR] + ) + + +def get_partition_by_type(sharding_type: str) -> str: + device_sharding_types = { + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + } + host_sharding_types = {ShardingType.TABLE_ROW_WISE.value} + uniform_sharding_types = { + ShardingType.ROW_WISE.value, + ShardingType.DATA_PARALLEL.value, + } + + if sharding_type in device_sharding_types: + return PartitionByType.DEVICE.value + elif sharding_type in host_sharding_types: + return PartitionByType.HOST.value + elif sharding_type in uniform_sharding_types: + return PartitionByType.UNIFORM.value + + raise ValueError(f"Unrecognized sharding type provided: {sharding_type}") + + +def get_shard_lengths_and_offsets( + tensor: torch.Tensor, + world_size: int, + local_world_size: int, + sharding_type: str, + col_wise_shard_dim: Optional[int] = None, +) -> Tuple[List[List[int]], List[List[int]]]: + (rows, columns) = tensor.shape + + if sharding_type == ShardingType.DATA_PARALLEL.value: + return [[rows, columns]] * world_size, [[0, 0]] * world_size + elif sharding_type == ShardingType.TABLE_WISE.value: + return [[rows, columns]], [[0, 0]] + elif sharding_type == ShardingType.COLUMN_WISE.value: + return _get_cw_shard_lengths_and_offsets(columns, rows, col_wise_shard_dim) + elif sharding_type == ShardingType.ROW_WISE.value: + return _get_rw_shard_lengths_and_offsets(rows, world_size, columns) + elif sharding_type == ShardingType.TABLE_ROW_WISE.value: + return _get_rw_shard_lengths_and_offsets(rows, local_world_size, columns) + + raise ValueError(f"Unrecognized sharding type provided: {sharding_type}") + + +def _get_rw_shard_lengths_and_offsets( + hash_size: int, num_devices: int, columns: int +) -> Tuple[List[List[int]], List[List[int]]]: + # Set prefix of shard_lengths to be ceil(hash_size/num_devices). For exmaple + # if hash_size = 10, num_devices = 3, we will allocate the rows as + # 3,3,3,1 (rather than 3,3,2,2). This is due to implementation in RWSharding that sets + # block_size_lists to be ceil. The balanced way is harder to support on GPU. For more details + # see https://fb.quip.com/xbgbAchCTOL0 + # Also consider the example of hash_size = 5, num_devices = 4. The expected rows per rank is + # [2,2,1,0]. + num_devices: int = min(num_devices, hash_size) + + block_size: int = math.ceil(hash_size / num_devices) + last_rank: int = hash_size // block_size + last_block_size: int = hash_size - block_size * last_rank + shard_lengths: List[List[int]] = [] + + for rank in range(num_devices): + if rank < last_rank: + local_row: int = block_size + elif rank == last_rank: + local_row: int = last_block_size + else: + local_row: int = 0 + shard_lengths.append([local_row, columns]) + shard_offsets = [[0, 0]] + + for i in range(num_devices - 1): + shard_offsets.append([shard_lengths[i][0] + shard_offsets[i][0], 0]) + + return shard_lengths, shard_offsets + + +def _get_cw_shard_lengths_and_offsets( + hash_size: int, + rows: int, + col_wise_shard_dim: Optional[int] = None, +) -> Tuple[List[List[int]], List[List[int]]]: + block_size: int = min( + col_wise_shard_dim if col_wise_shard_dim else MIN_CW_DIM, hash_size + ) + num_col_wise_shards, residual = divmod(hash_size, block_size) + + shard_lengths: List[List[int]] = [[rows, block_size]] * (num_col_wise_shards - 1) + shard_lengths.append([rows, block_size + residual]) + + shard_offsets: List[List[int]] = [ + [0, block_size * rank] for rank in range(num_col_wise_shards) + ] + return shard_lengths, shard_offsets + + +def get_shard_storages( + sharder: ModuleSharder[nn.Module], + sharding_type: str, + tensor: torch.Tensor, + compute_device: str, + compute_kernel: str, + shard_lengths: List[List[int]], + batch_size: int, + world_size: int, + local_world_size: int, + input_lengths: List[float], + caching_ratio: float, +) -> List[Storage]: + input_data_type_size = BIGINT_DTYPE + output_data_type_size = tensor.element_size() + + input_sizes, output_sizes = _get_shard_io_sizes( + sharding_type=sharding_type, + batch_size=batch_size, + world_size=world_size, + local_world_size=local_world_size, + input_lengths=input_lengths, + emb_dim=tensor.shape[1], + shard_lengths=shard_lengths, + input_data_type_size=input_data_type_size, + output_data_type_size=output_data_type_size, + ) + + tensor_storage = sharder.storage_usage(tensor, compute_device, compute_kernel) + hbm_storage: int = tensor_storage.get("hbm", 0) + ddr_storage: int = tensor_storage.get("ddr", 0) + + if compute_kernel == EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING.value: + hbm_storage = round(ddr_storage * caching_ratio) + ddr_storage = ddr_storage - hbm_storage + + hbm_specific_sizes: List[int] = _get_storage_specific_sizes( + storage=hbm_storage, + shape=tensor.shape, + shard_lengths=shard_lengths, + sharding_type=sharding_type, + compute_kernel=compute_kernel, + on_device=compute_device == "cuda", + input_sizes=input_sizes, + input_data_type_size=input_data_type_size, + output_data_type_size=output_data_type_size, + ) + ddr_specific_sizes: List[int] = _get_storage_specific_sizes( + storage=ddr_storage, + shape=tensor.shape, + shard_lengths=shard_lengths, + sharding_type=sharding_type, + compute_kernel=compute_kernel, + on_device=compute_device == "cpu", + input_sizes=input_sizes, + input_data_type_size=input_data_type_size, + output_data_type_size=output_data_type_size, + ) + + hbm_sizes: List[int] = [ + input_size + output_size + hbm_specific_size if compute_device == "cuda" else 0 + for input_size, output_size, hbm_specific_size in zip( + input_sizes, + output_sizes, + hbm_specific_sizes, + ) + ] + ddr_sizes: List[int] = [ + input_size + output_size + ddr_specific_size + if compute_device == "cpu" + else ddr_specific_size + for input_size, output_size, ddr_specific_size in zip( + input_sizes, + output_sizes, + ddr_specific_sizes, + ) + ] + + return [ + Storage( + hbm=hbm_size, + ddr=ddr_size, + ) + for hbm_size, ddr_size in zip(hbm_sizes, ddr_sizes) + ] + + +def _get_shard_io_sizes( + sharding_type: str, + batch_size: int, + world_size: int, + local_world_size: int, + input_lengths: List[float], + emb_dim: int, + shard_lengths: List[List[int]], + input_data_type_size: int, + output_data_type_size: int, +) -> Tuple[List[int], List[int]]: + if sharding_type == ShardingType.DATA_PARALLEL.value: + return _get_dp_shard_io_sizes( + batch_size=batch_size, + input_lengths=input_lengths, + emb_dim=emb_dim, + num_shards=len(shard_lengths), + input_data_type_size=input_data_type_size, + output_data_type_size=output_data_type_size, + ) + elif sharding_type == ShardingType.TABLE_WISE.value: + return _get_tw_shard_io_sizes( + batch_size=batch_size, + world_size=world_size, + input_lengths=input_lengths, + emb_dim=emb_dim, + input_data_type_size=input_data_type_size, + output_data_type_size=output_data_type_size, + ) + elif sharding_type == ShardingType.COLUMN_WISE.value: + return _get_cw_shard_io_sizes( + batch_size=batch_size, + world_size=world_size, + input_lengths=input_lengths, + shard_lengths=shard_lengths, + input_data_type_size=input_data_type_size, + output_data_type_size=output_data_type_size, + ) + elif sharding_type == ShardingType.ROW_WISE.value: + return _get_rw_shard_io_sizes( + batch_size=batch_size, + world_size=world_size, + input_lengths=input_lengths, + shard_lengths=shard_lengths, + input_data_type_size=input_data_type_size, + output_data_type_size=output_data_type_size, + ) + elif sharding_type == ShardingType.TABLE_ROW_WISE.value: + return _get_twrw_shard_io_sizes( + batch_size=batch_size, + world_size=world_size, + local_world_size=local_world_size, + input_lengths=input_lengths, + shard_lengths=shard_lengths, + input_data_type_size=input_data_type_size, + output_data_type_size=output_data_type_size, + ) + else: + raise ValueError(f"Unrecognized sharding type provided: {sharding_type}") + + +def _get_dp_shard_io_sizes( + batch_size: int, + input_lengths: List[float], + emb_dim: int, + num_shards: int, + input_data_type_size: int, + output_data_type_size: int, +) -> Tuple[List[int], List[int]]: + input_sizes: List[int] = [ + # pyre-ignore[58] + math.ceil(batch_size * sum(input_lengths) * input_data_type_size) + ] * num_shards + + output_sizes: List[int] = [ + batch_size * emb_dim * len(input_lengths) * output_data_type_size + ] * num_shards + + return input_sizes, output_sizes + + +def _get_tw_shard_io_sizes( + batch_size: int, + world_size: int, + input_lengths: List[float], + emb_dim: int, + input_data_type_size: int, + output_data_type_size: int, +) -> Tuple[List[int], List[int]]: + input_sizes: List[int] = [ + # pyre-ignore[58] + math.ceil(batch_size * world_size * sum(input_lengths) * input_data_type_size) + ] + + output_sizes: List[int] = [ + batch_size * world_size * emb_dim * len(input_lengths) * output_data_type_size + ] + + return input_sizes, output_sizes + + +def _get_cw_shard_io_sizes( + batch_size: int, + world_size: int, + input_lengths: List[float], + shard_lengths: List[List[int]], + input_data_type_size: int, + output_data_type_size: int, +) -> Tuple[List[int], List[int]]: + input_sizes: List[int] = [ + # pyre-ignore[58] + math.ceil(batch_size * world_size * sum(input_lengths) * input_data_type_size) + ] * len(shard_lengths) + + output_sizes: List[int] = [ + ( + batch_size + * world_size + * shard_lengths[i][1] + * len(input_lengths) + * output_data_type_size + ) + for i in range(len(shard_lengths)) + ] + + return input_sizes, output_sizes + + +def _get_rw_shard_io_sizes( + batch_size: int, + world_size: int, + input_lengths: List[float], + shard_lengths: List[List[int]], + input_data_type_size: int, + output_data_type_size: int, +) -> Tuple[List[int], List[int]]: + input_sizes: List[int] = [ + math.ceil( + batch_size + * world_size + # pyre-ignore[58] + * sum(input_lengths) + / world_size + * input_data_type_size + ) + ] * len(shard_lengths) + + output_sizes: List[int] = [ + ( + batch_size + * world_size + * shard_lengths[i][1] + * len(input_lengths) + * output_data_type_size + ) + for i in range(len(shard_lengths)) + ] + + return input_sizes, output_sizes + + +def _get_twrw_shard_io_sizes( + batch_size: int, + world_size: int, + local_world_size: int, + input_lengths: List[float], + shard_lengths: List[List[int]], + input_data_type_size: int, + output_data_type_size: int, +) -> Tuple[List[int], List[int]]: + input_sizes: List[int] = [ + math.ceil( + batch_size + * world_size + # pyre-ignore[58] + * sum(input_lengths) + / local_world_size + * input_data_type_size + ) + ] * len(shard_lengths) + + output_sizes: List[int] = [ + ( + batch_size + * world_size + * shard_lengths[i][1] + * len(input_lengths) + * output_data_type_size + ) + for i in range(len(shard_lengths)) + ] + + return input_sizes, output_sizes + + +def _get_storage_specific_sizes( + storage: int, + shape: torch.Size, + shard_lengths: List[List[int]], + sharding_type: str, + compute_kernel: str, + on_device: bool, + input_sizes: List[int], + input_data_type_size: int, + output_data_type_size: int, +) -> List[int]: + tensor_sizes: List[int] = [ + math.ceil(storage * math.prod(length) / math.prod(shape)) + if sharding_type != ShardingType.DATA_PARALLEL.value + else storage + for length in shard_lengths + ] + + gradient_sizes: List[int] = tensor_sizes + if compute_kernel == EmbeddingComputeKernel.SPARSE.value and on_device: + gradient_sizes = [ + math.ceil( + input_size + * shard_length[1] + * output_data_type_size + / input_data_type_size + ) + for input_size, shard_length in zip(input_sizes, shard_lengths) + ] + + optimizer_sizes: List[int] = [ + tensor_size * 2 if sharding_type == ShardingType.DATA_PARALLEL.value else 0 + for tensor_size in tensor_sizes + ] + + return [ + tensor_size + gradient_size + optimizer_size + for tensor_size, gradient_size, optimizer_size in zip( + tensor_sizes, gradient_sizes, optimizer_sizes + ) + ] diff --git a/torchrec/distributed/planner/new/partitioners.py b/torchrec/distributed/planner/new/partitioners.py new file mode 100644 index 000000000..f59411e1c --- /dev/null +++ b/torchrec/distributed/planner/new/partitioners.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python3 + +from typing import List, Tuple, Optional, Dict, cast + +from torchrec.distributed.planner.new.constants import MAX_SIZE +from torchrec.distributed.planner.new.types import ( + Partitioner, + Topology, + ShardingOption, + Storage, + PartitionByType, + PartitionError, +) + + +def greedy_partition( + num_partitions: int, + sharding_options: List[ShardingOption], + shard_idxes: Optional[List[Tuple[int, int]]] = None, + partition_sums: Optional[List[float]] = None, + mem_cap: Optional[List[Storage]] = None, +) -> List[List[Tuple[int, int]]]: + """ + Divides indexes among `num_parititions` partitions in a greedy + fashion based on cost weights associated with each [option_idx, shard_idx]. + Returns a list of indices of (option_idx, shard_idx) that should be allocated to each partition + + For example if we have sharding_options = [ + [0,1,2,3] with costs [10,20,30,40] + [0,1] with costs [200,300] + ] with num_partitions=3 + + The final output would be + [ + partition_0 = [(1,1)], with a cost of 300 + partition_1 = [(1,0)], with a cost of 200 + partition_2 = [(0,0),(0,1),(0,2)], with a cost of 100 (10+20+30+40) + ] + """ + + if shard_idxes is None: + shard_idxes = [] + for option_idx, sharding_option in enumerate(sharding_options): + for shard_idx in range(sharding_option.num_shards): + shard_idxes.append((option_idx, shard_idx)) + + def to_comparable(order_shard_idx: Tuple[int, int]) -> Tuple[float, Storage]: + sharding_option: ShardingOption = sharding_options[order_shard_idx[0]] + return ( + cast(float, sharding_option.shards[order_shard_idx[1]].cost), + sharding_option.shards[order_shard_idx[1]].storage, + ) + + sorted_shard_idxes = sorted( + shard_idxes, key=lambda order_shard_idx: to_comparable(order_shard_idx) + ) + + partitions = [[] for p in range(num_partitions)] + if partition_sums is None: + partition_sums = [0.0] * num_partitions + + partition_size_sums = [Storage(hbm=0, ddr=0) for _ in range(num_partitions)] + + if mem_cap is None: + mem_cap = [Storage(hbm=MAX_SIZE, ddr=MAX_SIZE) for _ in range(num_partitions)] + + assert len(partition_size_sums) == len( + mem_cap + ), "partition_size_sums and mem_cap must have the same dimensions" + + """ + Successively add remaining pairs to the partition with the + minimum sum. + """ + while sorted_shard_idxes: + option_idx, shard_idx = sorted_shard_idxes.pop() + storage_size = sharding_options[option_idx].shards[shard_idx].storage + cost = cast(float, sharding_options[option_idx].shards[shard_idx].cost) + + min_sum = MAX_SIZE + min_partition_idx = -1 + for partition_idx in range(num_partitions): + partition_mem_cap: Storage = mem_cap[partition_idx] + partition_size_sum: Storage = partition_size_sums[partition_idx] + if ( + partition_mem_cap.hbm >= partition_size_sum.hbm + storage_size.hbm + ) and (partition_mem_cap.ddr >= partition_size_sum.ddr + storage_size.ddr): + if partition_sums[partition_idx] < min_sum: + min_sum = partition_sums[partition_idx] + min_partition_idx = partition_idx + + if min_partition_idx == -1: + raise PartitionError( + f"Table of size {storage_size}GB cannot be added to any rank. partition_size_sums: {partition_size_sums}. mem_cap: {mem_cap}." + ) + + partitions[min_partition_idx].append((option_idx, shard_idx)) + + partition_size_sums[min_partition_idx] += storage_size + partition_sums[min_partition_idx] += cost + + return partitions + + +def uniform_partition( + num_partitions: int, + sharding_options: List[ShardingOption], + mem_cap: List[Storage], +) -> List[List[Tuple[int, int]]]: + """ + We assign one shard to each rank. For example, For example if we have sharding_options = [ + [0,1,2,3], + [0,1,2,3], + ] with num_partitions=4 + The final output would be + [ + partition_0 = [(0,0),(1,0)] + partition_1 = [(0,1),(1,1)] + partition_2 = [(0,2),(1,2)] + partition_3 = [(0,3),(1,3)] + ] + """ + + shard_idxes: List[Tuple[int, int]] = [] + partition_size_sums = [Storage(hbm=0, ddr=0) for _ in range(num_partitions)] + + for option_idx, sharding_option in enumerate(sharding_options): + for shard_idx in range(sharding_option.num_shards): + shard_idxes.append((option_idx, shard_idx)) + + partitions: List[List[Tuple[int, int]]] = [[] for _ in range(num_partitions)] + for option_idx, shard_idx in shard_idxes: + storage_size = sharding_options[option_idx].shards[shard_idx].storage + if partition_size_sums[shard_idx] + storage_size > mem_cap[shard_idx]: + raise PartitionError( + f"Table of size {storage_size}GB cannot be added to any rank. partition_size_sums: {partition_size_sums}. mem_cap: {mem_cap}." + ) + partition_size_sums[shard_idx] += storage_size + partitions[shard_idx].append((option_idx, shard_idx)) + + return partitions + + +def _group_sharding_options( + sharding_options: List[ShardingOption], +) -> Dict[str, List[ShardingOption]]: + partition_by_groups = {} + for sharding_option in sharding_options: + if sharding_option.partition_by not in partition_by_groups: + partition_by_groups[sharding_option.partition_by] = [] + partition_by_groups[sharding_option.partition_by].append(sharding_option) + return partition_by_groups + + +class GreedyCostPartitioner(Partitioner): + """ + Greedy Partitioner + """ + + def run( + self, + sharding_options: List[ShardingOption], + topology: Topology, + ) -> None: + # pyre-ignore[16] + self._topology = topology + grouped_sharding_options = _group_sharding_options(sharding_options) + + if PartitionByType.UNIFORM.value in grouped_sharding_options: + self._partition_by_uniform( + grouped_sharding_options[PartitionByType.UNIFORM.value] + ) + if PartitionByType.HOST.value in grouped_sharding_options: + self._partition_by_host( + grouped_sharding_options[PartitionByType.HOST.value] + ) + if PartitionByType.DEVICE.value in grouped_sharding_options: + self._partition_by_device( + grouped_sharding_options[PartitionByType.DEVICE.value] + ) + + def _partition_by_uniform(self, sharding_options: List[ShardingOption]) -> None: + partitions = uniform_partition( + # pyre-ignore [16]: `GreedyCostPartitioner` has no attribute `_topology`. + num_partitions=self._topology.world_size, + sharding_options=sharding_options, + mem_cap=[device.storage for device in self._topology.devices], + ) + self._update_shards(partitions, sharding_options) + + def _partition_by_device(self, sharding_options: List[ShardingOption]) -> None: + # pyre-ignore [16]: `GreedyCostPartitioner` has no attribute `_topology`. + partition_sums = [float(device.cost) for device in self._topology.devices] + mem_cap: List[Storage] = [device.storage for device in self._topology.devices] + partitions = greedy_partition( + num_partitions=self._topology.world_size, + sharding_options=sharding_options, + partition_sums=partition_sums, + mem_cap=mem_cap, + ) + self._update_shards(partitions, sharding_options) + + def _partition_by_host(self, sharding_options: List[ShardingOption]) -> None: + # pyre-ignore [16]: `GreedyCostPartitioner` has no attribute `_topology`. + num_hosts: int = self._topology.world_size // self._topology.local_world_size + mem_cap: List[Storage] = [] + partition_sums = [] + + shard_idxes = [] + for option_idx, _ in enumerate(sharding_options): + # only take the first shard from each sharding option. We can infer the rest + shard_idxes.append((option_idx, 0)) + + for i in range(num_hosts): + devices_in_host = self._topology.devices[ + i + * self._topology.local_world_size : (i + 1) + * self._topology.local_world_size + ] + + # mem_cap of a host is the min of the storage of all devies on that host + mem_cap.append(min([device.storage for device in devices_in_host])) + # Cost of a host is the sum across all of its devices. Typically this should be zero at entry point. + partition_sums.append( + max([float(device.cost) for device in devices_in_host]) + ) + + host_level_partitions: List[List[Tuple[int, int]]] = greedy_partition( + num_partitions=num_hosts, + sharding_options=sharding_options, + shard_idxes=shard_idxes, + partition_sums=partition_sums, + mem_cap=mem_cap, + ) + partitions: List[List[Tuple[int, int]]] = [[] for _ in self._topology.devices] + for host_idx, host_partition in enumerate(host_level_partitions): + for [option_idx, shard_idx] in host_partition: + # each shard is placed on one device + # host+idx + offset is the device within that host + for offset in range(sharding_options[option_idx].num_shards): + partitions[ + self._topology.local_world_size * host_idx + offset + ].append((option_idx, shard_idx + offset)) + self._update_shards(partitions, sharding_options) + + def _update_shards( + self, + partitions: List[List[Tuple[int, int]]], + sharding_options: List[ShardingOption], + ) -> None: + + # here we update the ranks of the shards as well as device costs + for partition_idx, partition in enumerate(partitions): + for [option_idx, shard_idx] in partition: + sharding_options[option_idx].shards[shard_idx].rank = partition_idx + # pyre-ignore [16]: `GreedyCostPartitioner` has no attribute `_topology`. + self._topology.devices[partition_idx].storage -= ( + sharding_options[option_idx].shards[shard_idx].storage + ) + self._topology.devices[partition_idx].cost += ( + sharding_options[option_idx].shards[shard_idx].cost + ) diff --git a/torchrec/distributed/planner/new/placers.py b/torchrec/distributed/planner/new/placers.py new file mode 100644 index 000000000..203dd9be7 --- /dev/null +++ b/torchrec/distributed/planner/new/placers.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 + +import copy +from typing import Optional, List, Tuple, cast + +from torch.distributed._sharding_spec import EnumerableShardingSpec, ShardMetadata +from torchrec.distributed.planner.new.constants import MAX_SIZE +from torchrec.distributed.planner.new.types import ( + Partitioner, + Topology, + ShardingOption, + Placer, + RankStack, + PartitionError, + PlacerStats, + Shard, + Ranker, +) +from torchrec.distributed.types import ShardingPlan, ParameterSharding, ShardingType + + +def _merge_shards_by_dim(shards: List[Shard], dim: int) -> List[Shard]: + # merges shards down to one per rank along dimension. + # Will recompute shard offsets + merged_shards = [] + shards = sorted(shards, key=lambda x: x.rank) + + current_rank = -1 + current_shard: Optional[Shard] = None + current_dim_offset = 0 + for shard in shards: + if shard.rank != current_rank: + current_shard = copy.deepcopy(shard) + current_shard.offset[dim] = current_dim_offset + merged_shards.append(current_shard) + current_rank = shard.rank + else: + # pyre-ignore [16] + current_shard.length[dim] += shard.length[dim] + # pyre-ignore [16] + current_shard.storage += shard.storage + # pyre-ignore [16] + current_shard.cost += shard.cost + current_dim_offset += shard.length[dim] + return merged_shards + + +def _to_sharding_plan( + sharding_options: List[ShardingOption], + topology: Topology, +) -> ShardingPlan: + def _placement( + compute_device: str, + rank: int, + local_size: int, + ) -> str: + param_device = compute_device + if compute_device == "cuda": + param_device = f"cuda:{rank % local_size}" + return f"rank:{rank}/{param_device}" + + compute_device = topology.compute_device + local_size = topology.local_world_size + + plan = {} + for sharding_option in sharding_options: + shards = sharding_option.shards + sharding_type = sharding_option.sharding_type + if sharding_type == ShardingType.COLUMN_WISE.value: + shards = _merge_shards_by_dim(shards, 1) + if len(shards) == 1: + sharding_type = ShardingType.TABLE_WISE.value + + module_plan = plan.get(sharding_option.path, {}) + module_plan[sharding_option.name] = ParameterSharding( + sharding_spec=None + if sharding_type == ShardingType.DATA_PARALLEL.value + else EnumerableShardingSpec( + [ + ShardMetadata( + shard_lengths=shard.length, + shard_offsets=shard.offset, + placement=_placement( + compute_device, cast(int, shard.rank), local_size + ), + ) + for shard in shards + ] + ), + sharding_type=sharding_type, + compute_kernel=sharding_option.compute_kernel, + ranks=[cast(int, shard.rank) for shard in shards], + ) + plan[sharding_option.path] = module_plan + return ShardingPlan(plan) + + +class EmbeddingPlacer(Placer): + def __init__( + self, + topology: Topology, + partitioners: List[Partitioner], + rankers: List[Ranker], + ) -> None: + self._topology = topology + self._partitioners = partitioners + self._rankers = rankers + self._sharding_solution: Optional[List[ShardingOption]] = None + self._topology_solution: Optional[Topology] = None + self._counter = 0 + self._num_errors = 0 + + def run(self, sharding_options: List[ShardingOption]) -> ShardingPlan: + min_cost = MAX_SIZE + sharding_solution = None + topology_solution = None + for ranker in self._rankers: + rank_stack = ranker.run(copy.deepcopy(sharding_options)) + cur_sharding_options = rank_stack.bulk_pop() + while cur_sharding_options: + for partitioner in self._partitioners: + try: + sharding_candidate, topology_candidate = self._partition( + cur_sharding_options, partitioner + ) + cost = max( + [device.cost for device in topology_candidate.devices] + ) + if cost < min_cost: + sharding_solution = sharding_candidate + topology_solution = topology_candidate + min_cost = cost + except PartitionError: + self._num_errors += 1 + + self._counter += 1 + cur_sharding_options = self._backtrack(rank_stack, cur_sharding_options) + + if sharding_solution: + self._sharding_solution = sharding_solution + self._topology_solution = topology_solution + return _to_sharding_plan( + sharding_options=sharding_solution, topology=self._topology + ) + else: + raise PartitionError( + "Unable to find a plan for this model. Possible solutions:\n" + " 1) Increase the number of devices\n" + " 2) Reduce the model size\n" + " 3) Reduce batch size\n" + " 4) Remove planner constraints that might be reducing search space\n" + f"------ attempted {self._counter} iteration(s) ------\n" + ) + + @property + def stats(self) -> PlacerStats: + return PlacerStats( + num_iterations=self._counter, + num_errors=self._num_errors, + topology_solution=self._topology_solution, + sharding_solution=self._sharding_solution, + ) + + def _partition( + self, sharding_options: List[ShardingOption], partitioner: Partitioner + ) -> Tuple[List[ShardingOption], Topology]: + # create a working copy of topology and candidate + topology_candidate = copy.deepcopy(self._topology) + sharding_candidate = copy.deepcopy(sharding_options) + partitioner.run( + sharding_options=sharding_candidate, + topology=topology_candidate, + ) + return sharding_candidate, topology_candidate + + def _backtrack( + self, rank_stack: RankStack, sharding_options: List[ShardingOption] + ) -> List[ShardingOption]: + # attempt to remove sharding option with highest single shard storage cost + sharding_options.sort( + key=lambda x: ( + max([shard.storage.hbm for shard in x.shards]), + sum([shard.storage.hbm for shard in x.shards]), + max([shard.storage.ddr for shard in x.shards]), + sum([shard.storage.ddr for shard in x.shards]), + ), + reverse=True, + ) + idx = 0 + for sharding_option in sharding_options: + if rank_stack.remove(sharding_option): + break + idx += 1 + + if idx < len(sharding_options): + del sharding_options[idx] + sharding_options.append(rank_stack.pop()) + return sharding_options + + return [] diff --git a/torchrec/distributed/planner/new/planners.py b/torchrec/distributed/planner/new/planners.py new file mode 100644 index 000000000..c655271b4 --- /dev/null +++ b/torchrec/distributed/planner/new/planners.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 + +from typing import Dict, Optional, List + +import torch.distributed as dist +from torch import nn +from torchrec.distributed.collective_utils import ( + invoke_on_rank_and_broadcast_result, +) +from torchrec.distributed.planner.new.calculators import EmbeddingWTCostCalculator +from torchrec.distributed.planner.new.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.new.partitioners import GreedyCostPartitioner +from torchrec.distributed.planner.new.placers import EmbeddingPlacer +from torchrec.distributed.planner.new.rankers import DepthRanker, TotalWorkRanker +from torchrec.distributed.planner.new.stats import EmbeddingStats +from torchrec.distributed.planner.new.types import ( + PlannerConstraints, + InputStats, + Enumerator, + Placer, + Ranker, + Calculator, + Partitioner, + Topology, + Stats, +) +from torchrec.distributed.types import ( + ShardingPlan, + ShardingPlanner, + ModuleSharder, +) + + +def _reserve_storage_percentage(topology: Topology, percent: float) -> None: + for device in topology.devices: + device.storage.hbm = int((1 - percent) * device.storage.hbm) + device.storage.ddr = int((1 - percent) * device.storage.ddr) + + +class EmbeddingShardingPlanner(ShardingPlanner): + def __init__( + self, + topology: Topology, + components: Optional[Dict[str, object]] = None, + constraints: Optional[Dict[str, PlannerConstraints]] = None, + input_stats: Optional[Dict[str, InputStats]] = None, + ) -> None: + self._topology = topology + self._input_stats = input_stats + + if components is None: + components = {} + + self._enumerator: Enumerator = components.get( + "enumerator", + EmbeddingEnumerator( + topology=topology, + constraints=constraints, + input_stats=input_stats, + ), + ) + self._calculator: Calculator = components.get( + "calculator", + EmbeddingWTCostCalculator(topology=topology, constraints=constraints), + ) + self._partitioners: List[Partitioner] = components.get( + "paritioners", [GreedyCostPartitioner()] + ) + self._rankers: List[Ranker] = components.get( + "rankers", [DepthRanker(), TotalWorkRanker()] + ) + + self._placer: Placer = components.get( + "placer", + EmbeddingPlacer( + topology=topology, + partitioners=self._partitioners, + rankers=self._rankers, + ), + ) + self._stats: Stats = components.get("stats", EmbeddingStats()) + + def collective_plan( + self, + module: nn.Module, + sharders: List[ModuleSharder[nn.Module]], + pg: dist.ProcessGroup, + ) -> ShardingPlan: + """ + Call self.plan(...) on rank 0 and broadcast + """ + return invoke_on_rank_and_broadcast_result( + pg, + 0, + self.plan, + module, + sharders, + ) + + def plan( + self, + module: nn.Module, + sharders: List[ModuleSharder[nn.Module]], + ) -> ShardingPlan: + + # TODO: actually estimate storage of non-sharded modules: + _reserve_storage_percentage(self._topology, 0.40) + + sharding_options = self._enumerator.run(module=module, sharders=sharders) + self._calculator.run(sharding_options=sharding_options) + sharding_plan = self._placer.run(sharding_options=sharding_options) + self._stats.run( + sharding_plan=sharding_plan, + topology=self._topology, + placer_stats=self._placer.stats, + input_stats=self._input_stats, + ) + + return sharding_plan diff --git a/torchrec/distributed/planner/new/rankers.py b/torchrec/distributed/planner/new/rankers.py new file mode 100644 index 000000000..d7e0af5bd --- /dev/null +++ b/torchrec/distributed/planner/new/rankers.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 + +from typing import cast, List, Dict, Set + +from torchrec.distributed.planner.new.types import ( + Ranker, + RankStack, + ShardingOption, +) + + +class MinCostRankStack(RankStack): + """ + RankStack which orders sharding options on cost alone + """ + + def __init__( + self, + sharding_options: List[ShardingOption], + ) -> None: + fqns = {sharding_option.fqn for sharding_option in sharding_options} + + self._options_by_fqn: Dict[str, List[ShardingOption]] = {} + self._valid_fqns: Set[str] = fqns + self._remaining_fqns: Set[str] = set() + self._ordered_fqns: List[str] = [] + for fqn in fqns: + self._options_by_fqn[fqn] = [] + + self.bulk_push(sharding_options) + + def bulk_push(self, sharding_options: List[ShardingOption]) -> None: + seen_fqns = set() + for sharding_option in sharding_options: + seen_fqns.add(sharding_option.fqn) + self._push(sharding_option) + + for fqn in seen_fqns: + self._options_by_fqn[fqn].sort(key=lambda x: -x.cost) + self._reorder() + + def push(self, sharding_option: ShardingOption) -> None: + self._push(sharding_option) + self._options_by_fqn[sharding_option.fqn].sort(key=lambda x: -x.cost) + self._reorder() + + def _push(self, sharding_option: ShardingOption) -> None: + fqn = sharding_option.fqn + assert fqn in self._valid_fqns, f"Attempt to push unknown tensor {fqn}" + self._remaining_fqns.add(fqn) + self._options_by_fqn[fqn].append(sharding_option) + + def _reorder(self) -> None: + options = [ + fqn_options[-1] + for fqn, fqn_options in self._options_by_fqn.items() + if fqn in self._remaining_fqns + ] + options.sort(key=lambda x: -x.cost) + self._ordered_fqns = [option.fqn for option in options] + + def pop(self) -> ShardingOption: + fqn = self._ordered_fqns.pop() + sharding_option = self._options_by_fqn[fqn].pop() + self._remaining_fqns.remove(fqn) + return sharding_option + + def bulk_pop(self) -> List[ShardingOption]: + sharding_options = [] + num_fqns = len(self._ordered_fqns) + for _ in range(num_fqns): + sharding_options.append(self.pop()) + return sharding_options + + def remove(self, sharding_option: ShardingOption) -> bool: + fqn = sharding_option.fqn + assert fqn in self._valid_fqns, f"Attempt to remove unknown tensor {fqn}" + # check if another alternative option exists, if not return false + if not self._options_by_fqn[fqn]: + return False + self._remaining_fqns.add(fqn) + self._reorder() + return True + + def __len__(self) -> int: + return len(self._remaining_fqns) + + +class DepthRanker(Ranker): + def run(self, sharding_options: List[ShardingOption]) -> RankStack: + for sharding_option in sharding_options: + sharding_option.cost = max( + [cast(float, shard.cost) for shard in sharding_option.shards] + ) + return MinCostRankStack( + sharding_options=sharding_options, + ) + + +class TotalWorkRanker(Ranker): + def run(self, sharding_options: List[ShardingOption]) -> RankStack: + for sharding_option in sharding_options: + sharding_option.cost = sum( + [cast(float, shard.cost) for shard in sharding_option.shards] + ) + return MinCostRankStack( + sharding_options=sharding_options, + ) diff --git a/torchrec/distributed/planner/new/stats.py b/torchrec/distributed/planner/new/stats.py new file mode 100644 index 000000000..74608e8f4 --- /dev/null +++ b/torchrec/distributed/planner/new/stats.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 + +import logging +from typing import Tuple, Optional, Any, List, Dict + +from tabulate import tabulate +from torchrec.distributed.planner.new.types import ( + PlacerStats, + ShardingOption, + Stats, + Topology, + InputStats, +) +from torchrec.distributed.planner.utils import bytes_to_gb +from torchrec.distributed.types import ShardingType, ParameterSharding, ShardingPlan + + +logger: logging.Logger = logging.getLogger(__name__) + + +STATS_DIVIDER = "####################################################################################################" +STATS_BAR = f"#{'------------------------------------------------------------------------------------------------': ^98}#" + + +class EmbeddingStats(Stats): + def run( + self, + sharding_plan: ShardingPlan, + topology: Topology, + placer_stats: PlacerStats, + input_stats: Optional[Dict[str, InputStats]] = None, + ) -> None: + shard_by_fqn = { + module_name + "." + param_name: value + for module_name, param_dict in sharding_plan.plan.items() + for param_name, value in param_dict.items() + } + stats: Dict[int, Dict[str, Any]] = { + rank: {"type": {}, "pooling_factor": 0.0, "embedding_dims": 0} + for rank in range(topology.world_size) + } + sharding_solution = ( + placer_stats.sharding_solution if placer_stats.sharding_solution else [] + ) + if not placer_stats.topology_solution: + return + topology_solution = placer_stats.topology_solution + used_sharding_types = set() + + for sharding_option in sharding_solution: + fqn = sharding_option.fqn + + if shard_by_fqn.get(fqn) is None: + continue + shard: ParameterSharding = shard_by_fqn[fqn] + + ranks, pooling_factor, emb_dims = self._get_shard_stats( + shard=shard, + sharding_option=sharding_option, + world_size=topology.world_size, + local_size=topology.local_world_size, + input_stats=input_stats, + ) + sharding_type_abbr = _get_sharding_type_abbr(shard.sharding_type) + used_sharding_types.add(sharding_type_abbr) + + for i, rank in enumerate(ranks): + count = stats[rank]["type"].get(sharding_type_abbr, 0) + stats[rank]["type"][sharding_type_abbr] = count + 1 + stats[rank]["pooling_factor"] += pooling_factor[i] + stats[rank]["embedding_dims"] += emb_dims[i] + + table = [] + for rank, (initial_device, solution_device) in enumerate( + zip(topology.devices, topology_solution.devices) + ): + used_hbm = bytes_to_gb( + initial_device.storage.hbm - solution_device.storage.hbm + ) + used_hbm_ratio = ( + 1 - solution_device.storage.hbm / initial_device.storage.hbm + ) + used_ddr = bytes_to_gb( + initial_device.storage.ddr - solution_device.storage.ddr + ) + used_ddr_ratio = ( + 1 - solution_device.storage.ddr / initial_device.storage.ddr + ) + for sharding_type in used_sharding_types: + if sharding_type not in stats[rank]["type"]: + stats[rank]["type"][sharding_type] = 0 + + hbm = f"{used_hbm:.1f} ({used_hbm_ratio:.0%})" + ddr = f"{used_ddr:.1f} ({used_ddr_ratio:.0%})" + cost = f"{solution_device.cost / 1000:,.0f}" + pooling = f"{int(stats[rank]['pooling_factor']):,}" + dims = f"{stats[rank]['embedding_dims']:,}" + shards = " ".join( + f"{sharding_type}: {num_tables}" + for sharding_type, num_tables in sorted(stats[rank]["type"].items()) + ) + table.append([rank, hbm, ddr, cost, pooling, dims, shards]) + + headers = ["Rank", "HBM (GB)", "DDR (GB)", "Cost", "Input", "Output", "Shards"] + table = tabulate(table, headers=headers).split("\n") + + logger.info(STATS_DIVIDER) + header_text = "--- Planner Statistics ---" + logger.info(f"#{header_text: ^98}#") + + num_iterations = placer_stats.num_iterations + num_errors = placer_stats.num_errors + iter_text = ( + f"--- Ran {num_iterations} iteration(s), " + f"found {num_iterations - num_errors} possible plan(s) ---" + ) + logger.info(f"#{iter_text: ^98}#") + logger.info(STATS_BAR) + + for row in table: + logger.info(f"# {row: <97}#") + + logger.info(f"#{'' : ^98}#") + legend = "Input: pooling factor, Output: embedding dimension, Shards: number of tables" + logger.info(f"# {legend: <97}#") + logger.info(STATS_DIVIDER) + + def _get_shard_stats( + self, + shard: ParameterSharding, + sharding_option: ShardingOption, + world_size: int, + local_size: int, + input_stats: Optional[Dict[str, InputStats]] = None, + ) -> Tuple[List[int], List[float], List[int]]: + """ + Gets ranks, pooling factors, and embedding dimensions per shard + + Returns: + ranks: list of ranks + pooling_factor: list of pooling factors across ranks + emb_dims: list of embedding dimensions across ranks + """ + ranks = list(range(world_size)) + pooling_factor = [ + sum(input_stats[sharding_option.name].pooling_factors) + if input_stats and input_stats.get(sharding_option.name) + else 0.0 + ] + emb_dims = [sharding_option.tensor.shape[1]] + + if shard.sharding_type == ShardingType.DATA_PARALLEL.value: + emb_dims = emb_dims * len(ranks) + pooling_factor = pooling_factor * len(ranks) + + elif shard.sharding_type == ShardingType.TABLE_WISE.value: + assert shard.ranks + ranks = shard.ranks + + elif shard.sharding_type == ShardingType.COLUMN_WISE.value: + assert shard.ranks + ranks = shard.ranks + emb_dims = [ + int(shard.shard_lengths[1]) + # pyre-ignore [16] + for shard in shard.sharding_spec.shards + ] + pooling_factor = pooling_factor * len(ranks) + + elif shard.sharding_type == ShardingType.ROW_WISE.value: + pooling_factor = [pooling_factor[0] / world_size] * len(ranks) + emb_dims = emb_dims * len(ranks) + + elif shard.sharding_type == ShardingType.TABLE_ROW_WISE.value: + assert shard.ranks + host_id = shard.ranks[0] // local_size + ranks = list(range(host_id * local_size, (host_id + 1) * local_size)) + pooling_factor = [pooling_factor[0] / local_size] * len(ranks) + emb_dims = emb_dims * len(ranks) + + return ranks, pooling_factor, emb_dims + + +def _get_sharding_type_abbr(sharding_type: str) -> str: + if sharding_type == ShardingType.DATA_PARALLEL.value: + return "DP" + elif sharding_type == ShardingType.TABLE_WISE.value: + return "TW" + elif sharding_type == ShardingType.COLUMN_WISE.value: + return "CW" + elif sharding_type == ShardingType.ROW_WISE.value: + return "RW" + elif sharding_type == ShardingType.TABLE_ROW_WISE.value: + return "TWRW" + else: + raise ValueError(f"Unrecognized sharding type provided: {sharding_type}") diff --git a/torchrec/distributed/planner/new/tests/test_calculators.py b/torchrec/distributed/planner/new/tests/test_calculators.py new file mode 100644 index 000000000..37f8f6c4e --- /dev/null +++ b/torchrec/distributed/planner/new/tests/test_calculators.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 + +import unittest + +from torchrec.distributed.embedding_types import EmbeddingTableConfig +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollectionSharder, +) +from torchrec.distributed.planner.new.calculators import EmbeddingWTCostCalculator +from torchrec.distributed.planner.new.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.new.types import Topology +from torchrec.distributed.tests.test_model import TestSparseNN +from torchrec.fb.distributed.pooled_embedding_arch import PooledEmbeddingArchSharder +from torchrec.fb.modules.embedding_arch import PooledEmbeddingArch +from torchrec.modules.embedding_configs import EmbeddingBagConfig + + +class TestEmbeddingWTCostCalculator(unittest.TestCase): + def setUp(self) -> None: + topology = Topology(world_size=2, compute_device="cuda") + self.enumerator = EmbeddingEnumerator(topology=topology) + self.calculator = EmbeddingWTCostCalculator(topology=topology) + + def test_1_table_cost(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_0", + feature_names=["feature_0"], + ) + ] + model = TestSparseNN(tables=tables, weighted_tables=[]) + sharding_options = self.enumerator.run( + module=model, sharders=[EmbeddingBagCollectionSharder()] + ) + + self.calculator.run(sharding_options=sharding_options) + expected_costs = { + ("dense", "data_parallel"): [398.5666507405638, 398.5666507405638], + ("batched_dense", "data_parallel"): [378.9966555183946, 378.9966555183946], + ("dense", "table_wise"): [3543.7999681477945], + ("batched_dense", "table_wise"): [3504.659977703456], + ("batched_fused", "table_wise"): [3458.9966555183946], + ("sparse", "table_wise"): [3543.7999681477945], + ("batched_fused_uvm", "table_wise"): [83727.05882352941], + ("batched_fused_uvm_caching", "table_wise"): [22014.604904632153], + ("dense", "row_wise"): [3478.566650740564, 3478.566650740564], + ("batched_dense", "row_wise"): [3458.9966555183946, 3458.9966555183946], + ("batched_fused", "row_wise"): [3436.1649944258643, 3436.1649944258643], + ("sparse", "row_wise"): [3478.566650740564, 3478.566650740564], + ("batched_fused_uvm", "row_wise"): [43570.19607843138, 43570.19607843138], + ("batched_fused_uvm_caching", "row_wise"): [ + 12713.969118982744, + 12713.969118982744, + ], + ("dense", "table_row_wise"): [3546.833317407231, 3546.833317407231], + ("batched_dense", "table_row_wise"): [ + 3527.2633221850615, + 3527.2633221850615, + ], + ("batched_fused", "table_row_wise"): [3504.431661092531, 3504.431661092531], + ("sparse", "table_row_wise"): [3546.833317407231, 3546.833317407231], + ("batched_fused_uvm", "table_row_wise"): [ + 43638.46274509804, + 43638.46274509804, + ], + ("batched_fused_uvm_caching", "table_row_wise"): [ + 12782.23578564941, + 12782.23578564941, + ], + } + costs = { + ( + sharding_option.compute_kernel, + sharding_option.sharding_type, + ): [shard.cost for shard in sharding_option.shards] + for sharding_option in sharding_options + } + self.assertEqual(expected_costs, costs) + + def test_2_table_cost_for_pooledEmbArch_model(self) -> None: + tables = [ + EmbeddingTableConfig( + num_embeddings=100, + embedding_dim=10, + name="table_1", + feature_names=["feature_1"], + ) + ] + model = PooledEmbeddingArch( + tables=tables, + embedding_groups={"group_1": ["feature_1"]}, + ) + + sharding_options = self.enumerator.run( + module=model, sharders=[PooledEmbeddingArchSharder()] + ) + self.calculator.run(sharding_options=sharding_options) + expected_costs = { + ("dense", "data_parallel"): [398.5666507405638, 398.5666507405638], + ("batched_dense", "data_parallel"): [378.9966555183946, 378.9966555183946], + ("dense", "table_wise"): [3543.7999681477945], + ("batched_dense", "table_wise"): [3504.659977703456], + ("batched_fused", "table_wise"): [3458.9966555183946], + ("sparse", "table_wise"): [3543.7999681477945], + ("batched_fused_uvm", "table_wise"): [83727.05882352941], + ("batched_fused_uvm_caching", "table_wise"): [22014.604904632153], + ("dense", "table_row_wise"): [3546.833317407231, 3546.833317407231], + ("batched_dense", "table_row_wise"): [ + 3527.2633221850615, + 3527.2633221850615, + ], + ("batched_fused", "table_row_wise"): [3504.431661092531, 3504.431661092531], + ("sparse", "table_row_wise"): [3546.833317407231, 3546.833317407231], + ("batched_fused_uvm", "table_row_wise"): [ + 43638.46274509804, + 43638.46274509804, + ], + ("batched_fused_uvm_caching", "table_row_wise"): [ + 12782.23578564941, + 12782.23578564941, + ], + ("dense", "row_wise"): [3478.566650740564, 3478.566650740564], + ("batched_dense", "row_wise"): [3458.9966555183946, 3458.9966555183946], + ("batched_fused", "row_wise"): [3436.1649944258643, 3436.1649944258643], + ("sparse", "row_wise"): [3478.566650740564, 3478.566650740564], + ("batched_fused_uvm", "row_wise"): [43570.19607843138, 43570.19607843138], + ("batched_fused_uvm_caching", "row_wise"): [ + 12713.969118982744, + 12713.969118982744, + ], + ("dense", "column_wise"): [3543.7999681477945], + ("batched_dense", "column_wise"): [3504.659977703456], + ("batched_fused", "column_wise"): [3458.9966555183946], + ("sparse", "column_wise"): [3543.7999681477945], + ("batched_fused_uvm", "column_wise"): [83727.05882352941], + ("batched_fused_uvm_caching", "column_wise"): [22014.604904632153], + } + costs = { + ( + sharding_option.compute_kernel, + sharding_option.sharding_type, + ): [shard.cost for shard in sharding_option.shards] + for sharding_option in sharding_options + } + self.assertEqual(expected_costs, costs) diff --git a/torchrec/distributed/planner/new/tests/test_enumerators.py b/torchrec/distributed/planner/new/tests/test_enumerators.py new file mode 100644 index 000000000..026878a9b --- /dev/null +++ b/torchrec/distributed/planner/new/tests/test_enumerators.py @@ -0,0 +1,586 @@ +#!/usr/bin/env python3 + +import math +import unittest +from typing import List + +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, +) +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.planner.new.constants import ( + BIGINT_DTYPE, +) +from torchrec.distributed.planner.new.enumerators import ( + EmbeddingEnumerator, + _get_tw_shard_io_sizes, + _get_dp_shard_io_sizes, +) +from torchrec.distributed.planner.new.types import ( + InputStats, + PlannerConstraints, + Storage, + Topology, +) +from torchrec.distributed.tests.test_model import TestSparseNN +from torchrec.distributed.types import ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection + + +EXPECTED_RW_SHARD_LENGTHS = [ + [[13, 10], [13, 10], [13, 10], [13, 10], [13, 10], [13, 10], [13, 10], [9, 10]], + [[14, 20], [14, 20], [14, 20], [14, 20], [14, 20], [14, 20], [14, 20], [12, 20]], + [[15, 30], [15, 30], [15, 30], [15, 30], [15, 30], [15, 30], [15, 30], [15, 30]], + [[17, 40], [17, 40], [17, 40], [17, 40], [17, 40], [17, 40], [17, 40], [11, 40]], +] + +EXPECTED_RW_SHARD_OFFSETS = [ + [[0, 0], [13, 0], [26, 0], [39, 0], [52, 0], [65, 0], [78, 0], [91, 0]], + [[0, 0], [14, 0], [28, 0], [42, 0], [56, 0], [70, 0], [84, 0], [98, 0]], + [[0, 0], [15, 0], [30, 0], [45, 0], [60, 0], [75, 0], [90, 0], [105, 0]], + [[0, 0], [17, 0], [34, 0], [51, 0], [68, 0], [85, 0], [102, 0], [119, 0]], +] + +EXPECTED_RW_SHARD_STORAGE = [ + [ + Storage(hbm=85008, ddr=0), + Storage(hbm=85008, ddr=0), + Storage(hbm=85008, ddr=0), + Storage(hbm=85008, ddr=0), + Storage(hbm=85008, ddr=0), + Storage(hbm=85008, ddr=0), + Storage(hbm=85008, ddr=0), + Storage(hbm=84688, ddr=0), + ], + [ + Storage(hbm=512192, ddr=0), + Storage(hbm=512192, ddr=0), + Storage(hbm=512192, ddr=0), + Storage(hbm=512192, ddr=0), + Storage(hbm=512192, ddr=0), + Storage(hbm=512192, ddr=0), + Storage(hbm=512192, ddr=0), + Storage(hbm=511872, ddr=0), + ], + [ + Storage(hbm=515600, ddr=0), + Storage(hbm=515600, ddr=0), + Storage(hbm=515600, ddr=0), + Storage(hbm=515600, ddr=0), + Storage(hbm=515600, ddr=0), + Storage(hbm=515600, ddr=0), + Storage(hbm=515600, ddr=0), + Storage(hbm=515600, ddr=0), + ], + [ + Storage(hbm=1342784, ddr=0), + Storage(hbm=1342784, ddr=0), + Storage(hbm=1342784, ddr=0), + Storage(hbm=1342784, ddr=0), + Storage(hbm=1342784, ddr=0), + Storage(hbm=1342784, ddr=0), + Storage(hbm=1342784, ddr=0), + Storage(hbm=1340864, ddr=0), + ], +] + + +EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [ + [ + Storage(hbm=84176, ddr=832), + Storage(hbm=84176, ddr=832), + Storage(hbm=84176, ddr=832), + Storage(hbm=84176, ddr=832), + Storage(hbm=84176, ddr=832), + Storage(hbm=84176, ddr=832), + Storage(hbm=84176, ddr=832), + Storage(hbm=84112, ddr=576), + ], + [ + Storage(hbm=510400, ddr=1792), + Storage(hbm=510400, ddr=1792), + Storage(hbm=510400, ddr=1792), + Storage(hbm=510400, ddr=1792), + Storage(hbm=510400, ddr=1792), + Storage(hbm=510400, ddr=1792), + Storage(hbm=510400, ddr=1792), + Storage(hbm=510336, ddr=1536), + ], + [ + Storage(hbm=513296, ddr=2304), + Storage(hbm=513296, ddr=2304), + Storage(hbm=513296, ddr=2304), + Storage(hbm=513296, ddr=2304), + Storage(hbm=513296, ddr=2304), + Storage(hbm=513296, ddr=2304), + Storage(hbm=513296, ddr=2304), + Storage(hbm=513296, ddr=2304), + ], + [ + Storage(hbm=1341968, ddr=816), + Storage(hbm=1341968, ddr=816), + Storage(hbm=1341968, ddr=816), + Storage(hbm=1341968, ddr=816), + Storage(hbm=1341968, ddr=816), + Storage(hbm=1341968, ddr=816), + Storage(hbm=1341968, ddr=816), + Storage(hbm=1340336, ddr=528), + ], +] + + +EXPECTED_TWRW_SHARD_LENGTHS = [ + [[25, 10], [25, 10], [25, 10], [25, 10]], + [[28, 20], [28, 20], [28, 20], [26, 20]], + [[30, 30], [30, 30], [30, 30], [30, 30]], + [[33, 40], [33, 40], [33, 40], [31, 40]], +] + +EXPECTED_TWRW_SHARD_OFFSETS = [ + [[0, 0], [25, 0], [50, 0], [75, 0]], + [[0, 0], [28, 0], [56, 0], [84, 0]], + [[0, 0], [30, 0], [60, 0], [90, 0]], + [[0, 0], [33, 0], [66, 0], [99, 0]], +] + +EXPECTED_TWRW_SHARD_STORAGE = [ + [ + Storage(hbm=88016, ddr=0), + Storage(hbm=88016, ddr=0), + Storage(hbm=88016, ddr=0), + Storage(hbm=88016, ddr=0), + ], + [ + Storage(hbm=532864, ddr=0), + Storage(hbm=532864, ddr=0), + Storage(hbm=532864, ddr=0), + Storage(hbm=532544, ddr=0), + ], + [ + Storage(hbm=539680, ddr=0), + Storage(hbm=539680, ddr=0), + Storage(hbm=539680, ddr=0), + Storage(hbm=539680, ddr=0), + ], + [ + Storage(hbm=1374528, ddr=0), + Storage(hbm=1374528, ddr=0), + Storage(hbm=1374528, ddr=0), + Storage(hbm=1373888, ddr=0), + ], +] + +EXPECTED_CW_SHARD_LENGTHS = [ + [[100, 10]], + [[110, 8], [110, 12]], + [[120, 9], [120, 9], [120, 12]], + [[130, 12], [130, 12], [130, 16]], +] + +EXPECTED_CW_SHARD_OFFSETS = [ + [[0, 0]], + [[0, 0], [0, 8]], + [[0, 0], [0, 9], [0, 18]], + [[0, 0], [0, 12], [0, 24]], +] + +EXPECTED_CW_SHARD_STORAGE = [ + [Storage(hbm=106304, ddr=0)], + [Storage(hbm=351104, ddr=0), Storage(hbm=452928, ddr=0)], + [ + Storage(hbm=319936, ddr=0), + Storage(hbm=319936, ddr=0), + Storage(hbm=371968, ddr=0), + ], + [ + Storage(hbm=618688, ddr=0), + Storage(hbm=618688, ddr=0), + Storage(hbm=753920, ddr=0), + ], +] + + +class TWSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.TABLE_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.SPARSE.value] + + +class RWSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.ROW_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class UVMCachingRWSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.ROW_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING.value] + + +class TWRWSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.TABLE_ROW_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class CWSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.COLUMN_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class DPSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.DATA_PARALLEL.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.SPARSE.value] + + +class AllTypesSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ + ShardingType.DATA_PARALLEL.value, + ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.SPARSE.value, + EmbeddingComputeKernel.BATCHED_DENSE.value, + EmbeddingComputeKernel.BATCHED_FUSED.value, + EmbeddingComputeKernel.BATCHED_FUSED_UVM.value, + EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.BATCHED_QUANT.value, + ] + + +class TestEnumerators(unittest.TestCase): + def setUp(self) -> None: + self.compute_device = "cuda" + self.batch_size = 256 + self.world_size = 8 + self.local_world_size = 4 + self.constraints = { + "table_0": PlannerConstraints(min_partition=20), + "table_1": PlannerConstraints(min_partition=8), + "table_2": PlannerConstraints(min_partition=9, caching_ratio=0.36), + "table_3": PlannerConstraints(min_partition=12, caching_ratio=0.85), + } + self.input_stats = { + "table_0": InputStats(), + "table_1": InputStats(pooling_factors=[1, 3, 5]), + "table_2": InputStats(pooling_factors=[8, 2]), + "table_3": InputStats(pooling_factors=[2, 1, 3, 7]), + } + self.num_tables = 4 + tables = [ + EmbeddingBagConfig( + num_embeddings=100 + i * 10, + embedding_dim=10 + i * 10, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(self.num_tables) + ] + self.model = TestSparseNN(tables=tables, weighted_tables=[]) + self.enumerator = EmbeddingEnumerator( + topology=Topology( + world_size=self.world_size, + compute_device=self.compute_device, + local_world_size=self.local_world_size, + batch_size=self.batch_size, + ), + constraints=self.constraints, + input_stats=self.input_stats, + ) + + def test_dp_sharding(self) -> None: + # pyre-ignore[6] + sharding_options = self.enumerator.run(self.model, [DPSharder()]) + + for sharding_option in sharding_options: + self.assertEqual( + sharding_option.sharding_type, ShardingType.DATA_PARALLEL.value + ) + self.assertEqual( + [shard.length for shard in sharding_option.shards], + [list(sharding_option.tensor.shape)] * self.world_size, + ) + self.assertEqual( + [shard.offset for shard in sharding_option.shards], + [[0, 0]] * self.world_size, + ) + + input_data_type_size = BIGINT_DTYPE + output_data_type_size = sharding_option.tensor.element_size() + + input_sizes, output_sizes = _get_dp_shard_io_sizes( + batch_size=self.batch_size, + input_lengths=self.input_stats[sharding_option.name].pooling_factors, + emb_dim=sharding_option.tensor.shape[1], + num_shards=self.world_size, + input_data_type_size=input_data_type_size, + output_data_type_size=output_data_type_size, + ) + + tensor_sizes = [ + math.prod(sharding_option.tensor.shape) + * sharding_option.tensor.element_size() + ] * self.world_size + + gradient_sizes = ( + [ + input_sizes[0] + * sharding_option.tensor.shape[1] + * output_data_type_size + / input_data_type_size + ] + * self.world_size + if sharding_option.compute_kernel == EmbeddingComputeKernel.SPARSE.value + else tensor_sizes + ) + + optimizer_sizes = [tensor_size * 2 for tensor_size in tensor_sizes] + + storage_sizes = [ + input_size + tensor_size + output_size + gradient_size + optimizer_size + for input_size, tensor_size, output_size, gradient_size, optimizer_size in zip( + input_sizes, + tensor_sizes, + output_sizes, + gradient_sizes, + optimizer_sizes, + ) + ] + + expected_storage = [ + Storage(hbm=storage_size, ddr=0) for storage_size in storage_sizes + ] + + self.assertEqual( + [shard.storage for shard in sharding_option.shards], expected_storage + ) + + def test_tw_sharding(self) -> None: + # pyre-ignore[6] + sharding_options = self.enumerator.run(self.model, [TWSharder()]) + + for sharding_option in sharding_options: + self.assertEqual( + sharding_option.sharding_type, ShardingType.TABLE_WISE.value + ) + self.assertEqual( + sharding_option.shards[0].length, list(sharding_option.tensor.shape) + ) + self.assertEqual(sharding_option.shards[0].offset, [0, 0]) + + input_data_type_size = BIGINT_DTYPE + output_data_type_size = sharding_option.tensor.element_size() + + input_sizes, output_sizes = _get_tw_shard_io_sizes( + batch_size=self.batch_size, + world_size=self.world_size, + input_lengths=self.input_stats[sharding_option.name].pooling_factors, + emb_dim=sharding_option.tensor.shape[1], + input_data_type_size=input_data_type_size, + output_data_type_size=output_data_type_size, + ) + + tensor_size = ( + math.prod(sharding_option.tensor.shape) + * sharding_option.tensor.element_size() + ) + gradient_size = ( + ( + input_sizes[0] + * sharding_option.tensor.shape[1] + * output_data_type_size + / input_data_type_size + ) + if sharding_option.compute_kernel == EmbeddingComputeKernel.SPARSE.value + else tensor_size + ) + optimizer_size = 0 + + storage_size = ( + input_sizes[0] + + output_sizes[0] + + tensor_size + + gradient_size + + optimizer_size + ) + + self.assertEqual( + sharding_option.shards[0].storage, Storage(hbm=storage_size, ddr=0) + ) + + def test_rw_sharding(self) -> None: + # pyre-ignore[6] + sharding_options = self.enumerator.run(self.model, [RWSharder()]) + + for i, sharding_option in enumerate(sharding_options): + self.assertEqual(sharding_option.sharding_type, ShardingType.ROW_WISE.value) + self.assertEqual( + [shard.length for shard in sharding_option.shards], + EXPECTED_RW_SHARD_LENGTHS[i], + ) + self.assertEqual( + [shard.offset for shard in sharding_option.shards], + EXPECTED_RW_SHARD_OFFSETS[i], + ) + self.assertEqual( + [shard.storage for shard in sharding_option.shards], + EXPECTED_RW_SHARD_STORAGE[i], + ) + + def test_uvm_caching_rw_sharding(self) -> None: + # pyre-ignore[6] + sharding_options = self.enumerator.run(self.model, [UVMCachingRWSharder()]) + for i, sharding_option in enumerate(sharding_options): + self.assertEqual(sharding_option.sharding_type, ShardingType.ROW_WISE.value) + self.assertEqual( + [shard.length for shard in sharding_option.shards], + EXPECTED_RW_SHARD_LENGTHS[i], + ) + self.assertEqual( + [shard.offset for shard in sharding_option.shards], + EXPECTED_RW_SHARD_OFFSETS[i], + ) + self.assertEqual( + [shard.storage for shard in sharding_option.shards], + EXPECTED_UVM_CACHING_RW_SHARD_STORAGE[i], + ) + + def test_twrw_sharding(self) -> None: + # pyre-ignore[6] + sharding_options = self.enumerator.run(self.model, [TWRWSharder()]) + + for i, sharding_option in enumerate(sharding_options): + self.assertEqual( + sharding_option.sharding_type, ShardingType.TABLE_ROW_WISE.value + ) + self.assertEqual( + [shard.length for shard in sharding_option.shards], + EXPECTED_TWRW_SHARD_LENGTHS[i], + ) + self.assertEqual( + [shard.offset for shard in sharding_option.shards], + EXPECTED_TWRW_SHARD_OFFSETS[i], + ) + self.assertEqual( + [shard.storage for shard in sharding_option.shards], + EXPECTED_TWRW_SHARD_STORAGE[i], + ) + + def test_cw_sharding(self) -> None: + # pyre-ignore[6] + sharding_options = self.enumerator.run(self.model, [CWSharder()]) + + for i, sharding_option in enumerate(sharding_options): + self.assertEqual( + sharding_option.sharding_type, ShardingType.COLUMN_WISE.value + ) + self.assertEqual( + [shard.length for shard in sharding_option.shards], + EXPECTED_CW_SHARD_LENGTHS[i], + ) + self.assertEqual( + [shard.offset for shard in sharding_option.shards], + EXPECTED_CW_SHARD_OFFSETS[i], + ) + self.assertEqual( + [shard.storage for shard in sharding_option.shards], + EXPECTED_CW_SHARD_STORAGE[i], + ) + + def test_filtering(self) -> None: + constraint = PlannerConstraints( + sharding_types=[ + ShardingType.TABLE_ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ], + compute_kernels=[ + EmbeddingComputeKernel.SPARSE.value, + EmbeddingComputeKernel.BATCHED_FUSED_UVM.value, + EmbeddingComputeKernel.BATCHED_QUANT.value, + ], + ) + constraints = { + "table_0": constraint, + "table_1": constraint, + "table_2": constraint, + "table_3": constraint, + } + + enumerator = EmbeddingEnumerator( + topology=Topology( + world_size=self.world_size, + compute_device=self.compute_device, + local_world_size=self.local_world_size, + batch_size=self.batch_size, + ), + constraints=constraints, + ) + sharder = AllTypesSharder() + # pyre-ignore[6] + sharding_options = enumerator.run(self.model, [sharder]) + + expected_sharding_types = { + ShardingType.TABLE_ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + } + expected_compute_kernels = { + EmbeddingComputeKernel.SPARSE.value, + EmbeddingComputeKernel.BATCHED_FUSED_UVM.value, + EmbeddingComputeKernel.BATCHED_QUANT.value, + } + unexpected_sharding_types = ( + set(sharder.sharding_types(self.compute_device)) - expected_sharding_types + ) + unexpected_compute_kernels = ( + set(sharder.compute_kernels("", "")) - expected_compute_kernels + ) + + self.assertEqual( + len(sharding_options), + self.num_tables + * len(expected_sharding_types) + * len(expected_compute_kernels), + ) + + for sharding_option in sharding_options: + self.assertIn(sharding_option.sharding_type, expected_sharding_types) + self.assertNotIn(sharding_option.sharding_type, unexpected_sharding_types) + self.assertIn(sharding_option.compute_kernel, expected_compute_kernels) + self.assertNotIn(sharding_option.compute_kernel, unexpected_compute_kernels) diff --git a/torchrec/distributed/planner/new/tests/test_partitioners.py b/torchrec/distributed/planner/new/tests/test_partitioners.py new file mode 100644 index 000000000..7a9df6891 --- /dev/null +++ b/torchrec/distributed/planner/new/tests/test_partitioners.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python3 + +import copy +import unittest +from typing import List + +from torch import nn +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.planner.new.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.new.partitioners import GreedyCostPartitioner +from torchrec.distributed.planner.new.types import Storage, Topology, PartitionByType +from torchrec.distributed.tests.test_model import TestSparseNN +from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, +) + + +class RWSharder( + EmbeddingBagCollectionSharder[EmbeddingBagCollection], ModuleSharder[nn.Module] +): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.ROW_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class TWSharder( + EmbeddingBagCollectionSharder[EmbeddingBagCollection], ModuleSharder[nn.Module] +): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.TABLE_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class TWRWSharder( + EmbeddingBagCollectionSharder[EmbeddingBagCollection], ModuleSharder[nn.Module] +): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.TABLE_ROW_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class TestGreedyCostPartitioner(unittest.TestCase): + def setUp(self) -> None: + compute_device = "cuda" + self.topology = Topology(world_size=2, compute_device=compute_device) + tables = [ + EmbeddingBagConfig( + num_embeddings=100 + i, + embedding_dim=10 + i, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + self.topology = Topology(world_size=2, compute_device=compute_device) + self.model = TestSparseNN(tables=tables, weighted_tables=[]) + self.enumerator = EmbeddingEnumerator(topology=self.topology) + self.partitioner = GreedyCostPartitioner() + + def test_tw_balanced_cost_device(self) -> None: + sharding_options = self.enumerator.run( + module=self.model, sharders=[TWSharder()] + ) + + for sharding_option in sharding_options: + sharding_option.cost = 100 + sharding_option.shards[0].cost = 100 + sharding_option.shards[0].storage = Storage(hbm=1000, ddr=1000) + + candidate_topology = copy.deepcopy(self.topology) + self.partitioner.run( + sharding_options=sharding_options, + topology=candidate_topology, + ) + expected_ranks = { + "table_0": [1], + "table_1": [0], + "table_2": [1], + "table_3": [0], + } + + ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in sharding_options + } + self.assertEqual(expected_ranks, ranks) + + self.assertEqual(candidate_topology.devices[0].cost, 200) + self.assertEqual(candidate_topology.devices[1].cost, 200) + + self.assertEqual( + candidate_topology.devices[0].storage, + self.topology.devices[0].storage - Storage(2000, 2000), + ) + self.assertEqual( + candidate_topology.devices[1].storage, + self.topology.devices[1].storage - Storage(2000, 2000), + ) + + def test_tw_unbalanced_cost_device(self) -> None: + sharding_options = self.enumerator.run( + module=self.model, sharders=[TWSharder()] + ) + + for i, sharding_option in enumerate(sharding_options): + cost = 100 if i > 0 else 300 + sharding_option.cost = cost + sharding_option.shards[0].cost = cost + sharding_option.shards[0].storage = Storage(hbm=1000, ddr=1000) + + candidate_topology = copy.deepcopy(self.topology) + self.partitioner.run( + sharding_options=sharding_options, + topology=candidate_topology, + ) + expected_ranks = { + "table_0": [0], + "table_1": [1], + "table_2": [1], + "table_3": [1], + } + + ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in sharding_options + } + self.assertEqual(expected_ranks, ranks) + + self.assertEqual(candidate_topology.devices[0].cost, 300) + self.assertEqual(candidate_topology.devices[1].cost, 300) + + self.assertEqual( + candidate_topology.devices[0].storage, + self.topology.devices[0].storage - Storage(1000, 1000), + ) + self.assertEqual( + candidate_topology.devices[1].storage, + self.topology.devices[1].storage - Storage(3000, 3000), + ) + + def test_tw_balanced_cost_host(self) -> None: + self.topology = Topology( + world_size=16, local_world_size=8, compute_device="cuda" + ) + tables = [ + EmbeddingBagConfig( + num_embeddings=64, + embedding_dim=10 + i, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + + self.model = TestSparseNN(tables=tables, weighted_tables=[]) + self.enumerator = EmbeddingEnumerator(topology=self.topology) + self.partitioner = GreedyCostPartitioner() + sharding_options = self.enumerator.run( + module=self.model, sharders=[TWRWSharder()] + ) + for sharding_option in sharding_options: + cost = 100.0 + for shard in sharding_option.shards: + shard.cost = cost + shard.storage = Storage(hbm=1000, ddr=1000) + sharding_option.cost = cost * sharding_option.num_shards + sharding_option.partition_by = PartitionByType.HOST.value + + candidate_topology = copy.deepcopy(self.topology) + self.partitioner.run( + sharding_options=sharding_options, + topology=candidate_topology, + ) + + expected_ranks = { + "table_0": [8, 9, 10, 11, 12, 13, 14, 15], + "table_1": [0, 1, 2, 3, 4, 5, 6, 7], + "table_2": [8, 9, 10, 11, 12, 13, 14, 15], + "table_3": [0, 1, 2, 3, 4, 5, 6, 7], + } + + ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in sharding_options + } + self.assertEqual(expected_ranks, ranks) + + for i in range(self.topology.world_size): + self.assertEqual( + candidate_topology.devices[i].storage, + # there are two shards allocated to each device + self.topology.devices[i].storage - Storage(2000, 2000), + ) + + def test_rw_unbalanced_cost_uniform(self) -> None: + self.topology = Topology(world_size=4, compute_device="cuda") + tables = [ + EmbeddingBagConfig( + num_embeddings=64, + embedding_dim=10 + i, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + + self.model = TestSparseNN(tables=tables, weighted_tables=[]) + self.enumerator = EmbeddingEnumerator(topology=self.topology) + self.partitioner = GreedyCostPartitioner() + sharding_options = self.enumerator.run( + module=self.model, sharders=[RWSharder()] + ) + for sharding_option in sharding_options: + cost = 100.0 + for shard in sharding_option.shards: + shard.cost = cost + shard.storage = Storage(hbm=1000, ddr=1000) + sharding_option.cost = cost * sharding_option.num_shards + sharding_option.partition_by = PartitionByType.UNIFORM.value + + candidate_topology = copy.deepcopy(self.topology) + self.partitioner.run( + sharding_options=sharding_options, + topology=candidate_topology, + ) + + expected_ranks = { + "table_0": [0, 1, 2, 3], + "table_1": [0, 1, 2, 3], + "table_2": [0, 1, 2, 3], + "table_3": [0, 1, 2, 3], + } + + ranks = { + sharding_option.name: [shard.rank for shard in sharding_option.shards] + for sharding_option in sharding_options + } + self.assertEqual(expected_ranks, ranks) + + for i in range(self.topology.world_size): + self.assertEqual( + candidate_topology.devices[i].storage, + self.topology.devices[i].storage - Storage(4000, 4000), + ) diff --git a/torchrec/distributed/planner/new/tests/test_placers.py b/torchrec/distributed/planner/new/tests/test_placers.py new file mode 100644 index 000000000..0e19459cc --- /dev/null +++ b/torchrec/distributed/planner/new/tests/test_placers.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 + +import unittest +from typing import List, cast + +import torch +from torch import nn +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.planner.new.calculators import EmbeddingWTCostCalculator +from torchrec.distributed.planner.new.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.new.partitioners import GreedyCostPartitioner +from torchrec.distributed.planner.new.placers import EmbeddingPlacer +from torchrec.distributed.planner.new.rankers import DepthRanker +from torchrec.distributed.planner.new.types import Topology, PartitionError +from torchrec.distributed.tests.test_model import TestSparseNN +from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, +) + + +class TWvsRWSharder( + EmbeddingBagCollectionSharder[EmbeddingBagCollection], ModuleSharder[nn.Module] +): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.ROW_WISE.value, ShardingType.TABLE_WISE.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class TestEmbeddingShardingPlacer(unittest.TestCase): + def setUp(self) -> None: + compute_device = "cuda" + self.topology = Topology(world_size=2, compute_device=compute_device) + self.enumerator = EmbeddingEnumerator(topology=self.topology) + self.calculator = EmbeddingWTCostCalculator(topology=self.topology) + self.placer = EmbeddingPlacer( + topology=self.topology, + partitioners=[GreedyCostPartitioner()], + rankers=[DepthRanker()], + ) + + def test_tw_solution(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + sharding_options = self.enumerator.run(module=model, sharders=[TWvsRWSharder()]) + self.calculator.run(sharding_options=sharding_options) + sharding_plan = self.placer.run(sharding_options=sharding_options) + expected_ranks = [[0], [0], [1], [1]] + ranks = [ + cast(List[int], param_shard.ranks) + for param_shard in sharding_plan.plan["sparse.ebc"].values() + ] + self.assertEqual(sorted(expected_ranks), sorted(ranks)) + + def test_hidden_rw_solution(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(3) + ] + model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + sharding_options = self.enumerator.run(module=model, sharders=[TWvsRWSharder()]) + self.calculator.run(sharding_options=sharding_options) + sharding_plan = self.placer.run(sharding_options=sharding_options) + expected_ranks = [[0], [0, 1], [1]] + ranks = [ + cast(List[int], param_shard.ranks) + for param_shard in sharding_plan.plan["sparse.ebc"].values() + ] + self.assertEqual(sorted(expected_ranks), sorted(ranks)) + + def test_never_fit(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=10000000, + embedding_dim=10000000, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(10) + ] + model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + sharding_options = self.enumerator.run(module=model, sharders=[TWvsRWSharder()]) + self.calculator.run(sharding_options=sharding_options) + with self.assertRaises(PartitionError): + self.placer.run(sharding_options=sharding_options) + self.assertEqual(self.placer._counter, 11) diff --git a/torchrec/distributed/planner/new/tests/test_rankers.py b/torchrec/distributed/planner/new/tests/test_rankers.py new file mode 100644 index 000000000..a750eb51d --- /dev/null +++ b/torchrec/distributed/planner/new/tests/test_rankers.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 + +import unittest + +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollectionSharder, +) +from torchrec.distributed.planner.new.calculators import EmbeddingWTCostCalculator +from torchrec.distributed.planner.new.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.new.rankers import DepthRanker +from torchrec.distributed.planner.new.types import Topology +from torchrec.distributed.tests.test_model import TestSparseNN +from torchrec.modules.embedding_configs import EmbeddingBagConfig + + +class TestDepthRanker(unittest.TestCase): + def setUp(self) -> None: + topology = Topology(world_size=2, compute_device="cuda") + self.calculator = EmbeddingWTCostCalculator(topology=topology) + self.enumerator = EmbeddingEnumerator(topology=topology) + self.ranker = DepthRanker() + + def test_two_table_cost(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_0", + feature_names=["feature_0"], + ), + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_1", + feature_names=["feature_1"], + ), + ] + + model = TestSparseNN(tables=tables, weighted_tables=[]) + sharding_options = self.enumerator.run( + module=model, sharders=[EmbeddingBagCollectionSharder()] + ) + self.calculator.run(sharding_options) + rank_stack = self.ranker.run(sharding_options=sharding_options) + + # simulate first five iterations: + output = [] + for _ in range(5): + candidates = rank_stack.bulk_pop() + candidates.sort(key=lambda x: (x.cost, x.name)) + output.append( + [ + ( + candidate.name, + candidate.sharding_type, + candidate.compute_kernel, + ) + for candidate in candidates + ] + ) + drop = candidates[0] + keep = candidates[1:] + rank_stack.remove(drop) + rank_stack.bulk_push(keep) + + expected_output = [ + [ + ( + "table_0", + "data_parallel", + "batched_dense", + ), + ( + "table_1", + "data_parallel", + "batched_dense", + ), + ], + [ + ( + "table_1", + "data_parallel", + "batched_dense", + ), + ( + "table_0", + "data_parallel", + "dense", + ), + ], + [ + ( + "table_0", + "data_parallel", + "dense", + ), + ( + "table_1", + "data_parallel", + "dense", + ), + ], + [ + ( + "table_1", + "data_parallel", + "dense", + ), + ( + "table_0", + "row_wise", + "batched_fused", + ), + ], + [ + ( + "table_0", + "row_wise", + "batched_fused", + ), + ( + "table_1", + "row_wise", + "batched_fused", + ), + ], + ] + + self.assertEqual(expected_output, output) diff --git a/torchrec/distributed/planner/new/types.py b/torchrec/distributed/planner/new/types.py new file mode 100644 index 000000000..3e328eee4 --- /dev/null +++ b/torchrec/distributed/planner/new/types.py @@ -0,0 +1,373 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import abc +from dataclasses import field, dataclass +from enum import Enum +from typing import Optional, List, Dict, Tuple + +import torch +from torch import nn +from torchrec.distributed.planner.new.constants import ( + CROSS_NODE_BANDWIDTH, + INTRA_NODE_BANDWIDTH, + HBM_CAP, + DDR_CAP, + POOLING_FACTOR, + BATCH_SIZE, +) +from torchrec.distributed.types import ModuleSharder, ShardingPlan + +# ---- TOPOLOGY ---- # + + +@dataclass(repr=True, order=True, eq=True) +class Storage: + hbm: int + ddr: int + + def __add__(self, other: Storage) -> Storage: + return Storage( + hbm=self.hbm + other.hbm, + ddr=self.ddr + other.ddr, + ) + + def __sub__(self, other: Storage) -> Storage: + return Storage( + hbm=self.hbm - other.hbm, + ddr=self.ddr - other.ddr, + ) + + +@dataclass +class DeviceHardware: + rank: int + storage: Storage + cost: int = 0 + + +class Topology: + def __init__( + self, + world_size: int, + compute_device: str, + hbm_cap: Optional[int] = None, + ddr_cap: int = DDR_CAP, + local_world_size: Optional[int] = None, + intra_host_bw: int = INTRA_NODE_BANDWIDTH, + inter_host_bw: int = CROSS_NODE_BANDWIDTH, + batch_size: int = BATCH_SIZE, + ) -> None: + # validate input + assert compute_device in [ + "cpu", + "cuda", + ], f"unsupported compute device {compute_device}" + + self._compute_device = compute_device + self._world_size = world_size + + hbm_per_device = 0 + if self._compute_device == "cuda": + hbm_per_device = hbm_cap if hbm_cap else HBM_CAP + + self._devices: List[DeviceHardware] = [] + for rank in range(world_size): + self._devices.append( + DeviceHardware( + rank=rank, + storage=Storage(hbm=hbm_per_device, ddr=ddr_cap), + ) + ) + + self._local_world_size: int = ( + local_world_size if local_world_size else world_size + ) + self._intra_host_bw = intra_host_bw + self._inter_host_bw = inter_host_bw + self._batch_size = batch_size + + @property + def batch_size(self) -> int: + return self._batch_size + + @property + def compute_device(self) -> str: + return self._compute_device + + @property + def devices(self) -> List[DeviceHardware]: + return self._devices + + @property + def world_size(self) -> int: + return self._world_size + + @property + def local_world_size(self) -> int: + return self._local_world_size + + @property + def intra_host_bw(self) -> int: + return self._intra_host_bw + + @property + def inter_host_bw(self) -> int: + return self._inter_host_bw + + def __repr__(self) -> str: + topology_repr: str = f"world_size={self._world_size} \n" + topology_repr += f"compute_device={self._compute_device}\n" + topology_repr += "devices=\n" + for idx, device in enumerate(self._devices): + topology_repr += f"\tdevice {idx} {device}\n" + topology_repr += f"local_world_size={self._local_world_size} \n" + topology_repr += f"intra_host_bw={self._intra_host_bw} \n" + topology_repr += f"inter_host_bw={self._inter_host_bw} \n" + return topology_repr + + +# ---- INPUT / OUTPUT ----- # + + +@dataclass +class Shard: + length: List[int] + offset: List[int] + storage: Storage + cost: Optional[float] = None + rank: Optional[int] = None + + +@dataclass +class ShardingOption: + name: str + tensor: torch.Tensor + module: Tuple[str, nn.Module] + upstream_modules: List[Tuple[str, nn.Module]] + downstream_modules: List[Tuple[str, nn.Module]] + input_lengths: List[float] + batch_size: int # per single device + sharding_type: str + partition_by: str # {DEVICE, HOST, UNIFORM} + compute_kernel: str + cost: Optional[float] = None # main ranker value + # relevant to planner output, must be populated if sharding option + # part of final solution + shards: List[Shard] = field(default_factory=list) + + @property + def fqn(self) -> str: + return self.module[0] + "." + self.name + + @property + def path(self) -> str: + return self.module[0] + + @property + def num_shards(self) -> int: + return len(self.shards) + + @property + def num_inputs(self) -> int: + return len(self.input_lengths) + + +class PartitionByType(Enum): + """ + Well-known partition types + """ + + # Partitioning based on device + DEVICE = "device" + # Partitioning based on host + HOST = "host" + # Uniform, (ie. fixed layout) + UNIFORM = "uniform" + + +@dataclass +class PlannerConstraints: + """ + Stores user provided constraints around + sharding types, compute kernels and partitioning + """ + + sharding_types: Optional[List[str]] = None + compute_kernels: Optional[List[str]] = None + min_partition: Optional[int] = None # CW sharding + caching_ratio: Optional[float] = None # UVM caching + + +@dataclass +class InputStats: + """ + Stores statistics around input data for + a given tensor + """ + + pooling_factors: List[float] = field(default_factory=lambda: [POOLING_FACTOR]) + + +class PartitionError(Exception): + ... + + +@dataclass +class PlacerStats: + num_iterations: int + num_errors: int + topology_solution: Optional[Topology] + sharding_solution: Optional[List[ShardingOption]] + + +# ---- PLANNER COMPONENTS ---- # + + +class Enumerator(abc.ABC): + """ + Generate all relevant sharding options for given nn.Module, + input stats and user constraints + """ + + @abc.abstractmethod + def __init__( + self, + topology: Topology, + constraints: Optional[Dict[str, PlannerConstraints]] = None, + input_stats: Optional[Dict[str, InputStats]] = None, + ) -> None: + ... + + @abc.abstractmethod + def run( + self, module: nn.Module, sharders: List[ModuleSharder[nn.Module]] + ) -> List[ShardingOption]: + ... + + +class Calculator(abc.ABC): + """ + Calculate costs, requires fully specified sharding options + (ie. ranks/lengths) + """ + + @abc.abstractmethod + def __init__( + self, + topology: Topology, + constraints: Optional[Dict[str, PlannerConstraints]] = None, + ) -> None: + ... + + @abc.abstractmethod + def run(self, sharding_options: List[ShardingOption]) -> None: + # actual costs + ... + + +class RankStack(abc.ABC): + """ + "Stack"-like interface to manage complexity of providing + next sharding option for placer + """ + + @abc.abstractmethod + def pop(self) -> ShardingOption: + # pop next sharding option, no more than one sharding option per tensor + # should be returned + ... + + @abc.abstractmethod + def push(self, sharding_option: ShardingOption) -> None: + # push back shading_option, rerank as necessary + ... + + @abc.abstractmethod + def remove(self, sharding_option: ShardingOption) -> bool: + # remove a given sharding_option from consideration + ... + + @abc.abstractmethod + def bulk_pop(self) -> List[ShardingOption]: + # pop any remaining sharing options + ... + + @abc.abstractmethod + def bulk_push(self, sharding_options: List[ShardingOption]) -> None: + # push a list of sharding options + ... + + +class Ranker(abc.ABC): + """ + Given a calculator, topology and sharding options, populate a + RankStack and return it + """ + + @abc.abstractmethod + def run(self, sharding_options: List[ShardingOption]) -> RankStack: + ... + + +class Partitioner(abc.ABC): + """ + Parition + + Today we have multiple strategies ie. + (Greedy, BLDM, Linear) + """ + + @abc.abstractmethod + def run( + self, + sharding_options: List[ShardingOption], + topology: Topology, + ) -> None: + # modifies sharding_options and topology in-place + ... + + +class Placer(abc.ABC): + """ + Controls actual placement via: + 1) calls to rank stack + 2) calling into partitioners + 3) final ShardingOptions + 4) determining stopping conditions + """ + + @abc.abstractmethod + def __init__( + self, + topology: Topology, + partitioners: Optional[List[Partitioner]] = None, + rankers: Optional[List[Ranker]] = None, + ) -> None: + ... + + @abc.abstractmethod + def run(self, sharding_options: List[ShardingOption]) -> ShardingPlan: + ... + + @property + @abc.abstractmethod + def stats(self) -> PlacerStats: + ... + + +class Stats(abc.ABC): + """ + Log statistics related to the sharding plan + """ + + @abc.abstractmethod + def run( + self, + sharding_plan: ShardingPlan, + topology: Topology, + placer_stats: PlacerStats, + input_stats: Optional[Dict[str, InputStats]] = None, + ) -> None: + ... diff --git a/torchrec/distributed/planner/parameter_sharding.py b/torchrec/distributed/planner/parameter_sharding.py new file mode 100644 index 000000000..c20df0150 --- /dev/null +++ b/torchrec/distributed/planner/parameter_sharding.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +import abc +import itertools +import math +from typing import List, Tuple, cast + +import torch +from torch.distributed._sharding_spec import EnumerableShardingSpec, ShardMetadata +from torchrec.distributed.planner.types import ParameterInfo +from torchrec.distributed.types import ShardingType, ParameterSharding + + +def _twrw_shard_table_rows( + table_node: int, + hash_size: int, + embedding_dim: int, + world_size: int, + local_size: int, +) -> Tuple[List[int], List[int], List[int]]: + + block_size = math.ceil(hash_size / local_size) + last_rank_offset = hash_size // block_size + last_block_size = hash_size - block_size * (last_rank_offset) + + first_local_rank = (table_node) * local_size + last_local_rank = first_local_rank + last_rank_offset + local_rows: List[int] = [] + local_cols: List[int] = [] + local_row_offsets: List[int] = [] + cumul_row_offset = 0 + for rank in range(world_size): + local_col = embedding_dim + if rank < first_local_rank: + local_row = 0 + local_col = 0 + elif rank < last_local_rank: + local_row = block_size + elif rank == last_local_rank: + local_row = last_block_size + else: + local_row = 0 + local_rows.append(local_row) + local_cols.append(local_col) + local_row_offsets.append(cumul_row_offset) + cumul_row_offset += local_row + + return (local_rows, local_cols, local_row_offsets) + + +def _rw_shard_table_rows(hash_size: int, world_size: int) -> Tuple[List[int], int, int]: + block_size = (hash_size + world_size - 1) // world_size + last_rank = hash_size // block_size + last_block_size = hash_size - block_size * last_rank + local_rows: List[int] = [] + for rank in range(world_size): + if rank < last_rank: + local_row = block_size + elif rank == last_rank: + local_row = last_block_size + else: + local_row = 0 + local_rows.append(local_row) + return (local_rows, block_size, last_rank) + + +def _device_placement( + compute_device_type: str, + rank: int, + local_size: int, +) -> str: + param_device = torch.device("cpu") + if compute_device_type == "cuda": + param_device = torch.device("cuda", rank % local_size) + return f"rank:{rank}/{param_device}" + + +class ParameterShardingFactory(abc.ABC): + @staticmethod + def shard_parameters( + param_info: ParameterInfo, + compute_device_type: str, + world_size: int, + local_size: int, + ) -> ParameterSharding: + sharding_option = param_info.sharding_options[0] + sharding_type = sharding_option.sharding_type + if sharding_type == ShardingType.TABLE_WISE.value: + parameter_sharding = TwParameterSharding.shard_parameters( + param_info, compute_device_type, world_size, local_size + ) + elif sharding_type == ShardingType.ROW_WISE.value: + parameter_sharding = RwParameterSharding.shard_parameters( + param_info, compute_device_type, world_size, local_size + ) + elif sharding_type == ShardingType.TABLE_ROW_WISE.value: + parameter_sharding = TwRwParameterSharding.shard_parameters( + param_info, compute_device_type, world_size, local_size + ) + elif sharding_type == ShardingType.COLUMN_WISE.value: + parameter_sharding = CwParameterSharding.shard_parameters( + param_info, compute_device_type, world_size, local_size + ) + elif sharding_type == ShardingType.DATA_PARALLEL.value: + parameter_sharding = DpParameterSharding.shard_parameters( + param_info, compute_device_type, world_size, local_size + ) + else: + raise ValueError( + f"unsupported {sharding_option.sharding_type} sharding type" + ) + return parameter_sharding + + +class TwParameterSharding: + @classmethod + def shard_parameters( + cls, + param_info: ParameterInfo, + compute_device_type: str, + world_size: int, + local_size: int, + ) -> ParameterSharding: + sharding_option = param_info.sharding_options[0] + tensor = param_info.param + # pyre-fixme [16] + rank = sharding_option.ranks[0] + shards = [ + ShardMetadata( + shard_lengths=[ + tensor.shape[0], + tensor.shape[1], + ], + shard_offsets=[0, 0], + placement=_device_placement(compute_device_type, rank, local_size), + ) + ] + return ParameterSharding( + sharding_spec=EnumerableShardingSpec(shards), + sharding_type=sharding_option.sharding_type, + compute_kernel=sharding_option.compute_kernel, + ranks=sharding_option.ranks, + ) + + +class RwParameterSharding: + @classmethod + def shard_parameters( + cls, + param_info: ParameterInfo, + compute_device_type: str, + world_size: int, + local_size: int, + ) -> ParameterSharding: + sharding_option = param_info.sharding_options[0] + tensor = param_info.param + local_rows, block_size, last_rank = _rw_shard_table_rows( + tensor.shape[0], world_size + ) + shards = [ + ShardMetadata( + shard_lengths=[ + local_rows[rank], + tensor.shape[1], + ], + shard_offsets=[block_size * min(rank, last_rank), 0], + placement=_device_placement(compute_device_type, rank, local_size), + ) + for rank in range(world_size) + ] + return ParameterSharding( + sharding_type=sharding_option.sharding_type, + compute_kernel=sharding_option.compute_kernel, + ranks=sharding_option.ranks, + block_size=block_size, + sharding_spec=EnumerableShardingSpec(shards), + ) + + +class TwRwParameterSharding: + @classmethod + def shard_parameters( + cls, + param_info: ParameterInfo, + compute_device_type: str, + world_size: int, + local_size: int, + ) -> ParameterSharding: + sharding_option = param_info.sharding_options[0] + tensor = param_info.param + # pyre-fixme [16] + rank = sharding_option.ranks[0] + table_node = rank // local_size + local_rows, local_cols, local_row_offsets = _twrw_shard_table_rows( + table_node=table_node, + hash_size=tensor.shape[0], + embedding_dim=tensor.shape[1], + world_size=world_size, + local_size=local_size, + ) + shards = [ + ShardMetadata( + shard_lengths=[ + local_rows[rank], + local_cols[rank], + ], + shard_offsets=[local_row_offsets[rank], 0], + placement=_device_placement(compute_device_type, rank, local_size), + ) + for rank in range(table_node * local_size, (table_node + 1) * local_size) + ] + + return ParameterSharding( + sharding_type=sharding_option.sharding_type, + compute_kernel=sharding_option.compute_kernel, + ranks=sharding_option.ranks, + block_size=math.ceil(tensor.shape[0] / local_size), + sharding_spec=EnumerableShardingSpec(shards), + ) + + +class CwParameterSharding: + @classmethod + def shard_parameters( + cls, + param_info: ParameterInfo, + compute_device_type: str, + world_size: int, + local_size: int, + ) -> ParameterSharding: + sharding_option = param_info.sharding_options[0] + tensor = param_info.param + # pyre-fixme [6] + ranks = sorted(sharding_option.ranks) + block_size = cast(int, sharding_option.col_wise_shard_dim) + num_col_wise_shards, residual = divmod(tensor.shape[1], block_size) + sizes = [block_size] * num_col_wise_shards + if residual > 0: + sizes += [residual] + merged_sizes = [] + merged_ranks = [] + for i, rank in enumerate(ranks): + if rank not in merged_ranks: + merged_ranks.append(rank) + merged_sizes.append(sizes[i]) + else: + merged_sizes[-1] += sizes[i] + offsets = [0] + list(itertools.accumulate(merged_sizes))[:-1] + shards = [ + ShardMetadata( + shard_lengths=[ + tensor.shape[0], + merged_sizes[i], + ], + shard_offsets=[0, offsets[i]], + placement=_device_placement(compute_device_type, rank, local_size), + ) + for i, rank in enumerate(merged_ranks) + ] + return ParameterSharding( + sharding_type=sharding_option.sharding_type, + compute_kernel=sharding_option.compute_kernel, + ranks=merged_ranks, + block_size=block_size, + sharding_spec=EnumerableShardingSpec(shards), + ) + + +class DpParameterSharding: + @classmethod + def shard_parameters( + cls, + param_info: ParameterInfo, + compute_device_type: str, + world_size: int, + local_size: int, + ) -> ParameterSharding: + sharding_option = param_info.sharding_options[0] + return ParameterSharding( + sharding_type=sharding_option.sharding_type, + compute_kernel=sharding_option.compute_kernel, + ranks=sharding_option.ranks, + ) diff --git a/torchrec/distributed/planner/tests/test_embedding_planner.py b/torchrec/distributed/planner/tests/test_embedding_planner.py new file mode 100644 index 000000000..322af91e2 --- /dev/null +++ b/torchrec/distributed/planner/tests/test_embedding_planner.py @@ -0,0 +1,930 @@ +#!/usr/bin/env python3 + +import unittest +from typing import List +from unittest.mock import MagicMock, patch, call + +from torch.distributed._sharding_spec import ShardMetadata, EnumerableShardingSpec +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.planner.embedding_planner import EmbeddingShardingPlanner +from torchrec.distributed.planner.parameter_sharding import _rw_shard_table_rows +from torchrec.distributed.planner.types import ParameterHints +from torchrec.distributed.planner.utils import MIN_DIM +from torchrec.distributed.tests.test_model import TestSparseNN +from torchrec.distributed.types import ParameterSharding, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, +) + + +class CWSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.COLUMN_WISE.value] + + """ + Restricts to single impl. + """ + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class DPCWSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ + ShardingType.COLUMN_WISE.value, + ShardingType.DATA_PARALLEL.value, + ] + + """ + Restricts to single impl. + """ + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class TWSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ShardingType.TABLE_WISE.value] + + """ + Restricts to single impl. + """ + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class DPTWSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ + ShardingType.TABLE_WISE.value, + ShardingType.DATA_PARALLEL.value, + ] + + """ + Restricts to single impl. + """ + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class DPRWTWSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ + ShardingType.TABLE_WISE.value, + ShardingType.DATA_PARALLEL.value, + ShardingType.ROW_WISE.value, + ] + + """ + Restricts to single impl. + """ + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class TWRWSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ + ShardingType.TABLE_ROW_WISE.value, + ] + + """ + Restricts to single impl. + """ + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +class TestEmbeddingPlanner(unittest.TestCase): + def setUp(self) -> None: + # Mocks + self.compute_device_type = "cuda" + + @patch("torchrec.distributed.planner.embedding_planner.logger", create=True) + def test_allocation_planner_balanced(self, mock_logger: MagicMock) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=100 + i, + embedding_dim=10 + i, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + storage = {"hbm": 1} + expected_plan = { + "sparse.ebc": { + "table_0": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[0], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings, + tables[0].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ) + ] + ), + ), + "table_1": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[1], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[1].num_embeddings, + tables[1].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:1/cuda:1", + ) + ] + ), + ), + "table_2": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[1], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[2].num_embeddings, + tables[2].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:1/cuda:1", + ) + ] + ), + ), + "table_3": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[0], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[3].num_embeddings, + tables[3].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ) + ] + ), + ), + } + } + + model = TestSparseNN(tables=tables, weighted_tables=[]) + world_size = 2 + planner = EmbeddingShardingPlanner( + world_size=world_size, + compute_device_type=self.compute_device_type, + storage=storage, + ) + + sharders = [TWSharder()] + # pyre-ignore [6] + output = planner.plan(model, sharders) + self.assertEqual(output.plan, expected_plan) + + # check logger + self.assertEqual( + mock_logger.mock_calls[1 : world_size + 1], + [ + call.info( + " Rank 0 -- HBM/DDR: 0.0/0.0, Cost: 2308, Mean Pooling: 0, Emb Dims: 23, Shards: {'table_wise': 2}" + ), + call.info( + " Rank 1 -- HBM/DDR: 0.0/0.0, Cost: 2307, Mean Pooling: 0, Emb Dims: 23, Shards: {'table_wise': 2}" + ), + ], + ) + + @patch("torchrec.distributed.planner.embedding_planner.logger", create=True) + def test_twrw_no_constraints(self, mock_logger: MagicMock) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=100 + i, + embedding_dim=10 + i, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(1) + ] + expected_plan = { + "sparse.ebc": { + "table_0": ParameterSharding( + sharding_type=ShardingType.TABLE_ROW_WISE.value, + compute_kernel="dense", + ranks=[1], + block_size=50, + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings // 2, + tables[0].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings // 2, + tables[0].embedding_dim, + ], + shard_offsets=[50, 0], + placement="rank:1/cuda:1", + ), + ] + ), + ), + } + } + + model = TestSparseNN(tables=tables, weighted_tables=[]) + world_size = 2 + planner = EmbeddingShardingPlanner( + world_size=world_size, + compute_device_type=self.compute_device_type, + ) + + sharders = [TWRWSharder()] + # pyre-ignore [6] + output = planner.plan(model, sharders) + self.assertEqual(output.plan, expected_plan) + + # check logger + self.assertEqual( + mock_logger.mock_calls[1 : world_size + 1], + [ + call.info( + " Rank 0 -- HBM/DDR: 0.0/0.0, Cost: 1500, Mean Pooling: 0, Emb Dims: 10, Shards: {'table_row_wise': 1}" + ), + call.info( + " Rank 1 -- HBM/DDR: 0.0/0.0, Cost: 1500, Mean Pooling: 0, Emb Dims: 10, Shards: {'table_row_wise': 1}" + ), + ], + ) + + @patch("torchrec.distributed.planner.embedding_planner.logger", create=True) + def test_twrw_no_constraints_edge_case(self, mock_logger: MagicMock) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=5, + embedding_dim=10, + name="table_0", + feature_names=["feature_0"], + ) + ] + expected_plan = { + "sparse.ebc": { + "table_0": ParameterSharding( + sharding_type=ShardingType.TABLE_ROW_WISE.value, + compute_kernel="dense", + ranks=[3], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + 2, + 10, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + 2, + 10, + ], + shard_offsets=[2, 0], + placement="rank:1/cuda:1", + ), + ShardMetadata( + shard_lengths=[ + 1, + 10, + ], + shard_offsets=[4, 0], + placement="rank:2/cuda:2", + ), + ShardMetadata( + shard_lengths=[ + 0, + 10, + ], + shard_offsets=[5, 0], + placement="rank:3/cuda:3", + ), + ] + ), + ), + } + } + + model = TestSparseNN(tables=tables, weighted_tables=[]) + world_size = 4 + planner = EmbeddingShardingPlanner( + world_size=world_size, + compute_device_type=self.compute_device_type, + ) + + sharders = [TWRWSharder()] + # pyre-ignore [6] + output = planner.plan(model, sharders) + self.assertEqual(output.plan, expected_plan) + + @patch("torchrec.distributed.planner.embedding_planner.logger", create=True) + def test_allocation_planner_one_big_rest_small( + self, mock_logger: MagicMock + ) -> None: + big_hash = int(1024 * 1024 * 1024 / 16 / 4) + small_hash = 1000 + tables = [ + EmbeddingBagConfig( + num_embeddings=big_hash if i == 0 else small_hash, + embedding_dim=16, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + + storage = {"hbm": 1} + + expected_plan = { + "sparse.ebc": { + "table_0": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[0], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings, + tables[0].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ) + ] + ), + ), + "table_1": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[1], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[1].num_embeddings, + tables[1].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:1/cuda:1", + ) + ] + ), + ), + "table_2": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[1], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[2].num_embeddings, + tables[2].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:1/cuda:1", + ) + ] + ), + ), + "table_3": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[1], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[3].num_embeddings, + tables[3].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:1/cuda:1", + ) + ] + ), + ), + } + } + model = TestSparseNN(tables=tables, weighted_tables=[]) + + world_size = 2 + planner = EmbeddingShardingPlanner( + world_size=world_size, + compute_device_type=self.compute_device_type, + storage=storage, + ) + sharders = [DPTWSharder()] + # pyre-ignore [6] + output = planner.plan(model, sharders) + self.assertEqual(output.plan, expected_plan) + + # check logger + self.assertEqual( + mock_logger.mock_calls[1 : world_size + 1], + [ + call.info( + " Rank 0 -- HBM/DDR: 1.0/0.0, Cost: 5780, Mean Pooling: 0, Emb Dims: 16, Shards: {'table_wise': 1}" + ), + call.info( + " Rank 1 -- HBM/DDR: 0.0/0.0, Cost: 7200, Mean Pooling: 0, Emb Dims: 48, Shards: {'table_wise': 3}" + ), + ], + ) + + @patch("torchrec.distributed.planner.embedding_planner.logger", create=True) + def test_allocation_planner_two_big_rest_small( + self, mock_logger: MagicMock + ) -> None: + big_hash = int(1024 * 1024 * 1024 / 16 / 4) + small_hash = 1000 + tables = [ + EmbeddingBagConfig( + num_embeddings=big_hash if i <= 1 else small_hash, + embedding_dim=16, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + + storage = {"hbm": 1.1} + + expected_plan = { + "sparse.ebc": { + "table_0": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[0], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings, + tables[0].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ) + ] + ), + ), + "table_1": ParameterSharding( + sharding_type=ShardingType.TABLE_WISE.value, + compute_kernel="dense", + ranks=[1], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[1].num_embeddings, + tables[1].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:1/cuda:1", + ) + ] + ), + ), + "table_2": ParameterSharding( + sharding_type=ShardingType.DATA_PARALLEL.value, + compute_kernel="dense", + ranks=None, + sharding_spec=None, + ), + "table_3": ParameterSharding( + sharding_type=ShardingType.DATA_PARALLEL.value, + compute_kernel="dense", + ranks=None, + sharding_spec=None, + ), + } + } + model = TestSparseNN(tables=tables, weighted_tables=[]) + + world_size = 2 + planner = EmbeddingShardingPlanner( + world_size=world_size, + compute_device_type=self.compute_device_type, + # pyre-fixme[6]: Expected `Optional[typing.Dict[str, int]]` for 3rd + # param but got `Dict[str, float]`. + storage=storage, + ) + sharders = [DPRWTWSharder()] + # pyre-ignore [6] + output = planner.plan(model, sharders) + self.assertEqual(output.plan, expected_plan) + + # check logger + self.assertEqual( + mock_logger.mock_calls[1 : world_size + 1], + [ + call.info( + " Rank 0 -- HBM/DDR: 1.0/0.0, Cost: 8180, Mean Pooling: 0, Emb Dims: 48, Shards: {'table_wise': 1, 'data_parallel': 2}" + ), + call.info( + " Rank 1 -- HBM/DDR: 1.0/0.0, Cost: 8180, Mean Pooling: 0, Emb Dims: 48, Shards: {'table_wise': 1, 'data_parallel': 2}" + ), + ], + ) + + @patch("torchrec.distributed.planner.embedding_planner.logger", create=True) + def test_allocation_planner_rw_two_big_rest_small( + self, mock_logger: MagicMock + ) -> None: + big_hash = int(1024 * 1024 * 1024 / 16 / 4) + small_hash = 1000 + tables = [ + EmbeddingBagConfig( + num_embeddings=big_hash if i <= 1 else small_hash, + embedding_dim=16, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + local_rows, block_size, last_rank = _rw_shard_table_rows(big_hash, 4) + storage = {"hbm": 0.6} + + expected_plan = { + "sparse.ebc": { + "table_0": ParameterSharding( + sharding_type=ShardingType.ROW_WISE.value, + compute_kernel="dense", + ranks=None, + block_size=block_size, + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + local_rows[0], + tables[0].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + local_rows[1], + tables[0].embedding_dim, + ], + shard_offsets=[block_size, 0], + placement="rank:1/cuda:1", + ), + ShardMetadata( + shard_lengths=[ + local_rows[2], + tables[0].embedding_dim, + ], + shard_offsets=[2 * block_size, 0], + placement="rank:2/cuda:2", + ), + ShardMetadata( + shard_lengths=[ + local_rows[3], + tables[0].embedding_dim, + ], + shard_offsets=[3 * block_size, 0], + placement="rank:3/cuda:3", + ), + ], + ), + ), + "table_1": ParameterSharding( + sharding_type=ShardingType.ROW_WISE.value, + compute_kernel="dense", + ranks=None, + block_size=block_size, + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + local_rows[0], + tables[1].embedding_dim, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + local_rows[1], + tables[1].embedding_dim, + ], + shard_offsets=[block_size, 0], + placement="rank:1/cuda:1", + ), + ShardMetadata( + shard_lengths=[ + local_rows[2], + tables[1].embedding_dim, + ], + shard_offsets=[2 * block_size, 0], + placement="rank:2/cuda:2", + ), + ShardMetadata( + shard_lengths=[ + local_rows[3], + tables[1].embedding_dim, + ], + shard_offsets=[3 * block_size, 0], + placement="rank:3/cuda:3", + ), + ], + ), + ), + "table_2": ParameterSharding( + sharding_type=ShardingType.DATA_PARALLEL.value, + compute_kernel="dense", + ranks=None, + ), + "table_3": ParameterSharding( + sharding_type=ShardingType.DATA_PARALLEL.value, + compute_kernel="dense", + ranks=None, + ), + } + } + model = TestSparseNN(tables=tables, weighted_tables=[]) + + world_size = 4 + planner = EmbeddingShardingPlanner( + world_size=world_size, + compute_device_type=self.compute_device_type, + # pyre-fixme[6]: Expected `Optional[typing.Dict[str, int]]` for 3rd + # param but got `Dict[str, float]`. + storage=storage, + ) + sharders = [DPRWTWSharder()] + # pyre-ignore [6] + output = planner.plan(model, sharders) + self.assertEqual(output.plan, expected_plan) + + # check logger + self.assertEqual( + mock_logger.mock_calls[1 : world_size + 1], + [ + call.info( + f" Rank {rank} -- HBM/DDR: 0.5/0.0, Cost: 31298, Mean Pooling: 0, Emb Dims: 64, Shards: {{'row_wise': 2, 'data_parallel': 2}}" + ) + for rank in range(world_size) + ], + ) + + @patch("torchrec.distributed.planner.embedding_planner.logger", create=True) + def test_allocation_planner_cw_balanced(self, mock_logger: MagicMock) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=128, + name="table_0", + feature_names=["feature_0"], + ) + ] + storage = {"hbm": 1} + block_size, residual = divmod(128, 2) + expected_plan = { + "sparse.ebc": { + "table_0": ParameterSharding( + sharding_type=ShardingType.COLUMN_WISE.value, + compute_kernel="dense", + ranks=[ + 0, + 1, + ], + block_size=block_size, + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings, + block_size, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings, + block_size, + ], + shard_offsets=[0, block_size], + placement="rank:1/cuda:1", + ), + ] + ), + ), + } + } + + model = TestSparseNN(tables=tables, weighted_tables=[]) + world_size = 2 + planner = EmbeddingShardingPlanner( + world_size=world_size, + compute_device_type=self.compute_device_type, + storage=storage, + hints={ + "table_0": ParameterHints( + sharding_types=[ShardingType.COLUMN_WISE.value], + col_wise_shard_dim=64, + ), + }, + ) + + sharders = [CWSharder()] + # pyre-ignore [6] + output = planner.plan(model, sharders) + self.assertEqual(output.plan, expected_plan) + # check logger + self.assertEqual( + mock_logger.mock_calls[1 : world_size + 1], + [ + call.info( + " Rank 0 -- HBM/DDR: 0.0/0.0, Cost: 12800, Mean Pooling: 0, Emb Dims: 64, Shards: {'column_wise': 1}" + ), + call.info( + " Rank 1 -- HBM/DDR: 0.0/0.0, Cost: 12800, Mean Pooling: 0, Emb Dims: 64, Shards: {'column_wise': 1}" + ), + ], + ) + + @patch("torchrec.distributed.planner.embedding_planner.logger", create=True) + def test_allocation_planner_cw_two_big_rest_small_with_residual( + self, mock_logger: MagicMock + ) -> None: + big_hash = int(1024 * 1024 * 1024 / 16 / 4) + small_hash = 1000 + tables = [ + EmbeddingBagConfig( + num_embeddings=(big_hash if i <= 1 else small_hash) // 4, + embedding_dim=62, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + + num_shards, residual = divmod(62, MIN_DIM) + if residual > 0: + num_shards += 1 + + storage = {"hbm": 0.6} + + expected_plan = { + "sparse.ebc": { + "table_0": ParameterSharding( + sharding_type=ShardingType.COLUMN_WISE.value, + compute_kernel="dense", + ranks=[0, 1], + block_size=MIN_DIM, + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings, + MIN_DIM, + ], + shard_offsets=[0, 0], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_lengths=[ + tables[0].num_embeddings, + residual, + ], + shard_offsets=[0, MIN_DIM], + placement="rank:1/cuda:1", + ), + ] + ), + ), + "table_1": ParameterSharding( + sharding_type=ShardingType.COLUMN_WISE.value, + compute_kernel="dense", + ranks=[2, 3], + block_size=MIN_DIM, + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_lengths=[ + tables[1].num_embeddings, + MIN_DIM, + ], + shard_offsets=[0, 0], + placement="rank:2/cuda:2", + ), + ShardMetadata( + shard_lengths=[ + tables[1].num_embeddings, + residual, + ], + shard_offsets=[0, MIN_DIM], + placement="rank:3/cuda:3", + ), + ] + ), + ), + "table_2": ParameterSharding( + sharding_type=ShardingType.DATA_PARALLEL.value, + compute_kernel="dense", + ranks=None, + ), + "table_3": ParameterSharding( + sharding_type=ShardingType.DATA_PARALLEL.value, + compute_kernel="dense", + ranks=None, + ), + } + } + model = TestSparseNN(tables=tables, weighted_tables=[]) + + world_size = 4 + planner = EmbeddingShardingPlanner( + world_size=world_size, + compute_device_type=self.compute_device_type, + # pyre-fixme[6]: Expected `Optional[typing.Dict[str, int]]` for 3rd + # param but got `Dict[str, float]`. + storage=storage, + hints={ + "table_0": ParameterHints( + sharding_types=[ShardingType.COLUMN_WISE.value], + col_wise_shard_dim=32, + ), + "table_1": ParameterHints( + sharding_types=[ShardingType.COLUMN_WISE.value], + col_wise_shard_dim=32, + ), + }, + ) + sharders = [DPCWSharder()] + # pyre-ignore [6] + output = planner.plan(model, sharders) + self.assertEqual(output.plan, expected_plan) + + # check logger + self.assertEqual( + mock_logger.mock_calls[1 : world_size + 1], + [ + call.info( + " Rank 0 -- HBM/DDR: 0.5/0.0, Cost: 27964, Mean Pooling: 0, Emb Dims: 156, Shards: {'column_wise': 1, 'data_parallel': 2}" + ), + call.info( + " Rank 1 -- HBM/DDR: 0.5/0.0, Cost: 27964, Mean Pooling: 0, Emb Dims: 154, Shards: {'column_wise': 1, 'data_parallel': 2}" + ), + call.info( + " Rank 2 -- HBM/DDR: 0.5/0.0, Cost: 27964, Mean Pooling: 0, Emb Dims: 156, Shards: {'column_wise': 1, 'data_parallel': 2}" + ), + call.info( + " Rank 3 -- HBM/DDR: 0.5/0.0, Cost: 27964, Mean Pooling: 0, Emb Dims: 154, Shards: {'column_wise': 1, 'data_parallel': 2}" + ), + ], + ) diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py new file mode 100644 index 000000000..ba444cc7c --- /dev/null +++ b/torchrec/distributed/planner/types.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 + +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Deque, Tuple + +import torch +from torchrec.distributed.types import ParameterStorage + + +@dataclass +class ParameterHints: + """ + Stores user provided hints around + sharding types and compute kernels + """ + + sharding_types: Optional[List[str]] = None + compute_kernels: Optional[List[str]] = None + col_wise_shard_dim: Optional[int] = None + + +@dataclass +class ParameterInputStats: + """ + Stores statistics around input data for + a given parameter + """ + + mean: Optional[List[float]] = None + std: Optional[List[float]] = None + + +@dataclass +class Storage: + capacity: int = 0 + free: int = 0 + + +@dataclass +class DeviceInfo: + rank: int + compute_device: str = "cpu" + total_cost: int = 0 + # Device level storage + hbm: Storage = field(default_factory=Storage) + + def __lt__(self, other: "DeviceInfo") -> bool: + return (self.total_cost, -self.hbm.free, self.rank) < ( + other.total_cost, + -other.hbm.free, + other.rank, + ) + + +@dataclass +class HostInfo: + devices: List[DeviceInfo] + # Host level storage + ddr: Storage = field(default_factory=Storage) + + +@dataclass +class Topology: + hosts: List[HostInfo] + world_size: int + host_and_device_by_rank: Dict[int, Tuple[int, int]] = field(default_factory=dict) + + def get_host(self, rank: int) -> HostInfo: + host_idx, _ = self.host_and_device_by_rank[rank] + return self.hosts[host_idx] + + def get_device(self, rank: int) -> DeviceInfo: + host_idx, device_idx = self.host_and_device_by_rank[rank] + return self.hosts[host_idx].devices[device_idx] + + +@dataclass +class ShardingOption: + sharding_type: str + compute_kernel: str + storage_usage: Dict[str, int] + cost: int = 0 + ranks: Optional[List[int]] = None + _num_col_wise_shards: Optional[int] = None + col_wise_shard_dim: Optional[int] = None + + def __lt__(self, other: "ShardingOption") -> bool: + """ + Sharding option with lowest cost is preferable + If cost same, pick option with lowest (HBM, DDR, SDD) usage + """ + return ( + self.cost, + self.storage_usage.get(ParameterStorage.HBM.value, 0), + self.storage_usage.get(ParameterStorage.DDR.value, 0), + ) < ( + other.cost, + other.storage_usage.get(ParameterStorage.HBM.value, 0), + other.storage_usage.get(ParameterStorage.DDR.value, 0), + ) + + +@dataclass +class CostInput: + param: torch.Tensor + compute_device_type: str + compute_kernel: str + sharding_type: str + input_stats: Optional[ParameterInputStats] + + +@dataclass +class ParameterInfo: + param: torch.Tensor + name: str + prefix: str + sharding_options: Deque[ShardingOption] + + @property + def fqn(self) -> str: + return self.prefix + "." + self.name + + +@dataclass +class ParamSortKey: + compute_cost: int + storage_cost: int + sharding_cost: int + fqn: str + sort_by: str = "compute" + + def __lt__(self, other: "ParamSortKey") -> bool: + if self.sort_by == "compute": + return self._lt_compute_cost(other) + elif self.sort_by == "storage": + return self._lt_storage_cost(other) + else: + raise ValueError(f"Invalid sort_by value {self.sort_by}") + + def _lt_compute_cost(self, other: "ParamSortKey") -> bool: + return ( + -self.compute_cost, + -self.storage_cost, + self.sharding_cost, + self.fqn, + ) < (-other.compute_cost, -other.storage_cost, other.sharding_cost, other.fqn) + + def _lt_storage_cost(self, other: "ParamSortKey") -> bool: + return ( + -self.storage_cost, + self.sharding_cost, + -self.compute_cost, + self.fqn, + ) < (-other.storage_cost, other.sharding_cost, -other.compute_cost, other.fqn) diff --git a/torchrec/distributed/planner/utils.py b/torchrec/distributed/planner/utils.py new file mode 100644 index 000000000..ba944b517 --- /dev/null +++ b/torchrec/distributed/planner/utils.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 + +import math +from typing import Any, Type, Dict, Optional, List, cast + +from torchrec.distributed.comm import get_local_size, get_num_groups +from torchrec.distributed.planner.parameter_sharding import ParameterShardingFactory +from torchrec.distributed.planner.types import ( + ShardingOption, + Topology, + DeviceInfo, + HostInfo, + ParameterInfo, + Storage, + ParamSortKey, +) +from torchrec.distributed.types import ( + ParameterStorage, + ShardingPlan, + ShardingType, +) + +DEFAULT_DDR_STORAGE: int = 4 * 1024 * 1024 * 1024 * 1024 # 4 TB +DEFAULT_HBM_STORAGE: int = 32 * 1024 * 1024 * 1024 # 32 GB +MIN_DIM: int = 32 + +SHARDING_PREFERENCE: Dict[str, int] = { + ShardingType.DATA_PARALLEL.value: 0, + ShardingType.TABLE_WISE.value: 1, + ShardingType.TABLE_ROW_WISE.value: 2, + ShardingType.ROW_WISE.value: 3, + ShardingType.COLUMN_WISE.value: 4, +} + + +def gb_to_bytes(gb: float) -> int: + return int(gb * 1024 * 1024 * 1024) + + +def bytes_to_gb(num_bytes: int) -> float: + return float(num_bytes / (1024 * 1024 * 1024)) + + +# pyre-ignore[2] +def sharder_name(t: Type[Any]) -> str: + return t.__module__ + "." + t.__name__ + + +def is_enough_storage( + sharding_option: ShardingOption, + topology: Topology, + device: Optional[DeviceInfo] = None, +) -> bool: + storage = sharding_option.storage_usage + device_ranks = range(topology.world_size) + host_ranks = range(len(topology.hosts)) + if sharding_option.sharding_type == ShardingType.DATA_PARALLEL.value: + pass + elif sharding_option.sharding_type == ShardingType.ROW_WISE.value: + storage = { + k: math.ceil(v / len(device_ranks if k == "hbm" else host_ranks)) + for k, v in storage.items() + } + elif sharding_option.sharding_type == ShardingType.TABLE_ROW_WISE.value: + assert ( + device is not None + ), "Sharding option must have a device for TWRW storage calcuation" + device_ranks = [ + device.rank for device in topology.get_host(device.rank).devices + ] + host_ranks = [topology.host_and_device_by_rank[device.rank][0]] + storage = { + k: math.ceil(v / len(device_ranks if k == "hbm" else host_ranks)) + for k, v in storage.items() + } + elif sharding_option.sharding_type == ShardingType.TABLE_WISE.value: + assert ( + device is not None + ), "Sharding option must have a device for TW storage calcuation" + device_ranks = [device.rank] + host_ranks = [topology.host_and_device_by_rank[device.rank][0]] + elif sharding_option.sharding_type == ShardingType.COLUMN_WISE.value: + assert ( + device is not None + ), "Sharding option must have a device for CW storage calcuation" + device_ranks = [device.rank] + host_ranks = [topology.host_and_device_by_rank[device.rank][0]] + storage = { + # pyre-fixme[58] + k: math.ceil(v / sharding_option._num_col_wise_shards) + for k, v in storage.items() + } + else: + raise ValueError(f"unsupported sharding_type {sharding_option.sharding_type}") + for storage_type, storage_usage in storage.items(): + if storage_type == ParameterStorage.HBM.value: + for device_rank in device_ranks: + if topology.get_device(device_rank).hbm.free < storage_usage: + return False + elif storage_type == ParameterStorage.DDR.value: + for host_rank in host_ranks: + if topology.get_host(host_rank).ddr.free < storage_usage: + return False + else: + raise ValueError(f"Unknown ParameterStorage type {storage_type}") + return True + + +def allocate_param( + sharding_option: ShardingOption, topology: Topology, is_deallocation: bool = False +) -> None: + """ + Reduces relevant free storage in toplogy based on sharding option + + Setting is_deallocation=True will do inverse (free up storage) + """ + storage = sharding_option.storage_usage + device_ranks = range(topology.world_size) + host_ranks = range(len(topology.hosts)) + if sharding_option.sharding_type == ShardingType.DATA_PARALLEL.value: + pass + elif sharding_option.sharding_type == ShardingType.ROW_WISE.value: + storage = { + k: math.ceil(v / len(device_ranks if k == "hbm" else host_ranks)) + for k, v in storage.items() + } + elif sharding_option.sharding_type == ShardingType.TABLE_ROW_WISE.value: + assert ( + sharding_option.ranks is not None + ), "Sharding option must have a device for TWRW storage calcuation" + device_ranks = [ + device.rank + # pyre-fixme[22]: The cast is redundant. + for device in topology.get_host(cast(int, sharding_option.ranks[0])).devices + ] + host_ranks = [ + # pyre-fixme[22]: The cast is redundant. + topology.host_and_device_by_rank[cast(int, sharding_option.ranks[0])][0] + ] + storage = { + k: math.ceil(v / len(device_ranks if k == "hbm" else host_ranks)) + for k, v in storage.items() + } + elif sharding_option.sharding_type == ShardingType.TABLE_WISE.value: + assert ( + sharding_option.ranks is not None + ), "Sharding option must have a device for TW storage calcuation" + # pyre-fixme[22]: The cast is redundant. + device_ranks = [cast(int, sharding_option.ranks[0])] + # pyre-fixme[16] + host_ranks = [topology.host_and_device_by_rank[sharding_option.ranks[0]][0]] + elif sharding_option.sharding_type == ShardingType.COLUMN_WISE.value: + assert ( + sharding_option.ranks is not None + ), "Sharding option must have at least one device for CW storage calcuation" + # for col-wise sharding, we allocate one shard at a time + device_ranks = [sharding_option.ranks[-1]] + host_ranks = [topology.host_and_device_by_rank[sharding_option.ranks[-1]][0]] + storage = { + # pyre-fixme[58] + k: math.ceil(v / sharding_option._num_col_wise_shards) + for k, v in storage.items() + } + else: + raise ValueError(f"unsupported sharding_type {sharding_option.sharding_type}") + + for storage_type, storage_usage in storage.items(): + if is_deallocation: + storage_usage = -storage_usage + if storage_type == ParameterStorage.HBM.value: + for device_rank in device_ranks: + topology.get_device(device_rank).hbm.free -= storage_usage + elif storage_type == ParameterStorage.DDR.value: + for host_rank in host_ranks: + topology.get_host(host_rank).ddr.free -= storage_usage + else: + raise ValueError(f"Unknown ParameterStorage type {storage_type}") + + for device_rank in device_ranks: + cost = -sharding_option.cost if is_deallocation else sharding_option.cost + topology.get_device(device_rank).total_cost += cost + + +def deallocate_param( + sharding_option: ShardingOption, + topology: Topology, +) -> None: + allocate_param(sharding_option, topology, is_deallocation=True) + + +def param_sort_key( + parameter_info: ParameterInfo, world_size: int, sort_by: str = "compute" +) -> ParamSortKey: + sharding_option = parameter_info.sharding_options[0] + compute_cost = sharding_option.cost + storage_cost = sum(sharding_option.storage_usage.values()) + if sharding_option.sharding_type == ShardingType.DATA_PARALLEL.value: + storage_cost *= world_size + sharding_preference = SHARDING_PREFERENCE[ + parameter_info.sharding_options[0].sharding_type + ] + return ParamSortKey( + compute_cost=compute_cost, + storage_cost=storage_cost, + sharding_cost=sharding_preference, + fqn=parameter_info.fqn, + sort_by=sort_by, + ) + + +def to_plan( + parameter_infos: List[ParameterInfo], + compute_device_type: str, + world_size: int, + local_size: int, +) -> ShardingPlan: + plan = {} + for parameter_info in parameter_infos: + shards = plan.get(parameter_info.prefix, {}) + shards[parameter_info.name] = ParameterShardingFactory.shard_parameters( + param_info=parameter_info, + compute_device_type=compute_device_type, + world_size=world_size, + local_size=local_size, + ) + plan[parameter_info.prefix] = shards + return ShardingPlan(plan) + + +def _get_storage( + compute_device: str, storage_in_gb: Optional[Dict[str, int]] +) -> Dict[str, int]: + if storage_in_gb is None: + storage_in_gb = {} + + hbm = storage_in_gb.get("hbm", None) + if hbm is None and compute_device == "cuda": + hbm = DEFAULT_HBM_STORAGE + elif hbm is None: + hbm = 0 + else: + hbm = gb_to_bytes(hbm) + + ddr = storage_in_gb.get("ddr", None) + if ddr is None: + ddr = DEFAULT_DDR_STORAGE + + return { + "hbm": hbm, + "ddr": ddr, + } + + +def get_topology( + world_size: int, + compute_device_type: str, + storage_in_gb: Optional[Dict[str, int]], +) -> Topology: + devices_per_host = get_local_size(world_size) + num_hosts = get_num_groups(world_size) + compute_device = compute_device_type + storage = _get_storage(compute_device_type, storage_in_gb) + topology = Topology( + hosts=[ + HostInfo( + devices=[ + DeviceInfo( + rank=rank, + compute_device=compute_device, + hbm=Storage( + capacity=storage["hbm"], + free=storage["hbm"], + ), + ) + for rank in range( + num_host * devices_per_host, + min(world_size, (num_host + 1) * devices_per_host), + ) + ], + ddr=Storage( + capacity=storage["ddr"], + free=storage["ddr"], + ), + ) + for num_host in range(num_hosts) + ], + world_size=world_size, + ) + for i, host in enumerate(topology.hosts): + for j, device in enumerate(host.devices): + topology.host_and_device_by_rank[device.rank] = (i, j) + return topology diff --git a/torchrec/distributed/rw_sharding.py b/torchrec/distributed/rw_sharding.py new file mode 100644 index 000000000..d8b431a40 --- /dev/null +++ b/torchrec/distributed/rw_sharding.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python3 + +from typing import List, Optional, Dict, Any, Tuple + +import torch +import torch.distributed as dist +from torch.distributed._sharding_spec import ShardMetadata +from torchrec.distributed.dist_data import ( + PooledEmbeddingsReduceScatter, + SequenceEmbeddingAllToAll, +) +from torchrec.distributed.embedding_lookup import ( + GroupedPooledEmbeddingsLookup, + GroupedEmbeddingsLookup, +) +from torchrec.distributed.embedding_sharding import ( + group_tables, + SparseFeaturesAllToAll, + BasePooledEmbeddingDist, + BaseSparseFeaturesDist, + EmbeddingSharding, + BaseSequenceEmbeddingDist, + SequenceShardingContext, + BaseEmbeddingLookup, + bucketize_kjt_before_all2all, +) +from torchrec.distributed.embedding_types import ( + ShardedEmbeddingTable, + GroupedEmbeddingConfig, + SparseFeatures, + EmbeddingComputeKernel, + BaseGroupedFeatureProcessor, +) +from torchrec.distributed.types import ( + Awaitable, + ParameterSharding, +) +from torchrec.modules.embedding_configs import EmbeddingTableConfig + + +class RwSparseFeaturesDist(BaseSparseFeaturesDist): + def __init__( + self, + pg: dist.ProcessGroup, + num_id_list_features: int, + num_id_score_list_features: int, + id_list_feature_hash_sizes: List[int], + id_score_list_feature_hash_sizes: List[int], + device: Optional[torch.device] = None, + is_sequence: bool = False, + has_feature_processor: bool = False, + ) -> None: + super().__init__() + self._world_size: int = pg.size() + self._num_id_list_features = num_id_list_features + self._num_id_score_list_features = num_id_score_list_features + id_list_feature_block_sizes = [ + (hash_size + self._world_size - 1) // self._world_size + for hash_size in id_list_feature_hash_sizes + ] + id_score_list_feature_block_sizes = [ + (hash_size + self._world_size - 1) // self._world_size + for hash_size in id_score_list_feature_hash_sizes + ] + self.register_buffer( + "_id_list_feature_block_sizes_tensor", + torch.tensor( + id_list_feature_block_sizes, + device=device, + dtype=torch.int32, + ), + ) + self.register_buffer( + "_id_score_list_feature_block_sizes_tensor", + torch.tensor( + id_score_list_feature_block_sizes, + device=device, + dtype=torch.int32, + ), + ) + self._dist = SparseFeaturesAllToAll( + pg, + self._world_size * [self._num_id_list_features], + self._world_size * [self._num_id_score_list_features], + device, + ) + self._is_sequence = is_sequence + self._has_feature_processor = has_feature_processor + self.unbucketize_permute_tensor: Optional[torch.Tensor] = None + + def forward( + self, + sparse_features: SparseFeatures, + ) -> Awaitable[SparseFeatures]: + if self._num_id_list_features > 0: + assert sparse_features.id_list_features is not None + ( + id_list_features, + self.unbucketize_permute_tensor, + ) = bucketize_kjt_before_all2all( + sparse_features.id_list_features, + num_buckets=self._world_size, + block_sizes=self._id_list_feature_block_sizes_tensor, + output_permute=self._is_sequence, + bucketize_pos=self._has_feature_processor, + ) + else: + id_list_features = None + + if self._num_id_score_list_features > 0: + assert sparse_features.id_score_list_features is not None + id_score_list_features, _ = bucketize_kjt_before_all2all( + sparse_features.id_score_list_features, + num_buckets=self._world_size, + block_sizes=self._id_score_list_feature_block_sizes_tensor, + output_permute=False, + bucketize_pos=False, + ) + else: + id_score_list_features = None + + bucketized_sparse_features = SparseFeatures( + id_list_features=id_list_features, + id_score_list_features=id_score_list_features, + ) + return self._dist(bucketized_sparse_features) + + +class RwPooledEmbeddingDist(BasePooledEmbeddingDist): + def __init__( + self, + pg: dist.ProcessGroup, + ) -> None: + super().__init__() + self._dist = PooledEmbeddingsReduceScatter(pg) + + def forward(self, local_embs: torch.Tensor) -> Awaitable[torch.Tensor]: + return self._dist(local_embs) + + +class RwSequenceEmbeddingDist(BaseSequenceEmbeddingDist): + def __init__( + self, + pg: dist.ProcessGroup, + num_features: int, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self._dist = SequenceEmbeddingAllToAll(pg, [num_features] * pg.size(), device) + + def forward( + self, sharding_ctx: SequenceShardingContext, local_embs: torch.Tensor + ) -> Awaitable[torch.Tensor]: + return self._dist( + local_embs=local_embs, + lengths=sharding_ctx.lengths_after_input_dist, + input_splits=sharding_ctx.input_splits, + output_splits=sharding_ctx.output_splits, + unbucketize_permute_tensor=sharding_ctx.unbucketize_permute_tensor, + ) + + +class RwEmbeddingSharding(EmbeddingSharding): + """ + Shards embedding bags row-wise, i.e.. a given embedding table is evenly distribued by rows and table slices are placed on all ranks. + """ + + def __init__( + self, + embedding_configs: List[ + Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] + ], + pg: dist.ProcessGroup, + device: Optional[torch.device] = None, + is_sequence: bool = False, + ) -> None: + super().__init__() + self._pg = pg + if device is None: + device = torch.device("cpu") + self._device = device + self._is_sequence = is_sequence + sharded_tables_per_rank = self._shard(embedding_configs) + self._grouped_embedding_configs_per_rank: List[ + List[GroupedEmbeddingConfig] + ] = [] + self._score_grouped_embedding_configs_per_rank: List[ + List[GroupedEmbeddingConfig] + ] = [] + ( + self._grouped_embedding_configs_per_rank, + self._score_grouped_embedding_configs_per_rank, + ) = group_tables(sharded_tables_per_rank) + self._grouped_embedding_configs: List[ + GroupedEmbeddingConfig + ] = self._grouped_embedding_configs_per_rank[dist.get_rank(pg)] + self._score_grouped_embedding_configs: List[ + GroupedEmbeddingConfig + ] = self._score_grouped_embedding_configs_per_rank[dist.get_rank(pg)] + + self._has_feature_processor: bool = False + for group_config in self._grouped_embedding_configs: + if group_config.has_feature_processor: + self._has_feature_processor = True + + def _shard( + self, + embedding_configs: List[ + Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] + ], + ) -> List[List[ShardedEmbeddingTable]]: + world_size = self._pg.size() + tables_per_rank: List[List[ShardedEmbeddingTable]] = [ + [] for i in range(world_size) + ] + for config in embedding_configs: + # pyre-fixme [16] + shards = config[1].sharding_spec.shards + + for rank in range(world_size): + tables_per_rank[rank].append( + ShardedEmbeddingTable( + num_embeddings=config[0].num_embeddings, + embedding_dim=config[0].embedding_dim, + name=config[0].name, + embedding_names=config[0].embedding_names, + data_type=config[0].data_type, + feature_names=config[0].feature_names, + pooling=config[0].pooling, + is_weighted=config[0].is_weighted, + has_feature_processor=config[0].has_feature_processor, + local_rows=shards[rank].shard_lengths[0], + local_cols=config[0].embedding_dim, + block_size=config[1].block_size, + compute_kernel=EmbeddingComputeKernel(config[1].compute_kernel), + local_metadata=shards[rank], + weight_init_max=config[0].weight_init_max, + weight_init_min=config[0].weight_init_min, + ) + ) + return tables_per_rank + + def create_input_dist(self) -> BaseSparseFeaturesDist: + num_id_list_features = self._get_id_list_features_num() + num_id_score_list_features = self._get_id_score_list_features_num() + id_list_feature_hash_sizes = self._get_id_list_features_hash_sizes() + id_score_list_feature_hash_sizes = self._get_id_score_list_features_hash_sizes() + return RwSparseFeaturesDist( + pg=self._pg, + num_id_list_features=num_id_list_features, + num_id_score_list_features=num_id_score_list_features, + id_list_feature_hash_sizes=id_list_feature_hash_sizes, + id_score_list_feature_hash_sizes=id_score_list_feature_hash_sizes, + device=self._device, + is_sequence=self._is_sequence, + has_feature_processor=self._has_feature_processor, + ) + + def create_lookup( + self, + fused_params: Optional[Dict[str, Any]], + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup: + if self._is_sequence: + return GroupedEmbeddingsLookup( + grouped_configs=self._grouped_embedding_configs, + fused_params=fused_params, + pg=self._pg, + device=self._device, + ) + else: + return GroupedPooledEmbeddingsLookup( + grouped_configs=self._grouped_embedding_configs, + grouped_score_configs=self._score_grouped_embedding_configs, + fused_params=fused_params, + pg=self._pg, + device=self._device, + feature_processor=feature_processor, + ) + + def create_pooled_output_dist(self) -> RwPooledEmbeddingDist: + return RwPooledEmbeddingDist(self._pg) + + def create_sequence_output_dist(self) -> RwSequenceEmbeddingDist: + return RwSequenceEmbeddingDist( + self._pg, + self._get_id_list_features_num(), + self._device, + ) + + def embedding_dims(self) -> List[int]: + embedding_dims = [] + for grouped_config in self._grouped_embedding_configs: + embedding_dims.extend(grouped_config.embedding_dims()) + for grouped_config in self._score_grouped_embedding_configs: + embedding_dims.extend(grouped_config.embedding_dims()) + return embedding_dims + + def embedding_names(self) -> List[str]: + embedding_names = [] + for grouped_config in self._grouped_embedding_configs: + embedding_names.extend(grouped_config.embedding_names()) + for grouped_config in self._score_grouped_embedding_configs: + embedding_names.extend(grouped_config.embedding_names()) + return embedding_names + + def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: + embedding_shard_metadata = [] + for grouped_config in self._grouped_embedding_configs: + embedding_shard_metadata.extend(grouped_config.embedding_shard_metadata()) + for grouped_config in self._score_grouped_embedding_configs: + embedding_shard_metadata.extend(grouped_config.embedding_shard_metadata()) + return embedding_shard_metadata + + def id_list_feature_names(self) -> List[str]: + id_list_feature_names = [] + for grouped_config in self._grouped_embedding_configs: + id_list_feature_names.extend(grouped_config.feature_names()) + return id_list_feature_names + + def id_score_list_feature_names(self) -> List[str]: + id_score_list_feature_names = [] + for grouped_config in self._score_grouped_embedding_configs: + id_score_list_feature_names.extend(grouped_config.feature_names()) + return id_score_list_feature_names + + def _get_id_list_features_num(self) -> int: + return sum( + group_config.num_features() + for group_config in self._grouped_embedding_configs + ) + + def _get_id_score_list_features_num(self) -> int: + return sum( + group_config.num_features() + for group_config in self._score_grouped_embedding_configs + ) + + def _get_id_list_features_hash_sizes(self) -> List[int]: + id_list_feature_hash_sizes: List[int] = [] + for group_config in self._grouped_embedding_configs: + id_list_feature_hash_sizes.extend(group_config.feature_hash_sizes()) + return id_list_feature_hash_sizes + + def _get_id_score_list_features_hash_sizes(self) -> List[int]: + id_score_list_feature_hash_sizes: List[int] = [] + for group_config in self._score_grouped_embedding_configs: + id_score_list_feature_hash_sizes.extend(group_config.feature_hash_sizes()) + return id_score_list_feature_hash_sizes diff --git a/torchrec/distributed/tests/collective_utils_test.py b/torchrec/distributed/tests/collective_utils_test.py new file mode 100644 index 000000000..faa41c3c4 --- /dev/null +++ b/torchrec/distributed/tests/collective_utils_test.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 + +import logging +import os +from unittest import mock + +import caffe2.torch.fb.distributed.utils.log_utils as log_utils +import torch.distributed as dist +from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR # @manual +from torch.testing._internal.common_distributed import MultiProcessTestCase # @manual +from torchrec.distributed.collective_utils import ( + is_leader, + invoke_on_rank_and_broadcast_result, + run_on_leader, +) +from torchrec.tests.utils import get_free_port + + +logger: logging.Logger = log_utils.getLogger() + + +""" +buck test @mode/dev-nosan //torchrec/distributed/tests:collective_utils_test + +Mirrors the test cases implemented for ExtendProcessGroup collective_utils located in: +fbcode/caffe2/torch/fb/hpc/tests/collective_utils_test.py +""" + + +class CollectiveUtilsTest(MultiProcessTestCase): + @classmethod + def setUpClass(cls) -> None: + os.environ["MASTER_ADDR"] = str(MASTER_ADDR) + os.environ["MASTER_PORT"] = str(get_free_port()) + os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP" + super().setUpClass() + + def setUp(self) -> None: + super(CollectiveUtilsTest, self).setUp() + self._spawn_processes() + + def tearDown(self) -> None: + super(CollectiveUtilsTest, self).tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def world_size(self) -> int: + return 2 + + def test_is_leader(self) -> None: + dist.init_process_group( + rank=self.rank, world_size=self.world_size, backend="gloo" + ) + pg = dist.new_group( + ranks=[0, 1], + backend="gloo", + ) + + if pg.rank() == 0: + assert is_leader(pg, 0) is True + assert is_leader(pg, 1) is False + else: + assert is_leader(pg, 1) is True + assert is_leader(pg, 0) is False + + def test_invoke_on_rank_and_broadcast_result(self) -> None: + dist.init_process_group( + rank=self.rank, world_size=self.world_size, backend="gloo" + ) + pg = dist.new_group( + ranks=[0, 1], + backend="gloo", + ) + + func = mock.MagicMock() + func.return_value = pg.rank() + + res = invoke_on_rank_and_broadcast_result(pg=pg, rank=0, func=func) + assert res == 0, f"Expect res to be 0 (got {res})" + + if pg.rank() == 0: + func.assert_called_once() + else: + func.assert_not_called() + func.reset_mock() + + res = invoke_on_rank_and_broadcast_result(pg=pg, rank=1, func=func) + assert res == 1, f"Expect res to be 1 (got {res})" + + if pg.rank() == 0: + func.assert_not_called() + else: + func.assert_called_once() + + def test_run_on_leader_decorator(self) -> None: + dist.init_process_group( + rank=self.rank, world_size=self.world_size, backend="gloo" + ) + pg = dist.new_group( + ranks=[0, 1], + backend="gloo", + ) + + @run_on_leader(pg, 0) + def _test_run_on_0(rank: int) -> int: + return rank + + res = _test_run_on_0(pg.rank()) + assert res == 0 + + @run_on_leader(pg, 1) + def _test_run_on_1(rank: int) -> int: + return rank + + res = _test_run_on_1(pg.rank()) + assert res == 1 diff --git a/torchrec/distributed/tests/test_comm.py b/torchrec/distributed/tests/test_comm.py new file mode 100644 index 000000000..04d39e07a --- /dev/null +++ b/torchrec/distributed/tests/test_comm.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 + +import itertools +import os +import unittest + +import numpy +import torch +import torch.distributed as dist +import torchrec.distributed.comm_ops as comm_ops +from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR # @manual +from torch.testing._internal.common_distributed import MultiProcessTestCase # @manual +from torchrec.tests.utils import get_free_port + + +class TestAllToAll(MultiProcessTestCase): + @classmethod + def setUpClass(cls) -> None: + os.environ["MASTER_ADDR"] = str(MASTER_ADDR) + os.environ["MASTER_PORT"] = str(get_free_port()) + super().setUpClass() + + def setUp(self) -> None: + super(TestAllToAll, self).setUp() + self._spawn_processes() + + def tearDown(self) -> None: + super(TestAllToAll, self).tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def world_size(self) -> int: + return 2 + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `torch.cuda.device_count() = 0` to decorator factory `unittest.skipIf`. + @unittest.skipIf( + torch.cuda.device_count() < 2, "Need at least two ranks to run this test" + ) + def test_alltoallv(self) -> None: + dist.init_process_group( + rank=self.rank, world_size=self.world_size, backend="nccl" + ) + device = torch.device(f"cuda:{self.rank}") + + torch.cuda.set_device(device) + + B_global = 10 + D0 = 8 + D1 = 9 + + input_embedding0 = torch.rand( + (B_global, D0), + device=device, + requires_grad=True, + ) + input_embedding1 = torch.rand( + (B_global, D1), + device=device, + requires_grad=True, + ) + + input_embeddings = [input_embedding0, input_embedding1] + out_split = [17, 17] + + a2a_req = comm_ops.alltoallv(input_embeddings, out_split) + v_embs_out = a2a_req.wait() + res = torch.cat(v_embs_out, dim=1).cpu() + self.assertEqual(tuple(res.size()), (5, 34)) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + # `torch.cuda.device_count() = 0` to decorator factory `unittest.skipIf`. + @unittest.skipIf( + torch.cuda.device_count() < 2, "Need at least two ranks to run this test" + ) + def test_alltoall_sequence(self) -> None: + dist.init_process_group( + rank=self.rank, world_size=self.world_size, backend="nccl" + ) + device = torch.device(f"cuda:{self.rank}") + torch.cuda.set_device(device) + + ranks = 2 + tables_mp = [[0], [1, 2]] + lengths_dp = [ + numpy.array([[1, 2], [1, 1], [2, 1]]), + numpy.array([[1, 2], [2, 1], [3, 1]]), + ] # W, T_g, B_l + lengths_a2a = [ + numpy.array([[[1, 2]], [[1, 2]]]), # Rank 0 + numpy.array( + [ + [[1, 1], [2, 1]], # from Rank 0 + [[2, 1], [3, 1]], # from rank 1 + ] + ), # Rank 1 + ] # W, W, T_l, B_l + lengths_mp = [ + numpy.array( + [ + [1, 2, 1, 2], + ] + ), + numpy.array([[1, 1, 2, 1], [2, 1, 3, 1]]), + ] # w, t_l, b_g + input_seg = list(itertools.accumulate([0] + [len(i) for i in tables_mp])) + input_splits = [ + [ + lengths_dp[i][input_seg[j] : input_seg[j + 1], :].sum() + for j in range(ranks) + ] + for i in range(ranks) + ] + output_splits = [lengths_a2a[i].sum(axis=(1, 2)).tolist() for i in range(ranks)] + table_dim = 3 + num_features_per_rank = [len(features) for features in tables_mp] + seq_all2all_forward_recat = [] + for j in range(ranks): + for i in range(num_features_per_rank[self.rank]): + seq_all2all_forward_recat.append(j + i * ranks) + seq_all2all_forward_recat_tensor = torch.IntTensor(seq_all2all_forward_recat) + seq_all2all_backward_recat = [] + for i in range(num_features_per_rank[self.rank]): + for j in range(ranks): + seq_all2all_backward_recat.append( + i + j * num_features_per_rank[self.rank] + ) + + seq_all2all_backward_recat_tensor = torch.IntTensor(seq_all2all_backward_recat) + input_embeddings = torch.rand( + lengths_mp[self.rank].sum(), + table_dim, + device=device, + requires_grad=True, + ) + lengths_after_sparse_data_all2all = torch.IntTensor(lengths_mp[self.rank]) + a2a_req = comm_ops.alltoall_sequence( + a2a_sequence_embs_tensor=input_embeddings.cuda(), + forward_recat_tensor=seq_all2all_forward_recat_tensor.cuda(), + backward_recat_tensor=seq_all2all_backward_recat_tensor.cuda(), + lengths_after_sparse_data_all2all=lengths_after_sparse_data_all2all.cuda(), + input_splits=input_splits[self.rank], + output_splits=output_splits[self.rank], + ) + seq_embs_out = a2a_req.wait() + seq_embs_out.backward(seq_embs_out) + grad = input_embeddings.grad + self.assertEqual(input_embeddings.cpu().detach(), grad.cpu().detach()) diff --git a/torchrec/distributed/tests/test_dist_data.py b/torchrec/distributed/tests/test_dist_data.py new file mode 100644 index 000000000..1c7fded6d --- /dev/null +++ b/torchrec/distributed/tests/test_dist_data.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python3 + +import abc +import itertools +import multiprocessing +import os +import random +import unittest +from typing import List, Tuple, TypeVar, Any, Generator, Union + +import hypothesis.strategies as st +import torch +import torch.distributed as dist +from hypothesis import given, settings + +# @manual=//python/wheel/numpy:numpy +from numpy.testing import assert_array_equal +from torchrec.distributed.dist_data import ( + KJTAllToAll, + PooledEmbeddingsAllToAll, + PooledEmbeddingsReduceScatter, + KJTAllToAllAwaitable, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.tests.utils import get_free_port, seed_and_log + + +T = TypeVar("T", int, float) + +# Lightly adapted from Stack Overflow #10823877 +def _flatten(iterable: List[T]) -> Generator[T, None, None]: + iterator, sentinel, stack = iter(iterable), object(), [] + while True: + value = next(iterator, sentinel) + if value is sentinel: + if not stack: + break + iterator = stack.pop() + else: + try: + new_iterator = iter(value) + except TypeError: + yield value + else: + stack.append(iterator) + iterator = new_iterator + + +def _to_tensor(iterator: List[T], device_id: int, dtype: torch.dtype) -> torch.Tensor: + return torch.tensor(list(_flatten(iterator)), dtype=dtype).cuda(device_id) + + +def _generate_sparse_features_batch( + keys: List[str], splits: List[int], B: int, is_weighted: bool = False +) -> Tuple[List[KeyedJaggedTensor], List[KeyedJaggedTensor]]: + world_size = len(splits) + offsets = [0] + list(itertools.accumulate(splits)) + values = {} + lengths = {} + weights = {} if is_weighted else None + + for key in keys: + lengths[key] = [ + [random.randint(0, 10) for _ in range(B)] for i in range(world_size) + ] + values[key] = [ + [random.randint(0, 1000) for _ in range(sum(lengths[key][i]))] + for i in range(world_size) + ] + + if weights: + weights[key] = [ + [random.random() for _ in range(sum(lengths[key][i]))] + for i in range(world_size) + ] + + in_jagged: List[KeyedJaggedTensor] = [] + out_jagged: List[KeyedJaggedTensor] = [] + for i in range(world_size): + in_jagged.append( + KeyedJaggedTensor.from_lengths_sync( + keys=keys, + lengths=_to_tensor([lengths[key][i] for key in keys], i, torch.int), + values=_to_tensor([values[key][i] for key in keys], i, torch.int), + weights=_to_tensor([weights[key][i] for key in keys], i, torch.float) + if weights + else None, + ) + ) + key_index = [] + out_keys = keys[offsets[i] : offsets[i + 1]] + for key in out_keys: + for j in range(world_size): + key_index.append((key, j)) + out_jagged.append( + KeyedJaggedTensor.from_lengths_sync( + keys=out_keys, + lengths=_to_tensor( + [lengths[key][j] for key, j in key_index], + i, + torch.int, + ), + values=_to_tensor( + [values[key][j] for key, j in key_index], + i, + torch.int, + ), + weights=_to_tensor( + [weights[key][j] for key, j in key_index], + i, + torch.float, + ) + if weights + else None, + ) + ) + return in_jagged, out_jagged + + +def _generate_pooled_embedding_batch( + keys: List[str], dims: List[int], splits: List[int], B: int +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + world_size = len(splits) + offsets = [0] + list(itertools.accumulate(splits)) + local_emb = {} + + for key, dim in zip(keys, dims): + local_emb[key] = [ + [random.random() for _ in range(dim)] for _ in range(B * world_size) + ] + + in_tensor: List[torch.Tensor] = [] + out_tensor: List[torch.Tensor] = [] + for i in range(world_size): + in_keys = keys[offsets[i] : offsets[i + 1]] + in_tensor.append( + _to_tensor( + [local_emb[key][b] for b in range(B * world_size) for key in in_keys], + i, + torch.float, + ).view(B * world_size, -1) + if in_keys + else torch.empty(B * world_size, 0, dtype=torch.float).cuda(i) + ) + out_tensor.append( + _to_tensor( + [local_emb[key][b] for b in range(B * i, B * (i + 1)) for key in keys], + i, + torch.float, + ).view(B, -1) + ) + + return in_tensor, out_tensor + + +class DistDataTestCase(abc.ABC, unittest.TestCase): + @seed_and_log + def setUp(self) -> None: + torch.use_deterministic_algorithms(True) + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP" + self.WORLD_SIZE = 2 + + def _run_multi_process_test( + self, + _input: Union[List[KeyedJaggedTensor], List[torch.Tensor]], + output: Union[List[KeyedJaggedTensor], List[torch.Tensor]], + **kwargs: Any, + ) -> None: + processes = [] + ctx = multiprocessing.get_context("spawn") + for rank in range(self.WORLD_SIZE): + p = ctx.Process( + target=self._run_test_dist, + args=( + rank, + self.WORLD_SIZE, + _input[rank], + output[rank], + ), + kwargs=kwargs, + ) + p.start() + processes.append(p) + + for p in processes: + p.join() + self.assertEqual(0, p.exitcode) + + @classmethod + @abc.abstractmethod + def _run_test_dist(cls) -> None: + pass + + +class KJTAllToAllTest(DistDataTestCase): + @classmethod + def _validate( + cls, + actual_output_awaitable: Union[KJTAllToAllAwaitable, KeyedJaggedTensor], + expected_output_awaitable: Union[KJTAllToAllAwaitable, KeyedJaggedTensor], + ) -> None: + actual_output = ( + actual_output_awaitable + if isinstance(actual_output_awaitable, KeyedJaggedTensor) + else actual_output_awaitable.wait() + ) + expected_output = ( + expected_output_awaitable + if isinstance(expected_output_awaitable, KeyedJaggedTensor) + else expected_output_awaitable.wait() + ) + assert_array_equal( + actual_output.values().cpu(), + expected_output.values().cpu(), + ) + assert_array_equal( + actual_output.weights().cpu() + if actual_output.weights_or_none() is not None + else [], + expected_output.weights().cpu() + if expected_output.weights_or_none() is not None + else [], + ) + assert_array_equal( + actual_output.lengths().cpu(), + expected_output.lengths().cpu(), + ) + assert_array_equal( + actual_output.keys(), + expected_output.keys(), + ) + + @classmethod + def _run_test_dist( + cls, + rank: int, + world_size: int, + _input: KeyedJaggedTensor, + output: KeyedJaggedTensor, + backend: str, + splits: List[int], + ) -> None: + dist.init_process_group(rank=rank, world_size=world_size, backend=backend) + device = torch.device(f"cuda:{rank}") + if backend == "gloo": + device = torch.device("cpu") + _input = _input.to(device=device) + output = output.to(device=device) + pg = dist.group.WORLD + a2a = KJTAllToAll(pg=pg, splits=splits, device=device) + cls._validate(a2a(_input), output) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + backend=st.sampled_from(["gloo", "nccl"]), + B=st.integers(min_value=1, max_value=3), + features=st.integers(min_value=2, max_value=5), + is_weighted=st.booleans(), + ) + @settings(max_examples=10, deadline=None) + def test_features( + self, backend: str, B: int, features: int, is_weighted: bool + ) -> None: + keys = [f"F{feature}" for feature in range(features)] + rank0_split = random.randint(0, features) + splits = [rank0_split, features - rank0_split] + _input, output = _generate_sparse_features_batch( + keys=keys, splits=splits, B=B, is_weighted=is_weighted + ) + + self._run_multi_process_test( + _input=_input, + output=output, + backend=backend, + splits=splits, + ) + + +class PooledEmbeddingsAllToAllTest(DistDataTestCase): + @classmethod + def _run_test_dist( + cls, + rank: int, + world_size: int, + _input: torch.Tensor, + output: torch.Tensor, + backend: str, + dim_sum_per_rank: List[int], + ) -> None: + dist.init_process_group(rank=rank, world_size=world_size, backend=backend) + pg = dist.group.WORLD + if backend == "gloo": + device = torch.device("cpu") + else: + device = torch.device(f"cuda:{rank}") + _input = _input.to(device=device) + output = output.to(device=device) + a2a = PooledEmbeddingsAllToAll( + pg=pg, + dim_sum_per_rank=dim_sum_per_rank, + device=device, + ) + _input.requires_grad = True + res = a2a(_input).wait() + res.backward(res) + assert_array_equal( + res.cpu().detach(), + output.cpu().detach(), + ) + assert_array_equal( + _input.cpu().detach().div_(world_size), + _input.grad.cpu().detach(), + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + backend=st.sampled_from(["gloo", "nccl"]), + B=st.integers(min_value=2, max_value=5), + features=st.integers(min_value=2, max_value=5), + is_reversed=st.booleans(), + ) + @settings(max_examples=20, deadline=None) + def test_pooled_embeddings( + self, backend: str, B: int, features: int, is_reversed: bool + ) -> None: + keys = [f"F{feature}" for feature in range(features)] + dims = random.sample([8, 16, 32] * features, features) + rank0_split = random.randint(1, features - 1) + splits = [rank0_split, features - rank0_split] + if is_reversed: + splits.reverse() + dim_sum_per_rank = [sum(dims[: splits[0]]), sum(dims[splits[0] :])] + + _input, output = _generate_pooled_embedding_batch( + keys=keys, + dims=dims, + splits=splits, + B=B, + ) + + self._run_multi_process_test( + _input=_input, + output=output, + backend=backend, + dim_sum_per_rank=dim_sum_per_rank, + ) + + +class PooledEmbeddingsReduceScatterTest(DistDataTestCase): + @classmethod + def _validate( + cls, + actual_output: torch.Tensor, + expected_output: torch.Tensor, + input: torch.Tensor, + world_size: int, + ) -> None: + assert_array_equal(actual_output.cpu().detach(), expected_output.cpu().detach()) + assert_array_equal( + input.grad.cpu().detach(), + torch.ones(input.size()).div_(world_size), + ) + + @classmethod + def _run_test_dist( + cls, + rank: int, + world_size: int, + input: torch.Tensor, + expected_output: torch.Tensor, + ) -> None: + dist.init_process_group(rank=rank, world_size=2, backend="nccl") + pg = dist.group.WORLD + input = input.cuda(rank) + input.requires_grad = True + rs = PooledEmbeddingsReduceScatter(pg).cuda(rank) + actual_output = rs(input).wait() + s = torch.sum(actual_output) + s.backward() + cls._validate(actual_output, expected_output, input, world_size) + + # pyre-fixme[56] + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @settings(deadline=60000) + def test_pooled_embedding_reduce_scatter(self) -> None: + embeddding_dim = 10 + batch_size = 2 + embeddings = torch.rand((batch_size * 2, embeddding_dim)) + embeddings_by_rank = list(torch.chunk(embeddings, batch_size, dim=0)) + expect_results = torch.chunk( + torch.stack(embeddings_by_rank, dim=0).sum(dim=0), + 2, + dim=0, + ) + self._run_multi_process_test( + embeddings_by_rank, + expect_results, + ) diff --git a/torchrec/distributed/tests/test_fused_optim.py b/torchrec/distributed/tests/test_fused_optim.py new file mode 100644 index 000000000..0eda75687 --- /dev/null +++ b/torchrec/distributed/tests/test_fused_optim.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 + +import os +import unittest +from typing import List, Optional, Dict, cast, Union + +import hypothesis.strategies as st +import torch +import torch.distributed as dist +import torch.nn as nn +from fbgemm_gpu.split_embedding_configs import EmbOptimType +from hypothesis import Verbosity, given, settings +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embedding_types import EmbeddingTableConfig +from torchrec.distributed.model_parallel import ( + DistributedModelParallel, +) +from torchrec.distributed.planner.types import ParameterHints +from torchrec.distributed.tests.test_model import ( + TestSparseNN, + TestSparseNNBase, + TestEBCSharder, + TestEBSharder, +) +from torchrec.distributed.tests.test_model_parallel_base import ModelParallelTestBase +from torchrec.distributed.types import ( + ModuleSharder, + ShardingType, + ShardingEnv, +) +from torchrec.modules.embedding_configs import BaseEmbeddingConfig +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.tests.utils import ( + skip_if_asan_class, + init_distributed_single_host, + seed_and_log, +) + + +def create_test_sharder( + sharding_type: str, kernel_type: str, optim: EmbOptimType +) -> Union[TestEBSharder, TestEBCSharder]: + fused_params = {} + fused_params["optimizer"] = optim + if optim == EmbOptimType.EXACT_SGD: + fused_params["learning_rate"] = 0.1 + else: + fused_params["learning_rate"] = 0.01 + return TestEBCSharder(sharding_type, kernel_type, fused_params) + + +@skip_if_asan_class +class ModelParallelTest(ModelParallelTestBase): + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.ROW_WISE.value, + ShardingType.TABLE_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.BATCHED_FUSED.value, + ] + ), + optim_type=st.sampled_from( + [ + EmbOptimType.EXACT_SGD, + EmbOptimType.EXACT_ROWWISE_ADAGRAD, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_sharding_nccl_rw( + self, + sharding_type: str, + kernel_type: str, + optim_type: EmbOptimType, + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder(sharding_type, kernel_type, optim_type), + ], + backend="nccl", + optim=optim_type, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.BATCHED_FUSED.value, + ] + ), + optim_type=st.sampled_from( + [ + EmbOptimType.EXACT_SGD, + EmbOptimType.EXACT_ROWWISE_ADAGRAD, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_sharding_nccl_tw( + self, + sharding_type: str, + kernel_type: str, + optim_type: EmbOptimType, + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder(sharding_type, kernel_type, optim_type), + ], + backend="nccl", + optim=optim_type, + ) + + @seed_and_log + def setUp(self) -> None: + super().setUp() + torch.use_deterministic_algorithms(True) + if torch.cuda.is_available(): + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + + num_features = 4 + num_weighted_features = 2 + + self.tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 2) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + self.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 2) * 4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + + self.embedding_groups = { + "group_0": ["feature_" + str(i) for i in range(num_features)] + } + + def _test_sharding( + self, + sharders: List[ModuleSharder[nn.Module]], + optim: EmbOptimType, + backend: str = "gloo", + world_size: int = 2, + local_size: Optional[int] = None, + hints: Optional[Dict[str, ParameterHints]] = None, + ) -> None: + self._run_multi_process_test( + # pyre-ignore [6] + callable=self._test_optim_single_rank, + world_size=world_size, + local_size=local_size, + model_class=cast(TestSparseNNBase, TestSparseNN), + tables=self.tables, + weighted_tables=self.weighted_tables, + embedding_groups=self.embedding_groups, + sharders=sharders, + backend=backend, + optim=optim, + hints=hints, + ) + + @classmethod + def _test_optim_single_rank( + cls, + rank: int, + world_size: int, + model_class: TestSparseNNBase, + embedding_groups: Dict[str, List[str]], + tables: List[EmbeddingTableConfig], + sharders: List[ModuleSharder[nn.Module]], + backend: str, + optim: EmbOptimType, + weighted_tables: Optional[List[EmbeddingTableConfig]] = None, + hints: Optional[Dict[str, ParameterHints]] = None, + local_size: Optional[int] = None, + ) -> None: + # Override local_size after pg construction because unit test device count + # is larger than local_size setup. This can be problematic for twrw because + # we have ShardedTensor placement check. + os.environ["LOCAL_WORLD_SIZE"] = str(world_size) + if backend == "nccl": + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + else: + device = torch.device("cpu") + pg = init_distributed_single_host( + rank=rank, + world_size=world_size, + backend=backend, + local_size=local_size, + ) + if rank == 0: + global_pg = dist.new_group(ranks=[0], backend=backend) + dist.new_group(ranks=[1], backend=backend) + else: + dist.new_group(ranks=[0], backend=backend) + global_pg = dist.new_group(ranks=[1], backend=backend) + + # Generate model & inputs. + (global_model, inputs) = cls._gen_model_and_input( + model_class=model_class, + tables=tables, + weighted_tables=weighted_tables, + embedding_groups=embedding_groups, + world_size=world_size, + num_float_features=16, + ) + global_model = global_model.cuda(0) + global_model = DistributedModelParallel( + global_model, + env=ShardingEnv.from_process_group(global_pg), + sharders=sharders, + device=torch.device("cuda:0"), + init_data_parallel=False, + ) + global_input = inputs[0][0].to(torch.device("cuda:0")) + local_input = inputs[0][1][rank].to(device) + + # Run single step of unsharded model to populate optimizer states. + global_opt = torch.optim.SGD(global_model.parameters(), lr=0.1) + cls._gen_full_pred_after_one_step(global_model, global_opt, global_input) + + # Shard model. + local_model = model_class( + tables=cast(List[BaseEmbeddingConfig], tables), + weighted_tables=weighted_tables, + embedding_groups=embedding_groups, + dense_device=device, + sparse_device=torch.device("meta"), + num_float_features=16, + ) + local_model = DistributedModelParallel( + local_model, + env=ShardingEnv.from_process_group(pg), + sharders=sharders, + device=device, + ) + local_opt = torch.optim.SGD(local_model.parameters(), lr=0.1) + + # Load model & optimizer states from the global model. + cls._copy_state_dict(local_model.state_dict(), global_model.state_dict()) + for param_name, local_state in local_model.fused_optimizer.state_dict()[ + "state" + ].items(): + global_state = global_model.fused_optimizer.state_dict()["state"][ + param_name + ] + cls._copy_state_dict(local_state, global_state) + + # Run a single training step of the sharded model. + local_pred = cls._gen_full_pred_after_one_step( + local_model, local_opt, local_input + ) + all_local_pred = [] + for _ in range(world_size): + all_local_pred.append(torch.empty_like(local_pred)) + dist.all_gather(all_local_pred, local_pred, group=pg) + + # Run second training step of the unsharded model. + global_opt = torch.optim.SGD(global_model.parameters(), lr=0.1) + global_pred = cls._gen_full_pred_after_one_step( + global_model, global_opt, global_input + ) + + # Compare predictions of sharded vs unsharded models. + torch.testing.assert_allclose( + global_pred.cpu(), torch.cat(all_local_pred).cpu() + ) diff --git a/torchrec/distributed/tests/test_lazy_awaitable.py b/torchrec/distributed/tests/test_lazy_awaitable.py new file mode 100644 index 000000000..44ef133af --- /dev/null +++ b/torchrec/distributed/tests/test_lazy_awaitable.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 + +import unittest +from typing import Dict + +import torch +import torch.fx +from torchrec.distributed.types import LazyAwaitable + + +class NeedWait(LazyAwaitable[torch.Tensor]): + def __init__(self, actual_value: torch.Tensor) -> None: + super().__init__() + self.actual_value = actual_value + + def wait(self) -> torch.Tensor: + self.actual_value += 8 + return self.actual_value + + +class NeedWaitNoInit(LazyAwaitable[torch.Tensor]): + def __init__(self, actual_value: torch.Tensor) -> None: + # ill-formed type, no super.__init__() here + # should error out when using it + self.actual_value = actual_value + + def wait(self) -> torch.Tensor: + self.actual_value += 8 + return self.actual_value + + +class NeedWaitDict(LazyAwaitable[Dict[str, torch.Tensor]]): + def __init__(self, actual_value: Dict[str, torch.Tensor], key: str) -> None: + super().__init__() + self.actual_value = actual_value + self.key = key + + def wait(self) -> Dict[str, torch.Tensor]: + self.actual_value[self.key] *= 3 + return self.actual_value + + +class AsyncModule(torch.nn.Module): + """ + Dummy async module + + Constructor Args: + + + Call Args: + x: torch.Tensor + + Returns: + LazyAwaitable[torch.Tensor] + + Example: + >>> AsyncModule() + """ + + def forward(self, x: torch.Tensor) -> LazyAwaitable[torch.Tensor]: + return NeedWait(x) + + +class TestLazyAwaitable(unittest.TestCase): + def test_lazy_torch_function(self) -> None: + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.async_compute = AsyncModule() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + async_fut = self.async_compute(x) + y = x * 3 + 5 + return torch.add(async_fut, y) + + # ensure computation of y happens earlier than wait() + m = Model() + ref_res = m(torch.ones(3, 4)) + self.assertTrue(torch.equal(ref_res, 17 * torch.ones(3, 4))) + + # ensure fx tracing works + gm = torch.fx.symbolic_trace(m) + traced_res = gm(torch.ones(3, 4)) + self.assertTrue(torch.equal(traced_res, ref_res)) + + def test_lazy_getattr(self) -> None: + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.async_compute = AsyncModule() + + def forward(self, x: torch.Tensor) -> int: + async_fut = self.async_compute(x) + y = x * 3 + 5 + return async_fut.numel() + y.numel() + + m = Model() + ref_res = m(torch.ones(3, 4)) + self.assertEqual(ref_res, 24) + + # ensure fx tracing works + gm = torch.fx.symbolic_trace(m) + traced_res = gm(torch.ones(3, 4)) + self.assertEqual(traced_res, ref_res) + + def test_lazy_getattr_explicit_wait(self) -> None: + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.async_compute = AsyncModule() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + async_fut = self.async_compute(x) + y = x * 3 + 5 + return async_fut.wait() + y + + m = Model() + ref_res = m(torch.ones(3, 4)) + self.assertTrue(torch.equal(ref_res, 17 * torch.ones(3, 4))) + + # ensure fx tracing works + gm = torch.fx.symbolic_trace(m) + traced_res = gm(torch.ones(3, 4)) + self.assertTrue(torch.equal(traced_res, ref_res)) + + def test_lazy_awaitable_init_error(self) -> None: + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.lazy_awaitable = NeedWaitNoInit(torch.ones(2, 3)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.lazy_awaitable + + m = Model() + + with self.assertRaisesRegex(RuntimeError, "has not been initialized properly"): + m(torch.ones(2, 3)) + + def test_lazy_wait_and_result(self) -> None: + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.async_compute = AsyncModule() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + async_fut = self.async_compute(x) + y = x * 3 + 5 + numel = async_fut.numel() + return numel + async_fut._result + y + + m = Model() + ref_res = m(torch.ones(3, 4)) + self.assertTrue(torch.equal(ref_res, 29 * torch.ones(3, 4))) + + # ensure fx tracing works + gm = torch.fx.symbolic_trace(m) + traced_res = gm(torch.ones(3, 4)) + self.assertTrue(torch.equal(traced_res, ref_res)) + + def test_lazy_get_item(self) -> None: + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.async_compute = AsyncModule() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + async_fut = self.async_compute(x) + return async_fut[1:3] + + m = Model() + ref_res = m(torch.ones(3, 4)) + self.assertTrue(torch.equal(ref_res, 9 * torch.ones(2, 4))) + + # ensure fx tracing works + gm = torch.fx.symbolic_trace(m) + traced_res = gm(torch.ones(3, 4)) + self.assertTrue(torch.equal(traced_res, ref_res)) + + def test_lazy_magic_methods(self) -> None: + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.async_compute1 = AsyncModule() + self.async_compute2 = AsyncModule() + + def forward(self, x: torch.Tensor) -> int: + async_fut1 = self.async_compute1(x) + async_fut2 = self.async_compute2(x) + y = x * 3 + 5 + return 2 * async_fut1 + y - async_fut2 + + m = Model() + ref_res = m(torch.ones(3, 4)) + self.assertTrue(torch.equal(ref_res, 9 * torch.ones(3, 4))) + + gm = torch.fx.symbolic_trace(m) + traced_res = gm(torch.ones(3, 4)) + self.assertTrue(torch.equal(traced_res, ref_res)) + + def test_lazy_wait_dict(self) -> None: + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.dict = {"t1": torch.ones(2, 3)} + self.wait_dict = NeedWaitDict(self.dict, "t1") + + def forward(self) -> torch.Tensor: + return self.wait_dict["t1"] + 2 + + m = Model() + ref_res = m() + self.assertTrue(torch.equal(ref_res, 5 * torch.ones(2, 3))) + + # ensure fx tracing works + gm = torch.fx.symbolic_trace(m) + traced_res = gm() + self.assertTrue(torch.equal(traced_res, ref_res)) + + def test_lazy_awaitable_serde(self) -> None: + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.async_compute = AsyncModule() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + async_fut = self.async_compute(x) + y = x * 3 + 5 + return torch.add(async_fut, y) + + m = Model() + gm = torch.fx.symbolic_trace(m) + + import pickle + import tempfile + + tempFile = None + with tempfile.NamedTemporaryFile(delete=False) as f: + pickle.dump(gm, f) + tempFile = f + + with open(tempFile.name, "rb") as f: + loaded = pickle.load(f) + + ref_res = loaded(torch.ones(3, 4)) + self.assertTrue(torch.equal(ref_res, 17 * torch.ones(3, 4))) + + tempFile.close() diff --git a/torchrec/distributed/tests/test_model.py b/torchrec/distributed/tests/test_model.py new file mode 100644 index 000000000..0da6122ab --- /dev/null +++ b/torchrec/distributed/tests/test_model.py @@ -0,0 +1,473 @@ +#!/usr/bin/env python3 + +from dataclasses import dataclass +from typing import List, cast, Optional, Tuple, Any, Dict, Union + +import torch +import torch.nn as nn +from torchrec.distributed.embedding_types import EmbeddingTableConfig +from torchrec.distributed.embeddingbag import ( + EmbeddingBagSharder, + EmbeddingBagCollectionSharder, +) +from torchrec.modules.embedding_configs import EmbeddingBagConfig, BaseEmbeddingConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +from torchrec.types import Pipelineable + + +@dataclass +class ModelInput(Pipelineable): + float_features: torch.Tensor + idlist_features: KeyedJaggedTensor + idscore_features: KeyedJaggedTensor + label: torch.Tensor + + @staticmethod + def generate( + batch_size: int, + world_size: int, + num_float_features: int, + tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]], + weighted_tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]], + pooling_avg: int = 10, + ) -> Tuple["ModelInput", List["ModelInput"]]: + """ + Returns a global (single-rank training) batch + and a list of local (multi-rank training) batches of world_size. + """ + idlist_features = [ + feature for table in tables for feature in table.feature_names + ] + idscore_features = [ + feature for table in weighted_tables for feature in table.feature_names + ] + + idlist_ind_ranges = [table.num_embeddings for table in tables] + idscore_ind_ranges = [table.num_embeddings for table in weighted_tables] + + # Generate global batch. + global_idlist_lengths = [] + global_idlist_indices = [] + global_idscore_lengths = [] + global_idscore_indices = [] + global_idscore_weights = [] + + for ind_range in idlist_ind_ranges: + lengths = torch.abs( + torch.randn(batch_size * world_size) + pooling_avg + ).int() + num_indices = cast(int, torch.sum(lengths).item()) + indices = torch.randint(0, ind_range, (num_indices,)) + global_idlist_lengths.append(lengths) + global_idlist_indices.append(indices) + global_idlist_kjt = KeyedJaggedTensor( + keys=idlist_features, + values=torch.cat(global_idlist_indices), + lengths=torch.cat(global_idlist_lengths), + ) + + for ind_range in idscore_ind_ranges: + lengths = torch.abs( + torch.randn(batch_size * world_size) + pooling_avg + ).int() + num_indices = cast(int, torch.sum(lengths).item()) + indices = torch.randint(0, ind_range, (num_indices,)) + weights = torch.rand((num_indices,)) + global_idscore_lengths.append(lengths) + global_idscore_indices.append(indices) + global_idscore_weights.append(weights) + global_idscore_kjt = ( + KeyedJaggedTensor( + keys=idscore_features, + values=torch.cat(global_idscore_indices), + lengths=torch.cat(global_idscore_lengths), + weights=torch.cat(global_idscore_weights), + ) + if global_idscore_indices + else None + ) + + global_float = torch.rand((batch_size * world_size, num_float_features)) + global_label = torch.rand(batch_size * world_size) + + # Split global batch into local batches. + local_inputs = [] + for r in range(world_size): + local_idlist_lengths = [] + local_idlist_indices = [] + local_idscore_lengths = [] + local_idscore_indices = [] + local_idscore_weights = [] + + for lengths, indices in zip(global_idlist_lengths, global_idlist_indices): + local_idlist_lengths.append( + lengths[r * batch_size : (r + 1) * batch_size] + ) + lengths_cumsum = [0] + lengths.view(world_size, -1).sum(dim=1).cumsum( + dim=0 + ).tolist() + local_idlist_indices.append( + indices[lengths_cumsum[r] : lengths_cumsum[r + 1]] + ) + + for lengths, indices, weights in zip( + global_idscore_lengths, global_idscore_indices, global_idscore_weights + ): + local_idscore_lengths.append( + lengths[r * batch_size : (r + 1) * batch_size] + ) + lengths_cumsum = [0] + lengths.view(world_size, -1).sum(dim=1).cumsum( + dim=0 + ).tolist() + local_idscore_indices.append( + indices[lengths_cumsum[r] : lengths_cumsum[r + 1]] + ) + local_idscore_weights.append( + weights[lengths_cumsum[r] : lengths_cumsum[r + 1]] + ) + + local_idlist_kjt = KeyedJaggedTensor( + keys=idlist_features, + values=torch.cat(local_idlist_indices), + lengths=torch.cat(local_idlist_lengths), + ) + + local_idscore_kjt = ( + KeyedJaggedTensor( + keys=idscore_features, + values=torch.cat(local_idscore_indices), + lengths=torch.cat(local_idscore_lengths), + weights=torch.cat(local_idscore_weights), + ) + if local_idscore_indices + else None + ) + + local_input = ModelInput( + float_features=global_float[r * batch_size : (r + 1) * batch_size], + idlist_features=local_idlist_kjt, + idscore_features=local_idscore_kjt, + label=global_label[r * batch_size : (r + 1) * batch_size], + ) + local_inputs.append(local_input) + + return ( + ModelInput( + float_features=global_float, + idlist_features=global_idlist_kjt, + idscore_features=global_idscore_kjt, + label=global_label, + ), + local_inputs, + ) + + def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput": + return ModelInput( + float_features=self.float_features.to( + device=device, non_blocking=non_blocking + ), + idlist_features=self.idlist_features.to( + device=device, non_blocking=non_blocking + ), + # pyre-ignore [6] + idscore_features=self.idscore_features.to( + device=device, non_blocking=non_blocking + ) + if self.idscore_features is not None + else None, + label=self.label.to(device=device, non_blocking=non_blocking), + ) + + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + self.float_features.record_stream(stream) + self.idlist_features.record_stream(stream) + if self.idscore_features is not None: + self.idscore_features.record_stream(stream) + self.label.record_stream(stream) + + +class TestDenseArch(nn.Module): + """ + Basic nn.Module for testing + + Constructor Args: + device + + Call Args: + dense_input: torch.Tensor + + Returns: + KeyedTensor + + Example: + >>> TestDenseArch() + """ + + def __init__( + self, + num_float_features: int = 10, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + if device is None: + device = torch.device("cpu") + self.linear: nn.modules.Linear = nn.Linear( + in_features=num_float_features, out_features=8, device=device + ) + + def forward(self, dense_input: torch.Tensor) -> torch.Tensor: + return self.linear(dense_input) + + +class TestOverArch(nn.Module): + """ + Basic nn.Module for testing + + Constructor Args: + device + + Call Args: + dense: torch.Tensor, + sparse: KeyedTensor, + + Returns: + torch.Tensor + + Example: + >>> TestOverArch() + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + if device is None: + device = torch.device("cpu") + self._features: List[str] = [ + feature for table in tables for feature in table.feature_names + ] + self._weighted_features: List[str] = [ + feature for table in weighted_tables for feature in table.feature_names + ] + in_features = ( + 8 + + sum([table.embedding_dim * len(table.feature_names) for table in tables]) + + sum( + [ + table.embedding_dim * len(table.feature_names) + for table in weighted_tables + ] + ) + ) + self.linear: nn.modules.Linear = nn.Linear( + in_features=in_features, out_features=16, device=device + ) + + def forward( + self, + dense: torch.Tensor, + sparse: KeyedTensor, + ) -> torch.Tensor: + ret_list = [] + ret_list.append(dense) + for feature_name in self._features: + ret_list.append(sparse[feature_name]) + for feature_name in self._weighted_features: + ret_list.append(sparse[feature_name]) + return self.linear(torch.cat(ret_list, dim=1)) + + +class TestSparseArch(nn.Module): + """ + Basic nn.Module for testing + + Constructor Args: + tables + device + + Call Args: + features + + Returns: + KeyedTensor + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + if device is None: + device = torch.device("cpu") + self.ebc: EmbeddingBagCollection = EmbeddingBagCollection( + tables=tables, + device=device, + ) + self.weighted_ebc: EmbeddingBagCollection = EmbeddingBagCollection( + tables=weighted_tables, + is_weighted=True, + device=device, + ) + + def forward( + self, features: KeyedJaggedTensor, weighted_features: KeyedJaggedTensor + ) -> KeyedTensor: + ebc = self.ebc(features) + w_ebc = self.weighted_ebc(weighted_features) + return KeyedTensor( + keys=ebc.keys() + w_ebc.keys(), + length_per_key=ebc.length_per_key() + w_ebc.length_per_key(), + values=torch.cat([ebc.values(), w_ebc.values()], dim=1), + ) + + +class TestSparseNNBase(nn.Module): + """ + Base class for a SparseNN model. + + Constructor Args: + tables: List[BaseEmbeddingConfig], + weighted_tables: Optional[List[BaseEmbeddingConfig]], + embedding_groups: Optional[Dict[str, List[str]]], + dense_device: Optional[torch.device], + sparse_device: Optional[torch.device], + """ + + def __init__( + self, + tables: List[BaseEmbeddingConfig], + weighted_tables: Optional[List[BaseEmbeddingConfig]] = None, + embedding_groups: Optional[Dict[str, List[str]]] = None, + dense_device: Optional[torch.device] = None, + sparse_device: Optional[torch.device] = None, + ) -> None: + super().__init__() + if dense_device is None: + dense_device = torch.device("cpu") + if sparse_device is None: + sparse_device = torch.device("cpu") + + +class TestSparseNN(TestSparseNNBase): + """ + Simple version of a SparseNN model. + + Constructor Args: + tables: List[EmbeddingBagConfig], + weighted_tables: Optional[List[EmbeddingBagConfig]], + embedding_groups: Optional[Dict[str, List[str]]], + dense_device: Optional[torch.device], + sparse_device: Optional[torch.device], + + Call Args: + input: ModelInput, + + Returns: + torch.Tensor + + Example: + >>> TestSparseNN() + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + num_float_features: int = 10, + weighted_tables: Optional[List[EmbeddingBagConfig]] = None, + embedding_groups: Optional[Dict[str, List[str]]] = None, + dense_device: Optional[torch.device] = None, + sparse_device: Optional[torch.device] = None, + ) -> None: + super().__init__( + tables=cast(List[BaseEmbeddingConfig], tables), + weighted_tables=cast(Optional[List[BaseEmbeddingConfig]], weighted_tables), + embedding_groups=embedding_groups, + dense_device=dense_device, + sparse_device=sparse_device, + ) + if weighted_tables is None: + weighted_tables = [] + + self.dense = TestDenseArch(num_float_features, dense_device) + self.sparse = TestSparseArch(tables, weighted_tables, sparse_device) + self.over = TestOverArch(tables, weighted_tables, dense_device) + + def forward( + self, + input: ModelInput, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + dense_r = self.dense(input.float_features) + sparse_r = self.sparse(input.idlist_features, input.idscore_features) + over_r = self.over(dense_r, sparse_r) + pred = torch.sigmoid(torch.mean(over_r, dim=1)) + if self.training: + return ( + torch.nn.functional.binary_cross_entropy_with_logits(pred, input.label), + pred, + ) + else: + return pred + + +class TestEBCSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def __init__( + self, sharding_type: str, kernel_type: str, fused_params: Dict[str, Any] = {} + ) -> None: + self._sharding_type = sharding_type + self._kernel_type = kernel_type + self._fused_params = fused_params + + """ + Restricts sharding to single type only. + """ + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [self._sharding_type] + + """ + Restricts to single impl. + """ + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [self._kernel_type] + + @property + def fused_params(self) -> Optional[Dict[str, Any]]: + return self._fused_params + + +class TestEBSharder(EmbeddingBagSharder[nn.EmbeddingBag]): + def __init__( + self, sharding_type: str, kernel_type: str, fused_params: Dict[str, Any] + ) -> None: + self._sharding_type = sharding_type + self._kernel_type = kernel_type + self._fused_params = fused_params + + """ + Restricts sharding to single type only. + """ + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [self._sharding_type] + + """ + Restricts to single impl. + """ + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [self._kernel_type] + + @property + def fused_params(self) -> Optional[Dict[str, Any]]: + return self._fused_params diff --git a/torchrec/distributed/tests/test_model_parallel.py b/torchrec/distributed/tests/test_model_parallel.py new file mode 100644 index 000000000..5f116e808 --- /dev/null +++ b/torchrec/distributed/tests/test_model_parallel.py @@ -0,0 +1,645 @@ +#!/usr/bin/env python3 + +import os +import unittest +from collections import OrderedDict +from enum import Enum +from typing import List, Tuple, Optional, Dict, cast, Union + +import hypothesis.strategies as st +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from fbgemm_gpu.split_embedding_configs import EmbOptimType +from hypothesis import Verbosity, given, settings +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.embeddingbag import EmbeddingBagSharder +from torchrec.distributed.model_parallel import ( + DistributedModelParallel, + default_sharders, +) +from torchrec.distributed.planner import EmbeddingShardingPlanner +from torchrec.distributed.planner.types import ParameterHints +from torchrec.distributed.tests.test_model import ( + TestSparseNN, + TestSparseNNBase, + TestEBCSharder, + TestEBSharder, + ModelInput, +) +from torchrec.distributed.tests.test_model_parallel_base import ModelParallelTestBase +from torchrec.distributed.types import ( + ModuleSharder, + ShardedTensor, + ShardingType, + ShardingEnv, +) +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.tests.utils import ( + get_free_port, + skip_if_asan_class, + init_distributed_single_host, + seed_and_log, +) + + +class SharderType(Enum): + EMBEDDING_BAG = "embedding_bag" + EMBEDDING_BAG_COLLECTION = "embedding_bag_collection" + + +def create_test_sharder( + sharder_type: str, sharding_type: str, kernel_type: str +) -> Union[TestEBSharder, TestEBCSharder]: + if sharder_type == SharderType.EMBEDDING_BAG.value: + return TestEBSharder(sharding_type, kernel_type, {"learning_rate": 0.1}) + elif sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value: + return TestEBCSharder(sharding_type, kernel_type, {"learning_rate": 0.1}) + else: + raise ValueError(f"Sharder not supported {sharder_type}") + + +@skip_if_asan_class +class ModelParallelTest(ModelParallelTestBase): + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.ROW_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.SPARSE.value, + EmbeddingComputeKernel.BATCHED_DENSE.value, + EmbeddingComputeKernel.BATCHED_FUSED.value, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_sharding_nccl_rw( + self, + sharder_type: str, + sharding_type: str, + kernel_type: str, + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder(sharder_type, sharding_type, kernel_type), + ], + backend="nccl", + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.DATA_PARALLEL.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.BATCHED_DENSE.value, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_sharding_nccl_dp( + self, sharder_type: str, sharding_type: str, kernel_type: str + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder(sharder_type, sharding_type, kernel_type), + ], + backend="nccl", + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + # SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.SPARSE.value, + EmbeddingComputeKernel.BATCHED_DENSE.value, + EmbeddingComputeKernel.BATCHED_FUSED.value, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_sharding_nccl_tw( + self, sharder_type: str, sharding_type: str, kernel_type: str + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder(sharder_type, sharding_type, kernel_type), + ], + backend="nccl", + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_ROW_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.SPARSE.value, + EmbeddingComputeKernel.BATCHED_DENSE.value, + EmbeddingComputeKernel.BATCHED_FUSED.value, + ] + ), + local_size=st.sampled_from([2]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_sharding_nccl_twrw( + self, + sharder_type: str, + sharding_type: str, + kernel_type: str, + local_size: int, + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder(sharder_type, sharding_type, kernel_type), + ], + backend="nccl", + world_size=2, + local_size=local_size, + ) + + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + # TODO: enable it with correct semantics, see T104397332 + # SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.SPARSE.value, + EmbeddingComputeKernel.BATCHED_DENSE.value, + EmbeddingComputeKernel.BATCHED_FUSED.value, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_sharding_gloo_tw( + self, + sharder_type: str, + sharding_type: str, + kernel_type: str, + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder(sharder_type, sharding_type, kernel_type), + ], + backend="gloo", + ) + + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.DATA_PARALLEL.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.BATCHED_DENSE.value, + # TODO dp+batch_fused is numerically buggy in cpu + # EmbeddingComputeKernel.SPARSE.value, + # EmbeddingComputeKernel.BATCHED_FUSED.value, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_sharding_gloo_dp( + self, sharder_type: str, sharding_type: str, kernel_type: str + ) -> None: + self._test_sharding( + # pyre-ignore[6] + sharders=[ + create_test_sharder(sharder_type, sharding_type, kernel_type), + ], + backend="gloo", + ) + + def test_parameter_init(self) -> None: + class MyModel(nn.Module): + def __init__(self, device: str, val: float) -> None: + super().__init__() + self.p = nn.Parameter( + torch.empty(3, dtype=torch.float32, device=device) + ) + self.val = val + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.constant_(self.p, self.val) + + pg = init_distributed_single_host(rank=0, world_size=1, backend="gloo") + + # Check that already allocated parameters are left 'as is'. + cpu_model = MyModel(device="cpu", val=3.2) + sharded_model = DistributedModelParallel( + cpu_model, + env=ShardingEnv.from_process_group(pg), + ) + sharded_param = next(sharded_model.parameters()) + np.testing.assert_array_equal( + np.array([3.2, 3.2, 3.2], dtype=np.float32), sharded_param.detach().numpy() + ) + + # Check that parameters over 'meta' device are allocated and initialized. + meta_model = MyModel(device="meta", val=7.5) + sharded_model = DistributedModelParallel( + meta_model, + env=ShardingEnv.from_process_group(pg), + ) + sharded_param = next(sharded_model.parameters()) + np.testing.assert_array_equal( + np.array([7.5, 7.5, 7.5], dtype=np.float32), sharded_param.detach().numpy() + ) + + @seed_and_log + def setUp(self) -> None: + super().setUp() + torch.use_deterministic_algorithms(True) + if torch.cuda.is_available(): + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + + num_features = 4 + num_weighted_features = 2 + + self.tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 2) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + self.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 2) * 4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + + self.embedding_groups = { + "group_0": ["feature_" + str(i) for i in range(num_features)] + } + + def _test_sharding( + self, + sharders: List[ModuleSharder[nn.Module]], + backend: str = "gloo", + world_size: int = 2, + local_size: Optional[int] = None, + hints: Optional[Dict[str, ParameterHints]] = None, + ) -> None: + self._run_multi_process_test( + # pyre-ignore [6] + callable=self._test_sharding_single_rank, + world_size=world_size, + local_size=local_size, + model_class=cast(TestSparseNNBase, TestSparseNN), + tables=self.tables, + weighted_tables=self.weighted_tables, + embedding_groups=self.embedding_groups, + sharders=sharders, + backend=backend, + optim=EmbOptimType.EXACT_SGD, + hints=hints, + ) + + +class ModelParallelStateDictTest(unittest.TestCase): + def setUp(self) -> None: + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["LOCAL_WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP" + if torch.cuda.is_available(): + self.device = torch.device("cuda:0") + backend = "nccl" + torch.cuda.set_device(self.device) + else: + self.device = torch.device("cpu") + backend = "gloo" + if not dist.is_initialized(): + dist.init_process_group(backend=backend) + + num_features = 4 + num_weighted_features = 2 + self.batch_size = 3 + self.num_float_features = 10 + + self.tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + self.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + + def _generate_dmps_and_batch( + self, sharders: List[ModuleSharder[nn.Module]] = default_sharders + ) -> Tuple[List[DistributedModelParallel], ModelInput]: + _, local_batch = ModelInput.generate( + batch_size=self.batch_size, + world_size=1, + num_float_features=self.num_float_features, + tables=self.tables, + weighted_tables=self.weighted_tables, + ) + batch = local_batch[0].to(self.device) + + # Create two TestSparseNN modules, wrap both in DMP + dmps = [] + for _ in range(2): + m = TestSparseNN( + tables=self.tables, + num_float_features=self.num_float_features, + weighted_tables=self.weighted_tables, + dense_device=self.device, + sparse_device=torch.device("meta"), + ) + dmp = DistributedModelParallel( + module=m, + init_data_parallel=False, + device=self.device, + sharders=sharders, + ) + + with torch.no_grad(): + dmp(batch) + dmp.init_data_parallel() + dmps.append(dmp) + return (dmps, batch) + + def test_meta_device_dmp_state_dict(self) -> None: + env = ShardingEnv.from_process_group(dist.GroupMember.WORLD) + + m1 = TestSparseNN( + tables=self.tables, + num_float_features=self.num_float_features, + weighted_tables=self.weighted_tables, + dense_device=self.device, + ) + # dmp with real device + dmp1 = DistributedModelParallel( + module=m1, + init_data_parallel=False, + init_parameters=False, + sharders=default_sharders, + device=self.device, + env=env, + plan=EmbeddingShardingPlanner( + world_size=env.world_size, + compute_device_type=self.device.type, + ).plan(m1, default_sharders), + ) + + m2 = TestSparseNN( + tables=self.tables, + num_float_features=self.num_float_features, + weighted_tables=self.weighted_tables, + dense_device=self.device, + ) + # dmp with meta device + dmp2 = DistributedModelParallel( + module=m2, + init_data_parallel=False, + init_parameters=False, + sharders=default_sharders, + device=torch.device("meta"), + env=env, + plan=EmbeddingShardingPlanner( + world_size=env.world_size, + compute_device_type=self.device.type, + ).plan(m2, default_sharders), + ) + + sd1 = dmp1.state_dict() + for key, v2 in dmp2.state_dict().items(): + v1 = sd1[key] + if isinstance(v2, nn.parameter.UninitializedParameter) and isinstance( + v1, nn.parameter.UninitializedParameter + ): + continue + if isinstance(v2, ShardedTensor): + self.assertTrue(isinstance(v1, ShardedTensor)) + assert len(v2.local_shards()) == 1 + dst = v2.local_shards()[0].tensor + else: + dst = v2 + if isinstance(v1, ShardedTensor): + assert len(v1.local_shards()) == 1 + src = v1.local_shards()[0].tensor + else: + src = v1 + self.assertEqual(src.size(), dst.size()) + + # pyre-ignore[56] + @given( + sharders=st.sampled_from( + [ + [EmbeddingBagCollectionSharder()], + [EmbeddingBagSharder()], + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + def test_load_state_dict(self, sharders: List[ModuleSharder[nn.Module]]) -> None: + models, batch = self._generate_dmps_and_batch(sharders) + m1, m2 = models + + # load the second's (m2's) with the first (m1's) state_dict + m2.load_state_dict( + cast("OrderedDict[str, torch.Tensor]", m1.state_dict()), strict=False + ) + # validate the models are equivalent + with torch.no_grad(): + loss1, pred1 = m1(batch) + loss2, pred2 = m2(batch) + self.assertTrue(torch.equal(loss1, loss2)) + self.assertTrue(torch.equal(pred1, pred2)) + sd1 = m1.state_dict() + for key, value in m2.state_dict().items(): + v2 = sd1[key] + if isinstance(value, ShardedTensor): + assert len(value.local_shards()) == 1 + dst = value.local_shards()[0].tensor + else: + dst = value + if isinstance(v2, ShardedTensor): + assert len(v2.local_shards()) == 1 + src = v2.local_shards()[0].tensor + else: + src = v2 + self.assertTrue(torch.equal(src, dst)) + + # pyre-ignore[56] + @given( + sharders=st.sampled_from( + [ + [EmbeddingBagCollectionSharder()], + [EmbeddingBagSharder()], + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + def test_load_state_dict_prefix( + self, sharders: List[ModuleSharder[nn.Module]] + ) -> None: + (m1, m2), batch = self._generate_dmps_and_batch(sharders) + + # load the second's (m2's) with the first (m1's) state_dict + m2.load_state_dict( + cast("OrderedDict[str, torch.Tensor]", m1.state_dict(prefix="alpha")), + prefix="alpha", + strict=False, + ) + + # validate the models are equivalent + sd1 = m1.state_dict() + for key, value in m2.state_dict().items(): + v2 = sd1[key] + if isinstance(value, ShardedTensor): + assert len(value.local_shards()) == 1 + dst = value.local_shards()[0].tensor + else: + dst = value + if isinstance(v2, ShardedTensor): + assert len(v2.local_shards()) == 1 + src = v2.local_shards()[0].tensor + else: + src = v2 + self.assertTrue(torch.equal(src, dst)) + + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG.value, + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.SPARSE.value, + # EmbeddingComputeKernel.BATCHED_DENSE.value, + EmbeddingComputeKernel.BATCHED_FUSED.value, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_params_and_buffers( + self, sharder_type: str, sharding_type: str, kernel_type: str + ) -> None: + sharders = [ + create_test_sharder(sharder_type, sharding_type, kernel_type), + ] + # pyre-ignore[6] + (m, _), batch = self._generate_dmps_and_batch(sharders=sharders) + state_dict_keys = set(m.state_dict().keys()) + param_keys = {key for (key, _) in m.named_parameters()} + buffer_keys = {key for (key, _) in m.named_buffers()} + self.assertEqual(state_dict_keys, {*param_keys, *buffer_keys}) diff --git a/torchrec/distributed/tests/test_model_parallel_base.py b/torchrec/distributed/tests/test_model_parallel_base.py new file mode 100644 index 000000000..1114dcf2c --- /dev/null +++ b/torchrec/distributed/tests/test_model_parallel_base.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 + +import multiprocessing +import os +import unittest +from typing import cast, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from fbgemm_gpu.split_embedding_configs import EmbOptimType +from torchrec.distributed.embedding_types import EmbeddingTableConfig +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.planner import EmbeddingShardingPlanner +from torchrec.distributed.planner.types import ParameterHints +from torchrec.distributed.tests.test_model import ( + ModelInput, + TestSparseNNBase, +) +from torchrec.distributed.types import ( + ShardedTensor, + ModuleSharder, + ShardingEnv, +) +from torchrec.modules.embedding_configs import BaseEmbeddingConfig +from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper +from torchrec.tests.utils import ( + get_free_port, + seed_and_log, + init_distributed_single_host, +) + + +class ModelParallelTestBase(unittest.TestCase): + @seed_and_log + def setUp(self) -> None: + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP" + os.environ["NCCL_SOCKET_IFNAME"] = "lo" + + @classmethod + def _test_sharding_single_rank( + cls, + rank: int, + world_size: int, + model_class: TestSparseNNBase, + embedding_groups: Dict[str, List[str]], + tables: List[EmbeddingTableConfig], + sharders: List[ModuleSharder[nn.Module]], + backend: str, + optim: EmbOptimType, + weighted_tables: Optional[List[EmbeddingTableConfig]] = None, + hints: Optional[Dict[str, ParameterHints]] = None, + local_size: Optional[int] = None, + ) -> None: + # Override local_size after pg construction because unit test device count + # is larger than local_size setup. This can be problematic for twrw because + # we have ShardedTensor placement check. + os.environ["LOCAL_WORLD_SIZE"] = str(world_size) + if backend == "nccl": + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + else: + device = torch.device("cpu") + pg = init_distributed_single_host( + rank=rank, + world_size=world_size, + backend=backend, + local_size=local_size, + ) + + # Generate model & inputs. + (global_model, inputs) = cls._gen_model_and_input( + model_class=model_class, + tables=tables, + weighted_tables=weighted_tables, + embedding_groups=embedding_groups, + world_size=world_size, + num_float_features=16, + ) + global_model = global_model.to(device) + global_input = inputs[0][0].to(device) + local_input = inputs[0][1][rank].to(device) + + # Shard model. + local_model = model_class( + tables=cast(List[BaseEmbeddingConfig], tables), + weighted_tables=cast(List[BaseEmbeddingConfig], weighted_tables), + embedding_groups=embedding_groups, + dense_device=device, + sparse_device=torch.device("meta"), + num_float_features=16, + ) + + planner = EmbeddingShardingPlanner(world_size, device.type, hints) + plan = planner.collective_plan(local_model, sharders, pg) + + local_model = DistributedModelParallel( + local_model, + env=ShardingEnv.from_process_group(pg), + plan=plan, + sharders=sharders, + device=device, + ) + + dense_optim = KeyedOptimizerWrapper( + dict(local_model.named_parameters()), + lambda params: torch.optim.SGD(params, lr=0.1), + ) + local_opt = CombinedOptimizer([local_model.fused_optimizer, dense_optim]) + + # Load model state from the global model. + cls._copy_state_dict(local_model.state_dict(), global_model.state_dict()) + + # Run a single training step of the sharded model. + local_pred = cls._gen_full_pred_after_one_step( + local_model, local_opt, local_input + ) + all_local_pred = [] + for _ in range(world_size): + all_local_pred.append(torch.empty_like(local_pred)) + dist.all_gather(all_local_pred, local_pred, group=pg) + + # Run second training step of the unsharded model. + assert optim == EmbOptimType.EXACT_SGD + global_opt = torch.optim.SGD(global_model.parameters(), lr=0.1) + global_pred = cls._gen_full_pred_after_one_step( + global_model, global_opt, global_input + ) + + # Compare predictions of sharded vs unsharded models. + torch.testing.assert_allclose(global_pred, torch.cat(all_local_pred)) + + def _run_multi_process_test( + self, + callable: Callable[ + [int, int, List[ModuleSharder[nn.Module]], List[torch.Tensor]], None + ], + world_size: int, + sharders: List[ModuleSharder[nn.Module]], + tables: List[EmbeddingTableConfig], + backend: str, + optim: EmbOptimType, + model_class: TestSparseNNBase, + embedding_groups: Optional[Dict[str, List[str]]] = None, + weighted_tables: Optional[List[EmbeddingTableConfig]] = None, + hints: Optional[Dict[str, ParameterHints]] = None, + local_size: Optional[int] = None, + ) -> None: + ctx = multiprocessing.get_context("spawn") + processes = [] + for rank in range(world_size): + p = ctx.Process( + target=callable, + args=( + rank, + world_size, + model_class, + embedding_groups, + tables, + sharders, + backend, + optim, + weighted_tables, + hints, + ), + kwargs={ + "local_size": local_size, + }, + ) + p.start() + processes.append(p) + + for p in processes: + p.join() + self.assertEqual(0, p.exitcode) + + @staticmethod + def _generate_inputs( + world_size: int, + tables: List[EmbeddingTableConfig], + weighted_tables: Optional[List[EmbeddingTableConfig]] = None, + batch_size: int = 4, + num_float_features: int = 16, + ) -> Tuple[ModelInput, List[ModelInput]]: + return ModelInput.generate( + batch_size=batch_size, + world_size=world_size, + num_float_features=num_float_features, + tables=tables, + weighted_tables=weighted_tables or [], + ) + + @classmethod + def _gen_model_and_input( + cls, + model_class: TestSparseNNBase, + tables: List[EmbeddingTableConfig], + embedding_groups: Dict[str, List[str]], + world_size: int, + weighted_tables: Optional[List[EmbeddingTableConfig]] = None, + num_float_features: int = 16, + ) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]: + torch.manual_seed(0) + + model = model_class( + tables=cast(List[BaseEmbeddingConfig], tables), + num_float_features=num_float_features, + weighted_tables=cast(List[BaseEmbeddingConfig], weighted_tables), + embedding_groups=embedding_groups, + ) + inputs = [ + cls._generate_inputs( + world_size=world_size, + tables=tables, + weighted_tables=weighted_tables, + num_float_features=num_float_features, + ) + ] + return (model, inputs) + + @classmethod + def _gen_full_pred_after_one_step( + cls, + model: nn.Module, + opt: torch.optim.Optimizer, + input: ModelInput, + ) -> torch.Tensor: + # Run a single training step of the global model. + opt.zero_grad() + model.train(True) + loss, _ = model(input) + loss.backward() + opt.step() + + # Run a forward pass of the global model. + with torch.no_grad(): + model.train(False) + full_pred = model(input) + return full_pred + + @classmethod + def _copy_state_dict( + cls, + loc: Dict[str, Union[torch.Tensor, ShardedTensor]], + glob: Dict[str, torch.Tensor], + ) -> None: + for name, tensor in loc.items(): + assert name in glob + global_tensor = glob[name] + if isinstance(global_tensor, ShardedTensor): + global_tensor = global_tensor.local_shards()[0].tensor + if isinstance(tensor, ShardedTensor): + for local_shard in tensor.local_shards(): + assert global_tensor.ndim == local_shard.tensor.ndim + shard_meta = local_shard.metadata + t = global_tensor.detach() + if t.ndim == 1: + t = t[ + shard_meta.shard_offsets[0] : shard_meta.shard_offsets[0] + + local_shard.tensor.shape[0] + ] + elif t.ndim == 2: + t = t[ + shard_meta.shard_offsets[0] : shard_meta.shard_offsets[0] + + local_shard.tensor.shape[0], + shard_meta.shard_offsets[1] : shard_meta.shard_offsets[1] + + local_shard.tensor.shape[1], + ] + else: + raise ValueError("Tensors with ndim > 2 are not supported") + local_shard.tensor.copy_(t) + else: + tensor.copy_(global_tensor) diff --git a/torchrec/distributed/tests/test_quant_model_parallel.py b/torchrec/distributed/tests/test_quant_model_parallel.py new file mode 100644 index 000000000..6a84ed70b --- /dev/null +++ b/torchrec/distributed/tests/test_quant_model_parallel.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 + +import copy +import unittest +from typing import List + +import torch +from torch import nn +from torch import quantization as quant +from torchrec.distributed.embedding_lookup import ( + GroupedEmbeddingBag, + BatchedFusedEmbeddingBag, + BatchedDenseEmbeddingBag, + QuantBatchedEmbeddingBag, +) +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embeddingbag import ( + QuantEmbeddingBagCollectionSharder, +) +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.tests.test_model import ( + TestSparseNN, + TestEBCSharder, +) +from torchrec.distributed.types import ShardingType, ShardingEnv +from torchrec.distributed.utils import sharded_model_copy +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.quant.embedding_modules import ( + EmbeddingBagCollection as QuantEmbeddingBagCollection, +) + + +class TestQuantEBCSharder(QuantEmbeddingBagCollectionSharder): + def __init__(self, sharding_type: str, kernel_type: str) -> None: + self._sharding_type = sharding_type + self._kernel_type = kernel_type + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [self._sharding_type] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [self._kernel_type] + + +def _quantize_sharded(module: nn.Module, inplace: bool) -> nn.Module: + qconfig = quant.QConfigDynamic( + activation=quant.PlaceholderObserver, + weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8), + ) + return quant.quantize_dynamic( + module, + qconfig_spec={ + GroupedEmbeddingBag: qconfig, + BatchedFusedEmbeddingBag: qconfig, + BatchedDenseEmbeddingBag: qconfig, + }, + mapping={ + GroupedEmbeddingBag: QuantBatchedEmbeddingBag, + BatchedFusedEmbeddingBag: QuantBatchedEmbeddingBag, + BatchedDenseEmbeddingBag: QuantBatchedEmbeddingBag, + }, + inplace=inplace, + ) + + +def _quantize(module: nn.Module, inplace: bool) -> nn.Module: + qconfig = quant.QConfigDynamic( + activation=quant.PlaceholderObserver, + weight=quant.PlaceholderObserver.with_args(dtype=torch.qint8), + ) + return quant.quantize_dynamic( + module, + qconfig_spec={ + EmbeddingBagCollection: qconfig, + }, + mapping={ + EmbeddingBagCollection: QuantEmbeddingBagCollection, + }, + inplace=inplace, + ) + + +class QuantModelParallelTest(unittest.TestCase): + def setUp(self) -> None: + self.device = torch.device("cuda:0") + torch.cuda.set_device(self.device) + + num_features = 4 + num_weighted_features = 2 + + self.tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + self.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 10, + embedding_dim=(i + 1) * 4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + + # pyre-fixme[56] + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "No GPUs available", + ) + def test_quant_pred(self) -> None: + device = torch.device("cuda:0") + model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + num_float_features=10, + dense_device=device, + sparse_device=torch.device("meta"), + ) + quant_model = _quantize(model, inplace=True) + _ = DistributedModelParallel( + quant_model, + # pyre-ignore [6] + sharders=[ + TestQuantEBCSharder( + sharding_type=ShardingType.DATA_PARALLEL.value, + kernel_type=EmbeddingComputeKernel.BATCHED_QUANT.value, + ) + ], + device=device, + env=ShardingEnv.from_local(world_size=1, rank=0), + init_data_parallel=False, + ) + + # pyre-fixme[56] + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "No GPUs available", + ) + def test_quant_train(self) -> None: + device = torch.device("cuda:0") + model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + num_float_features=10, + dense_device=device, + sparse_device=torch.device("meta"), + ) + sharded_model = DistributedModelParallel( + model, + # pyre-ignore [6] + sharders=[ + TestEBCSharder( + sharding_type=ShardingType.DATA_PARALLEL.value, + kernel_type=EmbeddingComputeKernel.BATCHED_FUSED.value, + ) + ], + device=device, + env=ShardingEnv.from_local(world_size=1, rank=0), + init_data_parallel=False, + ) + with sharded_model_copy(device="cpu"): + sharded_model_cpu = copy.deepcopy(sharded_model) + _ = _quantize_sharded(sharded_model_cpu, inplace=True) diff --git a/torchrec/distributed/tests/test_train_pipeline.py b/torchrec/distributed/tests/test_train_pipeline.py new file mode 100644 index 000000000..660b18fdc --- /dev/null +++ b/torchrec/distributed/tests/test_train_pipeline.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 + +import os +import unittest +from dataclasses import dataclass +from typing import Tuple, List, Optional, Dict + +import torch +import torch.distributed as dist +from torch import nn, optim +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embedding_types import ( + SparseFeaturesList, +) +from torchrec.distributed.embeddingbag import ( + ShardedEmbeddingBagCollection, + EmbeddingBagCollectionSharder, +) +from torchrec.distributed.tests.test_model import ( + TestSparseNN, + ModelInput, + TestEBCSharder, +) +from torchrec.distributed.train_pipeline import ( + TrainPipelineBase, + TrainPipelineSparseDist, +) +from torchrec.distributed.types import ( + Awaitable, + ParameterSharding, + ShardedModuleContext, + ShardingEnv, +) +from torchrec.distributed.types import ( + ShardingType, +) +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.optim.keyed import KeyedOptimizerWrapper +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.tests.utils import get_free_port, init_distributed_single_host + + +class TestShardedEmbeddingBagCollection(ShardedEmbeddingBagCollection): + def input_dist( + self, + ctx: ShardedModuleContext, + features: KeyedJaggedTensor, + ) -> Awaitable[SparseFeaturesList]: + return super().input_dist(ctx, features) + + +class TestCustomEBCSharder(EmbeddingBagCollectionSharder[EmbeddingBagCollection]): + def shard( + self, + module: EmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + ) -> TestShardedEmbeddingBagCollection: + return TestShardedEmbeddingBagCollection( + module, params, env, self.fused_params, device + ) + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [ + ShardingType.TABLE_WISE.value, + ] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + return [EmbeddingComputeKernel.DENSE.value] + + +@dataclass +class ModelInputSimple: + float_features: torch.Tensor + label: torch.Tensor + + def to(self, device: torch.device, non_blocking: bool) -> "ModelInputSimple": + return ModelInputSimple( + float_features=self.float_features.to( + device=device, non_blocking=non_blocking + ), + label=self.label.to(device=device, non_blocking=non_blocking), + ) + + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + self.float_features.record_stream(stream) + self.label.record_stream(stream) + + +class TestModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.model = nn.Linear(10, 1) + self.loss_fn = nn.BCEWithLogitsLoss() + + def forward( + self, model_input: ModelInputSimple + ) -> Tuple[torch.Tensor, torch.Tensor]: + pred = self.model(model_input.float_features) + loss = self.loss_fn(pred, model_input.label) + return (loss, pred) + + +class TrainPipelineBaseTest(unittest.TestCase): + def setUp(self) -> None: + self.device = torch.device("cuda:0") + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_equal_to_non_pipelined(self) -> None: + model_cpu = TestModule() + model_gpu = TestModule().to(self.device) + model_gpu.load_state_dict(model_cpu.state_dict()) + optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01) + optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) + data = [ + ModelInputSimple( + float_features=torch.rand((10,)), + label=torch.randint(2, (1,), dtype=torch.float32), + ) + for b in range(5) + ] + dataloader = iter(data) + pipeline = TrainPipelineBase(model_gpu, optimizer_gpu, self.device) + + for example in data[:-1]: + optimizer_cpu.zero_grad() + loss, pred = model_cpu(example) + loss.backward() + optimizer_cpu.step() + + pred_gpu = pipeline.progress(dataloader) + + self.assertEquals(pred_gpu.device, self.device) + self.assertTrue(torch.isclose(pred_gpu.cpu(), pred)) + + +class TrainPipelineSparseDistTest(unittest.TestCase): + def setUp(self) -> None: + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + if not dist.is_initialized(): + self.pg = init_distributed_single_host(backend="gloo", rank=0, world_size=1) + else: + self.pg = dist.group.WORLD + + num_features = 4 + num_weighted_features = 2 + + self.tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 100, + embedding_dim=(i + 1) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + self.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 100, + embedding_dim=(i + 1) * 4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + + self.device = torch.device("cuda:0") + + def _test_move_cpu_gpu_helper( + self, distributed_model: DistributedModelParallel + ) -> None: + model_cpu = TestSparseNN( + tables=self.tables, weighted_tables=self.weighted_tables + ) + optimizer_cpu = optim.SGD(model_cpu.parameters(), lr=0.1) + optimizer_distributed = KeyedOptimizerWrapper( + dict(distributed_model.named_parameters()), + lambda params: optim.SGD(params, lr=0.1), + ) + pipeline = TrainPipelineSparseDist( + distributed_model, optimizer_distributed, self.device + ) + + data = [ + ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=1, + world_size=1, + num_float_features=10, + )[0] + for i in range(5) + ] + dataloader = iter(data) + + for example in data[:-2]: + optimizer_cpu.zero_grad() + loss, pred = model_cpu(example) + loss.backward() + optimizer_cpu.step() + + pred_gpu = pipeline.progress(dataloader) + + self.assertEquals(pred_gpu.device, self.device) + self.assertEquals(pred_gpu.cpu().size(), pred.size()) + self.assertEquals(len(pipeline._pipelined_modules), 2) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_move_cpu_gpu(self) -> None: + unsharded_model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + dense_device=self.device, + sparse_device=torch.device("meta"), + ) + distributed_model = DistributedModelParallel( + unsharded_model, + env=ShardingEnv.from_process_group(self.pg), + init_data_parallel=False, + device=self.device, + # pyre-ignore [6] + sharders=[ + TestEBCSharder( + sharding_type=ShardingType.TABLE_WISE.value, + kernel_type=EmbeddingComputeKernel.DENSE.value, + ) + ], + ) + self._test_move_cpu_gpu_helper(distributed_model) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + def test_pipelining(self) -> None: + unsharded_model = TestSparseNN( + tables=self.tables, + weighted_tables=self.weighted_tables, + dense_device=self.device, + sparse_device=torch.device("meta"), + ) + distributed_model = DistributedModelParallel( + unsharded_model, + env=ShardingEnv.from_process_group(self.pg), + init_data_parallel=False, + device=self.device, + # pyre-fixme [6] + sharders=[TestCustomEBCSharder()], + ) + self._test_move_cpu_gpu_helper(distributed_model) diff --git a/torchrec/distributed/tests/test_utils.py b/torchrec/distributed/tests/test_utils.py new file mode 100644 index 000000000..f4b4140e1 --- /dev/null +++ b/torchrec/distributed/tests/test_utils.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 + +import itertools +import math +import os +import random +import unittest +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +from torchrec.distributed.embedding_sharding import bucketize_kjt_before_all2all +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollectionSharder, +) +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.tests.test_model import TestSparseNN +from torchrec.distributed.utils import get_unsharded_module_names +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.sparse.tests.tests_utils import keyed_jagged_tensor_equals +from torchrec.tests.utils import get_free_port + + +def _compute_translated_lengths( + row_indices: List[int], + indices_offsets: List[int], + lengths_size: int, + trainers_size: int, + block_sizes: List[int], +) -> List[int]: + translated_lengths = [0] * trainers_size * lengths_size + + batch_size = int(lengths_size / len(block_sizes)) + iteration = feature_offset = batch_iteration = 0 + for start_offset, end_offset in zip(indices_offsets, indices_offsets[1:]): + # iterate all rows that belong to current feature and batch iteration + for row_idx in row_indices[start_offset:end_offset]: + # compute the owner of this row + trainer_offset = int(row_idx / block_sizes[feature_offset]) + # we do not have enough trainers to handle this row + if trainer_offset >= trainers_size: + continue + trainer_lengths_offset = trainer_offset * lengths_size + # compute the offset in lengths that is local in each trainer + local_lengths_offset = feature_offset * batch_size + batch_iteration + # increment the corresponding length in the trainer + translated_lengths[trainer_lengths_offset + local_lengths_offset] += 1 + # bookkeeping + iteration += 1 + feature_offset = int(iteration / batch_size) + batch_iteration = (batch_iteration + 1) % batch_size + return translated_lengths + + +def _compute_translated_indices_with_weights( + translated_lengths: List[int], + row_indices: List[int], + indices_offsets: List[int], + lengths_size: int, + weights: Optional[List[int]], + trainers_size: int, + block_sizes: List[int], +) -> List[Tuple[int, int]]: + translated_indices_with_weights = [(0, 0)] * len(row_indices) + + translated_indices_offsets = np.cumsum([0] + translated_lengths) + batch_size = int(lengths_size / len(block_sizes)) + iteration = feature_offset = batch_iteration = 0 + for start_offset, end_offset in zip(indices_offsets, indices_offsets[1:]): + # iterate all rows that belong to current feature and batch iteration + # and assign the translated row index to the corresponding offset in output + for current_offset in range(start_offset, end_offset): + row_idx = row_indices[current_offset] + feature_block_size = block_sizes[feature_offset] + # compute the owner of this row + trainer_offset = int(row_idx / feature_block_size) + if trainer_offset >= trainers_size: + continue + trainer_lengths_offset = trainer_offset * lengths_size + # compute the offset in lengths that is local in each trainer + local_lengths_offset = feature_offset * batch_size + batch_iteration + # since we know the number of rows belonging to each trainer, + # we can figure out the corresponding offset in the translated indices list + # for the current translated index + translated_indices_offset = translated_indices_offsets[ + trainer_lengths_offset + local_lengths_offset + ] + translated_indices_with_weights[translated_indices_offset] = ( + row_idx % feature_block_size, + weights[current_offset] if weights else 0, + ) + # the next row that goes to this trainer for this feature and batch + # combination goes to the next offset + translated_indices_offsets[ + trainer_lengths_offset + local_lengths_offset + ] += 1 + # bookkeeping + iteration += 1 + feature_offset = int(iteration / batch_size) + batch_iteration = (batch_iteration + 1) % batch_size + return translated_indices_with_weights + + +def block_bucketize_ref( + keyed_jagged_tensor: KeyedJaggedTensor, + trainers_size: int, + block_sizes: torch.Tensor, +) -> KeyedJaggedTensor: + lengths_list = keyed_jagged_tensor.lengths().view(-1).tolist() + indices_list = keyed_jagged_tensor.values().view(-1).tolist() + weights_list = ( + keyed_jagged_tensor.weights().view(-1).tolist() + if keyed_jagged_tensor.weights() is not None + else None + ) + block_sizes_list = block_sizes.view(-1).tolist() + lengths_size = len(lengths_list) + + """ + each element in indices_offsets signifies both the starting offset, in indices_list, + that corresponds to all rows in a particular feature and batch iteration, + and the ending offset of the previous feature/batch iteration + + For example: + given that features_size = 2 and batch_size = 2, an indices_offsets of + [0,1,4,6,6] signifies that: + + elements in indices_list[0:1] belongs to feature 0 batch 0 + elements in indices_list[1:4] belongs to feature 0 batch 1 + elements in indices_list[4:6] belongs to feature 1 batch 0 + elements in indices_list[6:6] belongs to feature 1 batch 1 + """ + indices_offsets = np.cumsum([0] + lengths_list) + + translated_lengths = _compute_translated_lengths( + row_indices=indices_list, + indices_offsets=indices_offsets, + lengths_size=lengths_size, + trainers_size=trainers_size, + block_sizes=block_sizes_list, + ) + translated_indices_with_weights = _compute_translated_indices_with_weights( + translated_lengths=translated_lengths, + row_indices=indices_list, + indices_offsets=indices_offsets, + lengths_size=lengths_size, + weights=weights_list, + trainers_size=trainers_size, + block_sizes=block_sizes_list, + ) + + translated_indices = [ + translated_index for translated_index, _ in translated_indices_with_weights + ] + + translated_weights = [ + translated_weight for _, translated_weight in translated_indices_with_weights + ] + + expected_keys = [ + f"{key}@bucket_{index}" + for index in range(trainers_size) + for key in keyed_jagged_tensor.keys() + ] + + return KeyedJaggedTensor( + keys=expected_keys, + lengths=torch.tensor( + translated_lengths, dtype=keyed_jagged_tensor.lengths().dtype + ) + .view(-1) + .cuda(), + values=torch.tensor( + translated_indices, dtype=keyed_jagged_tensor.values().dtype + ).cuda(), + weights=torch.tensor(translated_weights).float().cuda() + if weights_list + else None, + ) + + +class UtilsTest(unittest.TestCase): + def test_get_unsharded_module_names(self) -> None: + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["LOCAL_WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP" + device = torch.device("cpu") + backend = "gloo" + if not dist.is_initialized(): + dist.init_process_group(backend=backend) + tables = [ + EmbeddingBagConfig( + num_embeddings=10, + embedding_dim=4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(2) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=10, + embedding_dim=4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(2) + ] + m = TestSparseNN( + tables=tables, + weighted_tables=weighted_tables, + dense_device=device, + sparse_device=device, + ) + dmp = DistributedModelParallel( + module=m, + init_data_parallel=False, + device=device, + sharders=[ + EmbeddingBagCollectionSharder(), + ], + ) + + np.testing.assert_array_equal( + sorted(get_unsharded_module_names(dmp)), + sorted(["module.over", "module.dense"]), + ) + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "CUDA is not available", + ) + def test_kjt_bucketize_before_all2all(self) -> None: + index_type = random.choice([torch.int, torch.long]) + offset_type = random.choice([torch.int, torch.long]) + world_size = random.randint(1, 129) + MAX_NUM_FEATURES = 15 + MAX_BATCH_SIZE = 15 + MAX_LENGTH = 10 + # max number of rows needed for a given feature to have unique row index + MAX_ROW_COUNT = MAX_LENGTH * MAX_BATCH_SIZE + + num_features = random.randint(2, MAX_NUM_FEATURES) + batch_size = random.randint(2, MAX_BATCH_SIZE) + lengths_list = [ + random.randrange(MAX_LENGTH + 1) for _ in range(num_features * batch_size) + ] + keys_list = [f"feature_{i}" for i in range(num_features)] + # for each feature, generate unrepeated row indices + indices_lists = [ + random.sample( + range(MAX_ROW_COUNT), + # number of indices needed is the length sum of all batches for a feature + sum( + lengths_list[ + feature_offset * batch_size : (feature_offset + 1) * batch_size + ] + ), + ) + for feature_offset in range(num_features) + ] + indices_list = list(itertools.chain(*indices_lists)) + + weights_list = [random.randint(1, 100) for _ in range(len(indices_list))] + + # for each feature, calculate the minimum block size needed to + # distribute all rows to the available trainers + block_sizes_list = [ + math.ceil((max(feature_indices_list) + 1) / world_size) + for feature_indices_list in indices_lists + ] + + kjt = KeyedJaggedTensor( + keys=keys_list, + lengths=torch.tensor(lengths_list, dtype=offset_type) + .view(num_features * batch_size) + .cuda(), + values=torch.tensor(indices_list, dtype=index_type).cuda(), + weights=torch.tensor(weights_list, dtype=torch.float).cuda(), + ) + """ + each entry in block_sizes identifies how many hashes for each feature goes + to every rank; we have three featues in `self.features` + """ + block_sizes = torch.tensor(block_sizes_list, dtype=index_type).cuda() + + block_bucketized_kjt, _ = bucketize_kjt_before_all2all( + kjt, world_size, block_sizes, False, False + ) + + expected_block_bucketized_kjt = block_bucketize_ref( + kjt, + world_size, + block_sizes, + ) + + print(f"block_sizes: {block_sizes}") + print(f"num_features: {num_features}") + print(f"batch_size: {batch_size}") + print(f"world_size: {world_size}") + print(f"KeyedJaggedTensor: {kjt}") + print(f"block_bucketized KeyedJaggedTensor: {block_bucketized_kjt}") + print( + f"expected_block_bucketized KeyedJaggedTensor: {expected_block_bucketized_kjt}" + ) + self.assertTrue( + keyed_jagged_tensor_equals( + block_bucketized_kjt, expected_block_bucketized_kjt + ) + ) diff --git a/torchrec/distributed/train_pipeline.py b/torchrec/distributed/train_pipeline.py new file mode 100644 index 000000000..874cc1089 --- /dev/null +++ b/torchrec/distributed/train_pipeline.py @@ -0,0 +1,509 @@ +#!/usr/bin/env python3 + +import abc +import logging +from dataclasses import dataclass, field +from typing import ( + Iterator, + Tuple, + Optional, + TypeVar, + Generic, + cast, + Any, + Dict, + List, + Set, +) + +import torch +from torch.autograd.profiler import record_function +from torch.fx.node import Node +from torch.nn.parallel import DistributedDataParallel +from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule +from torchrec.distributed.types import Awaitable, ShardedModuleContext +from torchrec.types import Pipelineable, Multistreamable + +logger: logging.Logger = logging.getLogger(__name__) + + +In = TypeVar("In", bound=Pipelineable) +Out = TypeVar("Out") +DistIn = TypeVar("DistIn", bound=Multistreamable) +DistOut = TypeVar("DistOut") + + +class TrainPipeline(abc.ABC, Generic[In, Out]): + @abc.abstractmethod + def progress(self, dataloader_iter: Iterator[In]) -> Out: + pass + + +def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: + assert isinstance( + batch, (torch.Tensor, Pipelineable) + ), f"{type(batch)} must implement Pipelineable interface" + return cast(In, batch.to(device=device, non_blocking=non_blocking)) + + +def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None: + if stream is None: + return + torch.cuda.current_stream().wait_stream(stream) + # As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, + # PyTorch uses the "caching allocator" for memroy allocation for tensors. When a tensor is + # freed, its memory is likely to be reused by newly constructed tenosrs. By default, + # this allocator traces whether a tensor is still in use by only the CUDA stream where it + # was created. When a tensor is used by additional CUDA streams, we need to call record_stream + # to tell the allocator about all these streams. Otherwise, the allocator might free the + # underlying memory of the tensor once it is no longer used by the creator stream. This is + # a notable programming trick when we write programs using multi CUDA streams. + cur_stream = torch.cuda.current_stream() + assert isinstance( + batch, (torch.Tensor, Multistreamable) + ), f"{type(batch)} must implement Multistreamable interface" + batch.record_stream(cur_stream) + + +class TrainPipelineBase(TrainPipeline[In, Out]): + """ + This class runs training iterations using a pipeline of two stages, each as a CUDA stream, + namely, the current (default) stream and self._memcpy_stream. For each iteration, + self._memcpy_stream moves the input from host (CPU) memory to GPU memory, and the default + stream runs forward, backward, and optimization. + """ + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + ) -> None: + self._model = model + self._optimizer = optimizer + self._device = device + self._memcpy_stream: Optional[torch.cuda.streams.Stream] = ( + torch.cuda.Stream() if device.type == "cuda" else None + ) + self._cur_batch: Optional[In] = None + self._connected = False + + def _connect(self, dataloader_iter: Iterator[In]) -> None: + cur_batch = next(dataloader_iter) + self._cur_batch = cur_batch + with torch.cuda.stream(self._memcpy_stream): + self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True) + + with torch.no_grad(): + # Init lazy modules if any. An example lazy module is + # https://pytorch.org/docs/stable/generated/torch.nn.LazyLinear.html + model = self._model + model(self._cur_batch) + + # Make sure we init data parallel modules if not done yet. + if isinstance(model, DistributedModelParallel): + model.init_data_parallel() + + self._connected = True + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + if not self._connected: + self._connect(dataloader_iter) + + # Fetch next batch + with record_function("## next_batch ##"): + next_batch = next(dataloader_iter) + cur_batch = self._cur_batch + assert cur_batch is not None + + if self._model.training: + with record_function("## zero_grad ##"): + self._optimizer.zero_grad() + + with record_function("## wait_for_batch ##"): + _wait_for_batch(cur_batch, self._memcpy_stream) + + with record_function("## forward ##"): + (losses, output) = self._model(cur_batch) + + if self._model.training: + with record_function("## backward ##"): + torch.sum(losses, dim=0).backward() + + # Copy the next batch to GPU + self._cur_batch = cur_batch = next_batch + with record_function("## copy_batch_to_gpu ##"): + with torch.cuda.stream(self._memcpy_stream): + self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True) + + # Update + if self._model.training: + with record_function("## optimizer ##"): + self._optimizer.step() + + return output + + +class Tracer(torch.fx.Tracer): + def __init__(self, unsharded_module_names: List[str]) -> None: + super().__init__() + self._unsharded_module_names = unsharded_module_names + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + if ( + isinstance(m, ShardedModule) + or module_qualified_name in self._unsharded_module_names + ): + return True + return super().is_leaf_module(m, module_qualified_name) + + +@dataclass +class TrainPipelineContext: + # pyre-ignore [4] + input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict) + module_contexts: Dict[str, ShardedModuleContext] = field(default_factory=dict) + + +@dataclass +class ArgInfo: + # attributes of input batch, e.g. batch.attr1.attr2 call + # will produce ["attr1", "attr2"] + input_attrs: List[str] + # name for kwarg of pipelined forward() call or None + # for a positional arg + name: Optional[str] + + +class PipelinedForward(Generic[DistIn, DistOut, Out]): + def __init__( + self, + name: str, + args: List[ArgInfo], + module: ShardedModule[DistIn, DistOut, Out], + context: TrainPipelineContext, + dist_stream: Optional[torch.cuda.streams.Stream], + ) -> None: + self._name = name + self._args = args + self._module = module + self._context = context + self._dist_stream = dist_stream + + # pyre-ignore [2] + def __call__(self, *input, **kwargs) -> Awaitable[Out]: + assert self._name in self._context.input_dist_requests + request = self._context.input_dist_requests[self._name] + assert isinstance(request, Awaitable) + with record_function("## wait_sparse_data_dist ##"): + # Finish waiting on the dist_stream, + # in case some delayed stream scheduling happens during the wait() call. + with torch.cuda.stream(self._dist_stream): + data = request.wait() + + # Make sure that both result of input_dist and context + # are properly transferred to the current stream. + if self._dist_stream is not None: + torch.cuda.current_stream().wait_stream(self._dist_stream) + cur_stream = torch.cuda.current_stream() + + assert isinstance( + data, (torch.Tensor, Multistreamable) + ), f"{type(data)} must implement Multistreamable interface" + data.record_stream(cur_stream) + + ctx = self._context.module_contexts[self._name] + ctx.record_stream(cur_stream) + + return self._module.compute_and_output_dist( + self._context.module_contexts[self._name], data + ) + + @property + def name(self) -> str: + return self._name + + @property + def args(self) -> List[ArgInfo]: + return self._args + + +def _start_data_dist( + pipelined_modules: List[ShardedModule], + batch: In, + context: TrainPipelineContext, +) -> None: + context.input_dist_requests.clear() + context.module_contexts.clear() + for module in pipelined_modules: + forward = module.forward + assert isinstance(forward, PipelinedForward) + + # Retrieve argument. + args = [] + kwargs = {} + for arg_info in forward.args: + if arg_info.input_attrs: + arg = batch + for attr in arg_info.input_attrs: + arg = getattr(arg, attr) + if arg_info.name: + kwargs[arg_info.name] = arg + else: + args.append(arg) + else: + args.append(None) + + # Start input distribution. + module_ctx = module.create_context() + context.module_contexts[forward.name] = module_ctx + context.input_dist_requests[forward.name] = module.input_dist( + module_ctx, *args, **kwargs + ) + + +# pyre-ignore +def _get_node_args_helper(arguments, num_found: int) -> Tuple[List[ArgInfo], int]: + """A helper funtion that goes through the args/kwargs of a + Node and arranges them into a list of ArgInfos. It also counts the number + of (args + kwargs) found. + """ + arg_info_list = [ArgInfo([], None) for _ in range(len(arguments))] + for arg, arg_info in zip(arguments, arg_info_list): + if arg is None: + num_found += 1 + continue + while True: + if not isinstance(arg, torch.fx.Node): + break + child_node = arg + + if child_node.op == "placeholder": + num_found += 1 + break + elif ( + child_node.op == "call_function" + and child_node.target.__module__ == "builtins" + # pyre-ignore [16] + and child_node.target.__name__ == "getattr" + ): + arg_info.input_attrs.insert(0, child_node.args[1]) + arg = child_node.args[0] + else: + break + return arg_info_list, num_found + + +def _get_node_args(node: Node) -> Tuple[List[ArgInfo], int]: + num_found = 0 + pos_arg_info_list, num_found = _get_node_args_helper(node.args, num_found) + kwargs_arg_info_list, num_found = _get_node_args_helper( + node.kwargs.values(), num_found + ) + + # Replace with proper names for kwargs + for name, arg_info_list in zip(node.kwargs, kwargs_arg_info_list): + arg_info_list.name = name + + arg_info_list = pos_arg_info_list + kwargs_arg_info_list + return arg_info_list, num_found + + +def _get_unsharded_module_names_helper( + model: torch.nn.Module, + path: str, + unsharded_module_names: Set[str], +) -> bool: + sharded_children = set() + for name, child in model.named_children(): + curr_path = path + name + if isinstance(child, ShardedModule): + sharded_children.add(name) + else: + child_sharded = _get_unsharded_module_names_helper( + child, + curr_path + ".", + unsharded_module_names, + ) + if child_sharded: + sharded_children.add(name) + + if len(sharded_children) > 0: + for name, _ in model.named_children(): + if name not in sharded_children: + unsharded_module_names.add(path + name) + + return len(sharded_children) > 0 + + +def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]: + """ + Returns a list of top level modules do not contain any sharded sub modules. + """ + unsharded_module_names: Set[str] = set() + _get_unsharded_module_names_helper( + model, + "", + unsharded_module_names, + ) + return list(unsharded_module_names) + + +def _rewrite_model( # noqa C901 + model: torch.nn.Module, + context: TrainPipelineContext, + dist_stream: Optional[torch.cuda.streams.Stream], +) -> List[ShardedModule]: + + # Get underlying nn.Module + while isinstance(model, DistributedModelParallel) or isinstance( + model, DistributedDataParallel + ): + model = model.module + + # Collect a list of sharded modules. + sharded_modules = {} + for name, m in model.named_modules(): + if isinstance(m, ShardedModule): + sharded_modules[name] = m + + # Trace a model. + tracer = Tracer(_get_unsharded_module_names(model)) + graph = tracer.trace(model) + + # Select sharded modules, which are top-level in the forward call graph, + # i.e. which don't have input transformations, i.e. + # rely only on 'builtins.getattr'. + ret = [] + for node in graph.nodes: + if node.op == "call_module" and node.target in sharded_modules: + total_num_args = len(node.args) + len(node.kwargs) + if total_num_args == 0: + continue + arg_info_list, num_found = _get_node_args(node) + if num_found == total_num_args: + logger.info(f"Module '{node.target}'' will be pipelined") + child = sharded_modules[node.target] + child.forward = PipelinedForward( + node.target, arg_info_list, child, context, dist_stream + ) + ret.append(child) + return ret + + +class TrainPipelineSparseDist(TrainPipeline[In, Out]): + """ + This pipeline overlaps device transfer, and ShardedModule.input_dist() with + forward and backward. This helps hiding all2all latency while preserving + the training forward / backward ordering. + stage 3: forward, backward + stage 2: ShardedModule.input_dist() + stage 1: device transfer + + ShardedModule.input_dist() is only done for top-level modules in the call graph. + To be considered a top-level module, + a module can only depend on 'getattr' calls on input. + + Input model must be symbolically traceable + with the exception of ShardedModule and DistributedDataParallel modules. + """ + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + ) -> None: + self._model = model + self._optimizer = optimizer + self._device = device + # use two data streams to support two concurrent batches + if device.type == "cuda": + self._memcpy_stream: Optional[ + torch.cuda.streams.Stream + ] = torch.cuda.Stream() + self._data_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream() + else: + self._memcpy_stream: Optional[torch.cuda.streams.Stream] = None + self._data_stream: Optional[torch.cuda.streams.Stream] = None + self._batch_i: Optional[In] = None + self._batch_ip1: Optional[In] = None + self._batch_ip2: Optional[In] = None + self._connected = False + self._context = TrainPipelineContext() + self._pipelined_modules: List[ShardedModule] = [] + + def _connect(self, dataloader_iter: Iterator[In]) -> None: + # batch 1 + with torch.cuda.stream(self._memcpy_stream): + batch_i = next(dataloader_iter) + self._batch_i = batch_i = _to_device( + batch_i, self._device, non_blocking=True + ) + with torch.no_grad(): + # Init lazy modules if any. + model = self._model + model(self._batch_i) + + if isinstance(model, DistributedModelParallel): + model.init_data_parallel() + + # Try to pipeline input data dist. + self._pipelined_modules = _rewrite_model( + model, self._context, self._data_stream + ) + + with torch.cuda.stream(self._data_stream): + _wait_for_batch(batch_i, self._memcpy_stream) + _start_data_dist(self._pipelined_modules, batch_i, self._context) + + # batch 2 + with torch.cuda.stream(self._memcpy_stream): + batch_ip1 = next(dataloader_iter) + self._batch_ip1 = batch_ip1 = _to_device( + batch_ip1, self._device, non_blocking=True + ) + self._connected = True + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + if not self._connected: + self._connect(dataloader_iter) + + if self._model.training: + with record_function("## zero_grad ##"): + self._optimizer.zero_grad() + + with record_function("## copy_batch_to_gpu ##"): + with torch.cuda.stream(self._memcpy_stream): + batch_ip2 = next(dataloader_iter) + self._batch_ip2 = batch_ip2 = _to_device( + batch_ip2, self._device, non_blocking=True + ) + batch_i = cast(In, self._batch_i) + batch_ip1 = cast(In, self._batch_ip1) + + with record_function("## wait_for_batch ##"): + _wait_for_batch(batch_i, self._data_stream) + + # Forward + with record_function("## forward ##"): + (losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i)) + + with record_function("## sparse_data_dist ##"): + with torch.cuda.stream(self._data_stream): + _wait_for_batch(batch_ip1, self._memcpy_stream) + _start_data_dist(self._pipelined_modules, batch_ip1, self._context) + + if self._model.training: + # Backward + with record_function("## backward ##"): + torch.sum(losses, dim=0).backward() + + # Update + with record_function("## optimizer ##"): + self._optimizer.step() + + self._batch_i = batch_ip1 + self._batch_ip1 = batch_ip2 + + return output diff --git a/torchrec/distributed/tw_sharding.py b/torchrec/distributed/tw_sharding.py new file mode 100644 index 000000000..4f1e0cb41 --- /dev/null +++ b/torchrec/distributed/tw_sharding.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python3 + +from typing import List, Optional, Any, Dict, Tuple + +import torch +import torch.distributed as dist +from torch.distributed._sharding_spec import ShardMetadata +from torchrec.distributed.dist_data import ( + PooledEmbeddingsAllToAll, + SequenceEmbeddingAllToAll, +) +from torchrec.distributed.embedding_lookup import ( + GroupedPooledEmbeddingsLookup, + GroupedEmbeddingsLookup, +) +from torchrec.distributed.embedding_sharding import ( + EmbeddingSharding, + SparseFeaturesAllToAll, + group_tables, + BasePooledEmbeddingDist, + BaseSequenceEmbeddingDist, + BaseSparseFeaturesDist, + SequenceShardingContext, + BaseEmbeddingLookup, +) +from torchrec.distributed.embedding_types import ( + GroupedEmbeddingConfig, + SparseFeatures, + ShardedEmbeddingTable, + EmbeddingComputeKernel, + BaseGroupedFeatureProcessor, +) +from torchrec.distributed.types import ( + Awaitable, + ParameterSharding, +) +from torchrec.modules.embedding_configs import EmbeddingTableConfig + + +class TwSparseFeaturesDist(BaseSparseFeaturesDist): + def __init__( + self, + pg: dist.ProcessGroup, + id_list_features_per_rank: List[int], + id_score_list_features_per_rank: List[int], + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self._dist = SparseFeaturesAllToAll( + pg, + id_list_features_per_rank, + id_score_list_features_per_rank, + device, + ) + + def forward( + self, + sparse_features: SparseFeatures, + ) -> Awaitable[SparseFeatures]: + return self._dist(sparse_features) + + +class TwPooledEmbeddingDist(BasePooledEmbeddingDist): + def __init__( + self, + pg: dist.ProcessGroup, + dim_sum_per_rank: List[int], + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self._dist = PooledEmbeddingsAllToAll( + pg, + dim_sum_per_rank, + device, + ) + + def forward(self, local_embs: torch.Tensor) -> Awaitable[torch.Tensor]: + return self._dist(local_embs) + + +class TwSequenceEmbeddingDist(BaseSequenceEmbeddingDist): + def __init__( + self, + pg: dist.ProcessGroup, + features_per_rank: List[int], + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self._dist = SequenceEmbeddingAllToAll(pg, features_per_rank, device) + + def forward( + self, sharding_ctx: SequenceShardingContext, local_embs: torch.Tensor + ) -> Awaitable[torch.Tensor]: + return self._dist( + local_embs=local_embs, + lengths=sharding_ctx.lengths_after_input_dist, + input_splits=sharding_ctx.input_splits, + output_splits=sharding_ctx.output_splits, + unbucketize_permute_tensor=None, + ) + + +class TwEmbeddingSharding(EmbeddingSharding): + """ + Shards embedding bags table-wise, i.e.. a given embedding table is entirely placed on a selected rank. + """ + + def __init__( + self, + embedding_configs: List[ + Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] + ], + pg: dist.ProcessGroup, + device: Optional[torch.device] = None, + is_sequence: bool = False, + ) -> None: + super().__init__() + self._pg = pg + self._device = device + self._is_sequence = is_sequence + sharded_tables_per_rank = self._shard(embedding_configs) + self._grouped_embedding_configs_per_rank: List[ + List[GroupedEmbeddingConfig] + ] = [] + self._score_grouped_embedding_configs_per_rank: List[ + List[GroupedEmbeddingConfig] + ] = [] + ( + self._grouped_embedding_configs_per_rank, + self._score_grouped_embedding_configs_per_rank, + ) = group_tables(sharded_tables_per_rank) + self._grouped_embedding_configs: List[ + GroupedEmbeddingConfig + ] = self._grouped_embedding_configs_per_rank[dist.get_rank(pg)] + self._score_grouped_embedding_configs: List[ + GroupedEmbeddingConfig + ] = self._score_grouped_embedding_configs_per_rank[dist.get_rank(pg)] + + def _shard( + self, + embedding_configs: List[ + Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] + ], + ) -> List[List[ShardedEmbeddingTable]]: + world_size = self._pg.size() + tables_per_rank: List[List[ShardedEmbeddingTable]] = [ + [] for i in range(world_size) + ] + for config in embedding_configs: + # pyre-fixme [16] + shards = config[1].sharding_spec.shards + + for rank in range(world_size): + table = ShardedEmbeddingTable( + num_embeddings=config[0].num_embeddings, + embedding_dim=config[0].embedding_dim, + name=config[0].name, + embedding_names=[], + data_type=config[0].data_type, + feature_names=[], + pooling=config[0].pooling, + is_weighted=config[0].is_weighted, + has_feature_processor=config[0].has_feature_processor, + compute_kernel=EmbeddingComputeKernel(config[1].compute_kernel), + ) + + # pyre-fixme [16] + if rank == config[1].ranks[0]: + table.embedding_names = config[0].embedding_names + table.feature_names = config[0].feature_names + table.local_rows = config[0].num_embeddings + table.local_cols = config[0].embedding_dim + table.local_metadata = shards[0] + table.weight_init_min = config[0].weight_init_min + table.weight_init_max = config[0].weight_init_max + + tables_per_rank[rank].append(table) + return tables_per_rank + + def create_input_dist(self) -> BaseSparseFeaturesDist: + return TwSparseFeaturesDist( + self._pg, + self._id_list_features_per_rank(), + self._id_score_list_features_per_rank(), + self._device, + ) + + def create_lookup( + self, + fused_params: Optional[Dict[str, Any]], + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup: + if self._is_sequence: + return GroupedEmbeddingsLookup( + grouped_configs=self._grouped_embedding_configs, + fused_params=fused_params, + pg=self._pg, + device=self._device, + ) + else: + return GroupedPooledEmbeddingsLookup( + grouped_configs=self._grouped_embedding_configs, + grouped_score_configs=self._score_grouped_embedding_configs, + fused_params=fused_params, + pg=self._pg, + device=self._device, + feature_processor=feature_processor, + ) + + def create_pooled_output_dist(self) -> TwPooledEmbeddingDist: + return TwPooledEmbeddingDist( + self._pg, + self._dim_sum_per_rank(), + self._device, + ) + + def create_sequence_output_dist( + self, + ) -> BaseSequenceEmbeddingDist: + return TwSequenceEmbeddingDist( + self._pg, + self._id_list_features_per_rank(), + self._device, + ) + + def _dim_sum_per_rank(self) -> List[int]: + dim_sum_per_rank = [] + for grouped_embedding_configs, score_grouped_embedding_configs in zip( + self._grouped_embedding_configs_per_rank, + self._score_grouped_embedding_configs_per_rank, + ): + dim_sum = 0 + for grouped_config in grouped_embedding_configs: + dim_sum += grouped_config.dim_sum() + for grouped_config in score_grouped_embedding_configs: + dim_sum += grouped_config.dim_sum() + dim_sum_per_rank.append(dim_sum) + return dim_sum_per_rank + + def embedding_dims(self) -> List[int]: + embedding_dims = [] + for grouped_embedding_configs, score_grouped_embedding_configs in zip( + self._grouped_embedding_configs_per_rank, + self._score_grouped_embedding_configs_per_rank, + ): + for grouped_config in grouped_embedding_configs: + embedding_dims.extend(grouped_config.embedding_dims()) + for grouped_config in score_grouped_embedding_configs: + embedding_dims.extend(grouped_config.embedding_dims()) + return embedding_dims + + def embedding_names(self) -> List[str]: + embedding_names = [] + for grouped_embedding_configs, score_grouped_embedding_configs in zip( + self._grouped_embedding_configs_per_rank, + self._score_grouped_embedding_configs_per_rank, + ): + for grouped_config in grouped_embedding_configs: + embedding_names.extend(grouped_config.embedding_names()) + for grouped_config in score_grouped_embedding_configs: + embedding_names.extend(grouped_config.embedding_names()) + return embedding_names + + def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: + embedding_shard_metadata = [] + for grouped_embedding_configs, score_grouped_embedding_configs in zip( + self._grouped_embedding_configs_per_rank, + self._score_grouped_embedding_configs_per_rank, + ): + for grouped_config in grouped_embedding_configs: + embedding_shard_metadata.extend( + grouped_config.embedding_shard_metadata() + ) + for grouped_config in score_grouped_embedding_configs: + embedding_shard_metadata.extend( + grouped_config.embedding_shard_metadata() + ) + return embedding_shard_metadata + + def id_list_feature_names(self) -> List[str]: + id_list_feature_names = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: + for grouped_config in grouped_embedding_configs: + id_list_feature_names.extend(grouped_config.feature_names()) + return id_list_feature_names + + def id_score_list_feature_names(self) -> List[str]: + id_score_list_feature_names = [] + for ( + score_grouped_embedding_configs + ) in self._score_grouped_embedding_configs_per_rank: + for grouped_config in score_grouped_embedding_configs: + id_score_list_feature_names.extend(grouped_config.feature_names()) + return id_score_list_feature_names + + def _id_list_features_per_rank(self) -> List[int]: + id_list_features_per_rank = [] + for grouped_embedding_configs in self._grouped_embedding_configs_per_rank: + num_features = 0 + for grouped_config in grouped_embedding_configs: + num_features += grouped_config.num_features() + id_list_features_per_rank.append(num_features) + return id_list_features_per_rank + + def _id_score_list_features_per_rank(self) -> List[int]: + id_score_list_features_per_rank = [] + for ( + score_grouped_embedding_configs + ) in self._score_grouped_embedding_configs_per_rank: + num_features = 0 + for grouped_config in score_grouped_embedding_configs: + num_features += grouped_config.num_features() + id_score_list_features_per_rank.append(num_features) + return id_score_list_features_per_rank diff --git a/torchrec/distributed/twrw_sharding.py b/torchrec/distributed/twrw_sharding.py new file mode 100644 index 000000000..afce350d8 --- /dev/null +++ b/torchrec/distributed/twrw_sharding.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python3 + +import itertools +import math +from typing import List, Optional, Dict, Any, Tuple, cast + +import torch +import torch.distributed as dist +from torch.distributed._sharding_spec import ShardMetadata +from torchrec.distributed.comm import ( + intra_and_cross_node_pg, + get_local_size, +) +from torchrec.distributed.dist_data import ( + PooledEmbeddingsReduceScatter, + PooledEmbeddingsAllToAll, +) +from torchrec.distributed.embedding_lookup import GroupedPooledEmbeddingsLookup +from torchrec.distributed.embedding_sharding import ( + group_tables, + SparseFeaturesAllToAll, + BasePooledEmbeddingDist, + BaseSequenceEmbeddingDist, + BaseSparseFeaturesDist, + EmbeddingSharding, + BaseEmbeddingLookup, + bucketize_kjt_before_all2all, +) +from torchrec.distributed.embedding_types import ( + GroupedEmbeddingConfig, + SparseFeatures, + ShardedEmbeddingTable, + EmbeddingComputeKernel, + BaseGroupedFeatureProcessor, +) +from torchrec.distributed.types import ( + Awaitable, + ParameterSharding, +) +from torchrec.modules.embedding_configs import EmbeddingTableConfig + + +class TwRwSparseFeaturesDist(BaseSparseFeaturesDist): + def __init__( + self, + pg: dist.ProcessGroup, + intra_pg: dist.ProcessGroup, + num_id_list_features: int, + num_id_score_list_features: int, + id_list_features_per_rank: List[int], + id_score_list_features_per_rank: List[int], + id_list_feature_hash_sizes: List[int], + id_score_list_feature_hash_sizes: List[int], + device: Optional[torch.device] = None, + has_feature_processor: bool = False, + ) -> None: + super().__init__() + assert ( + pg.size() % intra_pg.size() == 0 + ), "currently group granularity must be node" + + self._world_size: int = pg.size() + self._local_size: int = intra_pg.size() + self._num_cross_nodes: int = self._world_size // self._local_size + id_list_feature_block_sizes = [ + math.ceil(hash_size / self._local_size) + for hash_size in id_list_feature_hash_sizes + ] + id_score_list_feature_block_sizes = [ + math.ceil(hash_size / self._local_size) + for hash_size in id_score_list_feature_hash_sizes + ] + + self._id_list_sf_staggered_shuffle: List[int] = self._staggered_shuffle( + id_list_features_per_rank + ) + self._id_score_list_sf_staggered_shuffle: List[int] = self._staggered_shuffle( + id_score_list_features_per_rank + ) + self.register_buffer( + "_id_list_feature_block_sizes_tensor", + torch.tensor( + id_list_feature_block_sizes, + device=device, + dtype=torch.int32, + ), + ) + self.register_buffer( + "_id_score_list_feature_block_sizes_tensor", + torch.tensor( + id_score_list_feature_block_sizes, + device=device, + dtype=torch.int32, + ), + ) + self.register_buffer( + "_id_list_sf_staggerd_shuffle_tensor", + torch.tensor( + self._id_list_sf_staggered_shuffle, + device=device, + dtype=torch.int32, + ), + ) + self.register_buffer( + "_id_score_list_sf_staggered_shuffle_tensor", + torch.tensor( + self._id_score_list_sf_staggered_shuffle, + device=device, + dtype=torch.int32, + ), + ) + self._dist = SparseFeaturesAllToAll( + pg, + id_list_features_per_rank, + id_score_list_features_per_rank, + device, + self._num_cross_nodes, + ) + self._has_feature_processor = has_feature_processor + + def forward( + self, + sparse_features: SparseFeatures, + ) -> Awaitable[SparseFeatures]: + bucketized_sparse_features = SparseFeatures( + id_list_features=bucketize_kjt_before_all2all( + sparse_features.id_list_features, + num_buckets=self._local_size, + block_sizes=self._id_list_feature_block_sizes_tensor, + output_permute=False, + bucketize_pos=self._has_feature_processor, + )[0].permute( + self._id_list_sf_staggered_shuffle, + self._id_list_sf_staggerd_shuffle_tensor, + ) + if sparse_features.id_list_features is not None + else None, + id_score_list_features=bucketize_kjt_before_all2all( + sparse_features.id_score_list_features, + num_buckets=self._local_size, + block_sizes=self._id_score_list_feature_block_sizes_tensor, + output_permute=False, + bucketize_pos=False, + )[0].permute( + self._id_score_list_sf_staggered_shuffle, + self._id_score_list_sf_staggered_shuffle_tensor, + ) + if sparse_features.id_score_list_features is not None + else None, + ) + return self._dist(bucketized_sparse_features) + + def _staggered_shuffle(self, features_per_rank: List[int]) -> List[int]: + nodes = self._world_size // self._local_size + features_per_node = [ + features_per_rank[node * self._local_size] for node in range(nodes) + ] + node_offsets = [0] + list(itertools.accumulate(features_per_node)) + num_features = node_offsets[-1] + + return [ + bucket * num_features + feature + for node in range(nodes) + for bucket in range(self._local_size) + for feature in range(node_offsets[node], node_offsets[node + 1]) + ] + + +class TwRwEmbeddingDist(BasePooledEmbeddingDist): + def __init__( + self, + cross_pg: dist.ProcessGroup, + intra_pg: dist.ProcessGroup, + dim_sum_per_node: List[int], + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self._intra_dist = PooledEmbeddingsReduceScatter(intra_pg) + self._cross_dist = PooledEmbeddingsAllToAll( + cross_pg, + dim_sum_per_node, + device, + ) + + def forward(self, local_embs: torch.Tensor) -> Awaitable[torch.Tensor]: + return self._cross_dist(self._intra_dist(local_embs).wait()) + + +class TwRwEmbeddingSharding(EmbeddingSharding): + """ + Shards embedding bags table-wise then row-wise + """ + + def __init__( + self, + embedding_configs: List[ + Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] + ], + pg: dist.ProcessGroup, + device: Optional[torch.device] = None, + is_sequence: bool = False, + ) -> None: + super().__init__() + if is_sequence: + raise RuntimeError( + "TABLE_ROW_WISE sharding does not support sequence embeddings." + ) + self._pg = pg + self._world_size: int = self._pg.size() + self._my_rank: int = self._pg.rank() + self._device = device + self._is_sequence = is_sequence + intra_pg, cross_pg = intra_and_cross_node_pg(device) + self._intra_pg: Optional[dist.ProcessGroup] = intra_pg + self._cross_pg: Optional[dist.ProcessGroup] = cross_pg + self._local_size: int = ( + intra_pg.size() if intra_pg else get_local_size(self._world_size) + ) + + sharded_tables_per_rank = self._shard(embedding_configs) + self._grouped_embedding_configs_per_rank: List[ + List[GroupedEmbeddingConfig] + ] = [] + self._score_grouped_embedding_configs_per_rank: List[ + List[GroupedEmbeddingConfig] + ] = [] + self._grouped_embedding_configs_per_node: List[ + List[GroupedEmbeddingConfig] + ] = [] + self._score_grouped_embedding_configs_per_node: List[ + List[GroupedEmbeddingConfig] + ] = [] + ( + self._grouped_embedding_configs_per_rank, + self._score_grouped_embedding_configs_per_rank, + ) = group_tables(sharded_tables_per_rank) + self._grouped_embedding_configs_per_node = [ + self._grouped_embedding_configs_per_rank[rank] + for rank in range(self._world_size) + if rank % self._local_size == 0 + ] + self._score_grouped_embedding_configs_per_node = [ + self._score_grouped_embedding_configs_per_rank[rank] + for rank in range(self._world_size) + if rank % self._local_size == 0 + ] + self._has_feature_processor: bool = False + for group_config in self._score_grouped_embedding_configs_per_node[ + self._my_rank // self._local_size + ]: + if group_config.has_feature_processor: + self._has_feature_processor = True + + def _shard( + self, + embedding_configs: List[ + Tuple[EmbeddingTableConfig, ParameterSharding, torch.Tensor] + ], + ) -> List[List[ShardedEmbeddingTable]]: + world_size = self._world_size + local_size = self._local_size + tables_per_rank: List[List[ShardedEmbeddingTable]] = [ + [] for i in range(world_size) + ] + for config in embedding_configs: + # pyre-ignore [16] + table_node = config[1].ranks[0] // local_size + # pyre-fixme [16] + shards = config[1].sharding_spec.shards + + for rank in range(world_size): + table = ShardedEmbeddingTable( + num_embeddings=config[0].num_embeddings, + embedding_dim=config[0].embedding_dim, + name=config[0].name, + embedding_names=[], + data_type=config[0].data_type, + feature_names=[], + pooling=config[0].pooling, + is_weighted=config[0].is_weighted, + block_size=config[1].block_size, + has_feature_processor=config[0].has_feature_processor, + compute_kernel=EmbeddingComputeKernel(config[1].compute_kernel), + ) + if ( + rank >= table_node * local_size + and rank < (table_node + 1) * local_size + ): + shard_idx = rank - (table_node * local_size) + table.embedding_names = config[0].embedding_names + table.feature_names = config[0].feature_names + table.local_rows = shards[shard_idx].shard_lengths[0] + table.local_cols = config[0].embedding_dim + table.local_metadata = shards[shard_idx] + table.weight_init_max = config[0].weight_init_max + table.weight_init_min = config[0].weight_init_min + + tables_per_rank[rank].append(table) + + return tables_per_rank + + def create_input_dist(self) -> BaseSparseFeaturesDist: + num_id_list_features = self._get_id_list_features_num() + num_id_score_list_features = self._get_id_score_list_features_num() + id_list_features_per_rank = self._features_per_rank( + self._grouped_embedding_configs_per_rank + ) + id_score_list_features_per_rank = self._features_per_rank( + self._score_grouped_embedding_configs_per_rank + ) + id_list_feature_hash_sizes = self._get_id_list_features_hash_sizes() + id_score_list_feature_hash_sizes = self._get_id_score_list_features_hash_sizes() + return TwRwSparseFeaturesDist( + pg=self._pg, + intra_pg=cast(dist.ProcessGroup, self._intra_pg), + num_id_list_features=num_id_list_features, + num_id_score_list_features=num_id_score_list_features, + id_list_features_per_rank=id_list_features_per_rank, + id_score_list_features_per_rank=id_score_list_features_per_rank, + id_list_feature_hash_sizes=id_list_feature_hash_sizes, + id_score_list_feature_hash_sizes=id_score_list_feature_hash_sizes, + device=self._device, + has_feature_processor=self._has_feature_processor, + ) + + def create_lookup( + self, + fused_params: Optional[Dict[str, Any]], + feature_processor: Optional[BaseGroupedFeatureProcessor] = None, + ) -> BaseEmbeddingLookup: + return GroupedPooledEmbeddingsLookup( + grouped_configs=self._grouped_embedding_configs_per_rank[self._my_rank], + grouped_score_configs=self._score_grouped_embedding_configs_per_rank[ + self._my_rank + ], + fused_params=fused_params, + pg=self._pg, + device=self._device, + feature_processor=feature_processor, + ) + + def create_pooled_output_dist(self) -> BasePooledEmbeddingDist: + return TwRwEmbeddingDist( + cross_pg=cast(dist.ProcessGroup, self._cross_pg), + intra_pg=cast(dist.ProcessGroup, self._intra_pg), + dim_sum_per_node=self._dim_sum_per_node(), + device=self._device, + ) + + def create_sequence_output_dist(self) -> BaseSequenceEmbeddingDist: + raise NotImplementedError + + def embedding_dims(self) -> List[int]: + embedding_dims = [] + for grouped_embedding_configs, score_grouped_embedding_configs in zip( + self._grouped_embedding_configs_per_node, + self._score_grouped_embedding_configs_per_node, + ): + for grouped_config in grouped_embedding_configs: + embedding_dims.extend(grouped_config.embedding_dims()) + for grouped_config in score_grouped_embedding_configs: + embedding_dims.extend(grouped_config.embedding_dims()) + return embedding_dims + + def embedding_names(self) -> List[str]: + embedding_names = [] + for grouped_embedding_configs, score_grouped_embedding_configs in zip( + self._grouped_embedding_configs_per_node, + self._score_grouped_embedding_configs_per_node, + ): + for grouped_config in grouped_embedding_configs: + embedding_names.extend(grouped_config.embedding_names()) + for grouped_config in score_grouped_embedding_configs: + embedding_names.extend(grouped_config.embedding_names()) + return embedding_names + + def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: + embedding_shard_metadata = [] + for grouped_config in self._grouped_embedding_configs_per_node: + for config in grouped_config: + embedding_shard_metadata.extend(config.embedding_shard_metadata()) + for grouped_config in self._score_grouped_embedding_configs_per_node: + for config in grouped_config: + embedding_shard_metadata.extend(config.embedding_shard_metadata()) + return embedding_shard_metadata + + def id_list_feature_names(self) -> List[str]: + id_list_feature_names = [] + for grouped_config in self._grouped_embedding_configs_per_node: + for config in grouped_config: + id_list_feature_names.extend(config.feature_names()) + return id_list_feature_names + + def id_score_list_feature_names(self) -> List[str]: + id_score_list_feature_names = [] + for grouped_config in self._score_grouped_embedding_configs_per_node: + for config in grouped_config: + id_score_list_feature_names.extend(config.feature_names()) + return id_score_list_feature_names + + def _get_id_list_features_num(self) -> int: + id_list_features_num: int = 0 + for grouped_config in self._grouped_embedding_configs_per_node: + for config in grouped_config: + id_list_features_num += config.num_features() + return id_list_features_num + + def _get_id_score_list_features_num(self) -> int: + id_score_list_features_num: int = 0 + for grouped_config in self._score_grouped_embedding_configs_per_node: + for config in grouped_config: + id_score_list_features_num += config.num_features() + return id_score_list_features_num + + def _get_id_list_features_hash_sizes(self) -> List[int]: + id_list_feature_hash_sizes: List[int] = [] + for grouped_config in self._grouped_embedding_configs_per_node: + for config in grouped_config: + id_list_feature_hash_sizes.extend(config.feature_hash_sizes()) + return id_list_feature_hash_sizes + + def _get_id_score_list_features_hash_sizes(self) -> List[int]: + id_score_list_feature_hash_sizes: List[int] = [] + for grouped_config in self._score_grouped_embedding_configs_per_node: + for config in grouped_config: + id_score_list_feature_hash_sizes.extend(config.feature_hash_sizes()) + return id_score_list_feature_hash_sizes + + def _dim_sum_per_node(self) -> List[int]: + dim_sum_per_rank = [] + for grouped_embedding_configs, score_grouped_embedding_configs in zip( + self._grouped_embedding_configs_per_node, + self._score_grouped_embedding_configs_per_node, + ): + dim_sum = 0 + for grouped_config in grouped_embedding_configs: + dim_sum += grouped_config.dim_sum() + for grouped_config in score_grouped_embedding_configs: + dim_sum += grouped_config.dim_sum() + dim_sum_per_rank.append(dim_sum) + return dim_sum_per_rank + + def _features_per_rank( + self, group: List[List[GroupedEmbeddingConfig]] + ) -> List[int]: + features_per_rank = [] + for grouped_embedding_configs in group: + num_features = 0 + for grouped_config in grouped_embedding_configs: + num_features += grouped_config.num_features() + features_per_rank.append(num_features) + return features_per_rank diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py new file mode 100644 index 000000000..254fea016 --- /dev/null +++ b/torchrec/distributed/types.py @@ -0,0 +1,542 @@ +#!/usr/bin/env python3 + +import abc +import operator +from dataclasses import dataclass +from enum import Enum, unique +from typing import ( + Any, + Dict, + Generic, + Optional, + TypeVar, + List, + Type, + Iterator, +) + +from torch.distributed._sharding_spec import ShardingSpec + +try: + # For python 3.6 and below, GenericMeta will be used by + # other metaclasses (i.e. AwaitableMeta) for customized + # behaviors, as Generic is non-trival metaclass in + # python 3.6 and below + from typing import GenericMeta # pyre-ignore: python 3.6 +except ImportError: + # In python 3.7+, GenericMeta doesn't exist as it's no + # longer a non-trival metaclass, + # see https://www.python.org/dev/peps/pep-0560/ + # So we make a fake type here in order to share the same + # code with python 3.6 or below, it will just be used as + # a placeholder for customized metaclass behaviors + # (i.e. Awaitable) + class GenericMeta(type): + pass + + +import torch +import torch.distributed as dist +import torch.fx +from torch import nn +from torch.distributed._sharded_tensor import ( # noqa + Shard, + ShardedTensor, + ShardedTensorMetadata, +) +from torch.distributed._sharding_spec import ShardMetadata # noqa +from torchrec.types import Multistreamable + + +class ShardingType(Enum): + """ + Well-known sharding types, + used by inter-module optimizations. + """ + + # Replicated on all ranks + DATA_PARALLEL = "data_parallel" + # Placed on a single rank + TABLE_WISE = "table_wise" + # Placed on multiple ranks as different sharded tables + COLUMN_WISE = "column_wise" + # Range-split on the first dimension across all ranks + ROW_WISE = "row_wise" + # Row-wise on the same node and table-wise across nodes + # Useful when having multiple ranks perf node + # and comms within a single node are more efficient than across nodes. + TABLE_ROW_WISE = "table_row_wise" + + +class ParameterStorage(Enum): + """ + Well-known physical resources, + which can be used as constraints by ShardingPlanner. + """ + + # GPU-attached memory + HBM = "hbm" + # CPU-attached memory + DDR = "ddr" + + +@unique +class ComputeKernel(Enum): + DEFAULT = "default" + + +W = TypeVar("W") +M = TypeVar("M", bound=nn.Module) +Out = TypeVar("Out") +CompIn = TypeVar("CompIn", bound=Multistreamable) +DistOut = TypeVar("DistOut") + + +class Awaitable(abc.ABC, Generic[W]): + @abc.abstractmethod + def wait(self) -> W: + pass + + +class NoWait(Awaitable[W]): + def __init__(self, obj: W) -> None: + self._obj = obj + + def wait(self) -> W: + return self._obj + + +class _LazyAwaitableMeta(GenericMeta, abc.ABCMeta, torch.fx.ProxyableClassMeta): + """ + The _LazyAwaitableMeta class that inherits both ABCMeta and ProxyableClassMeta + This is because ABCMeta/ProxyableClassMeta are both non-trival metaclasses + Declaring this separately to ensure the init order of metaclasses + + XXX: Generics are non-trival metaclass before python 3.7 but are removed + afterwards. we add GenericsMeta here to support version before 3.7. + """ + + pass + + +class LazyAwaitable(Awaitable[W], metaclass=_LazyAwaitableMeta): + """ + The LazyAwaitable type which exposes a `wait()` API, concrete types + can control how to initialize and how the `wait()` behavior should + be in order to achieve specific async operation. + + This base LazyAwaitable type is a "lazy" async type, which means it will + delay `wait()` as late as possible, see details in `__torch_function__` + below. This could help the model automatically enable computation and + communication overlap, model author doesn't need to manually call + `wait()` if the results is used by a pytorch function, or by other python + operations (NOTE: need to implement corresponding magic methods + like __getattr__ below) + + Some caveats: + * This works with Pytorch functions, but not any generic method, if + you would like to do arbitary python operations, you need to + implement the corresponding magic methods + * In the case that one function have two or more arguments are LazyAwaitable, + the lazy wait mechanism can't ensure perfect computation/communication + overlap (i.e. quickly waited the first one but long wait on the second) + """ + + def __init__( + self, + ) -> None: + super().__init__() + # _result is used to cache the results after the wait() is called. + self._result: Optional[W] = None + + @staticmethod + # pyre-ignore [2, 3] + def _wait_async(obj: Any) -> Any: + """ + This method is used internally to automatically wait when necessary + and cache the results of the `LazyAwaitable.wait()` call + """ + if isinstance(obj, LazyAwaitable): + if obj._result is None: + obj._result = obj.wait() + return obj._result + else: + return obj + + # pyre-ignore [2, 3] + def __torch_function__(self, func, types, args=(), kwargs=None): + """ + The LazyAwaitable type has a `__torch_function__` implementation. + This means when this type is seens as an argument to a PyTorch + function in a position where it expects a W, the PyTorch's + dispatcher will call into this function for special handling + + Our `__torch_function__` implementation goes through all of the + args and kwargs and checks if any of them are `LazyAwaitable`. + If it is, it will call `wait()` on it and replace the LazyAwaitable + type object with the result of wait. In this way, async values + are waited on when the concrete value is first needed and without + the user having to write an explicit `wait()` call. + """ + kwargs = kwargs or {} + + # wait() on all LazyAwaitable args/kwargs and replace + # them with the resulting value. + new_args = torch.fx.node.map_aggregate(args, LazyAwaitable._wait_async) + new_kwargs = torch.fx.node.map_aggregate(kwargs, LazyAwaitable._wait_async) + + return func(*new_args, **new_kwargs) + + # pyre-ignore [2, 3] + def __getattr__(self, name): + """ + overriding __getattr__ to allow LazyAwaitable to first wait and + then call getattr on the wait results. + """ + if name == "_result": + raise RuntimeError( + f"LazyAwaitable type {type(self)} has not been initialized properly, " + f"did you forget to call 'super()'?" + ) + + res = LazyAwaitable._wait_async(self) + return getattr(res, name) + + +class LazyNoWait(LazyAwaitable[W]): + def __init__(self, obj: W) -> None: + super().__init__() + self._obj = obj + + def wait(self) -> W: + return self._obj + + +# install magic methods +for orig_method_name in torch.fx.graph.magic_methods: + as_magic = f"__{orig_method_name}__" + + def scope(method): + def impl(*args, **kwargs): + lhs = args[0] + op_fn = getattr(operator, method) + if len(args) == 1: + return op_fn(LazyAwaitable._wait_async(lhs)) + elif len(args) == 2: + rhs = args[1] + return op_fn( + LazyAwaitable._wait_async(lhs), LazyAwaitable._wait_async(rhs) + ) + else: + raise RuntimeError(f"magic method {as_magic} not supported!") + + impl.__name__ = as_magic + setattr(LazyAwaitable, as_magic, impl) + + # pyre-ignore [16] + scope(orig_method_name) + +# install reflective magic methods +for orig_method_name in torch.fx.graph.reflectable_magic_methods: + as_magic = f"__r{orig_method_name}__" + # pyre-ignore [2, 3] + def scope(method): + # pyre-ignore [2, 3, 53] + def impl(self, rhs): + op_fn = getattr(operator, method) + return op_fn( + LazyAwaitable._wait_async(rhs), LazyAwaitable._wait_async(self) + ) + + impl.__name__ = as_magic + impl.__qualname__ = as_magic + setattr(LazyAwaitable, as_magic, impl) + + # pyre-ignore [16] + scope(orig_method_name) + + +@dataclass +class ParameterSharding: + """ + sharding_type and location are the only two dimensions + produced by ShardingPlanner. + We can add more in future, but this seems sufficient for immediate needs. + """ + + """ + How this parameter is sharded. See ShardingType for well-known types. + """ + sharding_type: str + + """ + Compute kernel to be used by this parameter. + """ + compute_kernel: str + + """ + ShardingType.TABLE_WISE - rank where this embedding is placed + ShardingType.COLUMN_WISE - rank where this embedding shards are placed, we see them as individual tables + ShardingType.TABLE_ROW_WISE - first rank when this embedding is placed + ShardingType.ROW_WISE, ShardingType.DATA_PARALLEL - unused + """ + ranks: Optional[List[int]] = None + + """ + The block size of sharding dim on each shard. + mainly used in cw, not applicable in tw/dp + """ + block_size: int = 0 + sharding_spec: Optional[ShardingSpec] = None + + +@dataclass +class ShardedModuleContext(Multistreamable): + pass + + +class EmptyContext(ShardedModuleContext): + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + pass + + # pyre-ignore [2] + def __setattr__(self, key: str, value: Any) -> None: + raise NotImplementedError() + + +class ShardingEnv: + """ + Provides abstraction over torch.distributed.ProcessGroup, + which practically enables DistributedModelParallel to be used during inference. + """ + + def __init__( + self, world_size: int, rank: int, pg: Optional[dist.ProcessGroup] = None + ) -> None: + self.world_size = world_size + self.rank = rank + self.process_group: Optional[dist.ProcessGroup] = pg + + @classmethod + def from_process_group(cls, pg: dist.ProcessGroup) -> "ShardingEnv": + """ + Creates ProcessGroup-based sharding environment. + Typically used during training. + """ + return cls(dist.get_world_size(pg), dist.get_rank(pg), pg) + + @classmethod + def from_local(cls, world_size: int, rank: int) -> "ShardingEnv": + """ + Creates a local host-based sharding environment. + Typically used during single host inference. + """ + return cls(world_size, rank, None) + + +class ShardedModule(abc.ABC, nn.Module, Generic[CompIn, DistOut, Out]): + """ + All model-parallel modules implement this interface. + Inputs and outputs are data-parallel. + 'input_dist' / 'output_dist' are responsible of transforming inputs / outputs + from data-parallel to model parallel and vise-versa. + """ + + @abc.abstractmethod + def __init__(self) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}") + + def create_context(self) -> ShardedModuleContext: + return EmptyContext() + + @abc.abstractmethod + def input_dist( + self, + ctx: ShardedModuleContext, + # pyre-ignore[2] + *input, + # pyre-ignore[2] + **kwargs, + ) -> Awaitable[CompIn]: + pass + + @abc.abstractmethod + def compute(self, ctx: ShardedModuleContext, dist_input: CompIn) -> DistOut: + pass + + @abc.abstractmethod + def output_dist( + self, ctx: ShardedModuleContext, output: DistOut + ) -> LazyAwaitable[Out]: + pass + + def compute_and_output_dist( + self, ctx: ShardedModuleContext, input: CompIn + ) -> LazyAwaitable[Out]: + """ + In case of multiple output distributions + it makes sense to override this method and initiate + output distibution as soon as corresponding compute completes. + """ + output = self.compute(ctx, input) + return self.output_dist(ctx, output) + + # pyre-ignore[2] + def forward(self, *input, **kwargs) -> LazyAwaitable[Out]: + ctx = self.create_context() + dist_input = self.input_dist(ctx, *input, **kwargs).wait() + return self.compute_and_output_dist(ctx, dist_input) + + def sparse_grad_parameter_names( + self, + destination: Optional[List[str]] = None, + prefix: str = "", + ) -> List[str]: + destination = [] if destination is None else destination + return destination + + def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: + for key, _ in self.named_parameters(prefix): + yield key + + +class ModuleSharder(abc.ABC, Generic[M]): + """ + ModuleSharder is per each module, which supports sharding, + e.g. EmbeddingBagCollection. + """ + + @abc.abstractclassmethod + # pyre-ignore [3] + def shard( + self, + module: M, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: torch.device, + ) -> ShardedModule[Any, Any, Any]: + """ + Does actual sharding. It will allocate parameters on the requested locations + as specified by corresponding ParameterSharding. + Default implementation is just data-parallel replication. + + Args: + module: module to shard + params: dict of fully qualified parameter name + (module path + parameter name, '.'-separated) to its sharding spec. + pg: process group to use + device: compute device + + Returns: + sharded module implementation + """ + ... + + @property + @abc.abstractmethod + def module_type(self) -> Type[M]: + ... + + def shardable_parameters(self, module: M) -> Dict[str, nn.Parameter]: + """ + List of parameters, which can be sharded. + """ + return dict(module.named_parameters()) + + def sharding_types(self, compute_device_type: str) -> List[str]: + """ + List of supported sharding types. See ShardingType for well-known examples. + """ + return [ShardingType.DATA_PARALLEL.value] + + def compute_kernels( + self, sharding_type: str, compute_device_type: str + ) -> List[str]: + """ + List of supported compute kernels for a given sharding_type and compute device. + """ + + return [ComputeKernel.DEFAULT.value] + + def storage_usage( + self, tensor: torch.Tensor, compute_device_type: str, compute_kernel: str + ) -> Dict[str, int]: + """ + List of system resources and corresponding usage given a compute device and + compute kernel + """ + + assert compute_device_type in {"cuda", "cpu"} + storage_map = {"cuda": ParameterStorage.HBM, "cpu": ParameterStorage.DDR} + return { + storage_map[compute_device_type].value: tensor.element_size() + * tensor.nelement() + } + + +@dataclass +class ShardingPlan: + """ + Representation of sharding plan. + Dict keyed by module path of dict of parameter sharding specs keyed by parameter name. + """ + + plan: Dict[str, Dict[str, ParameterSharding]] + + def get_plan_for_module( + self, module_path: str + ) -> Optional[Dict[str, ParameterSharding]]: + """ + Args: + module_path + + Returns: + dict of parameter sharding specs keyed by parameter name. + Returns none if sharding specs does not exist for given module_path. + """ + return self.plan.get(module_path, None) + + def __str__(self) -> str: + return str(self.plan) + + +class ShardingPlanner(abc.ABC): + """ + Plans sharding. + This plan can be saved and re-used to ensure sharding stability. + """ + + @abc.abstractmethod + def plan( + self, + module: nn.Module, + sharders: List[ModuleSharder[nn.Module]], + ) -> ShardingPlan: + """ + Args: + modules + sharders + + Returns: + Sharding plan. + """ + ... + + @abc.abstractmethod + def collective_plan( + self, + module: nn.Module, + sharders: List[ModuleSharder[nn.Module]], + ) -> ShardingPlan: + """ + Call self.plan(...) on rank 0 and broadcast + + Args: + modules + sharders + + Returns: + Sharding plan. + """ + ... diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py new file mode 100644 index 000000000..5a46babbc --- /dev/null +++ b/torchrec/distributed/utils.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 + +from collections import OrderedDict +from typing import List, Set, Union + +import torch +from torchrec.distributed.types import ShardedModule + + +def append_prefix(prefix: str, name: str) -> str: + if prefix != "" and name != "": + return prefix + "." + name + else: + return prefix + name + + +def filter_state_dict( + state_dict: "OrderedDict[str, torch.Tensor]", name: str +) -> "OrderedDict[str, torch.Tensor]": + rtn_dict = OrderedDict() + for key, value in state_dict.items(): + if key.startswith(name): + # + 1 to length is to remove the '.' after the key + rtn_dict[key[len(name) + 1 :]] = value + return rtn_dict + + +def _get_unsharded_module_names_helper( + model: torch.nn.Module, + path: str, + unsharded_module_names: Set[str], +) -> bool: + sharded_children = set() + for name, child in model.named_children(): + curr_path = path + name + if isinstance(child, ShardedModule): + sharded_children.add(name) + else: + child_sharded = _get_unsharded_module_names_helper( + child, + curr_path + ".", + unsharded_module_names, + ) + if child_sharded: + sharded_children.add(name) + + if len(sharded_children) > 0: + for name, _ in model.named_children(): + if name not in sharded_children: + unsharded_module_names.add(path + name) + + return len(sharded_children) > 0 + + +def get_unsharded_module_names(model: torch.nn.Module) -> List[str]: + """ + Returns a list of top level modules do not contain any sharded sub modules. + """ + unsharded_module_names: Set[str] = set() + _get_unsharded_module_names_helper( + model, + "", + unsharded_module_names, + ) + return list(unsharded_module_names) + + +class sharded_model_copy: + """ + Allows to copy DistributedModelParallel module to a target device. + Example coping model to CPU: + >>> m = DistributedModelParallel(m) + with sharded_model_copy("cpu"): + m_cpu = copy.deepcopy(m) + + """ + + def __init__(self, device: Union[str, int, torch.device]) -> None: + self.device = device + + def __enter__(self) -> None: + # pyre-ignore [16] + self.t_copy_save_ = torch.Tensor.__deepcopy__ + # pyre-ignore [16] + self.p_copy_save_ = torch.nn.Parameter.__deepcopy__ + + device = self.device + + # pyre-ignore [2, 3, 53] + def _tensor_copy(tensor, memo): + if tensor.device != device: + return tensor.detach().to(device) + else: + return tensor.detach().clone() + + # pyre-ignore [2, 3] + def _param_copy(param, memo): + return torch.nn.Parameter(_tensor_copy(param, memo)) + + # pyre-ignore [2, 3] + def _no_copy(obj, memo): + return obj + + # pyre-ignore [16] + torch.Tensor.__deepcopy__ = _tensor_copy + torch.nn.Parameter.__deepcopy__ = _param_copy + torch._C._distributed_c10d.ProcessGroupNCCL.__deepcopy__ = _no_copy + torch._C._distributed_c10d.ProcessGroupGloo.__deepcopy__ = _no_copy + torch._C._distributed_c10d.Work.__deepcopy__ = _no_copy + # pyre-ignore [16] + torch.cuda.streams.Stream.__deepcopy__ = _no_copy + + # pyre-ignore [2] + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + # pyre-ignore [16] + torch.Tensor.__deepcopy__ = self.t_copy_save_ + # pyre-ignore [16] + torch.nn.Parameter.__deepcopy__ = self.p_copy_save_ + torch._C._distributed_c10d.ProcessGroupNCCL.__deepcopy__ = None + torch._C._distributed_c10d.ProcessGroupGloo.__deepcopy__ = None + torch._C._distributed_c10d.Work.__deepcopy__ = None + # pyre-ignore [16] + torch.cuda.streams.Stream.__deepcopy__ = None diff --git a/torchrec/examples/__init__.py b/torchrec/examples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/examples/dlrm/.torchxconfig b/torchrec/examples/dlrm/.torchxconfig new file mode 100644 index 000000000..14066fd25 --- /dev/null +++ b/torchrec/examples/dlrm/.torchxconfig @@ -0,0 +1,20 @@ +# generated by running +# cd ~/fbsource/fbcode/torchx/fb/example +# torchx configure --all --schedulers local_cwd,mast,flow +# +# these are project defaults, you can override them from cli as +# $ torchx run -s mast -cfg hpcClusterUuid=FooBar,hpcJobOncall=oncall_barbaz + +[local_cwd] +log_dir = None +prepend_cwd = False + +[mast] +hpcClusterUuid = TSCTestCluster +hpcIdentity = oncall_torchrec +hpcJobOncall = torchrec + +[flow] +secure_group = oncall_torchrec +entitlement = default +proxy_workflow_image = None diff --git a/torchrec/examples/dlrm/README.MD b/torchrec/examples/dlrm/README.MD new file mode 100644 index 000000000..197721d46 --- /dev/null +++ b/torchrec/examples/dlrm/README.MD @@ -0,0 +1,11 @@ +# Running + +## Torchx +We recommend using [torchx](https://pytorch.org/torchx/main/quickstart.html) to run. +Here we use the [DDP builtin](https://pytorch.org/torchx/main/components/distributed.html) + +1. pip install torchx +2. (optional) setup a slurm or kubernetes cluster +3. + a. locally: torchx run dist.ddp -j 1x2 --script dlrm_main.py + b. remotely: torchx run -s slurm dist.ddp -j 1x8 --script dlrm_main.py diff --git a/torchrec/examples/dlrm/__init__.py b/torchrec/examples/dlrm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/examples/dlrm/dlrm_main.py b/torchrec/examples/dlrm/dlrm_main.py new file mode 100644 index 000000000..b8d6a2713 --- /dev/null +++ b/torchrec/examples/dlrm/dlrm_main.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +import argparse +import os +import sys +from typing import List + +import torch +from pyre_extensions import none_throws +from torch import distributed as dist +from torch.utils.data import DataLoader +from torchrec import EmbeddingBagCollection +from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES +from torchrec.datasets.random import RandomRecDataset +from torchrec.distributed import TrainPipelineSparseDist +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.examples.dlrm.modules.dlrm_train import DLRMTrain +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.optim.keyed import KeyedOptimizerWrapper +from tqdm import tqdm + + +# TODO(T102703283): Clean up configuration options for main module for OSS. +def parse_args(argv: List[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="torchrec + lightning app") + parser.add_argument( + "--epochs", type=int, default=1, help="number of epochs to train" + ) + parser.add_argument( + "--batch_size", type=int, default=32, help="batch size to use for training" + ) + parser.add_argument( + "--limit_train_batches", + type=int, + default=100, + help="number of train batches", + ) + parser.add_argument( + "--dataset_name", + type=str, + default="criteo_1t", + help="dataset for experiment, current support criteo_1tb, criteo_kaggle", + ) + parser.add_argument( + "--num_workers", + type=int, + default=2, + help="number of dataloader workers", + ) + parser.add_argument( + "--num_embeddings", + type=int, + default=100_000, + help="max_ind_size. The number of embeddings in each embedding table. Defaults" + " to 100_000 if num_embeddings_per_feature is not supplied.", + ) + parser.add_argument( + "--num_embeddings_per_feature", + type=str, + default=None, + help="Comma separated max_ind_size per sparse feature. The number of embeddings" + " in each embedding table. 26 values are expected for the Criteo dataset.", + ) + parser.add_argument( + "--dense_arch_layer_sizes", + type=str, + default="512,256,64", + help="Comma separated layer sizes for dense arch.", + ) + parser.add_argument( + "--over_arch_layer_sizes", + type=str, + default="512,512,256,1", + help="Comma separated layer sizes for over arch.", + ) + parser.add_argument( + "--embedding_dim", + type=int, + default=64, + help="Size of each embedding.", + ) + parser.add_argument( + "--undersampling_rate", + type=float, + help="Desired proportion of zero-labeled samples to retain (i.e. undersampling zero-labeled rows)." + " Ex. 0.3 indicates only 30pct of the rows with label 0 will be kept." + " All rows with label 1 will be kept. Value should be between 0 and 1." + " When not supplied, no undersampling occurs.", + ) + parser.add_argument( + "--seed", + type=float, + help="Random seed for reproducibility.", + ) + parser.add_argument( + "--pin_memory", + dest="pin_memory", + action="store_true", + help="Use pinned memory when loading data.", + ) + parser.set_defaults(pin_memory=False) + return parser.parse_args(argv) + + +def main(argv: List[str]) -> None: + args = parse_args(argv) + + rank = int(os.environ["LOCAL_RANK"]) + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + backend = "nccl" + torch.cuda.set_device(device) + else: + device = torch.device("cpu") + backend = "gloo" + + if not torch.distributed.is_initialized(): + dist.init_process_group(backend=backend) + + if args.num_embeddings_per_feature is not None: + num_embeddings_per_feature = list( + map(int, args.num_embeddings_per_feature.split(",")) + ) + num_embeddings = None + else: + num_embeddings_per_feature = None + num_embeddings = args.num_embeddings + + dataloader = DataLoader( + RandomRecDataset( + keys=DEFAULT_CAT_NAMES, + batch_size=args.batch_size, + hash_size=num_embeddings, + hash_sizes=num_embeddings_per_feature, + manual_seed=args.seed, + ids_per_feature=1, + num_dense=len(DEFAULT_INT_NAMES), + ), + batch_size=None, + batch_sampler=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + ) + iterator = iter(dataloader) + # TODO add criteo support and add random_dataloader arg + + eb_configs = [ + EmbeddingBagConfig( + name=f"t_{feature_name}", + embedding_dim=args.embedding_dim, + num_embeddings=none_throws(num_embeddings_per_feature)[feature_idx] + if num_embeddings is None + else num_embeddings, + feature_names=[feature_name], + ) + for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES) + ] + sharded_module_kwargs = {} + if args.over_arch_layer_sizes is not None: + sharded_module_kwargs["over_arch_layer_sizes"] = list( + map(int, args.over_arch_layer_sizes.split(",")) + ) + + train_model = DLRMTrain( + embedding_bag_collection=EmbeddingBagCollection( + tables=eb_configs, device=torch.device("meta") + ), + dense_in_features=len(DEFAULT_INT_NAMES), + dense_arch_layer_sizes=list(map(int, args.dense_arch_layer_sizes.split(","))), + over_arch_layer_sizes=list(map(int, args.over_arch_layer_sizes.split(","))), + dense_device=device, + ) + + model = DistributedModelParallel( + module=train_model, + device=device, + ) + optimizer = KeyedOptimizerWrapper( + dict(model.named_parameters()), + lambda params: torch.optim.SGD(params, lr=0.01), + ) + + train_pipeline = TrainPipelineSparseDist( + model, + optimizer, + device, + ) + + for _ in range(args.epochs): + for _ in tqdm(range(args.limit_train_batches)): + loss, logits, labels = train_pipeline.progress(iterator) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/torchrec/examples/dlrm/modules/__init__.py b/torchrec/examples/dlrm/modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/examples/dlrm/modules/dlrm_train.py b/torchrec/examples/dlrm/modules/dlrm_train.py new file mode 100644 index 000000000..87d27435e --- /dev/null +++ b/torchrec/examples/dlrm/modules/dlrm_train.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +from typing import Tuple, Optional, List + +import torch +from torch import nn +from torchrec.datasets.utils import Batch +from torchrec.models.dlrm import DLRM +from torchrec.modules.embedding_modules import EmbeddingBagCollection + + +class DLRMTrain(nn.Module): + """ + nn.Module to wrap DLRM model to use with train_pipeline. + + DLRM Recsys model from "Deep Learning Recommendation Model for Personalization and + Recommendation Systems" (https://arxiv.org/abs/1906.00091). Processes sparse + features by learning pooled embeddings for each feature. Learns the relationship + between dense features and sparse features by projecting dense features into the + same embedding space. Also, learns the pairwise relationships between sparse + features. + + The module assumes all sparse features have the same embedding dimension + (i.e, each EmbeddingBagConfig uses the same embedding_dim) + + Constructor Args: + embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags + used to define SparseArch. + dense_in_features (int): the dimensionality of the dense input features. + dense_arch_layer_sizes (list[int]): the layer sizes for the DenseArch. + over_arch_layer_sizes (list[int]): the layer sizes for the OverArch. NOTE: The + output dimension of the InteractionArch should not be manually specified + here. + dense_device: (Optional[torch.device]). + + Call Args: + batch: batch used with criteo and random data from torchrec.datasets + + Returns: + Tuple[loss, Tuple[loss, logits, labels]] + + Example: + >>> TODO + """ + + def __init__( + self, + embedding_bag_collection: EmbeddingBagCollection, + dense_in_features: int, + dense_arch_layer_sizes: List[int], + over_arch_layer_sizes: List[int], + dense_device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self.model = DLRM( + embedding_bag_collection=embedding_bag_collection, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=dense_arch_layer_sizes, + over_arch_layer_sizes=over_arch_layer_sizes, + dense_device=dense_device, + ) + self.loss_fn: nn.Module = nn.BCEWithLogitsLoss() + + def forward( + self, batch: Batch + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + logits = self.model(batch.dense_features, batch.sparse_features) + logits = logits.squeeze() + loss = self.loss_fn(logits, batch.labels.float()) + + return loss, (loss.detach(), logits.detach(), batch.labels.detach()) diff --git a/torchrec/examples/dlrm/tests/test_dlrm_main.py b/torchrec/examples/dlrm/tests/test_dlrm_main.py new file mode 100644 index 000000000..2edfa7360 --- /dev/null +++ b/torchrec/examples/dlrm/tests/test_dlrm_main.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 + +import os +import tempfile +import unittest +import uuid + +from torch.distributed.launcher.api import elastic_launch, LaunchConfig +from torchrec.examples.dlrm.dlrm_main import main +from torchrec.tests import utils + + +class MainTest(unittest.TestCase): + @classmethod + def _run_trainer(cls) -> None: + main( + [ + "--limit_train_batches", + "5", + "--over_arch_layer_sizes", + "8,1", + "--dense_arch_layer_sizes", + "8,8", + "--embedding_dim", + "8", + "--num_embeddings", + "64", + ] + ) + + @utils.skip_if_asan + def test_main_function(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + lc = LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=2, + run_id=str(uuid.uuid4()), + rdzv_backend="c10d", + rdzv_endpoint=os.path.join(tmpdir, "rdzv"), + rdzv_configs={"store_type": "file"}, + start_method="spawn", + monitor_interval=1, + max_restarts=0, + ) + + elastic_launch(config=lc, entrypoint=self._run_trainer)() diff --git a/torchrec/examples/notebooks/criteo_tutorial.ipynb b/torchrec/examples/notebooks/criteo_tutorial.ipynb new file mode 100644 index 000000000..de0006ab2 --- /dev/null +++ b/torchrec/examples/notebooks/criteo_tutorial.ipynb @@ -0,0 +1,869 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# torchrec Criteo Terabyte Tutorial" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Table of contents\n", + "1. Instantiating Criteo Terabyte dataset\n", + "2. Defining and applying batch data transformation function\n", + "3. Defining model\n", + "4. Training and evaluating model\n", + "5. Training and evaluating model on GPU" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Dict, List, Tuple, Union\n", + "\n", + "import torch\n", + "from torchrec.datasets.criteo import criteo_terabyte\n", + "\n", + "torch.set_printoptions(threshold=20)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Instantiating Criteo Terabyte dataset\n", + "Let's begin by instantiating a datapipe representing the Criteo 1TB Click Logs https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/ dataset (we'll refer to it here as the Criteo Terabyte dataset)." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "datapipe = criteo_terabyte(\n", + " (\"/home/jeffhwang/local/datasets/criteo/day_11.tsv\",),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By default, the datapipe returns each sample as a dictionary that maps each default feature name to a typecasted feature value (int for each of the label and 13 integer features, and str for each of the 26 categorical features)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'label': 0,\n", + " 'int_0': 0,\n", + " 'int_1': 0,\n", + " 'int_2': 0,\n", + " 'int_3': 0,\n", + " 'int_4': 0,\n", + " 'int_5': 1,\n", + " 'int_6': 0,\n", + " 'int_7': 124,\n", + " 'int_8': 0,\n", + " 'int_9': 1,\n", + " 'int_10': 0,\n", + " 'int_11': 1,\n", + " 'int_12': 0,\n", + " 'cat_0': '35b29d1c',\n", + " 'cat_1': '11b5bc17',\n", + " 'cat_2': '63f76c15',\n", + " 'cat_3': 'f2463ffb',\n", + " 'cat_4': '16420cce',\n", + " 'cat_5': '6fcd6dcb',\n", + " 'cat_6': '6e1739cb',\n", + " 'cat_7': '337bf7a5',\n", + " 'cat_8': '2e4e821f',\n", + " 'cat_9': '4dc5d654',\n", + " 'cat_10': '59e53f80',\n", + " 'cat_11': '12716184',\n", + " 'cat_12': '00c5ffb7',\n", + " 'cat_13': 'be4ee537',\n", + " 'cat_14': 'eb24f585',\n", + " 'cat_15': '4cdc3efa',\n", + " 'cat_16': 'd20856aa',\n", + " 'cat_17': '7232d217',\n", + " 'cat_18': '9512c20b',\n", + " 'cat_19': '6c8c076c',\n", + " 'cat_20': '174c2fe8',\n", + " 'cat_21': 'b32f71aa',\n", + " 'cat_22': '59f8acf3',\n", + " 'cat_23': 'f3a1835d',\n", + " 'cat_24': '30436bfc',\n", + " 'cat_25': 'b757e957'}" + ] + }, + "execution_count": 6, + "metadata": { + "bento_obj_id": "140592387409728" + }, + "output_type": "execute_result" + } + ], + "source": [ + "next(iter(datapipe))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can adjust the format of each sample via input parameter `row_mapper`. For instance, if we'd prefer to work with lists of feature values, we can define and provide a function that maps a raw split TSV line to a list of typecasted values:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 0,\n", + " 1,\n", + " 0,\n", + " 124,\n", + " 0,\n", + " 1,\n", + " 0,\n", + " 1,\n", + " 0,\n", + " '35b29d1c',\n", + " '11b5bc17',\n", + " '63f76c15',\n", + " 'f2463ffb',\n", + " '16420cce',\n", + " '6fcd6dcb',\n", + " '6e1739cb',\n", + " '337bf7a5',\n", + " '2e4e821f',\n", + " '4dc5d654',\n", + " '59e53f80',\n", + " '12716184',\n", + " '00c5ffb7',\n", + " 'be4ee537',\n", + " 'eb24f585',\n", + " '4cdc3efa',\n", + " 'd20856aa',\n", + " '7232d217',\n", + " '9512c20b',\n", + " '6c8c076c',\n", + " '174c2fe8',\n", + " 'b32f71aa',\n", + " '59f8acf3',\n", + " 'f3a1835d',\n", + " '30436bfc',\n", + " 'b757e957']" + ] + }, + "execution_count": 7, + "metadata": { + "bento_obj_id": "140590507879232" + }, + "output_type": "execute_result" + } + ], + "source": [ + "from torchrec.datasets.utils import safe_cast\n", + "\n", + "def row_to_list(row):\n", + " return [\n", + " safe_cast(val, int, 0) for val in row[:14]\n", + " ] + [\n", + " safe_cast(val, str, \"\") for val in row[14:]\n", + " ]\n", + "\n", + "list_datapipe = criteo_terabyte(\n", + " (\"/home/jeffhwang/local/datasets/criteo/day_11.tsv\",),\n", + " row_mapper=row_to_list,\n", + ")\n", + "next(iter(list_datapipe))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or, if we'd prefer to operate directly on raw split TSV lines, we can pass `None`:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['0',\n", + " '',\n", + " '',\n", + " '0',\n", + " '0',\n", + " '',\n", + " '1',\n", + " '0',\n", + " '124',\n", + " '0',\n", + " '1',\n", + " '',\n", + " '1',\n", + " '0',\n", + " '35b29d1c',\n", + " '11b5bc17',\n", + " '63f76c15',\n", + " 'f2463ffb',\n", + " '16420cce',\n", + " '6fcd6dcb',\n", + " '6e1739cb',\n", + " '337bf7a5',\n", + " '2e4e821f',\n", + " '4dc5d654',\n", + " '59e53f80',\n", + " '12716184',\n", + " '00c5ffb7',\n", + " 'be4ee537',\n", + " 'eb24f585',\n", + " '4cdc3efa',\n", + " 'd20856aa',\n", + " '7232d217',\n", + " '9512c20b',\n", + " '6c8c076c',\n", + " '174c2fe8',\n", + " 'b32f71aa',\n", + " '59f8acf3',\n", + " 'f3a1835d',\n", + " '30436bfc',\n", + " 'b757e957']" + ] + }, + "execution_count": 8, + "metadata": { + "bento_obj_id": "140590530453760" + }, + "output_type": "execute_result" + } + ], + "source": [ + "raw_datapipe = criteo_terabyte(\n", + " (\"/home/jeffhwang/local/datasets/criteo/day_11.tsv\",),\n", + " row_mapper=None,\n", + ")\n", + "next(iter(raw_datapipe))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we move onto creating train and validation datapipes representing complementary subsets of the dataset and applying a sample limit, batching, and collation to each:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from torchrec.datasets.utils import idx_split_train_val\n", + "\n", + "datapipe = criteo_terabyte(\n", + " (\"/home/jeffhwang/local/datasets/criteo/day_11.tsv\",),\n", + ")\n", + "train_datapipe, val_datapipe = idx_split_train_val(datapipe, 0.7)\n", + "train_datapipe = train_datapipe.limit(int(1e3)).batch(100).collate()\n", + "val_datapipe = val_datapipe.limit(int(1e3)).batch(100).collate()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Defining and applying batch data transformation function\n", + "\n", + "At this point, each item that is read from `train_datapipe` and `val_datapipe` is a dictionary representing a batch of 100 Criteo Terabyte samples (\"batch dictionary\"). The dictionary maps each string feature name to 100 feature values, each corresponding to a sample in the batch.\n", + "\n", + "Each of the 13 feature names corresponding to integer-valued features (\"int_0\" through \"int_12\") maps to a shape-(100,) tensor of integers; each of the 26 feature names corresponding to categorical features (\"cat_0\" through \"cat_25\") maps to a length-100 list of hex strings." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "int_0: tensor([ 0, 118, 3, ..., 24, 12, 1])\n", + "cat_0: ['35b29d1c', '0ede8acc', '9a38fdbd', 'b7590909', 'f7f317e1', 'e5f3fd8d', '74a30cd8', 'a2309537', '0d2de9b7', 'd173a71b', '9bb030cc', 'd080dcdd', 'e5f3fd8d', '75bbaf08', 'fd2294fd', '6f88737d', '7f5629e3', '4ba9ec22', 'b1e51346', 'ae08ee40', '6dfe5365', 'b401509b', '', '288878ba', '', 'e5f3fd8d', '11440f4a', 'e5f3fd8d', '4a3130c4', '6f4012dc', 'a1c393aa', 'a5ba1c3d', '105fc022', 'e5f3fd8d', '', '5deaeb35', '8175c6fa', '265366bf', '', '8a2b1e43', 'ad98e872', 'ad98e872', '36ad0c3a', 'faec4515', 'ad98e872', '372034f9', '788a5d5b', 'e5f3fd8d', '240b1f33', 'ad98e872', 'a6367ddd', '84bff54b', '265366bf', 'cc1858ef', '03fd28c6', 'f6771153', '76d82355', 'ad98e872', '73de94cd', '265366bf', 'ad98e872', 'ad98e872', '32818e9b', '788a5d5b', 'b2d27a4e', '341cc7aa', 'ad98e872', '4d4b357f', '10a8c43d', '6a6402aa', 'ad98e872', '2edf58c3', '', 'ad98e872', 'b2d27a4e', 'b401509b', '2c4bc41a', '7592d348', 'ad98e872', '0d5c791d', 'ad98e872', 'ad98e872', '922980a7', 'ff0adf28', '788a5d5b', 'ad98e872', '5f430440', '', 'ad98e872', 'ad98e872', 'ad98e872', '8a2b1e43', '265366bf', '', '15548013', 'b380001c', 'c250bf94', 'ad98e872', 'ad98e872', '41a99438']\n" + ] + } + ], + "source": [ + "batch = next(iter(train_datapipe))\n", + "print(\"int_0:\", batch[\"int_0\"])\n", + "print(\"cat_0:\", batch[\"cat_0\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There are a few data transformations we'd like to apply to each batch dictionary to produce the data we want to feed into our model:\n", + "- Normalize integer feature values, e.g. by applying a logarithmic function.\n", + "- Map each categorical feature hex string value to an integer that can be used to index into an embedding table.\n", + "- Separate integer features, categorical features, and labels into individual tensors reshaped appropriately.\n", + "\n", + "Towards accomplishing this, we define a function `_transform` that accepts a batch dictionary as an input, applies the aforementioned transformations, and returns a tuple of three tensors corresponding to integer features, categorical features, and labels:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES, DEFAULT_LABEL_NAME\n", + "\n", + "NUM_EMBEDDINGS = int(1e5)\n", + "\n", + "col_transforms = {\n", + " **{name: lambda x: torch.log(x + 2) for name in DEFAULT_INT_NAMES},\n", + " **{\n", + " name: lambda x: x.fmod(NUM_EMBEDDINGS - 1) + 1\n", + " for name in DEFAULT_CAT_NAMES\n", + " },\n", + "}\n", + " \n", + "def _transform(\n", + " batch: Dict[str, List[Union[int, str]]]\n", + ") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n", + " int_x = torch.cat(\n", + " [\n", + " col_transforms[col_name](torch.tensor(batch[col_name]).unsqueeze(0).T)\n", + " for col_name in DEFAULT_INT_NAMES\n", + " if col_name in col_transforms\n", + " ],\n", + " dim=1,\n", + " )\n", + " cat_x = torch.cat(\n", + " [\n", + " col_transforms[col_name](\n", + " torch.tensor([int(v, 16) if v else -1 for v in batch[col_name]])\n", + " .unsqueeze(0)\n", + " .T\n", + " )\n", + " for col_name in DEFAULT_CAT_NAMES\n", + " if col_name in col_transforms\n", + " ],\n", + " dim=1,\n", + " )\n", + " y = torch.tensor(batch[DEFAULT_LABEL_NAME], dtype=torch.float32).unsqueeze(1)\n", + " return int_x, cat_x, y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then, using `map`, we produce a new pair of train and validation datapipes that applies `_transform` to each batch dictionary of data:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "train_datapipe = train_datapipe.map(_transform)\n", + "val_datapipe = val_datapipe.map(_transform)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[0.6931, 0.6931, 0.6931, ..., 0.6931, 1.0986, 0.6931],\n", + " [4.7875, 0.6931, 2.5649, ..., 0.6931, 6.0234, 2.8332],\n", + " [1.6094, 5.4638, 2.6391, ..., 1.7918, 7.9697, 2.6391],\n", + " ...,\n", + " [3.2581, 7.3343, 1.7918, ..., 1.7918, 7.8921, 1.7918],\n", + " [2.6391, 2.9957, 1.3863, ..., 1.0986, 8.0064, 1.3863],\n", + " [1.0986, 1.0986, 1.0986, ..., 0.6931, 4.8675, 1.0986]]),\n", + " tensor([[ 7086, 25811, 76217, ..., 89288, 33022, 22656],\n", + " [68043, 68258, 52745, ..., 81118, 40776, 34095],\n", + " [52112, 50486, 12400, ..., 6322, 33022, 47765],\n", + " ...,\n", + " [ 8472, 85233, 86687, ..., 68498, 33022, 87620],\n", + " [ 8472, 94259, 77092, ..., 77871, 70499, 87620],\n", + " [43585, 52600, 2570, ..., 3211, 51896, 67374]]),\n", + " tensor([[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]))" + ] + }, + "execution_count": 13, + "metadata": { + "bento_obj_id": "140592381608960" + }, + "output_type": "execute_result" + } + ], + "source": [ + "next(iter(train_datapipe))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we've got datapipes that produce data that we can train and evaluate a model on!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Defining model\n", + "To utilize the integer (dense) and categorical (sparse) features present in the Criteo Terabyte dataset, we define `TestSparseNN`, which maps dense and sparse features to embeddings and interacts the embeddings to produce an output:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from torchrec.fb.modules.mlp import LazyMLP\n", + "\n", + "\n", + "class TestSparseNN(torch.nn.Module):\n", + " def __init__(\n", + " self,\n", + " *,\n", + " hidden_layer_size,\n", + " output_dim,\n", + " sparse_input_size,\n", + " num_embeddings,\n", + " embedding_dim,\n", + " ):\n", + " super(TestSparseNN, self).__init__()\n", + " self.dense_arch = LazyMLP([hidden_layer_size, embedding_dim])\n", + " self.embedding_layers = self._embedding_layers(\n", + " sparse_input_size, num_embeddings, embedding_dim\n", + " )\n", + " self.over_arch = LazyMLP([output_dim])\n", + " self.final = torch.nn.LazyLinear(1)\n", + "\n", + " def _embedding_layers(self, sparse_input_size, num_embeddings, embedding_dim):\n", + " return torch.nn.ModuleList(\n", + " [\n", + " torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)\n", + " for _ in range(sparse_input_size)\n", + " ]\n", + " )\n", + "\n", + " def _interact(self, embeddings):\n", + " batch_size, embedding_dim = embeddings[0].shape\n", + " stacked_embeddings = torch.cat(embeddings, dim=1).view(\n", + " batch_size, -1, embedding_dim\n", + " )\n", + " interactions = torch.matmul(\n", + " stacked_embeddings, torch.transpose(stacked_embeddings, 1, 2)\n", + " )\n", + " _, embedding_count, _ = interactions.shape\n", + " rows, cols = torch.tril_indices(embedding_count, embedding_count)\n", + " return interactions[:, rows, cols]\n", + "\n", + " def forward(self, dense_x, cat_x):\n", + " embedded_dense = self.dense_arch(dense_x)\n", + " embedded_sparse = [\n", + " embedding_layer(cat_x[:, idx])\n", + " for idx, embedding_layer in enumerate(self.embedding_layers)\n", + " ]\n", + " interactions = self._interact([embedded_dense] + embedded_sparse)\n", + " return self.final(\n", + " self.over_arch(torch.cat([embedded_dense, interactions], dim=1))\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Training and evaluating model\n", + "We can now train an instance of `TestSparseNN` on data supplied by `train_datapipe`" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss: 0.536839 0\n", + "loss: 0.127229 100\n", + "loss: 0.322483 200\n", + "loss: 0.155279 300\n", + "loss: 0.143317 400\n", + "loss: 0.244847 500\n", + "loss: 0.264295 600\n", + "loss: 0.136220 700\n", + "loss: 0.040959 800\n", + "loss: 0.080234 900\n" + ] + } + ], + "source": [ + "model = TestSparseNN(\n", + " hidden_layer_size=20,\n", + " output_dim=10,\n", + " sparse_input_size=26,\n", + " num_embeddings=NUM_EMBEDDINGS,\n", + " embedding_dim=16,\n", + ")\n", + "\n", + "# Initialize lazy modules.\n", + "int_x, cat_x, y = next(iter(train_datapipe))\n", + "model(int_x, cat_x)\n", + "\n", + "loss_fn = torch.nn.BCEWithLogitsLoss()\n", + "optimizer = torch.optim.Adagrad(model.parameters(), lr=1e-2, weight_decay=1e-6)\n", + "\n", + "for batch_num, (int_x, cat_x, y) in enumerate(train_datapipe):\n", + " res = model(int_x, cat_x)\n", + " loss = loss_fn(res, y)\n", + " \n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " if batch_num % 1 == 0:\n", + " loss, current = loss.item(), batch_num * len(y)\n", + " print(f\"loss: {loss:>7f} {current}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ", and evaluate the trained model on data supplied by `val_datapipe`" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test results:\n", + "AUROC: 0.530097 Avg loss: 0.126038\n" + ] + } + ], + "source": [ + "import sklearn.metrics\n", + "\n", + "\n", + "y_true = []\n", + "y_pred = []\n", + "with torch.no_grad():\n", + " for int_x, cat_x, y in val_datapipe:\n", + " pred = model(int_x, cat_x)\n", + " y_pred.append(pred)\n", + " y_true.append(y)\n", + "\n", + "auroc = sklearn.metrics.roc_auc_score(\n", + " torch.cat(y_true).view(-1),\n", + " torch.sigmoid(torch.cat(y_pred).view(-1)),\n", + ")\n", + "val_loss = loss_fn(\n", + " torch.cat(y_pred).view(-1),\n", + " torch.cat(y_true).view(-1),\n", + ")\n", + "\n", + "print(\"Test results:\")\n", + "print(f\"AUROC: {auroc:>8f} Avg loss: {val_loss:>8f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Training and evaluating model on GPU\n", + "\n", + "If we have access to a GPU device, we can leverage it as follows to accelerate model training and evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss: 0.120394 0\n", + "loss: 0.122464 10000\n", + "loss: 0.148007 20000\n", + "loss: 0.153441 30000\n", + "loss: 0.124577 40000\n", + "loss: 0.146918 50000\n", + "loss: 0.153290 60000\n", + "loss: 0.124814 70000\n", + "loss: 0.139988 80000\n", + "loss: 0.155026 90000\n", + "loss: 0.127663 100000\n", + "loss: 0.128998 110000\n", + "loss: 0.149047 120000\n", + "loss: 0.130756 130000\n", + "loss: 0.098698 140000\n", + "loss: 0.156221 150000\n", + "loss: 0.144935 160000\n", + "loss: 0.111010 170000\n", + "loss: 0.165140 180000\n", + "loss: 0.162504 190000\n", + "loss: 0.126021 200000\n", + "loss: 0.133386 210000\n", + "loss: 0.146822 220000\n", + "loss: 0.139569 230000\n", + "loss: 0.134351 240000\n", + "loss: 0.143748 250000\n", + "loss: 0.127452 260000\n", + "loss: 0.150848 270000\n", + "loss: 0.110147 280000\n", + "loss: 0.121761 290000\n", + "loss: 0.148827 300000\n", + "loss: 0.135799 310000\n", + "loss: 0.143518 320000\n", + "loss: 0.147040 330000\n", + "loss: 0.147874 340000\n", + "loss: 0.158020 350000\n", + "loss: 0.117170 360000\n", + "loss: 0.160049 370000\n", + "loss: 0.124524 380000\n", + "loss: 0.147427 390000\n", + "loss: 0.143686 400000\n", + "loss: 0.145967 410000\n", + "loss: 0.140429 420000\n", + "loss: 0.129113 430000\n", + "loss: 0.136281 440000\n", + "loss: 0.163455 450000\n", + "loss: 0.102815 460000\n", + "loss: 0.126294 470000\n", + "loss: 0.152309 480000\n", + "loss: 0.130393 490000\n", + "loss: 0.151293 500000\n", + "loss: 0.140869 510000\n", + "loss: 0.156620 520000\n", + "loss: 0.151464 530000\n", + "loss: 0.146070 540000\n", + "loss: 0.153463 550000\n", + "loss: 0.142922 560000\n", + "loss: 0.152070 570000\n", + "loss: 0.123993 580000\n", + "loss: 0.166800 590000\n", + "loss: 0.126718 600000\n", + "loss: 0.187246 610000\n", + "loss: 0.139779 620000\n", + "loss: 0.132810 630000\n", + "loss: 0.149490 640000\n", + "loss: 0.125739 650000\n", + "loss: 0.156822 660000\n", + "loss: 0.137232 670000\n", + "loss: 0.146410 680000\n", + "loss: 0.122474 690000\n", + "loss: 0.116913 700000\n", + "loss: 0.133779 710000\n", + "loss: 0.150961 720000\n", + "loss: 0.121909 730000\n", + "loss: 0.130351 740000\n", + "loss: 0.137554 750000\n", + "loss: 0.139059 760000\n", + "loss: 0.116831 770000\n", + "loss: 0.139617 780000\n", + "loss: 0.150021 790000\n", + "loss: 0.155689 800000\n", + "loss: 0.140969 810000\n", + "loss: 0.122985 820000\n", + "loss: 0.145107 830000\n", + "loss: 0.146708 840000\n", + "loss: 0.113037 850000\n", + "loss: 0.081020 860000\n", + "loss: 0.139679 870000\n", + "loss: 0.151576 880000\n", + "loss: 0.125169 890000\n", + "loss: 0.148480 900000\n", + "loss: 0.154493 910000\n", + "loss: 0.148526 920000\n", + "loss: 0.141710 930000\n", + "loss: 0.138688 940000\n", + "loss: 0.166732 950000\n", + "loss: 0.146822 960000\n", + "loss: 0.140306 970000\n", + "loss: 0.161611 980000\n", + "loss: 0.149240 990000\n", + "Test results:\n", + "AUROC: 0.724610 Avg loss: 0.131163\n" + ] + } + ], + "source": [ + "assert(torch.cuda.is_available())\n", + "\n", + "device = torch.device(\"cuda:0\")\n", + "\n", + "datapipe = criteo_terabyte(\n", + " (\"/home/jeffhwang/local/datasets/criteo/day_11.tsv\",),\n", + ")\n", + "train_datapipe, val_datapipe = idx_split_train_val(datapipe, 70)\n", + "train_datapipe = train_datapipe.limit(int(1e6)).batch(1000).collate().map(_transform)\n", + "val_datapipe = val_datapipe.limit(int(1e5)).batch(1000).collate().map(_transform)\n", + "\n", + "model.to(device)\n", + "\n", + "int_x, cat_x, y = next(iter(train_datapipe))\n", + "int_x, cat_x, y = int_x.to(device), cat_x.to(device), y.to(device)\n", + "model(int_x, cat_x)\n", + "\n", + "loss_fn = torch.nn.BCEWithLogitsLoss()\n", + "optimizer = torch.optim.Adagrad(model.parameters(), lr=1e-2, weight_decay=1e-6)\n", + "\n", + "for batch_num, (int_x, cat_x, y) in enumerate(train_datapipe):\n", + " int_x, cat_x, y = int_x.to(device), cat_x.to(device), y.to(device)\n", + " res = model(int_x, cat_x)\n", + " loss = loss_fn(res, y)\n", + "\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " if batch_num % 10 == 0:\n", + " loss, current = loss.item(), batch_num * len(y)\n", + " print(f\"loss: {loss:>7f} {current}\")\n", + "\n", + "y_true = []\n", + "y_pred = []\n", + "with torch.no_grad():\n", + " for int_x, cat_x, y in val_datapipe:\n", + " int_x, cat_x, y = int_x.to(device), cat_x.to(device), y.to(device)\n", + " pred = model(int_x, cat_x)\n", + " y_pred.append(pred)\n", + " y_true.append(y)\n", + "\n", + "auroc = sklearn.metrics.roc_auc_score(\n", + " torch.cat(y_true).view(-1).cpu(),\n", + " torch.sigmoid(torch.cat(y_pred).view(-1)).cpu(),\n", + ")\n", + "val_loss = loss_fn(\n", + " torch.cat(y_pred).view(-1).cpu(),\n", + " torch.cat(y_true).view(-1).cpu(),\n", + ")\n", + "\n", + "print(\"Test results:\")\n", + "print(f\"AUROC: {auroc:>8f} Avg loss: {val_loss:>8f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "anp_metadata": { + "path": "notebooks/fbsource/fbcode/torchrec/examples/criteo_tutorial.ipynb" + }, + "bento_stylesheets": { + "bento/extensions/flow/main.css": true, + "bento/extensions/kernel_selector/main.css": true, + "bento/extensions/kernel_ui/main.css": true, + "bento/extensions/new_kernel/main.css": true, + "bento/extensions/system_usage/main.css": true, + "bento/extensions/theme/main.css": true + }, + "disseminate_notebook_id": { + "notebook_id": "515185369527336" + }, + "disseminate_notebook_info": { + "backup_notebook_id": "472451483979616", + "bento_version": "20210606-210329", + "description": "", + "hide_code": false, + "hipster_group": "", + "kernel_build_info": { + "deps": [ + "//caffe2/caffe2/fb/ifbpy:all_pytorch_and_caffe2_deps", + "//github/third-party/PyTorchLightning/pytorch-lightning:lib" + ], + "external_deps": [] + }, + "no_uii": true, + "notebook_number": "770927", + "others_can_edit": false, + "reviewers": "", + "revision_id": "483679666198664", + "tags": "", + "tasks": "", + "title": "criteo_tutorial" + }, + "kernelspec": { + "display_name": "pytorch", + "language": "python", + "name": "bento_kernel_pytorch" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/torchrec/examples/notebooks/movielens_tutorial.ipynb b/torchrec/examples/notebooks/movielens_tutorial.ipynb new file mode 100644 index 000000000..625ff5e36 --- /dev/null +++ b/torchrec/examples/notebooks/movielens_tutorial.ipynb @@ -0,0 +1,639 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# torchrec MovieLens Tutorial" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Table of contents\n", + "1. Instantiating MovieLens-25M dataset\n", + "2. Defining model\n", + "3. Training and evaluating model\n", + "4. Finding similar movies" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Instantiating MovieLens-25M dataset\n", + "\n", + "To start, we can load the MovieLens-25M dataset using `torchrec.datasets.movielens.movielens_25m`. The function loads just the user-movie ratings data in `ratings.csv` by default; we call the function with `include_movies_data=True` such that it adds movie data from `movies.csv` to each user-movie sample." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from torchrec.datasets.movielens import movielens_25m\n", + "\n", + "dp = movielens_25m(\"/home/jeffhwang/local/datasets/ml-25m\", include_movies_data=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's check out a single sample." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'userId': 1,\n", + " 'movieId': 296,\n", + " 'rating': 5.0,\n", + " 'timestamp': 1147880044,\n", + " 'title': 'Pulp Fiction (1994)',\n", + " 'genres': 'Comedy|Crime|Drama|Thriller'}" + ] + }, + "execution_count": 2, + "metadata": { + "bento_obj_id": "139906334494848" + }, + "output_type": "execute_result" + } + ], + "source": [ + "next(iter(dp))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Seems reasonable.\n", + "\n", + "Next, we instantiate datapipes representing training and validation data splits and apply shuffling and batching." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from torchrec.datasets.utils import rand_split_train_val\n", + "\n", + "train_dp, val_dp = rand_split_train_val(dp, 0.9)\n", + "batched_train_dp = train_dp.shuffle(buffer_size=int(1e5)).batch(8192)\n", + "batched_val_dp = val_dp.batch(8192)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Turns out that the integer user ids and movie ids referenced by the dataset aren't contiguous. Let's remap them to contiguous values so that we can use them with `torch.nn.Embedding` more easily downstream.\n", + "\n", + "To do so, we first populate dictionaries that map movie and user ids to ids in contiguous ranges" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "contig_movie_ids = {}\n", + "contig_user_ids = {}\n", + "movie_id_to_title_genre = {}\n", + "\n", + "available_movie_id = 0\n", + "available_user_id = 0\n", + "for sample in dp:\n", + " if sample[\"movieId\"] not in contig_movie_ids:\n", + " contig_movie_ids[sample[\"movieId\"]] = available_movie_id\n", + " available_movie_id += 1\n", + " if sample[\"userId\"] not in contig_user_ids:\n", + " contig_user_ids[sample[\"userId\"]] = available_user_id\n", + " available_user_id += 1\n", + " movie_id_to_title_genre[sample[\"movieId\"]] = (sample[\"title\"], sample[\"genres\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ", and then define a function `_transform` that uses those dictionaries to remap movie and user ids for a batch of data. While we're at it, we'll also have `_transform` reformat the batch as tensors representing user ids, movie ids, and labels (numerical movie ratings given by users)." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "def _transform(batch):\n", + " user_ids = torch.tensor([contig_user_ids[sample[\"userId\"]] for sample in batch], dtype=torch.int32)\n", + " movie_ids = torch.tensor([contig_movie_ids[sample[\"movieId\"]] for sample in batch], dtype=torch.int32)\n", + " labels = torch.tensor([sample[\"rating\"] for sample in batch], dtype=torch.float)\n", + " return user_ids, movie_ids, labels" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we configure our training and validation datapipes to apply `_transform` to each batch of data using `map`." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "preproc_train_dp = batched_train_dp.map(_transform)\n", + "preproc_val_dp = batched_val_dp.map(_transform)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "At this point, `preproc_train_dp` and `preproc_val_dp` are set up to produce the data that our model expects." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Defining model\n", + "\n", + "Next, we define the model we're going to train. We'll go with a simplified two-tower model `TwoTowerModel` resembling a matrix factorization model that attempts to learn a low-rank approximation of the user-movie ratings matrix. More specifically, we want to find matrices $U \\in \\mathbb{R}^{u \\times d}$ and $M \\in \\mathbb{R}^{m \\times d}$ such that $U M^T \\approx A$, where each row in $U$ represents a user embedding of dimension $d$ and each row in $M$ a movie embedding also of dimension $d$. Once we find matrices $U$ and $M$, we can infer the rating that the $i$-th user gives the $j$-th movie as $u_i^T \\cdot m_j^T$, i.e. the dot product of the $i$-th row in $U$ and $j$-th row in $M$.\n", + "\n", + "`TwoTowerModel` represents $U$ and $M$ as embedding tables — instances of `torch.nn.Embedding`." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "class TwoTowerModel(torch.nn.Module):\n", + " def __init__(self, num_embeddings_0, num_embeddings_1, embedding_dim):\n", + " super().__init__()\n", + " self.model_0 = torch.nn.Embedding(num_embeddings_0, embedding_dim)\n", + " self.model_1 = torch.nn.Embedding(num_embeddings_1, embedding_dim)\n", + " \n", + " def forward(self, input):\n", + " embeddings_0 = self.model_0(input[0])\n", + " embeddings_1 = self.model_1(input[1])\n", + " return torch.sum(embeddings_0 * embeddings_1, axis=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Training and evaluating model\n", + "We're ready to train our model. Let's instantiate the model we just defined" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "model = TwoTowerModel(\n", + " len(contig_user_ids),\n", + " len(contig_movie_ids),\n", + " 32\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ", instantiate our loss function and optimizer" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "loss_fn = torch.nn.MSELoss()\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-6)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ", and define our train and test loops." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def train_loop(dp, model, loss_fn, optimizer):\n", + " for batch, (users, movies, labels) in enumerate(dp):\n", + " pred = model((users, movies))\n", + " loss = loss_fn(pred, labels)\n", + "\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " if batch % 100 == 0:\n", + " loss, current = loss.item(), batch * len(labels)\n", + " print(f\"loss: {loss:>7f}; batch: {batch}\")\n", + "\n", + "def test_loop(dp, model, loss_fn):\n", + " test_loss = 0\n", + " batch_count = 0\n", + " with torch.no_grad():\n", + " for batch, (users, movies, labels) in enumerate(dp):\n", + " pred = model((users, movies))\n", + " test_loss += loss_fn(pred, labels).item()\n", + " batch_count += 1\n", + " \n", + " print(f\"Test loss: {test_loss / batch_count}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And now, we train." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss: 46.037308; batch: 0\n", + "loss: 38.465691; batch: 100\n", + "loss: 36.324226; batch: 200\n", + "loss: 32.601963; batch: 300\n", + "loss: 31.454372; batch: 400\n", + "loss: 28.976149; batch: 500\n", + "loss: 28.139290; batch: 600\n", + "loss: 26.033499; batch: 700\n", + "loss: 24.564535; batch: 800\n", + "loss: 24.771233; batch: 900\n", + "loss: 23.982235; batch: 1000\n", + "loss: 23.603422; batch: 1100\n", + "loss: 22.836342; batch: 1200\n", + "loss: 22.366283; batch: 1300\n", + "loss: 21.981915; batch: 1400\n", + "loss: 20.872860; batch: 1500\n", + "loss: 21.898285; batch: 1600\n", + "loss: 21.096714; batch: 1700\n", + "loss: 21.146002; batch: 1800\n", + "loss: 20.964083; batch: 1900\n", + "loss: 20.462320; batch: 2000\n", + "loss: 19.815897; batch: 2100\n", + "loss: 20.486073; batch: 2200\n", + "loss: 19.710440; batch: 2300\n", + "loss: 19.632444; batch: 2400\n", + "loss: 20.404284; batch: 2500\n", + "loss: 19.290865; batch: 2600\n", + "loss: 19.100807; batch: 2700\n", + "Test loss: 18.480480686511868\n", + "loss: 18.086700; batch: 0\n", + "loss: 17.338081; batch: 100\n", + "loss: 17.585329; batch: 200\n", + "loss: 17.180315; batch: 300\n", + "loss: 17.549704; batch: 400\n", + "loss: 17.645161; batch: 500\n", + "loss: 17.722076; batch: 600\n", + "loss: 17.680521; batch: 700\n", + "loss: 17.408783; batch: 800\n", + "loss: 17.525982; batch: 900\n", + "loss: 17.376921; batch: 1000\n", + "loss: 17.426014; batch: 1100\n", + "loss: 17.220980; batch: 1200\n", + "loss: 16.952011; batch: 1300\n", + "loss: 17.134249; batch: 1400\n", + "loss: 17.048244; batch: 1500\n", + "loss: 17.227037; batch: 1600\n", + "loss: 17.148981; batch: 1700\n", + "loss: 17.023678; batch: 1800\n", + "loss: 16.871969; batch: 1900\n", + "loss: 17.318562; batch: 2000\n", + "loss: 16.544071; batch: 2100\n", + "loss: 16.600208; batch: 2200\n", + "loss: 16.858044; batch: 2300\n", + "loss: 16.681547; batch: 2400\n", + "loss: 17.338814; batch: 2500\n", + "loss: 16.720623; batch: 2600\n", + "loss: 16.843494; batch: 2700\n", + "Test loss: 16.428918199601515\n", + "loss: 16.221512; batch: 0\n", + "loss: 16.358212; batch: 100\n", + "loss: 15.603410; batch: 200\n", + "loss: 15.793720; batch: 300\n", + "loss: 16.120581; batch: 400\n", + "loss: 15.609840; batch: 500\n", + "loss: 15.646942; batch: 600\n", + "loss: 15.803013; batch: 700\n", + "loss: 15.460732; batch: 800\n", + "loss: 15.707933; batch: 900\n", + "loss: 15.387153; batch: 1000\n", + "loss: 15.972283; batch: 1100\n", + "loss: 15.815547; batch: 1200\n", + "loss: 15.986384; batch: 1300\n", + "loss: 15.839499; batch: 1400\n", + "loss: 16.124321; batch: 1500\n", + "loss: 15.958432; batch: 1600\n", + "loss: 15.983364; batch: 1700\n", + "loss: 15.935410; batch: 1800\n", + "loss: 15.357247; batch: 1900\n", + "loss: 15.969000; batch: 2000\n", + "loss: 15.663831; batch: 2100\n", + "loss: 15.580512; batch: 2200\n", + "loss: 15.733629; batch: 2300\n", + "loss: 15.704931; batch: 2400\n", + "loss: 15.884350; batch: 2500\n", + "loss: 15.432696; batch: 2600\n", + "loss: 15.874282; batch: 2700\n", + "Test loss: 15.481964759577334\n" + ] + } + ], + "source": [ + "epochs = 3\n", + "\n", + "for __ in range(epochs):\n", + " train_loop(preproc_train_dp, model, loss_fn, optimizer)\n", + " test_loop(preproc_val_dp, model, loss_fn)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We've got a trained model!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Finding similar movies\n", + "For kicks, let's see if we can use our model's trained embeddings to find movies that are most similar to some query movie. In theory, movies with embeddings that are similar should themselves be similar." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "contig_to_movie_id = {v: k for k, v in contig_movie_ids.items()}\n", + "\n", + "def get_topk_sim_movies(movie_id, k=20):\n", + " embedding = model.model_1(torch.tensor([contig_movie_ids[movie_id]]))\n", + " movie_embeddings = model.get_parameter(\"model_1.weight\")\n", + " movie_similarities = torch.sum(embedding * movie_embeddings, axis=1) / torch.maximum(torch.norm(embedding) * torch.norm(movie_embeddings, dim=1), torch.ones(movie_embeddings.shape[0]) * 1e-12)\n", + " topk_sim = torch.topk(movie_similarities, 20)\n", + " contig_ids = topk_sim.indices.tolist()\n", + " return [\n", + " (*movie_id_to_title_genre[contig_to_movie_id[movie_id]], contig_to_movie_id[movie_id]) \n", + " for movie_id in contig_ids\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('Drive (2011)', 'Crime|Drama|Film-Noir|Thriller', 88129),\n", + " ('Wolf of Wall Street, The (2013)', 'Comedy|Crime|Drama', 106782),\n", + " ('Ex Machina (2015)', 'Drama|Sci-Fi|Thriller', 115713),\n", + " ('Shutter Island (2010)', 'Drama|Mystery|Thriller', 74458),\n", + " ('Girl with the Dragon Tattoo, The (2011)', 'Drama|Thriller', 91658),\n", + " ('True Grit (2010)', 'Western', 82459),\n", + " ('Gone Girl (2014)', 'Drama|Thriller', 112556),\n", + " ('Looper (2012)', 'Action|Crime|Sci-Fi', 96610),\n", + " ('127 Hours (2010)', 'Adventure|Drama|Thriller', 81562),\n", + " ('No Country for Old Men (2007)', 'Crime|Drama', 55820),\n", + " ('Black Swan (2010)', 'Drama|Thriller', 81591),\n", + " ('Big Short, The (2015)', 'Drama', 148626),\n", + " ('Grand Budapest Hotel, The (2014)', 'Comedy|Drama', 109374),\n", + " ('Skyfall (2012)', 'Action|Adventure|Thriller|IMAX', 96079),\n", + " ('Dark Knight Rises, The (2012)', 'Action|Adventure|Crime|IMAX', 91529),\n", + " ('Sherlock Holmes: A Game of Shadows (2011)',\n", + " 'Action|Adventure|Comedy|Crime|Mystery|Thriller',\n", + " 91542),\n", + " ('Gran Torino (2008)', 'Crime|Drama', 64614),\n", + " ('Whiplash (2014)', 'Drama', 112552),\n", + " ('American History X (1998)', 'Crime|Drama', 2329),\n", + " ('Edge of Tomorrow (2014)', 'Action|Sci-Fi|IMAX', 111759)]" + ] + }, + "execution_count": 17, + "metadata": { + "bento_obj_id": "139905706007040" + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Drive\n", + "get_topk_sim_movies(88129)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('Lost in Translation (2003)', 'Comedy|Drama|Romance', 6711),\n", + " ('Blade (1998)', 'Action|Horror|Thriller', 2167),\n", + " ('O Brother, Where Art Thou? (2000)', 'Adventure|Comedy|Crime', 4027),\n", + " ('Minority Report (2002)', 'Action|Crime|Mystery|Sci-Fi|Thriller', 5445),\n", + " ('Bowling for Columbine (2002)', 'Documentary', 5669),\n", + " ('Lord of War (2005)', 'Action|Crime|Drama|Thriller|War', 36529),\n", + " ('Dark City (1998)', 'Adventure|Film-Noir|Sci-Fi|Thriller', 1748),\n", + " ('Rocky (1976)', 'Drama', 1954),\n", + " ('Best in Show (2000)', 'Comedy', 3911),\n", + " ('Deliverance (1972)', 'Adventure|Drama|Thriller', 2871),\n", + " ('Eternal Sunshine of the Spotless Mind (2004)',\n", + " 'Drama|Romance|Sci-Fi',\n", + " 7361),\n", + " ('Crouching Tiger, Hidden Dragon (Wo hu cang long) (2000)',\n", + " 'Action|Drama|Romance',\n", + " 3996),\n", + " ('Fahrenheit 9/11 (2004)', 'Documentary', 8622),\n", + " ('Back to the Future Part III (1990)',\n", + " 'Adventure|Comedy|Sci-Fi|Western',\n", + " 2012),\n", + " ('Boogie Nights (1997)', 'Drama', 1673),\n", + " ('Run Lola Run (Lola rennt) (1998)', 'Action|Crime', 2692),\n", + " ('Con Air (1997)', 'Action|Adventure|Thriller', 1552),\n", + " (\"Ocean's Eleven (2001)\", 'Crime|Thriller', 4963),\n", + " ('Austin Powers: International Man of Mystery (1997)',\n", + " 'Action|Adventure|Comedy',\n", + " 1517),\n", + " ('Bad Boys (1995)', 'Action|Comedy|Crime|Drama|Thriller', 145)]" + ] + }, + "execution_count": 18, + "metadata": { + "bento_obj_id": "139906355972864" + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Lost in Translation\n", + "get_topk_sim_movies(6711)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('Ratatouille (2007)', 'Animation|Children|Drama', 50872),\n", + " ('Harry Potter and the Goblet of Fire (2005)',\n", + " 'Adventure|Fantasy|Thriller|IMAX',\n", + " 40815),\n", + " ('Toy Story 2 (1999)', 'Adventure|Animation|Children|Comedy|Fantasy', 3114),\n", + " ('Ice Age (2002)', 'Adventure|Animation|Children|Comedy', 5218),\n", + " ('Monsters, Inc. (2001)',\n", + " 'Adventure|Animation|Children|Comedy|Fantasy',\n", + " 4886),\n", + " ('Harry Potter and the Chamber of Secrets (2002)', 'Adventure|Fantasy', 5816),\n", + " ('Wedding Singer, The (1998)', 'Comedy|Romance', 1777),\n", + " ('Armageddon (1998)', 'Action|Romance|Sci-Fi|Thriller', 1917),\n", + " ('WALL·E (2008)', 'Adventure|Animation|Children|Romance|Sci-Fi', 60069),\n", + " ('Finding Nemo (2003)', 'Adventure|Animation|Children|Comedy', 6377),\n", + " ('Indiana Jones and the Temple of Doom (1984)',\n", + " 'Action|Adventure|Fantasy',\n", + " 2115),\n", + " ('Signs (2002)', 'Horror|Sci-Fi|Thriller', 5502),\n", + " ('Toy Story 3 (2010)',\n", + " 'Adventure|Animation|Children|Comedy|Fantasy|IMAX',\n", + " 78499),\n", + " ('My Big Fat Greek Wedding (2002)', 'Comedy|Romance', 5299),\n", + " ('Harry Potter and the Prisoner of Azkaban (2004)',\n", + " 'Adventure|Fantasy|IMAX',\n", + " 8368),\n", + " ('Star Wars: Episode VI - Return of the Jedi (1983)',\n", + " 'Action|Adventure|Sci-Fi',\n", + " 1210),\n", + " ('Black Hawk Down (2001)', 'Action|Drama|War', 5010),\n", + " (\"Harry Potter and the Sorcerer's Stone (a.k.a. Harry Potter and the Philosopher's Stone) (2001)\",\n", + " 'Adventure|Children|Fantasy',\n", + " 4896),\n", + " ('Big Hero 6 (2014)', 'Action|Animation|Comedy', 115617),\n", + " ('Hot Fuzz (2007)', 'Action|Comedy|Crime|Mystery', 51255)]" + ] + }, + "execution_count": 21, + "metadata": { + "bento_obj_id": "139905661210368" + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Ratatouille\n", + "get_topk_sim_movies(50872)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "What do you think? Can we do better?" + ] + } + ], + "metadata": { + "bento_stylesheets": { + "bento/extensions/flow/main.css": true, + "bento/extensions/kernel_selector/main.css": true, + "bento/extensions/kernel_ui/main.css": true, + "bento/extensions/new_kernel/main.css": true, + "bento/extensions/system_usage/main.css": true, + "bento/extensions/theme/main.css": true + }, + "disseminate_notebook_id": { + "notebook_id": "552815895724328" + }, + "disseminate_notebook_info": { + "bento_version": "20210627-210324", + "description": "", + "hide_code": false, + "hipster_group": "", + "kernel_build_info": { + "deps": [ + "//caffe2/caffe2/fb/ifbpy:all_pytorch_and_caffe2_deps", + "//github/third-party/PyTorchLightning/pytorch-lightning:lib" + ], + "external_deps": [] + }, + "no_uii": true, + "notebook_number": "954867", + "others_can_edit": false, + "reviewers": "", + "revision_id": "489172452146612", + "tags": "", + "tasks": "", + "title": "torchrec movielens" + }, + "kernelspec": { + "display_name": "pytorch", + "language": "python", + "name": "bento_kernel_pytorch" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/torchrec/fx/__init__.py b/torchrec/fx/__init__.py new file mode 100644 index 000000000..099f7b959 --- /dev/null +++ b/torchrec/fx/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 + +from torchrec.fx.tracer import Tracer, symbolic_trace # noqa diff --git a/torchrec/fx/tests/test_tracer.py b/torchrec/fx/tests/test_tracer.py new file mode 100644 index 000000000..489c7ebb6 --- /dev/null +++ b/torchrec/fx/tests/test_tracer.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 + +import unittest +from typing import List, Optional, Tuple + +import torch +import torch.fx +from torch.testing import FileCheck # @manual +from torchrec.distributed.types import LazyAwaitable +from torchrec.fx import symbolic_trace +from torchrec.sparse.jagged_tensor import ( + JaggedTensor, + KeyedJaggedTensor, +) + + +torch.fx.wrap("len") + + +class TestTracer(unittest.TestCase): + maxDiff: Optional[int] = None + + def test_jagged_tensor(self) -> None: + class ModuleCreateAndAccessJaggedTensor(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input: int) -> int: + features = JaggedTensor( + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + return ( + features.values().numel() + + features.weights().numel() + + features.lengths().numel() + + features.offsets().numel() + ) + + class ModuleUseJaggedTensorAsInputAndOutput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input: JaggedTensor) -> JaggedTensor: + return JaggedTensor( + input.values(), + input.weights(), + lengths=input.lengths(), + offsets=input.offsets(), + ) + + class ModuleUseJaggedTensorAsInput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input: JaggedTensor) -> int: + return ( + input.values().numel() + + input.weights().numel() + + input.lengths().numel() + + input.offsets().numel() + ) + + class ModuleUseJaggedTensorAsOutput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + values: torch.Tensor, + weights: torch.Tensor, + lengths: torch.Tensor, + ) -> JaggedTensor: + return JaggedTensor(values, weights, lengths) + + # Case 1: JaggedTensor is only used as an output of the root module. + m = ModuleUseJaggedTensorAsOutput() + gm = symbolic_trace(m) + FileCheck().check("JaggedTensor").check("return jagged_tensor").run(gm.code) + + values = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + weights = torch.tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + lengths = torch.tensor([0, 2, 2, 3, 4, 5, 8]) + + ref_jt = m(values, weights, lengths) + traced_jt = gm(values, weights, lengths) + + self.assertTrue(torch.equal(traced_jt.values(), ref_jt.values())) + self.assertTrue(torch.equal(traced_jt.weights(), ref_jt.weights())) + self.assertTrue(torch.equal(traced_jt.lengths(), ref_jt.lengths())) + + # Case 2: JaggedTensor is only used as an input of the root module. + m = ModuleUseJaggedTensorAsInput() + gm = symbolic_trace(m) + FileCheck().check("values()").check("numel()").check("weights").check( + "lengths" + ).check("offsets").run(gm.code) + + input = JaggedTensor( + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + ref_out = m(input) + traced_out = gm(input) + self.assertEqual(ref_out, traced_out) + + # Case 3: JaggedTensor is used as both an input and an output of the root module. + m = ModuleUseJaggedTensorAsInputAndOutput() + gm = symbolic_trace(m) + FileCheck().check("values()").check("weights").check("lengths").check( + "offsets" + ).check("JaggedTensor").run(gm.code) + + ref_out = m(input) + traced_out = gm(input) + self.assertTrue(torch.equal(traced_out.values(), ref_out.values())) + self.assertTrue(torch.equal(traced_out.weights(), ref_out.weights())) + self.assertTrue(torch.equal(traced_out.lengths(), ref_out.lengths())) + + # Case 4: JaggedTensor is only used within the root module and not as part of + # the root module's input/output interface. + m = ModuleCreateAndAccessJaggedTensor() + gm = symbolic_trace(m) + FileCheck().check("return 29").check_not("JaggedTensor").run(gm.code) + ref_out = m(8) + traced_out = gm(8) + self.assertEqual(ref_out, traced_out) + + def test_keyed_jagged_tensor(self) -> None: + class ModuleCreateAndAccessKeyedJaggedTensor(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input: int) -> int: + features = KeyedJaggedTensor.from_offsets_sync( + values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + keys=["index_0", "index_1"], + offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]), + ) + return ( + len(features.keys()) + + features.values().numel() + + features.weights().numel() + + features.lengths().numel() + + features.offsets().numel() + ) + + class ModuleUseKeyedJaggedTensorAsInputAndOutput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, input: KeyedJaggedTensor + ) -> Tuple[KeyedJaggedTensor, int]: + output = KeyedJaggedTensor( + input.keys(), + input.values(), + input.weights(), + lengths=input.lengths(), + offsets=input.offsets(), + ) + return output, output._stride + + class ModuleUseKeyedJaggedTensorAsInput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input: KeyedJaggedTensor) -> int: + return ( + len(input.keys()) + + input.values().numel() + + input.weights().numel() + + input.lengths().numel() + + input.offsets().numel() + ) + + class ModuleUseKeyedJaggedTensorAsOutput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + keys: List[str], + values: torch.Tensor, + weights: torch.Tensor, + lengths: torch.Tensor, + ) -> Tuple[KeyedJaggedTensor, int]: + output = KeyedJaggedTensor(keys, values, weights, lengths) + return output, output._stride + + # Case 1: KeyedJaggedTensor is only used as an output of the root module. + m = ModuleUseKeyedJaggedTensorAsOutput() + gm = symbolic_trace(m) + FileCheck().check("KeyedJaggedTensor").check( + "return (keyed_jagged_tensor," + ).run(gm.code) + + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + lengths = torch.IntTensor([2, 0, 1, 1, 1, 3]) + + ref_out = m(keys, values, weights, lengths) + traced_out = gm(keys, values, weights, lengths) + + self.assertEqual(ref_out[1], traced_out[1]) + self.assertTrue(torch.equal(traced_out[0].offsets(), ref_out[0].offsets())) + + # Case 2: KeyedJaggedTensor is only used as an input of the root module. + m = ModuleUseKeyedJaggedTensorAsInput() + gm = symbolic_trace(m) + FileCheck().check("KeyedJaggedTensor").check("keys()").check("len").check( + "values()" + ).check("numel()").run(gm.code) + + input = KeyedJaggedTensor.from_offsets_sync( + values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + keys=["index_0", "index_1"], + offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]), + ) + ref_out = m(input) + traced_out = gm(input) + self.assertEqual(ref_out, traced_out) + + # Case 3: KeyedJaggedTensor is used as both an input and an output of the root module. + m = ModuleUseKeyedJaggedTensorAsInputAndOutput() + gm = symbolic_trace(m) + FileCheck().check("KeyedJaggedTensor").check("keys()").check("values()").check( + "._stride" + ).run(gm.code) + input = KeyedJaggedTensor.from_offsets_sync( + values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + weights=torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + keys=["index_0", "index_1"], + offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]), + ) + ref_out = m(input) + traced_out = gm(input) + self.assertEqual(ref_out[1], traced_out[1]) + + # Case 4: KeyedJaggedTensor is only used within the root module and not as part of + # the root module's input/output interface. + m = ModuleCreateAndAccessKeyedJaggedTensor() + gm = symbolic_trace(m) + FileCheck().check("return 35").check_not("KeyedJaggedTensor").run(gm.code) + ref_out = m(8) + traced_out = gm(8) + self.assertEqual(ref_out, traced_out) + + def test_trace_async_module(self) -> None: + class NeedWait(LazyAwaitable[torch.Tensor]): + def __init__(self, obj: torch.Tensor) -> None: + super().__init__() + self._obj = obj + + def wait(self) -> torch.Tensor: + return self._obj + 3 + + class MyAsyncModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input) -> LazyAwaitable[torch.Tensor]: + return NeedWait(input + 2) + + # Test automated LazyAwaitable type `wait()` + class AutoModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.sparse = MyAsyncModule() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.add(self.sparse(input), input * 10) + + auto_model = AutoModel() + auto_gm = symbolic_trace(auto_model) + FileCheck().check("+ 2").check("NeedWait").check("* 10").run(auto_gm.code) + + input = torch.randn(3, 4) + ref_out = auto_model(input) + traced_out = auto_gm(input) + self.assertTrue(torch.equal(ref_out, traced_out)) diff --git a/torchrec/fx/tracer.py b/torchrec/fx/tracer.py new file mode 100644 index 000000000..9d190300b --- /dev/null +++ b/torchrec/fx/tracer.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 + +from typing import Any, Optional, Dict, Union, Callable + +import torch +from torch.fx.node import Argument +from torchrec.distributed.types import NoWait + + +class Tracer(torch.fx.Tracer): + """ + NOTE [ Custom FX tracer for torchrec ] + + We create a custom FX tracer to trace torchrec based models. The custom tracer + right now have several purposes (the list might expand if we have more use cases): + 1. Handling python generic types (i.e. NoWait[T], Awaitable[T]) and lower it to + TorchScript if needed + """ + + def __init__(self) -> None: + super().__init__() + + # pyre-ignore[2] + def create_arg(self, a: Any) -> Argument: + if isinstance(a, NoWait): + return self.create_node( + "call_function", + target=NoWait, + args=self.create_arg((a._obj,)), + kwargs={}, + type_expr=NoWait, + ) + return super().create_arg(a) + + +def symbolic_trace( + # pyre-ignore[24] + root: Union[torch.nn.Module, Callable], + concrete_args: Optional[Dict[str, Any]] = None, +) -> torch.fx.GraphModule: + + tracer = Tracer() + graph = tracer.trace(root, concrete_args) + return torch.fx.GraphModule(root, graph) diff --git a/torchrec/linter/module_linter.py b/torchrec/linter/module_linter.py new file mode 100644 index 000000000..37d52b0b2 --- /dev/null +++ b/torchrec/linter/module_linter.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 + +import ast +import json +from argparse import ArgumentParser, Namespace +from typing import Any, Dict, List, Optional, Tuple + + +MAX_NUM_ARGS_IN_MODULE_CTOR = 5 + + +def print_error_message( + python_path: str, node: ast.AST, name: str, message: str, severity: str = "warning" +) -> None: + """ + This function will print linter error in a format that is compatible with + our internal tools. + + Args: + python_path: Path to the file with the error + node: AST node describing snippet of code + name: Name of the linter error + message: Error message to show to user + + Optional Args: + severity: How severe should be considered the error. Default level: 'error' + + Returns: + None + """ + lint_item = { + "path": python_path, + "line": node.lineno, + "char": node.col_offset + 1, + "severity": severity, + "name": name, + "description": message, + } + print(json.dumps(lint_item)) + + +# pyre-ignore[3]: Return type must be specified as type that does not contain `Any`. +def get_function_args(node: ast.FunctionDef) -> Tuple[List[Any], List[Any]]: + """ + This functon will process function definition and will extract all + arguments used by a given function and return all optional and non-optional + args used by the function. + + Args: + node: Function node containing function that needs to be analyzed + + Returns: + (non_optional_args, optional_args): named function args + """ + assert ( + type(node) == ast.FunctionDef + ), "Incorrect node type. Expected ast.FunctionDef, got {}".format(type(node)) + total_args = len(node.args.args) + default_args = len(node.args.defaults) + + optional_args = [] + non_optional_args = [] + # Handle positional args + for i in range(total_args): + if i + default_args < total_args: + non_optional_args.append(node.args.args[i].arg) + else: + optional_args.append(node.args.args[i].arg) + + # Handle named args + for arg in node.args.kwonlyargs: + optional_args.append(arg.arg) + + return non_optional_args, optional_args + + +def check_class_definition(python_path: str, node: ast.ClassDef) -> None: + """ + This function will run set of sanity checks against class definitions + and their docstrings. + + Args: + python_path: Path to the file that is getting checked + node: AST node with the ClassDef that needs to be checked + + Returns: + None + """ + assert ( + type(node) == ast.ClassDef + ), "Received invalid node type. Expected ClassDef, got: {}".format(type(node)) + + is_TorchRec_module = False + is_test_file = "tests" in python_path + for base in node.bases: + # For now only names and attributes are supported + if type(base) != ast.Name and type(base) != ast.Attribute: # pragma: nocover + continue + + # We assume that TorchRec module has one of the following inheritance patterns: + # 1. `class SomeTorchRecModule(LazyModuleExtensionMixin, torch.nn.Module)` + # 2. `class SomeTorchRecModule(torch.nn.Module)` + # pyre-ignore[16]: `_ast.expr` has no attribute `id`. + if hasattr(base, "id") and base.id == "LazyModuleExtensionMixin": + is_TorchRec_module = True + break + # pyre-ignore[16]: `_ast.expr` has no attribute `id`. + elif hasattr(base, "attr") and base.attr == "Module": + is_TorchRec_module = True + break + + if not is_TorchRec_module or is_test_file: + return + + docstring: Optional[str] = ast.get_docstring(node) + if docstring is None: + print_error_message( + python_path, + node, + "No docstring found in a TorchRec module", + "TorchRec modules are required to have a docstring describing how " + "to use them. Given Module don't have a docstring, please fix this.", + ) + return + + # Check presence of the example: + if "Example:" not in docstring or ">>> " not in docstring: + print_error_message( + python_path, + node, + "No runnable example in a TorchRec module", + "TorchRec modules are required to have runnable examples in " + '"Example:" section, that start from ">>> ". Please fix the docstring', + ) + + # Check correctness of the Args for a class definition: + required_keywords = ["Constructor Args:", "Call Args:", "Returns:"] + missing_keywords = [] + for keyword in required_keywords: + if keyword not in docstring: + missing_keywords.append(keyword) + + if len(missing_keywords) > 0: + print_error_message( + python_path, + node, + "Missing required keywords from TorchRec module", + "TorchRec modules are required to description of their args and " + 'results in "Constructor Args:", "Call Args:", "Returns:". ' + "Missing keywords: {}.".format(missing_keywords), + ) + + # Check actual args from the functions + # pyre-ignore[33]: Explicit annotation for `functions` cannot contain `Any`. + functions: Dict[str, Tuple[List[Any], List[Any]]] = {} + for sub_node in node.body: + if type(sub_node) == ast.FunctionDef: + assert isinstance(sub_node, ast.FunctionDef) + functions[sub_node.name] = get_function_args(sub_node) + + def check_function(function_name: str) -> None: + if function_name not in functions: + return + + if function_name == "__init__": + # NOTE: -1 to not count the `self` argument. + num_args = sum([len(args) for args in functions[function_name]]) - 1 + if num_args > MAX_NUM_ARGS_IN_MODULE_CTOR: + print_error_message( + python_path, + node, + "TorchRec module has too many constructor arguments", + "TorchRec module can have at most {} constructor arguments, but this module has {}.".format( + MAX_NUM_ARGS_IN_MODULE_CTOR, + len(functions[function_name][1]), + ), + ) + if function_name in functions: + missing_required_args = [] + missing_optional_args = [] + for arg in functions[function_name][0]: + # Ignore checks for required self and net args + if arg == "self" or arg == "net": + continue + assert docstring is not None + if arg not in docstring: + missing_required_args.append(arg) + for arg in functions[function_name][1]: + assert docstring is not None + if arg not in docstring: + missing_optional_args.append(arg) + if len(missing_required_args) > 0 or len(missing_optional_args) > 0: + print_error_message( + python_path, + node, + "Missing docstring descriptions for {} function arguments.".format( + function_name + ), + ( + "Missing descriptions for {} function arguments. " + "Missing required args: {}, missing optional args: {}" + ).format( + function_name, missing_required_args, missing_optional_args + ), + ) + + check_function("__init__") + check_function("forward") + + +def read_file(path: str) -> str: # pragma: nocover + """ + This function simply reads contents of the file. It's moved out to a function + purely to simplify testing process. + + Args: + path: File to read. + + Returns: + content(str): Content of given file. + """ + return open(path).read() + + +def linter_one_file(python_path: str) -> None: + """ + This function will check all Modules defined in the given file for a valid + documentation based on the AST. + + Input args: + python_path: Path to the file that need to be verified with the linter. + + Returns: + None + """ + python_path = python_path.strip() + try: + for node in ast.parse(read_file(python_path)).body: + if type(node) == ast.ClassDef: + assert isinstance(node, ast.ClassDef) + check_class_definition(python_path, node) + except SyntaxError as e: # pragma: nocover + # possible failing due to file parsing error + lint_item = { + "path": python_path, + "line": e.lineno, + "char": e.offset, + "severity": "warning", + "name": "syntax-error", + "description": ( + f"There is a linter parser error with message: {e.msg}. " + "Please report the diff to torchrec oncall" + ), + "bypassChangedLineFiltering": True, + } + print(json.dumps(lint_item)) + + +def _make_argparse() -> ArgumentParser: # pragma: nocover + parser = ArgumentParser( + description="TorchRec docstring linter", fromfile_prefix_chars="@" + ) + parser.add_argument("source_files", nargs="+", help="Path to python source files") + + return parser + + +def _parse_args() -> Namespace: # pragma: nocover + ap = _make_argparse() + return ap.parse_args() + + +if __name__ == "__main__": # pragma: nocover + args: Namespace = _parse_args() + for filename in args.source_files: + linter_one_file(filename) diff --git a/torchrec/linter/tests/test_module_linter.py b/torchrec/linter/tests/test_module_linter.py new file mode 100644 index 000000000..5b25ed267 --- /dev/null +++ b/torchrec/linter/tests/test_module_linter.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 + +import unittest +from unittest.mock import patch + +import hypothesis.strategies as st +from hypothesis import given +from torchrec.linter import module_linter + + +def populate_parent_class_list( + class_src: str, + uses_LazyModuleExtensionMixin: bool, +) -> str: + if uses_LazyModuleExtensionMixin: + parent_class_list = "LazyModuleExtensionMixin, torch.nn.Module" + else: + parent_class_list = "torch.nn.Module" + return class_src.replace("${parent_class_list}", parent_class_list) + + +class DocStringLinterTest(unittest.TestCase): + def test_docstring_empty(self) -> None: + src = "" + with patch("builtins.print") as p, patch( + "torchrec.linter.module_linter.read_file", return_value=src + ): + module_linter.linter_one_file("a") + + self.assertEqual(p.call_count, 0) + + def test_docstring_no_modules(self) -> None: + src = """ +class A: + pass + """ + with patch("builtins.print") as p, patch( + "torchrec.linter.module_linter.read_file", return_value=src + ): + module_linter.linter_one_file("a") + + self.assertEqual(p.call_count, 0) + + # pyre-ignore[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. + @given(uses_LazyModuleExtensionMixin=st.booleans()) + def test_docstring_no_docstring(self, uses_LazyModuleExtensionMixin: bool) -> None: + src = """ +class F(${parent_class_list}): + def __init__(self): + pass + """ + src = populate_parent_class_list(src, uses_LazyModuleExtensionMixin) + with patch("builtins.print") as p, patch( + "torchrec.linter.module_linter.read_file", return_value=src + ): + module_linter.linter_one_file("a") + + self.assertEqual(p.call_count, 1) + self.assertTrue( + "No docstring found in a TorchRec module" in p.call_args_list[0][0][0] + ) + + # pyre-ignore[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. + @given(uses_LazyModuleExtensionMixin=st.booleans()) + def test_docstring_no_module_init( + self, uses_LazyModuleExtensionMixin: bool + ) -> None: + src = """ +class F(${parent_class_list}): + \""" + \""" + def forward(self, net, z): + pass + """ + src = populate_parent_class_list(src, uses_LazyModuleExtensionMixin) + with patch("builtins.print") as p, patch( + "torchrec.linter.module_linter.read_file", return_value=src + ): + module_linter.linter_one_file("a") + + self.assertEqual(p.call_count, 3) + self.assertTrue( + "No runnable example in a TorchRec module" in p.call_args_list[0][0][0] + ) + self.assertTrue( + "Missing required keywords from TorchRec module" + in p.call_args_list[1][0][0] + ) + self.assertTrue("Missing docstring descriptions" in p.call_args_list[2][0][0]) + self.assertTrue("['z']" in p.call_args_list[2][0][0]) + + # pyre-ignore[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. + @given(uses_LazyModuleExtensionMixin=st.booleans()) + def test_missing_args(self, uses_LazyModuleExtensionMixin: bool) -> None: + src = """ +class F(${parent_class_list}): + \""" + \""" + def __init__(self, x, y='a', *arg, k='k'): + pass + + def forward(self, net, z): + pass + """ + src = populate_parent_class_list(src, uses_LazyModuleExtensionMixin) + with patch("builtins.print") as p, patch( + "torchrec.linter.module_linter.read_file", return_value=src + ): + module_linter.linter_one_file("a") + + self.assertEqual(p.call_count, 4) + self.assertTrue( + "No runnable example in a TorchRec module" in p.call_args_list[0][0][0] + ) + self.assertTrue( + "Missing required keywords from TorchRec module" + in p.call_args_list[1][0][0] + ) + self.assertTrue("Missing docstring descriptions" in p.call_args_list[2][0][0]) + self.assertTrue("['x']" in p.call_args_list[2][0][0]) + self.assertTrue("['y', 'k']" in p.call_args_list[2][0][0]) + self.assertTrue("Missing docstring descriptions" in p.call_args_list[3][0][0]) + self.assertTrue("['z']" in p.call_args_list[3][0][0]) + + # pyre-ignore[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. + @given(uses_LazyModuleExtensionMixin=st.booleans()) + def test_valid_module(self, uses_LazyModuleExtensionMixin: bool) -> None: + src = """ +class F(${parent_class_list}): + \""" + Blah. + + Constructor Args: + x: Blah + y: Blah. Default: "a" + + Call Args: + z: Blah + + Returns: + None + + Example: + >>> pass + \""" + def __init__(self, x, y='a'): + pass + + def forward(self, z): + pass + """ + src = populate_parent_class_list(src, uses_LazyModuleExtensionMixin) + with patch("builtins.print") as p, patch( + "torchrec.linter.module_linter.read_file", return_value=src + ): + module_linter.linter_one_file("a") + + self.assertEqual(p.call_count, 0) + + # pyre-ignore[56]: Pyre was not able to infer the type of argument + # `hypothesis.strategies.booleans()` to decorator factory `hypothesis.given`. + @given(uses_LazyModuleExtensionMixin=st.booleans()) + def test_num_ctor_args(self, uses_LazyModuleExtensionMixin: bool) -> None: + # Case 1: TorchRec module has less than 5 ctor args -> pass + src = """ +class F(${parent_class_list}): + \""" + Blah. + + Constructor Args: + a: Blah + b: Blah + c: Blah + d: Blah + e: Blah. Default: "e". + + Call Args: + z: Blah + + Returns: + None + + Example: + >>> pass + \""" + def __init__(self, a, b, c, d, e='e'): + pass + + def forward(self, z): + pass + """ + src = populate_parent_class_list(src, uses_LazyModuleExtensionMixin) + with patch("builtins.print") as p, patch( + "torchrec.linter.module_linter.read_file", return_value=src + ): + module_linter.linter_one_file("a") + + self.assertEqual(p.call_count, 0) + + # Case 2: TorchRec module has more than 5 ctor args -> print error + src = """ +class F(${parent_class_list}): + \""" + Blah. + + Constructor Args: + a: Blah + b: Blah + c: Blah + d: Blah + e: Blah + f: Blah. Default: "f". + + Call Args: + z: Blah + + Returns: + None + + Example: + >>> pass + \""" + def __init__(self, a, b, c, d, e, f='f'): + pass + + def forward(self, z): + pass + """ + src = populate_parent_class_list(src, uses_LazyModuleExtensionMixin) + with patch("builtins.print") as p, patch( + "torchrec.linter.module_linter.read_file", return_value=src + ): + module_linter.linter_one_file("a") + + self.assertEqual(p.call_count, 1) + self.assertTrue( + "TorchRec module has too many constructor arguments" + in p.call_args_list[0][0][0] + ) + + # Case 3: not a TorchRec module -> pass + src = """ +class F: + \""" + Blah. + + Constructor Args: + a: Blah + b: Blah + c: Blah + d: Blah + e: Blah + f: Blah. Default: "f". + + Call Args: + z: Blah + + Returns: + None + + Example: + >>> pass + \""" + def __init__(self, a, b, c, d, e, f='f'): + pass + + def forward(self, z): + pass + """ + with patch("builtins.print") as p, patch( + "torchrec.linter.module_linter.read_file", return_value=src + ): + module_linter.linter_one_file("a") + + self.assertEqual(p.call_count, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchrec/models/__init__.py b/torchrec/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/models/deepfm.py b/torchrec/models/deepfm.py new file mode 100644 index 000000000..c62536374 --- /dev/null +++ b/torchrec/models/deepfm.py @@ -0,0 +1,325 @@ +#!/usr/bin/env python3 + +from typing import List + +import torch +from torch import nn +from torchrec import EmbeddingBagCollection, KeyedJaggedTensor +from torchrec.modules.deepfm import DeepFM, FactorizationMachine +from torchrec.sparse.jagged_tensor import KeyedTensor + + +class SparseArch(nn.Module): + """ + Processes the Sparse Features of SparseNN. Does Embedding Lookup for all + EmbeddingBag and Embedding features of each collection. + + Constructor Args: + embedding_bag_collection: EmbeddingBagCollection, + + Call Args: + features: KeyedJaggedTensor, + + Returns: + KeyedJaggedTensor - size F * D X B + + Example: + >>> eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] + ) + ebc_config = EmbeddingBagCollectionConfig(tables=[eb1_config, eb2_config]) + + ebc = EmbeddingBagCollection(config=ebc_config) + + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + sparse_arch(features) + """ + + def __init__(self, embedding_bag_collection: EmbeddingBagCollection) -> None: + super().__init__() + self.embedding_bag_collection: EmbeddingBagCollection = embedding_bag_collection + + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedTensor: + return self.embedding_bag_collection(features) + + +class DenseArch(nn.Module): + """ + Processes the dense features of DeepFMNN model. Output layer is sized to + the embedding_dimension of the EmbeddingBagCollection embeddings + + Constructor Args: + in_features: int + hidden_layer_size: int + embedding_dim: int - the same size of the embedding_dimension of sparseArch + device: torch.device + + Call Args: + features: torch.Tensor - size B X num_features + + Returns: + torch.Tensor - size B X D + + Example: + >>> B = 20 + >>> D = 3 + >>> in_features = 10 + >>> dense_arch = DenseArch(in_features=10, hidden_layer_size=10, embedding_dim=D) + >>> dense_embedded = dense_arch(torch.rand((B, 10))) + """ + + def __init__( + self, + in_features: int, + hidden_layer_size: int, + embedding_dim: int, + ) -> None: + super().__init__() + self.model: nn.Module = nn.Sequential( + nn.Linear(in_features, hidden_layer_size), + nn.ReLU(), + nn.Linear(hidden_layer_size, embedding_dim), + nn.ReLU(), + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + return self.model(features) + + +class FMInteractionArch(nn.Module): + """ + Processes the output of both SparseArch (sparse_features) and DenseArch + (dense_features) and apply the general DeepFM interaction according to the + extenal source of DeepFM paper: https://arxiv.org/pdf/1703.04247.pdf + + The output dimension is expected to be a cat of dense_features, D + + Constructor Args: + fm_in_features: int - the input dimension of dense_module in DeepFM. For example, + the input embeddings is [randn(3, 2, 3), randn(3, 4, 5)], the fm_in_features should + be: 2*3+4*5. + sparse_feature_names: List[str] - length of F + deep_fm_dimension: int - output of the deep interaction (DI) in the DeepFM arch. + + Call Args: + dense_features: torch.Tensor - size B X D + sparse_features: KeyedJaggedTensor - size F * D X B + + Returns: + torch.Tensor - B X (D + DI + 1) + + Example: + >>> D = 3 + >>> B = 10 + >>> keys = ["f1", "f2"] + >>> F = len(keys) + >>> fm_inter_arch = FMInteractionArch(sparse_feature_names=keys) + >>> dense_features = torch.rand((B, D)) + >>> sparse_features = KeyedTensor( + >>> keys=keys, + >>> length_per_key=[D, D], + >>> values=torch.rand((B, D * F)), + >>> ) + >>> cat_fm_output = fm_inter_arch(dense_features, sparse_features) + """ + + def __init__( + self, + fm_in_features: int, + sparse_feature_names: List[str], + deep_fm_dimension: int, + ) -> None: + super().__init__() + self.sparse_feature_names: List[str] = sparse_feature_names + self.deep_fm = DeepFM( + dense_module=nn.Sequential( + nn.Linear(fm_in_features, deep_fm_dimension), + nn.ReLU(), + ) + ) + self.fm = FactorizationMachine() + + def forward( + self, dense_features: torch.Tensor, sparse_features: KeyedTensor + ) -> torch.Tensor: + if len(self.sparse_feature_names) == 0: + return dense_features + + tensor_list: List[torch.Tensor] = [dense_features] + # dense/sparse interaction + # size B X F + for feature_name in self.sparse_feature_names: + tensor_list.append(sparse_features[feature_name]) + + deep_interaction = self.deep_fm(tensor_list) + fm_interaction = self.fm(tensor_list) + + return torch.cat([dense_features, deep_interaction, fm_interaction], dim=1) + + +class OverArch(nn.Module): + r""" + Final Arch - simple MLP. The output is just one target + + Constructor Args: + in_features: int - the output dimension of interaction arch + + Call Args: + features: torch.Tensor + + Returns: + torch.Tensor - size B X 1 + + Example: + >>> B = 20 + >>> over_arch = OverArch() + >>> logits = over_arch(torch.rand((B, 10))) + """ + + def __init__(self, in_features: int) -> None: + super().__init__() + self.model: nn.Module = nn.Sequential( + nn.Linear(in_features, 1), + nn.Sigmoid(), + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + return self.model(features) + + +class SimpleDeepFMNN(nn.Module): + r""" + Basic recsys module with DeepFM arch. Processes sparse features by + learning pooled embeddings for each feature. Learns the relationship between + dense features and sparse features by projecting dense features into the same + embedding space. Learns the interaction among those dense and sparse features + by deep_fm proposed in this paper: https://arxiv.org/pdf/1703.04247.pdf + + The module assumes all sparse features have the same embedding dimension + (i.e, each EmbeddingBagConfig uses the same embedding_dim) + + Constructor Args: + num_dense_features: int - the number of input dense features. + embedding_bag_collection: EmbeddingBagCollection, + hidden_layer_size: int, the hidden layer size that used in dense module + deep_fm_dimension: int, the output layer size that used in deep_fm's deep + interaction module + + + Call Args: + dense_features: torch.Tensodr, + sparse_features: KeyedJaggedTensor, + + Returns: + torch.Tensor - logits with size B X 1 + + Example: + >>> B = 2 + D = 8 + + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + ebc_config = EmbeddingBagCollectionConfig(tables=[eb1_config, eb2_config]) + + ebc = EmbeddingBagCollection(config=ebc_config) + sparse_nn = SimpleDeepFMNN( + embedding_bag_collection=ebc, hidden_layer_size=20, over_embedding_dim=5 + ) + + features = torch.rand((B, 100)) + + # 0 1 + # 0 [1,2] [4,5] + # 1 [4,3] [2,9] + # ^ + # feature + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), + offsets=torch.tensor([0, 2, 4, 6, 8]), + ) + + logits = sparse_nn( + dense_features=features, + sparse_features=sparse_features, + ) + """ + + def __init__( + self, + num_dense_features: int, + embedding_bag_collection: EmbeddingBagCollection, + hidden_layer_size: int, + deep_fm_dimension: int, + ) -> None: + super().__init__() + assert ( + len(embedding_bag_collection.embedding_bag_configs) > 0 + ), "At least one embedding bag is required" + for i in range(1, len(embedding_bag_collection.embedding_bag_configs)): + conf_prev = embedding_bag_collection.embedding_bag_configs[i - 1] + conf = embedding_bag_collection.embedding_bag_configs[i] + assert ( + conf_prev.embedding_dim == conf.embedding_dim + ), "All EmbeddingBagConfigs must have the same dimension" + embedding_dim: int = embedding_bag_collection.embedding_bag_configs[ + 0 + ].embedding_dim + + feature_names = [] + + fm_in_features = embedding_dim + for conf in embedding_bag_collection.embedding_bag_configs: + for feat in conf.feature_names: + feature_names.append(feat) + fm_in_features += conf.embedding_dim + + self.sparse_arch = SparseArch(embedding_bag_collection) + self.dense_arch = DenseArch( + in_features=num_dense_features, + hidden_layer_size=hidden_layer_size, + embedding_dim=embedding_dim, + ) + self.inter_arch = FMInteractionArch( + fm_in_features=fm_in_features, + sparse_feature_names=feature_names, + deep_fm_dimension=deep_fm_dimension, + ) + over_in_features = embedding_dim + deep_fm_dimension + 1 + self.over_arch = OverArch(over_in_features) + + def forward( + self, + dense_features: torch.Tensor, + sparse_features: KeyedJaggedTensor, + ) -> torch.Tensor: + embedded_dense = self.dense_arch(dense_features) + embedded_sparse = self.sparse_arch(sparse_features) + concatenated_dense = self.inter_arch( + dense_features=embedded_dense, sparse_features=embedded_sparse + ) + logits = self.over_arch(concatenated_dense) + return logits diff --git a/torchrec/models/dlrm.py b/torchrec/models/dlrm.py new file mode 100644 index 000000000..4b99460e9 --- /dev/null +++ b/torchrec/models/dlrm.py @@ -0,0 +1,360 @@ +#!/usr/bin/env python3 + +from math import comb +from typing import List, Optional + +import torch +from torch import nn +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.modules.mlp import MLP +from torchrec.sparse.jagged_tensor import ( + KeyedJaggedTensor, + KeyedTensor, +) + + +""" +Notations uses throughout: + +F: number of sparseFeatures +D: embedding_dimension of sparse features +B: batch_size +num_features: number of dense features + +""" + + +class SparseArch(nn.Module): + """ + Processes the Sparse Features of SparseNN. Does Embedding Lookup for all + EmbeddingBag and Embedding features of each collection. + + Constructor Args: + embedding_bag_collection: EmbeddingBagCollection, + + Call Args: + features: KeyedJaggedTensor, + + Returns: + KeyedJaggedTensor - size F * D X B + + Example: + >>> eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] + ) + ebc_config = EmbeddingBagCollectionConfig(tables=[eb1_config, eb2_config]) + + ebc = EmbeddingBagCollection(config=ebc_config) + + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + sparse_arch(features) + """ + + def __init__(self, embedding_bag_collection: EmbeddingBagCollection) -> None: + super().__init__() + self.embedding_bag_collection: EmbeddingBagCollection = embedding_bag_collection + + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedTensor: + return self.embedding_bag_collection(features) + + +class DenseArch(nn.Module): + """ + Processes the dense features of SparseNN model. + + Constructor Args: + in_features: int - size of the input. + layer_sizes: List[int] - list of layer sizes. + device: (Optional[torch.device]). + + Call Args: + features: torch.Tensor - size B X num_features + + Returns: + torch.Tensor - size B X D + + Example: + >>> B = 20 + D = 3 + dense_arch = DenseArch(10, layer_sizes=[15, D]) + dense_embedded = dense_arch(torch.rand((B, 10))) + """ + + def __init__( + self, + in_features: int, + layer_sizes: List[int], + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self.model: nn.Module = MLP( + in_features, layer_sizes, bias=True, activation="relu", device=device + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + return self.model(features) + + +class InteractionArch(nn.Module): + """ + Processes the output of both SparseArch (sparse_features) and DenseArch + (dense_features). Returns the pairwise dot product of each sparse feature pair, + the dot product of each sparse features with the output of the dense layer, + and the dense layer itself (all concatenated). + + NOTE: The dimensionality of the dense_features (D) is expected to match the + dimensionality of the sparse_features so that the dot products between them can be + computed. + + Constructor Args: + num_sparse_features: int - size F + + Call Args: + dense_features: torch.Tensor - size B X D + sparse_features: KeyedJaggedTensor - size F * D X B + + Returns: + torch.Tensor - B X (D + F + F choose 2) + + Example: + >>> D = 3 + B = 10 + keys = ["f1", "f2"] + F = len(keys) + inter_arch = InteractionArch(num_sparse_features=F) + + dense_features = torch.rand((B, D)) + + sparse_features = KeyedTensor( + keys=keys, + length_per_key=[D, D], + values=torch.rand((B, D * F)), + ) + + # B X (D + F + F choose 2) + concat_dense = inter_arch(dense_features, sparse_features) + """ + + def __init__(self, num_sparse_features: int) -> None: + super().__init__() + self.F = num_sparse_features + self.triu_indices: torch.Tensor = torch.triu_indices( + self.F + 1, self.F + 1, offset=1 + ) + + def forward( + self, dense_features: torch.Tensor, sparse_features: KeyedTensor + ) -> torch.Tensor: + if self.F <= 0: + return dense_features + (B, D) = dense_features.shape + + sparse_values = sparse_features.values().reshape(B, self.F, D) + combined_values = torch.cat((dense_features.unsqueeze(1), sparse_values), dim=1) + + # dense/sparse + sparse/sparse interaction + # size B X (F + F choose 2) + interactions = torch.bmm( + combined_values, torch.transpose(combined_values, 1, 2) + ) + interactions_flat = interactions[:, self.triu_indices[0], self.triu_indices[1]] + + return torch.cat((dense_features, interactions_flat), dim=1) + + +class OverArch(nn.Module): + """ + Final Arch of SparseNN - simple MLP over OverArch + + Constructor Args: + in_features: int + layer_sizes: list[int] + device: (Optional[torch.device]). + + Call Args: + features: torch.Tensor + + Returns: + torch.Tensor - size B X layer_sizes[-1] + + Example: + >>> B = 20 + D = 3 + over_arch = OverArch(10, [5, 1]) + logits = over_arch(torch.rand((B, 10))) + """ + + def __init__( + self, + in_features: int, + layer_sizes: List[int], + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + if len(layer_sizes) <= 1: + raise ValueError("OverArch must have multiple layers.") + self.model: nn.Module = nn.Sequential( + MLP( + in_features, + layer_sizes[:-1], + bias=True, + activation="relu", + device=device, + ), + nn.Linear(layer_sizes[-2], layer_sizes[-1], bias=True, device=device), + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + return self.model(features) + + +class DLRM(nn.Module): + """ + Recsys model from "Deep Learning Recommendation Model for Personalization and + Recommendation Systems" (https://arxiv.org/abs/1906.00091). Processes sparse + features by learning pooled embeddings for each feature. Learns the relationship + between dense features and sparse features by projecting dense features into the + same embedding space. Also, learns the pairwise relationships between sparse + features. + + The module assumes all sparse features have the same embedding dimension + (i.e, each EmbeddingBagConfig uses the same embedding_dim) + + Constructor Args: + embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags + used to define SparseArch. + dense_in_features (int): the dimensionality of the dense input features. + dense_arch_layer_sizes (list[int]): the layer sizes for the DenseArch. + over_arch_layer_sizes (list[int]): the layer sizes for the OverArch. NOTE: The + output dimension of the InteractionArch should not be manually specified + here. + dense_device: (Optional[torch.device]). + + Call Args: + dense_features: torch.Tensor, + sparse_features: KeyedJaggedTensor, + + Returns: + torch.Tensor - logits with size B X 1 + + Example: + >>> B = 2 + D = 8 + + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + ebc_config = EmbeddingBagCollectionConfig(tables=[eb1_config, eb2_config]) + + ebc = EmbeddingBagCollection(config=ebc_config) + model = DLRM( + embedding_bag_collection=ebc, + dense_in_features=100, + dense_arch_layer_sizes=[20], + over_arch_layer_sizes=[5, 1], + ) + + features = torch.rand((B, 100)) + + # 0 1 + # 0 [1,2] [4,5] + # 1 [4,3] [2,9] + # ^ + # feature + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), + offsets=torch.tensor([0, 2, 4, 6, 8]), + ) + + logits = model( + dense_features=features, + sparse_features=sparse_features, + ) + """ + + def __init__( + self, + embedding_bag_collection: EmbeddingBagCollection, + dense_in_features: int, + dense_arch_layer_sizes: List[int], + over_arch_layer_sizes: List[int], + dense_device: Optional[torch.device] = None, + ) -> None: + super().__init__() + assert ( + len(embedding_bag_collection.embedding_bag_configs) > 0 + ), "At least one embedding bag is required" + for i in range(1, len(embedding_bag_collection.embedding_bag_configs)): + conf_prev = embedding_bag_collection.embedding_bag_configs[i - 1] + conf = embedding_bag_collection.embedding_bag_configs[i] + assert ( + conf_prev.embedding_dim == conf.embedding_dim + ), "All EmbeddingBagConfigs must have the same dimension" + embedding_dim: int = embedding_bag_collection.embedding_bag_configs[ + 0 + ].embedding_dim + if dense_arch_layer_sizes[-1] != embedding_dim: + raise ValueError( + f"embedding_bag_collection dimension ({embedding_dim}) and final dense " + "arch layer size ({dense_arch_layer_sizes[-1]}) must match." + ) + + num_feature_names = sum( + [ + len(conf.feature_names) + for conf in embedding_bag_collection.embedding_bag_configs + ] + ) + + over_in_features = ( + embedding_dim + comb(num_feature_names, 2) + num_feature_names + ) + + self.sparse_arch = SparseArch(embedding_bag_collection) + self.dense_arch = DenseArch( + in_features=dense_in_features, + layer_sizes=dense_arch_layer_sizes, + device=dense_device, + ) + self.inter_arch = InteractionArch(num_sparse_features=num_feature_names) + self.over_arch = OverArch( + in_features=over_in_features, + layer_sizes=over_arch_layer_sizes, + device=dense_device, + ) + + def forward( + self, + dense_features: torch.Tensor, + sparse_features: KeyedJaggedTensor, + ) -> torch.Tensor: + embedded_dense = self.dense_arch(dense_features) + embedded_sparse = self.sparse_arch(sparse_features) + concatenated_dense = self.inter_arch( + dense_features=embedded_dense, sparse_features=embedded_sparse + ) + logits = self.over_arch(concatenated_dense) + return logits diff --git a/torchrec/models/tests/test_deepfm.py b/torchrec/models/tests/test_deepfm.py new file mode 100644 index 000000000..4aa892e83 --- /dev/null +++ b/torchrec/models/tests/test_deepfm.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 + +import unittest + +import torch +from torch.testing import FileCheck # @manual +from torchrec.fx import Tracer +from torchrec.fx import symbolic_trace +from torchrec.models.deepfm import ( + FMInteractionArch, + SimpleDeepFMNN, +) +from torchrec.modules.embedding_configs import ( + EmbeddingBagConfig, +) +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +class FMInteractionArchTest(unittest.TestCase): + def test_basic(self) -> None: + torch.manual_seed(0) + + D = 3 + B = 3 + DI = 2 + keys = ["f1", "f2"] + F = len(keys) + dense_features = torch.rand((B, D)) + + embeddings = KeyedTensor( + keys=keys, + length_per_key=[D] * F, + values=torch.rand((B, D * F)), + ) + inter_arch = FMInteractionArch( + fm_in_features=D + D * F, + sparse_feature_names=keys, + deep_fm_dimension=DI, + ) + inter_output = inter_arch(dense_features, embeddings) + self.assertEqual(inter_output.size(), (B, D + DI + 1)) + + # check output forward numerical accuracy + expected_output = torch.Tensor( + [ + [0.4963, 0.7682, 0.0885, 0.0000, 0.2646, 4.3660], + [0.1320, 0.3074, 0.6341, 0.0000, 0.0834, 7.6417], + [0.4901, 0.8964, 0.4556, 0.0000, 0.0671, 15.5230], + ], + ) + self.assertTrue( + torch.allclose( + inter_output, + expected_output, + rtol=1e-4, + atol=1e-4, + ) + ) + + # check tracer compatibility + gm = torch.fx.GraphModule(inter_arch, Tracer().trace(inter_arch)) + torch.jit.script(gm) + + +class SimpleDeepFMNNTest(unittest.TestCase): + def test_basic(self) -> None: + B = 2 + D = 8 + num_dense_features = 100 + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + + features = torch.rand((B, num_dense_features)) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]), + offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]), + ) + + deepfm_nn = SimpleDeepFMNN( + num_dense_features=num_dense_features, + embedding_bag_collection=ebc, + hidden_layer_size=20, + deep_fm_dimension=5, + ) + + logits = deepfm_nn( + dense_features=features, + sparse_features=sparse_features, + ) + self.assertEqual(logits.size(), (B, 1)) + + def test_no_sparse(self) -> None: + ebc = EmbeddingBagCollection(tables=[]) + with self.assertRaises(AssertionError): + SimpleDeepFMNN( + num_dense_features=10, + embedding_bag_collection=ebc, + hidden_layer_size=20, + deep_fm_dimension=5, + ) + + def test_fx(self) -> None: + B = 2 + D = 8 + num_dense_features = 100 + + eb1_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config]) + deepfm_nn = SimpleDeepFMNN( + num_dense_features=num_dense_features, + embedding_bag_collection=ebc, + hidden_layer_size=20, + deep_fm_dimension=5, + ) + gm = symbolic_trace(deepfm_nn) + FileCheck().check("KeyedJaggedTensor").check("cat").check("f2").run(gm.code) + + features = torch.rand((B, num_dense_features)) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f2"], + values=torch.tensor(range(3)), + offsets=torch.tensor([0, 2, 3]), + ) + + logits = gm( + dense_features=features, + sparse_features=sparse_features, + ) + self.assertEqual(logits.size(), (B, 1)) + + def test_fx_script(self) -> None: + B = 2 + D = 8 + num_dense_features = 100 + + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + deepfm_nn = SimpleDeepFMNN( + num_dense_features=num_dense_features, + embedding_bag_collection=ebc, + hidden_layer_size=20, + deep_fm_dimension=5, + ) + + features = torch.rand((B, num_dense_features)) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]), + offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]), + ) + + deepfm_nn( + dense_features=features, + sparse_features=sparse_features, + ) + + gm = symbolic_trace(deepfm_nn) + + scripted_gm = torch.jit.script(gm) + + logits = scripted_gm(features, sparse_features) + self.assertEqual(logits.size(), (B, 1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchrec/models/tests/test_dlrm.py b/torchrec/models/tests/test_dlrm.py new file mode 100644 index 000000000..faf3071e2 --- /dev/null +++ b/torchrec/models/tests/test_dlrm.py @@ -0,0 +1,538 @@ +#!/usr/bin/env python3 + +import math +import unittest +from itertools import combinations +from typing import List + +import torch +from torch.testing import FileCheck # @manual +from torchrec.fx import symbolic_trace +from torchrec.models.dlrm import ( + SparseArch, + DenseArch, + InteractionArch, + DLRM, +) +from torchrec.modules.embedding_configs import ( + EmbeddingBagConfig, +) +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor + + +class SparseArchTest(unittest.TestCase): + def test_basic(self) -> None: + torch.manual_seed(0) + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + D = sum( + eb_config.embedding_dim * len(eb_config.feature_names) + for eb_config in [eb1_config, eb2_config] + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + sparse_arch = SparseArch(ebc) + + keys = ["f1", "f2", "f3", "f4", "f5"] + offsets = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 19]) + features = KeyedJaggedTensor.from_offsets_sync( + keys=keys, + values=torch.tensor( + [1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3] + ), + offsets=offsets, + ) + B = (len(offsets) - 1) // len(keys) + + kt = sparse_arch(features) + self.assertEqual(kt.values().size(), (B, D)) + self.assertEqual(kt.keys(), ["f1", "f3", "f2"]) + self.assertEqual(kt.offset_per_key(), [0, 3, 6, 10]) + + expected_values = torch.tensor( + [ + [ + -0.7499, + -1.2665, + 1.0143, + -0.7499, + -1.2665, + 1.0143, + 2.0283, + 2.8195, + -2.1004, + -0.3142, + ], + [ + 0.0082, + 0.6241, + -0.1119, + 0.0082, + 0.6241, + -0.1119, + -0.6147, + 3.3314, + -0.8118, + -0.5584, + ], + ] + ) + self.assertTrue( + torch.allclose( + kt.values(), + expected_values, + rtol=1e-4, + atol=1e-4, + ), + ) + + def test_fx_and_shape(self) -> None: + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + + D = sum( + eb_config.embedding_dim * len(eb_config.feature_names) + for eb_config in [eb1_config, eb2_config] + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + sparse_arch = SparseArch(ebc) + gm = symbolic_trace(sparse_arch) + + FileCheck().check("KeyedJaggedTensor").check("cat").run(gm.code) + + keys = ["f1", "f2", "f3", "f4", "f5"] + offsets = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 19]) + features = KeyedJaggedTensor.from_offsets_sync( + keys=keys, + values=torch.tensor( + [1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3] + ), + offsets=offsets, + ) + B = (len(offsets) - 1) // len(keys) + + kt = gm(features) + self.assertEqual(kt.values().size(), (B, D)) + self.assertEqual(kt.keys(), ["f1", "f3", "f2"]) + self.assertEqual(kt.offset_per_key(), [0, 3, 6, 10]) + + # TODO(T89043538): Auto-generate this test. + def test_fx_script(self) -> None: + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + sparse_arch = SparseArch(ebc) + + gm = symbolic_trace(sparse_arch) + torch.jit.script(gm) + + +class DenseArchTest(unittest.TestCase): + def test_basic(self) -> None: + torch.manual_seed(0) + B = 4 + D = 3 + in_features = 10 + dense_arch = DenseArch(in_features=in_features, layer_sizes=[10, D]) + dense_embedded = dense_arch(torch.rand((B, in_features))) + self.assertEqual(dense_embedded.size(), (B, D)) + + expected = torch.tensor( + [ + [0.2351, 0.1578, 0.2784], + [0.1579, 0.1012, 0.2660], + [0.2459, 0.2379, 0.2749], + [0.2582, 0.2178, 0.2860], + ] + ) + self.assertTrue( + torch.allclose( + dense_embedded, + expected, + rtol=1e-4, + atol=1e-4, + ) + ) + + def test_fx_and_shape(self) -> None: + B = 20 + D = 3 + in_features = 10 + dense_arch = DenseArch(in_features=in_features, layer_sizes=[10, D]) + gm = symbolic_trace(dense_arch) + dense_embedded = gm(torch.rand((B, in_features))) + self.assertEqual(dense_embedded.size(), (B, D)) + + # TODO(T89043538): Auto-generate this test. + def test_fx_script(self) -> None: + B = 20 + D = 3 + in_features = 10 + dense_arch = DenseArch(in_features=in_features, layer_sizes=[10, D]) + gm = symbolic_trace(dense_arch) + scripted_gm = torch.jit.script(gm) + dense_embedded = scripted_gm(torch.rand((B, in_features))) + self.assertEqual(dense_embedded.size(), (B, D)) + + +class InteractionArchTest(unittest.TestCase): + def test_basic(self) -> None: + D = 3 + B = 10 + keys = ["f1", "f2"] + F = len(keys) + inter_arch = InteractionArch(num_sparse_features=F) + + dense_features = torch.rand((B, D)) + + embeddings = KeyedTensor( + keys=keys, + length_per_key=[D] * F, + values=torch.rand((B, D * F)), + ) + concat_dense = inter_arch(dense_features, embeddings) + # B X (D + F + F choose 2) + self.assertEqual(concat_dense.size(), (B, D + F + math.comb(F, 2))) + + def test_larger(self) -> None: + D = 8 + B = 20 + keys = ["f1", "f2", "f3", "f4"] + F = len(keys) + inter_arch = InteractionArch(num_sparse_features=F) + + dense_features = torch.rand((B, D)) + + embeddings = KeyedTensor( + keys=keys, + length_per_key=[D] * F, + values=torch.rand((B, D * F)), + ) + + concat_dense = inter_arch(dense_features, embeddings) + # B X (D + F + F choose 2) + self.assertEqual(concat_dense.size(), (B, D + F + math.comb(F, 2))) + + def test_fx_and_shape(self) -> None: + D = 3 + B = 10 + keys = ["f1", "f2"] + F = len(keys) + inter_arch = InteractionArch(num_sparse_features=F) + gm = symbolic_trace(inter_arch) + + dense_features = torch.rand((B, D)) + + embeddings = KeyedTensor( + keys=keys, + length_per_key=[D] * F, + values=torch.rand((B, D * F)), + ) + + concat_dense = gm(dense_features, embeddings) + # B X (D + F + F choose 2) + self.assertEqual(concat_dense.size(), (B, D + F + math.comb(F, 2))) + + # TODO(T89043538): Auto-generate this test. + def test_fx_script(self) -> None: + D = 3 + B = 10 + keys = ["f1", "f2"] + F = len(keys) + inter_arch = InteractionArch(num_sparse_features=F) + gm = symbolic_trace(inter_arch) + scripted_gm = torch.jit.script(gm) + + dense_features = torch.rand((B, D)) + + embeddings = KeyedTensor( + keys=keys, + length_per_key=[D] * F, + values=torch.rand((B, D * F)), + ) + + concat_dense = scripted_gm(dense_features, embeddings) + # B X (D + F + F choose 2) + self.assertEqual(concat_dense.size(), (B, D + F + math.comb(F, 2))) + + def test_correctness(self) -> None: + D = 11 + B = 25 + keys = ["f1", "f2", "f3", "f4", "f5", "f6"] + F = len(keys) + inter_arch = InteractionArch(num_sparse_features=F) + + dense_features = torch.rand((B, D)) + + embeddings = KeyedTensor( + keys=keys, + length_per_key=[D] * F, + values=torch.rand((B, D * F)), + ) + + concat_dense = inter_arch(dense_features, embeddings) + # B X (D + F + F choose 2) + self.assertEqual(concat_dense.size(), (B, D + F + math.comb(F, 2))) + + expected = self._test_correctness_helper( + dense_features=dense_features, + sparse_features=embeddings, + sparse_feature_names=keys, + ) + self.assertTrue( + torch.allclose( + concat_dense, + expected, + rtol=1e-4, + atol=1e-4, + ) + ) + + def _test_correctness_helper( + self, + dense_features: torch.Tensor, + sparse_features: KeyedTensor, + sparse_feature_names: List[str], + ) -> torch.Tensor: + interactions: List[torch.Tensor] = [] + # dense/sparse interaction + # size B X F + for feature_name in sparse_feature_names: + sparse_values = sparse_features[feature_name] + dots = torch.sum(sparse_values * dense_features, dim=1) + # dots is size B + interactions.append(dots) + + # sparse/sparse interaction + # size B X (F choose 2) + for (f1, f2) in list(combinations(sparse_feature_names, 2)): + f1_values = sparse_features[f1] + f2_values = sparse_features[f2] + dots = torch.sum(f1_values * f2_values, dim=1) + interactions.append(dots) + + interactions_tensor = torch.stack(interactions).transpose(1, 0) + return torch.cat((dense_features, interactions_tensor), dim=1) + + def test_numerical_stability(self) -> None: + D = 3 + B = 6 + keys = ["f1", "f2"] + F = len(keys) + inter_arch = InteractionArch(num_sparse_features=F) + torch.manual_seed(0) + dense_features = torch.randint(0, 10, (B, D)) + + embeddings = KeyedTensor( + keys=keys, + length_per_key=[D] * F, + values=torch.randint(0, 10, (B, D * F)), + ) + + concat_dense = inter_arch(dense_features, embeddings) + expected = torch.LongTensor( + [ + [4, 9, 3, 61, 57, 63], + [0, 3, 9, 84, 27, 45], + [7, 3, 7, 34, 50, 25], + [3, 1, 6, 21, 50, 91], + [6, 9, 8, 125, 109, 74], + [6, 6, 8, 18, 80, 21], + ] + ) + + self.assertTrue(torch.equal(concat_dense, expected)) + + +class DLRMTest(unittest.TestCase): + def test_basic(self) -> None: + torch.manual_seed(0) + B = 2 + D = 8 + dense_in_features = 100 + + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + sparse_nn = DLRM( + embedding_bag_collection=ebc, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], + ) + + features = torch.rand((B, dense_in_features)) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]), + offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]), + ) + + logits = sparse_nn( + dense_features=features, + sparse_features=sparse_features, + ) + self.assertEqual(logits.size(), (B, 1)) + + expected_logits = torch.tensor([[0.5805], [0.5909]]) + self.assertTrue( + torch.allclose( + logits, + expected_logits, + rtol=1e-4, + atol=1e-4, + ) + ) + + def test_one_sparse(self) -> None: + B = 2 + D = 8 + dense_in_features = 100 + + eb1_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config]) + sparse_nn = DLRM( + embedding_bag_collection=ebc, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], + ) + + features = torch.rand((B, dense_in_features)) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f2"], + values=torch.tensor(range(3)), + offsets=torch.tensor([0, 2, 3]), + ) + + logits = sparse_nn( + dense_features=features, + sparse_features=sparse_features, + ) + self.assertEqual(logits.size(), (B, 1)) + + def test_no_sparse(self) -> None: + ebc = EmbeddingBagCollection(tables=[]) + D_unused = 1 + with self.assertRaises(AssertionError): + DLRM( + embedding_bag_collection=ebc, + dense_in_features=100, + dense_arch_layer_sizes=[20, D_unused], + over_arch_layer_sizes=[5, 1], + ) + + def test_fx(self) -> None: + B = 2 + D = 8 + dense_in_features = 100 + + eb1_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config]) + sparse_nn = DLRM( + embedding_bag_collection=ebc, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], + ) + gm = symbolic_trace(sparse_nn) + FileCheck().check("KeyedJaggedTensor").check("cat").check("f2").run(gm.code) + + features = torch.rand((B, dense_in_features)) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f2"], + values=torch.tensor(range(3)), + offsets=torch.tensor([0, 2, 3]), + ) + + logits = gm( + dense_features=features, + sparse_features=sparse_features, + ) + self.assertEqual(logits.size(), (B, 1)) + + # TODO(T89043538): Auto-generate this test. + def test_fx_script(self) -> None: + B = 2 + D = 8 + dense_in_features = 100 + + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=D, + num_embeddings=100, + feature_names=["f2"], + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + sparse_nn = DLRM( + embedding_bag_collection=ebc, + dense_in_features=dense_in_features, + dense_arch_layer_sizes=[20, D], + over_arch_layer_sizes=[5, 1], + ) + + features = torch.rand((B, dense_in_features)) + sparse_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]), + offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]), + ) + + sparse_nn( + dense_features=features, + sparse_features=sparse_features, + ) + + gm = symbolic_trace(sparse_nn) + + scripted_gm = torch.jit.script(gm) + + logits = scripted_gm(features, sparse_features) + self.assertEqual(logits.size(), (B, 1)) diff --git a/torchrec/modules/__init__.py b/torchrec/modules/__init__.py new file mode 100644 index 000000000..e5a0d9b48 --- /dev/null +++ b/torchrec/modules/__init__.py @@ -0,0 +1 @@ +#!/usr/bin/env python3 diff --git a/torchrec/modules/activation.py b/torchrec/modules/activation.py new file mode 100644 index 000000000..1c4ac6a9d --- /dev/null +++ b/torchrec/modules/activation.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +from typing import Optional, Union, List + +import torch +from torch import nn + + +class SwishLayerNorm(nn.Module): + """ + Applies the Swish function with layer normalization: + 'Y = X * Sigmoid(LayerNorm(X)).' + + Call Args: + input: an input tensor + + Returns: + output: an output tensor + + Constructor Args: + input_dims: dimensions to normalize over. E.g., If an input tensor has shape + [batch_size, d1, d2, d3], set input_dim=[d2, d3] will do the layer normalization + on last two dimensions. + device: (Optional[torch.device]). + + Example: + >>> sln = SwishLayerNorm(100) + + """ + + def __init__( + self, + input_dims: Union[int, List[int], torch.Size], + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + self.norm: torch.nn.modules.Sequential = nn.Sequential( + nn.LayerNorm(input_dims, device=device), + nn.Sigmoid(), + ) + + def forward( + self, + input: torch.Tensor, + ) -> torch.Tensor: + return input * self.norm(input) diff --git a/torchrec/modules/crossnet.py b/torchrec/modules/crossnet.py new file mode 100644 index 000000000..c9319340e --- /dev/null +++ b/torchrec/modules/crossnet.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python3 + +from typing import Optional, Callable, Union + +import torch + + +class CrossNet(torch.nn.Module): + r""" + Cross Network: https://arxiv.org/abs/1708.05123 + + Cross net is a stack of "crossing" operations on a tensor of shape :math:`(*, N)` + to the same shape, effectively creating :math:`N` learnable polynomical functions + over the input tensor. + + In this module, the crossing operations are defined based on a full rank matrix (NxN), + such that crossing effect can cover all bits on each layer. On each layer l, the tensor + is transformed into: + x_{l+1} = x_0 * (W_l x x_l + b_l) + x_l + where W_l is a square matrix (NxN), "*" means element-wise multiplication, "x" means + matrix multiplication. + + Constructor Args: + in_features (int): the dimension of the input. + num_layers (int): the number of layers in the module. + + Call Args: + input (torch.Tensor): tensor with shape [batch_size, in_features] + + Returns: + output (torch.Tensor): tensor with shape [batch_size, in_features] + + Example: + >>> batch_size = 3 + >>> num_layers = 2 + >>> in_features = 10 + >>> input = torch.randn(batch_size, in_features) + >>> dcn = CrossNet(num_layers=num_layers) + >>> output = dcn(input) + """ + + def __init__( + self, + in_features: int, + num_layers: int, + ) -> None: + super().__init__() + self._num_layers = num_layers + self.kernels: torch.nn.Module = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.nn.init.xavier_normal_(torch.empty(in_features, in_features)) + ) + for i in range(self._num_layers) + ] + ) + self.bias: torch.nn.Module = torch.nn.ParameterList( + [ + torch.nn.Parameter(torch.nn.init.zeros_(torch.empty(in_features, 1))) + for i in range(self._num_layers) + ] + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + x_0 = input.unsqueeze(2) # (B, N, 1) + x_l = x_0 + + for layer in range(self._num_layers): + # pyre-ignore[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + xl_w = torch.matmul(self.kernels[layer], x_l) # (B, N, 1) + # pyre-ignore[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + x_l = x_0 * (xl_w + self.bias[layer]) + x_l # (B, N, 1) + + return torch.squeeze(x_l, dim=2) + + +class LowRankCrossNet(torch.nn.Module): + r""" + Low Rank Cross net is a high-efficient cross net. Instead of using full rank cross matrix (NxN) + at each layer, it will use two kernels W (`N`x`r`) and V (`r`x`N`), where `r << N`, to simplify the matrix + multiplication. + + On each layer l, the tensor is transformed into + x_{l+1} = x_0 * (W_l x (V_l x x_l) + b_l) + x_l + where W_l is either a vector, "*" means element-wise multiplication, and "x" means matrix multiplication. + + Note that, rank `r` should be chosen smartly. Usually, we should expect `r < N/2` to have computation saving; we should + expect `r` ~= N/4 to perserve the accuracy of full rank cross net. + + Constructor Args: + in_features (int): the dimension of the input. + num_layers (int): the number of layers in the module. + low_rank (int): the rank setup of the cross matrix (default = 0). Value must be always >= 0 + + Call Args: + input (torch.Tensor): tensor with shape [batch_size, in_features] + + Returns: + output (torch.Tensor): tensor with shape [batch_size, in_features] + + Example: + >>> batch_size = 3 + >>> num_layers = 2 + >>> in_features = 10 + >>> input = torch.randn(batch_size, in_features) + >>> dcn = LowRankCrossNet(num_layers=num_layers, low_rank=3) + >>> output = dcn(input) + """ + + def __init__( + self, + in_features: int, + num_layers: int, + low_rank: int = 1, + ) -> None: + super().__init__() + assert low_rank >= 1, "Low rank must be larger or equal to 1" + + self._num_layers = num_layers + self._low_rank = low_rank + self.W_kernels: torch.nn.Module = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.nn.init.xavier_normal_( + torch.empty(in_features, self._low_rank) + ) + ) + for i in range(self._num_layers) + ] + ) + self.V_kernels: torch.nn.Module = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.nn.init.xavier_normal_( + torch.empty(self._low_rank, in_features) + ) + ) + for i in range(self._num_layers) + ] + ) + self.bias: torch.nn.Module = torch.nn.ParameterList( + [ + torch.nn.Parameter(torch.nn.init.zeros_(torch.empty(in_features, 1))) + for i in range(self._num_layers) + ] + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + x_0 = input.unsqueeze(2) # (B, N, 1) + x_l = x_0 + + for layer in range(self._num_layers): + xl_w = torch.matmul( + # pyre-ignore[29]: `Union[torch.Tensor, torch.nn.Module]` is not a + # function. + self.W_kernels[layer], + # pyre-ignore[29]: `Union[torch.Tensor, torch.nn.Module]` is not a + # function. + torch.matmul(self.V_kernels[layer], x_l), + ) # (B, N, 1) + # pyre-ignore[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + x_l = x_0 * (xl_w + self.bias[layer]) + x_l # (B, N, 1) + + return torch.squeeze(x_l, dim=2) # (B, N) + + +class VectorCrossNet(torch.nn.Module): + r""" + Vector Cross Network can be refered as DCN-V1 (https://arxiv.org/pdf/1708.05123.pdf). + + It is also a specialized low rank cross net, where rank=1. In this version, on each layer, instead + of keeping two kernels W and V, we only keep one vector kernel W (Nx1). So, we will use dot + operation to compute the "crossing" effect of features; thus, we can save two matrix multiplications + to further reduce computational cost and cut the number of learnable parameter number. + + On each layer l, the tensor is transformed into + x_{l+1} = x_0 * (W_l . x_l + b_l) + x_l + where W_l is either a vector, "*" means element-wise multiplication; "." means dot operations. + + Constructor Args: + in_features (int): the dimension of the input. + num_layers (int): the number of layers in the module. + + Call Args: + input (torch.Tensor): tensor with shape [batch_size, in_features] + + Returns: + output (torch.Tensor): tensor with shape [batch_size, in_features] + + Example: + >>> batch_size = 3 + >>> num_layers = 2 + >>> in_features = 10 + >>> input = torch.randn(batch_size, in_features) + >>> dcn = VectorCrossNet(num_layers=num_layers) + >>> output = dcn(input) + """ + + def __init__( + self, + in_features: int, + num_layers: int, + ) -> None: + super().__init__() + self._num_layers = num_layers + self.kernels: torch.nn.Module = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.nn.init.xavier_normal_(torch.empty(in_features, 1)) + ) + for i in range(self._num_layers) + ] + ) + self.bias: torch.nn.Module = torch.nn.ParameterList( + [ + torch.nn.Parameter(torch.nn.init.zeros_(torch.empty(in_features, 1))) + for i in range(self._num_layers) + ] + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + x_0 = input.unsqueeze(2) # (B, N, 1) + x_l = x_0 + + for layer in range(self._num_layers): + xl_w = torch.tensordot( + x_l, + # pyre-ignore[29]: `Union[torch.Tensor, torch.nn.Module]` is not a + # function. + self.kernels[layer], + dims=([1], [0]), + ) # (B, 1, 1) + # pyre-ignore[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + x_l = torch.matmul(x_0, xl_w) + self.bias[layer] + x_l # (B, N, 1) + + return torch.squeeze(x_l, dim=2) # (B, N) + + +class LowRankMixtureCrossNet(torch.nn.Module): + r""" + LowRankMixtureCrossNet is a DCN V2 implementation from the paper: https://arxiv.org/pdf/2008.13535.pdf + + LowRankMixtureCrossNet defines the learnable crossing parameter per layer as low-rank matrix (Nxr) together + with mixture of expert. Compared to LowRankCrossNet, instead of relying on one single expert to learn + feature crosses, this module leverages such `K` experts, each learning feature interactions in a + different subspaces, and adaptively combine the learned crosses using a gating mechanism that depends + on input `x`. + + On each layer l, the tensor is transformed into + x_{l+1} = MoE(expert_i foreach i in K experts) + x_l + and each expert i is defined as: + expert_i = x_0 * (U_l_i x g(C_l_i x g(V_l_i x x_l)) + b_l) + where U_l_i (N, r), C_l_i (r, r), and V_l_i (r, N) are low-rank matrix, "*" means element-wise multiplication, + "x" means matrix multiplication, and g(.) is the non-linear activation function. + + One optimization is when num_expert is 1, the gate evaluation and MOE will be skipped for computation saving. + + Constructor Args: + in_features (int): the dimension of the input. + num_layers (int): the number of layers in the module. + low_rank (int): the rank setup of the cross matrix (default = 0). Value must be always >= 0 + activation (Union[torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]): the non-linear activation + function, used in defining experts. Default is relu. + + Call Args: + input (torch.Tensor): tensor with shape [batch_size, in_features] + + Returns: + output (torch.Tensor): tensor with shape [batch_size, in_features] + + Example: + >>> batch_size = 3 + >>> num_layers = 2 + >>> in_features = 10 + >>> input = torch.randn(batch_size, in_features) + >>> dcn = LowRankCrossNet(num_layers=num_layers, num_experts=5, low_rank=3) + >>> output = dcn(input) + """ + + def __init__( + self, + in_features: int, + num_layers: int, + num_experts: int = 1, + low_rank: int = 1, + activation: Union[ + torch.nn.Module, + Callable[[torch.Tensor], torch.Tensor], + ] = torch.relu, + ) -> None: + super().__init__() + assert num_experts >= 1, "num_experts must be larger or equal to 1" + assert low_rank >= 1, "Low rank must be larger or equal to 1" + + self._num_layers = num_layers + self._num_experts = num_experts + self._low_rank = low_rank + self._in_features = in_features + self.U_kernels: torch.nn.Module = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.nn.init.xavier_normal_( + torch.empty( + self._num_experts, self._in_features, self._low_rank + ) + ) + ) + for i in range(self._num_layers) + ] + ) + self.V_kernels: torch.nn.Module = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.nn.init.xavier_normal_( + torch.empty( + self._num_experts, self._low_rank, self._in_features + ) + ) + ) + for i in range(self._num_layers) + ] + ) + self.bias: torch.nn.Module = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.nn.init.zeros_(torch.empty(self._in_features, 1)) + ) + for i in range(self._num_layers) + ] + ) + self.gates: Optional[torch.nn.Module] = ( + torch.nn.ModuleList( + [ + torch.nn.Linear(self._in_features, 1, bias=False) + for i in range(self._num_layers) + ] + ) + if self._num_experts > 1 + else None + ) + + self._activation = activation + self.C_kernels: torch.nn.Module = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.nn.init.xavier_normal_( + torch.empty(self._num_experts, self._low_rank, self._low_rank) + ) + ) + for i in range(self._num_layers) + ] + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + x_0 = input.unsqueeze(2) # (B, N, 1) + x_l = x_0 + + for layer in range(self._num_layers): + # set up gating: + if self._num_experts > 1: + gating = [] + for i in range(self._num_experts): + # pyre-ignore[16]: `Optional` has no attribute `__getitem__`. + gating.append(self.gates[i](x_l.squeeze(2))) + gating = torch.stack(gating, 1) # (B, K, 1) + + # set up experts + experts = [] + for i in range(self._num_experts): + expert = torch.matmul( + # pyre-ignore[29] + self.V_kernels[layer][i], + x_l, + ) # (B, r, 1) + expert = torch.matmul( + # pyre-ignore[29] + self.C_kernels[layer][i], + self._activation(expert), + ) # (B, r, 1) + expert = torch.matmul( + # pyre-ignore[29] + self.U_kernels[layer][i], + self._activation(expert), + ) # (B, N, 1) + # pyre-ignore[29] + expert = x_0 * (expert + self.bias[layer]) # (B, N, 1) + experts.append(expert.squeeze(2)) # (B, N) + experts = torch.stack(experts, 2) # (B, N, K) + + if self._num_experts > 1: + # MOE update + moe = torch.matmul( + experts, + # pyre-ignore[61]: `gating` may not be initialized here. + torch.nn.functional.softmax(gating, 1), + ) # (B, N, 1) + x_l = moe + x_l # (B, N, 1) + else: + x_l = experts + x_l # (B, N, 1) + + return torch.squeeze(x_l, dim=2) # (B, N) diff --git a/torchrec/modules/deepfm.py b/torchrec/modules/deepfm.py new file mode 100644 index 000000000..4cba32581 --- /dev/null +++ b/torchrec/modules/deepfm.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 + +from typing import List + +import torch +from torch import nn +from torch.fx import wrap + + +# pyre-ignore[56]: Pyre was not able to infer the type of the decorator `torch.fx.wrap`. +@wrap +def _get_flatten_input(inputs: List[torch.Tensor]) -> torch.Tensor: + return torch.cat( + [input.flatten(1) for input in inputs], + dim=1, + ) + + +class DeepFM(nn.Module): + r""" + This is the DeepFM module. Extenal source of DeepFM paper: + https://arxiv.org/pdf/1703.04247.pdf + + This module is not to cover the end-end functionality of the published paper. + Instead, it is only the deep component of the publication. It is used to learn + high-order feature interactions. If low-order feature interaction should + be learnt, please use FactorizationMachine module instead, which will share + the same embedding input of this module. + + To support modeling flexibility, we customize the key components as: + * Different from the public paper, we change the input from raw sparse + features to embeddings of the features. It can allow flexibility in + embedding dimensions and number of embeddings, as long as all embedding + tensors have the same batch size. + * On top of the public paper, we allow users to customize the hidden layer + to be any module, not limited to just MLP. + + The general architecture of the module is like: + + 1 x 10 output + /|\ + | pass into `dense_module` + | + 1 x 90 + /|\ + | concat + | + 1 x 20, 1 x 30, 1 x 40 list of embeddings + + Constructor Args: + dense_module: nn.Module. + any customerized module that can be used (such as MLP) in DeepFM. The in_features + of this module must be equal to the elements counts. For example, the input + embeddings is [randn(3, 2, 3), randn(3, 4, 5)], the in_features should be: + 2*3+4*5. + + Call Args: + embeddings: List[torch.Tensor]: + The list of all embeddings (e.g. dense, common_sparse, specialized_sparse, + embedding_features, raw_embedding_features) in the shape of: + (batch_size, num_embeddings, embedding_dim) + + For the ease of operation, embeddings that have the same embedding dimension + have the option to be stacked into a single tensor. For example, when we have + 1 trained embedding with dimension=32, 5 native embedding with dimension=64, + and 3 dense features with dimension=16, we can prepare the embeddings list to + be the list of: + tensor(B, 1, 32) (trained_embedding with num_embeddings=1, embedding_dim=32) + tensor(B, 5, 64) (native_embedding with num_embeddings=5, embedding_dim=64) + tensor(B, 3, 16) (dense_features with num_embeddings=3, embedding_dim=32) + + Note that: batch_size of all input tensors need to be identical. + + Returns: + deepfm_output (torch.Tensor): output of `dense_module` with flattened and + concatenated `embeddings` as input. + + Example: + >>> import torch + >>> from torchrec.fb.modules.deepfm import DeepFM + >>> from torchrec.fb.modules.mlp import LazyMLP + >>> batch_size = 3 + >>> output_dim = 30 + >>> # the input embedding are in torch.Tensor of [batch_size, num_embeddings, embedding_dim] + >>> input_embeddings = [ + >>> torch.randn(batch_size, 2, 64), + >>> torch.randn(batch_size, 2, 32), + >>> ] + >>> dense_module = nn.Linear(192, output_dim) + >>> deepfm = DeepFM(dense_module=dense_module) + >>> deep_fm_output = deepfm(embeddings=input_embeddings) + """ + + def __init__( + self, + dense_module: nn.Module, + ) -> None: + super().__init__() + self.dense_module = dense_module + + def forward( + self, + embeddings: List[torch.Tensor], + ) -> torch.Tensor: + # flatten each embedding to be [B, N, D] -> [B, N*D], then cat them all on dim=1 + deepfm_input = _get_flatten_input(embeddings) + deepfm_output = self.dense_module(deepfm_input) + return deepfm_output + + +class FactorizationMachine(nn.Module): + r""" + This is the Factorization Machine module, mentioned in the DeepFM paper: + https://arxiv.org/pdf/1703.04247.pdf + + This module is not to cover the end-end functionality of the published paper. + Instead, it is only the FM part of the publication. It is used to learn + 2nd-order feature interactions. + + To support modeling flexibility, we customize the key components as: + * Different from the public paper, we change the input from raw sparse + features to embeddings of the features. It can allow flexibility in + embedding dimensions and number of embeddings, as long as all embedding + tensors have the same batch size. + + The general architecture of the module is like: + + 1 x 1 output + /|\ + | pass into `dense_module` + | + 1 x 90 + /|\ + | concat + | + 1 x 20, 1 x 30, 1 x 40 list of embeddings + + Constructor Args: + None + + Call Args: + embeddings: List[torch.Tensor]: + The list of all embeddings (e.g. dense, common_sparse, specialized_sparse, + embedding_features, raw_embedding_features) in the shape of: + (batch_size, num_embeddings, embedding_dim) + + For the ease of operation, embeddings that have the same embedding dimension + have the option to be stacked into a single tensor. For example, when we have + 1 trained embedding with dimension=32, 5 native embedding with dimension=64, + and 3 dense features with dimension=16, we can prepare the embeddings list to + be the list of: + tensor(B, 1, 32) (trained_embedding with num_embeddings=1, embedding_dim=32) + tensor(B, 5, 64) (native_embedding with num_embeddings=5, embedding_dim=64) + tensor(B, 3, 16) (dense_features with num_embeddings=3, embedding_dim=32) + + Note that: batch_size of all input tensors need to be identical. + + Returns: + output (torch.Tensor): output of fm with flattened and + concatenated `embeddings` as input. Expected to be [B, 1] + + Example: + >>> batch_size = 3 + >>> # the input embedding are in torch.Tensor of [batch_size, num_embeddings, embedding_dim] + >>> input_embeddings = [ + >>> torch.randn(batch_size, 2, 64), + >>> torch.randn(batch_size, 2, 32), + >>> ] + >>> fm = FactorizationMachine() + >>> output = fm(embeddings=input_embeddings) + """ + + def __init__( + self, + ) -> None: + super().__init__() + + def forward( + self, + embeddings: List[torch.Tensor], + ) -> torch.Tensor: + # flatten each embedding to be [B, N, D] -> [B, N*D], then cat them all on dim=1 + fm_input = _get_flatten_input(embeddings) + sum_of_input = torch.sum(fm_input, dim=1, keepdim=True) + sum_of_square = torch.sum(fm_input * fm_input, dim=1, keepdim=True) + square_of_sum = sum_of_input * sum_of_input + cross_term = square_of_sum - sum_of_square + cross_term = torch.sum(cross_term, dim=1, keepdim=True) * 0.5 # [B, 1] + return cross_term diff --git a/torchrec/modules/embedding_configs.py b/torchrec/modules/embedding_configs.py new file mode 100644 index 000000000..ec09154f0 --- /dev/null +++ b/torchrec/modules/embedding_configs.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 + +from dataclasses import dataclass, field +from enum import Enum, unique +from math import sqrt +from typing import Optional, List, Dict + + +@unique +class PoolingType(Enum): + SUM = "SUM" + MEAN = "MEAN" + NONE = "NONE" + + +@unique +class DataType(Enum): + """ + Our fusion impl supports only certain types of data + so it makes sense to retrict in a non-fused version as well. + """ + + FP32 = "FP32" + FP16 = "FP16" + INT8 = "INT8" + INT4 = "INT4" + INT2 = "INT2" + + +DATA_TYPE_NUM_BITS: Dict[DataType, int] = { + DataType.FP32: 32, + DataType.FP16: 16, + DataType.INT8: 8, + DataType.INT4: 4, + DataType.INT2: 2, +} + + +@dataclass +class BaseEmbeddingConfig: + num_embeddings: int + embedding_dim: int + name: str = "" + data_type: DataType = DataType.FP32 + feature_names: List[str] = field(default_factory=list) + weight_init_max: Optional[float] = None + weight_init_min: Optional[float] = None + + def get_weight_init_max(self) -> float: + if self.weight_init_max is None: + return sqrt(1 / self.num_embeddings) + else: + return self.weight_init_max + + def get_weight_init_min(self) -> float: + if self.weight_init_min is None: + return -sqrt(1 / self.num_embeddings) + else: + return self.weight_init_min + + def num_features(self) -> int: + return len(self.feature_names) + + +@dataclass +class EmbeddingTableConfig(BaseEmbeddingConfig): + pooling: PoolingType = PoolingType.SUM + is_weighted: bool = False + has_feature_processor: bool = False + embedding_names: List[str] = field(default_factory=list) + + +@dataclass +class EmbeddingBagConfig(BaseEmbeddingConfig): + pooling: PoolingType = PoolingType.SUM + + +@dataclass +class EmbeddingConfig(BaseEmbeddingConfig): + pass diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py new file mode 100644 index 000000000..ae08a106e --- /dev/null +++ b/torchrec/modules/embedding_modules.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 + +import abc +from typing import List, Dict, Optional + +import torch +import torch.nn as nn +from torchrec.modules.embedding_configs import ( + DataType, + EmbeddingConfig, + EmbeddingBagConfig, + PoolingType, +) +from torchrec.sparse.jagged_tensor import ( + KeyedJaggedTensor, + JaggedTensor, + KeyedTensor, +) + + +def _to_mode(pooling: PoolingType) -> str: + if pooling == PoolingType.SUM: + return "sum" + elif pooling == PoolingType.MEAN: + return "mean" + else: + raise ValueError(f"Unsupported pooling {pooling}") + + +class EmbeddingBagCollectionInterface(abc.ABC, nn.Module): + """ + Interface for EmbeddingBagCollection, GroupedEmbeddingBag, and BaseBatchedEmbeddingBag. + """ + + @abc.abstractmethod + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedTensor: + pass + + @abc.abstractproperty + def embedding_bag_configs( + self, + ) -> List[EmbeddingBagConfig]: + pass + + @abc.abstractproperty + def is_weighted(self) -> bool: + pass + + +class EmbeddingBagCollection(EmbeddingBagCollectionInterface): + """ + EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags) + It processes sparse data in the form of KeyedJaggedTensor + with values of the form [F X B X L] + F: features (keys) + B: batch size + L: Length of sparse features (jagged) + + and outputs a KeyedTensor with values of the form [B * (F * D)] + where + F: features (keys) + D: each feature's (key's) embedding dimension + B: batch size + + Constructor Args: + tables (List[EmbeddingBagConfig]): list of embedding tables + is_weighted: (bool): whether input KeyedJaggedTensor is weighted + device: (Optional[torch.device]): default compute device + + Call Args: + features: KeyedJaggedTensor, + weighted_features: KeyedJaggedTensor, + + Returns: + KeyedTensor + + Example: + table_0 = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + table_1 = EmbeddingBagConfig( + name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] + ) + + ebc = EmbeddingBagCollection(tables=[table_0, table_1]) + + # 0 1 2 <-- batch + # "f1" [0,1] None [2] + # "f2" [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + pooled_embeddings = ebc(features) + print(pooled_embeddings.values()) + tensor([[-0.6149, 0.0000, -0.3176], + [-0.8876, 0.0000, -1.5606], + [ 1.6805, 0.0000, 0.6810], + [-1.4206, -1.0409, 0.2249], + [ 0.1823, -0.4697, 1.3823], + [-0.2767, -0.9965, -0.1797], + [ 0.8864, 0.1315, -2.0724]], grad_fn=) + print(pooled_embeddings.keys()) + ['f1', 'f2'] + print(pooled_embeddings.offset_per_key()) + tensor([0, 3, 7]) + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + is_weighted: bool = False, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}") + self._is_weighted = is_weighted + # pyre-ignore[11] + self.embedding_bags: nn.ModuleDict = nn.ModuleDict() + self._embedding_bag_configs = tables + self._embedding_names: List[str] = [] + self._lengths_per_embedding: List[int] = [] + table_names = set() + shared_feature: Dict[str, bool] = {} + for embedding_config in tables: + if embedding_config.name in table_names: + raise ValueError(f"Duplicate table name {embedding_config.name}") + table_names.add(embedding_config.name) + dtype = ( + torch.float32 + if embedding_config.data_type == DataType.FP32 + else torch.float16 + ) + self.embedding_bags[embedding_config.name] = nn.EmbeddingBag( + num_embeddings=embedding_config.num_embeddings, + embedding_dim=embedding_config.embedding_dim, + mode=_to_mode(embedding_config.pooling), + device=device, + include_last_offset=True, + dtype=dtype, + ) + if not embedding_config.feature_names: + embedding_config.feature_names = [embedding_config.name] + for feature_name in embedding_config.feature_names: + if feature_name not in shared_feature: + shared_feature[feature_name] = False + else: + shared_feature[feature_name] = True + self._lengths_per_embedding.append(embedding_config.embedding_dim) + + for embedding_config in tables: + for feature_name in embedding_config.feature_names: + if shared_feature[feature_name]: + self._embedding_names.append( + feature_name + "@" + embedding_config.name + ) + else: + self._embedding_names.append(feature_name) + + def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + pooled_embeddings: List[torch.Tensor] = [] + + for embedding_config, embedding_bag in zip( + self._embedding_bag_configs, self.embedding_bags.values() + ): + for feature_name in embedding_config.feature_names: + f = features[feature_name] + res = embedding_bag( + input=f.values(), + offsets=f.offsets(), + per_sample_weights=f.weights() if self._is_weighted else None, + ) + pooled_embeddings.append(res) + data = torch.cat(pooled_embeddings, dim=1) + return KeyedTensor( + keys=self._embedding_names, + values=data, + length_per_key=self._lengths_per_embedding, + ) + + @property + def embedding_bag_configs(self) -> List[EmbeddingBagConfig]: + return self._embedding_bag_configs + + @property + def is_weighted(self) -> bool: + return self._is_weighted + + +class EmbeddingCollection(nn.Module): + """ + EmbeddingCollection represents a collection of non-pooled embeddings + It processes sparse data in the form of KeyedJaggedTensor + of the form [F X B X L] + F: features (keys) + B: batch size + L: Length of sparse features (variable) + + and outputs Dict[feature (key), JaggedTensor]. + Each JaggedTensor contains values of the form (B * L) X D + where + B: batch size + L: Length of sparse features (jagged) + D: each feature's (key's) embedding dimension + and lengths are of the form L + + Constructor Args: + tables (List[EmbeddingBagConfig]): list of embedding tables + device: (Optional[torch.device]): default compute device + + Call Args: + features: KeyedJaggedTensor, + + Returns: + Dict[str, JaggedTensor] + + Example: + >>> e1_config = EmbeddingConfig( + name="t1", embedding_dim=2, num_embeddings=10, feature_names=["f1"] + ) + e2_config = EmbeddingConfig( + name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] + ) + ec_config = EmbeddingCollectionConfig(tables=[e1_config, e2_config]) + + ec = EmbeddingCollection(config=ec_config) + + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + feature_embeddings = ec(features) + print(feature_embeddings['f2'].values()) + tensor([[-0.2050, 0.5478, 0.6054], + [ 0.7352, 0.3210, -3.0399], + [ 0.1279, -0.1756, -0.4130], + [ 0.7519, -0.4341, -0.0499], + [ 0.9329, -1.0697, -0.8095]], grad_fn=) + + """ + + def __init__( + self, + tables: List[EmbeddingConfig], + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}") + self.embeddings: nn.ModuleDict = nn.ModuleDict() + self.embedding_configs = tables + self._embedding_names: List[str] = [] + table_names = set() + shared_feature: Dict[str, bool] = {} + for embedding_config in tables: + if embedding_config.name in table_names: + raise ValueError(f"Duplicate table name {embedding_config.name}") + table_names.add(embedding_config.name) + self.embeddings[embedding_config.name] = nn.Embedding( + num_embeddings=embedding_config.num_embeddings, + embedding_dim=embedding_config.embedding_dim, + device=device, + ) + if not embedding_config.feature_names: + embedding_config.feature_names = [embedding_config.name] + for feature_name in embedding_config.feature_names: + if feature_name not in shared_feature: + shared_feature[feature_name] = False + else: + shared_feature[feature_name] = True + + for embedding_config in tables: + for feature_name in embedding_config.feature_names: + if shared_feature[feature_name]: + self._embedding_names.append( + feature_name + "@" + embedding_config.name + ) + else: + self._embedding_names.append(feature_name) + + def forward(self, features: KeyedJaggedTensor) -> Dict[str, JaggedTensor]: + feature_embeddings: Dict[str, JaggedTensor] = {} + idx = 0 + for embedding_config, embedding in zip( + self.embedding_configs, self.embeddings.values() + ): + for feature_name in embedding_config.feature_names: + f = features[feature_name] + lookup = embedding( + input=f.values(), + ) + feature_embeddings[self._embedding_names[idx]] = JaggedTensor( + values=lookup, + offsets=f.offsets(), + lengths=f.lengths(), + ) + idx += 1 + return feature_embeddings diff --git a/torchrec/modules/feature_processor.py b/torchrec/modules/feature_processor.py new file mode 100644 index 000000000..bb58d68a3 --- /dev/null +++ b/torchrec/modules/feature_processor.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 + +import abc +from typing import Dict + +import torch +import torch.nn as nn +from torchrec.sparse.jagged_tensor import JaggedTensor + + +class BaseFeatureProcessor(nn.Module): + """ + abstract base class for feature processor + """ + + @abc.abstractmethod + def forward( + self, + features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + pass + + +class PositionWeightedModule(BaseFeatureProcessor): + def __init__( + self, + max_feature_lengths: Dict[str, int], + ) -> None: + super().__init__() + self.max_feature_lengths = max_feature_lengths + self.position_weights: nn.ParameterDict = nn.ParameterDict() + for key, length in max_feature_lengths.items(): + # pyre-fixme[29] + self.position_weights[key] = nn.Parameter(torch.empty([length]).fill_(1.0)) + + def forward( + self, + features: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + ret: Dict[str, JaggedTensor] = {} + # pyre-fixme[29] + for key, pos_weight in self.position_weights.items(): + seq = torch.ops.fbgemm.offsets_range( + features[key].lengths().long(), features[key].values().long() + ) + ret[key] = JaggedTensor( + values=features[key].values(), + lengths=features[key].lengths(), + offsets=features[key].offsets(), + weights=torch.gather(pos_weight, dim=0, index=seq), + ) + return ret diff --git a/torchrec/modules/lazy_extension.py b/torchrec/modules/lazy_extension.py new file mode 100644 index 000000000..94194575c --- /dev/null +++ b/torchrec/modules/lazy_extension.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 + +import functools +import inspect +from typing import ( + Any, + Callable, +) + +import torch +import torch.utils.hooks as hooks +from torch.nn.modules.lazy import ( + _LazyProtocol, + LazyModuleMixin, +) +from torch.nn.modules.module import ( + _global_backward_hooks, + _global_forward_pre_hooks, + _global_forward_hooks, +) + + +def _apply_functions_after_first_forward( + module: torch.nn.Module, + # pyre-ignore[2] + input: Any, + # pyre-ignore[2] + output: Any, +) -> None: + _functions_to_lazy_apply = getattr(module, "_functions_to_lazy_apply", None) + if _functions_to_lazy_apply is not None: + for fn in _functions_to_lazy_apply: + module.apply(fn) + delattr(module, "_functions_to_lazy_apply") + # pyre-ignore[16] + module._lazy_apply_hook.remove() + delattr(module, "_lazy_apply_hook") + + +def lazy_apply( + module: torch.nn.Module, fn: Callable[[torch.nn.Module], None] +) -> torch.nn.Module: + r"""Attaches a function to a module, which will be applied recursively to every submodule + (as returned by `.children()`) of the module as well as the module itself right after + the first forward pass (i.e. after all submodules and parameters have been initialized). + Typical use includes initializing the numerical value of the parameters of + a lazy module (i.e. modules inherited from LazyModuleMixin). + + Note that `lazy_apply()` can be used on both lazy and non-lazy modules. + + Args: + module (torch.nn.Module): module to recursively apply `fn` on. + fn (Callable[[torch.nn.Module], None]): function to be attached to `module` and later + be applied to each submodule of `module` and the `module` itself. + + Returns: + `module` with `fn` attached. + + Example: + >>> @torch.no_grad() + >>> def init_weights(m): + >>> print(m) + >>> if type(m) == torch.nn.LazyLinear: + >>> m.weight.fill_(1.0) + >>> print(m.weight) + >>> + >>> linear = torch.nn.LazyLinear(2) + >>> lazy_apply(linear, init_weights) # doesn't run `init_weights` immediately + >>> input = torch.randn(2, 10) + >>> linear(input) # runs `init_weights` only once, right after first forward pass + >>> + >>> seq = torch.nn.Sequential(torch.nn.LazyLinear(2), torch.nn.LazyLinear(2)) + >>> lazy_apply(seq, init_weights) # doesn't run `init_weights` immediately + >>> input = torch.randn(2, 10) + >>> seq(input) # runs `init_weights` only once, right after first forward pass + """ + if not hasattr(module, "_functions_to_lazy_apply"): + # pyre-ignore[16] + module._functions_to_lazy_apply = [] + if not hasattr(module, "_lazy_apply_hook"): + # pyre-ignore[16] + module._lazy_apply_hook = module.register_forward_hook( + _apply_functions_after_first_forward + ) + # pyre-ignore[16] + module._functions_to_lazy_apply.append(fn) + return module + + +class _LazyExtensionProtocol(_LazyProtocol): + # pyre-ignore[2,3] + def _call_impl(self, *input, **kwargs): + ... + + +class LazyModuleExtensionMixin(LazyModuleMixin): + """ + This is an temporary extension of LazyModuleMixin to support passing keyword + arguments to lazy module's `forward` method. + + The long-term plan is to upstream this feature to LazyModuleMixin. Please see + https://github.com/pytorch/pytorch/issues/59923 for details. + + Please see TestLazyModuleExtensionMixin, which contains unit tests that ensure: + * LazyModuleExtensionMixin._infer_parameters has source code parity as + torch.nn.modules.lazy.LazyModuleMixin._infer_parameters, except that the former + can accept keyword arguments. + * LazyModuleExtensionMixin._call_impl has source code parity as + torch.nn.Module._call_impl, except that the former can pass keyword arguments + to forward pre hooks." + """ + + def apply(self, fn: Callable[[torch.nn.Module], None]) -> torch.nn.Module: + r"""Applies `fn` recursively to every submodule (as returned by `.children()`) + as well as self. Typical use includes initializing the parameters of a model. + + Note that calling `apply()` on an uninitialized lazy-module will result in an error. + User is required to initialize a lazy-module (by doing a dummy forward pass) before + calling `apply()` on the lazy-module. + + Args: + fn (torch.nn.Module -> None): function to be applied to each submodule + + Returns: + Module: self + + Example:: + >>> @torch.no_grad() + >>> def init_weights(m): + >>> print(m) + >>> if type(m) == torch.nn.LazyLinear: + >>> m.weight.fill_(1.0) + >>> print(m.weight) + >>> + >>> linear = torch.nn.LazyLinear(2) + >>> linear.apply(init_weights) # this fails, because `linear` (a lazy-module) hasn't been initialized yet + >>> + >>> input = torch.randn(2, 10) + >>> linear(input) # run a dummy forward pass to initialize the lazy-module + >>> + >>> linear.apply(init_weights) # this works now + """ + if hasattr(self, "_initialize_hook"): + raise RuntimeError( + "Module {} has not been initialized. ".format(self) + + "Please run a dummy forward pass on the model to initialize all modules, " + + "or use torchrec.modules.lazy_extension.lazy_apply to attach a function " + + "to this module which would be applied after this module is initialized." + ) + # If the module is already initialized, call `super().apply(fn)` to + # run the usual apply logic. + # pyre-ignore[16] + return super().apply(fn) + + # fmt: off + # pyre-ignore[2,3,14,47] + def _infer_parameters(self: _LazyExtensionProtocol, module, input, kwargs): + r"""Infers the size and initializes the parameters according to the + provided input batch. + Given a module that contains parameters that were declared inferrable + using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass + in the complete module using the provided input to initialize all the parameters + as needed. + The module is set into evaluation mode before running the forward pass in order + to avoid saving statistics or calculating gradients + """ + module.initialize_parameters(*input, **kwargs) + if module.has_uninitialized_params(): + raise RuntimeError('module {} has not been fully initialized'.format(self._get_name())) + module._initialize_hook.remove() + module._load_hook.remove() + delattr(module, '_initialize_hook') + delattr(module, '_load_hook') + if module.cls_to_become is not None: + module.__class__ = module.cls_to_become + # fmt: on + + # fmt: off + # pyre-ignore[2,3] + def _call_impl(self, *input, **kwargs): # noqa: C901 + # pyre-ignore[16] + forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward) + # If we don't have any hooks, we want to skip the rest of the logic in + # this function, and just call forward. + # pyre-ignore[16] + if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks + or _global_forward_hooks or _global_forward_pre_hooks): + return forward_call(*input, **kwargs) + # Do not call functions when jit is used + full_backward_hooks, non_full_backward_hooks = [], [] + if self._backward_hooks or _global_backward_hooks: + # pyre-ignore[16] + full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() + if _global_forward_pre_hooks or self._forward_pre_hooks: + # pyre-ignore[60]: Concatenation not yet support for multiple variadic + # tuples: `*torch.nn.modules.module._global_forward_pre_hooks.values(), + # *self._forward_pre_hooks.values()`. + for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()): + if len(inspect.signature(hook).parameters) == 3: + result = hook(self, input, kwargs) + else: + result = hook(self, input) + if result is not None: + if not isinstance(result, tuple): + result = (result,) + input = result + + bw_hook = None + if full_backward_hooks: + bw_hook = hooks.BackwardHook(self, full_backward_hooks) + input = bw_hook.setup_input_hook(input) + + result = forward_call(*input, **kwargs) + if _global_forward_hooks or self._forward_hooks: + # pyre-ignore[60]: Concatenation not yet support for multiple variadic + # tuples: `*torch.nn.modules.module._global_forward_hooks.values(), + # *self._forward_hooks.values()`. + for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()): + hook_result = hook(self, input, result) + if hook_result is not None: + result = hook_result + + if bw_hook: + result = bw_hook.setup_output_hook(result) + + # Handle the non-full backward hooks + if non_full_backward_hooks: + var = result + while not isinstance(var, torch.Tensor): + if isinstance(var, dict): + var = next((v for v in var.values() if isinstance(v, torch.Tensor))) + else: + var = var[0] + # pyre-ignore[16] + grad_fn = var.grad_fn + if grad_fn is not None: + for hook in non_full_backward_hooks: + wrapper = functools.partial(hook, self) + functools.update_wrapper(wrapper, hook) + grad_fn.register_hook(wrapper) + # pyre-ignore[16] + self._maybe_warn_non_full_backward_hook(input, result, grad_fn) + + return result + # fmt: on + + # pyre-ignore[4] + __call__: Callable[..., Any] = _call_impl diff --git a/torchrec/modules/mlp.py b/torchrec/modules/mlp.py new file mode 100644 index 000000000..f567773bb --- /dev/null +++ b/torchrec/modules/mlp.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 + +from typing import Callable, List, Optional, Union + +import torch +from torch import nn +from torchrec.modules.activation import SwishLayerNorm +from torchrec.modules.utils import extract_module_or_tensor_callable + + +class Perceptron(torch.nn.Module): + """ + Applies a linear transformation and activation. + + Constructor Args: + in_size (int): number of elements in each input sample. + out_size (int): number of elements in each output sample. + bias (bool): if set to ``False``, the layer will not learn an additive bias. + Default: ``True``. + activation (Union[torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]): + the activation function to apply to the output of linear transformation. + Default: torch.relu. + device: (Optional[torch.device]). + + Call Args: + input (torch.Tensor): tensor of shape (B, I) where I is number of elements + in each input sample. + + Returns: + output (torch.Tensor): tensor of shape (B, O) where O is number of elements + per channel in each output sample (i.e. `out_size`). + + Example: + >>> batch_size = 3 + >>> in_size = 40 + >>> input = torch.randn(batch_size, in_size) + >>> + >>> out_size = 16 + >>> perceptron = Perceptron(in_size, out_size, bias=True) + >>> + >>> output = perceptron(input) + >>> assert list(output) == [batch_size, out_size] + """ + + def __init__( + self, + in_size: int, + out_size: int, + bias: bool = True, + activation: Union[ + torch.nn.Module, + Callable[[torch.Tensor], torch.Tensor], + ] = torch.relu, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torchrec.modules.{self.__class__.__name__}") + self._out_size = out_size + self._in_size = in_size + self._linear: nn.Linear = nn.Linear( + self._in_size, self._out_size, bias=bias, device=device + ) + self._activation_fn: Callable[[torch.Tensor], torch.Tensor] = activation + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return self._activation_fn(self._linear(input)) + + +class MLP(torch.nn.Module): + """ + Applies a stack of Perceptron modules sequentially (i.e. Multi-Layer Perceptron). + + Constructor Args: + in_size (int): `in_size` of the input + layer_sizes (List[int]): `out_size` of each Perceptron module. + bias (bool): if set to False, the layer will not learn an additive bias. + Default: True. + activation (str, Union[Callable[[], torch.nn.Module], torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]): + the activation function to apply to the output of linear transformation of each Perceptron module. + If `activation` is a `str`, we currently only support the follow strings, as "relu", "sigmoid", + and "swish_layernorm". + If `activation` is a `Callable[[], torch.nn.Module]`, `activation()` will be called once per Perceptron module + to generate the activation module for that Perceptron module, and the parameters won't be shared + between those activation modules. One use case is when all the activation modules share the same + constructor arguments, but don't share the actual module parameters. + Default: torch.relu. + device: (Optional[torch.device]). + + Call Args: + input (torch.Tensor): tensor of shape (B, I) where I is number of elements + in each input sample. + + Returns: + output (torch.Tensor): tensor of shape (B, O) where O is `out_size` of + the last Perceptron module. + + Example: + >>> batch_size = 3 + >>> in_size = 40 + >>> input = torch.randn(batch_size, in_size) + >>> + >>> layer_sizes = [16, 8, 4] + >>> mlp_module = MLP(in_size, layer_sizes, bias=True) + >>> output = mlp_module(input) + >>> assert list(output.shape) == [batch_size, layer_sizes[-1]] + """ + + def __init__( + self, + in_size: int, + layer_sizes: List[int], + bias: bool = True, + activation: Union[ + str, + Callable[[], torch.nn.Module], + torch.nn.Module, + Callable[[torch.Tensor], torch.Tensor], + ] = torch.relu, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + + if activation == "relu": + activation = torch.relu + elif activation == "sigmoid": + activation = torch.sigmoid + + if not isinstance(activation, str): + self._mlp: torch.nn.Module = torch.nn.Sequential( + *[ + Perceptron( + layer_sizes[i - 1] if i > 0 else in_size, + layer_sizes[i], + bias=bias, + activation=extract_module_or_tensor_callable(activation), + device=device, + ) + for i in range(len(layer_sizes)) + ] + ) + else: + if activation == "swish_layernorm": + self._mlp: torch.nn.Module = torch.nn.Sequential( + *[ + Perceptron( + layer_sizes[i - 1] if i > 0 else in_size, + layer_sizes[i], + bias=bias, + activation=SwishLayerNorm(layer_sizes[i], device=device), + device=device, + ) + for i in range(len(layer_sizes)) + ] + ) + else: + assert ( + ValueError + ), "This MLP only support str version activation function of relu, sigmoid, and swish_layernorm" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return self._mlp(input) diff --git a/torchrec/modules/score_learning.py b/torchrec/modules/score_learning.py new file mode 100644 index 000000000..9c519246e --- /dev/null +++ b/torchrec/modules/score_learning.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 + +from typing import Dict + +import torch +import torch.nn as nn +from torchrec.sparse.jagged_tensor import ( + KeyedJaggedTensor, +) + + +# pyre-fixme[56]: Pyre was not able to infer the type of the decorator `torch.fx.wrap`. +@torch.fx.wrap +def lengths_range_fill(lengths: torch.Tensor) -> torch.Tensor: + """ + Generate arange list for each length element + Example: + lengths = torch.Tensor([3, 1, 2]) + return torch.Tensor([0, 1, 2, 0, 0, 1]) + """ + seq_list = [torch.arange(start=0, end=i, dtype=torch.int64) for i in lengths] + return torch.cat(seq_list) + + +class PositionWeightsAttacher(nn.Module): + """ + Map id list features to id score list features using each id's + position in the sample. + + Constructor Args: + features_max_length (Dict[str, int]): feature name to max_length mapping. + max_length, a.k.a truncation size, specifies the maximum number of ids + each sample has. For each feature, its position weight parameter size + is max_length. + + Call Args: + features: KeyedJaggedTensor + + Returns: + weighted_features (KeyedJaggedTensor): same as input features with weights + field being populated + + Example: + >>> features_max_length = {"f1": 10, "f2": 3} + pw = PositionWeightsAttacher(features_max_lengths) + + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + weighted_features = pw(features) + """ + + def __init__( + self, + features_max_length: Dict[str, int], + ) -> None: + super().__init__() + self.features_max_length = features_max_length + self.position_weights = nn.ParameterDict() + for feature_name, max_length in features_max_length.items(): + # pyre-fixme[29]: `Union[nn.Module, torch.Tensor]` is not a function. + self.position_weights[feature_name] = nn.Parameter(torch.ones(max_length)) + + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedJaggedTensor: + features_weights = [] + for feature_name, _ in self.features_max_length.items(): + lengths = features[feature_name].lengths() + # TODO(T92151660): replace pt ops with fbgemm's lengths_range_w_truncation_size + # and fast_gather + seq = lengths_range_fill(lengths) + weights = torch.gather(self.position_weights[feature_name], 0, seq) + features_weights.append(weights) + weights = torch.cat(features_weights) + return KeyedJaggedTensor.from_lengths_sync( + keys=features.keys(), + values=features.values(), + lengths=features.lengths(), + weights=weights, + ) diff --git a/torchrec/modules/tests/__init__.py b/torchrec/modules/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/modules/tests/test_activation.py b/torchrec/modules/tests/test_activation.py new file mode 100644 index 000000000..09be32cfe --- /dev/null +++ b/torchrec/modules/tests/test_activation.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 + +import unittest + +import torch +from torchrec.fx import Tracer +from torchrec.modules.activation import SwishLayerNorm + + +class TestActivation(unittest.TestCase): + def test_swish_takes_float(self) -> None: + m = SwishLayerNorm([3, 4]) + input = torch.randn(2, 3, 4) + output = m(input) + norm = torch.nn.LayerNorm([3, 4]) + ref_output = input * torch.sigmoid(norm(input)) + self.assertTrue(torch.allclose(output, ref_output)) + + def test_fx_script_swish(self) -> None: + m = SwishLayerNorm(10) + + gm = torch.fx.GraphModule(m, Tracer().trace(m)) + torch.jit.script(gm) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchrec/modules/tests/test_code_quality.py b/torchrec/modules/tests/test_code_quality.py new file mode 100644 index 000000000..42b06b1f6 --- /dev/null +++ b/torchrec/modules/tests/test_code_quality.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +import inspect +import sys +import unittest + +import torch +import torchrec # noqa +from torchrec.linter.module_linter import MAX_NUM_ARGS_IN_MODULE_CTOR + + +class CodeQualityTest(unittest.TestCase): + def test_num_ctor_args(self) -> None: + classes = inspect.getmembers(sys.modules["torchrec"], inspect.isclass) + for class_name, clazz in classes: + if issubclass(clazz, torch.nn.Module): + num_args_excluding_self = ( + len(inspect.getfullargspec(clazz.__init__).args) - 1 + ) + self.assertLessEqual( + num_args_excluding_self, + MAX_NUM_ARGS_IN_MODULE_CTOR, + "Modules in TorchRec can have no more than {} constructor args, but {} has {}.".format( + MAX_NUM_ARGS_IN_MODULE_CTOR, class_name, num_args_excluding_self + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchrec/modules/tests/test_crossnet.py b/torchrec/modules/tests/test_crossnet.py new file mode 100644 index 000000000..c0c91da59 --- /dev/null +++ b/torchrec/modules/tests/test_crossnet.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 + +import unittest + +import torch +from torch.fx import GraphModule, Tracer +from torchrec.modules.crossnet import ( + CrossNet, + LowRankCrossNet, + VectorCrossNet, + LowRankMixtureCrossNet, +) + +# unit test for Full Rank CrossNet: CrossNet +class TestCrossNet(unittest.TestCase): + def test_cross_net_numercial_forward(self) -> None: + torch.manual_seed(0) + + batch_size = 3 + num_layers = 20 + in_features = 2 + input = torch.randn(batch_size, in_features) + + # test using vector for crossing + dcn = CrossNet(in_features=in_features, num_layers=num_layers) + output = dcn(input) + expected_output = torch.Tensor( + [ + [2.4481, 2.2710], + [-63.1721, -109.2410], + [1.4030, 1.0054], + ] + ) + self.assertTrue(torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4)) + + def test_fx_script_cross_net(self) -> None: + input = torch.randn(2, 3) + dcn = CrossNet(in_features=3, num_layers=2) + dcn(input) + + # dry-run to initialize lazy module + gm = GraphModule(dcn, Tracer().trace(dcn)) + torch.jit.script(gm) + + +# unit test for Low Rank CrossNet: LowRankCrossNet +class TestLowRankCrossNet(unittest.TestCase): + def test_cross_net_numercial_forward(self) -> None: + torch.manual_seed(0) + + batch_size = 3 + num_layers = 20 + in_features = 2 + input = torch.randn(batch_size, in_features) + + # test using vector for crossing + dcn = LowRankCrossNet( + in_features=in_features, num_layers=num_layers, low_rank=10 + ) + output = dcn(input) + expected_output = torch.Tensor( + [ + [-11.5000, -3.4863], + [-0.2742, -0.3330], + [249.6694, 117.3466], + ] + ) + self.assertTrue(torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4)) + + def test_fx_script_cross_net(self) -> None: + input = torch.randn(2, 3) + dcn = LowRankCrossNet(in_features=3, num_layers=2, low_rank=2) + dcn(input) + + # dry-run to initialize lazy module + gm = GraphModule(dcn, Tracer().trace(dcn)) + torch.jit.script(gm) + + +# unit test for Vector Version CrossNet: VectorCrossNet +class TestVectorCrossNet(unittest.TestCase): + def test_cross_net_numercial_forward(self) -> None: + torch.manual_seed(0) + + batch_size = 3 + num_layers = 20 + in_features = 2 + input = torch.randn(batch_size, in_features) + + # test using vector for crossing + dcn = VectorCrossNet(in_features=in_features, num_layers=num_layers) + output = dcn(input) + expected_output = torch.Tensor( + [ + [1.8289e-04, -3.4827e-05], + [-2.2084e02, 5.7615e01], + [-1.3328e02, -1.7187e02], + ] + ) + self.assertTrue(torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4)) + + def test_fx_script_cross_net(self) -> None: + input = torch.randn(2, 3) + dcn = VectorCrossNet(in_features=3, num_layers=2) + dcn(input) + + # dry-run to initialize lazy module + gm = GraphModule(dcn, Tracer().trace(dcn)) + torch.jit.script(gm) + + +# unit test for Low Rank CrossNet with Mixture of Expert: LowRankMixtureCrossNet +class TestLowRankMixtureCrossNet(unittest.TestCase): + def test_cross_net_numercial_forward(self) -> None: + torch.manual_seed(0) + + batch_size = 3 + num_layers = 20 + in_features = 2 + input = torch.randn(batch_size, in_features) + + # test using vector for crossing + dcn = LowRankMixtureCrossNet( + in_features=in_features, num_layers=num_layers, num_experts=4, low_rank=10 + ) + output = dcn(input) + expected_output = torch.Tensor( + [ + [1.6171, -0.3217], + [-2.7060, 0.5359], + [-1.2054, -1.3132], + ] + ) + self.assertTrue(torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4)) + + def test_cross_net_numercial_forward_1_expert(self) -> None: + torch.manual_seed(0) + + batch_size = 3 + num_layers = 20 + in_features = 2 + input = torch.randn(batch_size, in_features) + + # test using vector for crossing + dcn = LowRankMixtureCrossNet( + in_features=in_features, num_layers=num_layers, num_experts=1, low_rank=10 + ) + output = dcn(input) + expected_output = torch.Tensor( + [ + [3.9203, -0.2686], + [-9.5767, 0.8621], + [-2.5836, -1.8124], + ] + ) + print(output) + self.assertTrue(torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4)) + + def test_fx_script_cross_net(self) -> None: + input = torch.randn(2, 3) + dcn = LowRankMixtureCrossNet(in_features=3, num_layers=2) + dcn(input) + + # dry-run to initialize lazy module + gm = GraphModule(dcn, Tracer().trace(dcn)) + torch.jit.script(gm) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchrec/modules/tests/test_deepfm.py b/torchrec/modules/tests/test_deepfm.py new file mode 100644 index 000000000..9a616d474 --- /dev/null +++ b/torchrec/modules/tests/test_deepfm.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 + +import unittest + +import torch +from torchrec.fx import Tracer +from torchrec.modules.deepfm import ( + DeepFM, + FactorizationMachine, +) + + +class TestDeepFM(unittest.TestCase): + def test_deepfm_shape(self) -> None: + + batch_size = 3 + output_dim = 30 + # the input embedding are in torch.Tensor of [batch_size, num_embeddings, embedding_dim] + input_embeddings = [ + torch.randn(batch_size, 2, 64), + torch.randn(batch_size, 2, 32), + torch.randn(batch_size, 3, 100), + torch.randn(batch_size, 5, 120), + ] + in_features = 2 * 64 + 2 * 32 + 3 * 100 + 5 * 120 + dense_module = torch.nn.Sequential( + torch.nn.Linear(in_features, 128), + torch.nn.ReLU(), + torch.nn.Linear(128, 32), + torch.nn.ReLU(), + torch.nn.Linear(32, output_dim), + torch.nn.ReLU(), + ) + deepfm = DeepFM(dense_module=dense_module) + + deep_fm_output = deepfm(input_embeddings) + + self.assertEqual(list(deep_fm_output.shape), [batch_size, output_dim]) + + def test_deepfm_with_lazy_shape(self) -> None: + batch_size = 3 + output_dim = 30 + # the input embedding are in torch.Tensor of [batch_size, num_embeddings, embedding_dim] + input_embeddings = [ + torch.randn(batch_size, 2, 64), + torch.randn(batch_size, 2, 32), + torch.randn(batch_size, 3, 100), + torch.randn(batch_size, 5, 120), + ] + dense_module = torch.nn.Sequential( + torch.nn.LazyLinear(output_dim), + torch.nn.ReLU(), + ) + deepfm = DeepFM(dense_module=dense_module) + + deep_fm_output = deepfm(input_embeddings) + + self.assertEqual(list(deep_fm_output.shape), [batch_size, output_dim]) + + def test_deepfm_numerical_forward(self) -> None: + torch.manual_seed(0) + + batch_size = 3 + output_dim = 2 + # the input embedding are in torch.Tensor of [batch_size, num_embeddings, embedding_dim] + input_embeddings = [ + torch.randn(batch_size, 2, 64), + torch.randn(batch_size, 2, 32), + torch.randn(batch_size, 3, 100), + torch.randn(batch_size, 5, 120), + ] + in_features = 2 * 64 + 2 * 32 + 3 * 100 + 5 * 120 + dense_module = torch.nn.Sequential( + torch.nn.Linear(in_features, 128), + torch.nn.ReLU(), + torch.nn.Linear(128, 32), + torch.nn.ReLU(), + torch.nn.Linear(32, output_dim), + torch.nn.ReLU(), + ) + deepfm = DeepFM(dense_module=dense_module) + + output = deepfm(input_embeddings) + + expected_output = torch.Tensor( + [ + [0.0896, 0.1182], + [0.0675, 0.0972], + [0.0764, 0.0199], + ], + ) + self.assertTrue( + torch.allclose( + output, + expected_output, + rtol=1e-4, + atol=1e-4, + ) + ) + + def test_fx_script_deepfm(self) -> None: + m = DeepFM(dense_module=torch.nn.Linear(4, 1)) + + # dryrun to initialize the input + m([torch.randn(2, 2, 2)]) + gm = torch.fx.GraphModule(m, Tracer().trace(m)) + torch.jit.script(gm) + + +class TestFM(unittest.TestCase): + def test_fm_shape(self) -> None: + + batch_size = 3 + # the input embedding are in torch.Tensor of [batch_size, num_embeddings, embedding_dim] + input_embeddings = [ + torch.randn(batch_size, 2, 64), + torch.randn(batch_size, 2, 32), + torch.randn(batch_size, 3, 100), + torch.randn(batch_size, 5, 120), + ] + + fm = FactorizationMachine() + + fm_output = fm(input_embeddings) + + self.assertEqual(list(fm_output.shape), [batch_size, 1]) + + def test_fm_numerical_forward(self) -> None: + torch.manual_seed(0) + + batch_size = 3 + # the input embedding are in torch.Tensor of [batch_size, num_embeddings, embedding_dim] + input_embeddings = [ + torch.randn(batch_size, 2, 64), + torch.randn(batch_size, 2, 32), + torch.randn(batch_size, 3, 100), + torch.randn(batch_size, 5, 120), + ] + fm = FactorizationMachine() + + output = fm(input_embeddings) + + expected_output = torch.Tensor( + [ + [-577.5231], + [752.7272], + [-509.1023], + ] + ) + self.assertTrue( + torch.allclose( + output, + expected_output, + rtol=1e-4, + atol=1e-4, + ) + ) + + def test_fx_script_fm(self) -> None: + m = FactorizationMachine() + gm = torch.fx.GraphModule(m, Tracer().trace(m)) + torch.jit.script(gm) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchrec/modules/tests/test_embedding_modules.py b/torchrec/modules/tests/test_embedding_modules.py new file mode 100644 index 000000000..bd2960bbc --- /dev/null +++ b/torchrec/modules/tests/test_embedding_modules.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 + +import unittest +from typing import Union, Dict + +import torch +import torch.fx +from torchrec.fx import symbolic_trace +from torchrec.modules.embedding_configs import ( + EmbeddingBagConfig, + EmbeddingConfig, +) +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, JaggedTensor + + +class EmbeddingBagCollectionTest(unittest.TestCase): + def test_unweighted(self) -> None: + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] + ) + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + pooled_embeddings = ebc(features) + self.assertEqual(pooled_embeddings.values().size(), (3, 7)) + self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 7]) + + def test_shared_tables(self) -> None: + eb_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f2"] + ) + ebc = EmbeddingBagCollection(tables=[eb_config]) + + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + pooled_embeddings = ebc(features) + self.assertEqual(pooled_embeddings.values().size(), (3, 6)) + self.assertEqual(pooled_embeddings.keys(), ["f1", "f2"]) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 6]) + + def test_shared_features(self) -> None: + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f1"] + ) + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + pooled_embeddings = ebc(features) + self.assertEqual(pooled_embeddings.values().size(), (6, 7)) + self.assertEqual(pooled_embeddings.keys(), ["f1@t1", "f1@t2"]) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 7]) + + def test_weighted(self) -> None: + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config], is_weighted=True) + + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 3, 4, 7]), + offsets=torch.tensor([0, 2, 4, 6, 8, 10, 12]), + weights=torch.tensor( + [0.1, 0.2, 0.4, 0.5, 0.4, 0.3, 0.2, 0.9, 0.1, 0.3, 0.4, 0.7] + ), + ) + + pooled_embeddings = ebc(features) + self.assertEqual(pooled_embeddings.values().size(), (2, 10)) + self.assertEqual(pooled_embeddings.keys(), ["f1", "f3", "f2"]) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 6, 10]) + + def test_fx(self) -> None: + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1", "f3"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config], is_weighted=True) + + gm = symbolic_trace(ebc) + + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f3", "f2"], + values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 3, 4, 7]), + offsets=torch.tensor([0, 2, 4, 6, 8, 10, 12]), + weights=torch.tensor( + [0.1, 0.2, 0.4, 0.5, 0.4, 0.3, 0.2, 0.9, 0.1, 0.3, 0.4, 0.7] + ), + ) + + pooled_embeddings = gm(features) + self.assertEqual(pooled_embeddings.values().size(), (2, 10)) + self.assertEqual(pooled_embeddings.keys(), ["f1", "f3", "f2"]) + self.assertEqual(pooled_embeddings.offset_per_key(), [0, 3, 6, 10]) + + # TODO(T89043538): Auto-generate this test. + def test_fx_script(self) -> None: + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] + ) + + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + + gm = symbolic_trace(ebc) + torch.jit.script(gm) + + def test_duplicate_config_name_fails(self) -> None: + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + eb2_config = EmbeddingBagConfig( + name="t1", embedding_dim=4, num_embeddings=10, feature_names=["f2"] + ) + with self.assertRaises(ValueError): + EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + + def test_device(self) -> None: + config = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + ebc = EmbeddingBagCollection(tables=[config], device=torch.device("meta")) + self.assertEquals(torch.device("meta"), ebc.embedding_bags["t1"].weight.device) + + +class EmbeddingCollectionTest(unittest.TestCase): + def test_simple(self) -> None: + e1_config = EmbeddingConfig( + name="t1", embedding_dim=2, num_embeddings=10, feature_names=["f1"] + ) + e2_config = EmbeddingConfig( + name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] + ) + ec = EmbeddingCollection(tables=[e1_config, e2_config]) + + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + feature_embeddings = ec(features) + + self.assertEqual(feature_embeddings["f1"].values().size(), (3, 2)) + self.assertTrue( + torch.allclose(feature_embeddings["f1"].lengths(), torch.tensor([2, 0, 1])) + ) + self.assertEqual(feature_embeddings["f2"].values().size(), (5, 3)) + self.assertTrue( + torch.allclose(feature_embeddings["f2"].lengths(), torch.tensor([1, 1, 3])) + ) + + def test_fx(self) -> None: + e1_config = EmbeddingConfig( + name="t1", embedding_dim=2, num_embeddings=10, feature_names=["f1"] + ) + e2_config = EmbeddingConfig( + name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] + ) + ec = EmbeddingCollection(tables=[e1_config, e2_config]) + + gm = symbolic_trace(ec) + + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + feature_embeddings = gm(features) + self.assertEqual(feature_embeddings["f1"].values().size(), (3, 2)) + self.assertTrue( + torch.allclose(feature_embeddings["f1"].lengths(), torch.tensor([2, 0, 1])) + ) + self.assertEqual(feature_embeddings["f2"].values().size(), (5, 3)) + self.assertTrue( + torch.allclose(feature_embeddings["f2"].lengths(), torch.tensor([1, 1, 3])) + ) + + # TODO(T89043538): Auto-generate this test. + def test_fx_script(self) -> None: + e1_config = EmbeddingConfig( + name="t1", embedding_dim=2, num_embeddings=10, feature_names=["f1"] + ) + e2_config = EmbeddingConfig( + name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] + ) + ec = EmbeddingCollection(tables=[e1_config, e2_config]) + + gm = symbolic_trace(ec) + torch.jit.script(gm) + + def test_duplicate_config_name_fails(self) -> None: + e1_config = EmbeddingConfig( + name="t1", embedding_dim=2, num_embeddings=10, feature_names=["f1"] + ) + e2_config = EmbeddingConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f2"] + ) + with self.assertRaises(ValueError): + EmbeddingCollection(tables=[e1_config, e2_config]) + + def test_device(self) -> None: + config = EmbeddingConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + eb = EmbeddingCollection(tables=[config], device=torch.device("meta")) + self.assertEquals(torch.device("meta"), eb.embeddings["t1"].weight.device) diff --git a/torchrec/modules/tests/test_lazy_extension.py b/torchrec/modules/tests/test_lazy_extension.py new file mode 100644 index 000000000..1dcab915c --- /dev/null +++ b/torchrec/modules/tests/test_lazy_extension.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 + +import inspect +import re +import unittest +from typing import Tuple + +import torch +from torch.nn.modules.lazy import LazyModuleMixin +from torchrec.modules.lazy_extension import ( + LazyModuleExtensionMixin, + lazy_apply, +) + + +def remove_comment(source_code: str) -> str: + result = re.sub(r"\s*#.*", "", str(source_code)) + return result + + +class TestLazyModuleExtensionMixin(unittest.TestCase): + def test_source_code_parity_on_call_impl(self) -> None: + original_call_impl_src = inspect.getsource(torch.nn.Module._call_impl) + lazy_ext_call_impl_src = inspect.getsource(LazyModuleExtensionMixin._call_impl) + + # remove comments + original_call_impl_src = remove_comment(original_call_impl_src) + lazy_ext_call_impl_src = remove_comment(lazy_ext_call_impl_src) + + # reproduce the only change: + old_code = """ + result = hook(self, input) + """ + new_code = """ + if len(inspect.signature(hook).parameters) == 3: + result = hook(self, input, kwargs) + else: + result = hook(self, input) + """ + expected_lazy_ext_call_impl_src = original_call_impl_src.replace( + old_code, new_code + ) + + self.assertEqual( + lazy_ext_call_impl_src, + expected_lazy_ext_call_impl_src, + "Please make sure `LazyModuleExtensionMixin._call_impl` has the same source code " + "as `torch.nn.Module._call_impl` except the expected difference that is checked " + "in this unit test.", + ) + + def test_source_code_parity_on_infer_parameters(self) -> None: + original_infer_parameters_src = inspect.getsource( + LazyModuleMixin._infer_parameters + ) + lazy_ext_infer_parameters_src = inspect.getsource( + LazyModuleExtensionMixin._infer_parameters + ) + + # remove comments + original_infer_parameters_src = remove_comment(original_infer_parameters_src) + lazy_ext_infer_parameters_src = remove_comment(lazy_ext_infer_parameters_src) + + # reproduce the only changes: + expected_lazy_ext_infer_parameters_src = original_infer_parameters_src.replace( + "def _infer_parameters(self: _LazyProtocol, module, input):", + "def _infer_parameters(self: _LazyExtensionProtocol, module, input, kwargs):", + ).replace( + "module.initialize_parameters(*input)", + "module.initialize_parameters(*input, **kwargs)", + ) + + self.assertEqual( + lazy_ext_infer_parameters_src, + expected_lazy_ext_infer_parameters_src, + "Please make sure `LazyModuleExtensionMixin._infer_parameters` has the same source " + "code as `LazyModuleMixin._infer_parameters` except the expected difference that " + "is checked in this unit test.", + ) + + def test_forward_pre_hook_self_function_with_input_only(self) -> None: + class TestModule(LazyModuleExtensionMixin, torch.nn.Module): + """ + Create this unit test to make sure the old way of initialize self hook function + is enabled, with the hook definition as: + valid_input_only_hook(self, module, input) + + if we run the TestModule as: + >>> m = TestModule() + >>> output = m() + >>> expected_output = torch.zeros(2, 2) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.register_forward_pre_hook(self.valid_input_only_hook) + + def valid_input_only_hook(self, module, input): + self.output = torch.zeros(2, 2) + + def initialize_parameters(self) -> None: + return None + + def forward(self) -> torch.Tensor: + return self.output + + m = TestModule() + + # test for self function registeration with register_forward_pre_hook + output_forward = m() + self.assertTrue( + torch.allclose(output_forward, torch.zeros(2, 2)), + "Please make sure forward function is executed as expected.", + ) + + def test_forward_pre_hook_global_function_with_input_only(self) -> None: + class TestModule(LazyModuleExtensionMixin, torch.nn.Module): + """ + Create this unit test to make sure the old way of insert hook function is enabled, + with the hook definition as: + valid_input_only_hook(self, module, input) + + if we run the TestModule as: + >>> def input_only_hook(module, input_tuple): + >>> return input_tuple[0] + 1 + >>> + >>> m = TestModule() + >>> m.register_forward_pre_hook(input_only_hook) + >>> output = m(torch.zeros(2, 2)) + >>> expected_output = torch.ones(2, 2) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def initialize_parameters(self, input) -> None: + return None + + def forward(self, input) -> torch.Tensor: + return input + + def input_only_hook( + module: torch.nn.Module, input: Tuple[torch.Tensor, ...] + ) -> torch.Tensor: + # input is tuple + return input[0] + 1 + + m = TestModule() + # pyre-fixme[29] + m.register_forward_pre_hook(input_only_hook) + output = m(torch.zeros(2, 2)) + self.assertTrue(torch.allclose(output, torch.ones(2, 2))) + + def test_lazy_apply(self) -> None: + count_original: int = 0 + count_increment: int = 1 + + class TestModule(LazyModuleExtensionMixin, torch.nn.Module): + def __init__(self): + super().__init__() + self.count = torch.tensor(count_original) + + def initialize_parameters(self, input) -> None: + pass + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return self.count.clone() + + def increment_count(module: torch.nn.Module) -> None: + if isinstance(module, TestModule): + module.count += torch.tensor(count_increment) + + def check_result(m: torch.nn.Module, count_after_first_forward: int) -> None: + # This check ensures that `lazy_apply()` is a delayed operation (i.e. the function is not applied immediately). + for count in m.parameters(): + self.assertTrue(torch.allclose(count, torch.tensor(0))) + + input = torch.tensor(321) + out = m(input) + + # This check ensures that the lazy-applied function is not called before forward function is called. + self.assertTrue(torch.allclose(out, torch.tensor(count_original))) + + # This check ensures that the lazy-applied function is called after forward function is called. + for count in m.parameters(): + self.assertTrue( + torch.allclose(count, torch.tensor(count_after_first_forward)) + ) + + # This check ensures that the lazy-applied function is removed after first forward pass is run. + out = m(input) + self.assertTrue( + torch.allclose(out, torch.tensor(count_after_first_forward)), str(out) + ) + # Since `increment_count` is not run the second time, value of `count` parameter is not changed. + for count in m.parameters(): + self.assertTrue( + torch.allclose(count, torch.tensor(count_after_first_forward)) + ) + + # fmt: off + check_result( + lazy_apply( + TestModule(), + increment_count, + ), + count_after_first_forward=1, + ) + check_result( + lazy_apply( + torch.nn.Sequential( + TestModule(), + TestModule(), + ), + increment_count, + ), + count_after_first_forward=1, + ) + check_result( + lazy_apply( + lazy_apply( + TestModule(), + increment_count, + ), + increment_count, + ), + count_after_first_forward=2, + ) + check_result( + lazy_apply( + lazy_apply( + torch.nn.Sequential( + TestModule(), + TestModule() + ), + increment_count, + ), + increment_count, + ), + count_after_first_forward=2, + ) + check_result( + lazy_apply( + torch.nn.Sequential( + lazy_apply( + TestModule(), + increment_count, + ), + torch.nn.Identity() + ), + increment_count, + ), + count_after_first_forward=2, + ) + # fmt: on + + def test_apply(self) -> None: + class TestModule(LazyModuleExtensionMixin, torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.tensor(1.0) + + def initialize_parameters(self, input) -> None: + return None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input + + @torch.no_grad() + def init_weights(m: torch.nn.Module) -> None: + if type(m) == TestModule: + m.param.fill_(7.0) + + # Case 1: Running `.apply()` without running first forward pass to + # initialize the module will result in error. + net = torch.nn.Sequential(TestModule(), TestModule()) + with self.assertRaisesRegex(RuntimeError, "has not been initialized"): + net.apply(init_weights) + + # Case 2: Running `.apply()` after running first forward pass will succeed. + net(torch.tensor(2.0)) + net.apply(init_weights) + self.assertTrue(torch.allclose(net[0].param, torch.tensor(7.0))) + + # Case 3: Running `.lazy_apply()` without running first forward pass will succeed, + # and the function will be applied right after first forward pass. + net = torch.nn.Sequential(TestModule(), TestModule()) + net = lazy_apply(net, init_weights) + # pyre-ignore[29] + self.assertTrue(torch.allclose(net[0].param, torch.tensor(1.0))) + net(torch.tensor(2.0)) + # pyre-ignore[29] + self.assertTrue(torch.allclose(net[0].param, torch.tensor(7.0))) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchrec/modules/tests/test_mlp.py b/torchrec/modules/tests/test_mlp.py new file mode 100644 index 000000000..70c89ce15 --- /dev/null +++ b/torchrec/modules/tests/test_mlp.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 + +import unittest +from typing import Callable, List, Union + +import hypothesis.strategies as st +import torch +from hypothesis import given, settings +from torch import nn +from torchrec.fx import symbolic_trace +from torchrec.modules.mlp import Perceptron, MLP + + +class TestMLP(unittest.TestCase): + # pyre-ignore[56]: Pyre was not able to infer the type of argument + # to decorator factory `hypothesis.given`. + @given( + has_bias=st.booleans(), + activation=st.sampled_from( + [ + torch.relu, + torch.tanh, + torch.sigmoid, + nn.SiLU(), + ] + ), + ) + @settings(deadline=None) + def test_perceptron_single_channel( + self, + has_bias: bool, + activation: Union[ + torch.nn.Module, + Callable[[torch.Tensor], torch.Tensor], + ], + ) -> None: + batch_size = 3 + + input_dims: List[int] = [40, 30, 20, 10] + input_tensors: List[torch.Tensor] = [ + torch.randn(batch_size, input_dims[0]), # Task 1 + torch.randn(batch_size, input_dims[1]), # Task 2 + torch.randn(batch_size, input_dims[2]), # Task 3 + torch.randn(batch_size, input_dims[3]), # Task 4 + ] + + perceptron_layer_size = 16 + num_tasks = 4 + + perceptron_for_tasks = [ + Perceptron( + input_dims[i], + perceptron_layer_size, + bias=has_bias, + activation=activation, + ) + for i in range(num_tasks) + ] + + # Dry-run with input of a different batch size + dry_run_batch_size = 1 + assert dry_run_batch_size != batch_size + for i in range(num_tasks): + perceptron_for_tasks[i]( + torch.randn(dry_run_batch_size, input_tensors[i].shape[-1]) + ) + + output_tensors = [] + expected_output_tensors = [] + for i in range(len(input_tensors)): + output_tensors.append(perceptron_for_tasks[i](input_tensors[i])) + expected_output_tensors.append( + perceptron_for_tasks[i]._activation_fn( + perceptron_for_tasks[i]._linear(input_tensors[i]) + ) + ) + + for i in range(len(output_tensors)): + self.assertEqual( + list(output_tensors[i].shape), [batch_size, perceptron_layer_size] + ) + self.assertTrue( + torch.allclose(output_tensors[i], expected_output_tensors[i]) + ) + + def test_fx_script_Perceptron(self) -> None: + batch_size = 1 + in_features = 3 + out_features = 5 + m = Perceptron(in_features, out_features) + + # Dry-run to initialize lazy module. + m(torch.randn(batch_size, in_features)) + + gm = symbolic_trace(m) + torch.jit.script(gm) + + def test_fx_script_MLP(self) -> None: + in_features = 3 + layer_sizes = [16, 8, 4] + m = MLP(in_features, layer_sizes) + + gm = symbolic_trace(m) + torch.jit.script(gm) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchrec/modules/tests/test_score_learning.py b/torchrec/modules/tests/test_score_learning.py new file mode 100644 index 000000000..c7a253550 --- /dev/null +++ b/torchrec/modules/tests/test_score_learning.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 + +import unittest + +import torch +from torchrec.fx import Tracer +from torchrec.modules.score_learning import PositionWeightsAttacher +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class PositionWeightsAttacherTest(unittest.TestCase): + def test_populate_weights(self) -> None: + features_max_length = {"f1": 10, "f2": 3} + pw = PositionWeightsAttacher(features_max_length) + + # 0 1 2 <-- batch + # 0 [0,1] None [2] + # 1 [3] [4] [5,6,7] + # ^ + # feature + features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + weighted_features = pw(features) + self.assertEqual(weighted_features.weights().size(), (8,)) + self.assertEqual(weighted_features["f1"].weights().size(), (3,)) + self.assertEqual(weighted_features["f2"].weights().size(), (5,)) + pw_f1_ref = torch.gather( + pw.state_dict()["position_weights.f1"], 0, torch.tensor([0, 1, 0]) + ) + pw_f1 = weighted_features["f1"].weights().detach() + self.assertTrue(torch.allclose(pw_f1_ref, pw_f1)) + pw_f2_ref = torch.gather( + pw.state_dict()["position_weights.f2"], 0, torch.tensor([0, 0, 0, 1, 2]) + ) + pw_f2 = weighted_features["f2"].weights().detach() + self.assertTrue(torch.allclose(pw_f2_ref, pw_f2)) + + def test_fx_script_PositionWeightsAttacher(self) -> None: + features_max_length = {"f1": 10, "f2": 3} + pw = PositionWeightsAttacher(features_max_length) + + gm = torch.fx.GraphModule(pw, Tracer().trace(pw)) + torch.jit.script(gm) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchrec/modules/utils.py b/torchrec/modules/utils.py new file mode 100644 index 000000000..cfe8cfe20 --- /dev/null +++ b/torchrec/modules/utils.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 + +import copy +from typing import Callable, Iterable, Tuple, Union + +import torch + + +def extract_module_or_tensor_callable( + module_or_callable: Union[ + Callable[[], torch.nn.Module], + torch.nn.Module, + Callable[[torch.Tensor], torch.Tensor], + ] +) -> Union[torch.nn.Module, Callable[[torch.Tensor], torch.Tensor]]: + try: + # pyre-ignore[20]: PositionalOnly call expects argument in position 0 + module = module_or_callable() + if isinstance(module, torch.nn.Module): + return module + else: + raise ValueError( + "Expected callable that takes no input to return " + "a torch.nn.Module, but got: {}".format(type(module)) + ) + except TypeError as e: + if "required positional argument" in str(e): + # pyre-ignore[7]: Expected `Union[typing.Callable[[torch.Tensor], torch.Tensor], torch.nn.Module]` + return module_or_callable + raise + + +def get_module_output_dimension( + module: Union[Callable[[torch.Tensor], torch.Tensor], torch.nn.Module], + in_features: int, +) -> int: + input = torch.zeros(1, in_features) + output = module(input) + return output.size(-1) + + +def check_module_output_dimension( + module: Union[Iterable[torch.nn.Module], torch.nn.Module], + in_features: int, + out_features: int, +) -> bool: + """ + Verify that the out_features of a given module or a list of modules matches the specified number. + If a list of modules or a ModuleList is given, recursively check all the submodules. + """ + if isinstance(module, list) or isinstance(module, torch.nn.ModuleList): + return all( + check_module_output_dimension(submodule, in_features, out_features) + for submodule in module + ) + else: + # pyre-fixme[6]: Expected `Union[typing.Callable[[torch.Tensor], + # torch.Tensor], torch.nn.Module]` for 1st param but got + # `Union[Iterable[torch.nn.Module], torch.nn.Module]`. + return get_module_output_dimension(module, in_features) == out_features + + +def init_mlp_weights_xavier_uniform(m: torch.nn.Module) -> None: + if isinstance(m, torch.nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + # pyre-fixme[16]: Optional type has no attribute `data`. + m.bias.data.fill_(0.0) + + +def construct_modulelist_from_single_module( + module: torch.nn.Module, sizes: Tuple[int, ...] +) -> torch.nn.Module: + """ + Given a single module, construct a (nested) ModuleList of size of sizes by making copies of + the provided module and reinitializing the Linear layers. + """ + if len(sizes) == 1: + return torch.nn.ModuleList( + [ + copy.deepcopy(module).apply(init_mlp_weights_xavier_uniform) + for _ in range(sizes[0]) + ] + ) + else: + # recursively create nested ModuleList + return torch.nn.ModuleList( + [ + construct_modulelist_from_single_module(module, sizes[1:]) + for _ in range(sizes[0]) + ] + ) + + +def convert_list_of_modules_to_modulelist( + modules: Iterable[torch.nn.Module], sizes: Tuple[int, ...] +) -> torch.nn.Module: + assert ( + # pyre-fixme[6]: Expected `Sized` for 1st param but got + # `Iterable[torch.nn.Module]`. + len(modules) + == sizes[0] + ), f"the counts of modules ({len(modules)}) do not match with the required counts {sizes}" + if len(sizes) == 1: + return torch.nn.ModuleList(modules) + else: + # recursively create nested list + return torch.nn.ModuleList( + # pyre-fixme[6]: Expected `Iterable[torch.nn.Module]` for 1st param but + # got `Module`. + convert_list_of_modules_to_modulelist(m, sizes[1:]) + for m in modules + ) diff --git a/torchrec/optim/__init__.py b/torchrec/optim/__init__.py new file mode 100644 index 000000000..dc9864126 --- /dev/null +++ b/torchrec/optim/__init__.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 + +from torchrec.optim.clipping import GradientClipping, GradientClippingOptimizer # noqa +from torchrec.optim.fused import FusedOptimizer, FusedOptimizerModule # noqa +from torchrec.optim.keyed import ( + KeyedOptimizer, + CombinedOptimizer, + KeyedOptimizerWrapper, + OptimizerWrapper, +) # noqa +from torchrec.optim.warmup import WarmupPolicy, WarmupStage, WarmupOptimizer # noqa diff --git a/torchrec/optim/clipping.py b/torchrec/optim/clipping.py new file mode 100644 index 000000000..5d3fcc70a --- /dev/null +++ b/torchrec/optim/clipping.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +from enum import Enum, unique +from typing import Any, List + +import torch +from torchrec.optim.keyed import OptimizerWrapper, KeyedOptimizer + + +@unique +class GradientClipping(Enum): + NORM = "norm" + VALUE = "value" + NONE = "none" + + +class GradientClippingOptimizer(OptimizerWrapper): + """ + Clips gradients before doing optimization step. + + Constructor Args: + optimizer (KeyedOptimizer): optimizer to wrap + clipping (GradientClipping): how to clip gradients + max_gradient (float): max value for clipping + """ + + def __init__( + self, + optimizer: KeyedOptimizer, + clipping: GradientClipping = GradientClipping.NONE, + max_gradient: float = 0.1, + ) -> None: + super().__init__(optimizer) + self._clipping = clipping + self._max_gradient = max_gradient + + self._params: List[torch.Tensor] = [] + for param_group in self.param_groups: + self._params += list(param_group["params"]) + + # pyre-ignore [2] + def step(self, closure: Any = None) -> None: + if self._clipping == GradientClipping.NORM: + torch.nn.utils.clip_grad_norm_(self._params, self._max_gradient) + elif self._clipping == GradientClipping.VALUE: + torch.nn.utils.clip_grad_value_(self._params, self._max_gradient) + + super().step(closure) diff --git a/torchrec/optim/fused.py b/torchrec/optim/fused.py new file mode 100644 index 000000000..13a5c00cf --- /dev/null +++ b/torchrec/optim/fused.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 + +import abc +from typing import Any + +from torch import optim +from torchrec.optim.keyed import KeyedOptimizer + + +class FusedOptimizer(KeyedOptimizer, abc.ABC): + """ + Assumes that weight update is done during backward pass, + thus step() is a no-op. + """ + + @abc.abstractmethod + # pyre-ignore [2] + def step(self, closure: Any = None) -> None: + ... + + @abc.abstractmethod + def zero_grad(self, set_to_none: bool = False) -> None: + ... + + def __repr__(self) -> str: + return optim.Optimizer.__repr__(self) + + +class FusedOptimizerModule(abc.ABC): + """ + Module, which does weight update during backward pass. + """ + + @property + @abc.abstractmethod + def fused_optimizer(self) -> KeyedOptimizer: + ... diff --git a/torchrec/optim/keyed.py b/torchrec/optim/keyed.py new file mode 100644 index 000000000..26435c862 --- /dev/null +++ b/torchrec/optim/keyed.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 + +from copy import deepcopy +from typing import ( + Callable, + List, + Mapping, + Set, + Dict, + Any, + Collection, + Tuple, + Union, +) + +import torch +from torch import optim +from torch.distributed._sharded_tensor import ShardedTensor + + +OptimizerFactory = Callable[[List[torch.Tensor]], optim.Optimizer] + + +class KeyedOptimizer(optim.Optimizer): + """ + Takes a dict of parameters and exposes state_dict by parameter key. + """ + + def __init__( + self, + params: Mapping[str, torch.Tensor], + # pyre-ignore [2] + state: Mapping[Any, Any], + param_groups: Collection[Mapping[str, Any]], + ) -> None: + torch._C._log_api_usage_once(f"torchrec.optim.{self.__class__.__name__}") + # pyre-ignore [4] + self.state: Mapping[Any, Any] = state + self.param_groups: Collection[Mapping[str, Any]] = param_groups + self.params = params + self.defaults: Dict[str, Any] = {} + + params_set = set(params.values()) + non_param_state_keys = [key for key in self.state if key not in params_set] + if len(non_param_state_keys) > 0: + raise ValueError( + "All state keys must be params. The following keys are not: {}.".format( + non_param_state_keys + ) + ) + + def state_dict(self) -> Dict[str, Any]: + """ + Returned state and param_groups will contain parameter keys + instead of parameter indices in torch.Optimizer. + This allows for advanced functionality like optimizer re-sharding to be implemented. + """ + + state = self.state + param_groups = self.param_groups + params = self.params + param_to_key = {param: key for key, param in params.items()} + + ret_state = { + param_to_key[param]: state_val for param, state_val in state.items() + } + + ret_groups = [] + for group in param_groups: + param_keys = [] + for param in group["params"]: + param_keys.append(param_to_key[param]) + ret_group = {"params": sorted(param_keys)} + for k, v in group.items(): + if k != "params": + ret_group[k] = deepcopy(v) + ret_groups.append(ret_group) + + return {"state": ret_state, "param_groups": ret_groups} + + def post_load_state_dict(self) -> None: + pass + + def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + """ + This implementation is much stricter than the one in torch.Optimizer: + it requires implementations to fully initialize their state during first optimization iteration, + and it prohibits loading an empty state into already initialized KeyedOptimizer and vise versa. + Because of introduced strictness it allows us to: + * do compatibility checks for state and param_groups, which improves usability + * avoid state duplication by directly copying into state tensors, e.g. + optimizer.step() # make sure optimizer is initialized + sd = optimizer.state_dict() + load_checkpoint(sd) # copy state directly into tensors, re-shard if needed + optimizer.load_state_dict(sd) # replace param_groups + """ + + new_state = state_dict["state"] + new_param_groups = state_dict["param_groups"] + state = self.state + param_groups = self.param_groups + params = self.params + + # Load state + if len(state) != len(new_state): + raise ValueError( + f"Different parameter count: {len(state)} vs {len(new_state)}" + ) + for param_key, param in params.items(): + if param not in state: + continue + if param_key not in new_state: + raise ValueError(f"Parameter {param_key} not found") + if len(state[param]) != len(new_state[param_key]): + raise ValueError( + f"Different state size: {len(state[param])} vs {len(new_state[param_key])}" + ) + for state_key, state_val in state[param].items(): + if state_key not in new_state[param_key]: + raise ValueError( + f"State key {state_key} not found for param {param_key}" + ) + + new_state_val = new_state[param_key][state_key] + if isinstance(state_val, torch.Tensor): + state_val.detach().copy_(new_state_val) + elif isinstance(state_val, ShardedTensor): + num_shards = len(state_val.local_shards()) + num_new_shards = len(new_state_val.local_shards()) + if num_shards != num_new_shards: + raise ValueError( + f"Different number of shards {num_shards} vs {num_new_shards} for {param_key}/{state_key}" + ) + for shard, new_shard in zip( + state_val.local_shards(), new_state_val.local_shards() + ): + shard.tensor.detach().copy_(new_shard.tensor) + else: + state[param][state_key] = deepcopy(new_state_val) + + # Load param_groups. + if len(param_groups) != len(new_param_groups): + raise ValueError( + f"Different param_groups count: {len(param_groups)} vs {len(new_param_groups)}" + ) + param_to_key = {param: key for key, param in params.items()} + group_map = {} + for group in param_groups: + param_keys = [] + for param in group["params"]: + param_keys.append(param_to_key[param]) + group_map["/".join(sorted(param_keys))] = group + new_group_map = {} + for new_group in new_param_groups: + param_keys = [] + for param_key in new_group["params"]: + param_keys.append(param_key) + new_group_map["/".join(sorted(param_keys))] = new_group + for group_key, group in group_map.items(): + if group_key not in new_group_map: + raise ValueError(f"Group {group_key} not found") + new_group = new_group_map[group_key] + if len(group) != len(new_group): + raise ValueError( + f"Different param_group size: {len(group)} vs {len(new_group)}" + ) + for k, v in group.items(): + if k not in new_group: + raise ValueError(f"Group key {k} not found for group {group_key}") + if k != "params": + group[k] = deepcopy(new_group[k]) + + self.post_load_state_dict() + + # pyre-ignore [2] + def add_param_group(self, param_group: Any) -> None: + raise NotImplementedError() + + +class CombinedOptimizer(KeyedOptimizer): + """ + Combines multiple optimizers into one. + """ + + def __init__( + self, optims: List[Union[KeyedOptimizer, Tuple[str, KeyedOptimizer]]] + ) -> None: + self.defaults: Dict[str, Any] = {} + # Append empty optimizer key if not passed. + self._optims: List[Tuple[str, KeyedOptimizer]] = [] + for key_value in optims: + if isinstance(key_value, KeyedOptimizer): + key_value = ("", key_value) + self._optims.append(key_value) + + all_keys: Set[str] = set() + for opt_key, opt in self._optims: + for param_key in opt.params.keys(): + new_param = CombinedOptimizer._prepend_opt_key(param_key, opt_key) + if new_param in all_keys: + raise ValueError(f"Duplicate param key {new_param}") + all_keys.add(new_param) + + def __repr__(self) -> str: + ret = [] + for _, opt in self._optims: + ret.append(opt.__repr__()) + return ",".join(ret) + + def zero_grad(self, set_to_none: bool = False) -> None: + for _, opt in self._optims: + # pyre-ignore [28] + opt.zero_grad(set_to_none=set_to_none) + + # pyre-ignore [2] + def step(self, closure: Any = None) -> None: + for _, opt in self._optims: + opt.step(closure=closure) + + @property + def optimizers(self) -> List[Tuple[str, KeyedOptimizer]]: + return self._optims + + @staticmethod + def _prepend_opt_key(name: str, opt_key: str) -> str: + return opt_key + ("." if opt_key else "") + name + + @property + def param_groups(self) -> Collection[Mapping[str, Any]]: + return [ + param_group for _, opt in self._optims for param_group in opt.param_groups + ] + + @property + def params(self) -> Mapping[str, torch.Tensor]: + ret = {} + for opt_key, opt in self._optims: + for param_key, param in opt.params.items(): + ret[CombinedOptimizer._prepend_opt_key(param_key, opt_key)] = param + return ret + + @property + # pyre-ignore [3] + def state(self) -> Mapping[torch.Tensor, Any]: + ret = {} + for _, opt in self._optims: + for param, state in opt.state.items(): + ret[param] = state + return ret + + def post_load_state_dict(self) -> None: + for _, opt in self._optims: + opt.post_load_state_dict() + + +class KeyedOptimizerWrapper(KeyedOptimizer): + """ + Takes a dict of parameters and exposes state_dict by parameter key. + """ + + def __init__( + self, + params: Mapping[str, torch.Tensor], + optim_factory: OptimizerFactory, + ) -> None: + self._optimizer: optim.Optimizer = optim_factory(list(params.values())) + super().__init__(params, self._optimizer.state, self._optimizer.param_groups) + + def zero_grad(self, set_to_none: bool = False) -> None: + self._optimizer.zero_grad() + + # pyre-ignore [2] + def step(self, closure: Any = None) -> None: + self._optimizer.step(closure=closure) + + +class OptimizerWrapper(KeyedOptimizer): + def __init__(self, optimizer: KeyedOptimizer) -> None: + self._optimizer = optimizer + self.params: Mapping[str, torch.Tensor] = optimizer.params + # pyre-ignore [4] + self.state: Mapping[Any, Any] = optimizer.state + self.param_groups: Collection[Mapping[str, Any]] = optimizer.param_groups + + def __repr__(self) -> str: + return self._optimizer.__repr__() + + def zero_grad(self, set_to_none: bool = False) -> None: + # pyre-ignore [28] + self._optimizer.zero_grad(set_to_none=set_to_none) + + # pyre-ignore [2] + def step(self, closure: Any = None) -> None: + self._optimizer.step(closure=closure) + + # pyre-ignore [2] + def add_param_group(self, param_group: Any) -> None: + raise NotImplementedError() + + def state_dict(self) -> Dict[str, Any]: + return self._optimizer.state_dict() + + def post_load_state_dict(self) -> None: + self._optimizer.post_load_state_dict() + + def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + self._optimizer.load_state_dict(state_dict) + # Reassign references because self._optimizer receives new state and param_group + # references after load_state_dict. + self.state = self._optimizer.state + self.param_groups = self._optimizer.param_groups + + self.post_load_state_dict() diff --git a/torchrec/optim/tests/test_clipping.py b/torchrec/optim/tests/test_clipping.py new file mode 100644 index 000000000..6a5a2e18b --- /dev/null +++ b/torchrec/optim/tests/test_clipping.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 + +import unittest + +import torch +from torch.autograd import Variable +from torchrec.optim.clipping import GradientClippingOptimizer, GradientClipping +from torchrec.optim.tests.test_utils import DummyKeyedOptimizer + + +class TestGradientClippingOptimizer(unittest.TestCase): + def test_clip_all_gradients_norm(self) -> None: + # Clip all gradients to zero + param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) + + keyed_optimizer = DummyKeyedOptimizer( + {"param_1": param_1}, {}, [{"params": [param_1]}] + ) + + gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=keyed_optimizer, max_gradient=0.0, clipping=GradientClipping.NORM + ) + + gradient_clipping_optimizer.zero_grad() + param_1.grad = torch.tensor([1.0, 2.0]) + gradient_clipping_optimizer.step() + + self.assertTrue(torch.equal(param_1.grad, torch.tensor([0.0, 0.0]))) + + def test_clip_no_gradients_norm(self) -> None: + # gradients are too small to be clipped + param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) + + keyed_optimizer = DummyKeyedOptimizer( + {"param_1": param_1}, {}, [{"params": [param_1]}] + ) + + gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=keyed_optimizer, max_gradient=1.0, clipping=GradientClipping.NORM + ) + + gradient_clipping_optimizer.zero_grad() + param_1.grad = torch.tensor([0.5, 0.5]) + gradient_clipping_optimizer.step() + + self.assertTrue(torch.equal(param_1.grad, torch.tensor([0.5, 0.5]))) + + def test_clip_partial_gradients_norm(self) -> None: + # test partial clipping + param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) + + keyed_optimizer = DummyKeyedOptimizer( + {"param_1": param_1}, {}, [{"params": [param_1]}] + ) + + gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=keyed_optimizer, max_gradient=1.0, clipping=GradientClipping.NORM + ) + + gradient_clipping_optimizer.zero_grad() + + param_1.grad = torch.tensor([2.0, 4.0]) + gradient_clipping_optimizer.step() + + norm = 2.0 ** 2 + 4.0 ** 2 + expected_grad = torch.tensor([2.0, 4.0]) * norm ** (-0.5) + self.assertTrue(torch.allclose(param_1.grad, expected_grad)) + + def test_clip_partial_gradients_norm_multi_params(self) -> None: + # test partial clipping + max_gradient = 2.0 + param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) + param_2 = Variable(torch.tensor([2.0, 4.0]), requires_grad=True) + + keyed_optimizer = DummyKeyedOptimizer( + {"param_1": param_1, "param_2": param_2}, + {}, + [{"params": [param_1]}, {"params": [param_2]}], + ) + + gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=keyed_optimizer, + max_gradient=max_gradient, + clipping=GradientClipping.NORM, + ) + + gradient_clipping_optimizer.zero_grad() + + param_1.grad = torch.tensor([2.0, 4.0]) + param_2.grad = torch.tensor([4.0, 8.0]) + + gradient_clipping_optimizer.step() + + print(param_1.grad, param_2.grad) + + norm = (2.0 ** 2 + 4.0 ** 2 + 4.0 ** 2 + 8.0 ** 2) ** (-0.5) + expected_grad_1 = torch.tensor([2.0, 4.0]) * norm * max_gradient + expected_grad_2 = torch.tensor([4.0, 8.0]) * norm * max_gradient + + print(param_1.grad, param_2.grad, expected_grad_1, expected_grad_2) + + self.assertTrue(torch.allclose(param_1.grad, expected_grad_1)) + self.assertTrue(torch.allclose(param_2.grad, expected_grad_2)) + + def test_clip_all_gradients_value(self) -> None: + # Clip all gradients to zero + param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) + + keyed_optimizer = DummyKeyedOptimizer( + {"param_1": param_1}, {}, [{"params": [param_1]}] + ) + + gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=keyed_optimizer, max_gradient=0, clipping=GradientClipping.VALUE + ) + + gradient_clipping_optimizer.zero_grad() + param_1.grad = torch.tensor([1.0, 2.0]) + gradient_clipping_optimizer.step() + + self.assertTrue(torch.equal(param_1.grad, torch.tensor([0.0, 0.0]))) + + def test_clip_no_gradients_value(self) -> None: + # gradients are too small to be clipped + param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) + + keyed_optimizer = DummyKeyedOptimizer( + {"param_1": param_1}, {}, [{"params": [param_1]}] + ) + + gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=keyed_optimizer, max_gradient=1.0, clipping=GradientClipping.VALUE + ) + + gradient_clipping_optimizer.zero_grad() + param_1.grad = torch.tensor([0.5, 0.5]) + gradient_clipping_optimizer.step() + + self.assertTrue(torch.equal(param_1.grad, torch.tensor([0.5, 0.5]))) + + def test_clip_gradients_value(self) -> None: + # test partial clipping + param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) + + keyed_optimizer = DummyKeyedOptimizer( + {"param_1": param_1}, {}, [{"params": [param_1]}] + ) + + gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=keyed_optimizer, max_gradient=1, clipping=GradientClipping.VALUE + ) + + gradient_clipping_optimizer.zero_grad() + + param_1.grad = torch.tensor([2.0, 4.0]) + gradient_clipping_optimizer.step() + + expected_grad = torch.tensor([1.0, 1.0]) + + self.assertTrue(torch.allclose(param_1.grad, expected_grad)) + + def test_clip_partial_gradients_value_multi_params(self) -> None: + # test partial clipping + max_gradient = 2.0 + param_1 = Variable(torch.tensor([1.0, 2.0]), requires_grad=True) + param_2 = Variable(torch.tensor([2.0, 4.0]), requires_grad=True) + + keyed_optimizer = DummyKeyedOptimizer( + {"param_1": param_1, "param_2": param_2}, + {}, + [{"params": [param_1]}, {"params": [param_2]}], + ) + + gradient_clipping_optimizer = GradientClippingOptimizer( + optimizer=keyed_optimizer, + max_gradient=max_gradient, + clipping=GradientClipping.VALUE, + ) + + gradient_clipping_optimizer.zero_grad() + + param_1.grad = torch.tensor([2.0, 4.0]) + param_2.grad = torch.tensor([4.0, 8.0]) + + gradient_clipping_optimizer.step() + + expected_grad_1 = torch.tensor([2.0, 2.0]) + expected_grad_2 = torch.tensor([2.0, 2.0]) + + self.assertTrue(torch.allclose(param_1.grad, expected_grad_1)) + self.assertTrue(torch.allclose(param_2.grad, expected_grad_2)) diff --git a/torchrec/optim/tests/test_keyed.py b/torchrec/optim/tests/test_keyed.py new file mode 100644 index 000000000..926035dae --- /dev/null +++ b/torchrec/optim/tests/test_keyed.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 + +import os +import unittest +from typing import Dict, Any + +import torch +import torch.distributed as dist +from torch.autograd import Variable +from torchrec.optim.keyed import ( + CombinedOptimizer, + KeyedOptimizer, + OptimizerWrapper, +) +from torchrec.tests.utils import get_free_port + + +class TestKeyedOptimizer(unittest.TestCase): + def _assert_state_dict_equals( + self, dict1: Dict[str, Any], dict2: Dict[str, Any] + ) -> None: + self.assertEqual(dict1["param_groups"], dict2["param_groups"]) + self.assertEqual( + dict1["state"]["param_2"], + dict2["state"]["param_2"], + ) + torch.testing.assert_close( + dict1["state"]["param_1"]["tensor"], + dict2["state"]["param_1"]["tensor"], + ) + + torch.testing.assert_close( + dict1["state"]["param_1"]["sharded_tensor"].local_shards()[0].tensor, + dict2["state"]["param_1"]["sharded_tensor"].local_shards()[0].tensor, + ) + + def test_load_state_dict(self) -> None: + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + dist.init_process_group("gloo", rank=0, world_size=1) + + # Set up example KeyedOptimizer. + param_1_t, param_2_t = torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0]) + param_1, param_2 = Variable(param_1_t), Variable(param_2_t) + keyed_optimizer = KeyedOptimizer( + {"param_1": param_1, "param_2": param_2}, + { + param_1: { + "one": 1.0, + "tensor": torch.tensor([5.0, 6.0]), + "sharded_tensor": dist._sharded_tensor.full( + # pyre-ignore [28] + dist._sharded_tensor.ChunkShardingSpec( + dim=0, placements=["rank:0/cpu"] + ), + (4,), + fill_value=1.0, + ), + }, + param_2: {"two": 2.0}, + }, + [ + { + "params": [param_1], + "param_group_val_0": 3.0, + "param_group_val_1": 4.0, + }, + { + "params": [param_2], + "param_group_val_0": 5.0, + "param_group_val_1": 6.0, + }, + ], + ) + + # Assert state_dict is as expected. + expected_state_dict = { + "state": { + "param_1": { + "one": 1.0, + "tensor": torch.tensor([5.0, 6.0]), + "sharded_tensor": dist._sharded_tensor.full( + # pyre-ignore [28] + dist._sharded_tensor.ChunkShardingSpec( + dim=0, placements=["rank:0/cpu"] + ), + (4,), + fill_value=1.0, + ), + }, + "param_2": {"two": 2.0}, + }, + "param_groups": [ + { + "params": ["param_1"], + "param_group_val_0": 3.0, + "param_group_val_1": 4.0, + }, + { + "params": ["param_2"], + "param_group_val_0": 5.0, + "param_group_val_1": 6.0, + }, + ], + } + self._assert_state_dict_equals( + expected_state_dict, keyed_optimizer.state_dict() + ) + + # Modify state dict and call load_state_dict. + # pyre-ignore [6] + expected_state_dict["state"]["param_1"]["one"] = 10.0 + # pyre-ignore [6] + expected_state_dict["state"]["param_1"]["tensor"] = torch.tensor([50.0, 60.0]) + # pyre-ignore [6] + expected_state_dict["state"]["param_1"][ + "sharded_tensor" + ] = dist._sharded_tensor.full( + # pyre-ignore [28] + dist._sharded_tensor.ChunkShardingSpec(dim=0, placements=["rank:0/cpu"]), + (4,), + fill_value=10.0, + ) + # pyre-ignore [6] + expected_state_dict["param_groups"][0]["param_group_val_0"] = 8.0 + # pyre-ignore [6] + expected_state_dict["param_groups"][1]["param_group_val_1"] = 9.0 + + keyed_optimizer.load_state_dict(expected_state_dict) + self._assert_state_dict_equals( + expected_state_dict, keyed_optimizer.state_dict() + ) + + def test_non_param_state_key(self) -> None: + with self.assertRaisesRegex(ValueError, "All state keys must be params."): + param_1_t = torch.tensor([1.0, 2.0]) + param_1 = Variable(param_1_t) + KeyedOptimizer( + {"param_1": param_1}, + {param_1: 1.0, "non_param_state_key": 2.0}, + [{"params": [param_1], "param_group_val_0": 3.0}], + ) + + +class TestCombinedOptimizer(unittest.TestCase): + def test_load_state_dict(self) -> None: + # Set up example KeyedOptimizer 1. + param_1_t = torch.tensor([1.0, 2.0]) + param_1 = Variable(param_1_t) + keyed_optimizer_1 = KeyedOptimizer( + {"param_1": param_1}, + {param_1: {"one": 1.0}}, + [{"params": [param_1], "param_group_val_0": 2.0}], + ) + + # Set up example KeyedOptimizer 2. + param_2_t = torch.tensor([-1.0, -2.0]) + param_2 = Variable(param_2_t) + keyed_optimizer_2 = KeyedOptimizer( + {"param_2": param_2}, + {param_2: {"two": -1.0}}, + [{"params": [param_2], "param_group_val_0": -2.0}], + ) + + combined_optimizer = CombinedOptimizer( + [("ko1", keyed_optimizer_1), ("", keyed_optimizer_2)] + ) + + combined_optimizer_state_dict = combined_optimizer.state_dict() + combined_optimizer_state_dict["state"]["ko1.param_1"] = {"one": 999} + combined_optimizer_state_dict["state"]["param_2"] = {"two": 998} + combined_optimizer_state_dict["param_groups"][0]["param_group_val_0"] = 997 + combined_optimizer_state_dict["param_groups"][1]["param_group_val_0"] = 996 + + combined_optimizer.load_state_dict(combined_optimizer_state_dict) + + # Check that optimizers in the combined optimizer have their state and + # param_groups updated. + self.assertEqual(keyed_optimizer_1.state[param_1], {"one": 999}) + self.assertEqual(keyed_optimizer_2.state[param_2], {"two": 998}) + # pyre-ignore[16] + self.assertEqual(keyed_optimizer_1.param_groups[0]["param_group_val_0"], 997) + self.assertEqual(keyed_optimizer_2.param_groups[0]["param_group_val_0"], 996) + + +class TestOptimizerWrapper(unittest.TestCase): + def test_load_state_dict(self) -> None: + param_1_t = torch.tensor([1.0, 2.0]) + param_1 = Variable(param_1_t) + keyed_optimizer = KeyedOptimizer( + {"param_1": param_1}, + {param_1: {"one": 1.0}}, + [{"params": [param_1], "param_group_val_0": 2.0}], + ) + optimizer_wrapper = OptimizerWrapper(keyed_optimizer) + + optimizer_wrapper_state_dict = optimizer_wrapper.state_dict() + optimizer_wrapper_state_dict["state"]["param_1"] = {"one": 999} + optimizer_wrapper_state_dict["param_groups"][0]["param_group_val_0"] = 998 + optimizer_wrapper.load_state_dict(optimizer_wrapper_state_dict) + + # Check that both keyed_optimizer and optimizer_wrapper have their state and + # param_groups updated. + self.assertEqual(keyed_optimizer.state[param_1], {"one": 999}) + self.assertEqual(optimizer_wrapper.state[param_1], {"one": 999}) + # pyre-ignore[16] + self.assertEqual(keyed_optimizer.param_groups[0]["param_group_val_0"], 998) + self.assertEqual(optimizer_wrapper.param_groups[0]["param_group_val_0"], 998) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchrec/optim/tests/test_utils.py b/torchrec/optim/tests/test_utils.py new file mode 100644 index 000000000..0f3dbd81d --- /dev/null +++ b/torchrec/optim/tests/test_utils.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 + +from typing import Any + +from torchrec.optim.keyed import KeyedOptimizer + + +class DummyKeyedOptimizer(KeyedOptimizer): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + # pyre-ignore[2] + def step(self, closure: Any) -> None: + pass # Override NotImplementedError. diff --git a/torchrec/optim/tests/test_warmup.py b/torchrec/optim/tests/test_warmup.py new file mode 100644 index 000000000..7536213f8 --- /dev/null +++ b/torchrec/optim/tests/test_warmup.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +import unittest +from collections import defaultdict +from typing import Any + +import torch +from torch.autograd import Variable +from torchrec.optim.keyed import KeyedOptimizer +from torchrec.optim.warmup import WarmupOptimizer, WarmupStage, WarmupPolicy + + +class DummyKeyedOptimizer(KeyedOptimizer): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + # pyre-ignore[2] + def step(self, closure: Any) -> None: + pass # Override NotImplementedError. + + +class TestWarmupOptimizer(unittest.TestCase): + def test_load_state_dict(self) -> None: + def get_optimizer() -> WarmupOptimizer: + param_1_t = torch.tensor([1.0, 2.0]) + param_1 = Variable(param_1_t) + keyed_optimizer = DummyKeyedOptimizer( + {"param_1": param_1}, defaultdict(dict), [{"params": [param_1]}] + ) + warmup_optimizer = WarmupOptimizer( + keyed_optimizer, + stages=[ + WarmupStage( + WarmupPolicy.LINEAR, max_iters=100, value=1e-2, lr_scale=1 + ), + ], + ) + return warmup_optimizer + + warmup_optimizer_1 = get_optimizer() + num_iters = 10 + for _ in range(num_iters): + warmup_optimizer_1.zero_grad() + warmup_optimizer_1.step() + + param_state = list(warmup_optimizer_1.state.values())[0] + self.assertEquals( + param_state["warmup"].tolist()[0], + num_iters, + ) + + warmup_optimizer_2 = get_optimizer() + warmup_optimizer_2.step() + warmup_optimizer_2.zero_grad() + + warmup_optimizer_2.load_state_dict(warmup_optimizer_1.state_dict()) + + self.assertEqual( + warmup_optimizer_1.state_dict()["param_groups"], + warmup_optimizer_2.state_dict()["param_groups"], + ) + torch.testing.assert_close( + warmup_optimizer_1.state_dict()["state"]["__warmup"], + warmup_optimizer_2.state_dict()["state"]["__warmup"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchrec/optim/warmup.py b/torchrec/optim/warmup.py new file mode 100644 index 000000000..f48cf1ac9 --- /dev/null +++ b/torchrec/optim/warmup.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 + +import logging +import math +from dataclasses import dataclass +from enum import Enum, unique +from typing import List, Any, Tuple + +import torch +from torchrec.optim.keyed import OptimizerWrapper, KeyedOptimizer + +logger: logging.Logger = logging.getLogger(__name__) + + +@unique +class WarmupPolicy(Enum): + NONE = "none" + LINEAR = "linear" + CONSTANT = "constant" + POLY = "poly" + STEP = "step" + INVSQRT = "inv_sqrt" # inverse square root + + +@dataclass +class WarmupStage: + policy: WarmupPolicy = WarmupPolicy.LINEAR + max_iters: int = 1 + value: float = 1.0 + lr_scale: float = 1.0 + # used as number denominator for iters in poly decay + # default to max_iters if not set to value > 0 + # also used as stepsize in step decay + # default to 1 if not set to value > 0 + decay_iters: int = -1 + + +def _lr_stages(stages: List[WarmupStage]) -> List[WarmupStage]: + last_stage = WarmupStage(policy=WarmupPolicy.NONE, max_iters=1 << 63, value=1.0) + if len(stages) == 0: + return [last_stage] + + start_iter = 0 + for stage in stages: + assert stage.max_iters > start_iter, ( + f"Max iter of the stage {stage} must be greater than the previous " + f"max iter {start_iter}" + ) + start_iter = stage.max_iters + if stage.decay_iters <= 0: + if stage.policy == WarmupPolicy.STEP: + stage.decay_iters = 1 + else: + stage.decay_iters = stage.max_iters + return stages + [last_stage] + + +def _get_multiplier(stage: WarmupStage, iter: int) -> float: + multiplier = 1.0 + if stage.policy == WarmupPolicy.LINEAR: + multiplier = stage.value + (1.0 - stage.value) * iter / stage.max_iters + elif stage.policy == WarmupPolicy.CONSTANT: + multiplier = stage.value + elif stage.policy == WarmupPolicy.POLY: + multiplier = math.pow(1 - iter / stage.decay_iters, stage.value) + elif stage.policy == WarmupPolicy.STEP: + multiplier = math.pow(stage.value, iter // stage.decay_iters) + elif stage.policy == WarmupPolicy.INVSQRT: + multiplier = 1.0 / math.sqrt(iter) + return multiplier * stage.lr_scale + + +class WarmupOptimizer(OptimizerWrapper): + """ + Adjusts learning rate according to the schedule. + + Constructor Args: + optimizer (KeyedOptimizer): optimizer to wrap + stages (List[WarmupStage]): stages to go through + lr (float): initial learning rate + lr_param (str): learning rate parameter in parameter group. + param_name: Name of fake parameter to hold warmup state. + """ + + def __init__( + self, + optimizer: KeyedOptimizer, + stages: List[WarmupStage], + lr: float = 0.1, + lr_param: str = "lr", + param_name: str = "__warmup", + ) -> None: + super().__init__(optimizer) + self._stages: List[WarmupStage] = _lr_stages(stages) + self._lr_param: str = lr_param + self._lr: float = lr + self._warmup_param: torch.nn.Parameter = torch.nn.Parameter() + # pyre-ignore [16] + self.params[param_name] = self._warmup_param + # for fused optimizer we will do first backward() pass before calling step() + self._set_lr(0, 0) + + def _set_lr(self, iter_: int, stage_id: int) -> None: + lr = self._lr * _get_multiplier(self._stages[stage_id], iter_) + for param_group in self.param_groups: + # pyre-ignore [16] + param_group[self._lr_param] = lr + + def _get_warmup_state(self) -> Tuple[int, int]: + if self._warmup_param in self.state: + iter_, stage_id = self.state[self._warmup_param]["warmup"].tolist() + else: + iter_ = 0 + stage_id = 0 + return iter_, stage_id + + def post_load_state_dict(self) -> None: + iter_, stage_id = self._get_warmup_state() + self._set_lr(iter_, stage_id) + + # pyre-ignore [2] + def step(self, closure: Any = None) -> None: + super().step(closure) + iter_, stage_id = self._get_warmup_state() + + iter_ += 1 + if iter_ > self._stages[stage_id].max_iters and stage_id + 1 < len( + self._stages + ): + stage_id += 1 + logger.info( + "Optimizer finishing " + f"{self._stages[stage_id - 1]} " + "switching to " + f"{self._stages[stage_id]}" + ) + self._set_lr(iter_, stage_id) + + # pyre-ignore [16] + self.state[self._warmup_param] = { + "warmup": torch.tensor([iter_, stage_id], dtype=torch.long) + } diff --git a/torchrec/quant/__init__.py b/torchrec/quant/__init__.py new file mode 100644 index 000000000..31faec675 --- /dev/null +++ b/torchrec/quant/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 + +from torchrec.quant.embedding_modules import EmbeddingBagCollection # noqa diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py new file mode 100644 index 000000000..9cc69cb56 --- /dev/null +++ b/torchrec/quant/embedding_modules.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 + +import copy +from collections import OrderedDict +from typing import Dict, Any, Optional, List, Iterator, Tuple + +import torch +import torch.nn as nn +from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_table_batched_embeddings_ops import ( + PoolingMode, + IntNBitTableBatchedEmbeddingBagsCodegen, + EmbeddingLocation, +) +from torch import Tensor +from torchrec.modules.embedding_configs import ( + EmbeddingBagConfig, + PoolingType, + DataType, + DATA_TYPE_NUM_BITS, +) +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection as OriginalEmbeddingBagCollection, +) +from torchrec.modules.embedding_modules import EmbeddingBagCollectionInterface +from torchrec.sparse.jagged_tensor import ( + KeyedJaggedTensor, + KeyedTensor, +) + +torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + + +class EmbeddingBagCollection(EmbeddingBagCollectionInterface): + def __init__( + self, + table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]], + embedding_configs: List[EmbeddingBagConfig], + is_weighted: bool, + device: torch.device, + ) -> None: + def to_pooling_mode(pooling_type: PoolingType) -> PoolingMode: + if pooling_type == PoolingType.SUM: + return PoolingMode.SUM + else: + assert pooling_type == PoolingType.MEAN + return PoolingMode.MEAN + + def to_sparse_type(data_type: DataType) -> SparseType: + if data_type == DataType.FP16: + return SparseType.FP16 + elif data_type == DataType.INT8: + return SparseType.INT8 + elif data_type == DataType.INT4: + return SparseType.INT4 + elif data_type == DataType.INT2: + return SparseType.INT2 + else: + raise ValueError(f"Invalid DataType {data_type}") + + super().__init__() + + self._is_weighted = is_weighted + self._embedding_bag_configs: List[EmbeddingBagConfig] = embedding_configs + self.embedding_bags: nn.ModuleList[nn.Module] = nn.ModuleList() + for emb_config in self._embedding_bag_configs: + emb_module = IntNBitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + "", + emb_config.num_embeddings, + emb_config.embedding_dim, + to_sparse_type(emb_config.data_type), + EmbeddingLocation.HOST + if device.type == "cpu" + else EmbeddingLocation.DEVICE, + ) + ], + pooling_mode=to_pooling_mode(emb_config.pooling), + weight_lists=[table_name_to_quantized_weights[emb_config.name]], + device=device, + ) + + self.embedding_bags.append(emb_module) + + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedTensor: + keys: List[str] = [] + pooled_embeddings: List[Tensor] = [] + length_per_key: List[int] = [] + for emb_config, emb_module in zip( + self._embedding_bag_configs, self.embedding_bags + ): + for feature_name in emb_config.feature_names: + keys.append(feature_name) + + values = features[feature_name].values() + offsets = features[feature_name].offsets() + weights = features[feature_name].weights_or_none() + pooled_embeddings.append( + emb_module( + indices=values.int(), + offsets=offsets.int(), + per_sample_weights=weights, + ).float() + ) + + length_per_key.append(emb_config.embedding_dim) + + return KeyedTensor( + keys=features.keys(), + values=torch.cat(pooled_embeddings, dim=1), + length_per_key=length_per_key, + ) + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + for emb_config, emb_module in zip( + self._embedding_bag_configs, + self.embedding_bags, + ): + (weight, _) = emb_module.split_embedding_weights(split_scale_shifts=False)[ + 0 + ] + destination[prefix + f"embedding_bags.{emb_config.name}.weight"] = weight + return destination + + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + state_dict = self.state_dict(prefix=prefix, keep_vars=True) + for key, value in state_dict.items(): + yield key, value + + def _get_name(self) -> str: + return "QuantizedEmbeddingBagCollection" + + @classmethod + def from_float( + cls, module: OriginalEmbeddingBagCollection + ) -> "EmbeddingBagCollection": + assert hasattr( + module, "qconfig" + ), "EmbeddingBagCollection input float module must have qconfig defined" + + def _to_data_type(dtype: torch.dtype) -> DataType: + if dtype == torch.quint8 or dtype == torch.qint8: + return DataType.INT8 + elif dtype == torch.quint4 or dtype == torch.qint4: + return DataType.INT4 + elif dtype == torch.quint2 or dtype == torch.qint2: + return DataType.INT2 + else: + raise Exception(f"Invalid data type {dtype}") + + # pyre-ignore [16] + data_type = _to_data_type(module.qconfig.weight().dtype) + embedding_bag_configs = copy.deepcopy(module.embedding_bag_configs) + for config in embedding_bag_configs: + config.data_type = data_type + + table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]] = {} + device = torch.device("cpu") + for key, tensor in module.state_dict().items(): + # Extract table name from state dict key. + # e.g. ebc.embedding_bags.t1.weight + splits = key.split(".") + assert splits[-1] == "weight" + table_name = splits[-2] + + num_bits = DATA_TYPE_NUM_BITS[data_type] + device = tensor.device + if tensor.is_meta: + quant_weight = torch.empty( + (tensor.shape[0], (tensor.shape[1] * num_bits) // 8), + device="meta", + dtype=module.qconfig.weight().dtype, + ) + scale_shift = torch.empty( + (tensor.shape[0], 4), + device="meta", + dtype=module.qconfig.weight().dtype, + ) + else: + quant_res = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf( + tensor, num_bits + ) + quant_weight, scale_shift = ( + quant_res[:, :-4], + quant_res[:, -4:], + ) + table_name_to_quantized_weights[table_name] = (quant_weight, scale_shift) + + return cls( + table_name_to_quantized_weights, + embedding_bag_configs, + module.is_weighted, + device=device, + ) + + @property + def embedding_bag_configs( + self, + ) -> List[EmbeddingBagConfig]: + return self._embedding_bag_configs + + @property + def is_weighted(self) -> bool: + return self._is_weighted diff --git a/torchrec/quant/tests/test_embedding_modules.py b/torchrec/quant/tests/test_embedding_modules.py new file mode 100644 index 000000000..5a380e4dc --- /dev/null +++ b/torchrec/quant/tests/test_embedding_modules.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +import unittest + +import torch +from torchrec.modules.embedding_configs import ( + EmbeddingBagConfig, +) +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, +) +from torchrec.quant.embedding_modules import ( + EmbeddingBagCollection as QuantEmbeddingBagCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class EmbeddingBagCollectionTest(unittest.TestCase): + def test_ebc(self) -> None: + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=16, num_embeddings=10, feature_names=["f1"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", embedding_dim=16, num_embeddings=10, feature_names=["f2"] + ) + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + + features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.as_tensor([0, 1]), + lengths=torch.as_tensor([1, 1]), + ) + embeddings = ebc(features) + + # test forward + # pyre-ignore [16] + ebc.qconfig = torch.quantization.QConfig( + activation=torch.quantization.PlaceholderObserver.with_args( + dtype=torch.qint8 + ), + weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), + ) + + qebc = QuantEmbeddingBagCollection.from_float(ebc) + quantized_embeddings = qebc(features) + + self.assertEqual(embeddings.keys(), quantized_embeddings.keys()) + self.assertEqual(embeddings["f1"].shape, quantized_embeddings["f1"].shape) + self.assertTrue( + torch.allclose( + embeddings["f1"].cpu(), + quantized_embeddings["f1"].cpu().float(), + atol=1, + ) + ) + self.assertTrue( + torch.allclose( + embeddings["f2"].cpu(), + quantized_embeddings["f2"].cpu().float(), + atol=1, + ) + ) + + # test state dict + state_dict = ebc.state_dict() + quantized_state_dict = qebc.state_dict() + self.assertEqual(state_dict.keys(), quantized_state_dict.keys()) diff --git a/torchrec/sparse/__init__.py b/torchrec/sparse/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py new file mode 100644 index 000000000..b5c70a984 --- /dev/null +++ b/torchrec/sparse/jagged_tensor.py @@ -0,0 +1,1018 @@ +#!/usr/bin/env python3 + +import abc +from typing import Optional, List, Dict, Tuple + +import torch +import torch.fx +from torchrec.types import Pipelineable + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +def _cumsum(o: List[int]) -> List[int]: + ret = [0] * (len(o) + 1) + for i in range(len(o)): + ret[i + 1] = ret[i] + o[i] + return ret + + +def _to_offsets(lengths: torch.Tensor) -> torch.Tensor: + return torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + + +def _to_lengths(offsets: torch.Tensor) -> torch.Tensor: + return offsets[1:] - offsets[:-1] + + +def _maybe_compute_lengths( + lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor] +) -> torch.Tensor: + if lengths is None: + assert offsets is not None + lengths = _to_lengths(offsets) + return lengths + + +def _maybe_compute_offsets( + lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor] +) -> torch.Tensor: + if offsets is None: + assert lengths is not None + offsets = _to_offsets(lengths) + return offsets + + +def _get_weights_or_throw(weights: Optional[torch.Tensor]) -> torch.Tensor: + assert weights is not None, "This (Keyed)JaggedTensor doesn't have weights." + return weights + + +def _assert_offsets_or_lengths_is_provided( + offsets: Optional[torch.Tensor], lengths: Optional[torch.Tensor] +) -> None: + assert offsets is not None or lengths is not None, "Must provide lengths or offsets" + + +def _regroup_keyed_tensors( + keyed_tensors: List["KeyedTensor"], groups: List[List[str]] +) -> List[torch.Tensor]: + # Shortcut for no re-grouping + if len(keyed_tensors) == len(groups): + match = True + for kt, group in zip(keyed_tensors, groups): + if kt.keys() != group: + match = False + break + if match: + return [kt.values() for kt in keyed_tensors] + + embedding_dicts = [keyed_tensor.to_dict() for keyed_tensor in keyed_tensors] + lengths = [keyed_tensor.length_per_key() for keyed_tensor in keyed_tensors] + indices = [keyed_tensor._key_indices() for keyed_tensor in keyed_tensors] + key_dim = keyed_tensors[0].key_dim() + + key_to_idx: dict[str, int] = {} + for (i, keyed_tensor) in enumerate(keyed_tensors): + for key in keyed_tensor.keys(): + key_to_idx[key] = i + + # Rearrange values based on groups with a single torch.cat operation. + cat_input: List[torch.Tensor] = [] + for group in groups: + for name in group: + cat_input.append(embedding_dicts[key_to_idx[name]][name]) + rearranged_values = torch.cat(cat_input, key_dim) + + # Provide views over the rearranged values with a single torch.split operation. + split_lengths: List[int] = [] + for group in groups: + group_length = 0 + for name in group: + group_length += lengths[key_to_idx[name]][indices[key_to_idx[name]][name]] + split_lengths.append(group_length) + + return list(rearranged_values.split(split_lengths, dim=key_dim)) + + +torch.fx.wrap("_regroup_keyed_tensors") + + +def _values_string(values: torch.Tensor, start: int, end: int) -> str: + return "[" + ", ".join([str(value.item()) for value in values[start:end]]) + "]" + + +def _jagged_values_string( + values: torch.Tensor, + offsets: torch.Tensor, + offset_start: int, + offset_end: int, +) -> str: + return ( + "[" + + ", ".join( + [ + _values_string(values, offsets[index], offsets[index + 1]) + for index in range(offset_start, offset_end) + ] + ) + + "]" + ) + + +class JaggedTensorMeta(abc.ABCMeta, torch.fx.ProxyableClassMeta): + pass + + +class JaggedTensor(Pipelineable, metaclass=JaggedTensorMeta): + """Represents an (optionally weighted) jagged tensor + + A `JaggedTensor` is a tensor with a *jagged dimension* which is dimension whose + slices may be of different lengths. See KeyedJaggedTensor for full example. + + Implementation is torch.jit.script-able + """ + + def __init__( + self, + values: torch.Tensor, + weights: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> None: + self._values: torch.Tensor = values + self._weights: Optional[torch.Tensor] = weights + _assert_offsets_or_lengths_is_provided(offsets, lengths) + if offsets is not None: + _assert_tensor_has_no_elements_or_has_integers(offsets, "offsets") + if lengths is not None: + _assert_tensor_has_no_elements_or_has_integers(lengths, "lengths") + self._lengths: Optional[torch.Tensor] = lengths + self._offsets: Optional[torch.Tensor] = offsets + + @staticmethod + def empty(is_weighted: bool = False) -> "JaggedTensor": + weights = torch.tensor([]) if is_weighted else None + return JaggedTensor( + values=torch.tensor([]), + offsets=torch.tensor([]), + lengths=torch.tensor([]), + weights=weights, + ) + + @staticmethod + def from_dense_lengths( + values: torch.Tensor, + lengths: torch.Tensor, + weights: Optional[torch.Tensor] = None, + ) -> "JaggedTensor": + """ + Constructs `JaggedTensor` from dense values/weights of shape (B, N,). + + Note that `lengths` is still of shape (B,). + """ + mask2d = torch.arange(values.size(1), device=values.device).expand( + values.size(0), -1 + ) < lengths.unsqueeze(-1) + return JaggedTensor( + values=values[mask2d], + weights=weights[mask2d] if weights is not None else None, + lengths=lengths, + ) + + def lengths(self) -> torch.Tensor: + _lengths = _maybe_compute_lengths(self._lengths, self._offsets) + self._lengths = _lengths + return _lengths + + def offsets(self) -> torch.Tensor: + _offsets = _maybe_compute_offsets(self._lengths, self._offsets) + self._offsets = _offsets + return _offsets + + def values(self) -> torch.Tensor: + return self._values + + def weights(self) -> torch.Tensor: + return _get_weights_or_throw(self._weights) + + def weights_or_none(self) -> Optional[torch.Tensor]: + return self._weights + + def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor": + weights = self._weights + lengths = self._lengths + offsets = self._offsets + return JaggedTensor( + values=self._values.to(device, non_blocking=non_blocking), + weights=weights.to(device, non_blocking=non_blocking) + if weights is not None + else None, + lengths=lengths.to(device, non_blocking=non_blocking) + if lengths is not None + else None, + offsets=offsets.to(device, non_blocking=non_blocking) + if offsets is not None + else None, + ) + + # pyre-ignore [56] + @torch.jit.unused + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + self._values.record_stream(stream) + weights = self._weights + lengths = self._lengths + offsets = self._offsets + if weights is not None: + weights.record_stream(stream) + if lengths is not None: + lengths.record_stream(stream) + if offsets is not None: + offsets.record_stream(stream) + + def __str__(self) -> str: + offsets = self.offsets() + + if self._weights is None: + return ( + "JaggedTensor({\n " + + _jagged_values_string(self._values, offsets, 0, len(offsets) - 1) + + "\n})\n" + ) + + return ( + "JaggedTensor({\n" + + ' "values": ' + + _jagged_values_string(self._values, offsets, 0, len(offsets) - 1) + + ',\n "weights": ' + + _jagged_values_string( + _get_weights_or_throw(self._weights), offsets, 0, len(offsets) - 1 + ) + + "\n})\n" + ) + + +def _assert_tensor_has_no_elements_or_has_integers( + tensor: torch.Tensor, tensor_name: str +) -> None: + assert tensor.numel() == 0 or tensor.dtype in [ + torch.long, + torch.int, + torch.short, + torch.int8, + torch.uint8, + ], "{} must be of integer type, but got {}".format(tensor_name, tensor.dtype) + + +def _maybe_compute_index_per_key( + keys: List[str], + index_per_key: Optional[Dict[str, int]], +) -> Dict[str, int]: + if index_per_key is None: + index_per_key = {key: i for i, key in enumerate(keys)} + return index_per_key + + +def _maybe_compute_stride_kjt( + keys: List[str], + stride: Optional[int], + lengths: Optional[torch.Tensor], + offsets: Optional[torch.Tensor], +) -> int: + if stride is None: + if len(keys) == 0: + stride = 0 + elif offsets is not None: + stride = (offsets.numel() - 1) // len(keys) + elif lengths is not None: + stride = lengths.numel() // len(keys) + else: + stride = 1 + return stride + + +def _maybe_compute_length_per_key( + keys: List[str], + stride: int, + length_per_key: Optional[List[int]], + lengths: Optional[torch.Tensor], + offsets: Optional[torch.Tensor], +) -> List[int]: + if length_per_key is None: + if len(keys) and offsets is not None: + _length: List[int] = ( + torch.sum((offsets[1:] - offsets[:-1]).view(-1, stride), dim=1) + .cpu() + .tolist() + ) + elif len(keys) and lengths is not None: + _length: List[int] = ( + torch.sum(lengths.view(-1, stride), dim=1).cpu().tolist() + ) + else: + _length: List[int] = [] + length_per_key = _length + return length_per_key + + +def _maybe_compute_offset_per_key( + keys: List[str], + stride: int, + length_per_key: Optional[List[int]], + offset_per_key: Optional[List[int]], + lengths: Optional[torch.Tensor], + offsets: Optional[torch.Tensor], +) -> Tuple[List[int], List[int]]: + if length_per_key is None: + _length_per_key: List[int] = _maybe_compute_length_per_key( + keys, stride, length_per_key, lengths, offsets + ) + return _length_per_key, _cumsum(_length_per_key) + elif offset_per_key is None: + return length_per_key, _cumsum(length_per_key) + else: + return length_per_key, offset_per_key + + +def _jagged_tensor_string( + key: str, + values: torch.Tensor, + weights: Optional[torch.Tensor], + offsets: torch.Tensor, + offset_start: int, + offset_end: int, +) -> str: + if weights is None: + return '"{}": '.format(key) + _jagged_values_string( + values, offsets, offset_start, offset_end + ) + + return ( + '"{}"'.format(key) + + ': {\n "values": ' + + _jagged_values_string(values, offsets, offset_start, offset_end) + + ',\n "weights": ' + + _jagged_values_string( + _get_weights_or_throw(weights), offsets, offset_start, offset_end + ) + + "\n }" + ) + + +def _kjt_to_jt_dict( + stride: int, + keys: List[str], + length_per_key: List[int], + values: torch.Tensor, + lengths: torch.Tensor, + offsets: torch.Tensor, + weights: Optional[torch.Tensor], +) -> Dict[str, JaggedTensor]: + jt_dict: Dict[str, JaggedTensor] = {} + values_list = torch.split(values, length_per_key) + lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0) + if weights is not None: + weights_list = torch.split(weights, length_per_key) + for idx, key in enumerate(keys): + length = lengths_tuple[idx] + offset = _to_offsets(length) + jt_dict[key] = JaggedTensor( + lengths=length, + offsets=offset, + values=values_list[idx], + weights=weights_list[idx], + ) + else: + for idx, key in enumerate(keys): + length = lengths_tuple[idx] + offset = _to_offsets(length) + jt_dict[key] = JaggedTensor( + lengths=length, + offsets=offset, + values=values_list[idx], + ) + return jt_dict + + +def _merge_weights_or_none( + a_weights: Optional[torch.Tensor], + b_weights: Optional[torch.Tensor], +) -> Optional[torch.Tensor]: + assert not ( + (a_weights is None) ^ (b_weights is None) + ), "Can only merge weighted or unweighted KJTs." + if a_weights is None: + return None + # pyre-ignore[6] + return torch.cat([a_weights, b_weights], dim=0) + + +torch.fx.wrap("_merge_weights_or_none") + + +class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta): + """Represents an (optionally weighted) keyed jagged tensor + + A `JaggedTensor` is a tensor with a *jagged dimension* which is dimension whose + slices may be of different lengths. Keyed on first dimesion, and jagged on last dimension + + For example: + 0 1 2 <-- dim_1 + "Feature0" [V0,V1] None [V2] + "Feature1" [V3] [V4] [V5,V6,V7] + ^ + dim_0 + + dim_0: keyed dimension (ie. `Feature0`, `Feature1`) + dim_1: optional second dimension (ie. batch size) + dim_2: The jagged dimension which has slice lengths between 0-3 in the above example + + We represent this data with following inputs: + + values: torch.Tensor = [V0, V1, V2, V3, V4, V5, V6, V7], V == any tensor datatype + weights: torch.Tensor = [W0, W1, W2, W3, W4, W5, W6, W7], W == any tensor datatype + lengths: torch.Tensor = [2, 0, 1, 1, 1, 3], representing the jagged slice + offsets: torch.Tensor = [0, 2, 2, 3, 4, 5, 8], offsets from 0 for each jagged slice + keys: List[int] = ["Feature0", "Feature1"], which corresponds to each value of dim_0 + index_per_key: Dict[str, int] = {"Feature0": 0, "Feature1: 1}, index for key in keys + offset_per_key: List[int] = [0, 3, 8], start offset for each key and final offset + + + Implementation is torch.jit.script-able + """ + + def __init__( + self, + keys: List[str], + values: torch.Tensor, + weights: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + stride: Optional[int] = None, + # Below exposed to ensure torch.script-able + length_per_key: Optional[List[int]] = None, + offset_per_key: Optional[List[int]] = None, + index_per_key: Optional[Dict[str, int]] = None, + ) -> None: + self._keys: List[str] = keys + self._values: torch.Tensor = values + self._weights: Optional[torch.Tensor] = weights + if offsets is not None: + _assert_tensor_has_no_elements_or_has_integers(offsets, "offsets") + if lengths is not None: + _assert_tensor_has_no_elements_or_has_integers(lengths, "lengths") + self._lengths: Optional[torch.Tensor] = lengths + self._offsets: Optional[torch.Tensor] = offsets + stride = _maybe_compute_stride_kjt(keys, stride, lengths, offsets) + self._stride: int = stride + + # lazy fields + self._length_per_key: Optional[List[int]] = length_per_key + self._offset_per_key: Optional[List[int]] = offset_per_key + self._index_per_key: Optional[Dict[str, int]] = index_per_key + + @staticmethod + def from_offsets_sync( + keys: List[str], + values: torch.Tensor, + offsets: torch.Tensor, + weights: Optional[torch.Tensor] = None, + stride: Optional[int] = None, + ) -> "KeyedJaggedTensor": + kjt = KeyedJaggedTensor( + keys=keys, + values=values, + weights=weights, + offsets=offsets, + stride=stride, + ) + return kjt.sync() + + @staticmethod + def from_lengths_sync( + keys: List[str], + values: torch.Tensor, + lengths: torch.Tensor, + weights: Optional[torch.Tensor] = None, + stride: Optional[int] = None, + ) -> "KeyedJaggedTensor": + kjt = KeyedJaggedTensor( + keys=keys, + values=values, + weights=weights, + lengths=lengths, + stride=stride, + ) + return kjt.sync() + + @staticmethod + def concat( + a: "KeyedJaggedTensor", + b: "KeyedJaggedTensor", + ) -> "KeyedJaggedTensor": + if a.stride() != b.stride(): + raise ValueError( + f"Can only merge KJTs of the same stride ({a.stride()}, {b.stride()})." + ) + length_per_key = ( + a._length_per_key + b._length_per_key + if a._length_per_key is not None and b._length_per_key is not None + else None + ) + + return KeyedJaggedTensor( + keys=a.keys() + b.keys(), + values=torch.cat([a.values(), b.values()], dim=0), + weights=_merge_weights_or_none(a.weights_or_none(), b.weights_or_none()), + lengths=torch.cat([a.lengths(), b.lengths()], dim=0), + stride=a.stride(), + length_per_key=length_per_key, + ) + + @staticmethod + def empty( + is_weighted: bool = False, device: Optional[torch.device] = None + ) -> "KeyedJaggedTensor": + weights = None + if is_weighted is True: + weights = torch.tensor([], device=device) if device else torch.tensor([]) + + return KeyedJaggedTensor( + keys=[], + values=torch.tensor([], device=device) if device else torch.tensor([]), + weights=weights, + stride=0, + ) + + @staticmethod + def empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor": + return KeyedJaggedTensor( + keys=[], + values=torch.tensor([], device=kjt.device(), dtype=kjt.values().dtype), + weights=None + if kjt.weights_or_none() is None + else torch.tensor([], device=kjt.device(), dtype=kjt.weights().dtype), + lengths=None, + offsets=None, + stride=kjt.stride(), + ) + + def sync(self) -> "KeyedJaggedTensor": + self.length_per_key() + self.offset_per_key() + return self + + def device(self) -> torch.device: + return self._values.device + + def lengths(self) -> torch.Tensor: + _lengths = _maybe_compute_lengths(self._lengths, self._offsets) + self._lengths = _lengths + return _lengths + + def offsets(self) -> torch.Tensor: + _offsets = _maybe_compute_offsets(self._lengths, self._offsets) + self._offsets = _offsets + return _offsets + + def keys(self) -> List[str]: + return self._keys + + def values(self) -> torch.Tensor: + return self._values + + def weights(self) -> torch.Tensor: + return _get_weights_or_throw(self._weights) + + def weights_or_none(self) -> Optional[torch.Tensor]: + return self._weights + + def stride(self) -> int: + return self._stride + + def _key_indices(self) -> Dict[str, int]: + _index_per_key: Dict[str, int] = _maybe_compute_index_per_key( + self._keys, + self._index_per_key, + ) + self._index_per_key = _index_per_key + return _index_per_key + + def length_per_key(self) -> List[int]: + _length_per_key = _maybe_compute_length_per_key( + self._keys, + self.stride(), + self._length_per_key, + self._lengths, + self._offsets, + ) + self._length_per_key = _length_per_key + return _length_per_key + + def offset_per_key(self) -> List[int]: + _length_per_key, _offset_per_key = _maybe_compute_offset_per_key( + self._keys, + self.stride(), + self._length_per_key, + self._offset_per_key, + self._lengths, + self._offsets, + ) + self._length_per_key = _length_per_key + self._offset_per_key = _offset_per_key + return _offset_per_key + + def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: + split_list: List[KeyedJaggedTensor] = [] + start = 0 + start_offset = 0 + _length_per_key = self.length_per_key() + _offset_per_key = self.offset_per_key() + for segment in segments: + end = start + segment + end_offset = _offset_per_key[end] + keys: List[str] = self._keys[start:end] + if segment == len(self._keys): + # no torch slicing required + split_list.append( + KeyedJaggedTensor( + keys=self._keys, + values=self._values, + weights=self.weights_or_none(), + lengths=self._lengths, + offsets=self._offsets, + stride=self._stride, + length_per_key=self._length_per_key, + offset_per_key=self._offset_per_key, + index_per_key=self._index_per_key, + ) + ) + elif segment == 0: + split_list.append( + KeyedJaggedTensor( + keys=keys, + values=torch.tensor( + [], device=self.device(), dtype=self._values.dtype + ), + weights=None + if self.weights_or_none() is None + else torch.tensor( + [], + device=self.device(), + dtype=self.weights().dtype, + ), + lengths=torch.tensor([], device=self.device(), dtype=torch.int), + offsets=torch.tensor([], device=self.device(), dtype=torch.int), + stride=self._stride, + length_per_key=None, + offset_per_key=None, + index_per_key=None, + ) + ) + else: + split_length_per_key = _length_per_key[start:end] + split_list.append( + KeyedJaggedTensor( + keys=keys, + values=self._values[start_offset:end_offset], + weights=None + if self.weights_or_none() is None + else self.weights()[start_offset:end_offset], + lengths=self.lengths()[ + start * self._stride : end * self._stride + ], + offsets=None, + stride=self._stride, + length_per_key=split_length_per_key, + offset_per_key=None, + index_per_key=None, + ) + ) + start = end + start_offset = end_offset + return split_list + + def permute( + self, indices: List[int], indices_tensor: Optional[torch.Tensor] = None + ) -> "KeyedJaggedTensor": + + if indices_tensor is None: + indices_tensor = torch.tensor( + indices, dtype=torch.int, device=self.device() + ) + + length_per_key = self.length_per_key() + permuted_keys: List[str] = [] + permuted_length_per_key: List[int] = [] + permuted_lengths_sum = 0 + seen: Dict[str, int] = {} + for index in indices: + key = self._keys[index] + count = seen.get(key, 0) + permuted_keys.append(key) + permuted_lengths_sum += length_per_key[index] + permuted_length_per_key.append(length_per_key[index]) + seen[key] = count + 1 + + ( + permuted_lengths, + permuted_values, + permuted_weights, + ) = torch.ops.fbgemm.permute_sparse_data( + indices_tensor, + self.lengths().view(len(self._keys), -1), + self.values(), + self.weights_or_none(), + permuted_lengths_sum, + ) + + kjt = KeyedJaggedTensor( + keys=permuted_keys, + values=permuted_values, + weights=permuted_weights, + lengths=permuted_lengths.view(-1), + offsets=None, + stride=self._stride, + length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None, + offset_per_key=None, + index_per_key=None, + ) + return kjt + + def __getitem__(self, key: str) -> JaggedTensor: + offset_per_key = self.offset_per_key() + index = self._key_indices()[key] + start_offset = offset_per_key[index] + end_offset = offset_per_key[index + 1] + return JaggedTensor( + values=self._values[start_offset:end_offset], + weights=None + if self.weights_or_none() is None + else self.weights()[start_offset:end_offset], + lengths=self.lengths()[index * self._stride : (index + 1) * self._stride], + offsets=None, + ) + + def to_dict(self) -> Dict[str, JaggedTensor]: + return _kjt_to_jt_dict( + self.stride(), + self.keys(), + self.length_per_key(), + self.values(), + self.lengths(), + self.offsets(), + self.weights_or_none(), + ) + + # pyre-ignore [56] + @torch.jit.unused + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + self._values.record_stream(stream) + weights = self._weights + lengths = self._lengths + offsets = self._offsets + if weights is not None: + weights.record_stream(stream) + if lengths is not None: + lengths.record_stream(stream) + if offsets is not None: + offsets.record_stream(stream) + + def to( + self, device: torch.device, non_blocking: bool = False + ) -> "KeyedJaggedTensor": + weights = self._weights + lengths = self._lengths + offsets = self._offsets + length_per_key = self._length_per_key + offset_per_key = self._offset_per_key + index_per_key = self._index_per_key + + return KeyedJaggedTensor( + keys=self._keys, + values=self._values.to(device, non_blocking=non_blocking), + weights=weights.to(device, non_blocking=non_blocking) + if weights is not None + else None, + lengths=lengths.to(device, non_blocking=non_blocking) + if lengths is not None + else None, + offsets=offsets.to(device, non_blocking=non_blocking) + if offsets is not None + else None, + stride=self._stride, + length_per_key=length_per_key, + offset_per_key=offset_per_key, + index_per_key=index_per_key, + ) + + def __str__(self) -> str: + if self._offsets is None and self._lengths is None: + return "KeyedJaggedTensor()\n" + offsets = self.offsets() + + step = (len(offsets) - 1) // len(self._keys) + return ( + "KeyedJaggedTensor({\n" + + ",\n".join( + [ + " " + + _jagged_tensor_string( + self._keys[index], + self._values, + self._weights, + offsets, + index * step, + (index + 1) * step, + ) + for index in range(len(self._keys)) + ] + ) + + "\n})\n" + ) + + def pin_memory(self) -> "KeyedJaggedTensor": + weights = self._weights + lengths = self._lengths + offsets = self._offsets + + return KeyedJaggedTensor( + keys=self._keys, + values=self._values.pin_memory(), + weights=weights.pin_memory() if weights is not None else None, + lengths=lengths.pin_memory() if lengths is not None else None, + offsets=offsets.pin_memory() if offsets is not None else None, + stride=self._stride, + length_per_key=self._length_per_key, + offset_per_key=self._offset_per_key, + index_per_key=self._index_per_key, + ) + + +def _maybe_compute_offset_per_key_kt( + length_per_key: List[int], + offset_per_key: Optional[List[int]], +) -> List[int]: + if offset_per_key is None: + offset_per_key = _cumsum(length_per_key) + return offset_per_key + + +def _keyed_values_string(values: torch.Tensor) -> str: + return ( + "[" + + ", ".join([_values_string(value, 0, len(value)) for value in values]) + + "]" + ) + + +class KeyedTensor(Pipelineable, metaclass=JaggedTensorMeta): + """ + KeyedTensor holds a concatenated list of dense tensors + each of which can be accessed by a key. + Keyed dimension can be variable length (length_per_key). + Common use cases uses include storage of pooled embeddings of different dimensions. + + Constructor Args: + keys (List[str]): list of keys + length_per_key (List[int]): length of each key along key dimension + values (torch.Tensor): dense tensor, concatenated typically along key dimension + key_dim (int): key dimension, zero indexed - defaults to 1 (typically B is 0-dimension) + + Implementation is torch.jit.script-able + + + Example: + kt is KeyedTensor holding + + 0 1 2 + "Embedding A" [1,1] [1,1] [1,1] + "Embedding B" [2,1,2] [2,1,2] [2,1,2] + "Embedding C" [3,1,2,3] [3,1,2,3] [3,1,2,3] + >>> tensor_list = [ + torch.tensor([[1,1]] * 3), + torch.tensor([[2,1,2]] * 3), + torch.tensor([[3,1,2,3]] * 3), + ] + >>> keys = ["Embedding A", "Embedding B", "Embedding C"] + >>> kt = KeyedTensor.from_tensor_list(keys, tensor_list) + >>> kt.values() + tensor([[1, 1, 2, 1, 2, 3, 1, 2, 3], + [1, 1, 2, 1, 2, 3, 1, 2, 3], + [1, 1, 2, 1, 2, 3, 1, 2, 3]]) + >>> kt["Embedding B"] + tensor([[2, 1, 2], + [2, 1, 2], + [2, 1, 2]]) + """ + + def __init__( + self, + keys: List[str], + length_per_key: List[int], + values: torch.Tensor, + key_dim: int = 1, + # Below exposed to ensure torch.script-able + offset_per_key: Optional[List[int]] = None, + index_per_key: Optional[Dict[str, int]] = None, + ) -> None: + self._keys = keys + self._length_per_key = length_per_key + self._values = values + self._key_dim = key_dim + + self._offset_per_key: Optional[List[int]] = offset_per_key + self._index_per_key: Optional[Dict[str, int]] = index_per_key + + @staticmethod + def from_tensor_list( + keys: List[str], tensors: List[torch.Tensor], key_dim: int = 1, cat_dim: int = 1 + ) -> "KeyedTensor": + length_per_key = [tensor.shape[key_dim] for tensor in tensors] + return KeyedTensor( + keys=keys, + length_per_key=length_per_key, + values=torch.cat(tensors, dim=cat_dim), + key_dim=key_dim, + ) + + def keys(self) -> List[str]: + return self._keys + + def values(self) -> torch.Tensor: + return self._values + + def key_dim(self) -> int: + return self._key_dim + + def offset_per_key(self) -> List[int]: + _offset_per_key = _maybe_compute_offset_per_key_kt( + self._length_per_key, + self._offset_per_key, + ) + self._offset_per_key = _offset_per_key + return _offset_per_key + + def length_per_key(self) -> List[int]: + return self._length_per_key + + def _key_indices(self) -> Dict[str, int]: + _index_per_key = _maybe_compute_index_per_key( + self._keys, + self._index_per_key, + ) + self._index_per_key = _index_per_key + return _index_per_key + + def __getitem__(self, key: str) -> torch.Tensor: + index = self._key_indices()[key] + start = self.offset_per_key()[index] + length = self._length_per_key[index] + # pyre-ignore [16]: Undefined attribute `torch.Tensor` has no attribute `narrow` + return self._values.narrow(dim=self._key_dim, start=start, length=length) + + def to_dict(self) -> Dict[str, torch.Tensor]: + indices = self._key_indices() + lengths = self._length_per_key + split_values = self._values.split(lengths, dim=self._key_dim) + return {key: split_values[index] for (key, index) in indices.items()} + + @staticmethod + def regroup( + keyed_tensors: List["KeyedTensor"], groups: List[List[str]] + ) -> List[torch.Tensor]: + return _regroup_keyed_tensors(keyed_tensors, groups) + + # pyre-ignore [56] + @torch.jit.unused + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + self._values.record_stream(stream) + + def to(self, device: torch.device, non_blocking: bool = False) -> "KeyedTensor": + return KeyedTensor( + keys=self._keys, + length_per_key=self._length_per_key, + values=self._values.to(device, non_blocking=non_blocking), + key_dim=self._key_dim, + offset_per_key=self._offset_per_key, + index_per_key=self._index_per_key, + ) + + def __str__(self) -> str: + if len(self._keys) == 0: + return "KeyedTensor()\n" + + return ( + "KeyedTensor({\n" + + ",\n".join( + [ + ' "{}": '.format(key) + _keyed_values_string(self[key]) + for key in self._keys + ] + ) + + "\n})\n" + ) diff --git a/torchrec/sparse/tests/__init__.py b/torchrec/sparse/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py new file mode 100644 index 000000000..5a8ff94b3 --- /dev/null +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -0,0 +1,846 @@ +#!/usr/bin/env python3 + + +import unittest +from typing import List + +import torch +from torchrec.sparse.jagged_tensor import ( + JaggedTensor, + KeyedTensor, + KeyedJaggedTensor, +) + + +class TestJaggedTensor(unittest.TestCase): + def test_from_dense_lengths(self) -> None: + values = torch.Tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]] + ) + weights = 12.0 - values + j0 = JaggedTensor.from_dense_lengths( + values=values, + lengths=torch.IntTensor([1, 0, 2, 3]), + ) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([1, 0, 2, 3]))) + self.assertTrue( + torch.equal(j0.values(), torch.Tensor([1.0, 7.0, 8.0, 10.0, 11.0, 12.0])) + ) + self.assertTrue(j0.weights_or_none() is None) + j1 = JaggedTensor.from_dense_lengths( + values=values, + lengths=torch.IntTensor([2, 0, 1, 1]), + weights=weights, + ) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([2, 0, 1, 1]))) + self.assertTrue(torch.equal(j1.values(), torch.Tensor([1.0, 2.0, 7.0, 10.0]))) + self.assertTrue(torch.equal(j1.weights(), torch.Tensor([11.0, 10.0, 5.0, 2.0]))) + + def test_key_lookup(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + ) + j0 = jag_tensor["index_0"] + j1 = jag_tensor["index_1"] + + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0]))) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3]))) + self.assertTrue( + torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0])) + ) + + def test_split(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + ) + j0, j1 = jag_tensor.split([1, 1]) + + self.assertTrue(isinstance(j0, KeyedJaggedTensor)) + self.assertEqual(j0.keys(), ["index_0"]) + self.assertEqual(j1.keys(), ["index_1"]) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0]))) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3]))) + self.assertTrue( + torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0])) + ) + + def test_length_vs_offset(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]) + lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3]) + + j_offset = KeyedJaggedTensor.from_offsets_sync( + values=values, + keys=keys, + offsets=offsets, + ) + + j_lens = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + ) + + self.assertTrue(torch.equal(j_offset.lengths(), j_lens.lengths())) + # TODO: T88149179 + self.assertTrue(torch.equal(j_offset.offsets(), j_lens.offsets().int())) + + def test_concat(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]) + keys = ["index_0", "index_1", "index_2"] + lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0, 0, 1, 0]) + + kjt_expected = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + ) + kjt_actual = KeyedJaggedTensor.concat( + a=KeyedJaggedTensor.from_lengths_sync( + values=values[:4], + keys=keys[:1], + lengths=lengths[:4], + ), + b=KeyedJaggedTensor.from_lengths_sync( + values=values[4:], + keys=keys[1:], + lengths=lengths[4:], + ), + ) + self.assertTrue(torch.equal(kjt_expected.lengths(), kjt_actual.lengths())) + self.assertTrue(torch.equal(kjt_expected.offsets(), kjt_actual.offsets())) + self.assertTrue(torch.equal(kjt_expected.values(), kjt_actual.values())) + + def test_empty(self) -> None: + values = torch.Tensor() + keys = [] + offsets = torch.Tensor() + + KeyedJaggedTensor.from_offsets_sync(values=values, keys=keys, offsets=offsets) + + def test_2d(self) -> None: + values = torch.Tensor([[i * 0.5, i * 1.0, i * 1.5] for i in range(1, 9)]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + j = KeyedJaggedTensor.from_offsets_sync( + values=values, + keys=keys, + offsets=offsets, + ) + j_0 = j["index_0"] + + self.assertTrue(torch.equal(j_0.lengths(), torch.IntTensor([2, 0, 1]))) + self.assertTrue( + torch.equal( + j_0.values(), + torch.Tensor( + [ + [0.5, 1.0, 1.5], + [1.0, 2.0, 3.0], + [1.5, 3.0, 4.5], + ], + ), + ) + ) + + def test_float_lengths_offsets_throws(self) -> None: + values = torch.rand((7, 3)) + lengths = torch.tensor([3.0, 4.0]) + offsets = torch.tensor([0.0, 3.0, 7.0]) + + with self.assertRaises(AssertionError): + JaggedTensor(values=values, lengths=lengths) + with self.assertRaises(AssertionError): + JaggedTensor(values=values, offsets=offsets) + + def test_to(self) -> None: + j = JaggedTensor( + offsets=torch.tensor([0, 2, 2, 3]), + values=torch.tensor([0.5, 1.0, 1.5]), + weights=torch.tensor([5.0, 10.0, 15.0]), + ) + j2 = j.to(device=torch.device("cpu")) + self.assertTrue(torch.equal(j.offsets(), j2.offsets())) + self.assertTrue(torch.equal(j.lengths(), j2.lengths())) + self.assertTrue(torch.equal(j.values(), j2.values())) + self.assertTrue(torch.equal(j.weights(), j2.weights())) + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "CUDA is not available", + ) + def test_record_stream(self) -> None: + j = JaggedTensor( + offsets=torch.tensor([0, 2, 2, 3]), + values=torch.tensor([0.5, 1.0, 1.5]), + weights=torch.tensor([5.0, 10.0, 15.0]), + ).to(torch.device("cuda")) + j.record_stream(torch.cuda.current_stream()) + + def test_string_basic(self) -> None: + values = torch.Tensor([1.0]) + offsets = torch.IntTensor([0, 1]) + + jag_tensor = JaggedTensor( + values=values, + offsets=offsets, + ) + + self.assertEqual( + str(jag_tensor), + """\ +JaggedTensor({ + [[1.0]] +}) +""", + ) + + def test_string_values(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = JaggedTensor( + values=values, + offsets=offsets, + ) + + self.assertEqual( + str(jag_tensor), + """\ +JaggedTensor({ + [[1.0, 2.0], [], [3.0], [4.0], [5.0], [6.0, 7.0, 8.0]] +}) +""", + ) + + def test_string_weights(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = JaggedTensor( + values=values, + offsets=offsets, + weights=weights, + ) + + self.assertEqual( + str(jag_tensor), + """\ +JaggedTensor({ + "values": [[1.0, 2.0], [], [3.0], [4.0], [5.0], [6.0, 7.0, 8.0]], + "weights": [[1.0, 0.5], [], [1.5], [1.0], [0.5], [1.0, 1.0, 1.5]] +}) +""", + ) + + +class TestKeyedJaggedTensor(unittest.TestCase): + def test_key_lookup(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + j0 = jag_tensor["index_0"] + j1 = jag_tensor["index_1"] + + self.assertTrue(isinstance(j0, JaggedTensor)) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1]))) + self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5, 1.5]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0]))) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3]))) + self.assertTrue( + torch.equal(j1.weights(), torch.Tensor([1.0, 0.5, 1.0, 1.0, 1.5])) + ) + self.assertTrue( + torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0])) + ) + + def test_split(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + j0, j1 = jag_tensor.split([1, 1]) + + self.assertTrue(isinstance(j0, KeyedJaggedTensor)) + self.assertEqual(j0.keys(), ["index_0"]) + self.assertEqual(j1.keys(), ["index_1"]) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([2, 0, 1]))) + self.assertTrue(torch.equal(j0.weights(), torch.Tensor([1.0, 0.5, 1.5]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([1.0, 2.0, 3.0]))) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([1, 1, 3]))) + self.assertTrue( + torch.equal(j1.weights(), torch.Tensor([1.0, 0.5, 1.0, 1.0, 1.5])) + ) + self.assertTrue( + torch.equal(j1.values(), torch.Tensor([4.0, 5.0, 6.0, 7.0, 8.0])) + ) + + def test_zero_split(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + j0, j1 = jag_tensor.split([0, 2]) + + self.assertTrue(isinstance(j0, KeyedJaggedTensor)) + self.assertEqual(j0.keys(), []) + self.assertTrue(torch.equal(j0.lengths(), torch.IntTensor([]))) + self.assertTrue(torch.equal(j0.weights(), torch.Tensor([]))) + self.assertTrue(torch.equal(j0.values(), torch.Tensor([]))) + self.assertEqual(j0.stride(), 3) + + self.assertEqual(j1.keys(), ["index_0", "index_1"]) + self.assertTrue(torch.equal(j1.lengths(), torch.IntTensor([2, 0, 1, 1, 1, 3]))) + self.assertTrue(torch.equal(j1.weights(), weights)) + self.assertTrue(torch.equal(j1.values(), values)) + self.assertEqual(j0.stride(), 3) + + def test_permute_w_weights(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0]) + keys = ["index_0", "index_1", "index_2"] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + weights=weights, + ) + + indices = [1, 0, 2] + permuted_jag_tensor = jag_tensor.permute(indices) + self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) + self.assertEqual( + permuted_jag_tensor.offset_per_key(), + [0, 3, 5, 8], + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.values(), + torch.Tensor([3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0]), + ) + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.lengths(), + torch.IntTensor([1, 1, 1, 0, 2, 0, 0, 3, 0]), + ) + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.weights(), + torch.Tensor([1.5, 1.0, 0.5, 1.0, 0.5, 1.0, 1.0, 1.5]), + ), + ) + + def test_permute(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0]) + keys = ["index_0", "index_1", "index_2"] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + ) + + indices = [1, 0, 2] + permuted_jag_tensor = jag_tensor.permute(indices) + + self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) + self.assertEqual( + permuted_jag_tensor.offset_per_key(), + [0, 3, 5, 8], + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.values(), + torch.Tensor([3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0]), + ) + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.lengths(), + torch.IntTensor([1, 1, 1, 0, 2, 0, 0, 3, 0]), + ) + ) + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) + + def test_permute_duplicates(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0]) + keys = ["index_0", "index_1", "index_2"] + + jag_tensor = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + ) + + indices = [1, 0, 2, 1, 1] + permuted_jag_tensor = jag_tensor.permute(indices) + + self.assertEqual( + permuted_jag_tensor.keys(), + ["index_1", "index_0", "index_2", "index_1", "index_1"], + ) + self.assertEqual( + permuted_jag_tensor.offset_per_key(), + [0, 3, 5, 8, 11, 14], + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.values(), + torch.Tensor( + [ + 3.0, + 4.0, + 5.0, + 1.0, + 2.0, + 6.0, + 7.0, + 8.0, + 3.0, + 4.0, + 5.0, + 3.0, + 4.0, + 5.0, + ] + ), + ) + ) + self.assertTrue( + torch.equal( + permuted_jag_tensor.lengths(), + torch.IntTensor([1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1]), + ) + ) + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) + + def test_length_vs_offset(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]) + lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3]) + + j_offset = KeyedJaggedTensor.from_offsets_sync( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + + j_lens = KeyedJaggedTensor.from_lengths_sync( + values=values, + keys=keys, + lengths=lengths, + weights=weights, + ) + + self.assertTrue(torch.equal(j_offset.lengths(), j_lens.lengths())) + # TO DO: T88149179 + self.assertTrue(torch.equal(j_offset.offsets(), j_lens.offsets().int())) + + def test_2d(self) -> None: + values = torch.Tensor([[i * 0.5, i * 1.0, i * 1.5] for i in range(1, 9)]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + j = KeyedJaggedTensor.from_offsets_sync( + values=values, + weights=weights, + keys=keys, + offsets=offsets, + ) + j_0 = j["index_0"] + + self.assertTrue(torch.equal(j_0.lengths(), torch.IntTensor([2, 0, 1]))) + self.assertTrue( + torch.equal( + j_0.values(), + torch.Tensor( + [ + [0.5, 1.0, 1.5], + [1.0, 2.0, 3.0], + [1.5, 3.0, 4.5], + ], + ), + ) + ) + + def test_float_lengths_offsets_throws(self) -> None: + values = torch.rand((7, 3)) + keys = ["f1", "f2"] + # torch.Tensor([3, 4]) also fails + # pyre-fixme[6]: Expected `Optional[typing.Type[torch._dtype]]` for 2nd + # param but got `Type[float]`. + lengths = torch.tensor([3, 4], dtype=float) + # pyre-fixme[6]: Expected `Optional[typing.Type[torch._dtype]]` for 2nd + # param but got `Type[float]`. + offsets = torch.tensor([0, 3, 7], dtype=float) + + with self.assertRaises(AssertionError): + KeyedJaggedTensor.from_lengths_sync( + keys=keys, values=values, lengths=lengths + ) + with self.assertRaises(AssertionError): + KeyedJaggedTensor.from_offsets_sync( + keys=keys, values=values, offsets=offsets + ) + + def test_scriptable(self) -> None: + class MyModule(torch.nn.Module): + def forward(self, input: KeyedJaggedTensor) -> torch.Tensor: + values = input["any"].values() + return values + + m = MyModule() + torch.jit.script(m) + + def test_to(self) -> None: + j = KeyedJaggedTensor.from_offsets_sync( + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + values=torch.arange(8), + weights=torch.arange(8 * 10), + keys=["index_0", "index_1"], + ) + j2 = j.to(device=torch.device("cpu")) + self.assertTrue(torch.equal(j.offsets(), j2.offsets())) + self.assertTrue(torch.equal(j.lengths(), j2.lengths())) + self.assertTrue(torch.equal(j.values(), j2.values())) + self.assertTrue(torch.equal(j.weights(), j2.weights())) + + def test_string_none(self) -> None: + jag_tensor = KeyedJaggedTensor( + torch.Tensor(), + [], + ) + + self.assertEqual( + str(jag_tensor), + """\ +KeyedJaggedTensor() +""", + ) + + def test_string_basic(self) -> None: + values = torch.Tensor([1.0]) + keys = ["key"] + offsets = torch.IntTensor([0, 1]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + ) + + self.assertEqual( + str(jag_tensor), + """\ +KeyedJaggedTensor({ + "key": [[1.0]] +}) +""", + ) + + def test_string_values(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + ) + + self.assertEqual( + str(jag_tensor), + """\ +KeyedJaggedTensor({ + "index_0": [[1.0, 2.0], [], [3.0]], + "index_1": [[4.0], [5.0], [6.0, 7.0, 8.0]] +}) +""", + ) + + def test_string_weights(self) -> None: + values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + weights = torch.Tensor([1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 1.0, 1.5]) + keys = ["index_0", "index_1"] + offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) + + jag_tensor = KeyedJaggedTensor( + values=values, + keys=keys, + offsets=offsets, + weights=weights, + ) + + self.assertEqual( + str(jag_tensor), + """\ +KeyedJaggedTensor({ + "index_0": { + "values": [[1.0, 2.0], [], [3.0]], + "weights": [[1.0, 0.5], [], [1.5]] + }, + "index_1": { + "values": [[4.0], [5.0], [6.0, 7.0, 8.0]], + "weights": [[1.0], [0.5], [1.0, 1.0, 1.5]] + } +}) +""", + ) + + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "CUDA is not available", + ) + def test_record_stream(self) -> None: + j = KeyedJaggedTensor.from_offsets_sync( + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + values=torch.arange(8), + weights=torch.arange(8 * 10), + keys=["index_0", "index_1"], + ).to(torch.device("cuda")) + j.record_stream(torch.cuda.current_stream()) + + +class TestKeyedTensor(unittest.TestCase): + def test_key_lookup(self) -> None: + tensor_list = [ + torch.Tensor([[1.0, 1.0]]), + torch.Tensor([[2.0, 2.0], [3.0, 3.0]]), + ] + keys = ["dense_0", "dense_1"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=0, key_dim=0) + self.assertEqual(kt.key_dim(), 0) + + self.assertTrue(torch.equal(kt["dense_0"], tensor_list[0])) + self.assertTrue(torch.equal(kt["dense_1"], tensor_list[1])) + + def test_key_lookup_dim_1(self) -> None: + tensor_list = [ + torch.tensor([[1.0, 1.0]]).T, + torch.tensor([[2.0, 2.0], [3.0, 3.0]]).T, + ] + keys = ["dense_0", "dense_1"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, key_dim=1) + self.assertEqual(kt.key_dim(), 1) + self.assertTrue(torch.equal(kt["dense_0"], tensor_list[0])) + self.assertTrue(torch.equal(kt["dense_1"], tensor_list[1])) + + def test_to_dict(self) -> None: + tensor_list = [ + torch.Tensor([[1.0, 1.0]]), + torch.Tensor([[2.0, 2.0], [3.0, 3.0]]), + ] + keys = ["dense_0", "dense_1"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, cat_dim=0, key_dim=0) + self.assertEqual(kt.key_dim(), 0) + + d = kt.to_dict() + for key in keys: + self.assertTrue(torch.equal(kt[key], d[key])) + + def test_to_dict_dim_1(self) -> None: + tensor_list = [ + torch.tensor([[1.0, 1.0]]).T, + torch.tensor([[2.0, 2.0], [3.0, 3.0]]).T, + ] + keys = ["dense_0", "dense_1"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, key_dim=1) + self.assertEqual(kt.key_dim(), 1) + + d = kt.to_dict() + for key in keys: + self.assertTrue(torch.equal(kt[key], d[key])) + + def test_regroup_single_kt(self) -> None: + tensor_list = [torch.randn(2, 3) for i in range(5)] + key_dim = 1 + keys = ["dense_0", "dense_1", "dense_2", "dense_3", "dense_4"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, key_dim) + grouped_tensors = KeyedTensor.regroup( + [kt], [["dense_0", "dense_4"], ["dense_1", "dense_3"], ["dense_2"]] + ) + self.assertTrue( + torch.equal( + grouped_tensors[0], torch.cat([tensor_list[0], tensor_list[4]], key_dim) + ) + ) + self.assertTrue( + torch.equal( + grouped_tensors[1], torch.cat([tensor_list[1], tensor_list[3]], key_dim) + ) + ) + self.assertTrue(torch.equal(grouped_tensors[2], tensor_list[2])) + + def test_regroup_multiple_kt(self) -> None: + key_dim = 1 + tensor_list_1 = [torch.randn(2, 3) for i in range(3)] + keys_1 = ["dense_0", "dense_1", "dense_2"] + kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) + tensor_list_2 = [torch.randn(2, 3) for i in range(2)] + keys_2 = ["sparse_0", "sparse_1"] + kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) + grouped_tensors = KeyedTensor.regroup( + [kt_1, kt_2], [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] + ) + self.assertTrue( + torch.equal( + grouped_tensors[0], + torch.cat( + [tensor_list_1[0], tensor_list_2[1], tensor_list_1[2]], key_dim + ), + ) + ) + self.assertTrue( + torch.equal( + grouped_tensors[1], + torch.cat([tensor_list_1[1], tensor_list_2[0]], key_dim), + ) + ) + + def test_regroup_scriptable(self) -> None: + class MyModule(torch.nn.Module): + def forward( + self, inputs: List[KeyedTensor], groups: List[List[str]] + ) -> List[torch.Tensor]: + return KeyedTensor.regroup(inputs, groups) + + m = MyModule() + torch.jit.script(m) + + def test_regroup_fxable(self) -> None: + class MyModule(torch.nn.Module): + def forward( + self, inputs: List[KeyedTensor], groups: List[List[str]] + ) -> List[torch.Tensor]: + return KeyedTensor.regroup(inputs, groups) + + m = MyModule() + + # input + key_dim = 1 + tensor_list_1 = [torch.randn(2, 3) for i in range(3)] + keys_1 = ["dense_0", "dense_1", "dense_2"] + kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim) + tensor_list_2 = [torch.randn(2, 3) for i in range(2)] + keys_2 = ["sparse_0", "sparse_1"] + kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim) + inputs = [kt_1, kt_2] + groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] + + # ensure that symbolic tracing works + gm = torch.fx.symbolic_trace(m) + results = m(inputs, groups) + traced_results = gm(inputs, groups) + self.assertEqual(len(results), len(traced_results)) + for result, traced_result in zip(results, traced_results): + self.assertTrue(torch.equal(result, traced_result)) + + def test_scriptable(self) -> None: + class MyModule(torch.nn.Module): + def forward(self, input: KeyedTensor) -> torch.Tensor: + values = input["any"].values() + return values + + m = MyModule() + torch.jit.script(m) + + def test_string_none(self) -> None: + jag_tensor = KeyedTensor( + [], + [], + torch.Tensor(), + ) + + self.assertEqual( + str(jag_tensor), + """\ +KeyedTensor() +""", + ) + + def test_string_basic(self) -> None: + tensor_list = [ + torch.tensor([[1.0]]), + ] + keys = ["key"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list, key_dim=0) + + self.assertEqual( + str(kt), + """\ +KeyedTensor({ + "key": [[1.0]] +}) +""", + ) + + def test_string_values(self) -> None: + tensor_list = [ + torch.tensor([[1.0, 1.0]]).T, + torch.tensor([[2.0, 2.0], [3.0, 3.0]]).T, + ] + keys = ["dense_0", "dense_1"] + kt = KeyedTensor.from_tensor_list(keys, tensor_list) + + self.assertEqual( + str(kt), + """\ +KeyedTensor({ + "dense_0": [[1.0], [1.0]], + "dense_1": [[2.0, 3.0], [2.0, 3.0]] +}) +""", + ) diff --git a/torchrec/sparse/tests/tests_utils.py b/torchrec/sparse/tests/tests_utils.py new file mode 100644 index 000000000..b61cfaeb7 --- /dev/null +++ b/torchrec/sparse/tests/tests_utils.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +from typing import Optional + +import torch +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +def keyed_jagged_tensor_equals( + kjt1: Optional[KeyedJaggedTensor], kjt2: Optional[KeyedJaggedTensor] +) -> bool: + def _tensor_eq_or_none( + t1: Optional[torch.Tensor], t2: Optional[torch.Tensor] + ) -> bool: + if t1 is None and t2 is None: + return True + elif t1 is None and t2 is not None: + return False + elif t1 is not None and t2 is None: + return False + else: + assert t1 is not None + assert t2 is not None + return torch.equal(t1, t2) and t1.dtype == t2.dtype + + if kjt1 is None and kjt2 is None: + return True + elif kjt1 is None and kjt2 is not None: + return False + elif kjt1 is not None and kjt2 is None: + return False + else: + assert kjt1 is not None + assert kjt2 is not None + return ( + kjt1.keys() == kjt2.keys() + and _tensor_eq_or_none(kjt1.lengths(), kjt2.lengths()) + and _tensor_eq_or_none(kjt1.values(), kjt2.values()) + and _tensor_eq_or_none(kjt1._weights, kjt2._weights) + ) diff --git a/torchrec/tests/__init__.py b/torchrec/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/tests/utils.py b/torchrec/tests/utils.py new file mode 100644 index 000000000..e1c6b92fc --- /dev/null +++ b/torchrec/tests/utils.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 + +import ctypes +import os +import random +import socket +import time +from contextlib import closing +from functools import wraps +from typing import TypeVar, Callable, Optional + +import numpy as np +import torch +import torch.distributed as dist +from pyre_extensions import ParameterSpecification + +TParams = ParameterSpecification("TParams") +TReturn = TypeVar("TReturn") + + +def get_free_port() -> int: + if socket.has_ipv6: + family = socket.AF_INET6 + address = "localhost6" + else: + family = socket.AF_INET + address = "localhost4" + with socket.socket(family, socket.SOCK_STREAM) as s: + try: + s.bind((address, 0)) + s.listen(0) + with closing(s): + return s.getsockname()[1] + except Exception as e: + raise Exception( + f"Binding failed with address {address} while getting free port {e}" + ) + + +def is_asan() -> bool: + """Determines if the Python interpreter is running with ASAN""" + return hasattr(ctypes.CDLL(""), "__asan_init") + + +def is_tsan() -> bool: + """Determines if the Python interpreter is running with TSAN""" + return hasattr(ctypes.CDLL(""), "__tsan_init") + + +def is_asan_or_tsan() -> bool: + return is_asan() or is_tsan() + + +def skip_if_asan( + func: Callable[TParams, TReturn] +) -> Callable[TParams, Optional[TReturn]]: + """Skip test run if we are in ASAN mode.""" + + @wraps(func) + def wrapper(*args: TParams.args, **kwargs: TParams.kwargs) -> Optional[TReturn]: + if is_asan_or_tsan(): + print("Skipping test run since we are in ASAN mode.") + return + return func(*args, **kwargs) + + return wrapper + + +def skip_if_asan_class(cls: TReturn) -> Optional[TReturn]: + if is_asan_or_tsan(): + print("Skipping test run since we are in ASAN mode.") + return + return cls + + +def init_distributed_single_host( + rank: int, world_size: int, backend: str, local_size: Optional[int] = None +) -> dist.ProcessGroup: + os.environ["LOCAL_WORLD_SIZE"] = str(local_size if local_size else world_size) + os.environ["LOCAL_RANK"] = str(rank % local_size if local_size else rank) + dist.init_process_group(rank=rank, world_size=world_size, backend=backend) + return dist.group.WORLD + + +# pyre-ignore [24] +def seed_and_log(wrapped_func: Callable) -> Callable: + # pyre-ignore [2, 3] + def _wrapper(*args, **kwargs): + seed = int(time.time() * 1000) % (1 << 31) + print(f"Using random seed: {seed}") + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + return wrapped_func(*args, **kwargs) + + return _wrapper diff --git a/torchrec/types.py b/torchrec/types.py new file mode 100644 index 000000000..a6afc738c --- /dev/null +++ b/torchrec/types.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +import abc + +import torch + + +class Multistreamable(abc.ABC): + """ + Objects implementing this interface are allowed to be transferred + from one CUDA stream to another. + torch.Tensor and (Keyed)JaggedTensor implement this interface. + """ + + @abc.abstractmethod + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + """ + See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html + """ + ... + + +class Pipelineable(Multistreamable): + """ + This interface contains two methods, one for moving an input across devices, + the other one for marking streams that operate the input. + + torch.Tensor implements this interface and we can used it in many applications. + Another example is torchrec.(Keyed)JaggedTensor, which we use as the input to + torchrec.EmbeddingBagCollection, which in turn is often the first layer of many models. + Some models take compound inputs, which should implement this interface. + """ + + @abc.abstractmethod + def to(self, device: torch.device, non_blocking: bool) -> "Pipelineable": + """ + Please be aware that accoarding to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, + to might return self or a copy of self. So please remember to use `to` with the assignment operator, + for example, `in = in.to(new_device)`. + """ + ...