-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathmulti_similarity_loss.py
executable file
·82 lines (63 loc) · 2.89 KB
/
multi_similarity_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: [email protected]
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import torch
from torch import nn
# from ret_benchmark.losses.registry import LOSS
# @LOSS.register('ms_loss')
class MultiSimilarityLoss(nn.Module):
def __init__(self, margin=0.1):
super(MultiSimilarityLoss, self).__init__()
self.thresh = 0.5
self.margin = margin
self.scale_pos = 2.0
self.scale_neg = 40.0
def forward(self, feats, labels):
assert feats.size(0) == labels.size(0), \
f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}"
batch_size = feats.size(0)
feats = nn.functional.normalize(feats, p=2, dim=1)
# Shape: batchsize * batch size
sim_mat = torch.matmul(feats, torch.t(feats))
epsilon = 1e-5
loss = list()
# for i in range(batch_size):
# # print(i,'ccccc')
# pos_pair_ = sim_mat[i][labels == labels[i]]
# # print(pos_pair_.shape)
# pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon]
# neg_pair_ = sim_mat[i][labels != labels[i]]
# neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)]
# pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)]
# if len(neg_pair) < 1 or len(pos_pair) < 1:
# continue
# # weighting step
# pos_loss = 1.0 / self.scale_pos * torch.log(
# 1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh))))
# neg_loss = 1.0 / self.scale_neg * torch.log(
# 1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh))))
# loss.append(pos_loss + neg_loss)
mask = labels.expand(batch_size, batch_size).eq(
labels.expand(batch_size, batch_size).t())
for i in range(batch_size):
pos_pair_ = sim_mat[i][mask[i]]
pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon]
neg_pair_ = sim_mat[i][mask[i] == 0]
neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)]
pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)]
if len(neg_pair) < 1 or len(pos_pair) < 1:
continue
# weighting step
pos_loss = 1.0 / self.scale_pos * torch.log(
1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh))))
neg_loss = 1.0 / self.scale_neg * torch.log(
1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh))))
loss.append(pos_loss + neg_loss)
# pos_loss =
if len(loss) == 0:
return torch.zeros([], requires_grad=True, device=feats.device)
loss = sum(loss) / batch_size
return loss