-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attempt to change all Numpy calls to Torch calls #357
base: master
Are you sure you want to change the base?
Changes from 14 commits
59aef83
8bf7524
fa8c8d7
ce52619
73f720b
f96e21f
ef3dcf5
bf9b9db
df9b035
fe6cb92
f58e88c
594bb0c
f4739c2
995550e
02b9451
cc557d1
71f51d0
06efbc4
f2578d7
a4d2750
7058122
1a77f74
f4b759a
15ca2be
fb1b50d
62c1709
045d200
20d7ebd
d9dbc85
2756d71
4200b99
1489ea3
8833ee1
d492c36
5bbdf8f
b8a8a46
77b19da
85f196b
a56647d
922d2d3
5dbc8bc
20ab49f
3eceb84
901391e
2748c5c
b77aa5b
65f1b1f
9c274be
eaaf0a9
dc5db2e
c75bbef
915b99f
0305eec
0da12a0
82f71e7
199815c
cf020cf
9513c86
186c14b
b6d2202
84a0689
18f8d8e
8eaf1b7
b805465
7393e77
193c9f2
a3d327b
15de5ed
9e38599
ed64d06
273bd81
9ed4c9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,7 @@ def get_cam_weights(self, | |
target_layers: List[torch.nn.Module], | ||
targets: List[torch.nn.Module], | ||
activations: torch.Tensor, | ||
grads: torch.Tensor) -> np.ndarray: | ||
grads: torch.Tensor) -> torch.Tensor: | ||
raise Exception("Not Implemented") | ||
|
||
def get_cam_image(self, | ||
|
@@ -45,7 +45,7 @@ def get_cam_image(self, | |
targets: List[torch.nn.Module], | ||
activations: torch.Tensor, | ||
grads: torch.Tensor, | ||
eigen_smooth: bool = False) -> np.ndarray: | ||
eigen_smooth: bool = False) -> torch.Tensor: | ||
|
||
weights = self.get_cam_weights(input_tensor, | ||
target_layer, | ||
|
@@ -62,7 +62,7 @@ def get_cam_image(self, | |
def forward(self, | ||
input_tensor: torch.Tensor, | ||
targets: List[torch.nn.Module], | ||
eigen_smooth: bool = False) -> np.ndarray: | ||
eigen_smooth: bool = False) -> torch.Tensor: | ||
|
||
if self.cuda: | ||
input_tensor = input_tensor.cuda() | ||
|
@@ -73,7 +73,7 @@ def forward(self, | |
|
||
outputs = self.activations_and_grads(input_tensor) | ||
if targets is None: | ||
target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1) | ||
target_categories = torch.argmax(outputs.data, axis=-1) | ||
targets = [ClassifierOutputTarget( | ||
category) for category in target_categories] | ||
|
||
|
@@ -106,10 +106,10 @@ def compute_cam_per_layer( | |
self, | ||
input_tensor: torch.Tensor, | ||
targets: List[torch.nn.Module], | ||
eigen_smooth: bool) -> np.ndarray: | ||
activations_list = [a.cpu().data.numpy() | ||
eigen_smooth: bool) -> torch.Tensor: | ||
activations_list = [a.data | ||
for a in self.activations_and_grads.activations] | ||
grads_list = [g.cpu().data.numpy() | ||
grads_list = [g.data | ||
for g in self.activations_and_grads.gradients] | ||
target_size = self.get_target_width_height(input_tensor) | ||
|
||
|
@@ -130,24 +130,24 @@ def compute_cam_per_layer( | |
layer_activations, | ||
layer_grads, | ||
eigen_smooth) | ||
cam = np.maximum(cam, 0) | ||
cam = torch.maximum(cam, torch.tensor(0)) | ||
scaled = scale_cam_image(cam, target_size) | ||
cam_per_target_layer.append(scaled[:, None, :]) | ||
|
||
return cam_per_target_layer | ||
|
||
def aggregate_multi_layers( | ||
self, | ||
cam_per_target_layer: np.ndarray) -> np.ndarray: | ||
cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) | ||
cam_per_target_layer = np.maximum(cam_per_target_layer, 0) | ||
result = np.mean(cam_per_target_layer, axis=1) | ||
cam_per_target_layer: torch.Tensor) -> torch.Tensor: | ||
cam_per_target_layer = torch.cat(cam_per_target_layer, axis=1) | ||
cam_per_target_layer = torch.maximum(cam_per_target_layer, torch.tensor(0)) | ||
result = torch.mean(cam_per_target_layer, axis=1) | ||
return scale_cam_image(result) | ||
|
||
def forward_augmentation_smoothing(self, | ||
input_tensor: torch.Tensor, | ||
targets: List[torch.nn.Module], | ||
eigen_smooth: bool = False) -> np.ndarray: | ||
eigen_smooth: bool = False) -> torch.Tensor: | ||
transforms = tta.Compose( | ||
[ | ||
tta.HorizontalFlip(), | ||
|
@@ -167,18 +167,18 @@ def forward_augmentation_smoothing(self, | |
cam = transform.deaugment_mask(cam) | ||
|
||
# Back to numpy float32, HxW | ||
cam = cam.numpy() | ||
# cam = cam.numpy() | ||
cam = cam[:, 0, :, :] | ||
cams.append(cam) | ||
cams.append(cam) # TODO: Handle this for torch tensors | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Basically pre-initialise a tensor. Ive found that to be drastically faster that lists when dealing with cuda / non-cpu devices |
||
|
||
cam = np.mean(np.float32(cams), axis=0) | ||
cam = torch.mean(cams.to(torch.float32), axis=0) | ||
return cam | ||
|
||
def __call__(self, | ||
input_tensor: torch.Tensor, | ||
targets: List[torch.nn.Module] = None, | ||
aug_smooth: bool = False, | ||
eigen_smooth: bool = False) -> np.ndarray: | ||
eigen_smooth: bool = False) -> torch.Tensor: | ||
|
||
# Smooth the CAM result with test time augmentation | ||
if aug_smooth is True: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
import cv2 | ||
import numpy as np | ||
import torch | ||
from torchvision.transforms import Compose, Normalize, ToTensor | ||
from torchvision.transforms import Compose, Normalize, ToTensor, Resize | ||
from typing import List, Dict | ||
import math | ||
|
||
|
@@ -158,16 +158,30 @@ def show_factorization_on_image(img: np.ndarray, | |
|
||
|
||
def scale_cam_image(cam, target_size=None): | ||
result = [] | ||
for img in cam: | ||
img = img - np.min(img) | ||
img = img / (1e-7 + np.max(img)) | ||
if target_size is not None: | ||
img = cv2.resize(img, target_size) | ||
result.append(img) | ||
result = np.float32(result) | ||
# Disabled the target_size scaling for now | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This method is used in a few places in the code base. However for the GradCam class which Im using for timing tests its primarily used for the transpose (AFAIK), which isnt necessary here (or at least Im getting sane results my side). Will need to be fixed if this change to a pure torch approach is used |
||
# It appears to swap the axes dimensions and needs further work for the | ||
# proof of concept | ||
|
||
return result | ||
# if target_size is not None: | ||
# result = torch.zeros([cam.shape[0], target_size[0], target_size[1]]) | ||
# else: | ||
# result = torch.zeros(cam.shape) | ||
|
||
result = torch.zeros(cam.shape) | ||
|
||
for i in range(cam.shape[0]): | ||
img = cam[i] | ||
img = img - torch.min(img) | ||
img = img / (1e-7 + torch.max(img)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cv2 resize will need work to be done via a torch tensor. Will investigate once I get the concept working |
||
|
||
# if target_size is not None: | ||
# transform = Resize(target_size) | ||
# img = Resize(size = target_size)(img) | ||
|
||
|
||
result[i] = img | ||
|
||
return result.to(torch.float32) | ||
|
||
|
||
def scale_accross_batch_and_channels(tensor, target_size): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[metadata] | ||
name = grad-cam | ||
version = 1.1.0 | ||
version = 1.4.7 | ||
author = Jacob Gildenblat | ||
author_email = [email protected] | ||
description = Many Class Activation Map methods implemented in Pytorch. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM | ||
|
@@ -16,4 +16,4 @@ classifiers = | |
|
||
[options] | ||
packages = find: | ||
python_requires = >=3.6 | ||
python_requires = >=3.6 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
|
||
setuptools.setup( | ||
name='grad-cam', | ||
version='1.4.6', | ||
version='1.4.7', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just did this so on my compute node I know which version Im working with. In terms of what the real version bump would be Im open to a more major potential bump - as the final version of these changes would be substantial and potentially breaking |
||
author='Jacob Gildenblat', | ||
author_email='[email protected]', | ||
description='Many Class Activation Map methods implemented in Pytorch for classification, segmentation, object detection and more', | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now Im trying to do minimal changes to create a proof of concept I can run.
Ive left this as a draft PR and its still definitely a WIP
I just find the PR user interface is great to observe changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to: #356