Skip to content

Commit

Permalink
detr
Browse files Browse the repository at this point in the history
  • Loading branch information
lisb20 committed Jul 27, 2023
1 parent 55bcb24 commit 338b1b9
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 21 deletions.
Binary file modified NLOS_detr/data/__pycache__/dataset.cpython-39.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions NLOS_detr/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(self,
tmp_dir = [os.path.join(str(peo), d) for d in os.listdir(os.path.join(self.dataset_dir, str(peo)))]
self.dirs += tmp_dir[:600]

print('dirs: ', len(self.dirs))
pdb.set_trace()
# print('dirs: ', len(self.dirs))
# pdb.set_trace()

if use_fileclient:
self.npy_loader = npy_loader()
Expand Down
42 changes: 28 additions & 14 deletions NLOS_detr/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,15 @@ def __init__(self, pretrained: bool = True,
for m in init_modules:
m.apply(self._init_weights)

def forward(self, Ix: tuple):
I, x_gt = Ix # (B, C, T, H, W)
def forward(self, Ix: tuple,):
I, x_gt,freeze = Ix # (B, C, T, H, W)
if freeze:
for para in self.transformer.parameters():
para.requires_grad = False
else:
for para in self.transformer.parameters():
para.requires_grad = True

delta_I = torch.sub(I[:, :, 1:], I[:, :, :-1]).float()
B, T = x_gt.shape[:2] ## (b,T,6)
if self.warmup_frames > 0:
Expand Down Expand Up @@ -232,9 +239,9 @@ def forward(self, Ix: tuple):
cost_matrix = torch.zeros((npeo,npeo), device=x_gt.device, requires_grad=False)

pred_matched = torch.zeros((B,T,npeo * 2), device=x_gt.device, requires_grad=False)
x_loss_tot = 0
v_loss_tot = 0
m_loss_tot = 0
x_loss_tot = torch.tensor(0.0,device=x_gt.device)
v_loss_tot = torch.tensor(0.0,device=x_gt.device)
m_loss_tot = torch.tensor(0.0,device=x_gt.device)

class_gt = torch.zeros((B, npeo, self.max_peo + 1), device=x_gt.device) ## class: idx==0 -> empty 只分为有无人
valid_peo = []
Expand All @@ -256,23 +263,30 @@ def forward(self, Ix: tuple):
pred_peo.append(npeo - torch.sum(p==0))
for i in range(npeo):
for j in range(npeo):
x_loss = self.criterion(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*j:2*j+2])
v_loss = self.compute_v_loss(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*j:2*j+2])
if not freeze:
x_loss = self.criterion(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*j:2*j+2])
v_loss = self.compute_v_loss(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*j:2*j+2])
# m_loss = self.criterion(class_pred[b,i,:], class_gt[b,j,:])
m_loss = self.cross_entropy(class_pred[b,i,:], class_gt[b,j,:])
cost_matrix[i,j] = x_loss * self.x_alpha + v_loss * self.v_alpha + m_loss * self.m_alpha

m_loss = self.cross_entropy(class_pred[b,i,:], class_gt[b,j,:])
cost_matrix[i,j] = x_loss * self.x_alpha + v_loss * self.v_alpha + m_loss * self.m_alpha
else:
m_loss = self.cross_entropy(class_pred[b,i,:], class_gt[b,j,:])
cost_matrix[i,j] = m_loss * self.m_alpha
match_res = linear_sum_assignment(cost_matrix.cpu().detach().numpy())
match.append(match_res[1])

cnt = 0
for b in range(B):
m = match[b]
for i in range(npeo):
m_loss_tot += self.cross_entropy(class_pred[b,i,:], class_gt[b,m[i],:])
pred_matched[b,:,2*i:2*i+2] = x_pred[b,:,2*m[i]:2*m[i]+2]
x_loss_tot += self.criterion(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*m[i]:2*m[i]+2])
v_loss_tot += self.compute_v_loss(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*m[i]:2*m[i]+2])
if not freeze:
m_loss_tot += self.cross_entropy(class_pred[b,i,:], class_gt[b,m[i],:])
pred_matched[b,:,2*i:2*i+2] = x_pred[b,:,2*m[i]:2*m[i]+2]
x_loss_tot += self.criterion(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*m[i]:2*m[i]+2])
v_loss_tot += self.compute_v_loss(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*m[i]:2*m[i]+2])
else:
m_loss_tot += self.cross_entropy(class_pred[b,i,:], class_gt[b,m[i],:])
pred_matched[b,:,2*i:2*i+2] = x_pred[b,:,2*m[i]:2*m[i]+2]
cnt += 1 if valid_peo[b] == pred_peo[b] else 0

return (x_loss_tot,v_loss_tot,m_loss_tot), (pred_matched, cnt/B)
Expand Down
Binary file modified NLOS_detr/utils/__pycache__/trainer.cpython-39.pyc
Binary file not shown.
13 changes: 8 additions & 5 deletions NLOS_detr/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,10 @@ def train(self, epoch):

if self.train_cfgs['amp']:
with autocast():
(x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels))
if epoch < 15:
(x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels, True))
else:
(x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels, False))

loss = x_loss * self.train_cfgs['x_loss_alpha'] + \
v_loss * self.train_cfgs['v_loss_alpha'] + \
Expand Down Expand Up @@ -460,9 +463,9 @@ def val(self, epoch):
with torch.no_grad():
if self.train_cfgs['amp']:
with autocast():
(x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels))
(x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels, False))
else:
(x_loss, v_loss, m_loss), preds = self.model((inputs, labels))
(x_loss, v_loss, m_loss), preds = self.model((inputs, labels, False))

x_loss = x_loss.detach().clone()
v_loss = v_loss.detach().clone()
Expand Down Expand Up @@ -542,9 +545,9 @@ def test_plot(self, epoch, phase: str):
with torch.no_grad():
if self.train_cfgs['amp']:
with autocast():
loss, (pred_routes,_) = self.model((frames, gt_routes))
loss, (pred_routes,_) = self.model((frames, gt_routes, False))
else:
loss, (pred_routes,_) = self.model((frames, gt_routes))
loss, (pred_routes,_) = self.model((frames, gt_routes, False))

for idx, (gt, pred, map_size) in enumerate(zip(
gt_routes.cpu().numpy(), pred_routes.cpu().numpy(), map_sizes.cpu().numpy())):
Expand Down

0 comments on commit 338b1b9

Please sign in to comment.