-
Notifications
You must be signed in to change notification settings - Fork 103
/
Copy pathpreprocessing.py
executable file
·145 lines (112 loc) · 4.99 KB
/
preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn.functional as F
import numpy as np
def numpy_to_torch(a: np.ndarray):
return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0)
def torch_to_numpy(a: torch.Tensor):
return a.squeeze(0).permute(1,2,0).numpy()
def sample_patch_transformed(im, pos, scale, image_sz, transforms, is_mask=False):
"""Extract transformed image samples.
args:
im: Image.
pos: Center position for extraction.
scale: Image scale to extract features from.
image_sz: Size to resize the image samples to before extraction.
transforms: A set of image transforms to apply.
"""
# Get image patche
im_patch, _ = sample_patch(im, pos, scale*image_sz, image_sz, is_mask=is_mask)
# Apply transforms
im_patches = torch.cat([T(im_patch, is_mask=is_mask) for T in transforms])
return im_patches
def sample_patch_multiscale(im, pos, scales, image_sz, mode: str='replicate', max_scale_change=None):
"""Extract image patches at multiple scales.
args:
im: Image.
pos: Center position for extraction.
scales: Image scales to extract image patches from.
image_sz: Size to resize the image samples to
mode: how to treat image borders: 'replicate' (default), 'inside' or 'inside_major'
max_scale_change: maximum allowed scale change when using 'inside' and 'inside_major' mode
"""
if isinstance(scales, (int, float)):
scales = [scales]
# Get image patches
patch_iter, coord_iter = zip(*(sample_patch(im, pos, s*image_sz, image_sz, mode=mode,
max_scale_change=max_scale_change) for s in scales))
im_patches = torch.cat(list(patch_iter))
patch_coords = torch.cat(list(coord_iter))
return im_patches, patch_coords
def sample_patch(im: torch.Tensor, pos: torch.Tensor, sample_sz: torch.Tensor, output_sz: torch.Tensor = None,
mode: str = 'replicate', max_scale_change=None, is_mask=False):
"""Sample an image patch.
args:
im: Image
pos: center position of crop
sample_sz: size to crop
output_sz: size to resize to
mode: how to treat image borders: 'replicate' (default), 'inside' or 'inside_major'
max_scale_change: maximum allowed scale change when using 'inside' and 'inside_major' mode
"""
# if mode not in ['replicate', 'inside']:
# raise ValueError('Unknown border mode \'{}\'.'.format(mode))
# copy and convert
posl = pos.long().clone()
pad_mode = mode
# Get new sample size if forced inside the image
if mode == 'inside' or mode == 'inside_major':
pad_mode = 'replicate'
im_sz = torch.Tensor([im.shape[2], im.shape[3]])
shrink_factor = (sample_sz.float() / im_sz)
if mode == 'inside':
shrink_factor = shrink_factor.max()
elif mode == 'inside_major':
shrink_factor = shrink_factor.min()
shrink_factor.clamp_(min=1, max=max_scale_change)
sample_sz = (sample_sz.float() / shrink_factor).long()
# Compute pre-downsampling factor
if output_sz is not None:
resize_factor = torch.min(sample_sz.float() / output_sz.float()).item()
df = int(max(int(resize_factor - 0.1), 1))
else:
df = int(1)
sz = sample_sz.float() / df # new size
# Do downsampling
if df > 1:
os = posl % df # offset
posl = (posl - os) / df # new position
im2 = im[..., os[0].item()::df, os[1].item()::df] # downsample
else:
im2 = im
# compute size to crop
szl = torch.max(sz.round(), torch.Tensor([2])).long()
# Extract top and bottom coordinates
tl = posl - (szl - 1)/2
br = posl + szl/2 + 1
# Shift the crop to inside
if mode == 'inside' or mode == 'inside_major':
im2_sz = torch.LongTensor([im2.shape[2], im2.shape[3]])
shift = (-tl).clamp(0) - (br - im2_sz).clamp(0)
tl += shift
br += shift
outside = ((-tl).clamp(0) + (br - im2_sz).clamp(0)) // 2
shift = (-tl - outside) * (outside > 0).long()
tl += shift
br += shift
# Get image patch
# im_patch = im2[...,tl[0].item():br[0].item(),tl[1].item():br[1].item()]
# Get image patch
if not is_mask:
im_patch = F.pad(im2, (-tl[1].item(), br[1].item() - im2.shape[3], -tl[0].item(), br[0].item() - im2.shape[2]), pad_mode)
else:
im_patch = F.pad(im2, (-tl[1].item(), br[1].item() - im2.shape[3], -tl[0].item(), br[0].item() - im2.shape[2]))
# Get image coordinates
patch_coord = df * torch.cat((tl, br)).view(1,4)
if output_sz is None or (im_patch.shape[-2] == output_sz[0] and im_patch.shape[-1] == output_sz[1]):
return im_patch.clone(), patch_coord
# Resample
if not is_mask:
im_patch = F.interpolate(im_patch, output_sz.long().tolist(), mode='bilinear')
else:
im_patch = F.interpolate(im_patch, output_sz.long().tolist(), mode='nearest')
return im_patch, patch_coord