From 338b1b97aeb50eee6223ff49b4565bad7c7fb7a6 Mon Sep 17 00:00:00 2001 From: lisb20 <2692465509@qq.com> Date: Thu, 27 Jul 2023 19:27:38 +0800 Subject: [PATCH] detr --- .../data/__pycache__/dataset.cpython-39.pyc | Bin 3261 -> 3197 bytes NLOS_detr/data/dataset.py | 4 +- NLOS_detr/models.py | 42 ++++++++++++------ .../utils/__pycache__/trainer.cpython-39.pyc | Bin 15641 -> 15705 bytes NLOS_detr/utils/trainer.py | 13 +++--- 5 files changed, 38 insertions(+), 21 deletions(-) diff --git a/NLOS_detr/data/__pycache__/dataset.cpython-39.pyc b/NLOS_detr/data/__pycache__/dataset.cpython-39.pyc index 02ebd99c379b494f930fa2f5c1c402ca89d72fcb..0ce2a7cb3784d67f34190bcce54bf8c6451d8291 100644 GIT binary patch delta 317 zcmdlh`B#EBk(ZZ?0SE-1A58hYk@pxQzbw9dx0rKM^NN&# zhHEkwX>N{SJ<7yrFjL-MY79p*QqV25iqQB5D2)fs^czJlw`*O}-ayV@^P18VpUT)r;-)Hj!;{XDDU!qaR zpk@!XwLb1k7}40{*b~C2r;tx-^-&KsEWWkcRFHYLz2XJl{QIa`WPZVV}vg<1^eM z1)$;A$g;?;2-${mbP9HiZ+mxCqHdSk3JHl_O~~j2nUk%NePJ9xCE{l4!jXt5kv+k2 jk`P?v!c_0P<6n`iFg75FR8_Lb8D^tNtt!kqzgqeSGlF78 diff --git a/NLOS_detr/data/dataset.py b/NLOS_detr/data/dataset.py index 83bcc8a..426e827 100644 --- a/NLOS_detr/data/dataset.py +++ b/NLOS_detr/data/dataset.py @@ -37,8 +37,8 @@ def __init__(self, tmp_dir = [os.path.join(str(peo), d) for d in os.listdir(os.path.join(self.dataset_dir, str(peo)))] self.dirs += tmp_dir[:600] - print('dirs: ', len(self.dirs)) - pdb.set_trace() + # print('dirs: ', len(self.dirs)) + # pdb.set_trace() if use_fileclient: self.npy_loader = npy_loader() diff --git a/NLOS_detr/models.py b/NLOS_detr/models.py index ab80acb..38806eb 100644 --- a/NLOS_detr/models.py +++ b/NLOS_detr/models.py @@ -184,8 +184,15 @@ def __init__(self, pretrained: bool = True, for m in init_modules: m.apply(self._init_weights) - def forward(self, Ix: tuple): - I, x_gt = Ix # (B, C, T, H, W) + def forward(self, Ix: tuple,): + I, x_gt,freeze = Ix # (B, C, T, H, W) + if freeze: + for para in self.transformer.parameters(): + para.requires_grad = False + else: + for para in self.transformer.parameters(): + para.requires_grad = True + delta_I = torch.sub(I[:, :, 1:], I[:, :, :-1]).float() B, T = x_gt.shape[:2] ## (b,T,6) if self.warmup_frames > 0: @@ -232,9 +239,9 @@ def forward(self, Ix: tuple): cost_matrix = torch.zeros((npeo,npeo), device=x_gt.device, requires_grad=False) pred_matched = torch.zeros((B,T,npeo * 2), device=x_gt.device, requires_grad=False) - x_loss_tot = 0 - v_loss_tot = 0 - m_loss_tot = 0 + x_loss_tot = torch.tensor(0.0,device=x_gt.device) + v_loss_tot = torch.tensor(0.0,device=x_gt.device) + m_loss_tot = torch.tensor(0.0,device=x_gt.device) class_gt = torch.zeros((B, npeo, self.max_peo + 1), device=x_gt.device) ## class: idx==0 -> empty 只分为有无人 valid_peo = [] @@ -256,12 +263,15 @@ def forward(self, Ix: tuple): pred_peo.append(npeo - torch.sum(p==0)) for i in range(npeo): for j in range(npeo): - x_loss = self.criterion(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*j:2*j+2]) - v_loss = self.compute_v_loss(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*j:2*j+2]) + if not freeze: + x_loss = self.criterion(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*j:2*j+2]) + v_loss = self.compute_v_loss(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*j:2*j+2]) # m_loss = self.criterion(class_pred[b,i,:], class_gt[b,j,:]) - m_loss = self.cross_entropy(class_pred[b,i,:], class_gt[b,j,:]) - cost_matrix[i,j] = x_loss * self.x_alpha + v_loss * self.v_alpha + m_loss * self.m_alpha - + m_loss = self.cross_entropy(class_pred[b,i,:], class_gt[b,j,:]) + cost_matrix[i,j] = x_loss * self.x_alpha + v_loss * self.v_alpha + m_loss * self.m_alpha + else: + m_loss = self.cross_entropy(class_pred[b,i,:], class_gt[b,j,:]) + cost_matrix[i,j] = m_loss * self.m_alpha match_res = linear_sum_assignment(cost_matrix.cpu().detach().numpy()) match.append(match_res[1]) @@ -269,10 +279,14 @@ def forward(self, Ix: tuple): for b in range(B): m = match[b] for i in range(npeo): - m_loss_tot += self.cross_entropy(class_pred[b,i,:], class_gt[b,m[i],:]) - pred_matched[b,:,2*i:2*i+2] = x_pred[b,:,2*m[i]:2*m[i]+2] - x_loss_tot += self.criterion(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*m[i]:2*m[i]+2]) - v_loss_tot += self.compute_v_loss(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*m[i]:2*m[i]+2]) + if not freeze: + m_loss_tot += self.cross_entropy(class_pred[b,i,:], class_gt[b,m[i],:]) + pred_matched[b,:,2*i:2*i+2] = x_pred[b,:,2*m[i]:2*m[i]+2] + x_loss_tot += self.criterion(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*m[i]:2*m[i]+2]) + v_loss_tot += self.compute_v_loss(x_pred[b,:,2*i:2*i+2], x_gt[b,self.warmup_frames:,2*m[i]:2*m[i]+2]) + else: + m_loss_tot += self.cross_entropy(class_pred[b,i,:], class_gt[b,m[i],:]) + pred_matched[b,:,2*i:2*i+2] = x_pred[b,:,2*m[i]:2*m[i]+2] cnt += 1 if valid_peo[b] == pred_peo[b] else 0 return (x_loss_tot,v_loss_tot,m_loss_tot), (pred_matched, cnt/B) diff --git a/NLOS_detr/utils/__pycache__/trainer.cpython-39.pyc b/NLOS_detr/utils/__pycache__/trainer.cpython-39.pyc index 36466ef06e46fea5caf2746555463c42d39ccb71..89e6bb66a09e9f119348c5f0b1082c3258574552 100644 GIT binary patch delta 1744 zcmZ9MTWlLe6ozMZoweg@e2cw~?{S>Cxz!|1OcR<)sRAk?;?hb*N@**N9MuIQ7K@5# zwh^?Ud5ID> zy!)RSUzq)3)_ua|vJ?2bNIzhwkG|#pD6MD~@p;;$Q{r&@(%42NaF#3`p@dUrJw=wO zv5%7~;V$lGYEiAxX=Pe9eWu?GOe<#4Y%Rxc*KI{rjP-W(8@maWpGSe*$4SZ0nspK& z+%L5qb%^kQj2z%P4H$=Nl+tx?YakMG9fmzoGyuzov+g6HE&m1rB}sZ_gIsi zCY61n(tU+4h)B;F`ii*FGv9uOk_A!#P5Wpi18Q7TZj+42m9qd&<=*#?o9S#@H%D?aYv&Z5F>ta(u}Jr#c`AdjD&+gHe{3UQEo zn1j0@QkZj5IZ0uAxfbEqc9}la#|Z2S_f$?XKE?Ol*ur%=Ex9eOX{W`&@X=f=zPFs| zaa{EnLeB5bz;ISZNW3?Em{Q>>gfwZVSM(Gfpt_hX#MOc_w6;=skkY)^y!&0cU%b5c j&2ZUPc86^q<1*ZyqZ7HjaR?jT5T{0>G`Y4g@*@2YB2}c! delta 1684 zcmY*ZYiOHQ7(VAqnrm}u?q8crljhR2N!zSl8?kLe z*;e<%V&#A^N+IbS4sYA&bmBl{2#UhgU$Tim#QqVTI#IfMo9}xPU7L?Q-}Aob{ygXX zuFRg8wNKe>76QKyPJGDDPM))W98<39#8xESv@T2->FXfbg;?hdu8 zBt)t%ImW;mkz0;|NJJ9_81rc2*qn{o7hne^wGpZWrt+31Yv+b0>)<9}L1+zeGq5nQ2sa*R zQs1h(PT{!LT6-OjO-e!dr&*^;e1s{}igvWpQ|hfHAe&*)x>5~;Y*`-0QAdtK*1)JL zdoJs=cg1=#3sQfZ91-CZr!R_?)N9ryN{Ds|T$xRaNX92FrmXAb?oEZB6~*2==%{$T z_Y_?a`!mndHE?M$I!noMQt~hZ=t{K%6>r0@<2(FP_ip#A9N5-+AsdDHu4H$QfPXN< zZ7E&b|UQi(`F@k6S=_vsB&+Z4q=7G`I0qKr2@O z?Vx`rbAYZ+(7(ajK>r4Ff&QJ$4Xqx~zrnoR*knF&s6Wy>4)RB({4f)&5@Y^Rf^$h8 z01mTBu8K4L9dqp}nI<6j+q^?&+o6ZtrHnd8Nu{e4uO%Q#vLIL%1Z`40D2;}dWv>v7 zft^Vn1`d0ZJObR(R?=N9SRCt<YG< zBC+cfb%@V)?cMhS_CAR4BZ3!U0>OZA4dEw*FA;u5_ys|aa2??zggnBpVrnp$-$Z%? z0Y6UK7Q*ice<1vc;79liVIRWZ2>*z81}ErEu{oH_e~Jk+7LoB%?}Yq^zh8y}$s=pY%q=TqRR^(`JD7#rAX1ohNqK_8)oERV`dy|+ zdKdx6ZTF-V?J^$!i~rzz>Ckf!UNy{2iOR@i+KVzs9}i-$#}K4vw@eCP&L_4;9;8Iv zoA(=}x=!(A{%-0M@8`q1VP$Cj>->I7`$T5X1$vv<*mEW@Z7SOXCWp2L!_#ykozot` RiLQw^3PBoK|D^CT{THd(p{D=< diff --git a/NLOS_detr/utils/trainer.py b/NLOS_detr/utils/trainer.py index 994175f..4704349 100644 --- a/NLOS_detr/utils/trainer.py +++ b/NLOS_detr/utils/trainer.py @@ -353,7 +353,10 @@ def train(self, epoch): if self.train_cfgs['amp']: with autocast(): - (x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels)) + if epoch < 15: + (x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels, True)) + else: + (x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels, False)) loss = x_loss * self.train_cfgs['x_loss_alpha'] + \ v_loss * self.train_cfgs['v_loss_alpha'] + \ @@ -460,9 +463,9 @@ def val(self, epoch): with torch.no_grad(): if self.train_cfgs['amp']: with autocast(): - (x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels)) + (x_loss, v_loss, m_loss), (preds,acc) = self.model((inputs, labels, False)) else: - (x_loss, v_loss, m_loss), preds = self.model((inputs, labels)) + (x_loss, v_loss, m_loss), preds = self.model((inputs, labels, False)) x_loss = x_loss.detach().clone() v_loss = v_loss.detach().clone() @@ -542,9 +545,9 @@ def test_plot(self, epoch, phase: str): with torch.no_grad(): if self.train_cfgs['amp']: with autocast(): - loss, (pred_routes,_) = self.model((frames, gt_routes)) + loss, (pred_routes,_) = self.model((frames, gt_routes, False)) else: - loss, (pred_routes,_) = self.model((frames, gt_routes)) + loss, (pred_routes,_) = self.model((frames, gt_routes, False)) for idx, (gt, pred, map_size) in enumerate(zip( gt_routes.cpu().numpy(), pred_routes.cpu().numpy(), map_sizes.cpu().numpy())):