-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor asa tests for better isolation (#690)
* rev docs, module strings; refactor tests for better isolation (skip if extras not installed)
- Loading branch information
Showing
5 changed files
with
216 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
"""test class against quilt.asa.plot""" | ||
import os | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from quilt.tools import command | ||
from .utils import QuiltTestCase, try_require | ||
|
||
if not try_require('quilt[img]'): | ||
# pylint: disable=unexpected-keyword-arg | ||
pytest.skip( | ||
"only test if [img] extras installed", | ||
allow_module_level=True) | ||
|
||
# pylint: disable=no-self-use | ||
class ImportTest(QuiltTestCase): | ||
# the following two lines must happen first | ||
import matplotlib as mpl | ||
mpl.use('Agg') # specify a backend so headless unit tests don't barf | ||
|
||
def test_asa_plot(self): | ||
from quilt.asa.img import plot | ||
|
||
mydir = os.path.dirname(__file__) | ||
build_path = os.path.join(mydir, './build_img.yml') | ||
command.build('foo/imgtest', build_path) | ||
pkg = command.load('foo/imgtest') | ||
# expect no exceptions on root | ||
pkg(asa=plot()) | ||
# pylint: disable=no-member | ||
# expect no exceptions on GroupNode with only DF children | ||
pkg.dataframes(asa=plot()) | ||
# expect no exceptions on GroupNode with mixed children | ||
pkg.mixed(asa=plot()) | ||
# expect no exceptions on dir of images | ||
pkg.mixed.img(asa=plot()) | ||
pkg.mixed.img(asa=plot(formats=['jpg', 'png'])) | ||
# assert images != filtered, 'Expected only .jpg and .png images' | ||
# expect no exceptions on single images | ||
pkg.mixed.img.sf(asa=plot()) | ||
pkg.mixed.img.portal(asa=plot()) | ||
|
||
def _are_similar(self, ima, imb, error=0.01): | ||
"""predicate to see if images differ by less than | ||
the given error; uses mean squared error; see also | ||
https://www.pyimagesearch.com/2014/09/15/python-compare-two-images/ | ||
ima, imb: PIL.Image instances | ||
""" | ||
ima_ = np.array(ima).astype('float') | ||
imb_ = np.array(imb).astype('float') | ||
assert ima_.shape == imb_.shape, 'ima and imb must have same shape' | ||
# pylint: disable=invalid-name | ||
for x, y, _ in (ima_.shape, imb_.shape): | ||
assert x > 0 and y > 0, \ | ||
'unexpected image dimension: {}'.format((x, y)) | ||
# sum of normalized channel differences squared | ||
error_ = np.sum(((ima_ - imb_)/255) ** 2) | ||
# normalize by total number of samples | ||
error_ /= float(ima_.shape[0] * imb_.shape[1]) | ||
|
||
return error_ < error | ||
|
||
def test_asa_plot_output(self): | ||
from PIL import Image | ||
from matplotlib import pyplot as plt | ||
|
||
from quilt.asa.img import plot | ||
|
||
mydir = os.path.dirname(__file__) | ||
build_path = os.path.join(mydir, 'build_img.yml') | ||
command.build('foo/imgtest', build_path) | ||
pkg = command.load('foo/imgtest') | ||
|
||
outfile = os.path.join('.', 'temp-plot.png') | ||
# pylint: disable=no-member | ||
pkg.mixed.img(asa=plot(figsize=(10, 10))) | ||
# size * dpi = 1000 x 1000 pixels | ||
plt.savefig(outfile, dpi=100, format='png', transparent=False) | ||
|
||
ref_path = os.path.join(mydir, 'data', 'ref-asa-plot.png') | ||
|
||
ref_img = Image.open(ref_path) | ||
tst_img = Image.open(outfile) | ||
|
||
assert self._are_similar(ref_img, tst_img), \ | ||
'render differs from reference: {}'.format(ref_img) | ||
|
||
def test_asa_plot_formats_output(self): | ||
from PIL import Image | ||
from matplotlib import pyplot as plt | ||
|
||
from quilt.asa.img import plot | ||
|
||
mydir = os.path.dirname(__file__) | ||
build_path = os.path.join(mydir, 'build_img.yml') | ||
command.build('foo/imgtest', build_path) | ||
pkg = command.load('foo/imgtest') | ||
|
||
outfile = os.path.join('.', 'temp-formats-plot.png') | ||
|
||
# pylint: disable=no-member | ||
pkg.mixed.img(asa=plot(figsize=(10, 10), formats=['png'])) | ||
# size * dpi = 1000 x 1000 pixels | ||
plt.savefig(outfile, dpi=100, format='png', transparent=False) | ||
|
||
ref_path = os.path.join(mydir, 'data', 'ref-asa-formats.png') | ||
|
||
ref_img = Image.open(ref_path) | ||
tst_img = Image.open(outfile) | ||
|
||
assert self._are_similar(ref_img, tst_img), \ | ||
'render differs from reference: {}'.format(ref_img) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
"""test class against quilt.asa.torch""" | ||
import os | ||
|
||
import pytest | ||
from six import string_types | ||
|
||
from quilt.tools import command | ||
from quilt.nodes import DataNode | ||
from .utils import QuiltTestCase, try_require | ||
|
||
if not try_require('quilt[img,pytorch,torchvision]'): | ||
# pylint: disable=unexpected-keyword-arg | ||
pytest.skip("only test if [img,pytorch,torchvision] extras installed", | ||
allow_module_level=True) | ||
|
||
# pylint: disable=no-self-use | ||
class ImportTest(QuiltTestCase): | ||
def test_asa_pytorch(self): | ||
"""test asa.torch interface by converting a GroupNode with asa=""" | ||
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize | ||
from torch.utils.data import Dataset | ||
from PIL import Image | ||
from torch import Tensor | ||
|
||
from quilt.asa.pytorch import dataset | ||
# pylint: disable=missing-docstring | ||
# helper functions to simulate real pytorch dataset usage | ||
def calculate_valid_crop_size(crop_size, upscale_factor): | ||
return crop_size - (crop_size % upscale_factor) | ||
|
||
def node_parser(node): | ||
path = node() | ||
if isinstance(path, string_types): | ||
img = Image.open(path).convert('YCbCr') | ||
chan, _, _ = img.split() | ||
return chan | ||
else: | ||
raise TypeError('Expected string path to an image fragment') | ||
|
||
def input_transform(crop_size, upscale_factor): | ||
return Compose([ | ||
CenterCrop(crop_size), | ||
Resize(crop_size // upscale_factor), | ||
ToTensor(), | ||
]) | ||
|
||
def target_transform(crop_size): | ||
def _inner(img): | ||
img_ = img.copy() | ||
return Compose([ | ||
CenterCrop(crop_size), | ||
ToTensor(), | ||
])(img_) | ||
return _inner | ||
# pylint: disable=protected-access | ||
def is_image(node): | ||
"""file extension introspection on Quilt nodes""" | ||
if isinstance(node, DataNode): | ||
filepath = node._meta.get('_system', {}).get('filepath') | ||
if filepath: | ||
return any( | ||
filepath.endswith(extension) | ||
for extension in [".png", ".jpg", ".jpeg"]) | ||
# end helper functions | ||
|
||
mydir = os.path.dirname(__file__) | ||
build_path = os.path.join(mydir, 'build_img.yml') | ||
command.build('foo/torchtest', build_path) | ||
pkg = command.load('foo/torchtest') | ||
|
||
upscale_factor = 3 | ||
crop_size = calculate_valid_crop_size(256, upscale_factor) | ||
# pylint: disable=no-member | ||
my_dataset = pkg.mixed.img(asa=dataset( | ||
include=is_image, | ||
node_parser=node_parser, | ||
input_transform=input_transform(crop_size, upscale_factor), | ||
target_transform=target_transform(crop_size) | ||
)) | ||
assert isinstance(my_dataset, Dataset), \ | ||
'expected type {}, got {}'.format(type(Dataset), type(my_dataset)) | ||
|
||
assert my_dataset.__len__() == 2, \ | ||
'expected two images in mixed.img, got {}'.format(my_dataset.__len__()) | ||
|
||
for i in range(my_dataset.__len__()): | ||
tens = my_dataset.__getitem__(i) | ||
assert all((isinstance(x, Tensor) for x in tens)), \ | ||
'Expected all torch.Tensors in tuple, got {}'.format(tens) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.