diff --git a/src/data/data_load.py b/src/data/data_load.py index 2265593..33e7d78 100644 --- a/src/data/data_load.py +++ b/src/data/data_load.py @@ -69,8 +69,6 @@ def get_num_classes(data_config): data_loader = data_loader_from_config(data_config) return len(data_loader.dataset.classes) -def to_int16(label): - return torch.tensor(label, dtype=torch.int16) def create_data_loader(data_dir: str, image_height: int, image_width: int, image_dtype=torch.float16, using_gpu=False, batch_size=1, n_workers=1): @@ -80,9 +78,8 @@ def create_data_loader(data_dir: str, image_height: int, image_width: int, image transforms.ToDtype(image_dtype, scale=True), # Float16 is tiny bit faster, and bit more VRAM. Strange. transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) - label_transform = to_int16 try: - data_set = torch_data_set.ImageFolder(root=data_dir, transform=data_transform, target_transform=label_transform) + data_set = torch_data_set.ImageFolder(root=data_dir, transform=data_transform) except FileNotFoundError: raise FileNotFoundError('Data directory provided should contain directories that have images in them, ' 'directory provided: ' + data_dir)