Skip to content

Commit

Permalink
MaisiVAE: Auto-cast GroupNorm, deprecate norm_float16
Browse files Browse the repository at this point in the history
Signed-off-by: John Zielke <[email protected]>
  • Loading branch information
johnzielke committed Feb 4, 2025
1 parent 8dcb9dc commit 81865e4
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 30 deletions.
36 changes: 14 additions & 22 deletions monai/apps/generation/maisi/networks/autoencoderkl_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from monai.networks.blocks.spatialattention import SpatialAttentionBlock
from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL
from monai.utils.type_conversion import convert_to_tensor
from monai.utils.deprecate_utils import deprecated_arg

# Set up logging configuration
logger = logging.getLogger(__name__)
Expand All @@ -33,7 +34,7 @@ def _empty_cuda_cache(save_mem: bool) -> None:
torch.cuda.empty_cache()
return


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class MaisiGroupNorm3D(nn.GroupNorm):
"""
Custom 3D Group Normalization with optional print_info output.
Expand All @@ -43,7 +44,7 @@ class MaisiGroupNorm3D(nn.GroupNorm):
num_channels: Number of channels for the group norm.
eps: Epsilon value for numerical stability.
affine: Whether to use learnable affine parameters, default to `True`.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: Deprecated argument.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
Expand All @@ -59,14 +60,15 @@ def __init__(
save_mem: bool = True,
):
super().__init__(num_groups, num_channels, eps, affine)
self.norm_float16 = norm_float16
self.print_info = print_info
self.save_mem = save_mem

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.print_info:
logger.info(f"MaisiGroupNorm3D with input size: {input.size()}")

target_dtype = input.dtype

if len(input.shape) != 5:
raise ValueError("Expected a 5D tensor")

Expand All @@ -75,13 +77,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

inputs = []
for i in range(input.size(1)):
array = input[:, i : i + 1, ...].to(dtype=torch.float32)
array = input[:, i : i + 1, ...]
mean = array.mean([2, 3, 4, 5], keepdim=True)
std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_()
if self.norm_float16:
inputs.append(((array - mean) / std).to(dtype=torch.float16))
else:
inputs.append((array - mean) / std)
inputs.append(((array - mean) / std).to(dtype=target_dtype))

del input
_empty_cuda_cache(self.save_mem)
Expand Down Expand Up @@ -375,7 +374,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
return x


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class MaisiResBlock(nn.Module):
"""
Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
Expand Down Expand Up @@ -417,7 +416,6 @@ def __init__(
num_channels=in_channels,
eps=norm_eps,
affine=True,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand All @@ -439,7 +437,6 @@ def __init__(
num_channels=out_channels,
eps=norm_eps,
affine=True,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -500,7 +497,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
out_tensor: torch.Tensor = convert_to_tensor(out)
return out_tensor


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class MaisiEncoder(nn.Module):
"""
Convolutional cascade that downsamples the image into a spatial latent space.
Expand All @@ -520,7 +517,7 @@ class MaisiEncoder(nn.Module):
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: Deprecated argument.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
Expand Down Expand Up @@ -591,7 +588,6 @@ def __init__(
out_channels=output_channel,
num_splits=num_splits,
dim_split=dim_split,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -660,7 +656,6 @@ def __init__(
num_channels=num_channels[-1],
eps=norm_eps,
affine=True,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -689,7 +684,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
_empty_cuda_cache(self.save_mem)
return x


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class MaisiDecoder(nn.Module):
"""
Convolutional cascade upsampling from a spatial latent space into an image space.
Expand All @@ -710,7 +705,7 @@ class MaisiDecoder(nn.Module):
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: Deprecated argument.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
Expand Down Expand Up @@ -809,7 +804,6 @@ def __init__(
out_channels=block_out_ch,
num_splits=num_splits,
dim_split=dim_split,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -848,7 +842,6 @@ def __init__(
num_channels=block_in_ch,
eps=norm_eps,
affine=True,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand Down Expand Up @@ -878,6 +871,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


@deprecated_arg("norm_float16", since="1.5.0", removed="1.7.0")
class AutoencoderKlMaisi(AutoencoderKL):
"""
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
Expand All @@ -901,7 +895,7 @@ class AutoencoderKlMaisi(AutoencoderKL):
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: Deprecated argument.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
Expand Down Expand Up @@ -964,7 +958,6 @@ def __init__(
use_flash_attention=use_flash_attention,
num_splits=num_splits,
dim_split=dim_split,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
Expand All @@ -985,7 +978,6 @@ def __init__(
use_convtranspose=use_convtranspose,
num_splits=num_splits,
dim_split=dim_split,
norm_float16=norm_float16,
print_info=print_info,
save_mem=save_mem,
)
25 changes: 17 additions & 8 deletions tests/test_autoencoderkl_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,37 @@
else:
CASES = CASES_NO_ATTENTION

test_dtypes = [torch.float32, torch.float16]
if device.type == "cuda":
test_dtypes.append(torch.bfloat16)

DTYPE_CASES = []
for dtype in test_dtypes:
for case in CASES:
DTYPE_CASES.append(case + [dtype])


class TestAutoencoderKlMaisi(unittest.TestCase):

@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape):
net = AutoencoderKlMaisi(**input_param).to(device)
@parameterized.expand(DTYPE_CASES)
def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape, dtype):
net = AutoencoderKlMaisi(**input_param).to(device=device,dtype=dtype)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
result = net.forward(torch.randn(input_shape).to(device=device,dtype=dtype))
self.assertEqual(result[0].shape, expected_shape)
self.assertEqual(result[1].shape, expected_latent_shape)
self.assertEqual(result[2].shape, expected_latent_shape)

@parameterized.expand(CASES)
@parameterized.expand(DTYPE_CASES)
@SkipIfBeforePyTorchVersion((1, 11))
def test_shape_with_convtranspose_and_checkpointing(
self, input_param, input_shape, expected_shape, expected_latent_shape
self, input_param, input_shape, expected_shape, expected_latent_shape, dtype
):
input_param = input_param.copy()
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
net = AutoencoderKlMaisi(**input_param).to(device)
net = AutoencoderKlMaisi(**input_param).to(device=device,dtype=dtype)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
result = net.forward(torch.randn(input_shape).to(device=device,dtype=dtype))
self.assertEqual(result[0].shape, expected_shape)
self.assertEqual(result[1].shape, expected_latent_shape)
self.assertEqual(result[2].shape, expected_latent_shape)
Expand Down

0 comments on commit 81865e4

Please sign in to comment.