Skip to content

Commit

Permalink
Refactored merge and added gif generation
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanchenyang committed Jul 28, 2023
1 parent 928ea3d commit 0113329
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 92 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ __pycache__
/env/
/tests/imgs/merged/
/tests/imgs/large/
/tests/imgs/stitching_result.csv
/tests/img_folder/

htmlcov
*.egg-info
.coverage
6 changes: 4 additions & 2 deletions src/multifocal_stitching/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#from .stitching import *
from .stitching import stitch
from .merge_imgs import merge
from .utils import read_img

#__all__ = []
__all__ = [stitch, merge, read_img]
23 changes: 15 additions & 8 deletions src/multifocal_stitching/__main__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import csv
from .stitching import add_stitching_args, stitch
from .stitching import stitch
from .utils import get_default_parser, get_filenames, get_name, get_full_path, pairwise, read_img
from .merge_imgs import add_merge_args, merge_imgs
from .merge_imgs import merge_and_save

CSV_HEADER = ['Img 1', 'Img 2', 'X offset', 'Y offset', 'Corr Value', 'Area', 'r', 'use_win']

def main():
parser = add_stitching_args(add_merge_args(get_default_parser()))
args = parser.parse_args()
img_names = sorted(get_filenames(args))
with open(get_full_path(args, args.stitching_result), 'w') as outfile:
args = get_default_parser().parse_args()
if args.imgs is not None:
assert len(args.imgs) >= 2, 'Can only stitch two or more images!'
img_names = [get_full_path(args.dir, img) for img in args.imgs]
else:
img_names = sorted(get_filenames(args))
with open(get_full_path(args.dir, args.stitching_result), 'w') as outfile:
writer = csv.writer(outfile, delimiter=',')
writer.writerow(CSV_HEADER)
for img_names in pairwise(img_names):
Expand All @@ -29,8 +32,12 @@ def main():
writer.writerow([img_name1, img_name2, dx, dy, result.corr_coeff,
result.area, result.best_r, result.best_win])
if not args.no_merge:
res_dir = get_full_path(args, args.result_dir, mkdir=True)
merge_imgs(args, res_dir, img_name1, img_name2, dx, dy)
if args.verbose:
print('Merging:', img_name1, img_name2)
res_dir = get_full_path(args.dir, args.result_dir, mkdir=True)
merge_and_save(args.dir, res_dir, img_name1, img_name2, dx, dy,
resize_factor=args.resize_factor,
save_gif=args.save_gif,)

if __name__=='__main__':
main()
47 changes: 24 additions & 23 deletions src/multifocal_stitching/merge_imgs.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
import os
from PIL import Image
from typing import Tuple
from .utils import *

def add_merge_args(parser):
parser.add_argument('-s', '--stitching_result',
help='Stitching result csv file',
default='stitching_result.csv')
parser.add_argument('-d', '--result_dir',
help='Directory to save merged files',
default='merged')
parser.add_argument('-r', '--exclude_reverse',
help='Whether to additionally include img2 on top of img1',
action='store_true')
return parser

def merge_imgs(args, res_dir, img1, img2, dx, dy):
if args.verbose:
print('Merging:', img1, img2)
i1, i2 = [Image.open(get_full_path(args,img)) for img in (img1, img2)]
dx, dy = map(round_int, (dx, dy))
def merge(i1: Image, i2: Image, dx:int, dy:int) -> Tuple[Image, Image]:
assert i1.size == i2.size, "Images must be same size!"
W, H = i1.size
new_W, new_H = W + abs(dx), H + abs(dy)

i1_x = -dx if dx < 0 else 0
i1_y = -dy if dy < 0 else 0
i2_x = dx if dx > 0 else 0
i2_y = dy if dy > 0 else 0

res = Image.new(mode='RGB', size=(new_W, new_H))
res.paste(i1, (i1_x, i1_y))
res.paste(i2, (i2_x, i2_y))
res_path = os.path.join(res_dir,
f'{os.path.splitext(img1)[0]}__{os.path.splitext(img2)[0]}.jpg')
res.save(res_path)
if not args.exclude_reverse:
res.paste(i1, (i1_x, i1_y))
res.save(res_path[:-4] + '_r.jpg')

res_r = res.copy()
res_r.paste(i1, (i1_x, i1_y))

return res, res_r

def merge_and_save(base_dir:str, res_dir:str, img_name1:str, img_name2:str,
dx:int, dy:int, resize_factor:int=1, save_gif:bool=False):
i1, i2 = [Image.open(get_full_path(base_dir, i)) for i in (img_name1, img_name2)]
res = merge(i1, i2, dx, dy)
W, H = res[0].size
res_resized = [r.resize((W // resize_factor, H // resize_factor), Image.LANCZOS)
for r in res]
for i, r in enumerate(res_resized):
base1, base2 = [os.path.splitext(n)[0] for n in (img_name1, img_name2)]
r.save(get_full_path(res_dir, f'{base1}__{base2}_{i}.jpg'))
if save_gif:
res_resized[0].save(get_full_path(res_dir, f'{base1}__{base2}.png'), save_all=True,
append_images=res_resized[1:], duration=500, loop=0)
31 changes: 0 additions & 31 deletions src/multifocal_stitching/stitching.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Any, Tuple, Generator

from .utils import *
from .merge_imgs import add_merge_args, merge_imgs

def get_filter_mask(img: np.ndarray, r: int) -> np.ndarray:
x, y = img.shape
Expand Down Expand Up @@ -125,33 +124,3 @@ def candidate_stitches(img1: np.ndarray, img2: np.ndarray,

def stitch(img1: np.ndarray, img2: np.ndarray, **kwargs) -> StitchingResult:
return max(candidate_stitches(img1, img2, **kwargs), key=lambda r: r.corr_coeff)

def add_stitching_args(parser):
parser.add_argument('--ext',
help='Filename extension of images',
default='.jpg')
parser.add_argument('--no_merge',
help='Disable generating merged images',
action='store_true')
parser.add_argument('--workers', type=int,
help='Number of CPU threads to use in FFT',
default=2)
parser.add_argument('--min_overlap', type=float,
help='Set lower limit for overlapping region as a fraction of total image area',
default=0.125)
parser.add_argument('--early_term_thresh', type=float,
help='Stop searching when correlation is above this value',
default=0.7)
parser.add_argument('--use_wins', nargs="+", type=int,
help='Whether to try using Hanning window',
default=(0,))
parser.add_argument('--peak_cutoff_std', type=float,
help='Number of standard deviations below max value to use for peak finding',
default=1)
parser.add_argument('--peaks_dist_threshold', type=float,
help='Distance to consider as part of same cluster when finding peak centroid',
default=25)
parser.add_argument('--filter_radii', nargs="+", type=int,
default=(100,50,20),
help='Low-pass filter radii to try, smaller matches coarser/out-of-focus features')
return parser
58 changes: 49 additions & 9 deletions src/multifocal_stitching/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,70 @@ def pairwise(iterable):
def get_filenames(args):
return glob.glob(os.path.join(args.dir, f'*{args.ext}'))

def get_full_path(args, filename, mkdir=False):
path = os.path.join(args.dir, filename)
def get_full_path(base_dir, filename, mkdir=False):
path = os.path.join(base_dir, filename)
if mkdir and not os.path.exists(path):
os.makedirs(path)
return path

def get_name(path):
return os.path.split(path)[-1]

def round_int(s):
return int(round(float(s),0))

def round_int_np(x):
return np.round(x).astype('int')

def read_img(filename):
assert os.path.isfile(filename), f'Fild not found: {filename}'
return np.float64(cv2.imread(filename, cv2.IMREAD_GRAYSCALE))

def get_default_parser():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('dir',
help='Base directory')
parser.add_argument("-v", "--verbose",
help="Increase output verbosity",
action="store_true")
parser.add_argument('--ext',
help='Filename extension of images',
default='.jpg')
parser.add_argument('--imgs', nargs="+", type=str,
help='Stitch only provided images in provided order, otherwise '
'will run in batch mode over all images in directory',
default=None)
parser.add_argument('--no_merge',
help='Disable generating merged images',
action='store_true')
parser.add_argument('--workers', type=int,
help='Number of CPU threads to use in FFT',
default=2)
parser.add_argument('--min_overlap', type=float,
help='Set lower limit for overlapping region as a fraction of total image area',
default=0.125)
parser.add_argument('--early_term_thresh', type=float,
help='Stop searching when correlation is above this value',
default=0.7)
parser.add_argument('--use_wins', nargs="+", type=int,
help='Whether to try using Hanning window',
default=(0,))
parser.add_argument('--peak_cutoff_std', type=float,
help='Number of standard deviations below max value to use for peak finding',
default=1)
parser.add_argument('--peaks_dist_threshold', type=float,
help='Distance to consider as part of same cluster when finding peak centroid',
default=25)
parser.add_argument('--filter_radii', nargs="+", type=int,
help='Low-pass filter radii to try, smaller matches coarser/out-of-focus features',
default=(100,50,20))
parser.add_argument('--stitching_result',
help='Stitching result csv file',
default='stitching_result.csv')
parser.add_argument('--result_dir', type=str,
help='Directory to save merged files',
default='merged')
parser.add_argument('--resize_factor', type=int,
help='Whether to resize the images saved by a factor',
default=1)
parser.add_argument('--save_gif',
help='Whether to save a gif alternating between the merged files',
action="store_true")
return parser

def read_img(filename):
assert os.path.isfile(filename), f'Fild not found: {filename}'
return np.float64(cv2.imread(filename, cv2.IMREAD_GRAYSCALE))
28 changes: 9 additions & 19 deletions tests/test_stitching.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,14 @@
import shutil
import csv
from multifocal_stitching.stitching import stitch
from multifocal_stitching.utils import read_img
from multifocal_stitching.merge_imgs import merge_imgs
from multifocal_stitching.utils import read_img, get_full_path
from multifocal_stitching.merge_imgs import merge_and_save
from multifocal_stitching.__main__ import main as cli
from multifocal_stitching.__main__ import CSV_HEADER

def coord_is_close(res, val, tol=5):
assert np.linalg.norm(np.array(res.coord) - np.array(val), 1) <= tol

def get_full_path(base_dir, filename, mkdir=False):
path = os.path.join(base_dir, filename)
if mkdir and not os.path.exists(path):
os.makedirs(path)
return path

class TestCLI(unittest.TestCase):
def setUp(self):
self.base_dir = 'tests/img_folder'
Expand All @@ -45,24 +39,18 @@ def test_cli(self):
self.assertEqual(int(best_r), 50)
self.assertEqual(int(best_win), 0)

merged_name = os.path.join(
self.base_dir,
'merged',
'high_freq_features_1_small__high_freq_features_2_small.jpg'
)
merged_r_name = os.path.join(
self.base_dir,
'merged',
'high_freq_features_1_small__high_freq_features_2_small_r.jpg'
)
merged_name, merged_r_name = [os.path.join(
self.base_dir, 'merged',
f'high_freq_features_1_small__high_freq_features_2_small_{i}.jpg')
for i in range(2)]
self.assertTrue(os.path.isfile(merged_name))
self.assertTrue(os.path.isfile(merged_r_name))
merged = read_img(merged_name)
merged_r = read_img(merged_r_name)
self.assertEqual(merged.shape, merged_r.shape)
self.assertEqual(merged.shape, (2655, 6314))

class TestStitching(unittest.TestCase):
class TestStitch(unittest.TestCase):
def setUp(self):
self.base_dir = 'tests/imgs'

Expand All @@ -71,6 +59,8 @@ def stitch_name(self, name):
res = stitch(*[read_img(get_full_path(self.base_dir, name)) for name in names])
res_dir = get_full_path(self.base_dir, 'merged', mkdir=True)
dx, dy = res.coord
merge_and_save(self.base_dir, res_dir, names[0], names[1], dx, dy,
resize_factor=8, save_gif=True)
return res

def test_stitching_high_freq_features(self):
Expand Down

0 comments on commit 0113329

Please sign in to comment.