From 74421aef4bdc2857cae3f427b6a9458d91633791 Mon Sep 17 00:00:00 2001 From: Piotr Czapla Date: Wed, 14 Mar 2018 16:05:14 +0100 Subject: [PATCH] Moved method that depends on self.y to the class that has it The self.y is being used in resize_imps and get_n. It is assumed that it has `len()` defined so it is a list of an array. So FilesArrayDataset seems to be a better match for the methods. --- fastai/dataset.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/fastai/dataset.py b/fastai/dataset.py index 6a7bc0c89..7aaac1ea7 100644 --- a/fastai/dataset.py +++ b/fastai/dataset.py @@ -209,14 +209,9 @@ class FilesDataset(BaseDataset): def __init__(self, fnames, transform, path): self.path,self.fnames = path,fnames super().__init__(transform) - def get_n(self): return len(self.y) def get_sz(self): return self.transform.sz def get_x(self, i): return open_image(os.path.join(self.path, self.fnames[i])) - def resize_imgs(self, targ, new_path): - dest = resize_imgs(self.fnames, targ, self.path, new_path) - return self.__class__(self.fnames, self.y, self.transform, dest) - def denorm(self,arr): """Reverse the normalization done to a batch of images. @@ -232,10 +227,14 @@ def __init__(self, fnames, y, transform, path): self.y=y assert(len(fnames)==len(y)) super().__init__(fnames, transform, path) + def get_n(self): return len(self.y) def get_y(self, i): return self.y[i] def get_c(self): return self.y.shape[1] if len(self.y.shape)>1 else 0 + def resize_imgs(self, targ, new_path): + dest = resize_imgs(self.fnames, targ, self.path, new_path) + return self.__class__(self.fnames, self.y, self.transform, dest) class FilesIndexArrayDataset(FilesArrayDataset): def get_c(self): return int(self.y.max())+1