From 09a8d10c9e2b210800abd768d7c9877dfddd3744 Mon Sep 17 00:00:00 2001 From: Fabian Herzog Date: Tue, 18 Jun 2024 21:12:27 +0200 Subject: [PATCH] Update wandb configuration and paths in option.py and lmbn_config.yaml --- engine_v3.py | 103 ++++++++++++++++++++++------------- lmbn_config.yaml | 12 +++-- loss/__init__.py | 138 ++++++++++++++++++++++++----------------------- option.py | 4 ++ 4 files changed, 147 insertions(+), 110 deletions(-) diff --git a/engine_v3.py b/engine_v3.py index 74d9ff91e..0f1412c1c 100644 --- a/engine_v3.py +++ b/engine_v3.py @@ -3,8 +3,13 @@ from utils.functions import evaluation from utils.re_ranking import re_ranking, re_ranking_gpu +try: + import wandb +except ImportError: + wandb = None -class Engine(): + +class Engine: def __init__(self, args, model, optimizer, scheduler, loss, loader, ckpt): self.args = args self.train_loader = loader.train_loader @@ -19,23 +24,27 @@ def __init__(self, args, model, optimizer, scheduler, loss, loader, ckpt): self.scheduler = scheduler self.loss = loss - self.lr = 0. - self.device = torch.device('cpu' if args.cpu else 'cuda') + self.lr = 0.0 + self.device = torch.device("cpu" if args.cpu else "cuda") if torch.cuda.is_available(): - self.ckpt.write_log('[INFO] GPU: ' + torch.cuda.get_device_name(0)) + self.ckpt.write_log("[INFO] GPU: " + torch.cuda.get_device_name(0)) self.ckpt.write_log( - '[INFO] Starting from epoch {}'.format(self.scheduler.last_epoch + 1)) + "[INFO] Starting from epoch {}".format(self.scheduler.last_epoch + 1) + ) - def train(self): + if args.wandb and wandb is not None: + wandb.init(project=args.wandb_name) + def train(self): epoch = self.scheduler.last_epoch lr = self.scheduler.get_last_lr()[0] if lr != self.lr: self.ckpt.write_log( - '[INFO] Epoch: {}\tLearning rate: {:.2e} '.format(epoch + 1, lr)) + "[INFO] Epoch: {}\tLearning rate: {:.2e} ".format(epoch + 1, lr) + ) self.lr = lr self.loss.start_log() self.model.train() @@ -53,11 +62,19 @@ def train(self): loss.backward() self.optimizer.step() - self.ckpt.write_log('\r[INFO] [{}/{}]\t{}/{}\t{}'.format( - epoch + 1, self.args.epochs, - batch + 1, len(self.train_loader), - self.loss.display_loss(batch)), - end='' if batch + 1 != len(self.train_loader) else '\n') + self.ckpt.write_log( + "\r[INFO] [{}/{}]\t{}/{}\t{}".format( + epoch + 1, + self.args.epochs, + batch + 1, + len(self.train_loader), + self.loss.display_loss(batch), + ), + end="" if batch + 1 != len(self.train_loader) else "\n", + ) + + if wandb is not None: + wandb.log(self.loss.get_loss_dict(batch)) self.scheduler.step() self.loss.end_log(len(self.train_loader)) @@ -65,17 +82,18 @@ def train(self): def test(self): epoch = self.scheduler.last_epoch - self.ckpt.write_log('\n[INFO] Test:') + self.ckpt.write_log("\n[INFO] Test:") self.model.eval() self.ckpt.add_log(torch.zeros(1, 6)) with torch.no_grad(): - qf, query_ids, query_cams = self.extract_feature( - self.query_loader, self.args) + self.query_loader, self.args + ) gf, gallery_ids, gallery_cams = self.extract_feature( - self.test_loader, self.args) + self.test_loader, self.args + ) if self.args.re_rank: # q_g_dist = np.dot(qf, np.transpose(gf)) @@ -87,8 +105,7 @@ def test(self): # cosine distance dist = 1 - torch.mm(qf, gf.t()).cpu().numpy() - r, m_ap = evaluation( - dist, query_ids, gallery_ids, query_cams, gallery_cams, 50) + r, m_ap = evaluation(dist, query_ids, gallery_ids, query_cams, gallery_cams, 50) self.ckpt.log[-1, 0] = epoch self.ckpt.log[-1, 1] = m_ap @@ -99,22 +116,34 @@ def test(self): best = self.ckpt.log.max(0) self.ckpt.write_log( - '[INFO] mAP: {:.4f} rank1: {:.4f} rank3: {:.4f} rank5: {:.4f} rank10: {:.4f} (Best: {:.4f} @epoch {})'.format( - m_ap, - r[0], r[2], r[4], r[9], - best[0][1], self.ckpt.log[best[1][1], 0] - ), refresh=True + "[INFO] mAP: {:.4f} rank1: {:.4f} rank3: {:.4f} rank5: {:.4f} rank10: {:.4f} (Best: {:.4f} @epoch {})".format( + m_ap, r[0], r[2], r[4], r[9], best[0][1], self.ckpt.log[best[1][1], 0] + ), + refresh=True, ) if not self.args.test_only: - - self._save_checkpoint(epoch, r[0], self.ckpt.dir, is_best=( - self.ckpt.log[best[1][1], 0] == epoch)) + self._save_checkpoint( + epoch, + r[0], + self.ckpt.dir, + is_best=(self.ckpt.log[best[1][1], 0] == epoch), + ) self.ckpt.plot_map_rank(epoch) + if wandb is not None: + wandb.log( + { + "mAP": m_ap, + "rank1": r[0], + "rank3": r[2], + "rank5": r[4], + "rank10": r[9], + } + ) + def fliphor(self, inputs): - inv_idx = torch.arange(inputs.size( - 3) - 1, -1, -1).long() # N x C x H x W + inv_idx = torch.arange(inputs.size(3) - 1, -1, -1).long() # N x C x H x W return inputs.index_select(3, inv_idx) def extract_feature(self, loader, args): @@ -128,8 +157,7 @@ def extract_feature(self, loader, args): f1 = outputs.data.cpu() # flip - inputs = inputs.index_select( - 3, torch.arange(inputs.size(3) - 1, -1, -1)) + inputs = inputs.index_select(3, torch.arange(inputs.size(3) - 1, -1, -1)) input_img = inputs.to(self.device) outputs = self.model(input_img) f2 = outputs.data.cpu() @@ -137,7 +165,8 @@ def extract_feature(self, loader, args): ff = f1 + f2 if ff.dim() == 3: fnorm = torch.norm( - ff, p=2, dim=1, keepdim=True) # * np.sqrt(ff.shape[2]) + ff, p=2, dim=1, keepdim=True + ) # * np.sqrt(ff.shape[2]) ff = ff.div(fnorm.expand_as(ff)) ff = ff.view(ff.size(0), -1) @@ -175,13 +204,13 @@ def _parse_data_for_eval(self, data): def _save_checkpoint(self, epoch, rank1, save_dir, is_best=False): self.ckpt.save_checkpoint( { - 'state_dict': self.model.state_dict(), - 'epoch': epoch, - 'rank1': rank1, - 'optimizer': self.optimizer.state_dict(), - 'log': self.ckpt.log, + "state_dict": self.model.state_dict(), + "epoch": epoch, + "rank1": rank1, + "optimizer": self.optimizer.state_dict(), + "log": self.ckpt.log, # 'scheduler': self.scheduler.state_dict(), }, save_dir, - is_best=is_best + is_best=is_best, ) diff --git a/lmbn_config.yaml b/lmbn_config.yaml index cbbebda9d..a14a728d8 100644 --- a/lmbn_config.yaml +++ b/lmbn_config.yaml @@ -7,7 +7,7 @@ batchtest: 32 beta1: 0.9 beta2: 0.999 bnneck: false -config: '' +config: "" cosine_annealing: false cpu: false cuhk03_labeled: false @@ -15,10 +15,10 @@ cutout: false dampening: 0 data_test: market1501 data_train: market1501 -datadir: /content/ReIDataset/ +datadir: /media/CityFlow/reid/reid/ decay_type: step_50_80_110 drop_block: false -epochs: 140 +epochs: 2 epsilon: 1.0e-08 feat_inference: after feats: 512 @@ -33,8 +33,8 @@ margin: 0.7 model: LMBN_n momentum: 0.9 nGPU: 1 -nThread: 4 -nep_id: '' +nThread: 8 +nep_id: "" nesterov: false num_anchors: 1 num_classes: 751 @@ -52,3 +52,5 @@ w_ratio: 1.0 warmup: constant weight_decay: 0.0005 width: 128 +wandb: false +wandb_name: "" diff --git a/loss/__init__.py b/loss/__init__.py index d5d363a07..0c30be9a7 100644 --- a/loss/__init__.py +++ b/loss/__init__.py @@ -3,7 +3,8 @@ from importlib import import_module import matplotlib -matplotlib.use('Agg') + +matplotlib.use("Agg") import matplotlib.pyplot as plt import torch @@ -17,46 +18,49 @@ from loss.center_loss import CenterLoss -class LossFunction(): +class LossFunction: def __init__(self, args, ckpt): super(LossFunction, self).__init__() - ckpt.write_log('[INFO] Making loss...') + ckpt.write_log("[INFO] Making loss...") self.nGPU = args.nGPU self.args = args self.loss = [] - for loss in args.loss.split('+'): - weight, loss_type = loss.split('*') - if loss_type == 'CrossEntropy': + for loss in args.loss.split("+"): + weight, loss_type = loss.split("*") + if loss_type == "CrossEntropy": if args.if_labelsmooth: loss_function = CrossEntropyLabelSmooth( - num_classes=args.num_classes) - ckpt.write_log('[INFO] Label Smoothing On.') + num_classes=args.num_classes + ) + ckpt.write_log("[INFO] Label Smoothing On.") else: loss_function = nn.CrossEntropyLoss() - elif loss_type == 'Triplet': + elif loss_type == "Triplet": loss_function = TripletLoss(args.margin) - elif loss_type == 'GroupLoss': + elif loss_type == "GroupLoss": loss_function = GroupLoss( - total_classes=args.num_classes, max_iter=args.T, num_anchors=args.num_anchors) - elif loss_type == 'MSLoss': + total_classes=args.num_classes, + max_iter=args.T, + num_anchors=args.num_anchors, + ) + elif loss_type == "MSLoss": loss_function = MultiSimilarityLoss(margin=args.margin) - elif loss_type == 'Focal': - loss_function = FocalLoss(reduction='mean') - elif loss_type == 'OSLoss': + elif loss_type == "Focal": + loss_function = FocalLoss(reduction="mean") + elif loss_type == "OSLoss": loss_function = OSM_CAA_Loss() - elif loss_type == 'CenterLoss': + elif loss_type == "CenterLoss": loss_function = CenterLoss( - num_classes=args.num_classes, feat_dim=args.feats) + num_classes=args.num_classes, feat_dim=args.feats + ) - self.loss.append({ - 'type': loss_type, - 'weight': float(weight), - 'function': loss_function - }) + self.loss.append( + {"type": loss_type, "weight": float(weight), "function": loss_function} + ) if len(self.loss) > 1: - self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) + self.loss.append({"type": "Total", "weight": 0, "function": None}) self.log = torch.Tensor() @@ -64,63 +68,57 @@ def compute(self, outputs, labels): losses = [] for i, l in enumerate(self.loss): - if l['type'] in ['CrossEntropy']: - + if l["type"] in ["CrossEntropy"]: if isinstance(outputs[0], list): - loss = [l['function'](output, labels) - for output in outputs[0]] + loss = [l["function"](output, labels) for output in outputs[0]] elif isinstance(outputs[0], torch.Tensor): - loss = [l['function'](outputs[0], labels)] + loss = [l["function"](outputs[0], labels)] else: - raise TypeError( - 'Unexpected type: {}'.format(type(outputs[0]))) + raise TypeError("Unexpected type: {}".format(type(outputs[0]))) loss = sum(loss) - effective_loss = l['weight'] * loss + effective_loss = l["weight"] * loss losses.append(effective_loss) self.log[-1, i] += effective_loss.item() - elif l['type'] in ['Triplet', 'MSLoss']: + elif l["type"] in ["Triplet", "MSLoss"]: if isinstance(outputs[-1], list): - loss = [l['function'](output, labels) - for output in outputs[-1]] + loss = [l["function"](output, labels) for output in outputs[-1]] elif isinstance(outputs[-1], torch.Tensor): - loss = [l['function'](outputs[-1], labels)] + loss = [l["function"](outputs[-1], labels)] else: - raise TypeError( - 'Unexpected type: {}'.format(type(outputs[-1]))) + raise TypeError("Unexpected type: {}".format(type(outputs[-1]))) loss = sum(loss) - effective_loss = l['weight'] * loss + effective_loss = l["weight"] * loss losses.append(effective_loss) self.log[-1, i] += effective_loss.item() - elif l['type'] in ['GroupLoss']: + elif l["type"] in ["GroupLoss"]: if isinstance(outputs[-1], list): - loss = [l['function'](output[0], labels, output[1]) - for output in zip(outputs[-1], outputs[0][:3])] + loss = [ + l["function"](output[0], labels, output[1]) + for output in zip(outputs[-1], outputs[0][:3]) + ] elif isinstance(outputs[-1], torch.Tensor): - loss = [l['function'](outputs[-1], labels)] + loss = [l["function"](outputs[-1], labels)] else: - raise TypeError( - 'Unexpected type: {}'.format(type(outputs[-1]))) + raise TypeError("Unexpected type: {}".format(type(outputs[-1]))) loss = sum(loss) - effective_loss = l['weight'] * loss + effective_loss = l["weight"] * loss losses.append(effective_loss) self.log[-1, i] += effective_loss.item() - elif l['type'] in ['CenterLoss']: + elif l["type"] in ["CenterLoss"]: if isinstance(outputs[-1], list): - loss = [l['function'](output, labels) - for output in outputs[-1]] + loss = [l["function"](output, labels) for output in outputs[-1]] elif isinstance(outputs[-1], torch.Tensor): - loss = [l['function'](outputs[-1], labels)] + loss = [l["function"](outputs[-1], labels)] else: - raise TypeError( - 'Unexpected type: {}'.format(type(outputs[-1]))) + raise TypeError("Unexpected type: {}".format(type(outputs[-1]))) loss = sum(loss) - effective_loss = l['weight'] * loss + effective_loss = l["weight"] * loss losses.append(effective_loss) self.log[-1, i] += effective_loss.item() @@ -144,30 +142,37 @@ def display_loss(self, batch): n_samples = batch + 1 log = [] for l, c in zip(self.loss, self.log[-1]): - log.append('[{}: {:.6f}]'.format(l['type'], c / n_samples)) + log.append("[{}: {:.6f}]".format(l["type"], c / n_samples)) + + return "".join(log) - return ''.join(log) + def get_loss_dict(self, batch): + n_samples = batch + 1 + loss_dict = {} + for l, c in zip(self.loss, self.log[-1]): + loss_dict[l["type"]] = c.item() / n_samples + return loss_dict def plot_loss(self, apath, epoch): axis = np.linspace(1, epoch, epoch) for i, l in enumerate(self.loss): - label = '{} Loss'.format(l['type']) + label = "{} Loss".format(l["type"]) fig = plt.figure() plt.title(label) plt.plot(axis, self.log[:, i].numpy(), label=label) plt.legend() - plt.xlabel('Epochs') - plt.ylabel('Loss') + plt.xlabel("Epochs") + plt.ylabel("Loss") plt.grid(True) - plt.savefig('{}/loss_{}.jpg'.format(apath, l['type'])) + plt.savefig("{}/loss_{}.jpg".format(apath, l["type"])) plt.close(fig) # Following codes not being used def step(self): for l in self.get_loss_module(): - if hasattr(l, 'scheduler'): + if hasattr(l, "scheduler"): l.scheduler.step() def get_loss_module(self): @@ -177,22 +182,19 @@ def get_loss_module(self): return self.loss_module.module def save(self, apath): - torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) - torch.save(self.log, os.path.join(apath, 'loss_log.pt')) + torch.save(self.state_dict(), os.path.join(apath, "loss.pt")) + torch.save(self.log, os.path.join(apath, "loss_log.pt")) def load(self, apath, cpu=False): if cpu: - kwargs = {'map_location': lambda storage, loc: storage} + kwargs = {"map_location": lambda storage, loc: storage} else: kwargs = {} - self.load_state_dict(torch.load( - os.path.join(apath, 'loss.pt'), - **kwargs - )) - self.log = torch.load(os.path.join(apath, 'loss_log.pt')) + self.load_state_dict(torch.load(os.path.join(apath, "loss.pt"), **kwargs)) + self.log = torch.load(os.path.join(apath, "loss_log.pt")) for l in self.loss_module: - if hasattr(l, 'scheduler'): + if hasattr(l, "scheduler"): for _ in range(len(self.log)): l.scheduler.step() diff --git a/option.py b/option.py index 0485d7e19..f56ba96be 100755 --- a/option.py +++ b/option.py @@ -127,6 +127,10 @@ # parser.add_argument("--resume", action='store_true', help='whether resume training from specific checkpoint') # parser.add_argument('--save_models', action='store_true', help='save all intermediate models') +# for wandb +parser.add_argument('--wandb', action='store_true', help='use wandb') +parser.add_argument('--wandb_name', type=str, default='', help='wandb project name') + args = parser.parse_args()