Skip to content

Commit

Permalink
Force labels to be int64, no observable difference in speed or VRAM t…
Browse files Browse the repository at this point in the history
…o switching to int32, or int16
  • Loading branch information
gcervantes8 committed Dec 20, 2023
1 parent 6a23658 commit 9a51b12
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions src/data/data_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ def get_num_classes(data_config):
data_loader = data_loader_from_config(data_config)
return len(data_loader.dataset.classes)

def to_int16(label):
return torch.tensor(label, dtype=torch.int16)

def create_data_loader(data_dir: str, image_height: int, image_width: int, image_dtype=torch.float16, using_gpu=False,
batch_size=1, n_workers=1):
Expand All @@ -80,9 +78,8 @@ def create_data_loader(data_dir: str, image_height: int, image_width: int, image
transforms.ToDtype(image_dtype, scale=True), # Float16 is tiny bit faster, and bit more VRAM. Strange.
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
label_transform = to_int16
try:
data_set = torch_data_set.ImageFolder(root=data_dir, transform=data_transform, target_transform=label_transform)
data_set = torch_data_set.ImageFolder(root=data_dir, transform=data_transform)
except FileNotFoundError:
raise FileNotFoundError('Data directory provided should contain directories that have images in them, '
'directory provided: ' + data_dir)
Expand Down

0 comments on commit 9a51b12

Please sign in to comment.