-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
lisb20
committed
Jul 26, 2023
1 parent
4ebba60
commit 55bcb24
Showing
23 changed files
with
1,577 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
model_configs: | ||
model_name: PAC_Net | ||
pretrained: true | ||
rnn_type: gru | ||
rnn_hdim: 128 | ||
v_loss: true | ||
warm_up: 32 | ||
max_peo: 3 | ||
|
||
dataset_configs: | ||
dataset_root: /mnt/petrelfs/share_data/lisibo/NLOS/data_render_mot/ # fill your dataset path here! | ||
# dataset_root: ./dataset/real_shot_new | ||
# dataset_root: ../dataset/render | ||
data_type: real_shot | ||
train_ratio: 0.8 | ||
route_len: 128 | ||
# total_len: 250 | ||
noise_factor: 0 | ||
noisy: false | ||
max_peo: 3 | ||
|
||
loader_kwargs: | ||
num_workers: 8 | ||
pin_memory: true | ||
prefetch_factor: 8 | ||
persistent_workers: true | ||
|
||
train_configs: | ||
project_name: my_project # fill your wandb project name here! | ||
# resume: True | ||
# resume_path: 2023_07_03_15_56_48/ | ||
resume: false | ||
resume_path: None | ||
batch_size: 16 | ||
seed: 1026 | ||
device: cuda:0 | ||
amp: true | ||
v_loss_alpha: 500 | ||
x_loss_alpha: 1 | ||
m_loss_alpha: 100 | ||
|
||
loss_total_alpha: 1000 | ||
|
||
optim_kwargs: | ||
optimizer: AdamW | ||
lr: 3.0e-4 | ||
weight_decay: 2.0e-3 | ||
|
||
schedule_configs: | ||
schedule_type: cosine | ||
max_epoch: 120 | ||
cos_T: 70 | ||
cos_iters: 1 | ||
cos_mul: 2 | ||
|
||
distributed_configs: | ||
distributed: false | ||
gpu_ids: 0 | ||
device_ids: 1 | ||
world_size: 1 | ||
local_rank: 0 | ||
port: 6666 | ||
|
||
log_configs: | ||
log_dir: log # fill your log dir here! | ||
save_epoch_interval: 5 | ||
snapshot_interval: 100 |
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import os | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
from torch.utils.data import Dataset | ||
from torch.utils.data.dataset import random_split | ||
from scipy.io import loadmat | ||
import pdb | ||
|
||
from .loader import npy_loader | ||
|
||
|
||
class TrackingDataset(Dataset): | ||
def __init__(self, | ||
dataset_root: str, | ||
data_type: str = 'render', | ||
route_len: int = 128, | ||
use_fileclient: bool = False, | ||
noisy: bool = True, | ||
max_peo: int = 3, | ||
** kwags) -> None: | ||
self.dataset_root = dataset_root | ||
self.max_peo = max_peo | ||
self.num_frames = route_len | ||
if data_type == 'render': | ||
self.total_len = 256 | ||
self.npy_name = 'video_128_noisy.npy' if noisy else 'video_128.npy' | ||
elif data_type == 'real_shot': | ||
self.total_len = 256 | ||
self.npy_name = 'video_128.npy' | ||
|
||
|
||
self.dataset_dir = dataset_root | ||
self.dirs = [] | ||
for peo in [1,2,3]: | ||
tmp_dir = [os.path.join(str(peo), d) for d in os.listdir(os.path.join(self.dataset_dir, str(peo)))] | ||
self.dirs += tmp_dir[:600] | ||
|
||
print('dirs: ', len(self.dirs)) | ||
pdb.set_trace() | ||
|
||
if use_fileclient: | ||
self.npy_loader = npy_loader() | ||
self.load_npy = self.npy_loader.get_item | ||
else: | ||
self.load_npy = np.load | ||
|
||
def __len__(self): | ||
return len(self.dirs) | ||
|
||
def __getitem__(self, idx): | ||
abs_png_dir = os.path.join(self.dataset_dir, self.dirs[idx]) | ||
npy_file = os.path.join(abs_png_dir, self.npy_name) | ||
video = self.load_npy(npy_file) | ||
|
||
start_frame = random.randint(0, self.total_len - self.num_frames) | ||
video = video[:, start_frame:start_frame + self.num_frames] # (3, T, H, W) or (3, T-1, H, W) | ||
|
||
mat_file = loadmat(os.path.join(abs_png_dir, 'route.mat')) | ||
|
||
route = mat_file['route'][start_frame:start_frame + self.num_frames] # (T,) | ||
route = route.reshape((route.shape[0], -1)) # (T, 2n) | ||
## route: T*2n 对齐人数填充为 T*10。空缺填 | ||
## route 按照奇数行起始点排序 ----route[0:1]为起始最靠左的人 | ||
npeo = route.shape[1] // 2 | ||
avg = [] | ||
# for p in range(npeo): | ||
# avg.append({"st_x":route[0,2*p],"idx":p}) | ||
|
||
# avg.sort(key=lambda x:x["st_x"]) | ||
# tmp = np.zeros((route.shape[0], route.shape[1])) | ||
# for i in range(npeo): | ||
# tmp[:,2*i:2*i+2] = route[:,2*avg[i]["idx"]:2*avg[i]["idx"]+2] | ||
# route = tmp | ||
# print('route', sum(route[:,0]), sum(route[:,2])) | ||
for p in range(npeo): | ||
avg.append({"sumx":sum(route[:,2*p]),"idx":p}) | ||
avg.sort(key=lambda x:x["sumx"]) | ||
tmp = np.zeros((route.shape[0], route.shape[1])) | ||
for i in range(npeo): | ||
tmp[:,2*i:2*i+2] = route[:,2*avg[i]["idx"]:2*avg[i]["idx"]+2] | ||
route = tmp | ||
|
||
route = np.concatenate((route, np.ones((route.shape[0], self.max_peo * 2 - route.shape[1])) * 0.5), axis=1) # (T,10) | ||
assert route.shape[1] == self.max_peo * 2 and route.shape[0] == self.num_frames | ||
map_size = mat_file['map_size'] # (1, 2) | ||
## mapsize ->[mapsize * 5] | ||
map_size = np.tile(map_size, (1, self.max_peo)) # (1,10) | ||
|
||
return torch.from_numpy(video), torch.from_numpy(route).float(), torch.from_numpy(map_size).float() | ||
|
||
|
||
def split_dataset(phase: str = 'train', train_ratio: float = 0.8, **kwargs): | ||
full_dataset = TrackingDataset(**kwargs) | ||
|
||
if phase == 'train': | ||
train_size = int(len(full_dataset) * train_ratio) | ||
val_size = len(full_dataset) - train_size | ||
return random_split(full_dataset, [train_size, val_size]) | ||
elif phase == 'test': | ||
return full_dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import io | ||
import os | ||
from typing import Tuple, Union | ||
|
||
import numpy as np | ||
import mmcv | ||
import torch | ||
from torchvision.io import read_image | ||
from torchvision.transforms.functional import resize | ||
|
||
|
||
def load_frames( | ||
root: str, | ||
frame_range: Union[None, Tuple[int, int]] = None, | ||
output_size: Union[None, Tuple[int, int]] = None, | ||
rgb_only=True | ||
) -> torch.Tensor: | ||
frame_list = sorted([f for f in os.listdir(root) if f.endswith('.png')]) | ||
if frame_range is not None: | ||
frame_list = frame_list[frame_range[0]: frame_range[1]] | ||
frame_paths = [os.path.join(root, f) for f in frame_list] | ||
|
||
C, H, W = read_image(frame_paths[0]).shape | ||
frame_num = len(frame_list) | ||
if C == 4 and rgb_only: | ||
frames = torch.zeros((frame_num, 3, H, W)) | ||
else: | ||
frames = torch.zeros((frame_num, C, H, W)) | ||
for i in range(frame_num): | ||
frame = read_image(frame_paths[i]) # (C, H, W) | ||
if C == 4 and rgb_only: | ||
frame = frame[:3] | ||
frames[i] = frame | ||
if output_size is not None: | ||
frames = resize(frames, size=output_size) | ||
return frames # (T, C, H, W) | ||
|
||
|
||
class npy_loader(object): | ||
def __init__(self): | ||
self.file_client = mmcv.fileio.FileClient(backend='petrel') | ||
|
||
def get_item(self, file_path: str): | ||
npy_buffer = self.file_client.get(file_path) | ||
# return np.frombuffer(npy_buffer) | ||
with io.BytesIO(npy_buffer) as f: | ||
return np.load(f) # , encoding='bytes', allow_pickle=True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from torch import Tensor | ||
from torchvision.transforms.functional import resize | ||
|
||
|
||
def sub_mean(frames: Tensor) -> Tensor: | ||
mean_frame = frames.mean(axis=0, keepdim=True) | ||
frames_sub_mean = frames.sub(mean_frame) | ||
|
||
return frames_sub_mean | ||
|
||
|
||
def diff(frames: Tensor) -> Tensor: | ||
return frames[1:].sub(frames[:-1]) | ||
|
||
|
||
def normalize(frame: Tensor): | ||
return (frame - frame.min()) / (frame.max() - frame.min()) | ||
|
||
|
||
def resize_video(frames: Tensor, bias_ratio: float = None, output_size: tuple = (128, 128)) -> Tensor: | ||
T, C, H, W = frames.shape | ||
crop_idx = (W - H) // 2 | ||
if bias_ratio is not None: | ||
crop_idx -= int(W * bias_ratio) | ||
output_frames = frames[:, :, :, crop_idx:crop_idx + H] | ||
if output_size is not None: | ||
output_frames = resize(output_frames, size=output_size) | ||
|
||
return output_frames |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import argparse | ||
import os | ||
|
||
import torch | ||
import torch.multiprocessing as mp | ||
import yaml | ||
import pdb | ||
|
||
|
||
def main(cfg): | ||
dist_cfgs = cfg['distributed_configs'] | ||
|
||
os.makedirs(cfg['log_configs']['log_dir'], exist_ok=True) | ||
# os.environ["CUDA_VISIBLE_DEVICES"] = dist_cfgs['device_ids'] | ||
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | ||
|
||
world_size = len(dist_cfgs['device_ids'].split(',')) | ||
dist_cfgs['distributed'] = True if world_size > 1 else False | ||
dist_cfgs['world_size'] = world_size | ||
cfg['loader_kwargs']['batch_size'] = cfg['train_configs']['batch_size'] // world_size | ||
|
||
print("Allocating workers...") | ||
if dist_cfgs['distributed']: | ||
mp.spawn(worker, nprocs=world_size, args=(cfg,)) | ||
else: | ||
worker(0, cfg) | ||
|
||
|
||
def worker(rank, cfg): | ||
torch.cuda.set_device(rank) | ||
cfg['distributed_configs']['local_rank'] = rank | ||
|
||
from utils.trainer import Trainer_tracking | ||
trainer = Trainer_tracking(cfg) | ||
|
||
trainer.run() | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--cfg_file', type=str, default='default') | ||
|
||
parser.add_argument('--model_name', type=str, required=True) | ||
parser.add_argument('--warm_up', type=int, default=32) | ||
parser.add_argument('--pretrained', action="store_true") | ||
parser.add_argument('--rnn_hdim', type=int, default=128) | ||
|
||
parser.add_argument('-b', '--batch_size', type=int, default=32) | ||
parser.add_argument('--v_loss_alpha', type=float, default=500) | ||
parser.add_argument('--loss_total_alpha', type=float, default=1000) | ||
parser.add_argument('-r', '--resume', action='store_true', help='load previously saved checkpoint') | ||
|
||
parser.add_argument('-lr_b', '--lr_backbone', type=float, default=3e-4) | ||
parser.add_argument('-wd', '--weight_decay', type=float, default=2.0e-3) | ||
|
||
parser.add_argument('-T', '--cos_T', type=int, default=70) | ||
|
||
parser.add_argument('-g', '--gpu_ids', type=lambda x: x.replace(" ", ""), default='0', | ||
help='available gpu ids') | ||
parser.add_argument('--port', type=str, default='6666', help='port number of distributed init') | ||
|
||
args = parser.parse_args() | ||
|
||
config_file = os.path.join('configs', f'{args.cfg_file}.yaml') | ||
print(f'Reading config file: {config_file}') | ||
with open(config_file, 'r') as stream: | ||
config = yaml.load(stream, Loader=yaml.FullLoader) | ||
|
||
config['model_configs']['warm_up'] = args.warm_up | ||
config['model_configs']['pretrained'] = args.pretrained | ||
config['model_configs']['rnn_hdim'] = args.rnn_hdim | ||
|
||
config['dataset_configs']['route_len'] += args.warm_up | ||
|
||
config['train_configs']['batch_size'] = args.batch_size | ||
config['train_configs']['v_loss_alpha'] = args.v_loss_alpha | ||
config['train_configs']['loss_total_alpha'] = args.loss_total_alpha | ||
config['train_configs']['resume'] = args.resume | ||
|
||
config['optim_kwargs']['lr'] = args.lr_backbone | ||
config['optim_kwargs']['weight_decay'] = args.weight_decay | ||
|
||
config['schedule_configs']['cos_T'] = args.cos_T | ||
|
||
config['distributed_configs']['device_ids'] = args.gpu_ids | ||
config['distributed_configs']['port'] = args.port | ||
|
||
main(config) | ||
## seed 1026 | ||
# srun -p optimal --quotatype=auto --gres=gpu:1 -J NLOS_lisibo python train.py --model_name PAC-Net --pretrained --warm_up 32 -b 16 |
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Oops, something went wrong.