Skip to content

Commit

Permalink
Update wandb configuration and paths in option.py and lmbn_config.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
fubel committed Jun 18, 2024
1 parent b7c6831 commit 09a8d10
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 110 deletions.
103 changes: 66 additions & 37 deletions engine_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
from utils.functions import evaluation
from utils.re_ranking import re_ranking, re_ranking_gpu

try:
import wandb
except ImportError:
wandb = None

class Engine():

class Engine:
def __init__(self, args, model, optimizer, scheduler, loss, loader, ckpt):
self.args = args
self.train_loader = loader.train_loader
Expand All @@ -19,23 +24,27 @@ def __init__(self, args, model, optimizer, scheduler, loss, loader, ckpt):
self.scheduler = scheduler
self.loss = loss

self.lr = 0.
self.device = torch.device('cpu' if args.cpu else 'cuda')
self.lr = 0.0
self.device = torch.device("cpu" if args.cpu else "cuda")

if torch.cuda.is_available():
self.ckpt.write_log('[INFO] GPU: ' + torch.cuda.get_device_name(0))
self.ckpt.write_log("[INFO] GPU: " + torch.cuda.get_device_name(0))

self.ckpt.write_log(
'[INFO] Starting from epoch {}'.format(self.scheduler.last_epoch + 1))
"[INFO] Starting from epoch {}".format(self.scheduler.last_epoch + 1)
)

def train(self):
if args.wandb and wandb is not None:
wandb.init(project=args.wandb_name)

def train(self):
epoch = self.scheduler.last_epoch
lr = self.scheduler.get_last_lr()[0]

if lr != self.lr:
self.ckpt.write_log(
'[INFO] Epoch: {}\tLearning rate: {:.2e} '.format(epoch + 1, lr))
"[INFO] Epoch: {}\tLearning rate: {:.2e} ".format(epoch + 1, lr)
)
self.lr = lr
self.loss.start_log()
self.model.train()
Expand All @@ -53,29 +62,38 @@ def train(self):
loss.backward()
self.optimizer.step()

self.ckpt.write_log('\r[INFO] [{}/{}]\t{}/{}\t{}'.format(
epoch + 1, self.args.epochs,
batch + 1, len(self.train_loader),
self.loss.display_loss(batch)),
end='' if batch + 1 != len(self.train_loader) else '\n')
self.ckpt.write_log(
"\r[INFO] [{}/{}]\t{}/{}\t{}".format(
epoch + 1,
self.args.epochs,
batch + 1,
len(self.train_loader),
self.loss.display_loss(batch),
),
end="" if batch + 1 != len(self.train_loader) else "\n",
)

if wandb is not None:
wandb.log(self.loss.get_loss_dict(batch))

self.scheduler.step()
self.loss.end_log(len(self.train_loader))
# self._save_checkpoint(epoch, 0., self.ckpt.dir, is_best=True)

def test(self):
epoch = self.scheduler.last_epoch
self.ckpt.write_log('\n[INFO] Test:')
self.ckpt.write_log("\n[INFO] Test:")
self.model.eval()

self.ckpt.add_log(torch.zeros(1, 6))

with torch.no_grad():

qf, query_ids, query_cams = self.extract_feature(
self.query_loader, self.args)
self.query_loader, self.args
)
gf, gallery_ids, gallery_cams = self.extract_feature(
self.test_loader, self.args)
self.test_loader, self.args
)

if self.args.re_rank:
# q_g_dist = np.dot(qf, np.transpose(gf))
Expand All @@ -87,8 +105,7 @@ def test(self):
# cosine distance
dist = 1 - torch.mm(qf, gf.t()).cpu().numpy()

r, m_ap = evaluation(
dist, query_ids, gallery_ids, query_cams, gallery_cams, 50)
r, m_ap = evaluation(dist, query_ids, gallery_ids, query_cams, gallery_cams, 50)

self.ckpt.log[-1, 0] = epoch
self.ckpt.log[-1, 1] = m_ap
Expand All @@ -99,22 +116,34 @@ def test(self):
best = self.ckpt.log.max(0)

self.ckpt.write_log(
'[INFO] mAP: {:.4f} rank1: {:.4f} rank3: {:.4f} rank5: {:.4f} rank10: {:.4f} (Best: {:.4f} @epoch {})'.format(
m_ap,
r[0], r[2], r[4], r[9],
best[0][1], self.ckpt.log[best[1][1], 0]
), refresh=True
"[INFO] mAP: {:.4f} rank1: {:.4f} rank3: {:.4f} rank5: {:.4f} rank10: {:.4f} (Best: {:.4f} @epoch {})".format(
m_ap, r[0], r[2], r[4], r[9], best[0][1], self.ckpt.log[best[1][1], 0]
),
refresh=True,
)

if not self.args.test_only:

self._save_checkpoint(epoch, r[0], self.ckpt.dir, is_best=(
self.ckpt.log[best[1][1], 0] == epoch))
self._save_checkpoint(
epoch,
r[0],
self.ckpt.dir,
is_best=(self.ckpt.log[best[1][1], 0] == epoch),
)
self.ckpt.plot_map_rank(epoch)

if wandb is not None:
wandb.log(
{
"mAP": m_ap,
"rank1": r[0],
"rank3": r[2],
"rank5": r[4],
"rank10": r[9],
}
)

def fliphor(self, inputs):
inv_idx = torch.arange(inputs.size(
3) - 1, -1, -1).long() # N x C x H x W
inv_idx = torch.arange(inputs.size(3) - 1, -1, -1).long() # N x C x H x W
return inputs.index_select(3, inv_idx)

def extract_feature(self, loader, args):
Expand All @@ -128,16 +157,16 @@ def extract_feature(self, loader, args):

f1 = outputs.data.cpu()
# flip
inputs = inputs.index_select(
3, torch.arange(inputs.size(3) - 1, -1, -1))
inputs = inputs.index_select(3, torch.arange(inputs.size(3) - 1, -1, -1))
input_img = inputs.to(self.device)
outputs = self.model(input_img)
f2 = outputs.data.cpu()

ff = f1 + f2
if ff.dim() == 3:
fnorm = torch.norm(
ff, p=2, dim=1, keepdim=True) # * np.sqrt(ff.shape[2])
ff, p=2, dim=1, keepdim=True
) # * np.sqrt(ff.shape[2])
ff = ff.div(fnorm.expand_as(ff))
ff = ff.view(ff.size(0), -1)

Expand Down Expand Up @@ -175,13 +204,13 @@ def _parse_data_for_eval(self, data):
def _save_checkpoint(self, epoch, rank1, save_dir, is_best=False):
self.ckpt.save_checkpoint(
{
'state_dict': self.model.state_dict(),
'epoch': epoch,
'rank1': rank1,
'optimizer': self.optimizer.state_dict(),
'log': self.ckpt.log,
"state_dict": self.model.state_dict(),
"epoch": epoch,
"rank1": rank1,
"optimizer": self.optimizer.state_dict(),
"log": self.ckpt.log,
# 'scheduler': self.scheduler.state_dict(),
},
save_dir,
is_best=is_best
is_best=is_best,
)
12 changes: 7 additions & 5 deletions lmbn_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ batchtest: 32
beta1: 0.9
beta2: 0.999
bnneck: false
config: ''
config: ""
cosine_annealing: false
cpu: false
cuhk03_labeled: false
cutout: false
dampening: 0
data_test: market1501
data_train: market1501
datadir: /content/ReIDataset/
datadir: /media/CityFlow/reid/reid/
decay_type: step_50_80_110
drop_block: false
epochs: 140
epochs: 2
epsilon: 1.0e-08
feat_inference: after
feats: 512
Expand All @@ -33,8 +33,8 @@ margin: 0.7
model: LMBN_n
momentum: 0.9
nGPU: 1
nThread: 4
nep_id: ''
nThread: 8
nep_id: ""
nesterov: false
num_anchors: 1
num_classes: 751
Expand All @@ -52,3 +52,5 @@ w_ratio: 1.0
warmup: constant
weight_decay: 0.0005
width: 128
wandb: false
wandb_name: ""
Loading

0 comments on commit 09a8d10

Please sign in to comment.