Skip to content

Commit

Permalink
Minimal transformer examples
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga committed Nov 22, 2024
1 parent 8ce5287 commit c6695f8
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 0 deletions.
Empty file.
106 changes: 106 additions & 0 deletions examples/fabric/fp8_fsdp2_compile/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed.device_mesh import DeviceMesh

from torchao.float8 import convert_to_float8_training, Float8LinearConfig

import lightning as L
from lightning.fabric.strategies import ModelParallelStrategy
from lightning.pytorch.demos import Transformer, WikiText2

from tqdm import tqdm


def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
float8_config = Float8LinearConfig(
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
pad_inner_dim=True,
)

def module_filter_fn(mod: torch.nn.Module, fqn: str):
# we skip the decoder because it typically vocabulary size
# is not divisible by 16 as required by float8
if fqn == "decoder":
return False
return True

convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)

for module in model.modules():
if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)):
fully_shard(module, mesh=device_mesh)

fully_shard(model, mesh=device_mesh)

model = torch.compile(model)

return model


def train():
L.seed_everything(42)

batch_size = 8
micro_batch_size = 1

dataset = WikiText2()
dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size)

with torch.device("meta"):
model = Transformer(
vocab_size=dataset.vocab_size,
nlayers=16,
nhid=4096,
ninp=1024,
nhead=32,
)

strategy = ModelParallelStrategy(
data_parallel_size=4,
tensor_parallel_size=1,
parallelize_fn=configure_model
)

fabric = L.Fabric(precision="bf16-true", strategy=strategy)
fabric.launch()

model = fabric.setup(model)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
optimizer = fabric.setup_optimizers(optimizer)

dataloader = fabric.setup_dataloaders(dataloader)

iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader)

for i, batch in iterable:
input, target = batch

is_accumulating = i % (batch_size // micro_batch_size) != 0

with fabric.no_backward_sync(model, enabled=is_accumulating):
output = model(input, target)
loss = F.nll_loss(output, target.view(-1))
fabric.backward(loss)

if not is_accumulating:
fabric.clip_gradients(model, optimizer, max_norm=1.0)
optimizer.step()
optimizer.zero_grad()

if fabric.is_global_zero:
iterable.set_postfix_str(f"train_loss={loss.item():.2f}")

if i // (batch_size // micro_batch_size) > 100:
break

fabric.print(torch.cuda.memory_summary())


if __name__ == "__main__":
torch.set_float32_matmul_precision('high')

train()
Empty file.
94 changes: 94 additions & 0 deletions examples/pytorch/fp8_fsdp2_compile/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.distributed._composable.fsdp.fully_shard import fully_shard

from torchao.float8 import convert_to_float8_training, Float8LinearConfig

import lightning as L
from lightning.pytorch.strategies import ModelParallelStrategy
from lightning.pytorch.demos import Transformer, WikiText2


class LanguageModel(L.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.vocab_size = vocab_size
self.model = None

def configure_model(self):
if self.model is not None:
return

with torch.device("meta"):
model = Transformer(
vocab_size=self.vocab_size,
nlayers=16,
nhid=4096,
ninp=1024,
nhead=32,
)

float8_config = Float8LinearConfig(
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
pad_inner_dim=True,
)

def module_filter_fn(mod: torch.nn.Module, fqn: str):
# we skip the decoder because it typically vocabulary size
# is not divisible by 16 as required by float8
if fqn == "decoder":
return False
return True

convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)

for module in model.modules():
if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
fully_shard(module, mesh=self.device_mesh)

fully_shard(model, mesh=self.device_mesh)

self.model = torch.compile(model)

def training_step(self, batch):
input, target = batch
output = self.model(input, target)
loss = F.nll_loss(output, target.view(-1))
self.log("train_loss", loss, prog_bar=True)
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-4)


def train():
L.seed_everything(42)

dataset = WikiText2()
train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)

model = LanguageModel(vocab_size=dataset.vocab_size)

mp_strategy = ModelParallelStrategy(
data_parallel_size=4,
tensor_parallel_size=1,
)

trainer = L.Trainer(
strategy=mp_strategy,
max_steps=100,
precision="bf16-true",
accumulate_grad_batches=8
)

trainer.fit(model, train_dataloader)

trainer.print(torch.cuda.memory_summary())


if __name__ == "__main__":
torch.set_float32_matmul_precision('high')

train()

0 comments on commit c6695f8

Please sign in to comment.