-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathddp_test_nerf.py
executable file
·137 lines (112 loc) · 5.15 KB
/
ddp_test_nerf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
import torch
# import torch.nn as nn
import torch.optim
import torch.distributed
# from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing
import numpy as np
import os
# from collections import OrderedDict
# from ddp_model import NerfNet
import time
from data_loader_split import load_event_data_split
from utils import mse2psnr, colorize_np, to8b
import imageio
from ddp_train_nerf import config_parser, setup_logger, setup, cleanup, render_single_image, create_nerf
import logging
from nerf_sample_ray_split import CameraManager
logger = logging.getLogger(__package__)
def ddp_test_nerf(rank, args):
###### set up multi-processing
setup(rank, args.world_size)
###### set up logger
logger = logging.getLogger(__package__)
setup_logger()
###### decide chunk size according to gpu memory
if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14:
logger.info('setting batch size according to 24G gpu')
args.N_rand = 1024
args.chunk_size = 8192
else:
logger.info('setting batch size according to 12G gpu')
args.N_rand = 512
args.chunk_size = 4096
###### create network and wrap in ddp; each process should do this
camera_mgr = CameraManager(learnable=False)
start, models = create_nerf(rank, args, camera_mgr, load_camera_mgr=False)
render_splits = [x.strip() for x in args.render_splits.strip().split(',')]
# start testing
for split in render_splits:
out_dir = os.path.join(args.basedir, args.expname,
'render_{}_{:06d}'.format(split, start))
if rank == 0:
os.makedirs(out_dir, exist_ok=True)
###### load data and create ray samplers; each process should do this
ray_samplers = load_event_data_split(args.datadir, args.scene, split,
camera_mgr=models['camera_mgr'], max_winsize=1,
use_ray_jitter=args.use_ray_jitter, is_colored=args.is_colored,
polarity_offset=args.polarity_offset,
skip=args.testskip, cycle=args.is_cycled,
is_rgb_only=args.is_rgb_only)
for idx in range(len(ray_samplers)):
### each process should do this; but only main process merges the results
fname = '{:06d}.png'.format(idx)
if ray_samplers[idx].img_path is not None:
fname = os.path.basename(ray_samplers[idx].img_path)
if '.' not in fname:
fname = os.path.basename(ray_samplers[idx].rgb_path)
# fname = fname+'.png'
if os.path.isfile(os.path.join(out_dir, fname)):
logger.info('Skipping {}'.format(fname))
continue
time0 = time.time()
ret = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size, start)
dt = time.time() - time0
if rank == 0: # only main process should do this
logger.info('Rendered {} in {} seconds'.format(fname, dt))
# only save last level
im = ret[-1]['rgb'].numpy()
# compute psnr if ground-truth is available
if ray_samplers[idx].img_path is not None:
gt_im = ray_samplers[idx].get_img()
psnr = mse2psnr(np.mean((gt_im - im) * (gt_im - im)))
logger.info('{}: psnr={}'.format(fname, psnr))
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, fname), im)
im = ret[-1]['fg_rgb'].numpy()
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'fg_' + fname), im)
# im = ret[-1]['bg_rgb'].numpy()
# im = to8b(im)
# imageio.imwrite(os.path.join(out_dir, 'bg_' + fname), im)
im = ret[-1]['fg_depth'].numpy()
im = colorize_np(im, cmap_name='jet', append_cbar=True)
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'fg_depth_' + fname), im)
# im = ret[-1]['bg_depth'].numpy()
# im = colorize_np(im, cmap_name='jet', append_cbar=True)
# im = to8b(im)
# imageio.imwrite(os.path.join(out_dir, 'bg_depth_' + fname), im)
if 'fg_ldist' in ret[-1]:
im = ret[-1]['fg_ldist'].numpy()
im = colorize_np(im, cmap_name='jet', append_cbar=True)
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'fg_ldist_' + fname), im)
torch.cuda.empty_cache()
# clean up for multi-processing
cleanup()
def test():
parser = config_parser()
args = parser.parse_args()
logger.info(parser.format_values())
if args.world_size == -1:
args.world_size = torch.cuda.device_count()
logger.info('Using # gpus: {}'.format(args.world_size))
torch.multiprocessing.spawn(ddp_test_nerf,
args=(args,),
nprocs=args.world_size,
join=True)
if __name__ == '__main__':
setup_logger()
test()