Skip to content

Commit

Permalink
Refactor asa tests for better isolation (#690)
Browse files Browse the repository at this point in the history
* rev docs, module strings; refactor tests for better isolation (skip if extras not installed)
  • Loading branch information
akarve authored Jul 11, 2018
1 parent e569e17 commit 1168e7d
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 161 deletions.
114 changes: 114 additions & 0 deletions compiler/quilt/test/test_asa_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""test class against quilt.asa.plot"""
import os

import numpy as np
import pytest

from quilt.tools import command
from .utils import QuiltTestCase, try_require

if not try_require('quilt[img]'):
# pylint: disable=unexpected-keyword-arg
pytest.skip(
"only test if [img] extras installed",
allow_module_level=True)

# pylint: disable=no-self-use
class ImportTest(QuiltTestCase):
# the following two lines must happen first
import matplotlib as mpl
mpl.use('Agg') # specify a backend so headless unit tests don't barf

def test_asa_plot(self):
from quilt.asa.img import plot

mydir = os.path.dirname(__file__)
build_path = os.path.join(mydir, './build_img.yml')
command.build('foo/imgtest', build_path)
pkg = command.load('foo/imgtest')
# expect no exceptions on root
pkg(asa=plot())
# pylint: disable=no-member
# expect no exceptions on GroupNode with only DF children
pkg.dataframes(asa=plot())
# expect no exceptions on GroupNode with mixed children
pkg.mixed(asa=plot())
# expect no exceptions on dir of images
pkg.mixed.img(asa=plot())
pkg.mixed.img(asa=plot(formats=['jpg', 'png']))
# assert images != filtered, 'Expected only .jpg and .png images'
# expect no exceptions on single images
pkg.mixed.img.sf(asa=plot())
pkg.mixed.img.portal(asa=plot())

def _are_similar(self, ima, imb, error=0.01):
"""predicate to see if images differ by less than
the given error; uses mean squared error; see also
https://www.pyimagesearch.com/2014/09/15/python-compare-two-images/
ima, imb: PIL.Image instances
"""
ima_ = np.array(ima).astype('float')
imb_ = np.array(imb).astype('float')
assert ima_.shape == imb_.shape, 'ima and imb must have same shape'
# pylint: disable=invalid-name
for x, y, _ in (ima_.shape, imb_.shape):
assert x > 0 and y > 0, \
'unexpected image dimension: {}'.format((x, y))
# sum of normalized channel differences squared
error_ = np.sum(((ima_ - imb_)/255) ** 2)
# normalize by total number of samples
error_ /= float(ima_.shape[0] * imb_.shape[1])

return error_ < error

def test_asa_plot_output(self):
from PIL import Image
from matplotlib import pyplot as plt

from quilt.asa.img import plot

mydir = os.path.dirname(__file__)
build_path = os.path.join(mydir, 'build_img.yml')
command.build('foo/imgtest', build_path)
pkg = command.load('foo/imgtest')

outfile = os.path.join('.', 'temp-plot.png')
# pylint: disable=no-member
pkg.mixed.img(asa=plot(figsize=(10, 10)))
# size * dpi = 1000 x 1000 pixels
plt.savefig(outfile, dpi=100, format='png', transparent=False)

ref_path = os.path.join(mydir, 'data', 'ref-asa-plot.png')

ref_img = Image.open(ref_path)
tst_img = Image.open(outfile)

assert self._are_similar(ref_img, tst_img), \
'render differs from reference: {}'.format(ref_img)

def test_asa_plot_formats_output(self):
from PIL import Image
from matplotlib import pyplot as plt

from quilt.asa.img import plot

mydir = os.path.dirname(__file__)
build_path = os.path.join(mydir, 'build_img.yml')
command.build('foo/imgtest', build_path)
pkg = command.load('foo/imgtest')

outfile = os.path.join('.', 'temp-formats-plot.png')

# pylint: disable=no-member
pkg.mixed.img(asa=plot(figsize=(10, 10), formats=['png']))
# size * dpi = 1000 x 1000 pixels
plt.savefig(outfile, dpi=100, format='png', transparent=False)

ref_path = os.path.join(mydir, 'data', 'ref-asa-formats.png')

ref_img = Image.open(ref_path)
tst_img = Image.open(outfile)

assert self._are_similar(ref_img, tst_img), \
'render differs from reference: {}'.format(ref_img)
89 changes: 89 additions & 0 deletions compiler/quilt/test/test_asa_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""test class against quilt.asa.torch"""
import os

import pytest
from six import string_types

from quilt.tools import command
from quilt.nodes import DataNode
from .utils import QuiltTestCase, try_require

if not try_require('quilt[img,pytorch,torchvision]'):
# pylint: disable=unexpected-keyword-arg
pytest.skip("only test if [img,pytorch,torchvision] extras installed",
allow_module_level=True)

# pylint: disable=no-self-use
class ImportTest(QuiltTestCase):
def test_asa_pytorch(self):
"""test asa.torch interface by converting a GroupNode with asa="""
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize
from torch.utils.data import Dataset
from PIL import Image
from torch import Tensor

from quilt.asa.pytorch import dataset
# pylint: disable=missing-docstring
# helper functions to simulate real pytorch dataset usage
def calculate_valid_crop_size(crop_size, upscale_factor):
return crop_size - (crop_size % upscale_factor)

def node_parser(node):
path = node()
if isinstance(path, string_types):
img = Image.open(path).convert('YCbCr')
chan, _, _ = img.split()
return chan
else:
raise TypeError('Expected string path to an image fragment')

def input_transform(crop_size, upscale_factor):
return Compose([
CenterCrop(crop_size),
Resize(crop_size // upscale_factor),
ToTensor(),
])

def target_transform(crop_size):
def _inner(img):
img_ = img.copy()
return Compose([
CenterCrop(crop_size),
ToTensor(),
])(img_)
return _inner
# pylint: disable=protected-access
def is_image(node):
"""file extension introspection on Quilt nodes"""
if isinstance(node, DataNode):
filepath = node._meta.get('_system', {}).get('filepath')
if filepath:
return any(
filepath.endswith(extension)
for extension in [".png", ".jpg", ".jpeg"])
# end helper functions

mydir = os.path.dirname(__file__)
build_path = os.path.join(mydir, 'build_img.yml')
command.build('foo/torchtest', build_path)
pkg = command.load('foo/torchtest')

upscale_factor = 3
crop_size = calculate_valid_crop_size(256, upscale_factor)
# pylint: disable=no-member
my_dataset = pkg.mixed.img(asa=dataset(
include=is_image,
node_parser=node_parser,
input_transform=input_transform(crop_size, upscale_factor),
target_transform=target_transform(crop_size)
))
assert isinstance(my_dataset, Dataset), \
'expected type {}, got {}'.format(type(Dataset), type(my_dataset))

assert my_dataset.__len__() == 2, \
'expected two images in mixed.img, got {}'.format(my_dataset.__len__())

for i in range(my_dataset.__len__()):
tens = my_dataset.__getitem__(i)
assert all((isinstance(x, Tensor) for x in tens)), \
'Expected all torch.Tensors in tuple, got {}'.format(tens)
161 changes: 0 additions & 161 deletions compiler/quilt/test/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,16 @@
import platform
import time

# the following two lines must happen first
import matplotlib as mpl
mpl.use('Agg') # specify a backend so renderer doesn't barf
# pylint: disable=wrong-import-position
from PIL import Image
import numpy as np
import pandas as pd
import pytest
from matplotlib import pyplot as plt
from six import string_types

from quilt.tools import command
from quilt.nodes import DataNode, GroupNode
from quilt.tools.const import PACKAGE_DIR_NAME
from quilt.tools.package import Package
from quilt.tools.store import PackageStore, StoreException
from quilt.asa.img import plot
from .utils import patch, QuiltTestCase

# pylint: disable=protected-access
Expand Down Expand Up @@ -492,160 +485,6 @@ def test_lambda(node, hashes):
assert pkg.dataframes(asa=test_lambda) is testdata
assert pkg(asa=test_lambda) is testdata

# pylint: disable=no-member
def test_asa_plot(self):
mydir = os.path.dirname(__file__)
build_path = os.path.join(mydir, './build_img.yml')
command.build('foo/imgtest', build_path)
pkg = command.load('foo/imgtest')
# expect no exceptions on root
pkg(asa=plot())
# expect no exceptions on GroupNode with only DF children
pkg.dataframes(asa=plot())
# expect no exceptions on GroupNode with mixed children
pkg.mixed(asa=plot())
# expect no exceptions on dir of images
pkg.mixed.img(asa=plot())
pkg.mixed.img(asa=plot(formats=['jpg', 'png']))
# assert images != filtered, 'Expected only .jpg and .png images'
# expect no exceptions on single images
pkg.mixed.img.sf(asa=plot())
pkg.mixed.img.portal(asa=plot())

def _are_similar(self, ima, imb, error=0.01):
"""predicate to see if images differ by less than
the given error; uses mean squared error; see also
https://www.pyimagesearch.com/2014/09/15/python-compare-two-images/
ima, imb: PIL.Image instances
"""
ima_ = np.array(ima).astype('float')
imb_ = np.array(imb).astype('float')
assert ima_.shape == imb_.shape, 'ima and imb must have same shape'
for x, y, _ in (ima_.shape, imb_.shape):
assert x > 0 and y > 0, 'unexpected image dimension: {}'.format(shape)
# sum of normalized channel differences squared
error_ = np.sum(((ima_ - imb_)/255) ** 2)
# normalize by total number of samples
error_ /= float(ima_.shape[0] * imb_.shape[1])

return error_ < error

# pylint: disable=no-member
def test_asa_plot_output(self):
mydir = os.path.dirname(__file__)
build_path = os.path.join(mydir, 'build_img.yml')
command.build('foo/imgtest', build_path)
pkg = command.load('foo/imgtest')

outfile = os.path.join('.', 'temp-plot.png')
pkg.mixed.img(asa=plot(figsize=(10, 10)))
# size * dpi = 1000 x 1000 pixels
plt.savefig(outfile, dpi=100, format='png', transparent=False)

ref_path = os.path.join(mydir, 'data', 'ref-asa-plot.png')

ref_img = Image.open(ref_path)
tst_img = Image.open(outfile)

assert self._are_similar(ref_img, tst_img), \
'render differs from reference: {}'.format(ref_img)

# pylint: disable=no-member
def test_asa_plot_formats_output(self):
mydir = os.path.dirname(__file__)
build_path = os.path.join(mydir, 'build_img.yml')
command.build('foo/imgtest', build_path)
pkg = command.load('foo/imgtest')

outfile = os.path.join('.', 'temp-formats-plot.png')
pkg.mixed.img(asa=plot(figsize=(10, 10), formats=['png']))
# size * dpi = 1000 x 1000 pixels
plt.savefig(outfile, dpi=100, format='png', transparent=False)

ref_path = os.path.join(mydir, 'data', 'ref-asa-formats.png')

ref_img = Image.open(ref_path)
tst_img = Image.open(outfile)

assert self._are_similar(ref_img, tst_img), \
'render differs from reference: {}'.format(ref_img)


@pytest.mark.xfail(platform.system() in ['Windows'], reason=(
"infeasible to install pytorch on appveyor (even with conda)"
))
def test_asa_pytorch(self):
"""test asa.torch interface by converting a GroupNode with asa="""
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize
from torch.utils.data import Dataset
from torch import Tensor

from quilt.asa.pytorch import dataset
# pylint: disable=missing-docstring
# helper functions to simulate real pytorch dataset usage
def calculate_valid_crop_size(crop_size, upscale_factor):
return crop_size - (crop_size % upscale_factor)

def node_parser(node):
path = node()
if isinstance(path, string_types):
img = Image.open(path).convert('YCbCr')
chan, _, _ = img.split()
return chan
else:
raise TypeError('Expected string path to an image fragment')

def input_transform(crop_size, upscale_factor):
return Compose([
CenterCrop(crop_size),
Resize(crop_size // upscale_factor),
ToTensor(),
])

def target_transform(crop_size):
def _inner(img):
img_ = img.copy()
return Compose([
CenterCrop(crop_size),
ToTensor(),
])(img_)
return _inner

def is_image(node):
"""file extension introspection on Quilt nodes"""
if isinstance(node, DataNode):
filepath = node._meta.get('_system', {}).get('filepath')
if filepath:
return any(
filepath.endswith(extension)
for extension in [".png", ".jpg", ".jpeg"])
# end helper functions

mydir = os.path.dirname(__file__)
build_path = os.path.join(mydir, 'build_img.yml')
command.build('foo/torchtest', build_path)
pkg = command.load('foo/torchtest')

upscale_factor = 3
crop_size = calculate_valid_crop_size(256, upscale_factor)
my_dataset = pkg.mixed.img(asa=dataset(
include=is_image,
node_parser=node_parser,
input_transform=input_transform(crop_size, upscale_factor),
target_transform=target_transform(crop_size)
))
assert isinstance(my_dataset, Dataset), \
'expected type {}, got {}'.format(type(Dataset), type(my_dataset))

assert my_dataset.__len__() == 2, \
'expected two images in mixed.img, got {}'.format()

for i in range(my_dataset.__len__()):
tens = my_dataset.__getitem__(i);
assert all((isinstance(x, Tensor) for x in tens)), \
'Expected all torch.Tensors in tuple, got {}'.format(tens)

def test_memory_only_datanode_asa(self):
testdata = "justatest"
def test_lambda(node, hashes):
Expand Down
Loading

0 comments on commit 1168e7d

Please sign in to comment.