From 173e5654dfe4ed32a577eccf0486e9370e2d816d Mon Sep 17 00:00:00 2001 From: wangji Date: Mon, 1 May 2023 15:13:43 +0800 Subject: [PATCH] styleEDL --- .gitignore | 4 + README.md | 21 +++ configs/base.yaml | 65 +++++++ configs/train.yaml | 28 +++ dataset/__init__.py | 1 + dataset/dataset.py | 48 ++++++ envs.yaml | 176 +++++++++++++++++++ loss/EMD.py | 44 +++++ loss/__init__.py | 3 + loss/adv_div.py | 106 ++++++++++++ loss/share_specific.py | 37 ++++ main.py | 89 ++++++++++ network/GMG.py | 313 +++++++++++++++++++++++++++++++++ network/__init__.py | 13 ++ scheduler/__init__.py | 47 +++++ trainer/__init__.py | 10 ++ trainer/basetrainer.py | 292 +++++++++++++++++++++++++++++++ trainer/gmgtrainer.py | 155 +++++++++++++++++ transforms/__init__.py | 30 ++++ transforms/multiscale_crop.py | 94 ++++++++++ transforms/normalize.py | 29 ++++ utils/__init__.py | 5 + utils/metric.py | 316 ++++++++++++++++++++++++++++++++++ utils/opt_yaml.py | 88 ++++++++++ utils/options.py | 41 +++++ utils/tools_torch.py | 15 ++ utils/writer.py | 62 +++++++ 27 files changed, 2132 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 configs/base.yaml create mode 100644 configs/train.yaml create mode 100644 dataset/__init__.py create mode 100644 dataset/dataset.py create mode 100644 envs.yaml create mode 100644 loss/EMD.py create mode 100644 loss/__init__.py create mode 100644 loss/adv_div.py create mode 100644 loss/share_specific.py create mode 100644 main.py create mode 100644 network/GMG.py create mode 100644 network/__init__.py create mode 100644 scheduler/__init__.py create mode 100644 trainer/__init__.py create mode 100644 trainer/basetrainer.py create mode 100644 trainer/gmgtrainer.py create mode 100644 transforms/__init__.py create mode 100644 transforms/multiscale_crop.py create mode 100644 transforms/normalize.py create mode 100644 utils/__init__.py create mode 100644 utils/metric.py create mode 100644 utils/opt_yaml.py create mode 100644 utils/options.py create mode 100644 utils/tools_torch.py create mode 100644 utils/writer.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..36d1541 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.pyc +runs/ +logs/ +.vscode/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..ded3e93 --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ + +## 运行 +```bash +python main.py --tag +``` ++ `--tag` 指定保存目录的前缀,默认为cache + + + +## 添加新的模型 + +1. 在`network`文件夹下新建新的文件 +1. 在`network`下`__init__`中的`models`加入新的网络模型 +1. 在`configs`文件夹中新建对应的配置文件 + + +## 添加新的trainer + +1. 在`trainer`文件夹中新建文件,继承`basetrainer` +1. 在`trainer`下`__init__`中的`trainers`加入新的trainer +1. 运行时,由配置文件中的`trainer`指定 diff --git a/configs/base.yaml b/configs/base.yaml new file mode 100644 index 0000000..07357dc --- /dev/null +++ b/configs/base.yaml @@ -0,0 +1,65 @@ +model: backbone +trainer: backbone + +# =========================== +# train +seed: 330 +gpu_id: 0 +epochs: 90 +batch_size: 8 + + +image_size: 448 +isNormalize: True +num_workers: 12 + +dataset: Twitter_LDL +num_classes: 8 +data_path: /media/Harddisk_A/emotion_dataset/ + + +resume_path: +save_interval: 2000 +save_mark: 0.4 +display_interval: 20 + +# =========================== +# optimizer +momentum: 0.1 + +# =========================== +# learning rate +lr: 0.01 + +# =========================== +# scheduler +scheduler: scheduler_cos + +# stepLR +scheduler_stepLR: + step_size: 15 + gamma: 0.5 + +# MultiStepLR +scheduler_multi: + milestones: [10, 20, 50] + gamma: 0.5 + +# ExponentialLR +scheduler_exp: + gamma: 0.5 + +# CosineAnnealingLR +scheduler_cos: + T_max: 40 + eta_min: 0 + +# CyclicLR +scheduler_cyclic: + max_lr: 0.05 + up: 10 + down: 10 + +# lambda +scheduler_lambda: + lr_lambda: None diff --git a/configs/train.yaml b/configs/train.yaml new file mode 100644 index 0000000..938bd7b --- /dev/null +++ b/configs/train.yaml @@ -0,0 +1,28 @@ +model: gmg +trainer: gmg +seed: 330 + + +mu: 0.6 +lambda: 0.6 +parts: 2 + + +gpu_id: 1 +batch_size: 16 + +epochs: 90 + +lr: 0.005 +scheduler: scheduler_multi +scheduler_multi: + milestones: [10, 30, 50, 70] + gamma: 0.1 + +scheduler_cos: + T_max: 40 + eta_min: 0 + +momentum: 0.1 +weight_decay: 1e-7 +# resume_path: ./logs/Mar-01_14:45:22_final/epoch_55.pth diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000..d9ee325 --- /dev/null +++ b/dataset/__init__.py @@ -0,0 +1 @@ +from .dataset import Dataset_LDL diff --git a/dataset/dataset.py b/dataset/dataset.py new file mode 100644 index 0000000..65e984e --- /dev/null +++ b/dataset/dataset.py @@ -0,0 +1,48 @@ +# -*- encoding: utf-8 -*- + +import os +import torch +from PIL import Image, ImageFile +from torch.utils.data import Dataset + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +class Dataset_LDL(Dataset): + def __init__( + self, + root_path, + mode, + transforms, + ): + super().__init__() + self.transforms = transforms + self.images_path = [] + self.labels = [] + self.cls = [] + + if mode == "train": + file_name = 'ground_truth_train.txt' + elif mode == "test": + file_name = 'ground_truth_test.txt' + + # 读入文件 + with open(os.path.join(root_path, file_name), 'r') as f: + file = f.readlines() + + for line in file: + temp = line.rstrip('\n').rstrip(' ').split(' ') + self.images_path.append(os.path.join(root_path, "images", temp[0])) + label = [eval(i) for i in temp[1:-1]] + self.cls.append(eval(temp[-1])) + self.labels.append(label) + + def __getitem__(self, index): + original_img = Image.open(self.images_path[index]).convert('RGB') + label = torch.FloatTensor(self.labels[index]) + cls = self.cls[index] + original_img = self.transforms(original_img) + return original_img, label, cls + + def __len__(self) -> int: + return len(self.images_path) diff --git a/envs.yaml b/envs.yaml new file mode 100644 index 0000000..ed3e5ea --- /dev/null +++ b/envs.yaml @@ -0,0 +1,176 @@ +name: torch +channels: + - pytorch + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - asttokens=2.0.5=pyhd3eb1b0_0 + - backcall=0.2.0=pyhd3eb1b0_0 + - blas=1.0=mkl + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2022.10.11=h06a4308_0 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - cudatoolkit=11.3.1=h2bc3f7f_2 + - decorator=5.1.1=pyhd3eb1b0_0 + - executing=0.8.3=pyhd3eb1b0_0 + - ffmpeg=4.2.2=h20bf706_0 + - fftw=3.3.9=h27cfd23_1 + - filelock=3.6.0=pyhd3eb1b0_0 + - freetype=2.12.1=h4a9f257_0 + - ftfy=5.8=py_0 + - giflib=5.2.1=h7b6447c_0 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - huggingface_hub=0.2.1=pyhd3eb1b0_0 + - importlib_metadata=4.11.3=hd3eb1b0_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - jpeg=9e=h7f8727e_0 + - jupyter_client=7.3.5=py310h06a4308_0 + - jupyter_core=4.11.1=py310h06a4308_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libdeflate=1.8=h7f8727e_5 + - libffi=3.3=he6710b0_2 + - libgcc-ng=11.2.0=h1234567_1 + - libgfortran-ng=11.2.0=h00389a5_1 + - libgfortran5=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libidn2=2.3.2=h7f8727e_0 + - libopus=1.3.1=h7b6447c_0 + - libpng=1.6.37=hbc83047_0 + - libsodium=1.0.18=h7b6447c_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.4.0=hecacb30_0 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.0.3=h7f8727e_2 + - libuv=1.40.0=h7b6447c_0 + - libvpx=1.7.0=h439df22_0 + - libwebp=1.2.4=h11a3e52_0 + - libwebp-base=1.2.4=h5eee18b_0 + - lz4-c=1.9.3=h295c915_1 + - mkl=2021.4.0=h06a4308_640 + - mkl_fft=1.3.1=py310hd6ae3a3_0 + - mkl_random=1.2.2=py310h00e6091_0 + - ncurses=6.3=h5eee18b_3 + - nettle=3.7.3=hbbd107a_1 + - numpy-base=1.23.3=py310h8e6c178_0 + - openh264=2.1.1=h4ff587b_0 + - openssl=1.1.1s=h7f8727e_0 + - packaging=21.3=pyhd3eb1b0_0 + - parso=0.8.3=pyhd3eb1b0_0 + - pexpect=4.8.0=pyhd3eb1b0_3 + - pickleshare=0.7.5=pyhd3eb1b0_1003 + - prompt-toolkit=3.0.20=pyhd3eb1b0_0 + - ptyprocess=0.7.0=pyhd3eb1b0_2 + - pure_eval=0.2.2=pyhd3eb1b0_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pygments=2.11.2=pyhd3eb1b0_0 + - pyopenssl=22.0.0=pyhd3eb1b0_0 + - python=3.10.6=haa1d7c7_1 + - python-dateutil=2.8.2=pyhd3eb1b0_0 + - pytorch=1.11.0=py3.10_cuda11.3_cudnn8.2.0_0 + - pytorch-mutex=1.0=cuda + - readline=8.2=h5eee18b_0 + - sacremoses=0.0.43=pyhd3eb1b0_0 + - six=1.16.0=pyhd3eb1b0_1 + - sqlite=3.39.3=h5082296_0 + - stack_data=0.2.0=pyhd3eb1b0_0 + - threadpoolctl=2.2.0=pyh0d69192_0 + - tk=8.6.12=h1ccaba5_0 + - traitlets=5.1.1=pyhd3eb1b0_0 + - typing_extensions=4.3.0=py310h06a4308_0 + - tzdata=2022f=h04d1e81_0 + - wcwidth=0.2.5=pyhd3eb1b0_0 + - wheel=0.37.1=pyhd3eb1b0_0 + - x264=1!157.20191217=h7b6447c_0 + - xz=5.2.6=h5eee18b_0 + - yaml=0.2.5=h7b6447c_0 + - zeromq=4.3.4=h2531618_0 + - zlib=1.2.13=h5eee18b_0 + - zstd=1.5.2=ha4553b6_0 + - pip: + - absl-py==1.3.0 + - bottleneck==1.3.5 + - brotlipy==0.7.0 + - cachetools==5.2.0 + - certifi==2022.12.7 + - cffi==1.15.1 + - click==8.0.4 + - contourpy==1.0.7 + - cryptography==38.0.1 + - cycler==0.11.0 + - debugpy==1.5.1 + - easydict==1.10 + - entrypoints==0.4 + - fonttools==4.39.2 + - fvcore==0.1.5.post20221221 + - google-auth==2.15.0 + - google-auth-oauthlib==0.4.6 + - grpcio==1.51.1 + - idna==3.4 + - importlib-metadata==4.11.3 + - iopath==0.1.10 + - ipykernel==6.15.2 + - ipython==8.4.0 + - jedi==0.18.1 + - jieba==0.42.1 + - joblib==1.1.1 + - jupyter-client==7.3.5 + - jupyter-core==4.11.1 + - kiwisolver==1.4.4 + - markdown==3.4.1 + - markupsafe==2.1.1 + - matplotlib==3.7.1 + - matplotlib-inline==0.1.6 + - mkl-fft==1.3.1 + - mkl-random==1.2.2 + - mkl-service==2.4.0 + - nest-asyncio==1.5.5 + - numexpr==2.8.3 + - numpy==1.23.3 + - nvidia-ml-py3==7.352.0 + - oauthlib==3.2.2 + - pandas==1.4.4 + - pillow==9.2.0 + - pip==22.2.2 + - portalocker==2.7.0 + - protobuf==3.20.3 + - psutil==5.9.0 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pyparsing==3.0.9 + - pysocks==1.7.1 + - pytz==2022.1 + - pyyaml==6.0 + - pyzmq==23.2.0 + - regex==2022.7.9 + - requests==2.28.1 + - requests-oauthlib==1.3.1 + - rsa==4.9 + - scikit-learn==1.1.1 + - scipy==1.9.1 + - setuptools==65.5.0 + - tabulate==0.9.0 + - tensorboard==2.11.0 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.1 + - termcolor==2.2.0 + - tokenizers==0.11.4 + - torch==1.11.0 + - torchaudio==0.11.0 + - torchsummary==1.5.1 + - torchvision==0.12.0 + - tornado==6.2 + - tqdm==4.64.1 + - transformers==4.18.0 + - typing-extensions==4.3.0 + - urllib3==1.26.12 + - werkzeug==2.2.2 + - yacs==0.1.8 + - yapf==0.32.0 + - zipp==3.8.0 +prefix: /home/wj/.conda/envs/mmsa diff --git a/loss/EMD.py b/loss/EMD.py new file mode 100644 index 0000000..04f4f5c --- /dev/null +++ b/loss/EMD.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn + + +def single_emd_loss(p, q, r=2): + """ + Earth Mover's Distance of one sample + + Args: + p: true distribution of shape num_classes × 1 + q: estimated distribution of shape num_classes × 1 + r: norm parameter + """ + assert p.shape == q.shape, "Length of the two distribution must be the same" + length = p.shape[0] + emd_loss = 0.0 + for i in range(1, length + 1): + emd_loss += torch.abs(sum(p[:i] - q[:i]))**r + return (emd_loss / length)**(1. / r) + + +def emd_loss(p, q, r=2): + """ + Earth Mover's Distance on a batch + + Args: + p: true distribution of shape mini_batch_size × num_classes × 1 + q: estimated distribution of shape mini_batch_size × num_classes × 1 + r: norm parameters + """ + assert p.shape == q.shape, "Shape of the two distribution batches must be the same." + mini_batch_size = p.shape[0] + loss_vector = [] + for i in range(mini_batch_size): + loss_vector.append(single_emd_loss(p[i], q[i], r=r)) + return sum(loss_vector) / mini_batch_size + + +class EMDLoss(nn.Module): + def __init__(self) -> None: + super(EMDLoss, self).__init__() + + def forward(self, x, y): + return emd_loss(x, y) diff --git a/loss/__init__.py b/loss/__init__.py new file mode 100644 index 0000000..2e92239 --- /dev/null +++ b/loss/__init__.py @@ -0,0 +1,3 @@ +from .adv_div import AdvDivLoss, ProgressiveCircularLoss +from .EMD import EMDLoss +from .share_specific import SharedAndSpecificLoss diff --git a/loss/adv_div.py b/loss/adv_div.py new file mode 100644 index 0000000..f55119a --- /dev/null +++ b/loss/adv_div.py @@ -0,0 +1,106 @@ +import torch.nn as nn +import torch + +def Adv_hook(module, grad_in, grad_out): + return ((grad_in[0] * (-1), grad_in[1])) + + +class AdvDivLoss(nn.Module): + """ + Attention AdvDiverse Loss + x : is the vector + """ + def __init__(self, parts=4): + super(AdvDivLoss, self).__init__() + self.parts = parts + + self.fc_pre = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(1), + nn.Linear(256, 128, bias=False)) + self.fc = nn.Sequential(nn.BatchNorm1d(128), nn.ReLU(), + nn.Linear(128, 128), nn.BatchNorm1d(128)) + self.fc_pre.register_backward_hook(Adv_hook) + + def forward(self, x): + x = nn.functional.normalize(x) + x = self.fc_pre(x) + x = self.fc(x) + x = nn.functional.normalize(x) + out = 0 + num = int(x.size(0) / self.parts) + for i in range(self.parts): + for j in range(self.parts): + if i < j: + out += ((x[i * num:(i + 1) * num, :] - + x[j * num:(j + 1) * num, :]).norm( + dim=1, keepdim=True)).mean() + return out * 2 / (self.parts * (self.parts - 1)) + + +class Circular_structured(nn.Module): + """ + Circular_structured + x : is the vector + """ + def __init__(self): + super(Circular_structured, self).__init__() + self.pi = torch.tensor(3.14159).to('cuda:0') + def cal_pi(self, x): + x = torch.where((x >= 0.5 * self.pi) & (x < 1.5 * self.pi), torch.ones_like(x), torch.zeros_like(x)) + return x + def forward(self, x): + N, C = x.shape + j = torch.arange(1, C+1).repeat((N, 1)).to('cuda:0') + thetaj = self.pi * (j * 2 - 1) / 8.0 + rj = torch.ones_like(x) + + theta_ji = thetaj * x + + r_ji = rj * x + x_ji = r_ji * torch.cos(theta_ji) + y_ji = r_ji * torch.sin(theta_ji) + + x_i = x_ji.sum(1) + y_i = y_ji.sum(1) + + r_i = torch.sqrt(x_i ** 2 + y_i ** 2) + theta_i = torch.arctan(y_i/x_i) + + p_i = torch.where((theta_i < 1.5 * self.pi) & (theta_i >= 0.5 * self.pi), torch.ones_like(theta_i), torch.zeros_like(theta_i)) + + e_i = (p_i, theta_i, r_i) + return e_i + + +class ProgressiveCircularLoss(nn.Module): + """ + Circular_structured + x : is the vector + """ + def __init__(self, mu = 0.5): + super(ProgressiveCircularLoss, self).__init__() + self.mu = mu + self.klloss = nn.KLDivLoss(reduction='batchmean') + self.cs = Circular_structured() + + def forward(self, x, y): + x_ = self.cs(x) + y_ = self.cs(y) + L_pc = ((y_[0] - x_[0]) ** 2 + (y_[1] - x_[1]) ** 2) * y_[2] + l_pc = L_pc.mean() + l_kl = self.klloss(x, y) + return self.mu * l_pc + (1 - self.mu) * l_kl + +if __name__ == "__main__": + cs_loss = Circular_structured() + pc_loss = ProgressiveCircularLoss() + x = torch.tensor([[-0.2313, -0.1209, 0.1813, 0.0177, 0.2263, 0.3111, -0.2915, -0.2925], + [-0.2313, -0.1209, 0.1813, 0.0177, 0.2263, 0.3111, -0.2915, -0.2925]]) + y = torch.tensor([[-1.2313, -0.1209, 0.1813, 0.0377, 0.2263, 0.1111, -0.2915, -0.2125], + [-0.5313, -0.3209, 1.1813, 0.1177, 0.2263, 0.3111, -0.1915, -0.2925]]) + loss = cs_loss(x) + print(loss) + + loss = pc_loss(x, y) + print(loss) + diff --git a/loss/share_specific.py b/loss/share_specific.py new file mode 100644 index 0000000..94af666 --- /dev/null +++ b/loss/share_specific.py @@ -0,0 +1,37 @@ +# -*- encoding: utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SharedAndSpecificLoss(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(128, 64) + self.fc2 = nn.Linear(64, 3) + self.softmax = nn.Softmax(dim=-1) + self.loss = nn.CrossEntropyLoss() + + def forward(self, shared1, specific1, shared2, specific2, shared3, + specific3, label): + + orth_1 = torch.bmm(shared1.unsqueeze(1), + specific1.unsqueeze(1).transpose(1, 2)) + orth_1 = torch.norm(orth_1) + + orth_2 = torch.bmm(shared2.unsqueeze(1), + specific2.unsqueeze(1).transpose(1, 2)) + orth_2 = torch.norm(orth_2) + + orth_3 = torch.bmm(shared3.unsqueeze(1), + specific3.unsqueeze(1).transpose(1, 2)) + orth_3 = torch.norm(orth_3) + + shared = torch.cat([shared1, shared2, shared3], dim=0) + + out = self.fc1(shared) + out = self.fc2(out) + out = self.softmax(out) + cls = self.loss(out, label.long()) + + return orth_1 + orth_2 + orth_3 + cls diff --git a/main.py b/main.py new file mode 100644 index 0000000..b748242 --- /dev/null +++ b/main.py @@ -0,0 +1,89 @@ +# -*- encoding: utf-8 -*- +import os +import traceback +from collections import OrderedDict +import torch +from trainer import get_trainer +from utils import parse_args, seed_torch +from utils import read_yaml, write_yaml +from utils import create_summary, create_logger, clear_log, save_del + +if __name__ == "__main__": + # ================1. config==================== + opt = OrderedDict(vars(parse_args())) + + if opt['resume_path']: + config = read_yaml(os.path.join(opt['resume_path'], "config.yaml"), + isResume=True) + else: + config = read_yaml('base.yaml') + if opt['specific_cfg']: + config_train = read_yaml(opt['specific_cfg']) + else: + config_train = read_yaml('train.yaml') + for k, v in config_train.items(): + config[k] = v + + config_model = read_yaml(os.path.join(config['model'] + '.yaml')) + for k, v in config_model.items(): + config[k] = v + + for k, v in config.items(): + if k not in opt.keys() or opt[k] == None: + opt[k] = v + + # ================2. log file ==================== + if opt['tag'] == 'cache': + clear_log('cache') + + writer, path = create_summary(opt['tag']) + logger = create_logger(path) + logger.name = __name__ + + opt['path'] = path + + # ================3. device==================== + seed_torch(opt['seed']) + # device setting + if torch.cuda.is_available(): + opt['device'] = 'cuda:' + str(opt['gpu_id']) + else: + opt['device'] = 'cpu' + + # save config + write_yaml(opt['path'], opt) + + # ================4. start to train==================== + # 1. creater trainer + trainer = get_trainer(opt["trainer"])(opt, logger, writer) + + # 2. data loader + trainer.set_dataloader() + + # 3. model + trainer.set_model() + + # 4. optimizer + trainer.set_optimizer() + + # 5. lr + trainer.set_scheduler() + + # 6. metric + trainer.meters() + + # 7. loss + trainer.set_loss() + + # 8. resume + if opt["resume_path"]: + trainer.load_checkpoint() + + # 9. training + try: + trainer.train() + except KeyboardInterrupt: + save_del(opt['path']) + except Exception as e: + traceback.print_exc() + save_del(opt['path']) diff --git a/network/GMG.py b/network/GMG.py new file mode 100644 index 0000000..e49cac2 --- /dev/null +++ b/network/GMG.py @@ -0,0 +1,313 @@ +import torch +import torch.nn as nn +from torchvision import models +import torch.nn.functional as F +import numpy as np +import os + + +class DynamicGraphConvolution(nn.Module): + def __init__(self, in_features, out_features, num_nodes): + super(DynamicGraphConvolution, self).__init__() + + self.static_weight = nn.Sequential( + nn.Conv1d(in_features, out_features, 1), nn.LeakyReLU(0.2)) + + self.gap = nn.AdaptiveAvgPool1d(1) + self.conv_global = nn.Conv1d(in_features, in_features, 1) + self.bn_global = nn.BatchNorm1d(in_features) + self.relu = nn.LeakyReLU(0.2) + + self.conv_create_co_mat = nn.Conv1d(in_features * 2, num_nodes, 1) + self.dynamic_weight = nn.Conv1d(in_features, out_features, 1) + + def forward_static_gcn(self, x, adj): + x = torch.matmul(adj, x.transpose(1, 2)) + x = self.static_weight(x.transpose(1, 2)) + return x + + def forward_construct_dynamic_graph(self, x): + ### Model global representations ### + # import pdb; pdb.set_trace() + x_glb = self.gap(x) + x_glb = self.conv_global(x_glb) + x_glb = self.bn_global(x_glb) + x_glb = self.relu(x_glb) + x_glb = x_glb.expand(x_glb.size(0), x_glb.size(1), x.size(2)) + + ### Construct the dynamic correlation matrix ### + x = torch.cat((x_glb, x), dim=1) + dynamic_adj = self.conv_create_co_mat(x) + dynamic_adj = torch.softmax(dynamic_adj, dim=-1) # A_d + return dynamic_adj + + def forward_dynamic_gcn(self, x, dynamic_adj): + x = torch.matmul(x, dynamic_adj) + x = self.relu(x) + x = self.dynamic_weight(x) + x = self.relu(x) + return x + + def forward(self, x, adj): + """ D-GCN module + + Shape: + - Input: (B, C_in, N) # C_in: 1024, N: num_classes + - Output: (B, C_out, N) # C_out: 1024, N: num_classes + """ + out_static = self.forward_static_gcn(x, adj) # x [8, 3920, 8] + x = x + out_static # residual + dynamic_adj = self.forward_construct_dynamic_graph(x) + + x = self.forward_dynamic_gcn(x, dynamic_adj) + return x + +class HighDivModule(nn.Module): + def __init__(self, in_channels, order=1): + super(HighDivModule, self).__init__() + self.order = order + self.inter_channels = in_channels // 8 * 2 + for j in range(self.order): + for i in range(j + 1): + name = 'order' + str( + self.order) + '_' + str(j + 1) + '_' + str(i + 1) + setattr( + self, name, + nn.Sequential( + nn.Conv2d(in_channels, + self.inter_channels, + 1, + padding=0, + bias=False))) + for i in range(self.order): + name = 'convb' + str(self.order) + '_' + str(i + 1) + setattr( + self, name, + nn.Sequential( + nn.Conv2d(self.inter_channels, + in_channels, + 1, + padding=0, + bias=False), nn.Sigmoid())) + + def forward(self, x): + y = [] + for j in range(self.order): + for i in range(j + 1): + name = 'order' + str( + self.order) + '_' + str(j + 1) + '_' + str(i + 1) + layer = getattr(self, name) + y.append(layer(x)) + y_ = [] + cnt = 0 + for j in range(self.order): + y_temp = 1 + for i in range(j + 1): + y_temp = y_temp * y[cnt] + cnt += 1 + y_.append(F.relu(y_temp)) + + + y__ = 0 + for i in range(self.order): + name = 'convb' + str(self.order) + '_' + str(i + 1) + layer = getattr(self, name) + y__ += layer(y_[i]) + out = x * y__ / self.order + return out + + +class GMG(nn.Module): + def __init__(self, opt): + super().__init__() + self.parts = opt["parts"] + self.class_num = opt["num_classes"] + self.lambda_ = opt['lambda'] + self.resnet50 = models.resnet50(pretrained=True) + + + self.g1_sample = nn.UpsamplingNearest2d(size=(224, 224)) + self.g2_sample = nn.UpsamplingNearest2d(size=(224, 224)) + self.g3_sample = nn.UpsamplingNearest2d(size=(224, 224)) + + self.x_up = nn.UpsamplingNearest2d(size=(28, 28)) + + self.gram_ln = nn.LayerNorm((3, 224, 224)) + self.gram_conv1 = nn.Conv2d(in_channels=3, + out_channels=8, + kernel_size=7, + padding=3) + self.gram_conv1_ln = nn.LayerNorm(224) + self.gram_relu1 = nn.ReLU() + self.gram_pool1 = nn.MaxPool2d(2) + + self.gram_conv2 = nn.Conv2d(in_channels=8, + out_channels=16, + kernel_size=7, + padding=3) + self.gram_conv2_ln = nn.LayerNorm(112) + self.gram_relu2 = nn.ReLU() + self.gram_pool2 = nn.MaxPool2d(2) + + self.x_conv = nn.Conv2d(2048, 256, 1) + self.x_relu = nn.ReLU() + + self.l2_conv = nn.Conv2d(1024, 256, 1) + self.l2_relu = nn.ReLU() + + self.softmax = nn.Softmax(dim=1) + self.head = nn.Conv2d(512, 8, 1) + # MHN + for i in range(self.parts): + name = 'HIGH' + str(i) + setattr(self, name, HighDivModule(512, i + 1)) # High order feature_map channel attention, + # h*w spatial location point + # paper: https://ieeexplore.ieee.org/document/9009039/ + # link: https://zhuanlan.zhihu.com/p/104380548 + + for i in range(self.parts): + name = 'classifier' + str(i) + setattr(self, name, + nn.Sequential(nn.Conv2d(512, self.class_num, 1))) + + for i in range(self.parts): + name = 'classifier2' + str(i) + setattr(self, name, + nn.Sequential(nn.Conv2d(320, self.class_num, 1))) + + # GCN + self.adj = torch.from_numpy( + np.load( + os.path.join(opt["data_path"], opt['dataset'], "twitter.npy") ) + ).float().to(opt['device']) + + self.gcn = DynamicGraphConvolution(980*self.parts, 980*self.parts, 8) # style-gcn + + @staticmethod + def _cal_gram(feature): + feature = feature.view(feature.shape[0], feature.shape[1], -1) + feature = torch.matmul(feature, feature.transpose(-1, -2)) + return feature + + def _gram_forward(self, g1, g2, g3): + g1 = self._cal_gram(g1).unsqueeze(1) # b * c1 * wh -> b * 1 * c1 * c1 + g2 = self._cal_gram(g2).unsqueeze(1) # b * c2 * wh -> b * 1 * c2 * c2 + g3 = self._cal_gram(g3).unsqueeze(1) # b * c2 * wh -> b * 1 * c3 * c3 + g1 = self.g1_sample(g1) # b * 1 * 224 * 224 + g2 = self.g2_sample(g2) # b * 1 * 224 * 224 + g3 = self.g3_sample(g3) # b * 1 * 224 * 224 + g = torch.cat([g1, g2, g3], dim=1) # b * 3 * 224 * 224 + + g = self.gram_ln(g) + g = self.gram_conv1(g) + g = self.gram_relu1(g) + g = self.gram_conv1_ln(g) + g = self.gram_pool1(g) # b * 8 * 112 * 112 + g = self.gram_conv2(g) + g = self.gram_relu2(g) + g = self.gram_conv2_ln(g) + g = self.gram_pool2(g) # b * 16 * 56 * 56 + + g1 = g.view(g.shape[0], -1, 14, 14) # b * (16*4*4) * 14 * 14 + g1 = g1.repeat(self.parts, 1, 1, 1) # b parts * 16 * 14 * 14 + + g2 = g.view(g.shape[0], -1, 28, 28) # b * (16*2*2) * 28 * 28 + g2 = g2.repeat(self.parts, 1, 1, 1) # b parts * (16*2*2) * 28 * 28 + return g1, g2 + + def forward(self, x): + x = self.resnet50.conv1(x) + x = self.resnet50.bn1(x) + g1 = x + x = self.resnet50.relu(x) + x = self.resnet50.maxpool(x) + x = self.resnet50.layer1(x) + g2 = x + x = self.resnet50.layer2(x) + g3 = x # b, c, h, w + + # High-Order Attention + xx = [] + for i in range(self.parts): + name = 'HIGH' + str(i) + layer = getattr(self, name) + xx.append(layer(x)) + + + x = torch.cat(xx, 0) # B*parts, 512, 56, 56 + + x = self.resnet50.layer3(x) # append high order attention on feature map + # with the output of layer2 + # B*parts, 1024, 28, 28 + + # conv of 1*1 with relu activate behind of the output of layers for getting fc2, l2 + l2 = self.l2_conv(x) # (B*parts)*256*28*28 + fc2 = l2 # -> + l2 = self.l2_relu(l2) + + x = self.resnet50.layer4(x) # (B*parts)*2048*14*14 + + + x = self.x_conv(x) # (B*parts)*256*14*14 + fc1 = x # -> + x = self.x_relu(x) + xxxx = x + + # Style Representation + g1, g2 = self._gram_forward(g1, g2, g3)# two scale style feature + # Input: (B, 64, 224, 224), (B, 256, 112, 112), (B, 512, 56, 56) + # Input: (B*parts, 256, 14, 14), (B*parts, 64, 28, 28) + + xx1 = torch.cat([x, g1], dim=1) # (B*parts)*512*14*14 + num = int(xx1.size(0) / self.parts) # batch + + y_1 = [] + for i in range(self.parts): + name = 'classifier' + str(i) + layer = getattr(self, name) + x = layer(xx1[i * num:(i + 1) * num, :]) + y_1.append(x.flatten(2)) + # y_1: list (B, 8, 196) len=2 + + # FPN two layer + xxxx = self.x_up(xxxx) # (B*parts)*256*28*28 + l2 = xxxx + l2 + xx2 = torch.cat([l2, g2], dim=1) # # (B*parts)*(256+64)*28*28 + + y_2 = [] + for i in range(self.parts): + name = 'classifier2' + str(i) + layer = getattr(self, name) + x = layer(xx2[i * num:(i + 1) * num, :]) + y_2.append(x.flatten(2)) + # y_2: list (B, 8, 784) len=parts + + # cat + # 1. cat fpn + y_ = [] + for i in range(len(y_1)): + y_.append(torch.cat([y_1[i], y_2[i]], dim=-1)) + # y_: list (B, 8, 980) len=parts + + # 2. mhn + # for GCN + yy_ = y_[0] + for i in range(1, len(y_)): + yy_ = torch.cat((yy_, y_[i]), dim=-1) + # yy_: (B, 8, 1960) len=parts + + # mixed attention + for i in range(len(y_)): + base_logit = torch.mean(y_[i], dim=2) # first Order (B, 8) + att_logit = torch.max(y_[i], dim=2)[0] # second Order (B, 8) + y_[i] = self.softmax(base_logit + self.lambda_ * att_logit) + + # GCN + yy_ = yy_.transpose(1, 2) # (B, 1960, 8) + yy_ = self.gcn(yy_, self.adj) # adj 8, 8 + yy = yy_.transpose(1, 2) # (B, 8, 1960) + base_logit = torch.mean(yy, dim=2) # (B, 8) + att_logit = torch.max(yy, dim=2)[0] # (B, 8) + gcn = self.softmax(base_logit + self.lambda_ * att_logit) + + return y_, gcn, fc1, fc2 diff --git a/network/__init__.py b/network/__init__.py new file mode 100644 index 0000000..f4dbcaa --- /dev/null +++ b/network/__init__.py @@ -0,0 +1,13 @@ +from .GMG import GMG + +models = { + "gmg": GMG, # ours +} + + +def get_model(model: str): + if model in models: + return models[model] + else: + raise Exception('No such model: "%s", available: {%s}.' % + (model, '|'.join(models.keys()))) diff --git a/scheduler/__init__.py b/scheduler/__init__.py new file mode 100644 index 0000000..a1489b8 --- /dev/null +++ b/scheduler/__init__.py @@ -0,0 +1,47 @@ +from torch import optim + + +def get_scheduler(opt, optimizer): + if opt['scheduler'] == "scheduler_stepLR": + scheduler = optim.lr_scheduler.StepLR(optimizer, + step_size=opt['scheduler_stepLR']['step_size'], + gamma=opt['scheduler_stepLR']['gamma']) + + # MultiStepLR + elif opt['scheduler'] == "scheduler_multi": + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=opt['scheduler_multi']['milestones'], + gamma=opt['scheduler_multi']['gamma']) + + # ExponentialLR + elif opt['scheduler'] == "scheduler_exp": + scheduler = optim.lr_scheduler.ExponentialLR( + optimizer, gamma=opt['scheduler_exp']['gamma']) + + # CosineAnnealingLR + elif opt['scheduler'] == "scheduler_cos": + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=opt['scheduler_cos']['T_max'], + eta_min=opt['scheduler_cos']['eta_min']) + + # CyclicLR + elif opt['scheduler'] == "scheduler_cyclic": + scheduler = optim.lr_scheduler.CyclicLR( + optimizer, + base_lr=opt['lr'], + max_lr=opt['scheduler_cyclic']['max_lr'], + step_size_up=opt['scheduler_cyclic']['up'], + step_size_down=opt['scheduler_cyclic']['down']) + + elif opt['scheduler'] == 'scheduler_lambda': + if opt['scheduler_lambda']['lr_lambda'] == None: + raise NotImplementedError("lr_lambda need define") + scheduler = optim.lr_scheduler.LambdaLR( + optimizer, opt['scheduler_lambda']['lr_lambda']) + + else: + scheduler = None + + return scheduler diff --git a/trainer/__init__.py b/trainer/__init__.py new file mode 100644 index 0000000..7452e37 --- /dev/null +++ b/trainer/__init__.py @@ -0,0 +1,10 @@ +from .gmgtrainer import gmgTrainer +trainers = {"gmg": gmgTrainer} + + +def get_trainer(trainer: str): + if trainer in trainers: + return trainers[trainer] + else: + raise Exception('No such model: "%s", available: {%s}.' % + (trainer, '|'.join(trainers.keys()))) diff --git a/trainer/basetrainer.py b/trainer/basetrainer.py new file mode 100644 index 0000000..1f33eef --- /dev/null +++ b/trainer/basetrainer.py @@ -0,0 +1,292 @@ +# -*- encoding: utf-8 -*- +import os +import time + +from numpy import inf + +from utils import AverageMeter, write_result, LDL_measurement +from dataset import Dataset_LDL +from transforms import get_transforms +from scheduler import get_scheduler +from network import get_model + +import torch +from torch.optim import SGD +from torch.utils.data import DataLoader + + +class BaseTrainer(): + def __init__(self, opt, logger, writer) -> None: + self.opt = opt + self.logger = logger + self.writer = writer + self.logger.name = __name__ + + # data + self.dataloader_train = None + self.dataloader_test = None + + # model + self.model = None + + # optimizer + self.optimizer = None + self.scheduler = None + + # loss + self.loss = None + + # statistical + self.meters_dict = {} + + # train + # self.save_mark = inf + self.start_epoch = 0 + self.train_steps = 0 + self.test_steps = 0 + + def set_model(self): + self.model = get_model(self.opt['model'])(self.opt) + + def set_dataloader(self): + # train + transforms_train = get_transforms(self.opt['image_size'], + 'train', + self.opt['dataset'], + isNormalize=True) + dataset_train = Dataset_LDL( + os.path.join(self.opt['data_path'], self.opt['dataset']), 'train', + transforms_train) + self.dataloader_train = DataLoader(dataset_train, + batch_size=self.opt['batch_size'], + num_workers=self.opt['num_workers'], + shuffle=True, + drop_last=True) + + # test + transforms_test = get_transforms(self.opt['image_size'], + 'test', + self.opt['dataset'], + isNormalize=True) + dataset_test = Dataset_LDL( + os.path.join(self.opt['data_path'], self.opt['dataset']), 'test', + transforms_test) + self.dataloader_test = DataLoader(dataset_test, + batch_size=self.opt['batch_size'], + num_workers=self.opt['num_workers']) + + def set_optimizer(self): + self.optimizer = SGD(self.model.parameters(), + lr=self.opt['lr'], + momentum=self.opt['momentum']) + + def set_scheduler(self): + self.scheduler = get_scheduler(self.opt, self.optimizer) + + def set_loss(self): + raise NotImplementedError + + def train(self): + self.model.to(self.opt['device']) + self.loss.to(self.opt['device']) + for epoch in range(self.opt['epochs']): + epoch += self.start_epoch + # train + self.model.train() + for k, v in self.meters_dict.items(): + self.meters_dict[k].reset() + self.train_epoch(epoch) + self.writer_train(epoch) + + # learning rate + if self.scheduler: + self.scheduler.step() + self.writer.add_scalar( + "train/lr", + self.optimizer.state_dict()['param_groups'][0]['lr'], + global_step=epoch) + # test + self.model.eval() + for k, v in self.meters_dict.items(): + self.meters_dict[k].reset() + self.test_epoch(epoch) + self.writer_test(epoch) + + checkpoint = { + 'epoch': epoch, + 'state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'train_steps': self.train_steps, + 'test_steps': self.test_steps, + } + self.save_checkpoint(checkpoint, 'epoch') + + + def test(self): + raise NotImplementedError + + def train_epoch(self, epoch): + start_time = time.time() + for i, (inputs, labels, cls) in enumerate(self.dataloader_train): + inputs = inputs.to(self.opt['device']) + labels = labels.to(self.opt['device']) + cls = cls.to(self.opt['device']) + + outputs = self.model(inputs) + loss = self.loss(outputs, labels) + + self.meters_dict['loss'].update(loss.item(), inputs.size(0)) + self.meters_dict['ldl'].update(outputs, labels, inputs.size(0)) + + prediction = torch.max(outputs, 1)[1] + self.meters_dict['acc'].update( + sum(prediction == cls) / inputs.size(0)) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + batch_time = time.time() - start_time + self.meters_dict['batch_time'].update(batch_time) + start_time = time.time() + + if i % self.opt['display_interval'] == 0: + print( + '{}, {} Epoch, {} Iter, Loss: {:.4f}, Batch time: {:.4f}, Acc: {:.4f}' + .format(time.strftime('%Y-%m-%d %H:%M:%S'), epoch, i, + self.meters_dict['loss'].value(), + self.meters_dict['batch_time'].value(), + self.meters_dict['acc'].value())) + self.writer.add_scalar('train/acc', + self.meters_dict['acc'].value(), + global_step=self.train_steps) + self.writer.add_scalar('loss/train_loss', + self.meters_dict['loss'].value(), + global_step=self.train_steps) + + self.train_steps += 1 + + def test_epoch(self, epoch): + start_time = time.time() + for i, (inputs, labels, cls) in enumerate(self.dataloader_test): + inputs = inputs.to(self.opt['device']) + labels = labels.to(self.opt['device']) + cls = cls.to(self.opt['device']) + + with torch.no_grad(): + outputs = self.model(inputs) + loss = self.loss(outputs, labels) + + self.meters_dict['loss'].update(loss.item(), inputs.size(0)) + self.meters_dict['ldl'].update(outputs, labels, inputs.size(0)) + + prediction = torch.max(outputs, 1)[1] + self.meters_dict['acc'].update( + sum(prediction == cls) / inputs.size(0)) + + batch_time = time.time() - start_time + self.meters_dict['batch_time'].update(batch_time) + start_time = time.time() + if i % self.opt['display_interval'] == 0: + print( + '{}, {} Epoch, {} Iter, Loss: {:.4f}, Batch time: {:.4f}, Acc: {:.4f}' + .format(time.strftime('%Y-%m-%d %H:%M:%S'), epoch, i, + self.meters_dict['loss'].value(), + self.meters_dict['batch_time'].value(), + self.meters_dict['acc'].value())) + + self.writer.add_scalar('test/acc', + self.meters_dict['acc'].value(), + global_step=self.test_steps) + self.writer.add_scalar('loss/test_loss', + self.meters_dict['loss'].value(), + global_step=self.test_steps) + self.test_steps += 1 + + def meters(self): + self.meters_dict['loss'] = AverageMeter('loss') + self.meters_dict['ldl'] = LDL_measurement(self.opt['num_classes']) + self.meters_dict['batch_time'] = AverageMeter('batch_time') + self.meters_dict['acc'] = AverageMeter('acc') + + def writer_train(self, epoch): + loss = self.meters_dict['loss'].average() + ldl = self.meters_dict['ldl'].average() + acc = self.meters_dict['acc'].average() + + self.logger.info( + 'Train: {epoch}\tLoss: {loss:.4f}\tkldiv: {kldiv:.4f}\tCosine: {Cosine:.4f}\t Cheb: {Cheb:.4f}\t intersection: {intersection:.4f}' + .format(epoch=epoch, + loss=loss, + kldiv=ldl['klDiv'], + Cosine=ldl['cosine'], + Cheb=ldl['chebyshev'], + intersection=ldl['intersection'])) + self.logger.info("Acc:{}".format(acc)) + self.writer.add_scalar("train/acc", acc, epoch) + + self.writer.add_scalar("train/Loss", loss, epoch) + self.writer.add_scalar("train/KLDiv", ldl['klDiv'], epoch) + self.writer.add_scalar("train/Cosine", ldl['cosine'], epoch) + self.writer.add_scalar("train/intersection", ldl['intersection'], + epoch) + self.writer.add_scalar("train/chebyshev", ldl['chebyshev'], epoch) + self.writer.add_scalar("train/clark", ldl['clark'], epoch) + self.writer.add_scalar("train/canberra", ldl['canberra'], epoch) + self.writer.add_scalar("train/squareChord", ldl['squareChord'], epoch) + self.writer.add_scalar("train/sorensendist", ldl['sorensendist'], + epoch) + write_result(self.opt['path'], epoch, acc, ldl, 'train') + + def writer_test(self, epoch): + loss = self.meters_dict['loss'].average() + ldl = self.meters_dict['ldl'].average() + acc = self.meters_dict['acc'].average() + + self.logger.info( + 'Test: {epoch}\tLoss: {loss:.4f}\tkldiv: {kldiv:.4f}\tCosine: {Cosine:.4f}\tCheb: {Cheb:.4f}\t intersection: {intersection:.4f}' + .format(epoch=epoch, + loss=loss, + kldiv=ldl['klDiv'], + Cosine=ldl['cosine'], + Cheb=ldl['chebyshev'], + intersection=ldl['intersection'])) + + self.logger.info("Acc:{}".format(acc)) + self.writer.add_scalar("test/acc", acc, epoch) + + self.writer.add_scalar("test/Loss", loss, epoch) + self.writer.add_scalar("test/KLDiv", ldl['klDiv'], epoch) + self.writer.add_scalar("test/Cosine", ldl['cosine'], epoch) + self.writer.add_scalar("test/intersection", ldl['intersection'], epoch) + self.writer.add_scalar("test/chebyshev", ldl['chebyshev'], epoch) + self.writer.add_scalar("test/clark", ldl['clark'], epoch) + self.writer.add_scalar("test/canberra", ldl['canberra'], epoch) + self.writer.add_scalar("test/squareChord", ldl['squareChord'], epoch) + self.writer.add_scalar("test/sorensendist", ldl['sorensendist'], epoch) + + write_result(self.opt['path'], epoch, acc, ldl, 'test') + + def load_checkpoint(self): + checkpoint = torch.load(self.opt["resume_path"]) + # epoch + self.start_epoch = checkpoint['epoch'] + + # model + model_dict = self.model.state_dict() + for k, v in checkpoint['state_dict'].items(): + if k in model_dict and v.shape == model_dict[k].shape: + model_dict[k] = v + self.model.load_state_dict(model_dict) + + # optimizer + # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + # others + # self.save_mark = checkpoint['best_score'] + self.train_steps = checkpoint['train_steps'] + self.test_steps = checkpoint['test_steps'] + + def save_checkpoint(self, checkpoint, save_name): + save_path = os.path.join('logs', self.opt['path'], save_name + '.pth') + torch.save(checkpoint, save_path, _use_new_zipfile_serialization=True) diff --git a/trainer/gmgtrainer.py b/trainer/gmgtrainer.py new file mode 100644 index 0000000..5d56c97 --- /dev/null +++ b/trainer/gmgtrainer.py @@ -0,0 +1,155 @@ +# -*- encoding: utf-8 -*- +import time +import torch +import torch.nn as nn +from torch.optim import SGD + +from loss import AdvDivLoss +from .basetrainer import BaseTrainer + + +class gmgTrainer(BaseTrainer): + def __init__(self, opt, logger, writer): + super().__init__(opt, logger, writer) + logger.name = __name__ + self.softmax = nn.Softmax(dim=1) + + def set_loss(self): + self.loss = nn.KLDivLoss(reduction='batchmean') + + self.loss_div_1 = AdvDivLoss(self.opt['parts']) + self.loss_div_1.to(self.opt['device']) + + self.loss_div_2 = AdvDivLoss(self.opt['parts']) + self.loss_div_2.to(self.opt['device']) + + self.loss_mse = nn.MSELoss() + + def set_optimizer(self): + self.optimizer = SGD(self.model.parameters(), + lr=self.opt['lr'], + momentum=self.opt['momentum'], + weight_decay=eval(self.opt['weight_decay'])) + + def train_epoch(self, epoch): + start_time = time.time() + for i, (inputs, labels, cls) in enumerate(self.dataloader_train): + inputs = inputs.to(self.opt['device']) + labels = labels.to(self.opt['device']) + cls = cls.to(self.opt['device']) + outputs, gcn, fc1, fc2 = self.model(inputs) + + result2 = outputs[0] + for j in range(1, self.opt['parts']): + result2 = outputs[j] + result2 + result2 /= self.opt['parts'] + + result = self.opt['mu'] * result2 + (1 - self.opt['mu']) * gcn + + self.meters_dict['ldl'].update(result, labels, inputs.size(0)) + prediction = torch.max(result, 1)[1] + self.meters_dict['acc'].update( + sum(prediction == cls) / inputs.size(0)) + + loss_dis_1 = self.loss(torch.log(gcn), labels) + + loss_dis_2 = self.loss(torch.log(outputs[0]), labels) + for j in range(1, self.opt['parts']): + loss_dis_2 += self.loss(torch.log(outputs[j]), labels) + + loss_dis_2 /= self.opt['parts'] + + loss_dis = (loss_dis_1 + loss_dis_2) / 2 # L_pred + + loss_div_1 = self.loss_div_1(fc1) # + loss_div_2 = self.loss_div_2(fc2) + + loss_div = (loss_div_1 + loss_div_2) / 2 + + loss = loss_dis + loss_div / (loss_div / loss_dis).detach() + + self.meters_dict['loss'].update(loss.item(), inputs.size(0)) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + batch_time = time.time() - start_time + self.meters_dict['batch_time'].update(batch_time) + start_time = time.time() + + if i % self.opt['display_interval'] == 0: + print( + '{}, {} Epoch, {} Iter, Loss: {:.4f}, Batch time: {:.4f}, Acc: {:.4f}' + .format(time.strftime('%Y-%m-%d %H:%M:%S'), epoch, i, + self.meters_dict['loss'].value(), + self.meters_dict['batch_time'].value(), + self.meters_dict['acc'].value())) + self.writer.add_scalar('train/acc', + self.meters_dict['acc'].value(), + global_step=self.train_steps) + self.writer.add_scalar('loss/train_loss', + self.meters_dict['loss'].value(), + global_step=self.train_steps) + self.train_steps += 1 + + def test_epoch(self, epoch): + start_time = time.time() + for i, (inputs, labels, cls) in enumerate(self.dataloader_test): + inputs = inputs.to(self.opt['device']) + labels = labels.to(self.opt['device']) + cls = cls.to(self.opt['device']) + + with torch.no_grad(): + inputs = inputs.to(self.opt['device']) + labels = labels.to(self.opt['device']) + cls = cls.to(self.opt['device']) + + outputs, gcn, fc1, fc2 = self.model(inputs) + + result2 = outputs[0] + for j in range(1, self.opt['parts']): + result2 = outputs[j] + result2 + result2 /= self.opt['parts'] + + result = self.opt['mu'] * result2 + (1 - self.opt['mu']) * gcn + + self.meters_dict['ldl'].update(result, labels, inputs.size(0)) + prediction = torch.max(result, 1)[1] + self.meters_dict['acc'].update( + sum(prediction == cls) / inputs.size(0)) + + loss_dis = self.loss(torch.log(outputs[0]), labels) + loss_dis += self.loss(torch.log(gcn), labels) + for j in range(1, self.opt['parts']): + loss_dis += self.loss(torch.log(outputs[j]), labels) + + loss_dis /= self.opt['parts'] * 2 + + loss_div = self.loss_div_1(fc1) + loss_div2 = self.loss_div_2(fc2) + + loss_divv = (loss_div + loss_div2) / 2 + + loss = loss_dis + loss_divv / (loss_divv / loss_dis).detach() + + self.meters_dict['loss'].update(loss.item(), inputs.size(0)) + + batch_time = time.time() - start_time + self.meters_dict['batch_time'].update(batch_time) + start_time = time.time() + if i % self.opt['display_interval'] == 0: + print( + '{}, {} Epoch, {} Iter, Loss: {:.4f}, Batch time: {:.4f}, Acc: {:.4f}' + .format(time.strftime('%Y-%m-%d %H:%M:%S'), epoch, i, + self.meters_dict['loss'].value(), + self.meters_dict['batch_time'].value(), + self.meters_dict['acc'].value())) + + self.writer.add_scalar('test/acc', + self.meters_dict['acc'].value(), + global_step=self.test_steps) + self.writer.add_scalar('loss/test_loss', + self.meters_dict['loss'].value(), + global_step=self.test_steps) + self.test_steps += 1 diff --git a/transforms/__init__.py b/transforms/__init__.py new file mode 100644 index 0000000..e25cc0f --- /dev/null +++ b/transforms/__init__.py @@ -0,0 +1,30 @@ +from .multiscale_crop import MultiScaleCrop +from .normalize import Normalize + +from torchvision import transforms + + +def get_transforms(image_size, mode, dataset, isNormalize=True): + transforms_list = [] + + # resize + if mode == 'train': + transforms_list.append( + transforms.Resize((image_size + 64, image_size + 64))) + transforms_list.append( + MultiScaleCrop(image_size, + scales=(1.0, 0.875, 0.75, 0.66, 0.5), + max_distort=2)) + else: + transforms_list.append(transforms.Resize((image_size, image_size))) + + transforms_list.append(transforms.RandomHorizontalFlip(p=0.5)) + + # ToTensor + transforms_list.append(transforms.ToTensor()) + + # normalize + if isNormalize: + transforms_list.append(Normalize(dataset)) + + return transforms.Compose(transforms_list) diff --git a/transforms/multiscale_crop.py b/transforms/multiscale_crop.py new file mode 100644 index 0000000..39f61cf --- /dev/null +++ b/transforms/multiscale_crop.py @@ -0,0 +1,94 @@ +# -*- encoding: utf-8 -*- + +import random +from PIL import Image + + +class MultiScaleCrop(object): + def __init__(self, + input_size, + scales=None, + max_distort=1, + fix_crop=True, + more_fix_crop=True): + self.scales = scales if scales is not None else [1, .875, .75, .66] + self.max_distort = max_distort + self.fix_crop = fix_crop + self.more_fix_crop = more_fix_crop + self.input_size = input_size if not isinstance(input_size, int) else [ + input_size, input_size + ] + self.interpolation = Image.BILINEAR + + def __call__(self, img): + im_size = img.size + crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) + crop_img_group = img.crop( + (offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) + ret_img_group = crop_img_group.resize( + (self.input_size[0], self.input_size[1]), self.interpolation) + return ret_img_group + + def _sample_crop_size(self, im_size): + image_w, image_h = im_size[0], im_size[1] + + # find a crop size + base_size = min(image_w, image_h) + crop_sizes = [int(base_size * x) for x in self.scales] + crop_h = [ + self.input_size[1] if abs(x - self.input_size[1]) < 3 else x + for x in crop_sizes + ] + crop_w = [ + self.input_size[0] if abs(x - self.input_size[0]) < 3 else x + for x in crop_sizes + ] + + pairs = [] + for i, h in enumerate(crop_h): + for j, w in enumerate(crop_w): + if abs(i - j) <= self.max_distort: + pairs.append((w, h)) + + crop_pair = random.choice(pairs) + if not self.fix_crop: + w_offset = random.randint(0, image_w - crop_pair[0]) + h_offset = random.randint(0, image_h - crop_pair[1]) + else: + w_offset, h_offset = self._sample_fix_offset( + image_w, image_h, crop_pair[0], crop_pair[1]) + + return crop_pair[0], crop_pair[1], w_offset, h_offset + + def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): + offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, + crop_w, crop_h) + return random.choice(offsets) + + @staticmethod + def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): + w_step = (image_w - crop_w) // 4 + h_step = (image_h - crop_h) // 4 + + ret = list() + ret.append((0, 0)) # upper left + ret.append((4 * w_step, 0)) # upper right + ret.append((0, 4 * h_step)) # lower left + ret.append((4 * w_step, 4 * h_step)) # lower right + ret.append((2 * w_step, 2 * h_step)) # center + + if more_fix_crop: + ret.append((0, 2 * h_step)) # center left + ret.append((4 * w_step, 2 * h_step)) # center right + ret.append((2 * w_step, 4 * h_step)) # lower center + ret.append((2 * w_step, 0 * h_step)) # upper center + + ret.append((1 * w_step, 1 * h_step)) # upper left quarter + ret.append((3 * w_step, 1 * h_step)) # upper right quarter + ret.append((1 * w_step, 3 * h_step)) # lower left quarter + ret.append((3 * w_step, 3 * h_step)) # lower righ quarter + + return ret + + def __str__(self): + return self.__class__.__name__ diff --git a/transforms/normalize.py b/transforms/normalize.py new file mode 100644 index 0000000..e873531 --- /dev/null +++ b/transforms/normalize.py @@ -0,0 +1,29 @@ +# -*- encoding: utf-8 -*- +from torchvision import transforms + + +class Normalize(object): + def __init__(self, dataset) -> None: + super().__init__() + # train + if dataset == "Emotion6": + self.means = [0.41779748, 0.38421513, 0.34800839] + self.stdevs = [0.23552664, 0.22541416, 0.21950753] + elif dataset == 'Flickr_LDL': + self.means = [0.43735039, 0.39944456, 0.36520021] + self.stdevs = [0.24785846, 0.23636487, 0.23396503] + elif dataset == 'Twitter_LDL': + self.means = [0.49303343, 0.4541828, 0.43356296] + self.stdevs = [0.25708641, 0.2484328, 0.24492859] + elif dataset == 'general': + self.means = [0.5, 0.5, 0.5] + self.stdevs = [0.5, 0.5, 0.5] + else: + self.means = [0.49276434, 0.45391981, 0.43331505] + self.stdevs = [0.25703167, 0.24834259, 0.24485385] + + def __call__(self, image): + return transforms.Normalize(self.means, self.stdevs)(image) + + def __str__(self) -> str: + return self.__class__.__name__ diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..dfa8471 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,5 @@ +from .options import parse_args +from .tools_torch import seed_torch +from .opt_yaml import read_yaml, write_yaml, write_result +from .writer import create_logger, create_summary, clear_log, save_del +from .metric import AverageMeter, LDL_measurement diff --git a/utils/metric.py b/utils/metric.py new file mode 100644 index 0000000..2699151 --- /dev/null +++ b/utils/metric.py @@ -0,0 +1,316 @@ +import math +import torch +import numpy as np + + +class LDL_measurement(object): + ''' + compute and stores the average LDL measurement + ''' + def __init__(self, num_classes): + self.reset() + self.num_classes = num_classes + self.current_size = 1 + + def reset(self): + self.current_result = { + 'klDiv': 0, + 'cosine': 0, + 'intersection': 0, + 'chebyshev': 0, + 'clark': 0, + 'canberra': 0, + 'squareChord': 0, + 'sorensendist': 0 + } + self.sum = { + 'klDiv': 0, + 'cosine': 0, + 'intersection': 0, + 'chebyshev': 0, + 'clark': 0, + 'canberra': 0, + 'squareChord': 0, + 'sorensendist': 0 + } + self.count = 0 + + def update(self, output, target, n=1): + self.current_size = n + + self.current_result['klDiv'] = self.KLDiv(output, target) + self.current_result['cosine'] = self.cosine(output, target) + self.current_result['intersection'] = self.intersection(output, target) + self.current_result['chebyshev'] = self.chebyshev(output, target) + self.current_result['clark'] = self.clark(output, target) + self.current_result['canberra'] = self.canberra(output, target) + self.current_result['squareChord'] = self.squareChord(output, target) + self.current_result['sorensendist'] = self.sorensendist(output, target) + + # all + self.count += n + for key in self.sum: + self.sum[key] += self.current_result[key] + + def average(self): + avg = {} + for key in self.sum.keys(): + avg[key] = self.sum[key] / self.count + return avg + + def value(self): + current_avg = {} + for key in self.current_result.keys(): + current_avg[key] = self.current_result[key] / self.current_size + return current_avg + + def KLDiv(self, output, target): + distribution_predict = output.cpu().detach().numpy() + distribution_real = target.cpu().detach().numpy() + batch_KL = np.nansum(distribution_real * + np.log(distribution_real / distribution_predict)) + return batch_KL + + def cosine(self, output, target): + distribution_predict = output.cpu().detach().numpy() + distribution_real = target.cpu().detach().numpy() + return np.sum( + np.sum(distribution_real * distribution_predict, 1) / + (np.sqrt(np.sum(distribution_real**2, 1)) * + np.sqrt(np.sum(distribution_predict**2, 1)))) + + def intersection(self, output, target): + distribution_predict = output.cpu().detach().numpy() + distribution_real = target.cpu().detach().numpy() + + concat = np.dstack((distribution_predict, distribution_real)) + concat = np.min(concat, -1) + + return np.sum(concat) + + def chebyshev(self, output, target): + distribution_predict = output.cpu().detach().numpy() + distribution_real = target.cpu().detach().numpy() + return np.sum( + np.max(np.abs(distribution_real - distribution_predict), 1)) + + def clark(self, output, target): + distribution_predict = output.cpu().detach().numpy() + distribution_real = target.cpu().detach().numpy() + numerator = (distribution_real - distribution_predict)**2 + denominator = (distribution_real + distribution_predict)**2 + + return np.nansum( + np.sqrt( + np.nansum(numerator / denominator, axis=1) / self.num_classes)) + + def canberra(self, output, target): + distribution_predict = output.cpu().detach().numpy() + distribution_real = target.cpu().detach().numpy() + numerator = np.abs(distribution_real - distribution_predict) + denominator = distribution_real + distribution_predict + return np.nansum(numerator / denominator) / self.num_classes + + def squareChord(self, output, target): + distribution_predict = output.cpu().detach().numpy() + distribution_real = target.cpu().detach().numpy() + numerator = (np.sqrt(distribution_real) - + np.sqrt(distribution_predict))**2 + denominator = np.nansum(numerator) + return denominator + + def sorensendist(self, output, target): + distribution_predict = output.cpu().detach().numpy() + distribution_real = target.cpu().detach().numpy() + numerator = np.sum(np.abs(distribution_real - distribution_predict), 1) + denominator = np.sum(distribution_real + distribution_predict, 1) + return np.nansum(numerator / denominator) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def average(self): + return self.avg + + def value(self): + return self.val + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class AveragePrecisionMeter(object): + """ + The APMeter measures the average precision per class. + The APMeter is designed to operate on `NxK` Tensors `output` and + `target`, and optionally a `Nx1` Tensor weight where (1) the `output` + contains model output scores for `N` examples and `K` classes that ought to + be higher when the model is more convinced that the example should be + positively labeled, and smaller when the model believes the example should + be negatively labeled (for instance, the output of a sigmoid function); (2) + the `target` contains only values 0 (for negative examples) and 1 + (for positive examples); and (3) the `weight` ( > 0) represents weight for + each sample. + """ + def __init__(self, difficult_examples=True): + super(AveragePrecisionMeter, self).__init__() + self.reset() + self.difficult_examples = difficult_examples + + def reset(self): + """Resets the meter with empty member variables""" + self.scores = torch.FloatTensor(torch.FloatStorage()) + self.targets = torch.LongTensor(torch.LongStorage()) + self.filenames = [] + + def add(self, output, target, filename): + """ + Args: + output (Tensor): NxK tensor that for each of the N examples + indicates the probability of the example belonging to each of + the K classes, according to the model. The probabilities should + sum to one over all classes + target (Tensor): binary NxK tensort that encodes which of the K + classes are associated with the N-th input + (eg: a row [0, 1, 0, 1] indicates that the example is + associated with classes 2 and 4) + weight (optional, Tensor): Nx1 tensor representing the weight for + each example (each weight > 0) + """ + if not torch.is_tensor(output): + output = torch.from_numpy(output) + if not torch.is_tensor(target): + target = torch.from_numpy(target) + + if output.dim() == 1: + output = output.view(-1, 1) + else: + assert output.dim() == 2, \ + 'wrong output size (should be 1D or 2D with one column \ + per class)' + + if target.dim() == 1: + target = target.view(-1, 1) + else: + assert target.dim() == 2, \ + 'wrong target size (should be 1D or 2D with one column \ + per class)' + + if self.scores.numel() > 0: + assert target.size(1) == self.targets.size(1), \ + 'dimensions for output should match previously added examples.' + + # make sure storage is of sufficient size + if self.scores.storage().size() < self.scores.numel() + output.numel(): + new_size = math.ceil(self.scores.storage().size() * 1.5) + self.scores.storage().resize_(int(new_size + output.numel())) + self.targets.storage().resize_(int(new_size + output.numel())) + + # store scores and targets + offset = self.scores.size(0) if self.scores.dim() > 0 else 0 + self.scores.resize_(offset + output.size(0), output.size(1)) + self.targets.resize_(offset + target.size(0), target.size(1)) + self.scores.narrow(0, offset, output.size(0)).copy_(output) + self.targets.narrow(0, offset, target.size(0)).copy_(target) + + self.filenames += filename # record filenames + + def value(self): + """Returns the model's average precision for each class + Return: + ap (FloatTensor): 1xK tensor, with avg precision for each class k + """ + + if self.scores.numel() == 0: + return 0 + ap = torch.zeros(self.scores.size(1)) + rg = torch.arange(1, self.scores.size(0)).float() + # compute average precision for each class + for k in range(self.scores.size(1)): + # sort scores + scores = self.scores[:, k] + targets = self.targets[:, k] + # compute average precision + ap[k] = AveragePrecisionMeter.average_precision( + scores, targets, self.difficult_examples) + return ap + + @staticmethod + def average_precision(output, target, difficult_examples=True): + + # sort examples + sorted, indices = torch.sort(output, dim=0, descending=True) + + # Computes prec@i + pos_count = 0. + total_count = 0. + precision_at_i = 0. + for i in indices: + label = target[i] + if difficult_examples and label == 0: + continue + if label == 1: + pos_count += 1 + total_count += 1 + if label == 1: + precision_at_i += pos_count / total_count + precision_at_i /= pos_count + return precision_at_i + + def overall(self): + if self.scores.numel() == 0: + return 0 + scores = self.scores.cpu().numpy() + targets = self.targets.clone().cpu().numpy() + targets[targets == -1] = 0 + return self.evaluation(scores, targets) + + def overall_topk(self, k): + targets = self.targets.clone().cpu().numpy() + targets[targets == -1] = 0 + n, c = self.scores.size() + scores = np.zeros((n, c)) - 1 + index = self.scores.topk(k, 1, True, True)[1].cpu().numpy() + tmp = self.scores.cpu().numpy() + for i in range(n): + for ind in index[i]: + scores[i, ind] = 1 if tmp[i, ind] >= 0 else -1 + return self.evaluation(scores, targets) + + def evaluation(self, scores_, targets_): + n, n_class = scores_.shape + Nc, Np, Ng = np.zeros(n_class), np.zeros(n_class), np.zeros(n_class) + for k in range(n_class): + scores = scores_[:, k] + targets = targets_[:, k] + targets[targets == -1] = 0 + Ng[k] = np.sum(targets == 1) + Np[k] = np.sum(scores >= 0) + Nc[k] = np.sum(targets * (scores >= 0)) + Np[Np == 0] = 1 + OP = np.sum(Nc) / np.sum(Np) + OR = np.sum(Nc) / np.sum(Ng) + OF1 = (2 * OP * OR) / (OP + OR) + + CP = np.sum(Nc / Np) / n_class + CR = np.sum(Nc / Ng) / n_class + CF1 = (2 * CP * CR) / (CP + CR) + return OP, OR, OF1, CP, CR, CF1 diff --git a/utils/opt_yaml.py b/utils/opt_yaml.py new file mode 100644 index 0000000..1c9394d --- /dev/null +++ b/utils/opt_yaml.py @@ -0,0 +1,88 @@ +# -*- encoding: utf-8 -*- +import os +import yaml +from collections import OrderedDict + + +def read_yaml(cfg_path, isResume=False): + if not isResume: + path = os.path.join('configs', cfg_path) + if os.path.exists(path): + with open(path, 'r', encoding='utf-8') as f: + cfg = ordered_yaml_load(f.read()) + else: + cfg = {} + else: + with open(cfg_path, 'r', encoding='utf-8') as f: + cfg = ordered_yaml_load(f.read()) + + return cfg + + +def write_yaml(path, cfg): + del_list = [] + + for k, v in cfg.items(): + if 'scheduler_' in k and k != cfg['scheduler']: + del_list.append(k) + + for k in del_list: + del cfg[k] + + with open(os.path.join('logs', path, 'config.yaml'), 'w', + encoding='utf-8') as f: + ordered_yaml_dump(cfg, + f, + default_flow_style=False, + allow_unicode=True, + indent=4) + + +def write_result(path, epoch, acc, ldl, mode): + if mode == 'train': + name = 'train.txt' + elif mode == 'test': + name = 'test.txt' + with open(os.path.join('logs', path, name), 'a+') as fid: + if epoch == 0: + fid.write( + '{:^10},{:^10},{:^20},{:^20},{:^20},{:^20},{:^20},{:^20},{:^20},{:^20}\n' + .format('epoch', 'acc', 'klDiv', 'cosine', 'intersection', + 'chebyshev', 'squareChord', 'sorensendist', 'canberra', + 'clark')) + fid.write( + '{:5}Epoch,{:10},{:20},{:20},{:20},{:20},{:20},{:20},{:20},{:20}\n' + .format(epoch, acc, ldl['klDiv'], ldl['cosine'], + ldl['intersection'], ldl['chebyshev'], ldl['squareChord'], + ldl['sorensendist'], ldl['canberra'], ldl['clark'])) + + +def ordered_yaml_load(stream, + Loader=yaml.SafeLoader, + object_pairs_hook=OrderedDict): + class OrderedLoader(Loader): + pass + + def _construct_mapping(loader, node): + loader.flatten_mapping(node) + return object_pairs_hook(loader.construct_pairs(node)) + + OrderedLoader.add_constructor( + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, _construct_mapping) + return yaml.load(stream, OrderedLoader) + + +def ordered_yaml_dump(data, + stream=None, + Dumper=yaml.SafeDumper, + object_pairs_hook=OrderedDict, + **kwds): + class OrderedDumper(Dumper): + pass + + def _dict_representer(dumper, data): + return dumper.represent_mapping( + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, data.items()) + + OrderedDumper.add_representer(object_pairs_hook, _dict_representer) + return yaml.dump(data, stream, OrderedDumper, **kwds) diff --git a/utils/options.py b/utils/options.py new file mode 100644 index 0000000..9f9c49e --- /dev/null +++ b/utils/options.py @@ -0,0 +1,41 @@ +# -*- encoding: utf-8 -*- + +import argparse + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('--tag', + '-t', + type=str, + default='cache', + help='folder name to save the outputs') + + parser.add_argument('--batch_size', + '-b', + type=int, + help="input batch size") + + parser.add_argument('--mu', + type=float, + help="balence parameter of cnn and gcn") + + parser.add_argument('--lambda', + type=float, + help="balence parameter of mean and max") + + parser.add_argument('--resume_path', + '-r', + type=str, + help="which path to resume model") + + parser.add_argument('--specific_cfg', + '-s', + type=str, + default=None, + help="which path to resume model") + + parser.add_argument('--gpu-id', type=int, help="GPU index") + + return parser.parse_args() diff --git a/utils/tools_torch.py b/utils/tools_torch.py new file mode 100644 index 0000000..77b8b30 --- /dev/null +++ b/utils/tools_torch.py @@ -0,0 +1,15 @@ +# -*- encoding: utf-8 -*- +import os +import torch +import random +import numpy as np + + +def seed_torch(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True diff --git a/utils/writer.py b/utils/writer.py new file mode 100644 index 0000000..e82a289 --- /dev/null +++ b/utils/writer.py @@ -0,0 +1,62 @@ +# -*- encoding: utf-8 -*- +import os +import logging +from datetime import datetime +from torch.utils.tensorboard import SummaryWriter + + +def create_summary(tag): + name = None + if tag != 'cache': + name = datetime.now().strftime('%b-%d_%H:%M:%S') + '_' + tag + else: + name = 'cache' + + writer_dir = os.path.join("./runs", name) + if not os.path.exists(writer_dir): + os.makedirs(writer_dir) + + writer = SummaryWriter(writer_dir) + return writer, name + + +def create_logger(name): + logger = logging.getLogger() + logger.setLevel(level=logging.INFO) + + file_path = os.path.join('./logs', name) + if not os.path.exists(file_path): + os.makedirs(file_path) + + handler = logging.FileHandler(os.path.join(file_path, 'train.log'), + encoding="utf-8") + handler.setFormatter( + logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s")) + + console = logging.StreamHandler() + console.setLevel(logging.INFO) + console.setFormatter(logging.Formatter("%(message)s")) + + logger.addHandler(handler) + logger.addHandler(console) + + return logger + + +def clear_log(name): + for path in ['runs', 'logs']: + p = os.path.join(path, name) + if os.path.isdir(p): + command = 'rm -r ' + p + os.system(command) + + +def save_del(name): + while True: + ans = input("\nWhether to save the results of this training? (yes/no)") + if ans == 'no': + clear_log(name) + break + elif ans == 'yes': + break