-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathranked_loss.py
91 lines (74 loc) · 2.69 KB
/
ranked_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# encoding: utf-8
"""
@author: zzg
@contact: [email protected]
"""
import torch
from torch import nn
def normalize_rank(x, axis=-1):
"""Normalizing to unit length along the specified dimension.
Args:
x: pytorch Variable
Returns:
x: pytorch Variable, same shape as input
"""
x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
return x
def euclidean_dist_rank(x, y):
"""
Args:
x: pytorch Variable, with shape [m, d]
y: pytorch Variable, with shape [n, d]
Returns:
dist: pytorch Variable, with shape [m, n]
"""
m, n = x.size(0), y.size(0)
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
dist = xx + yy
dist.addmm_(1, -2, x, y.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
return dist
def rank_loss(dist_mat, labels, margin, alpha, tval):
"""
Args:
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
labels: pytorch LongTensor, with shape [N]
"""
assert len(dist_mat.size()) == 2
assert dist_mat.size(0) == dist_mat.size(1)
N = dist_mat.size(0)
total_loss = 0.0
for ind in range(N):
is_pos = labels.eq(labels[ind])
is_pos[ind] = 0
is_neg = labels.ne(labels[ind])
dist_ap = dist_mat[ind][is_pos]
dist_an = dist_mat[ind][is_neg]
ap_is_pos = torch.clamp(torch.add(dist_ap, margin - alpha), min=0.0)
ap_pos_num = ap_is_pos.size(0) + 1e-5
ap_pos_val_sum = torch.sum(ap_is_pos)
loss_ap = torch.div(ap_pos_val_sum, float(ap_pos_num))
an_is_pos = torch.lt(dist_an, alpha)
an_less_alpha = dist_an[an_is_pos]
an_weight = torch.exp(tval * (-1 * an_less_alpha + alpha))
an_weight_sum = torch.sum(an_weight) + 1e-5
an_dist_lm = alpha - an_less_alpha
an_ln_sum = torch.sum(torch.mul(an_dist_lm, an_weight))
loss_an = torch.div(an_ln_sum, an_weight_sum)
total_loss = total_loss + loss_ap + loss_an
total_loss = total_loss * 1.0 / N
return total_loss
class RankedLoss(object):
"Ranked_List_Loss_for_Deep_Metric_Learning_CVPR_2019_paper"
def __init__(self, margin=None, alpha=None, tval=None):
self.margin = margin
self.alpha = alpha
self.tval = tval
def __call__(self, global_feat, labels, normalize_feature=True):
if normalize_feature:
global_feat = normalize_rank(global_feat, axis=-1)
dist_mat = euclidean_dist_rank(global_feat, global_feat)
total_loss = rank_loss(
dist_mat, labels, self.margin, self.alpha, self.tval)
return total_loss