-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
108 lines (88 loc) · 3.9 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torch.nn as nn
from utils import iou
class YOLOLoss(nn.Module):
def __init__(self, S=7, B=2, C=20):
super(YOLOLoss, self).__init__()
self.mse = nn.MSELoss(reduction='sum')
self.S = S
self.B = B
self.C = C
self.lambda_coord = 5
self.lambda_noobj = 0.5
def forward(self, predictions, labels):
"""
Calculate loss
Args:
predictions: dim: (batch_size, S*S*(B*5+C))
labels: dim: (batch_size, S, S, (B*5+C))
"""
# predictions shape: (batch_size, S, S, (B*5+C))
predictions = predictions.reshape(-1, self.S, self.S, self.B * 5 + self.C)
# every grid cell generate two bounding boxes
# iou_bx shape: (batch_size, S, S, 1)
iou_b1 = iou(predictions[..., self.C + 1:self.C + 5], labels[..., self.C + 1:self.C + 5])
iou_b2 = iou(predictions[..., self.C + 6:self.C + 10], labels[..., self.C + 1:self.C + 5])
# ious shape: (2, batch_size, S, S, 1)
ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)
# To two bbox of every grid cell, find the best bbox
# iou_max shape: (batch_size, S, S, 1)
# best_box_index shape: (batch_size, S, S, 1)
iou_max, best_box_index = torch.max(ious, dim=0)
# exists_box shape: (batch_size, S, S, 1)
exists_box = labels[..., self.C].unsqueeze(3) # obj
# ------------------------------
# coordinate regression loss
# ------------------------------
# box_predictions shape: (batch_size, S, S, 4)
box_predictions = exists_box * (
best_box_index * predictions[..., self.C + 6:self.C + 10]
+ (1 - best_box_index) * predictions[..., self.C + 1:self.C + 5]
)
# box_labels shape: (batch_size, S, S, 4)
box_labels = exists_box * labels[..., self.C + 1:self.C + 5]
box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4]) * torch.sqrt(
torch.abs(box_predictions[..., 2:4] + 1e-6))
box_labels[..., 2:4] = torch.sqrt(box_labels[..., 2:4])
# more detail about torch.flatten in difficulties.py(torch_flatten)
# default start_dim is 0
box_loss = self.mse(
torch.flatten(box_predictions, end_dim=-2),
torch.flatten(box_labels, end_dim=-2)
)
# -------------------------------------
# object confidence regression loss
# -------------------------------------
# pred_box shape: (batch_size, S, S, 1)
pred_box = (
best_box_index * predictions[..., self.C + 5:self.C + 6] +
(1 - best_box_index) * predictions[..., self.C:self.C + 1]
)
#
object_loss = self.mse(
torch.flatten(exists_box * pred_box),
torch.flatten(exists_box * labels[..., self.C:self.C + 1] * iou_max)
)
# (batch_size, S*S)
no_obj_loss = self.mse(
torch.flatten((1 - exists_box) * predictions[..., self.C:self.C + 1], start_dim=1),
torch.flatten((1 - exists_box) * labels[..., self.C:self.C + 1], start_dim=1)
)
no_obj_loss += self.mse(
torch.flatten((1 - exists_box) * predictions[..., self.C + 5:self.C + 6], start_dim=1),
torch.flatten((1 - exists_box) * labels[..., self.C:self.C + 1], start_dim=1)
)
# -------------------------------------
# class probability regression loss
# -------------------------------------
class_loss = self.mse(
torch.flatten(exists_box * predictions[..., :self.C], end_dim=-2),
torch.flatten(exists_box * labels[..., :self.C], end_dim=-2)
)
loss = (
self.lambda_coord * box_loss
+ object_loss
+ self.lambda_noobj * no_obj_loss
+ class_loss
)
return loss