forked from pytorch/torchrec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fbshipit-source-id: 8f9686235729bb0aa9e03e3dbf73f74e75932b3f
- Loading branch information
0 parents
commit dfdbee8
Showing
128 changed files
with
28,284 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
name: pre-commit | ||
|
||
on: | ||
push: | ||
branches: [master] | ||
pull_request: | ||
|
||
jobs: | ||
pre-commit: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Setup Python | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: 3.8 | ||
architecture: x64 | ||
- name: Checkout Torchrec | ||
uses: actions/checkout@v2 | ||
- name: Run pre-commit | ||
uses: pre-commit/[email protected] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
repos: | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v4.0.1 | ||
hooks: | ||
- id: check-toml | ||
- id: check-yaml | ||
exclude: packaging/.* | ||
- id: end-of-file-fixer | ||
|
||
- repo: https://github.com/omnilib/ufmt | ||
rev: v1.3.0 | ||
hooks: | ||
- id: ufmt | ||
additional_dependencies: | ||
- black == 21.9b0 | ||
- usort == 0.6.4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[tool.usort] | ||
|
||
first_party_detection = false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from setuptools import setup, find_packages | ||
|
||
# Minimal setup configuration. | ||
setup( | ||
name="torchrec", | ||
packages=find_packages(exclude=("*tests",)), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import os | ||
|
||
import torchx.specs as specs | ||
from torchx.components.base import torch_dist_role | ||
from torchx.specs.api import Resource | ||
|
||
|
||
def test_installation() -> specs.AppDef: | ||
cwd = os.getcwd() | ||
entrypoint = os.path.join(cwd, "test_installation_main.py") | ||
|
||
user = os.environ.get("USER") | ||
image = f"/data/home/{user}" | ||
|
||
return specs.AppDef( | ||
name="test_installation", | ||
roles=[ | ||
torch_dist_role( | ||
name="trainer", | ||
image=image, | ||
# AWS p4d instance (https://aws.amazon.com/ec2/instance-types/p4/). | ||
resource=Resource( | ||
cpu=96, | ||
gpu=8, | ||
memMB=-1, | ||
), | ||
num_replicas=1, | ||
entrypoint=entrypoint, | ||
nproc_per_node="1", | ||
rdzv_backend="c10d", | ||
args=[], | ||
), | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import os | ||
import sys | ||
from typing import List, Iterator | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from torchrec import EmbeddingBagCollection | ||
from torchrec import KeyedJaggedTensor | ||
from torchrec.distributed.model_parallel import DistributedModelParallel | ||
from torchrec.models.dlrm import DLRM | ||
from torchrec.modules.embedding_configs import EmbeddingBagConfig | ||
from torchrec.optim.keyed import KeyedOptimizerWrapper | ||
|
||
|
||
class RandomIterator(Iterator): | ||
def __init__( | ||
self, batch_size: int, num_dense: int, num_sparse: int, num_embeddings: int | ||
) -> None: | ||
self.batch_size = batch_size | ||
self.num_dense = num_dense | ||
self.num_sparse = num_sparse | ||
self.sparse_keys = [f"feature{id}" for id in range(self.num_sparse)] | ||
self.num_embeddings = num_embeddings | ||
self.num_ids_per_feature = 3 | ||
self.num_ids_to_generate = ( | ||
self.num_sparse * self.num_ids_per_feature * self.batch_size | ||
) | ||
|
||
def __next__(self) -> (torch.Tensor, KeyedJaggedTensor, torch.Tensor): | ||
float_features = torch.randn( | ||
self.batch_size, | ||
self.num_dense, | ||
) | ||
labels = torch.randint( | ||
low=0, | ||
high=2, | ||
size=(self.batch_size,), | ||
) | ||
sparse_ids = torch.randint( | ||
high=self.num_sparse, | ||
size=(self.num_ids_to_generate,), | ||
) | ||
sparse_features = KeyedJaggedTensor.from_offsets_sync( | ||
keys=self.sparse_keys, | ||
values=sparse_ids, | ||
offsets=torch.tensor( | ||
list(range(0, self.num_ids_to_generate + 1, self.num_ids_per_feature)), | ||
dtype=torch.int32, | ||
), | ||
) | ||
return (float_features, sparse_features, labels) | ||
|
||
|
||
def main(argv: List[str]) -> None: | ||
batch_size = 1024 | ||
num_dense = 1000 | ||
num_sparse = 20 | ||
num_embeddings = 1000000 | ||
|
||
configs = [ | ||
EmbeddingBagConfig( | ||
name=f"table{id}", | ||
embedding_dim=64, | ||
num_embeddings=num_embeddings, | ||
feature_names=[f"feature{id}"], | ||
) | ||
for id in range(num_sparse) | ||
] | ||
|
||
rank = int(os.environ["LOCAL_RANK"]) | ||
if torch.cuda.is_available(): | ||
device = torch.device(f"cuda:{rank}") | ||
backend = "nccl" | ||
torch.cuda.set_device(device) | ||
else: | ||
raise Exception("Cuda not available") | ||
|
||
if not torch.distributed.is_initialized(): | ||
dist.init_process_group(backend=backend) | ||
|
||
model = DLRM( | ||
embedding_bag_collection=EmbeddingBagCollection( | ||
tables=configs, device=torch.device("meta") | ||
), | ||
dense_in_features=num_dense, | ||
dense_arch_layer_sizes=[500, 64], | ||
over_arch_layer_sizes=[32, 16, 1], | ||
dense_device=device, | ||
) | ||
model = DistributedModelParallel( | ||
module=model, | ||
device=device, | ||
) | ||
optimizer = KeyedOptimizerWrapper( | ||
dict(model.named_parameters()), | ||
lambda params: torch.optim.SGD(params, lr=0.01), | ||
) | ||
|
||
random_iterator = RandomIterator(batch_size, num_dense, num_sparse, num_embeddings) | ||
loss_fn = torch.nn.BCEWithLogitsLoss() | ||
for _ in range(10): | ||
(dense_features, sparse_features, labels) = next(random_iterator) | ||
dense_features = dense_features.to(device) | ||
sparse_features = sparse_features.to(device) | ||
output = model(dense_features, sparse_features) | ||
labels = labels.to(device) | ||
loss = loss_fn(output.squeeze(), labels.float()) | ||
torch.sum(loss, dim=0).backward() | ||
optimizer.zero_grad() | ||
optimizer.step() | ||
|
||
print( | ||
"\033[92m" + "Successfully ran a few epochs for DLRM. Installation looks good!" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main(sys.argv[1:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torchrec.distributed # noqa | ||
import torchrec.quant # noqa | ||
from torchrec.fx import tracer # noqa | ||
from torchrec.modules.embedding_configs import ( # noqa | ||
EmbeddingBagConfig, | ||
EmbeddingConfig, | ||
DataType, | ||
PoolingType, | ||
) | ||
from torchrec.modules.embedding_modules import ( # noqa | ||
EmbeddingBagCollection, | ||
EmbeddingCollection, | ||
EmbeddingBagCollectionInterface, | ||
) # noqa | ||
from torchrec.modules.score_learning import PositionWeightsAttacher # noqa | ||
from torchrec.sparse.jagged_tensor import ( # noqa | ||
JaggedTensor, | ||
KeyedJaggedTensor, | ||
KeyedTensor, | ||
) | ||
from torchrec.types import Pipelineable, Multistreamable # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torchrec.datasets.criteo # noqa | ||
import torchrec.datasets.movielens # noqa | ||
import torchrec.datasets.random # noqa | ||
import torchrec.datasets.utils # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from typing import ( | ||
Iterator, | ||
Any, | ||
Callable, | ||
Dict, | ||
Iterable, | ||
List, | ||
Optional, | ||
Union, | ||
) | ||
|
||
import torch | ||
import torch.utils.data.datapipes as dp | ||
from torch.utils.data import IterDataPipe | ||
from torchrec.datasets.utils import LoadFiles, ReadLinesFromCSV, safe_cast | ||
|
||
|
||
INT_FEATURE_COUNT = 13 | ||
CAT_FEATURE_COUNT = 26 | ||
DEFAULT_LABEL_NAME = "label" | ||
DEFAULT_INT_NAMES: List[str] = [f"int_{idx}" for idx in range(INT_FEATURE_COUNT)] | ||
DEFAULT_CAT_NAMES: List[str] = [f"cat_{idx}" for idx in range(CAT_FEATURE_COUNT)] | ||
DEFAULT_COLUMN_NAMES: List[str] = [ | ||
DEFAULT_LABEL_NAME, | ||
*DEFAULT_INT_NAMES, | ||
*DEFAULT_CAT_NAMES, | ||
] | ||
|
||
COLUMN_TYPE_CASTERS: List[Callable[[Union[int, str]], Union[int, str]]] = [ | ||
lambda val: safe_cast(val, int, 0), | ||
*(lambda val: safe_cast(val, int, 0) for _ in range(INT_FEATURE_COUNT)), | ||
*(lambda val: safe_cast(val, str, "") for _ in range(CAT_FEATURE_COUNT)), | ||
] | ||
|
||
|
||
def _default_row_mapper(example: List[str]) -> Dict[str, Union[int, str]]: | ||
column_names = reversed(DEFAULT_COLUMN_NAMES) | ||
column_type_casters = reversed(COLUMN_TYPE_CASTERS) | ||
return { | ||
next(column_names): next(column_type_casters)(val) for val in reversed(example) | ||
} | ||
|
||
|
||
class CriteoIterDataPipe(IterDataPipe): | ||
def __init__( | ||
self, | ||
paths: Iterable[str], | ||
*, | ||
# pyre-ignore[2] | ||
row_mapper: Optional[Callable[[List[str]], Any]] = _default_row_mapper, | ||
# pyre-ignore[2] | ||
**open_kw, | ||
) -> None: | ||
self.paths = paths | ||
self.row_mapper = row_mapper | ||
self.open_kw: Any = open_kw # pyre-ignore[4] | ||
|
||
# pyre-ignore[3] | ||
def __iter__(self) -> Iterator[Any]: | ||
worker_info = torch.utils.data.get_worker_info() | ||
paths = self.paths | ||
if worker_info is not None: | ||
paths = ( | ||
path | ||
for (idx, path) in enumerate(paths) | ||
if idx % worker_info.num_workers == worker_info.id | ||
) | ||
datapipe = LoadFiles(paths, mode="r", **self.open_kw) | ||
datapipe = ReadLinesFromCSV(datapipe, delimiter="\t") | ||
if self.row_mapper: | ||
datapipe = dp.iter.Mapper(datapipe, self.row_mapper) | ||
yield from datapipe | ||
|
||
|
||
def criteo_terabyte( | ||
paths: Iterable[str], | ||
*, | ||
# pyre-ignore[2] | ||
row_mapper: Optional[Callable[[List[str]], Any]] = _default_row_mapper, | ||
# pyre-ignore[2] | ||
**open_kw, | ||
) -> IterDataPipe: | ||
"""`Criteo 1TB Click Logs <https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/>`_ Dataset | ||
Args: | ||
paths (str): local paths to TSV files that constitute the Criteo 1TB dataset. | ||
row_mapper (Optional[Callable[[List[str]], Any]]): function to apply to each split TSV line. | ||
open_kw: options to pass to underlying invocation of iopath.common.file_io.PathManager.open. | ||
Example: | ||
>>> datapipe = criteo_terabyte( | ||
>>> ("/home/datasets/criteo/day_0.tsv", "/home/datasets/criteo/day_1.tsv") | ||
>>> ) | ||
>>> datapipe = dp.iter.Batcher(datapipe, 100) | ||
>>> datapipe = dp.iter.Collator(datapipe) | ||
>>> batch = next(iter(datapipe)) | ||
""" | ||
return CriteoIterDataPipe(paths, row_mapper=row_mapper, **open_kw) | ||
|
||
|
||
def criteo_kaggle( | ||
path: str, | ||
*, | ||
# pyre-ignore[2] | ||
row_mapper: Optional[Callable[[List[str]], Any]] = _default_row_mapper, | ||
# pyre-ignore[2] | ||
**open_kw, | ||
) -> IterDataPipe: | ||
"""`Kaggle/Criteo Display Advertising <https://www.kaggle.com/c/criteo-display-ad-challenge/>`_ Dataset | ||
Args: | ||
root (str): local path to train or test dataset file. | ||
row_mapper (Optional[Callable[[List[str]], Any]]): function to apply to each split TSV line. | ||
open_kw: options to pass to underlying invocation of iopath.common.file_io.PathManager.open. | ||
Example: | ||
>>> train_datapipe = criteo_kaggle( | ||
>>> "/home/datasets/criteo_kaggle/train.txt", | ||
>>> ) | ||
>>> example = next(iter(train_datapipe)) | ||
>>> test_datapipe = criteo_kaggle( | ||
>>> "/home/datasets/criteo_kaggle/test.txt", | ||
>>> ) | ||
>>> example = next(iter(test_datapipe)) | ||
""" | ||
return CriteoIterDataPipe((path,), row_mapper=row_mapper, **open_kw) |
Oops, something went wrong.