diff --git a/NLOS_detr/data/__pycache__/dataset.cpython-39.pyc b/NLOS_detr/data/__pycache__/dataset.cpython-39.pyc index 02ebd99..0ce2a7c 100644 Binary files a/NLOS_detr/data/__pycache__/dataset.cpython-39.pyc and b/NLOS_detr/data/__pycache__/dataset.cpython-39.pyc differ diff --git a/NLOS_detr/data/dataset.py b/NLOS_detr/data/dataset.py index 83bcc8a..426e827 100644 --- a/NLOS_detr/data/dataset.py +++ b/NLOS_detr/data/dataset.py @@ -37,8 +37,8 @@ def __init__(self, tmp_dir = [os.path.join(str(peo), d) for d in os.listdir(os.path.join(self.dataset_dir, str(peo)))] self.dirs += tmp_dir[:600] - print('dirs: ', len(self.dirs)) - pdb.set_trace() + # print('dirs: ', len(self.dirs)) + # pdb.set_trace() if use_fileclient: self.npy_loader = npy_loader() diff --git a/NLOS_detr/models.py b/NLOS_detr/models.py index ab80acb..38806eb 100644 --- a/NLOS_detr/models.py +++ b/NLOS_detr/models.py @@ -184,8 +184,15 @@ def __init__(self, pretrained: bool = True, for m in init_modules: m.apply(self._init_weights) - def forward(self, Ix: tuple): - I, x_gt = Ix # (B, C, T, H, W) + def forward(self, Ix: tuple,): + I, x_gt,freeze = Ix # (B, C, T, H, W) + if freeze: + for para in self.transformer.parameters(): + para.requires_grad = False + else: + for para in self.transformer.parameters(): + para.requires_grad = True + delta_I = torch.sub(I[:, :, 1:], I[:, :, :-1]).float() B, T = x_gt.shape[:2] ## (b,T,6) if self.warmup_frames > 0: @@ -232,9 +239,9 @@ def forward(self, Ix: tuple): cost_matrix = torch.zeros((npeo,npeo), device=x_gt.device, requires_grad=False) pred_matched = torch.zeros((B,T,npeo * 2), device=x_gt.device, requires_grad=False) - x_loss_tot = 0 - v_loss_tot = 0 - m_loss_tot = 0 + x_loss_tot = torch.tensor(0.0,device=x_gt.device) + v_loss_tot = torch.tensor(0.0,device=x_gt.device) + m_loss_tot = torch.tensor(0.0,device=x_gt.device) class_gt = torch.zeros((B, npeo, self.max_peo + 1), device=x_gt.device) ## class: idx==0 -> empty 只分为有无人 valid_peo = [] @@ -256,12 +263,15 @@ def forward(self, Ix: tuple): pred_peo.append(npeo - torch.sum(p==0)) for i in range(npeo): for j in range(npeo): - x_loss = self.criterion(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*j:2*j+2]) - v_loss = self.compute_v_loss(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*j:2*j+2]) + if not freeze: + x_loss = self.criterion(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*j:2*j+2]) + v_loss = self.compute_v_loss(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*j:2*j+2]) # m_loss = self.criterion(class_pred[b,i,:], class_gt[b,j,:]) - m_loss = self.cross_entropy(class_pred[b,i,:], class_gt[b,j,:]) - cost_matrix[i,j] = x_loss * self.x_alpha + v_loss * self.v_alpha + m_loss * self.m_alpha - + m_loss = self.cross_entropy(class_pred[b,i,:], class_gt[b,j,:]) + cost_matrix[i,j] = x_loss * self.x_alpha + v_loss * self.v_alpha + m_loss * self.m_alpha + else: + m_loss = self.cross_entropy(class_pred[b,i,:], class_gt[b,j,:]) + cost_matrix[i,j] = m_loss * self.m_alpha match_res = linear_sum_assignment(cost_matrix.cpu().detach().numpy()) match.append(match_res[1]) @@ -269,10 +279,14 @@ def forward(self, Ix: tuple): for b in range(B): m = match[b] for i in range(npeo): - m_loss_tot += self.cross_entropy(class_pred[b,i,:], class_gt[b,m[i],:]) - pred_matched[b,:,2*i:2*i+2] = x_pred[b,:,2*m[i]:2*m[i]+2] - x_loss_tot += self.criterion(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*m[i]:2*m[i]+2]) - v_loss_tot += self.compute_v_loss(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*m[i]:2*m[i]+2]) + if not freeze: + m_loss_tot += self.cross_entropy(class_pred[b,i,:], class_gt[b,m[i],:]) + pred_matched[b,:,2*i:2*i+2] = x_pred[b,:,2*m[i]:2*m[i]+2] + x_loss_tot += self.criterion(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*m[i]:2*m[i]+2]) + v_loss_tot += self.compute_v_loss(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*m[i]:2*m[i]+2]) + else: + m_loss_tot += self.cross_entropy(class_pred[b,i,:], class_gt[b,m[i],:]) + pred_matched[b,:,2*i:2*i+2] = x_pred[b,:,2*m[i]:2*m[i]+2] cnt += 1 if valid_peo[b] == pred_peo[b] else 0 return (x_loss_tot,v_loss_tot,m_loss_tot), (pred_matched, cnt/B) diff --git a/NLOS_detr/utils/__pycache__/trainer.cpython-39.pyc b/NLOS_detr/utils/__pycache__/trainer.cpython-39.pyc index 36466ef..89e6bb6 100644 Binary files a/NLOS_detr/utils/__pycache__/trainer.cpython-39.pyc and b/NLOS_detr/utils/__pycache__/trainer.cpython-39.pyc differ diff --git a/NLOS_detr/utils/trainer.py b/NLOS_detr/utils/trainer.py index 994175f..4704349 100644 --- a/NLOS_detr/utils/trainer.py +++ b/NLOS_detr/utils/trainer.py @@ -353,7 +353,10 @@ def train(self, epoch): if self.train_cfgs['amp']: with autocast(): - (x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels)) + if epoch < 15: + (x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels, True)) + else: + (x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels, False)) loss = x_loss * self.train_cfgs['x_loss_alpha'] + \ v_loss * self.train_cfgs['v_loss_alpha'] + \ @@ -460,9 +463,9 @@ def val(self, epoch): with torch.no_grad(): if self.train_cfgs['amp']: with autocast(): - (x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels)) + (x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels, False)) else: - (x_loss, v_loss, m_loss), preds = self.model((inputs, labels)) + (x_loss, v_loss, m_loss), preds = self.model((inputs, labels, False)) x_loss = x_loss.detach().clone() v_loss = v_loss.detach().clone() @@ -542,9 +545,9 @@ def test_plot(self, epoch, phase: str): with torch.no_grad(): if self.train_cfgs['amp']: with autocast(): - loss, (pred_routes,_) = self.model((frames, gt_routes)) + loss, (pred_routes,_) = self.model((frames, gt_routes, False)) else: - loss, (pred_routes,_) = self.model((frames, gt_routes)) + loss, (pred_routes,_) = self.model((frames, gt_routes, False)) for idx, (gt, pred, map_size) in enumerate(zip( gt_routes.cpu().numpy(), pred_routes.cpu().numpy(), map_sizes.cpu().numpy())):