Skip to content

Commit

Permalink
Major Refactor
Browse files Browse the repository at this point in the history
- Added types
- Removed args input for stitching functions
- Added test for CLI
  • Loading branch information
yuanchenyang committed Jul 27, 2023
1 parent 9d18603 commit be2f6d9
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 54 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ __pycache__
/env/
/tests/imgs/merged/
/tests/imgs/large/
/tests/img_folder/
htmlcov
*.egg-info
.coverage
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "multifocal-stitching"
version = "0.1.1"
version = "0.2"
description = "Algorithms and tools for stitching microscopy images taken at different focal lengths"
readme = "README.md"
requires-python = ">=3.7"
Expand Down
21 changes: 17 additions & 4 deletions src/multifocal_stitching/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,31 @@
from .utils import get_default_parser, get_filenames, get_name, get_full_path, pairwise, read_img
from .merge_imgs import add_merge_args, merge_imgs

CSV_HEADER = ['Img 1', 'Img 2', 'X offset', 'Y offset', 'Corr Value', 'Area', 'r', 'use_win']

def main():
parser = add_stitching_args(add_merge_args(get_default_parser()))
args = parser.parse_args()
img_names = sorted(get_filenames(args))
with open(get_full_path(args, args.stitching_result), 'w') as outfile:
writer = csv.writer(outfile, delimiter=',')
writer.writerow(['Img 1', 'Img 2', 'X offset', 'Y offset', 'Corr Value', 'Area', 'r', 'use_win'])
writer.writerow(CSV_HEADER)
for img_names in pairwise(img_names):
if args.verbose: print('Stitching', *img_names)
corr, res, (dx, dy), val, area, r, use_win = stitch(args, *map(read_img, img_names))
if args.verbose:
print('Stitching', *img_names)
result = stitch(*map(read_img, img_names),
use_wins = args.use_wins,
workers = args.workers,
peak_cutoff_std = args.peak_cutoff_std,
peaks_dist_threshold = args.peaks_dist_threshold,
filter_radii = args.filter_radii,
min_overlap = args.min_overlap,
early_term_thresh = args.early_term_thresh,
verbose = args.verbose)
dx, dy = result.coord
img_name1, img_name2 = map(get_name, img_names)
writer.writerow([img_name1, img_name2, dx, dy, corr, area, r, use_win])
writer.writerow([img_name1, img_name2, dx, dy, result.corr_coeff,
result.area, result.best_r, result.best_win])
if not args.no_merge:
res_dir = get_full_path(args, args.result_dir, mkdir=True)
merge_imgs(args, res_dir, img_name1, img_name2, dx, dy)
Expand Down
13 changes: 0 additions & 13 deletions src/multifocal_stitching/merge_imgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,3 @@ def merge_imgs(args, res_dir, img1, img2, dx, dy):
if not args.exclude_reverse:
res.paste(i1, (i1_x, i1_y))
res.save(res_path[:-4] + '_r.jpg')

def main():
parser = add_merge_args(get_default_parser())
args = parser.parse_args()
res_dir = get_full_path(args, args.result_dir, mkdir=True)
with open(get_full_path(args, args.stitching_result)) as csvfile:
reader = csv.reader(csvfile)
next(reader) # skip header row
for img1, img2, dx, dy, *_ in reader:
merge_imgs(args, res_dir, img1, img2, dx, dy)

if __name__=='__main__':
main()
88 changes: 60 additions & 28 deletions src/multifocal_stitching/stitching.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,29 @@
from itertools import product
from scipy import fft
from sklearn.cluster import AgglomerativeClustering
from typing import Any, Tuple, Generator

from .utils import *
from .merge_imgs import add_merge_args, merge_imgs

def get_filter_mask(img, r):
def get_filter_mask(img: np.ndarray, r: int) -> np.ndarray:
x, y = img.shape
mask = np.zeros((x, y), dtype="uint8")
cv2.circle(mask, (y//2, x//2), r, 255, -1)
return mask

def apply_filter(fft, img, filter_mask):
def apply_filter(fft: Any, img: np.ndarray, filter_mask: np.ndarray) -> np.ndarray:
res = fft.fftshift(img)
res[filter_mask == 0] = 0
return fft.ifftshift(res)

def corr(a1, a2):
def corr(a1: np.ndarray, a2: np.ndarray) -> float:
if len(a1) == 0 or len(a2) == 0:
return 0
return np.corrcoef(a1, a2)[0,1]

def get_overlap(img1, img2, coords, min_overlap=0.):
def get_overlap(img1: np.ndarray, img2: np.ndarray,
coords: Tuple[int], min_overlap: float=0.) -> Tuple[float,int]:
dx, dy = coords
assert img1.shape == img2.shape
Y, X = img1.shape
Expand All @@ -42,20 +44,23 @@ def get_overlap(img1, img2, coords, min_overlap=0.):
f1, f2 = s1.flatten(), s2.flatten()
return corr(f1, f2), area

def centroids(coords, labels):
def centroids(coords: np.ndarray, labels: np.ndarray) -> Generator[np.ndarray, None, None]:
for c in range(labels.max()+1):
yield round_int_np(coords[labels == c].mean(axis=0))

def get_peak_centroids(args, res):
#yield round_int_np(np.unravel_index(np.argmax(res), res.shape))
#cutoff = res > (res.mean() + args.peak_cutoff_std * res.std())
cutoff = res > (res.max() - args.peak_cutoff_std * res.std())
def get_peak_centroids(res: np.ndarray,
peak_cutoff_std: float,
peaks_dist_threshold: float) -> Generator[np.ndarray, None, None]:
''' Cluster peaks that are within `peak_cutoff_std` standard deviations
below maximum peak, then yields centroid of clusters
'''
cutoff = res > (res.max() - peak_cutoff_std * res.std())
if cutoff.sum() > 2:
X = np.argwhere(cutoff)
labels = AgglomerativeClustering(
n_clusters=None,
linkage='single',
distance_threshold=args.peaks_dist_threshold
distance_threshold=peaks_dist_threshold
).fit(X).labels_
cents = list(centroids(X, labels))
yield from sorted(cents, key=lambda coord: res[tuple(coord)])
Expand All @@ -67,32 +72,59 @@ def get_peak_centroids(args, res):
['corr_coeff', 'corr', 'coord', 'val', 'area', 'best_r', 'best_win']
)

def candidate_stitches(args, img1, img2):
assert img1.shape == img2.shape
win = cv2.createHanningWindow(img1.T.shape, cv2.CV_64F)
def print_stitching_result(r: StitchingResult):
dx, dy = r.coord
print(f'dx:{dx: 5} dy:{dy: 5} corr:{r.corr_coeff:+f} area:{r.area: 9} r:{r.best_r: 3}')

def candidate_stitches(img1: np.ndarray, img2: np.ndarray,
use_wins: Tuple[int] = (0,),
workers: int = 2,
peak_cutoff_std: float = 1,
peaks_dist_threshold: float = 25,
filter_radii: Tuple[int] = (100,50,20),
min_overlap: float = 0.125,
early_term_thresh: float = 0.7,
verbose: bool = False,
) -> Generator[StitchingResult, None, None]:
assert img1.shape == img2.shape, 'Images must be of same size!'
assert len(img1.shape) == 2, 'Image must be 2D array (one color channel)'
Y, X = img1.shape
for use_win in args.use_wins:

# Create window if required
if 1 in use_wins:
win = cv2.createHanningWindow(img1.T.shape, cv2.CV_64F)

for use_win in use_wins:
# 1. Compute FFT of input images
f1, f2 = [fft.fft2(img * win if use_win else img,
norm='ortho', workers=args.workers)
norm='ortho', workers=workers)
for img in (img1, img2)]
for r in args.filter_radius:

for r in filter_radii:
# 2. Apply low-pass filter to images in frequency domain
mask = get_filter_mask(img1, r)
G1, G2 = [apply_filter(fft, f, mask) for f in (f1, f2)]

# 3. Compute cross power spectrum
R = G1 * np.ma.conjugate(G2)
R /= np.absolute(R)
res = fft.ifft2(R, img1.shape, norm='ortho', workers=args.workers)
for dy, dx in get_peak_centroids(args, res):

# 4. Obtain cross correlation in spatial domain by taking inverse FFT
res = fft.ifft2(R, img1.shape, norm='ortho', workers=workers)

# 5. Group peaks and find centroids of groups
for dy, dx in get_peak_centroids(res, peak_cutoff_std, peaks_dist_threshold):
for dX, dY in product((dx, -X+dx), (dy, -Y+dy)):
coef, area = get_overlap(img1, img2, (dX, dY),
min_overlap=args.min_overlap)
if args.verbose:
print(f'dx:{dX: 5} dy:{dY: 5} corr:{coef:+f} area:{area: 9} r:{r: 3}')
yield StitchingResult(coef, res, (dX, dY), res[dY, dX], area, r, use_win)
if coef >= args.early_term_thresh:
coef, area = get_overlap(img1, img2, (dX, dY), min_overlap=min_overlap)
result = StitchingResult(coef, res, (dX, dY), res[dY, dX], area, r, use_win)
if verbose:
print_stitching_result(result)
yield result
if coef >= early_term_thresh:
return

def stitch(args, img1, img2):
return max(candidate_stitches(args, img1, img2), key=lambda r: r.corr_coeff)
def stitch(img1: np.ndarray, img2: np.ndarray, **kwargs) -> StitchingResult:
return max(candidate_stitches(img1, img2, **kwargs), key=lambda r: r.corr_coeff)

def add_stitching_args(parser):
parser.add_argument('--ext',
Expand All @@ -104,7 +136,7 @@ def add_stitching_args(parser):
parser.add_argument('--workers', type=int,
help='Number of CPU threads to use in FFT',
default=2)
parser.add_argument('--min_overlap', type=int,
parser.add_argument('--min_overlap', type=float,
help='Set lower limit for overlapping region as a fraction of total image area',
default=0.125)
parser.add_argument('--early_term_thresh', type=float,
Expand All @@ -119,7 +151,7 @@ def add_stitching_args(parser):
parser.add_argument('--peaks_dist_threshold', type=float,
help='Distance to consider as part of same cluster when finding peak centroid',
default=25)
parser.add_argument('--filter_radius', nargs="+", type=int,
parser.add_argument('--filter_radii', nargs="+", type=int,
default=(100,50,20),
help='Low-pass filter radii to try, smaller matches coarser/out-of-focus features')
return parser
69 changes: 61 additions & 8 deletions tests/test_stitching.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,76 @@
import unittest
import numpy as np
from multifocal_stitching.stitching import *
from multifocal_stitching.utils import *
import os
import sys
import shutil
import csv
from multifocal_stitching.stitching import stitch
from multifocal_stitching.utils import read_img
from multifocal_stitching.merge_imgs import merge_imgs
from multifocal_stitching.__main__ import main as cli
from multifocal_stitching.__main__ import CSV_HEADER

def coord_is_close(res, val, tol=5):
assert np.linalg.norm(np.array(res.coord) - np.array(val), 1) <= tol

def get_full_path(base_dir, filename, mkdir=False):
path = os.path.join(base_dir, filename)
if mkdir and not os.path.exists(path):
os.makedirs(path)
return path

class TestCLI(unittest.TestCase):
def setUp(self):
self.base_dir = 'tests/img_folder'
shutil.rmtree(self.base_dir, ignore_errors=True)
os.makedirs(self.base_dir)
shutil.copy('tests/imgs/high_freq_features_1_small.jpg', self.base_dir)
shutil.copy('tests/imgs/high_freq_features_2_small.jpg', self.base_dir)
sys.argv = ['', self.base_dir]

def test_cli(self):
cli()
csvfilename = get_full_path(self.base_dir, 'stitching_result.csv')
self.assertTrue(os.path.isfile(csvfilename))
with open(csvfilename) as csvfile:
reader = csv.reader(csvfile)
self.assertEqual(next(reader), CSV_HEADER)
img_name1, img_name2, dx, dy, corr_coeff, area, best_r, best_win = next(reader)
self.assertEqual(img_name1, 'high_freq_features_1_small.jpg')
self.assertEqual(img_name2, 'high_freq_features_2_small.jpg')
self.assertAlmostEqual(int(dx), 2474, delta=2)
self.assertAlmostEqual(int(dy), 495, delta=2)
self.assertAlmostEqual(float(corr_coeff), 0.5548805792236229)
self.assertEqual(int(area), 2274390)
self.assertEqual(int(best_r), 50)
self.assertEqual(int(best_win), 0)

merged_name = os.path.join(
self.base_dir,
'merged',
'high_freq_features_1_small__high_freq_features_2_small.jpg'
)
merged_r_name = os.path.join(
self.base_dir,
'merged',
'high_freq_features_1_small__high_freq_features_2_small_r.jpg'
)
self.assertTrue(os.path.isfile(merged_name))
self.assertTrue(os.path.isfile(merged_r_name))
merged = read_img(merged_name)
merged_r = read_img(merged_r_name)
self.assertEqual(merged.shape, merged_r.shape)
self.assertEqual(merged.shape, (2655, 6314))

class TestStitching(unittest.TestCase):
def setUp(self):
parser = add_stitching_args(add_merge_args(get_default_parser()))
self.args = parser.parse_args(['tests/imgs'])
self.base_dir = 'tests/imgs'

def stitch_name(self, name):
names = [f'{name}_{ext}_small.jpg' for ext in '12']
img_names = [get_full_path(self.args, name) for name in names]
res = stitch(self.args, *map(read_img, img_names))
res_dir = get_full_path(self.args, self.args.result_dir, mkdir=True)
res = stitch(*[read_img(get_full_path(self.base_dir, name)) for name in names])
res_dir = get_full_path(self.base_dir, 'merged', mkdir=True)
dx, dy = res.coord
merge_imgs(self.args, res_dir, names[0], names[1], dx, dy)
return res

def test_stitching_high_freq_features(self):
Expand Down

0 comments on commit be2f6d9

Please sign in to comment.