-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmade.py
52 lines (40 loc) · 1.54 KB
/
made.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from linear_masked import LinearMasked
class MADE(nn.Module):
""" MADE model for binary image dataset. """
def __init__(self, input_dim, use_cuda=True):
super().__init__()
self.input_dim = input_dim
self.device = torch.device('cuda') if use_cuda else None
self.net = nn.Sequential(
LinearMasked(input_dim, input_dim), nn.ReLU(),
LinearMasked(input_dim, input_dim), nn.ReLU(),
LinearMasked(input_dim, input_dim)
)
self.apply_masks()
def forward(self, x):
return self.net(x)
def set_mask(self, mask):
self.mask = torch.from_numpy(mask.astype(np.uint8)).to(self.device)
# self.mask.data.copy_(torch.from_numpy(mask.astype(np.uint8)))
def apply_masks(self):
# Set order of masks, i.e. who can make which edges
# Using natural ordering
order1 = np.arange(self.input_dim)
order2 = np.arange(self.input_dim)
order3 = np.arange(self.input_dim)
# Construct the mask matrices
masks = []
m1 = (order1[:, None] <= order2[None,:]).T
m2 = (order2[:, None] <= order3[None,:]).T
m3 = (order2[:,None] < order3[None,:]).T
masks.append(m1)
masks.append(m2)
masks.append(m3)
# Set the masks in all LinearMasked layers
layers = [l for l in self.net.modules() if isinstance(l, LinearMasked)]
for l, m in zip(layers, masks):
l.set_mask(m)