-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dense matching based line matcher #87
Draft
B1ueber2y
wants to merge
30
commits into
main
Choose a base branch
from
features/dense_matching
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 7 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
7490bec
dense matcher based line matching
B1ueber2y 237e215
formatting.
B1ueber2y 009b0db
add a test script.
B1ueber2y b95f8c5
minor.
B1ueber2y 15e88b3
formatting.
B1ueber2y 62ed189
Merge branch 'main' into features/dense_matching
B1ueber2y 4acbf88
fix formatting with ruff.
B1ueber2y 786211c
add tiny roma.
B1ueber2y be25734
Merge branch 'main' into features/dense_matching
B1ueber2y 7248ddd
minor.
B1ueber2y 1207f5f
Fix different input/output conventions for tiny RoMa
71b9a07
formattin.
B1ueber2y 9df66fc
refactor. set overlap threshold to 0.2
B1ueber2y 9d43541
Merge branch 'main' into features/dense_matching
B1ueber2y 8f2b91a
Minor simplifications
rpautrat b98ade9
RoMa mode in config
rpautrat 2e967af
One-to-many matching
rpautrat f20ed21
Mutual nearest neighbors + small fixes
rpautrat d355819
Merge branch 'main' into features/dense_matching
B1ueber2y 9141712
merge and fix linting issues.
B1ueber2y c8bb7d8
Merge branch 'main' into features/dense_matching
B1ueber2y bb5b73a
Merge branch 'main' into features/dense_matching
B1ueber2y ae78791
Merge branch 'main' into features/dense_matching
B1ueber2y 9402743
Make Gluestick install editable
b7a7b0f
Update the dense matcher configuration
8b5af66
Format
rpautrat 48782c1
Merge branch 'main' into features/dense_matching
B1ueber2y 12c9684
Revert editable GlueStick
rpautrat bf729ed
Merge branch 'features/dense_matching' of github.com:cvg/limap into f…
rpautrat adc1047
Merge branch 'main' into features/dense_matching
B1ueber2y File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .extractor import DenseNaiveExtractor | ||
from .matcher import RoMaLineMatcher |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .roma import RoMa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import os | ||
|
||
import torch | ||
|
||
|
||
class BaseDenseMatcher: | ||
def __init__(self): | ||
pass | ||
|
||
def to_normalized_coordinates(self, coords, h, w): | ||
""" | ||
coords: (..., 2) in the order x, y | ||
""" | ||
coords_x = 2 / w * coords[..., 0] - 1 | ||
coords_y = 2 / h * coords[..., 1] - 1 | ||
return torch.stack([coords_x, coords_y], axis=-1) | ||
|
||
def to_unnormalized_coordinates(self, coords, h, w): | ||
""" | ||
Inverse operation of `to_normalized_coordinates` | ||
""" | ||
coords_x = (coords[..., 0] + 1) * w / 2 | ||
coords_y = (coords[..., 1] + 1) * h / 2 | ||
return torch.stack([coords_x, coords_y], axis=-1) | ||
|
||
def get_sample_thresh(self): | ||
""" | ||
return sample threshold | ||
""" | ||
raise NotImplementedError | ||
|
||
def get_warpping_symmetric(self, img1, img2): | ||
""" | ||
return warp_1to2 ([-1, 1]), cert_1to2, warp_2to1([-1, 1]), cert_2to1 | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import os | ||
|
||
import romatch | ||
from PIL import Image | ||
|
||
from .base import BaseDenseMatcher | ||
|
||
|
||
class RoMa(BaseDenseMatcher): | ||
def __init__(self, mode="outdoor", device="cuda"): | ||
super(RoMa).__init__() | ||
if mode == "outdoor": | ||
self.model = romatch.roma_outdoor(device=device, coarse_res=560) | ||
elif mode == "indoor": | ||
self.model = romatch.roma_indoor(device=device, coarse_res=560) | ||
|
||
def get_sample_thresh(self): | ||
return self.model.sample_thresh | ||
|
||
def get_warpping_symmetric(self, img1, img2): | ||
warp, certainty = self.model.match( | ||
Image.fromarray(img1), Image.fromarray(img2) | ||
) | ||
N = 864 | ||
return ( | ||
warp[:, :N, 2:], | ||
certainty[:, :N], | ||
warp[:, N:, :2], | ||
certainty[:, N:], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import os | ||
|
||
import numpy as np | ||
|
||
import limap.util.io as limapio | ||
|
||
from ..base_detector import BaseDetector, BaseDetectorOptions | ||
|
||
|
||
class DenseNaiveExtractor(BaseDetector): | ||
def __init__(self, options=BaseDetectorOptions(), device=None): | ||
super().__init__(options) | ||
|
||
def get_module_name(self): | ||
return "dense_naive" | ||
|
||
def get_descinfo_fname(self, descinfo_folder, img_id): | ||
fname = os.path.join(descinfo_folder, f"descinfo_{img_id}.npz") | ||
return fname | ||
|
||
def save_descinfo(self, descinfo_folder, img_id, descinfo): | ||
limapio.check_makedirs(descinfo_folder) | ||
fname = self.get_descinfo_fname(descinfo_folder, img_id) | ||
limapio.save_npz(fname, descinfo) | ||
|
||
def read_descinfo(self, descinfo_folder, img_id): | ||
fname = self.get_descinfo_fname(descinfo_folder, img_id) | ||
descinfo = limapio.read_npz(fname) | ||
return descinfo | ||
|
||
def extract(self, camview, segs): | ||
img = camview.read_image(set_gray=self.set_gray) | ||
lines = segs[:, :4].reshape(-1, 2, 2) | ||
scores = segs[:, -1] * np.sqrt( | ||
np.linalg.norm(segs[:, :2] - segs[:, 2:4], axis=1) | ||
) | ||
descinfo = { | ||
"camview": camview, | ||
"image_shape": img.shape, | ||
"lines": lines, | ||
"scores": scores, | ||
} | ||
return descinfo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
import os | ||
from typing import NamedTuple | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
import limap.util.io as limapio | ||
|
||
from ..base_matcher import BaseMatcher, BaseMatcherOptions | ||
|
||
|
||
class BaseDenseLineMatcherOptions(NamedTuple): | ||
n_samples: int = 21 | ||
segment_percentage_th: float = 0.5 | ||
device = "cuda" | ||
pixel_th: float = 10.0 | ||
|
||
|
||
class BaseDenseLineMatcher(BaseMatcher): | ||
def __init__( | ||
self, | ||
extractor, | ||
dense_matcher, | ||
dense_options=BaseDenseLineMatcherOptions(), | ||
options=BaseMatcherOptions(), | ||
): | ||
super().__init__(extractor, options) | ||
assert self.extractor.get_module_name() == "dense_naive" | ||
self.dense_matcher = dense_matcher | ||
self.dense_options = dense_options | ||
assert self.dense_options.n_samples >= 2 | ||
|
||
def get_module_name(self): | ||
raise NotImplementedError | ||
|
||
def match_pair(self, descinfo1, descinfo2): | ||
if self.topk == 0: | ||
return self.match_segs_with_descinfo(descinfo1, descinfo2) | ||
else: | ||
return self.match_segs_with_descinfo_topk( | ||
descinfo1, descinfo2, topk=self.topk | ||
) | ||
|
||
def compute_distance_one_direction( | ||
self, descinfo1, descinfo2, warp_1to2, cert_1to2 | ||
): | ||
# get point samples along lines | ||
segs1 = torch.from_numpy(descinfo1["lines"]).to( | ||
self.dense_options.device | ||
) | ||
n_segs1 = segs1.shape[0] | ||
ratio = torch.arange( | ||
0, | ||
1 + 0.5 / (self.dense_options.n_samples - 1), | ||
1.0 / (self.dense_options.n_samples - 1), | ||
).to(self.dense_options.device) | ||
ratio = ratio[:, None].repeat(1, 2) | ||
coords_1 = ratio * segs1[:, [0], :].repeat( | ||
1, self.dense_options.n_samples, 1 | ||
) + (1 - ratio) * segs1[:, [1], :].repeat( | ||
1, self.dense_options.n_samples, 1 | ||
) | ||
coords_1 = coords_1.reshape(-1, 2) | ||
coords = self.dense_matcher.to_normalized_coordinates( | ||
coords_1, descinfo1["image_shape"][0], descinfo1["image_shape"][1] | ||
) | ||
coords_to_2 = F.grid_sample( | ||
warp_1to2.permute(2, 0, 1)[None], | ||
coords[None, None], | ||
align_corners=False, | ||
mode="bilinear", | ||
)[0, :, 0].mT | ||
coords_to_2 = self.dense_matcher.to_unnormalized_coordinates( | ||
coords_to_2, | ||
descinfo2["image_shape"][0], | ||
descinfo2["image_shape"][1], | ||
) | ||
cert_to_2 = F.grid_sample( | ||
cert_1to2[None, None, ...], | ||
coords[None, None], | ||
align_corners=False, | ||
mode="bilinear", | ||
)[0, 0, 0] | ||
cert_to_2 = cert_to_2.reshape(-1, self.dense_options.n_samples) | ||
|
||
# get projections | ||
segs2 = torch.from_numpy(descinfo2["lines"]).to( | ||
self.dense_options.device | ||
) | ||
n_segs2 = segs2.shape[0] | ||
starts2, ends2 = segs2[:, 0, :], segs2[:, 1, :] | ||
directions = ends2 - starts2 | ||
directions /= torch.norm(directions, dim=1, keepdim=True) | ||
starts2_proj = (starts2 * directions).sum(1) | ||
ends2_proj = (ends2 * directions).sum(1) | ||
|
||
# get line equations | ||
starts_homo = torch.cat([starts2, torch.ones_like(segs2[:, [0], 0])], 1) | ||
ends_homo = torch.cat([ends2, torch.ones_like(segs2[:, [0], 0])], 1) | ||
lines2_homo = torch.cross(starts_homo, ends_homo) | ||
lines2_homo /= torch.norm(lines2_homo[:, :2], dim=1)[:, None].repeat( | ||
1, 3 | ||
) | ||
|
||
# compute distance | ||
coords_to_2_homo = torch.cat( | ||
[coords_to_2, torch.ones_like(coords_to_2[:, [0]])], 1 | ||
) | ||
coords_proj = torch.matmul(coords_to_2, directions.T) | ||
dists = torch.abs(torch.matmul(coords_to_2_homo, lines2_homo.T)) | ||
overlap = torch.where( | ||
coords_proj > starts2_proj, | ||
torch.ones_like(dists), | ||
torch.zeros_like(dists), | ||
) | ||
overlap = torch.where( | ||
coords_proj < ends2_proj, overlap, torch.zeros_like(dists) | ||
) | ||
dists = dists.reshape( | ||
n_segs1, self.dense_options.n_samples, n_segs2 | ||
).permute(0, 2, 1) | ||
overlap = ( | ||
overlap.reshape(n_segs1, self.dense_options.n_samples, n_segs2) | ||
.permute(0, 2, 1) | ||
.to(torch.bool) | ||
) | ||
|
||
# get active lines for each target | ||
sample_thresh = self.dense_matcher.get_sample_thresh() | ||
good_sample = cert_to_2 > sample_thresh | ||
good_sample = torch.logical_and( | ||
good_sample[:, None, :].repeat(1, overlap.shape[1], 1), overlap | ||
) | ||
sample_weight = good_sample.to(torch.float) | ||
sample_weight_sum = sample_weight.sum(2) | ||
sample_weight[sample_weight_sum > 0] /= sample_weight_sum[ | ||
sample_weight_sum > 0 | ||
][:, None].repeat(1, sample_weight.shape[2]) | ||
# is_active = ( | ||
# sample_weight_sum | ||
# > self.dense_options.segment_percentage_th | ||
# * self.dense_options.n_samples | ||
# ) | ||
|
||
# get weighted dists | ||
weighted_dists = (dists * sample_weight).sum(2) | ||
weighted_dists[weighted_dists == 0] = 10000.0 | ||
return weighted_dists, sample_weight_sum / self.dense_options.n_samples | ||
|
||
def match_segs_with_descinfo(self, descinfo1, descinfo2): | ||
img1 = descinfo1["camview"].read_image() | ||
img2 = descinfo2["camview"].read_image() | ||
( | ||
warp_1to2, | ||
cert_1to2, | ||
warp_2to1, | ||
cert_2to1, | ||
) = self.dense_matcher.get_warpping_symmetric(img1, img2) | ||
|
||
# compute distance and overlap | ||
dists_1to2, overlap_1to2 = self.compute_distance_one_direction( | ||
descinfo1, descinfo2, warp_1to2, cert_1to2 | ||
) | ||
dists_2to1, overlap_2to1 = self.compute_distance_one_direction( | ||
descinfo2, descinfo1, warp_2to1, cert_2to1 | ||
) | ||
# overlap = torch.maximum(overlap_1to2, overlap_2to1.T) | ||
dists = torch.where( | ||
overlap_1to2 > overlap_2to1.T, dists_1to2, dists_2to1.T | ||
) | ||
|
||
# match: one-way nearest neighbor | ||
# TODO: one-to-many matching | ||
inds_1, inds_2 = torch.nonzero( | ||
dists | ||
== dists.min(dim=-1, keepdim=True).values | ||
* (dists <= self.dense_options.pixel_th), | ||
as_tuple=True, | ||
) | ||
inds_1 = inds_1.detach().cpu().numpy() | ||
inds_2 = inds_2.detach().cpu().numpy() | ||
matches_t = np.stack([inds_1, inds_2], axis=1) | ||
return matches_t | ||
|
||
def match_segs_with_descinfo_topk(self, descinfo1, descinfo2, topk=10): | ||
raise NotImplementedError | ||
|
||
|
||
class RoMaLineMatcher(BaseDenseLineMatcher): | ||
def __init__( | ||
self, | ||
extractor, | ||
mode="outdoor", | ||
dense_options=BaseDenseLineMatcherOptions(), | ||
options=BaseMatcherOptions(), | ||
): | ||
from .dense_matcher import RoMa | ||
|
||
roma_matcher = RoMa(mode=mode, device=dense_options.device) | ||
super().__init__( | ||
extractor, | ||
roma_matcher, | ||
dense_options=dense_options, | ||
options=options, | ||
) | ||
|
||
def get_module_name(self): | ||
return "dense_roma" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
romatch is not included in requirements.txt or dependencies
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we making a separate installation, as we now already make HAWP, LBD separately installed to reduce dependencies.