Skip to content

Commit

Permalink
Moved method that depends on self.y to the class that has it
Browse files Browse the repository at this point in the history
The self.y is being used in resize_imps and get_n. It is assumed that
it has `len()` defined so it is a list of an array.

So FilesArrayDataset seems to be a better match for the methods.
  • Loading branch information
Piotr Czapla committed Mar 14, 2018
1 parent 00a363b commit 74421ae
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions fastai/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,9 @@ class FilesDataset(BaseDataset):
def __init__(self, fnames, transform, path):
self.path,self.fnames = path,fnames
super().__init__(transform)
def get_n(self): return len(self.y)
def get_sz(self): return self.transform.sz
def get_x(self, i): return open_image(os.path.join(self.path, self.fnames[i]))

def resize_imgs(self, targ, new_path):
dest = resize_imgs(self.fnames, targ, self.path, new_path)
return self.__class__(self.fnames, self.y, self.transform, dest)

def denorm(self,arr):
"""Reverse the normalization done to a batch of images.
Expand All @@ -232,10 +227,14 @@ def __init__(self, fnames, y, transform, path):
self.y=y
assert(len(fnames)==len(y))
super().__init__(fnames, transform, path)
def get_n(self): return len(self.y)
def get_y(self, i): return self.y[i]
def get_c(self):
return self.y.shape[1] if len(self.y.shape)>1 else 0

def resize_imgs(self, targ, new_path):
dest = resize_imgs(self.fnames, targ, self.path, new_path)
return self.__class__(self.fnames, self.y, self.transform, dest)

class FilesIndexArrayDataset(FilesArrayDataset):
def get_c(self): return int(self.y.max())+1
Expand Down

0 comments on commit 74421ae

Please sign in to comment.