Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 + FSDP2 + torch.compile examples for PyTorch Lightning and Fabric #20440

Merged
merged 13 commits into from
Nov 26, 2024
108 changes: 107 additions & 1 deletion docs/source-fabric/advanced/compile.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,115 @@ always exclude the first call to ``forward()`` from your measurements, since it
Compile median time: 0.0185 seconds
Speedup: 1.4x


----

**********************************************
Apply torch.compile with ModelParallelStrategy
**********************************************

:func:`torch.compile` can also be invoked as part of the `parallelize_fn` argument of :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy`.

This is particularly handy when :func:`torch.compile` is used in combination with the `torch.distributed.tensor` API.

Here is an example:

.. code-block:: python

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer
from lightning.fabric.strategies.model_parallel import ModelParallelStrategy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed.device_mesh import DeviceMesh

def parallelize(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
for module in model.modules():
if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)):
fully_shard(module, mesh=device_mesh)

fully_shard(model, mesh=device_mesh)

return torch.compile(model)

def train():
L.seed_everything(42)

with torch.device("meta"):
model = Transformer(
vocab_size=50257,
nlayers=16,
nhid=4096,
ninp=1024,
nhead=32,
)

strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=parallelize)

fabric = L.Fabric(precision="bf16-true", strategy=strategy)
fabric.launch()

model = fabric.setup(model)

The advantage here is that `parallelize` is called when sharding the model,
so :func:`torch.compile` is guaranteed to run on model shards and capture distributed operations.

Also, when using other libraries like `torch ao <https://github.com/pytorch/ao>`_
that need to be applied in a similar fashion, it's easy to reason about the sequence of calls
needed to achieve the equivalent of `compile(distributed(quantized(model)))`:

.. code-block:: python

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed.device_mesh import DeviceMesh
from torchao.float8 import Float8LinearConfig, convert_to_float8_training

def parallelize(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
float8_config = Float8LinearConfig(
pad_inner_dim=True,
)

def module_filter_fn(mod: torch.nn.Module, fqn: str):
return fqn != "decoder"

convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)

for module in model.modules():
if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)):
fully_shard(module, mesh=device_mesh)

fully_shard(model, mesh=device_mesh)

return torch.compile(model)

def train():
L.seed_everything(42)

with torch.device("meta"):
model = Transformer(
vocab_size=50257,
nlayers=16,
nhid=4096,
ninp=1024,
nhead=32,
)

strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=parallelize)

fabric = L.Fabric(precision="bf16-true", strategy=strategy)
fabric.launch()

model = fabric.setup(model)

For a full example, see our `FP8 Distributed Transformer example <https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/fp8_distributed_transformer>`_.

----

******************
Avoid graph breaks
Expand Down
120 changes: 118 additions & 2 deletions docs/source-pytorch/advanced/compile.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,122 @@ always exclude the first call to ``forward()``/``*_step()`` from your measuremen

----

**************************************
Apply torch.compile in configure_model
**************************************

:func:`torch.compile` can also be invoked as part of the :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` hook.

This is particularly handy when :func:`torch.compile` is used in combination with :class:`~lightning.pytorch.strategies.model_parallel.ModelParallelStrategy`.

Here is an example:

.. code-block:: python

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer
from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed._composable.fsdp.fully_shard import fully_shard

class LanguageModel(L.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.vocab_size = vocab_size
self.model = None

def configure_model(self):
if self.model is not None:
return

with torch.device("meta"):
model = Transformer(
vocab_size=self.vocab_size,
nlayers=16,
nhid=4096,
ninp=1024,
nhead=32,
)

for module in model.modules():
if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
fully_shard(module, mesh=self.device_mesh)

fully_shard(model, mesh=self.device_mesh)

self.model = torch.compile(model)

def training_step(self, batch):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("train_loss", loss)
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-4)

The advantage here is that `configure_model` is called when sharding the model,
so :func:`torch.compile` is guaranteed to run on model shards and capture distributed operations.

Also, when using other libraries like `torch ao <https://github.com/pytorch/ao>`_
that need to be applied in a similar fashion, it's easy to reason about the sequence of calls
needed to achieve the equivalent of `compile(distributed(quantized(model)))`:

.. code-block:: python

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer
from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed.device_mesh import DeviceMesh
from torchao.float8 import Float8LinearConfig, convert_to_float8_training

class LanguageModel(L.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.vocab_size = vocab_size
self.model = None

def configure_model(self):
if self.model is not None:
return

with torch.device("meta"):
model = Transformer(
vocab_size=self.vocab_size,
nlayers=16,
nhid=4096,
ninp=1024,
nhead=32,
)

float8_config = Float8LinearConfig(
pad_inner_dim=True,
)

def module_filter_fn(mod: torch.nn.Module, fqn: str):
return fqn != "decoder"

convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)

for module in model.modules():
if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
fully_shard(module, mesh=self.device_mesh)

fully_shard(model, mesh=self.device_mesh)

self.model = torch.compile(model)

For a full example, see our `FP8 Distributed Transformer example <https://github.com/Lightning-AI/lightning/blob/master/examples/pytorch/fp8_distributed_transformer>`_.

----

******************
Avoid graph breaks
Expand Down Expand Up @@ -253,8 +369,8 @@ Limitations

There are a few limitations you should be aware of when using ``torch.compile`` **in conjunction with the Trainer**:

* The Trainer currently does not reapply ``torch.compile`` over DDP/FSDP, meaning distributed operations can't benefit from speed ups at the moment.
This limitation will be lifted in the future.
* The Trainer currently does not reapply ``torch.compile`` over :class:`~lightning.pytorch.strategies.DDPStrategy` and :class:`~lightning.pytorch.strategies.FSDPStrategy`, meaning distributed operations can't benefit from speed ups at the moment.
This limitation can be avoided by using :class:`~lightning.pytorch.strategies.model_parallel.ModelParallelStrategy`, as described in `Apply torch.compile in configure_model`_ above.

* In some cases, using ``self.log()`` in your LightningModule will cause compilation errors.
Until addressed, you can work around these issues by applying ``torch.compile`` to the submodule(s) of your LightningModule rather than to the entire LightningModule at once.
Expand Down
39 changes: 39 additions & 0 deletions examples/fabric/fp8_distributed_transformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
## Distributed, Low-Precision Transformer Example

This example shows how to use `ModelParallelStrategy` in `Fabric` to train a Transformer model minimizing memory usage, maximizing throughput, and distributing load across multiple GPUs.

### Training Large Models and Memory Requirements

One of the main challenges when training large models, like large language models (LLMs), is dealing with their memory footprint. LLMs can be so large that weights, activations, gradients and optimizer state don't fit a single GPU, so that they need to be distributed across multiple GPUs, and across multiple machines. There are multiple ways of distributing computations, among which fully-sharded data parallelism (FSDP) and tensor parallelism (TP).

An additional way of reducing memory requirements is representing floating point numbers in weights and activations in low numerical precision, such as 16-bit (`bfloat16`), or 8-bit (`fp8`). This leads to savings in memory usage, as well as memory bandwidth usage (fewer bytes transferred from device memory to GPU cores in unit time).

Roughly, reducing precision to `fp8` for linear layers can lead to 2x reduction in memory requirements and 1.6x improvement in throughput. Support for `fp8` weights and activations requires recent GPUs - Hopper, Ada Lovelace and above (e.g. H100, L4, L40).

The introduction of tensor subclasses in PyTorch brought two new APIs that can be used to achieve memory savings and distributed training (as well as inference) in combination:

- [torch ao](https://github.com/pytorch/ao) to execute linear layers in low numerical precision (`fp8` and other quantized formats)
- [dtensors](https://pytorch.org/docs/stable/distributed.tensor.html) to distribute models across GPUs, by combining TP and FSDP (referred to FSDP2 in PyTorch)

Notably, `torch ao` introduces quantization and dequantization operations in the model that may result in slow-downs if not optimized. Using `torch.compile` after `torch ao` recovers performance by generating optimized kernels for those operations.

### Vanilla Transformer Example

This example shows how to train a vanilla Transformer model using `fp8` precision and the FSDP2 distributed strategy, and then optimize the resulting model through `torch.compile`.

Specifically, we employ the `ModelParallelStrategy`, and use the `configure_model` hook to distribute the model using the PyTorch DTensor API.
In the same hook we also pass the model through the `torch ao` API (prior to FSDP2), as well as `torch.compile` (after FSDP2).

The resulting code follows the PyTorch API closely, while also taking advantage of the rest of PyTorch Lightning.

To execute the code directly just run:

```bash
python train.py
```

### A Note on torch.compile

Note that PyTorch Lightning also supports calling `torch.compile` on a `LightningModule` and passing it to the `Trainer`.

While this works for simple cases, in order to get the most out of the combination of the latest distributed, quantization, and compile PyTorch API's, we recommend invoking `torch.compile` at the end of the `configure_model` hook, as shown in this example.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
torchao>=0.7.0
100 changes: 100 additions & 0 deletions examples/fabric/fp8_distributed_transformer/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.fabric.strategies import ModelParallelStrategy
from lightning.pytorch.demos import Transformer, WikiText2
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed.device_mesh import DeviceMesh
from torch.utils.data import DataLoader
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
from tqdm import tqdm


def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
float8_config = Float8LinearConfig(
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly # noqa
pad_inner_dim=True,
)

def module_filter_fn(mod: torch.nn.Module, fqn: str):
# we skip the decoder because it typically vocabulary size
# is not divisible by 16 as required by float8
return fqn != "decoder"

convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)

for module in model.modules():
if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)):
fully_shard(module, mesh=device_mesh)

fully_shard(model, mesh=device_mesh)

return torch.compile(model)


def train():
L.seed_everything(42)

batch_size = 8
micro_batch_size = 1

max_steps = 100

dataset = WikiText2()
dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size)

with torch.device("meta"):
model = Transformer(
vocab_size=dataset.vocab_size,
nlayers=16,
nhid=4096,
ninp=1024,
nhead=32,
)

strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=configure_model)

fabric = L.Fabric(precision="bf16-true", strategy=strategy)
fabric.launch()

model = fabric.setup(model)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
optimizer = fabric.setup_optimizers(optimizer)

dataloader = fabric.setup_dataloaders(dataloader)

iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader)

steps = 0

for i, batch in iterable:
input, target = batch

is_accumulating = i % (batch_size // micro_batch_size) != 0

with fabric.no_backward_sync(model, enabled=is_accumulating):
output = model(input, target)
loss = F.nll_loss(output, target.view(-1))
fabric.backward(loss)

if not is_accumulating:
fabric.clip_gradients(model, optimizer, max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
steps += 1

if fabric.is_global_zero:
iterable.set_postfix_str(f"train_loss={loss.item():.2f}")

if steps == max_steps:
break

fabric.print(torch.cuda.memory_summary())


if __name__ == "__main__":
torch.set_float32_matmul_precision("high")

train()
Loading
Loading