Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[bc-breaking] clarify public API of float8_experimental #330

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ We provide two per-tensor scaling strategies: dynamic and delayed. See https://
This is the most accurate recipe as every tensor is scaled dynamically.

```python
from float8_experimental.float8_linear_utils import (
from float8_experimental import (
convert_to_float8_training,
precompute_float8_dynamic_scale_for_fsdp,
)
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp

# create model
m = Model(...)
Expand Down Expand Up @@ -82,11 +82,11 @@ for _ in range(N_ITER):
This is theoretically the most performant recipe as it minimizes memory reads.

```python
from float8_experimental.float8_linear_utils import (
from float8_experimental import (
convert_to_float8_training,
sync_float8_amax_and_scale_history,
TensorScalingType,
)
from float8_experimental.float8_linear import TensorScalingType

# create model
m = Model(...)
Expand Down
14 changes: 10 additions & 4 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@
TensorScalingType,
)
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import convert_to_float8_training
from float8_experimental.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
)
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp

# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals
Expand All @@ -30,7 +35,8 @@
"Float8TensorCastConfig",
# top level UX
"convert_to_float8_training",
# TODO(future): remove Float8Tensor and Float8Linear from public API
"Float8Tensor",
"Float8Linear",
"linear_requires_sync",
"sync_float8_amax_and_scale_history",
"precompute_float8_dynamic_scale_for_fsdp",
# note: Float8Tensor and Float8Linear are not public APIs
]
4 changes: 4 additions & 0 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):

class Float8Linear(torch.nn.Linear):
"""
Note: this is **not** a public API and is only intended to be used
inside of this repository. Please file an issue if you would benefit
from this being a public API.

A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks
scales in way friendly to delayed scaling.
"""
Expand Down
4 changes: 4 additions & 0 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ def backward(ctx, g):

class Float8Tensor(torch.Tensor):
"""
Note: this is **not** a public API and is only intended to be used
inside of this repository. Please file an issue if you would benefit
from this being a public API.

A Python-only Float8 tensor subclass. Contains:
* `_data`: the underlying e4m3 or e5m2 data
* `_scale`: the scale used to scale the original fp32 tensor. We multiply
Expand Down
Loading