-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathmisc.py
318 lines (240 loc) · 10.2 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""Miscellaneous functions and classes for the geocalib_inference package."""
import functools
import inspect
import logging
from typing import Callable, List
import numpy as np
import torch
logger = logging.getLogger(__name__)
# mypy: ignore-errors
def autocast(func: Callable) -> Callable:
"""Cast the inputs of a TensorWrapper method to PyTorch tensors if they are numpy arrays.
Use the device and dtype of the wrapper.
Args:
func (Callable): Method of a TensorWrapper class.
Returns:
Callable: Wrapped method.
"""
@functools.wraps(func)
def wrap(self, *args):
device = torch.device("cpu")
dtype = None
if isinstance(self, TensorWrapper):
if self._data is not None:
device = self.device
dtype = self.dtype
elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
raise ValueError(self)
cast_args = []
for arg in args:
if isinstance(arg, np.ndarray):
arg = torch.from_numpy(arg)
arg = arg.to(device=device, dtype=dtype)
cast_args.append(arg)
return func(self, *cast_args)
return wrap
class TensorWrapper:
"""Wrapper for PyTorch tensors."""
_data = None
@autocast
def __init__(self, data: torch.Tensor):
"""Wrapper for PyTorch tensors."""
self._data = data
@property
def shape(self) -> torch.Size:
"""Shape of the underlying tensor."""
return self._data.shape[:-1]
@property
def device(self) -> torch.device:
"""Get the device of the underlying tensor."""
return self._data.device
@property
def dtype(self) -> torch.dtype:
"""Get the dtype of the underlying tensor."""
return self._data.dtype
def __getitem__(self, index) -> torch.Tensor:
"""Get the underlying tensor."""
return self.__class__(self._data[index])
def __setitem__(self, index, item):
"""Set the underlying tensor."""
self._data[index] = item.data
def to(self, *args, **kwargs):
"""Move the underlying tensor to a new device."""
return self.__class__(self._data.to(*args, **kwargs))
def cpu(self):
"""Move the underlying tensor to the CPU."""
return self.__class__(self._data.cpu())
def cuda(self):
"""Move the underlying tensor to the GPU."""
return self.__class__(self._data.cuda())
def pin_memory(self):
"""Pin the underlying tensor to memory."""
return self.__class__(self._data.pin_memory())
def float(self):
"""Cast the underlying tensor to float."""
return self.__class__(self._data.float())
def double(self):
"""Cast the underlying tensor to double."""
return self.__class__(self._data.double())
def detach(self):
"""Detach the underlying tensor."""
return self.__class__(self._data.detach())
def numpy(self):
"""Convert the underlying tensor to a numpy array."""
return self._data.detach().cpu().numpy()
def new_tensor(self, *args, **kwargs):
"""Create a new tensor of the same type and device."""
return self._data.new_tensor(*args, **kwargs)
def new_zeros(self, *args, **kwargs):
"""Create a new tensor of the same type and device."""
return self._data.new_zeros(*args, **kwargs)
def new_ones(self, *args, **kwargs):
"""Create a new tensor of the same type and device."""
return self._data.new_ones(*args, **kwargs)
def new_full(self, *args, **kwargs):
"""Create a new tensor of the same type and device."""
return self._data.new_full(*args, **kwargs)
def new_empty(self, *args, **kwargs):
"""Create a new tensor of the same type and device."""
return self._data.new_empty(*args, **kwargs)
def unsqueeze(self, *args, **kwargs):
"""Create a new tensor of the same type and device."""
return self.__class__(self._data.unsqueeze(*args, **kwargs))
def squeeze(self, *args, **kwargs):
"""Create a new tensor of the same type and device."""
return self.__class__(self._data.squeeze(*args, **kwargs))
@classmethod
def stack(cls, objects: List, dim=0, *, out=None):
"""Stack a list of objects with the same type and shape."""
data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
return cls(data)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
"""Support torch functions."""
if kwargs is None:
kwargs = {}
return cls.stack(*args, **kwargs) if func is torch.stack else NotImplemented
class EuclideanManifold:
"""Simple euclidean manifold."""
@staticmethod
def J_plus(x: torch.Tensor) -> torch.Tensor:
"""Plus operator Jacobian."""
return torch.eye(x.shape[-1]).to(x)
@staticmethod
def plus(x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
"""Plus operator."""
return x + delta
class SphericalManifold:
"""Implementation of the spherical manifold.
Following the derivation from 'Integrating Generic Sensor Fusion Algorithms with Sound State
Representations through Encapsulation of Manifolds' by Hertzberg et al. (B.2, p. 25).
Householder transformation following Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by
Golub et al.
"""
@staticmethod
def householder_vector(x: torch.Tensor) -> torch.Tensor:
"""Return the Householder vector and beta.
Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by Golub et al. (Johns Hopkins Studies
in Mathematical Sciences) but using the nth element of the input vector as pivot instead of
first.
This computes the vector v with v(n) = 1 and beta such that H = I - beta * v * v^T is
orthogonal and H * x = ||x||_2 * e_n.
Args:
x (torch.Tensor): [..., n] tensor.
Returns:
torch.Tensor: v of shape [..., n]
torch.Tensor: beta of shape [...]
"""
sigma = torch.sum(x[..., :-1] ** 2, -1)
xpiv = x[..., -1]
norm = torch.norm(x, dim=-1)
if torch.any(sigma < 1e-7):
sigma = torch.where(sigma < 1e-7, sigma + 1e-7, sigma)
logger.warning("sigma < 1e-7")
vpiv = torch.where(xpiv < 0, xpiv - norm, -sigma / (xpiv + norm))
beta = 2 * vpiv**2 / (sigma + vpiv**2)
v = torch.cat([x[..., :-1] / vpiv[..., None], torch.ones_like(vpiv)[..., None]], -1)
return v, beta
@staticmethod
def apply_householder(y: torch.Tensor, v: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
"""Apply Householder transformation.
Args:
y (torch.Tensor): Vector to transform of shape [..., n].
v (torch.Tensor): Householder vector of shape [..., n].
beta (torch.Tensor): Householder beta of shape [...].
Returns:
torch.Tensor: Transformed vector of shape [..., n].
"""
return y - v * (beta * torch.einsum("...i,...i->...", v, y))[..., None]
@classmethod
def J_plus(cls, x: torch.Tensor) -> torch.Tensor:
"""Plus operator Jacobian."""
v, beta = cls.householder_vector(x)
H = -torch.einsum("..., ...k, ...l->...kl", beta, v, v)
H = H + torch.eye(H.shape[-1]).to(H)
return H[..., :-1] # J
@classmethod
def plus(cls, x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
"""Plus operator.
Equation 109 (p. 25) from 'Integrating Generic Sensor Fusion Algorithms with Sound State
Representations through Encapsulation of Manifolds' by Hertzberg et al. but using the nth
element of the input vector as pivot instead of first.
Args:
x: point on the manifold
delta: tangent vector
"""
eps = 1e-7
# keep norm is not equal to 1
nx = torch.norm(x, dim=-1, keepdim=True)
nd = torch.norm(delta, dim=-1, keepdim=True)
# make sure we don't divide by zero in backward as torch.where computes grad for both
# branches
nd_ = torch.where(nd < eps, nd + eps, nd)
sinc = torch.where(nd < eps, nd.new_ones(nd.shape), torch.sin(nd_) / nd_)
# cos is applied to last dim instead of first
exp_delta = torch.cat([sinc * delta, torch.cos(nd)], -1)
v, beta = cls.householder_vector(x)
return nx * cls.apply_householder(exp_delta, v, beta)
@torch.jit.script
def J_vecnorm(vec: torch.Tensor) -> torch.Tensor:
"""Compute the jacobian of vec / norm2(vec).
Args:
vec (torch.Tensor): [..., D] tensor.
Returns:
torch.Tensor: [..., D, D] Jacobian.
"""
D = vec.shape[-1]
norm_x = torch.norm(vec, dim=-1, keepdim=True).unsqueeze(-1) # (..., 1, 1)
if (norm_x == 0).any():
norm_x = norm_x + 1e-6
xxT = torch.einsum("...i,...j->...ij", vec, vec) # (..., D, D)
identity = torch.eye(D, device=vec.device, dtype=vec.dtype) # (D, D)
return identity / norm_x - (xxT / norm_x**3) # (..., D, D)
@torch.jit.script
def J_focal2fov(focal: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
"""Compute the jacobian of the focal2fov function."""
return -4 * h / (4 * focal**2 + h**2)
@torch.jit.script
def J_up_projection(uv: torch.Tensor, abc: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
"""Compute the jacobian of the up-vector projection.
Args:
uv (torch.Tensor): Normalized image coordinates of shape (..., 2).
abc (torch.Tensor): Gravity vector of shape (..., 3).
wrt (str, optional): Parameter to differentiate with respect to. Defaults to "uv".
Raises:
ValueError: If the wrt parameter is unknown.
Returns:
torch.Tensor: Jacobian with respect to the parameter.
"""
if wrt == "uv":
c = abc[..., 2][..., None, None, None]
return -c * torch.eye(2, device=uv.device, dtype=uv.dtype).expand(uv.shape[:-1] + (2, 2))
elif wrt == "abc":
J = uv.new_zeros(uv.shape[:-1] + (2, 3))
J[..., 0, 0] = 1
J[..., 1, 1] = 1
J[..., 0, 2] = -uv[..., 0]
J[..., 1, 2] = -uv[..., 1]
return J
else:
raise ValueError(f"Unknown wrt: {wrt}")