Skip to content

Commit

Permalink
styleEDL
Browse files Browse the repository at this point in the history
  • Loading branch information
Wangjii committed May 1, 2023
0 parents commit 173e565
Show file tree
Hide file tree
Showing 27 changed files with 2,132 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*.pyc
runs/
logs/
.vscode/
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

## 运行
```bash
python main.py --tag
```
+ `--tag` 指定保存目录的前缀,默认为cache



## 添加新的模型

1.`network`文件夹下新建新的文件
1.`network``__init__`中的`models`加入新的网络模型
1.`configs`文件夹中新建对应的配置文件


## 添加新的trainer

1.`trainer`文件夹中新建文件,继承`basetrainer`
1.`trainer``__init__`中的`trainers`加入新的trainer
1. 运行时,由配置文件中的`trainer`指定
65 changes: 65 additions & 0 deletions configs/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
model: backbone
trainer: backbone

# ===========================
# train
seed: 330
gpu_id: 0
epochs: 90
batch_size: 8


image_size: 448
isNormalize: True
num_workers: 12

dataset: Twitter_LDL
num_classes: 8
data_path: /media/Harddisk_A/emotion_dataset/


resume_path:
save_interval: 2000
save_mark: 0.4
display_interval: 20

# ===========================
# optimizer
momentum: 0.1

# ===========================
# learning rate
lr: 0.01

# ===========================
# scheduler
scheduler: scheduler_cos

# stepLR
scheduler_stepLR:
step_size: 15
gamma: 0.5

# MultiStepLR
scheduler_multi:
milestones: [10, 20, 50]
gamma: 0.5

# ExponentialLR
scheduler_exp:
gamma: 0.5

# CosineAnnealingLR
scheduler_cos:
T_max: 40
eta_min: 0

# CyclicLR
scheduler_cyclic:
max_lr: 0.05
up: 10
down: 10

# lambda
scheduler_lambda:
lr_lambda: None
28 changes: 28 additions & 0 deletions configs/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
model: gmg
trainer: gmg
seed: 330


mu: 0.6
lambda: 0.6
parts: 2


gpu_id: 1
batch_size: 16

epochs: 90

lr: 0.005
scheduler: scheduler_multi
scheduler_multi:
milestones: [10, 30, 50, 70]
gamma: 0.1

scheduler_cos:
T_max: 40
eta_min: 0

momentum: 0.1
weight_decay: 1e-7
# resume_path: ./logs/Mar-01_14:45:22_final/epoch_55.pth
1 change: 1 addition & 0 deletions dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .dataset import Dataset_LDL
48 changes: 48 additions & 0 deletions dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# -*- encoding: utf-8 -*-

import os
import torch
from PIL import Image, ImageFile
from torch.utils.data import Dataset

ImageFile.LOAD_TRUNCATED_IMAGES = True


class Dataset_LDL(Dataset):
def __init__(
self,
root_path,
mode,
transforms,
):
super().__init__()
self.transforms = transforms
self.images_path = []
self.labels = []
self.cls = []

if mode == "train":
file_name = 'ground_truth_train.txt'
elif mode == "test":
file_name = 'ground_truth_test.txt'

# 读入文件
with open(os.path.join(root_path, file_name), 'r') as f:
file = f.readlines()

for line in file:
temp = line.rstrip('\n').rstrip(' ').split(' ')
self.images_path.append(os.path.join(root_path, "images", temp[0]))
label = [eval(i) for i in temp[1:-1]]
self.cls.append(eval(temp[-1]))
self.labels.append(label)

def __getitem__(self, index):
original_img = Image.open(self.images_path[index]).convert('RGB')
label = torch.FloatTensor(self.labels[index])
cls = self.cls[index]
original_img = self.transforms(original_img)
return original_img, label, cls

def __len__(self) -> int:
return len(self.images_path)
176 changes: 176 additions & 0 deletions envs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
name: torch
channels:
- pytorch
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- asttokens=2.0.5=pyhd3eb1b0_0
- backcall=0.2.0=pyhd3eb1b0_0
- blas=1.0=mkl
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2022.10.11=h06a4308_0
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- cudatoolkit=11.3.1=h2bc3f7f_2
- decorator=5.1.1=pyhd3eb1b0_0
- executing=0.8.3=pyhd3eb1b0_0
- ffmpeg=4.2.2=h20bf706_0
- fftw=3.3.9=h27cfd23_1
- filelock=3.6.0=pyhd3eb1b0_0
- freetype=2.12.1=h4a9f257_0
- ftfy=5.8=py_0
- giflib=5.2.1=h7b6447c_0
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
- huggingface_hub=0.2.1=pyhd3eb1b0_0
- importlib_metadata=4.11.3=hd3eb1b0_0
- intel-openmp=2021.4.0=h06a4308_3561
- jpeg=9e=h7f8727e_0
- jupyter_client=7.3.5=py310h06a4308_0
- jupyter_core=4.11.1=py310h06a4308_0
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libdeflate=1.8=h7f8727e_5
- libffi=3.3=he6710b0_2
- libgcc-ng=11.2.0=h1234567_1
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libidn2=2.3.2=h7f8727e_0
- libopus=1.3.1=h7b6447c_0
- libpng=1.6.37=hbc83047_0
- libsodium=1.0.18=h7b6447c_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.4.0=hecacb30_0
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.0.3=h7f8727e_2
- libuv=1.40.0=h7b6447c_0
- libvpx=1.7.0=h439df22_0
- libwebp=1.2.4=h11a3e52_0
- libwebp-base=1.2.4=h5eee18b_0
- lz4-c=1.9.3=h295c915_1
- mkl=2021.4.0=h06a4308_640
- mkl_fft=1.3.1=py310hd6ae3a3_0
- mkl_random=1.2.2=py310h00e6091_0
- ncurses=6.3=h5eee18b_3
- nettle=3.7.3=hbbd107a_1
- numpy-base=1.23.3=py310h8e6c178_0
- openh264=2.1.1=h4ff587b_0
- openssl=1.1.1s=h7f8727e_0
- packaging=21.3=pyhd3eb1b0_0
- parso=0.8.3=pyhd3eb1b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pure_eval=0.2.2=pyhd3eb1b0_0
- pycparser=2.21=pyhd3eb1b0_0
- pygments=2.11.2=pyhd3eb1b0_0
- pyopenssl=22.0.0=pyhd3eb1b0_0
- python=3.10.6=haa1d7c7_1
- python-dateutil=2.8.2=pyhd3eb1b0_0
- pytorch=1.11.0=py3.10_cuda11.3_cudnn8.2.0_0
- pytorch-mutex=1.0=cuda
- readline=8.2=h5eee18b_0
- sacremoses=0.0.43=pyhd3eb1b0_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.39.3=h5082296_0
- stack_data=0.2.0=pyhd3eb1b0_0
- threadpoolctl=2.2.0=pyh0d69192_0
- tk=8.6.12=h1ccaba5_0
- traitlets=5.1.1=pyhd3eb1b0_0
- typing_extensions=4.3.0=py310h06a4308_0
- tzdata=2022f=h04d1e81_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- wheel=0.37.1=pyhd3eb1b0_0
- x264=1!157.20191217=h7b6447c_0
- xz=5.2.6=h5eee18b_0
- yaml=0.2.5=h7b6447c_0
- zeromq=4.3.4=h2531618_0
- zlib=1.2.13=h5eee18b_0
- zstd=1.5.2=ha4553b6_0
- pip:
- absl-py==1.3.0
- bottleneck==1.3.5
- brotlipy==0.7.0
- cachetools==5.2.0
- certifi==2022.12.7
- cffi==1.15.1
- click==8.0.4
- contourpy==1.0.7
- cryptography==38.0.1
- cycler==0.11.0
- debugpy==1.5.1
- easydict==1.10
- entrypoints==0.4
- fonttools==4.39.2
- fvcore==0.1.5.post20221221
- google-auth==2.15.0
- google-auth-oauthlib==0.4.6
- grpcio==1.51.1
- idna==3.4
- importlib-metadata==4.11.3
- iopath==0.1.10
- ipykernel==6.15.2
- ipython==8.4.0
- jedi==0.18.1
- jieba==0.42.1
- joblib==1.1.1
- jupyter-client==7.3.5
- jupyter-core==4.11.1
- kiwisolver==1.4.4
- markdown==3.4.1
- markupsafe==2.1.1
- matplotlib==3.7.1
- matplotlib-inline==0.1.6
- mkl-fft==1.3.1
- mkl-random==1.2.2
- mkl-service==2.4.0
- nest-asyncio==1.5.5
- numexpr==2.8.3
- numpy==1.23.3
- nvidia-ml-py3==7.352.0
- oauthlib==3.2.2
- pandas==1.4.4
- pillow==9.2.0
- pip==22.2.2
- portalocker==2.7.0
- protobuf==3.20.3
- psutil==5.9.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pyparsing==3.0.9
- pysocks==1.7.1
- pytz==2022.1
- pyyaml==6.0
- pyzmq==23.2.0
- regex==2022.7.9
- requests==2.28.1
- requests-oauthlib==1.3.1
- rsa==4.9
- scikit-learn==1.1.1
- scipy==1.9.1
- setuptools==65.5.0
- tabulate==0.9.0
- tensorboard==2.11.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- termcolor==2.2.0
- tokenizers==0.11.4
- torch==1.11.0
- torchaudio==0.11.0
- torchsummary==1.5.1
- torchvision==0.12.0
- tornado==6.2
- tqdm==4.64.1
- transformers==4.18.0
- typing-extensions==4.3.0
- urllib3==1.26.12
- werkzeug==2.2.2
- yacs==0.1.8
- yapf==0.32.0
- zipp==3.8.0
prefix: /home/wj/.conda/envs/mmsa
44 changes: 44 additions & 0 deletions loss/EMD.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
import torch.nn as nn


def single_emd_loss(p, q, r=2):
"""
Earth Mover's Distance of one sample
Args:
p: true distribution of shape num_classes × 1
q: estimated distribution of shape num_classes × 1
r: norm parameter
"""
assert p.shape == q.shape, "Length of the two distribution must be the same"
length = p.shape[0]
emd_loss = 0.0
for i in range(1, length + 1):
emd_loss += torch.abs(sum(p[:i] - q[:i]))**r
return (emd_loss / length)**(1. / r)


def emd_loss(p, q, r=2):
"""
Earth Mover's Distance on a batch
Args:
p: true distribution of shape mini_batch_size × num_classes × 1
q: estimated distribution of shape mini_batch_size × num_classes × 1
r: norm parameters
"""
assert p.shape == q.shape, "Shape of the two distribution batches must be the same."
mini_batch_size = p.shape[0]
loss_vector = []
for i in range(mini_batch_size):
loss_vector.append(single_emd_loss(p[i], q[i], r=r))
return sum(loss_vector) / mini_batch_size


class EMDLoss(nn.Module):
def __init__(self) -> None:
super(EMDLoss, self).__init__()

def forward(self, x, y):
return emd_loss(x, y)
3 changes: 3 additions & 0 deletions loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .adv_div import AdvDivLoss, ProgressiveCircularLoss
from .EMD import EMDLoss
from .share_specific import SharedAndSpecificLoss
Loading

0 comments on commit 173e565

Please sign in to comment.