Skip to content

Commit

Permalink
MaisiVAE: Auto-cast GroupNorm, deprecate norm_float16
Browse files Browse the repository at this point in the history
Signed-off-by: John Zielke <[email protected]>
  • Loading branch information
johnzielke committed Feb 4, 2025
1 parent 8dcb9dc commit 8888a48
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 26 deletions.
32 changes: 14 additions & 18 deletions monai/apps/generation/maisi/networks/autoencoderkl_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from monai.networks.blocks import Convolution
from monai.networks.blocks.spatialattention import SpatialAttentionBlock
from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.type_conversion import convert_to_tensor

# Set up logging configuration
Expand All @@ -34,6 +35,7 @@ def _empty_cuda_cache(save_mem: bool) -> None:
return


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class MaisiGroupNorm3D(nn.GroupNorm):
"""
Custom 3D Group Normalization with optional print_info output.
Expand All @@ -43,7 +45,7 @@ class MaisiGroupNorm3D(nn.GroupNorm):
num_channels: Number of channels for the group norm.
eps: Epsilon value for numerical stability.
affine: Whether to use learnable affine parameters, default to `True`.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: Deprecated argument.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
Expand All @@ -59,14 +61,15 @@ def __init__(
save_mem: bool = True,
):
super().__init__(num_groups, num_channels, eps, affine)
self.norm_float16 = norm_float16
self.print_info = print_info
self.save_mem = save_mem

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.print_info:
logger.info(f"MaisiGroupNorm3D with input size: {input.size()}")

target_dtype = input.dtype

if len(input.shape) != 5:
raise ValueError("Expected a 5D tensor")

Expand All @@ -75,13 +78,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

inputs = []
for i in range(input.size(1)):
array = input[:, i : i + 1, ...].to(dtype=torch.float32)
array = input[:, i : i + 1, ...]
mean = array.mean([2, 3, 4, 5], keepdim=True)
std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_()
if self.norm_float16:
inputs.append(((array - mean) / std).to(dtype=torch.float16))
else:
inputs.append((array - mean) / std)
inputs.append(((array - mean) / std).to(dtype=target_dtype))

del input
_empty_cuda_cache(self.save_mem)
Expand Down Expand Up @@ -376,6 +376,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class MaisiResBlock(nn.Module):
"""
Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
Expand Down Expand Up @@ -417,7 +418,6 @@ def __init__(
num_channels=in_channels,
eps=norm_eps,
affine=True,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand All @@ -439,7 +439,6 @@ def __init__(
num_channels=out_channels,
eps=norm_eps,
affine=True,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -501,6 +500,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return out_tensor


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class MaisiEncoder(nn.Module):
"""
Convolutional cascade that downsamples the image into a spatial latent space.
Expand All @@ -520,7 +520,7 @@ class MaisiEncoder(nn.Module):
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: Deprecated argument.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
Expand Down Expand Up @@ -591,7 +591,6 @@ def __init__(
out_channels=output_channel,
num_splits=num_splits,
dim_split=dim_split,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -660,7 +659,6 @@ def __init__(
num_channels=num_channels[-1],
eps=norm_eps,
affine=True,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -690,6 +688,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class MaisiDecoder(nn.Module):
"""
Convolutional cascade upsampling from a spatial latent space into an image space.
Expand All @@ -710,7 +709,7 @@ class MaisiDecoder(nn.Module):
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: Deprecated argument.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
Expand Down Expand Up @@ -809,7 +808,6 @@ def __init__(
out_channels=block_out_ch,
num_splits=num_splits,
dim_split=dim_split,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -848,7 +846,6 @@ def __init__(
num_channels=block_in_ch,
eps=norm_eps,
affine=True,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -878,6 +875,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class AutoencoderKlMaisi(AutoencoderKL):
"""
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
Expand All @@ -901,7 +899,7 @@ class AutoencoderKlMaisi(AutoencoderKL):
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: Deprecated argument.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
Expand Down Expand Up @@ -964,7 +962,6 @@ def __init__(
use_flash_attention=use_flash_attention,
num_splits=num_splits,
dim_split=dim_split,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand All @@ -985,7 +982,6 @@ def __init__(
use_convtranspose=use_convtranspose,
num_splits=num_splits,
dim_split=dim_split,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
26 changes: 18 additions & 8 deletions tests/test_autoencoderkl_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,38 @@
else:
CASES = CASES_NO_ATTENTION

test_dtypes = [torch.float32]
if device.type == "cuda":
test_dtypes.append(torch.bfloat16)
test_dtypes.append(torch.float16)

DTYPE_CASES = []
for dtype in test_dtypes:
for case in CASES:
DTYPE_CASES.append(case + [dtype])


class TestAutoencoderKlMaisi(unittest.TestCase):

@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape):
net = AutoencoderKlMaisi(**input_param).to(device)
@parameterized.expand(DTYPE_CASES)
def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape, dtype):
net = AutoencoderKlMaisi(**input_param).to(device=device, dtype=dtype)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
result = net.forward(torch.randn(input_shape).to(device=device, dtype=dtype))
self.assertEqual(result[0].shape, expected_shape)
self.assertEqual(result[1].shape, expected_latent_shape)
self.assertEqual(result[2].shape, expected_latent_shape)

@parameterized.expand(CASES)
@parameterized.expand(DTYPE_CASES)
@SkipIfBeforePyTorchVersion((1, 11))
def test_shape_with_convtranspose_and_checkpointing(
self, input_param, input_shape, expected_shape, expected_latent_shape
self, input_param, input_shape, expected_shape, expected_latent_shape, dtype
):
input_param = input_param.copy()
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
net = AutoencoderKlMaisi(**input_param).to(device)
net = AutoencoderKlMaisi(**input_param).to(device=device, dtype=dtype)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
result = net.forward(torch.randn(input_shape).to(device=device, dtype=dtype))
self.assertEqual(result[0].shape, expected_shape)
self.assertEqual(result[1].shape, expected_latent_shape)
self.assertEqual(result[2].shape, expected_latent_shape)
Expand Down

0 comments on commit 8888a48

Please sign in to comment.