diff --git a/.gitignore b/.gitignore index 1c725d0..05efca2 100644 --- a/.gitignore +++ b/.gitignore @@ -20,7 +20,9 @@ __pycache__ /env/ /tests/imgs/merged/ /tests/imgs/large/ +/tests/imgs/stitching_result.csv /tests/img_folder/ + htmlcov *.egg-info .coverage diff --git a/src/multifocal_stitching/__init__.py b/src/multifocal_stitching/__init__.py index 1a50845..e91d624 100644 --- a/src/multifocal_stitching/__init__.py +++ b/src/multifocal_stitching/__init__.py @@ -1,3 +1,5 @@ -#from .stitching import * +from .stitching import stitch +from .merge_imgs import merge +from .utils import read_img -#__all__ = [] +__all__ = [stitch, merge, read_img] diff --git a/src/multifocal_stitching/__main__.py b/src/multifocal_stitching/__main__.py index 485c4e2..6669845 100644 --- a/src/multifocal_stitching/__main__.py +++ b/src/multifocal_stitching/__main__.py @@ -1,15 +1,18 @@ import csv -from .stitching import add_stitching_args, stitch +from .stitching import stitch from .utils import get_default_parser, get_filenames, get_name, get_full_path, pairwise, read_img -from .merge_imgs import add_merge_args, merge_imgs +from .merge_imgs import merge_and_save CSV_HEADER = ['Img 1', 'Img 2', 'X offset', 'Y offset', 'Corr Value', 'Area', 'r', 'use_win'] def main(): - parser = add_stitching_args(add_merge_args(get_default_parser())) - args = parser.parse_args() - img_names = sorted(get_filenames(args)) - with open(get_full_path(args, args.stitching_result), 'w') as outfile: + args = get_default_parser().parse_args() + if args.imgs is not None: + assert len(args.imgs) >= 2, 'Can only stitch two or more images!' + img_names = [get_full_path(args.dir, img) for img in args.imgs] + else: + img_names = sorted(get_filenames(args)) + with open(get_full_path(args.dir, args.stitching_result), 'w') as outfile: writer = csv.writer(outfile, delimiter=',') writer.writerow(CSV_HEADER) for img_names in pairwise(img_names): @@ -29,8 +32,12 @@ def main(): writer.writerow([img_name1, img_name2, dx, dy, result.corr_coeff, result.area, result.best_r, result.best_win]) if not args.no_merge: - res_dir = get_full_path(args, args.result_dir, mkdir=True) - merge_imgs(args, res_dir, img_name1, img_name2, dx, dy) + if args.verbose: + print('Merging:', img_name1, img_name2) + res_dir = get_full_path(args.dir, args.result_dir, mkdir=True) + merge_and_save(args.dir, res_dir, img_name1, img_name2, dx, dy, + resize_factor=args.resize_factor, + save_gif=args.save_gif,) if __name__=='__main__': main() diff --git a/src/multifocal_stitching/merge_imgs.py b/src/multifocal_stitching/merge_imgs.py index b57dcf4..7c9c541 100644 --- a/src/multifocal_stitching/merge_imgs.py +++ b/src/multifocal_stitching/merge_imgs.py @@ -1,36 +1,37 @@ import os from PIL import Image +from typing import Tuple from .utils import * -def add_merge_args(parser): - parser.add_argument('-s', '--stitching_result', - help='Stitching result csv file', - default='stitching_result.csv') - parser.add_argument('-d', '--result_dir', - help='Directory to save merged files', - default='merged') - parser.add_argument('-r', '--exclude_reverse', - help='Whether to additionally include img2 on top of img1', - action='store_true') - return parser - -def merge_imgs(args, res_dir, img1, img2, dx, dy): - if args.verbose: - print('Merging:', img1, img2) - i1, i2 = [Image.open(get_full_path(args,img)) for img in (img1, img2)] - dx, dy = map(round_int, (dx, dy)) +def merge(i1: Image, i2: Image, dx:int, dy:int) -> Tuple[Image, Image]: + assert i1.size == i2.size, "Images must be same size!" W, H = i1.size new_W, new_H = W + abs(dx), H + abs(dy) + i1_x = -dx if dx < 0 else 0 i1_y = -dy if dy < 0 else 0 i2_x = dx if dx > 0 else 0 i2_y = dy if dy > 0 else 0 + res = Image.new(mode='RGB', size=(new_W, new_H)) res.paste(i1, (i1_x, i1_y)) res.paste(i2, (i2_x, i2_y)) - res_path = os.path.join(res_dir, - f'{os.path.splitext(img1)[0]}__{os.path.splitext(img2)[0]}.jpg') - res.save(res_path) - if not args.exclude_reverse: - res.paste(i1, (i1_x, i1_y)) - res.save(res_path[:-4] + '_r.jpg') + + res_r = res.copy() + res_r.paste(i1, (i1_x, i1_y)) + + return res, res_r + +def merge_and_save(base_dir:str, res_dir:str, img_name1:str, img_name2:str, + dx:int, dy:int, resize_factor:int=1, save_gif:bool=False): + i1, i2 = [Image.open(get_full_path(base_dir, i)) for i in (img_name1, img_name2)] + res = merge(i1, i2, dx, dy) + W, H = res[0].size + res_resized = [r.resize((W // resize_factor, H // resize_factor), Image.LANCZOS) + for r in res] + for i, r in enumerate(res_resized): + base1, base2 = [os.path.splitext(n)[0] for n in (img_name1, img_name2)] + r.save(get_full_path(res_dir, f'{base1}__{base2}_{i}.jpg')) + if save_gif: + res_resized[0].save(get_full_path(res_dir, f'{base1}__{base2}.png'), save_all=True, + append_images=res_resized[1:], duration=500, loop=0) diff --git a/src/multifocal_stitching/stitching.py b/src/multifocal_stitching/stitching.py index 745efe3..fe07b22 100644 --- a/src/multifocal_stitching/stitching.py +++ b/src/multifocal_stitching/stitching.py @@ -8,7 +8,6 @@ from typing import Any, Tuple, Generator from .utils import * -from .merge_imgs import add_merge_args, merge_imgs def get_filter_mask(img: np.ndarray, r: int) -> np.ndarray: x, y = img.shape @@ -125,33 +124,3 @@ def candidate_stitches(img1: np.ndarray, img2: np.ndarray, def stitch(img1: np.ndarray, img2: np.ndarray, **kwargs) -> StitchingResult: return max(candidate_stitches(img1, img2, **kwargs), key=lambda r: r.corr_coeff) - -def add_stitching_args(parser): - parser.add_argument('--ext', - help='Filename extension of images', - default='.jpg') - parser.add_argument('--no_merge', - help='Disable generating merged images', - action='store_true') - parser.add_argument('--workers', type=int, - help='Number of CPU threads to use in FFT', - default=2) - parser.add_argument('--min_overlap', type=float, - help='Set lower limit for overlapping region as a fraction of total image area', - default=0.125) - parser.add_argument('--early_term_thresh', type=float, - help='Stop searching when correlation is above this value', - default=0.7) - parser.add_argument('--use_wins', nargs="+", type=int, - help='Whether to try using Hanning window', - default=(0,)) - parser.add_argument('--peak_cutoff_std', type=float, - help='Number of standard deviations below max value to use for peak finding', - default=1) - parser.add_argument('--peaks_dist_threshold', type=float, - help='Distance to consider as part of same cluster when finding peak centroid', - default=25) - parser.add_argument('--filter_radii', nargs="+", type=int, - default=(100,50,20), - help='Low-pass filter radii to try, smaller matches coarser/out-of-focus features') - return parser diff --git a/src/multifocal_stitching/utils.py b/src/multifocal_stitching/utils.py index 2243e06..be1f345 100644 --- a/src/multifocal_stitching/utils.py +++ b/src/multifocal_stitching/utils.py @@ -15,8 +15,8 @@ def pairwise(iterable): def get_filenames(args): return glob.glob(os.path.join(args.dir, f'*{args.ext}')) -def get_full_path(args, filename, mkdir=False): - path = os.path.join(args.dir, filename) +def get_full_path(base_dir, filename, mkdir=False): + path = os.path.join(base_dir, filename) if mkdir and not os.path.exists(path): os.makedirs(path) return path @@ -24,12 +24,13 @@ def get_full_path(args, filename, mkdir=False): def get_name(path): return os.path.split(path)[-1] -def round_int(s): - return int(round(float(s),0)) - def round_int_np(x): return np.round(x).astype('int') +def read_img(filename): + assert os.path.isfile(filename), f'Fild not found: {filename}' + return np.float64(cv2.imread(filename, cv2.IMREAD_GRAYSCALE)) + def get_default_parser(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('dir', @@ -37,8 +38,47 @@ def get_default_parser(): parser.add_argument("-v", "--verbose", help="Increase output verbosity", action="store_true") + parser.add_argument('--ext', + help='Filename extension of images', + default='.jpg') + parser.add_argument('--imgs', nargs="+", type=str, + help='Stitch only provided images in provided order, otherwise ' + 'will run in batch mode over all images in directory', + default=None) + parser.add_argument('--no_merge', + help='Disable generating merged images', + action='store_true') + parser.add_argument('--workers', type=int, + help='Number of CPU threads to use in FFT', + default=2) + parser.add_argument('--min_overlap', type=float, + help='Set lower limit for overlapping region as a fraction of total image area', + default=0.125) + parser.add_argument('--early_term_thresh', type=float, + help='Stop searching when correlation is above this value', + default=0.7) + parser.add_argument('--use_wins', nargs="+", type=int, + help='Whether to try using Hanning window', + default=(0,)) + parser.add_argument('--peak_cutoff_std', type=float, + help='Number of standard deviations below max value to use for peak finding', + default=1) + parser.add_argument('--peaks_dist_threshold', type=float, + help='Distance to consider as part of same cluster when finding peak centroid', + default=25) + parser.add_argument('--filter_radii', nargs="+", type=int, + help='Low-pass filter radii to try, smaller matches coarser/out-of-focus features', + default=(100,50,20)) + parser.add_argument('--stitching_result', + help='Stitching result csv file', + default='stitching_result.csv') + parser.add_argument('--result_dir', type=str, + help='Directory to save merged files', + default='merged') + parser.add_argument('--resize_factor', type=int, + help='Whether to resize the images saved by a factor', + default=1) + parser.add_argument('--save_gif', + help='Whether to save a gif alternating between the merged files', + action="store_true") return parser - -def read_img(filename): - assert os.path.isfile(filename), f'Fild not found: {filename}' - return np.float64(cv2.imread(filename, cv2.IMREAD_GRAYSCALE)) diff --git a/tests/test_stitching.py b/tests/test_stitching.py index 1f4a6fd..5772d68 100644 --- a/tests/test_stitching.py +++ b/tests/test_stitching.py @@ -5,20 +5,14 @@ import shutil import csv from multifocal_stitching.stitching import stitch -from multifocal_stitching.utils import read_img -from multifocal_stitching.merge_imgs import merge_imgs +from multifocal_stitching.utils import read_img, get_full_path +from multifocal_stitching.merge_imgs import merge_and_save from multifocal_stitching.__main__ import main as cli from multifocal_stitching.__main__ import CSV_HEADER def coord_is_close(res, val, tol=5): assert np.linalg.norm(np.array(res.coord) - np.array(val), 1) <= tol -def get_full_path(base_dir, filename, mkdir=False): - path = os.path.join(base_dir, filename) - if mkdir and not os.path.exists(path): - os.makedirs(path) - return path - class TestCLI(unittest.TestCase): def setUp(self): self.base_dir = 'tests/img_folder' @@ -45,16 +39,10 @@ def test_cli(self): self.assertEqual(int(best_r), 50) self.assertEqual(int(best_win), 0) - merged_name = os.path.join( - self.base_dir, - 'merged', - 'high_freq_features_1_small__high_freq_features_2_small.jpg' - ) - merged_r_name = os.path.join( - self.base_dir, - 'merged', - 'high_freq_features_1_small__high_freq_features_2_small_r.jpg' - ) + merged_name, merged_r_name = [os.path.join( + self.base_dir, 'merged', + f'high_freq_features_1_small__high_freq_features_2_small_{i}.jpg') + for i in range(2)] self.assertTrue(os.path.isfile(merged_name)) self.assertTrue(os.path.isfile(merged_r_name)) merged = read_img(merged_name) @@ -62,7 +50,7 @@ def test_cli(self): self.assertEqual(merged.shape, merged_r.shape) self.assertEqual(merged.shape, (2655, 6314)) -class TestStitching(unittest.TestCase): +class TestStitch(unittest.TestCase): def setUp(self): self.base_dir = 'tests/imgs' @@ -71,6 +59,8 @@ def stitch_name(self, name): res = stitch(*[read_img(get_full_path(self.base_dir, name)) for name in names]) res_dir = get_full_path(self.base_dir, 'merged', mkdir=True) dx, dy = res.coord + merge_and_save(self.base_dir, res_dir, names[0], names[1], dx, dy, + resize_factor=8, save_gif=True) return res def test_stitching_high_freq_features(self):