diff --git a/code/datasets/scene_dataset.py b/code/datasets/scene_dataset.py index 7e6ffca..5275bed 100644 --- a/code/datasets/scene_dataset.py +++ b/code/datasets/scene_dataset.py @@ -13,7 +13,8 @@ def __init__(self, data_dir, img_res, scene_id='scan0', - cam_file=None + cam_file=None, + views=None ): self.instance_dir = os.path.join('../data', data_dir, '{0}'.format(scene_id)) @@ -31,15 +32,20 @@ def __init__(self, mask_dir = '{0}/mask'.format(self.instance_dir) mask_paths = sorted(utils.glob_imgs(mask_dir)) - self.n_images = len(image_paths) + if views: + self.views = views + image_paths = [image_paths[v] for v in self.views] + mask_paths = [mask_paths[v] for v in self.views] + else: + self.views = list(range(len(image_paths))) self.cam_file = '{0}/cameras.npz'.format(self.instance_dir) if cam_file is not None: self.cam_file = '{0}/{1}'.format(self.instance_dir, cam_file) camera_dict = np.load(self.cam_file) - scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] - world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] + scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in self.views] + world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in self.views] self.intrinsics_all = [] self.pose_all = [] @@ -63,7 +69,7 @@ def __init__(self, self.object_masks.append(torch.from_numpy(object_mask).bool()) def __len__(self): - return self.n_images + return len(self.views) def __getitem__(self, idx): uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32) @@ -119,8 +125,8 @@ def get_scale_mat(self): def get_gt_pose(self, scaled=False): # Load gt pose without normalization to unit sphere camera_dict = np.load(self.cam_file) - world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] - scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] + world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in self.views] + scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in self.views] pose_all = [] for scale_mat, world_mat in zip(scale_mats, world_mats): @@ -137,8 +143,8 @@ def get_pose_init(self): # get noisy initializations obtained with the linear method cam_file = '{0}/cameras_linear_init.npz'.format(self.instance_dir) camera_dict = np.load(cam_file) - scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] - world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] + scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in self.views] + world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in self.views] init_pose = [] for scale_mat, world_mat in zip(scale_mats, world_mats):