Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MaisiVAE: Auto-cast GroupNorm, deprecate norm_float16 #8326

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions monai/apps/generation/maisi/networks/autoencoderkl_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from monai.networks.blocks import Convolution
from monai.networks.blocks.spatialattention import SpatialAttentionBlock
from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.type_conversion import convert_to_tensor

# Set up logging configuration
Expand All @@ -34,6 +35,7 @@ def _empty_cuda_cache(save_mem: bool) -> None:
return


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class MaisiGroupNorm3D(nn.GroupNorm):
"""
Custom 3D Group Normalization with optional print_info output.
Expand All @@ -43,7 +45,7 @@ class MaisiGroupNorm3D(nn.GroupNorm):
num_channels: Number of channels for the group norm.
eps: Epsilon value for numerical stability.
affine: Whether to use learnable affine parameters, default to `True`.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: Deprecated argument.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
Expand All @@ -59,14 +61,15 @@ def __init__(
save_mem: bool = True,
):
super().__init__(num_groups, num_channels, eps, affine)
self.norm_float16 = norm_float16
self.print_info = print_info
self.save_mem = save_mem

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.print_info:
logger.info(f"MaisiGroupNorm3D with input size: {input.size()}")

target_dtype = input.dtype

if len(input.shape) != 5:
raise ValueError("Expected a 5D tensor")

Expand All @@ -75,13 +78,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

inputs = []
for i in range(input.size(1)):
array = input[:, i : i + 1, ...].to(dtype=torch.float32)
array = input[:, i : i + 1, ...]
mean = array.mean([2, 3, 4, 5], keepdim=True)
std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_()
if self.norm_float16:
inputs.append(((array - mean) / std).to(dtype=torch.float16))
else:
inputs.append((array - mean) / std)
inputs.append(((array - mean) / std).to(dtype=target_dtype))

del input
_empty_cuda_cache(self.save_mem)
Expand Down Expand Up @@ -376,6 +376,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class MaisiResBlock(nn.Module):
"""
Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
Expand Down Expand Up @@ -417,7 +418,6 @@ def __init__(
num_channels=in_channels,
eps=norm_eps,
affine=True,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand All @@ -439,7 +439,6 @@ def __init__(
num_channels=out_channels,
eps=norm_eps,
affine=True,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -501,6 +500,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return out_tensor


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class MaisiEncoder(nn.Module):
"""
Convolutional cascade that downsamples the image into a spatial latent space.
Expand All @@ -520,7 +520,7 @@ class MaisiEncoder(nn.Module):
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: Deprecated argument.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
Expand Down Expand Up @@ -591,7 +591,6 @@ def __init__(
out_channels=output_channel,
num_splits=num_splits,
dim_split=dim_split,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -660,7 +659,6 @@ def __init__(
num_channels=num_channels[-1],
eps=norm_eps,
affine=True,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -690,6 +688,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class MaisiDecoder(nn.Module):
"""
Convolutional cascade upsampling from a spatial latent space into an image space.
Expand All @@ -710,7 +709,7 @@ class MaisiDecoder(nn.Module):
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: Deprecated argument.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
Expand Down Expand Up @@ -809,7 +808,6 @@ def __init__(
out_channels=block_out_ch,
num_splits=num_splits,
dim_split=dim_split,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -848,7 +846,6 @@ def __init__(
num_channels=block_in_ch,
eps=norm_eps,
affine=True,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -878,6 +875,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class AutoencoderKlMaisi(AutoencoderKL):
"""
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
Expand All @@ -901,7 +899,7 @@ class AutoencoderKlMaisi(AutoencoderKL):
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: Deprecated argument.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
Expand Down Expand Up @@ -964,7 +962,6 @@ def __init__(
use_flash_attention=use_flash_attention,
num_splits=num_splits,
dim_split=dim_split,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand All @@ -985,7 +982,6 @@ def __init__(
use_convtranspose=use_convtranspose,
num_splits=num_splits,
dim_split=dim_split,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
26 changes: 18 additions & 8 deletions tests/test_autoencoderkl_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,38 @@
else:
CASES = CASES_NO_ATTENTION

test_dtypes = [torch.float32]
if device.type == "cuda":
test_dtypes.append(torch.bfloat16)
test_dtypes.append(torch.float16)

DTYPE_CASES = []
for dtype in test_dtypes:
for case in CASES:
DTYPE_CASES.append(case + [dtype])


class TestAutoencoderKlMaisi(unittest.TestCase):

@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape):
net = AutoencoderKlMaisi(**input_param).to(device)
@parameterized.expand(DTYPE_CASES)
def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape, dtype):
net = AutoencoderKlMaisi(**input_param).to(device=device, dtype=dtype)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
result = net.forward(torch.randn(input_shape).to(device=device, dtype=dtype))
self.assertEqual(result[0].shape, expected_shape)
self.assertEqual(result[1].shape, expected_latent_shape)
self.assertEqual(result[2].shape, expected_latent_shape)

@parameterized.expand(CASES)
@parameterized.expand(DTYPE_CASES)
@SkipIfBeforePyTorchVersion((1, 11))
def test_shape_with_convtranspose_and_checkpointing(
self, input_param, input_shape, expected_shape, expected_latent_shape
self, input_param, input_shape, expected_shape, expected_latent_shape, dtype
):
input_param = input_param.copy()
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
net = AutoencoderKlMaisi(**input_param).to(device)
net = AutoencoderKlMaisi(**input_param).to(device=device, dtype=dtype)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
result = net.forward(torch.randn(input_shape).to(device=device, dtype=dtype))
self.assertEqual(result[0].shape, expected_shape)
self.assertEqual(result[1].shape, expected_latent_shape)
self.assertEqual(result[2].shape, expected_latent_shape)
Expand Down
Loading