Skip to content

Commit

Permalink
fix: ➖ remove boundary loss (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrss authored Aug 15, 2024
1 parent e25241a commit 8576928
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 117 deletions.
1 change: 0 additions & 1 deletion src/cultionet/enums/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ class Destinations(StrEnum):


class LossTypes(StrEnum):
BOUNDARY = "BoundaryLoss"
CLASS_BALANCED_MSE = "ClassBalancedMSELoss"
TANIMOTO_COMPLEMENT = "TanimotoComplementLoss"
TANIMOTO = "TanimotoDistLoss"
Expand Down
1 change: 0 additions & 1 deletion src/cultionet/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .losses import (
BoundaryLoss,
ClassBalancedMSELoss,
CombinedLoss,
LossPreprocessing,
Expand Down
109 changes: 0 additions & 109 deletions src/cultionet/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,6 @@
import torch.nn as nn
import torch.nn.functional as F

try:
from kornia.contrib import distance_transform
except ImportError:
distance_transform = None

try:
import torch_topological.nn as topnn
except ImportError:
topnn = None


class LossPreprocessing(nn.Module):
def __init__(
Expand Down Expand Up @@ -381,105 +371,6 @@ def forward(
)


class BoundaryLoss(nn.Module):
"""Boundary (surface) loss.
Reference:
https://github.com/LIVIAETS/boundary-loss
"""

def __init__(self):
super().__init__()

assert distance_transform is not None

def fill_distances(
self,
distances: torch.Tensor,
targets: torch.Tensor,
):
dt = distance_transform(
F.pad(
(targets == 2).long().unsqueeze(1).float(),
pad=(
21,
21,
21,
21,
),
),
kernel_size=21,
h=0.1,
).squeeze(dim=1)[:, 21:-21, 21:-21]
dt /= dt.max()

idist = torch.where(
targets == 2, 0, torch.where(targets == 1, distances, 0)
)
idist = torch.where(targets > 0, idist, dt)

return idist

def forward(
self,
probs: torch.Tensor,
distances: torch.Tensor,
targets: torch.Tensor,
) -> torch.Tensor:
"""Performs a single forward pass.
Args:
probs: Predicted probabilities, shaped (B x H x W).
distances: Ground truth distance transform, shaped (B x H x W).
targets: Ground truth labels, shaped (B x H x W).
Returns:
Loss (float)
"""
distances = self.fill_distances(distances, targets)

return torch.einsum("bhw, bhw -> bhw", distances, 1.0 - probs).mean()


class TopologyLoss(nn.Module):
def __init__(self):
super().__init__()

if topnn is not None:
self.loss_func = topnn.SummaryStatisticLoss(
"total_persistence", p=2
)
self.cubical = topnn.CubicalComplex(dim=3)

def forward(
self,
inputs: torch.Tensor,
targets: torch.Tensor,
mask: T.Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Performs a single forward pass.
Args:
inputs: Predictions (probabilities) from model.
targets: Ground truth values.
"""
if mask is None:
targets = targets * mask
inputs = inputs * mask

persistence_information_target = self.cubical(targets)
persistence_information_target = [persistence_information_target[0]]

persistence_information = self.cubical(inputs)
persistence_information = [persistence_information[0]]

loss = self.loss_func(
persistence_information, persistence_information_target
)

return loss


class ClassBalancedMSELoss(nn.Module):
r"""
References:
Expand Down
6 changes: 0 additions & 6 deletions src/cultionet/models/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,9 +850,6 @@ def __init__(
self.edge_class = num_classes

self.loss_dict = {
LossTypes.BOUNDARY: {
"classification": cnetlosses.BoundaryLoss(),
},
LossTypes.CLASS_BALANCED_MSE: {
"classification": cnetlosses.ClassBalancedMSELoss(),
},
Expand Down Expand Up @@ -1042,9 +1039,6 @@ def __init__(
self.edge_class = num_classes

self.loss_dict = {
LossTypes.BOUNDARY: {
"classification": cnetlosses.BoundaryLoss(),
},
LossTypes.CLASS_BALANCED_MSE: {
"classification": cnetlosses.ClassBalancedMSELoss(),
},
Expand Down

0 comments on commit 8576928

Please sign in to comment.