-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualization.py
49 lines (44 loc) · 1.68 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import pygame
import torch
import numpy as np
from dataclasses import dataclass
@dataclass
class LayerVisualizer(object):
weights: torch.Tensor
window_width: int = 250
def __post_init__(self):
# Keep a view of weights which is more suitable for visualization
weights = self.weights.detach()
if len(weights.shape) == 2:
num_neurons, num_inputs = weights.shape
depth = 1
else:
num_neurons, kernel_height, kernel_width, depth = weights.shape
num_inputs = kernel_height * kernel_width
self._Sy = int(np.sqrt(num_inputs))
self._Sx = num_inputs//self._Sy
assert self._Sx * self._Sy == num_inputs, \
f"Inputs and outputs should be square, got {num_inputs=}, {num_neurons=}"
self._Ky = int(np.sqrt(num_neurons))
self._Kx = num_neurons//self._Ky
self.weights = weights.view(self._Ky, self._Kx, self._Sy, self._Sx, depth)
# Initialize visualization window
pygame.init()
self.window_width = max(self.window_width, self._Kx * self._Sx)
self._screen = pygame.display.set_mode((self.window_width, self.window_width))
pygame.display.set_caption("Weights")
def update(self):
weights = self.weights.permute((0, 2, 1, 3, 4))
if weights.shape[-1] == 1:
weights = weights.tile((1, 1, 1, 1, 3))
weights = weights.reshape(self._Ky * self._Sy, self._Kx * self._Sx, 3)
weights = weights.cpu().numpy()
vmax = np.amax(np.abs(weights))
img = (weights + vmax) / (2 * vmax + 1e-10)
img = (255 * img).astype(np.uint8)
img = pygame.surfarray.make_surface(img)
img = pygame.transform.scale(img, (self.window_width, self.window_width))
img = pygame.transform.rotate(img, -90)
img = pygame.transform.flip(img, True, False)
self._screen.blit(img, (0, 0))
pygame.display.flip()