Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect behavior for out-of-range casts to FP8E4M3FN #260

Open
yuanyao-nv opened this issue Jan 25, 2025 · 7 comments
Open

Incorrect behavior for out-of-range casts to FP8E4M3FN #260

yuanyao-nv opened this issue Jan 25, 2025 · 7 comments
Assignees

Comments

@yuanyao-nv
Copy link

The following calculation returns NaN but this is inconsistent with the behavior in ONNXRUNTIME which returns the max representable value of 448.

np.asarray([3e10]).astype(float8_e4m3fn)
@yuanyao-nv
Copy link
Author

I guess a related question is, is the saturation option exposed to the user in any way? see details at https://onnx.ai/onnx/technical/float8.html#cast

@yuanyao-nv
Copy link
Author

Actually there seems to be an inconsistency in the saturation behavior. The switch happens between 465 and 464:

>>> np.float32(465).astype(float8_e4m3fn)
nan
>>> np.float32(464).astype(float8_e4m3fn)
448

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 27, 2025

I think this is working as expected: if we compare the requested value to the spacing between representable values, we're rounding down within the first half of the spacing and rounding up in the second half. Consider this:

import numpy as np
import ml_dtypes
import warnings
warnings.simplefilter('ignore', RuntimeWarning)

def check(typ):
  name = typ.__name__
  print(f'type: {name}')
  largest = typ(ml_dtypes.finfo(typ).max)
  print(f'  largest value: {largest}')
  prev = np.nextafter(largest, typ(0))
  print(f'  previous value: {prev}')
  spacing = np.spacing(prev)
  print(f'  spacing: {spacing}')

  val1 = largest + spacing // 2 - 1
  val2 = val1 + 1
  val3 = val1 + 2

  print(f'  {name}({val1}) = {typ(val1)}')
  print(f'  {name}({val2}) = {typ(val2)}')
  print(f'  {name}({val3}) = {typ(val3)}')

check(np.float16)
check(ml_dtypes.float8_e5m2)
check(ml_dtypes.float8_e4m3fn)
type: float16
  largest value: 65504.0
  previous value: 65472.0
  spacing: 32.0
  float16(65519.0) = 65504.0
  float16(65520.0) = inf
  float16(65521.0) = inf
type: float8_e4m3fn
  largest value: 448
  previous value: 416
  spacing: 32
  float8_e4m3fn(463.0) = 448
  float8_e4m3fn(464.0) = 448
  float8_e4m3fn(465.0) = nan

Keep in mind that float8_e4m3fn has no representation for inf, so it returns nan.

Comparing the two, there is perhaps some kind of off-by-one error that I don't totally understand (in float16, max + spacing // 2 rounds up, while in float8_e4m3fn, max + spacing // 2 rounds down) but aside from that the behavior seems more-or-less consistent. What do you think?

@yuanyao-nv
Copy link
Author

@jakevdp The off-by-one might have to do with rounding to even and truncation to target mantissa bit width. It needs some checking.
But this is kind of orthogonal to my point about enabling the saturation behavior. As tabulated in the ONNX link above, there exist two modes of casting out of bound values for fp8. What ml_dtypes has implemented is the non-saturation behavior. The one that's more useful for ML quantization applications is the saturation behavior since NaN's propagate in the neural network and render all the subsequent tensors unusable.

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 27, 2025

Ah, sorry I misunderstood you – so your aim here is that overflowing values for dtypes with no inf representation should round to the largest representable value rather than resulting in NaN.

@hawkinsp – what do you think about this?

@hawkinsp
Copy link
Collaborator

I'm not sure. The natural thing to do would be to expose a custom saturating_cast ufunc, I think, which has the semantics you mention.

Although I'm not sure what behavior should .astype have? Saturating or non-saturating?

@yuanyao-nv
Copy link
Author

Yea as long as both options are exposed it should suffice. As regards which should be the default for .astype(), ONNX has picked the saturate mode as default https://onnx.ai/onnx/operators/onnx__Cast.html#attributes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants