-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 173e565
Showing
27 changed files
with
2,132 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
*.pyc | ||
runs/ | ||
logs/ | ||
.vscode/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
|
||
## 运行 | ||
```bash | ||
python main.py --tag | ||
``` | ||
+ `--tag` 指定保存目录的前缀,默认为cache | ||
|
||
|
||
|
||
## 添加新的模型 | ||
|
||
1. 在`network`文件夹下新建新的文件 | ||
1. 在`network`下`__init__`中的`models`加入新的网络模型 | ||
1. 在`configs`文件夹中新建对应的配置文件 | ||
|
||
|
||
## 添加新的trainer | ||
|
||
1. 在`trainer`文件夹中新建文件,继承`basetrainer` | ||
1. 在`trainer`下`__init__`中的`trainers`加入新的trainer | ||
1. 运行时,由配置文件中的`trainer`指定 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
model: backbone | ||
trainer: backbone | ||
|
||
# =========================== | ||
# train | ||
seed: 330 | ||
gpu_id: 0 | ||
epochs: 90 | ||
batch_size: 8 | ||
|
||
|
||
image_size: 448 | ||
isNormalize: True | ||
num_workers: 12 | ||
|
||
dataset: Twitter_LDL | ||
num_classes: 8 | ||
data_path: /media/Harddisk_A/emotion_dataset/ | ||
|
||
|
||
resume_path: | ||
save_interval: 2000 | ||
save_mark: 0.4 | ||
display_interval: 20 | ||
|
||
# =========================== | ||
# optimizer | ||
momentum: 0.1 | ||
|
||
# =========================== | ||
# learning rate | ||
lr: 0.01 | ||
|
||
# =========================== | ||
# scheduler | ||
scheduler: scheduler_cos | ||
|
||
# stepLR | ||
scheduler_stepLR: | ||
step_size: 15 | ||
gamma: 0.5 | ||
|
||
# MultiStepLR | ||
scheduler_multi: | ||
milestones: [10, 20, 50] | ||
gamma: 0.5 | ||
|
||
# ExponentialLR | ||
scheduler_exp: | ||
gamma: 0.5 | ||
|
||
# CosineAnnealingLR | ||
scheduler_cos: | ||
T_max: 40 | ||
eta_min: 0 | ||
|
||
# CyclicLR | ||
scheduler_cyclic: | ||
max_lr: 0.05 | ||
up: 10 | ||
down: 10 | ||
|
||
# lambda | ||
scheduler_lambda: | ||
lr_lambda: None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
model: gmg | ||
trainer: gmg | ||
seed: 330 | ||
|
||
|
||
mu: 0.6 | ||
lambda: 0.6 | ||
parts: 2 | ||
|
||
|
||
gpu_id: 1 | ||
batch_size: 16 | ||
|
||
epochs: 90 | ||
|
||
lr: 0.005 | ||
scheduler: scheduler_multi | ||
scheduler_multi: | ||
milestones: [10, 30, 50, 70] | ||
gamma: 0.1 | ||
|
||
scheduler_cos: | ||
T_max: 40 | ||
eta_min: 0 | ||
|
||
momentum: 0.1 | ||
weight_decay: 1e-7 | ||
# resume_path: ./logs/Mar-01_14:45:22_final/epoch_55.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .dataset import Dataset_LDL |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# -*- encoding: utf-8 -*- | ||
|
||
import os | ||
import torch | ||
from PIL import Image, ImageFile | ||
from torch.utils.data import Dataset | ||
|
||
ImageFile.LOAD_TRUNCATED_IMAGES = True | ||
|
||
|
||
class Dataset_LDL(Dataset): | ||
def __init__( | ||
self, | ||
root_path, | ||
mode, | ||
transforms, | ||
): | ||
super().__init__() | ||
self.transforms = transforms | ||
self.images_path = [] | ||
self.labels = [] | ||
self.cls = [] | ||
|
||
if mode == "train": | ||
file_name = 'ground_truth_train.txt' | ||
elif mode == "test": | ||
file_name = 'ground_truth_test.txt' | ||
|
||
# 读入文件 | ||
with open(os.path.join(root_path, file_name), 'r') as f: | ||
file = f.readlines() | ||
|
||
for line in file: | ||
temp = line.rstrip('\n').rstrip(' ').split(' ') | ||
self.images_path.append(os.path.join(root_path, "images", temp[0])) | ||
label = [eval(i) for i in temp[1:-1]] | ||
self.cls.append(eval(temp[-1])) | ||
self.labels.append(label) | ||
|
||
def __getitem__(self, index): | ||
original_img = Image.open(self.images_path[index]).convert('RGB') | ||
label = torch.FloatTensor(self.labels[index]) | ||
cls = self.cls[index] | ||
original_img = self.transforms(original_img) | ||
return original_img, label, cls | ||
|
||
def __len__(self) -> int: | ||
return len(self.images_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
name: torch | ||
channels: | ||
- pytorch | ||
- defaults | ||
dependencies: | ||
- _libgcc_mutex=0.1=main | ||
- _openmp_mutex=5.1=1_gnu | ||
- asttokens=2.0.5=pyhd3eb1b0_0 | ||
- backcall=0.2.0=pyhd3eb1b0_0 | ||
- blas=1.0=mkl | ||
- bzip2=1.0.8=h7b6447c_0 | ||
- ca-certificates=2022.10.11=h06a4308_0 | ||
- charset-normalizer=2.0.4=pyhd3eb1b0_0 | ||
- cudatoolkit=11.3.1=h2bc3f7f_2 | ||
- decorator=5.1.1=pyhd3eb1b0_0 | ||
- executing=0.8.3=pyhd3eb1b0_0 | ||
- ffmpeg=4.2.2=h20bf706_0 | ||
- fftw=3.3.9=h27cfd23_1 | ||
- filelock=3.6.0=pyhd3eb1b0_0 | ||
- freetype=2.12.1=h4a9f257_0 | ||
- ftfy=5.8=py_0 | ||
- giflib=5.2.1=h7b6447c_0 | ||
- gmp=6.2.1=h295c915_3 | ||
- gnutls=3.6.15=he1e5248_0 | ||
- huggingface_hub=0.2.1=pyhd3eb1b0_0 | ||
- importlib_metadata=4.11.3=hd3eb1b0_0 | ||
- intel-openmp=2021.4.0=h06a4308_3561 | ||
- jpeg=9e=h7f8727e_0 | ||
- jupyter_client=7.3.5=py310h06a4308_0 | ||
- jupyter_core=4.11.1=py310h06a4308_0 | ||
- lame=3.100=h7b6447c_0 | ||
- lcms2=2.12=h3be6417_0 | ||
- ld_impl_linux-64=2.38=h1181459_1 | ||
- lerc=3.0=h295c915_0 | ||
- libdeflate=1.8=h7f8727e_5 | ||
- libffi=3.3=he6710b0_2 | ||
- libgcc-ng=11.2.0=h1234567_1 | ||
- libgfortran-ng=11.2.0=h00389a5_1 | ||
- libgfortran5=11.2.0=h1234567_1 | ||
- libgomp=11.2.0=h1234567_1 | ||
- libidn2=2.3.2=h7f8727e_0 | ||
- libopus=1.3.1=h7b6447c_0 | ||
- libpng=1.6.37=hbc83047_0 | ||
- libsodium=1.0.18=h7b6447c_0 | ||
- libstdcxx-ng=11.2.0=h1234567_1 | ||
- libtasn1=4.16.0=h27cfd23_0 | ||
- libtiff=4.4.0=hecacb30_0 | ||
- libunistring=0.9.10=h27cfd23_0 | ||
- libuuid=1.0.3=h7f8727e_2 | ||
- libuv=1.40.0=h7b6447c_0 | ||
- libvpx=1.7.0=h439df22_0 | ||
- libwebp=1.2.4=h11a3e52_0 | ||
- libwebp-base=1.2.4=h5eee18b_0 | ||
- lz4-c=1.9.3=h295c915_1 | ||
- mkl=2021.4.0=h06a4308_640 | ||
- mkl_fft=1.3.1=py310hd6ae3a3_0 | ||
- mkl_random=1.2.2=py310h00e6091_0 | ||
- ncurses=6.3=h5eee18b_3 | ||
- nettle=3.7.3=hbbd107a_1 | ||
- numpy-base=1.23.3=py310h8e6c178_0 | ||
- openh264=2.1.1=h4ff587b_0 | ||
- openssl=1.1.1s=h7f8727e_0 | ||
- packaging=21.3=pyhd3eb1b0_0 | ||
- parso=0.8.3=pyhd3eb1b0_0 | ||
- pexpect=4.8.0=pyhd3eb1b0_3 | ||
- pickleshare=0.7.5=pyhd3eb1b0_1003 | ||
- prompt-toolkit=3.0.20=pyhd3eb1b0_0 | ||
- ptyprocess=0.7.0=pyhd3eb1b0_2 | ||
- pure_eval=0.2.2=pyhd3eb1b0_0 | ||
- pycparser=2.21=pyhd3eb1b0_0 | ||
- pygments=2.11.2=pyhd3eb1b0_0 | ||
- pyopenssl=22.0.0=pyhd3eb1b0_0 | ||
- python=3.10.6=haa1d7c7_1 | ||
- python-dateutil=2.8.2=pyhd3eb1b0_0 | ||
- pytorch=1.11.0=py3.10_cuda11.3_cudnn8.2.0_0 | ||
- pytorch-mutex=1.0=cuda | ||
- readline=8.2=h5eee18b_0 | ||
- sacremoses=0.0.43=pyhd3eb1b0_0 | ||
- six=1.16.0=pyhd3eb1b0_1 | ||
- sqlite=3.39.3=h5082296_0 | ||
- stack_data=0.2.0=pyhd3eb1b0_0 | ||
- threadpoolctl=2.2.0=pyh0d69192_0 | ||
- tk=8.6.12=h1ccaba5_0 | ||
- traitlets=5.1.1=pyhd3eb1b0_0 | ||
- typing_extensions=4.3.0=py310h06a4308_0 | ||
- tzdata=2022f=h04d1e81_0 | ||
- wcwidth=0.2.5=pyhd3eb1b0_0 | ||
- wheel=0.37.1=pyhd3eb1b0_0 | ||
- x264=1!157.20191217=h7b6447c_0 | ||
- xz=5.2.6=h5eee18b_0 | ||
- yaml=0.2.5=h7b6447c_0 | ||
- zeromq=4.3.4=h2531618_0 | ||
- zlib=1.2.13=h5eee18b_0 | ||
- zstd=1.5.2=ha4553b6_0 | ||
- pip: | ||
- absl-py==1.3.0 | ||
- bottleneck==1.3.5 | ||
- brotlipy==0.7.0 | ||
- cachetools==5.2.0 | ||
- certifi==2022.12.7 | ||
- cffi==1.15.1 | ||
- click==8.0.4 | ||
- contourpy==1.0.7 | ||
- cryptography==38.0.1 | ||
- cycler==0.11.0 | ||
- debugpy==1.5.1 | ||
- easydict==1.10 | ||
- entrypoints==0.4 | ||
- fonttools==4.39.2 | ||
- fvcore==0.1.5.post20221221 | ||
- google-auth==2.15.0 | ||
- google-auth-oauthlib==0.4.6 | ||
- grpcio==1.51.1 | ||
- idna==3.4 | ||
- importlib-metadata==4.11.3 | ||
- iopath==0.1.10 | ||
- ipykernel==6.15.2 | ||
- ipython==8.4.0 | ||
- jedi==0.18.1 | ||
- jieba==0.42.1 | ||
- joblib==1.1.1 | ||
- jupyter-client==7.3.5 | ||
- jupyter-core==4.11.1 | ||
- kiwisolver==1.4.4 | ||
- markdown==3.4.1 | ||
- markupsafe==2.1.1 | ||
- matplotlib==3.7.1 | ||
- matplotlib-inline==0.1.6 | ||
- mkl-fft==1.3.1 | ||
- mkl-random==1.2.2 | ||
- mkl-service==2.4.0 | ||
- nest-asyncio==1.5.5 | ||
- numexpr==2.8.3 | ||
- numpy==1.23.3 | ||
- nvidia-ml-py3==7.352.0 | ||
- oauthlib==3.2.2 | ||
- pandas==1.4.4 | ||
- pillow==9.2.0 | ||
- pip==22.2.2 | ||
- portalocker==2.7.0 | ||
- protobuf==3.20.3 | ||
- psutil==5.9.0 | ||
- pyasn1==0.4.8 | ||
- pyasn1-modules==0.2.8 | ||
- pyparsing==3.0.9 | ||
- pysocks==1.7.1 | ||
- pytz==2022.1 | ||
- pyyaml==6.0 | ||
- pyzmq==23.2.0 | ||
- regex==2022.7.9 | ||
- requests==2.28.1 | ||
- requests-oauthlib==1.3.1 | ||
- rsa==4.9 | ||
- scikit-learn==1.1.1 | ||
- scipy==1.9.1 | ||
- setuptools==65.5.0 | ||
- tabulate==0.9.0 | ||
- tensorboard==2.11.0 | ||
- tensorboard-data-server==0.6.1 | ||
- tensorboard-plugin-wit==1.8.1 | ||
- termcolor==2.2.0 | ||
- tokenizers==0.11.4 | ||
- torch==1.11.0 | ||
- torchaudio==0.11.0 | ||
- torchsummary==1.5.1 | ||
- torchvision==0.12.0 | ||
- tornado==6.2 | ||
- tqdm==4.64.1 | ||
- transformers==4.18.0 | ||
- typing-extensions==4.3.0 | ||
- urllib3==1.26.12 | ||
- werkzeug==2.2.2 | ||
- yacs==0.1.8 | ||
- yapf==0.32.0 | ||
- zipp==3.8.0 | ||
prefix: /home/wj/.conda/envs/mmsa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
def single_emd_loss(p, q, r=2): | ||
""" | ||
Earth Mover's Distance of one sample | ||
Args: | ||
p: true distribution of shape num_classes × 1 | ||
q: estimated distribution of shape num_classes × 1 | ||
r: norm parameter | ||
""" | ||
assert p.shape == q.shape, "Length of the two distribution must be the same" | ||
length = p.shape[0] | ||
emd_loss = 0.0 | ||
for i in range(1, length + 1): | ||
emd_loss += torch.abs(sum(p[:i] - q[:i]))**r | ||
return (emd_loss / length)**(1. / r) | ||
|
||
|
||
def emd_loss(p, q, r=2): | ||
""" | ||
Earth Mover's Distance on a batch | ||
Args: | ||
p: true distribution of shape mini_batch_size × num_classes × 1 | ||
q: estimated distribution of shape mini_batch_size × num_classes × 1 | ||
r: norm parameters | ||
""" | ||
assert p.shape == q.shape, "Shape of the two distribution batches must be the same." | ||
mini_batch_size = p.shape[0] | ||
loss_vector = [] | ||
for i in range(mini_batch_size): | ||
loss_vector.append(single_emd_loss(p[i], q[i], r=r)) | ||
return sum(loss_vector) / mini_batch_size | ||
|
||
|
||
class EMDLoss(nn.Module): | ||
def __init__(self) -> None: | ||
super(EMDLoss, self).__init__() | ||
|
||
def forward(self, x, y): | ||
return emd_loss(x, y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .adv_div import AdvDivLoss, ProgressiveCircularLoss | ||
from .EMD import EMDLoss | ||
from .share_specific import SharedAndSpecificLoss |
Oops, something went wrong.