-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimple_lstm.py
144 lines (116 loc) · 4.13 KB
/
simple_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import torch.nn as nn
from generate_data import *
import matplotlib.pyplot as plt
# https://github.com/jessicayung/blog-code-snippets/blob/master/lstm-pytorch/lstm-baseline.py
#####################
# Set parameters
#####################
# Data params
noise_var = 0
num_datapoints = 500
test_size = 0.2
num_train = int((1-test_size) * num_datapoints)
# Network params
input_size = 20
# If `per_element` is True, then LSTM reads in one timestep at a time.
per_element = True
if per_element:
lstm_input_size = 1
else:
lstm_input_size = input_size
# size of hidden layers
h1 = 32
output_dim = 1
num_layers = 2
learning_rate = 1e-3
num_epochs = 200
dtype = torch.float
#####################
# Generate data
#####################
data = ARData(num_datapoints, num_prev=input_size, test_size=test_size, noise_var=noise_var, coeffs=fixed_ar_coefficients[input_size])
print(data.X_train.shape)
print(data.y_train.shape)
print(data.X_test.shape)
print(data.y_test.shape)
# make training and test sets in torch
X_train = torch.from_numpy(data.X_train).type(torch.Tensor)
X_test = torch.from_numpy(data.X_test).type(torch.Tensor)
y_train = torch.from_numpy(data.y_train).type(torch.Tensor).view(-1)
y_test = torch.from_numpy(data.y_test).type(torch.Tensor).view(-1)
X_train = X_train.view([input_size, -1, 1])
X_test = X_test.view([input_size, -1, 1])
class LSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, batch_size, output_dim=1,
num_layers=2):
super(LSTM, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.num_layers = num_layers
# Define the LSTM layer
self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers)
# Define the output layer
self.linear = nn.Linear(self.hidden_dim, output_dim)
def init_hidden(self):
# This is what we'll initialise our hidden state as
return (torch.zeros(self.num_layers, self.batch_size, self.hidden_dim),
torch.zeros(self.num_layers, self.batch_size, self.hidden_dim))
def forward(self, input):
# Forward pass through LSTM layer
# shape of lstm_out: [input_size, batch_size, hidden_dim]
# shape of self.hidden: (a, b), where a and b both
# have shape (num_layers, batch_size, hidden_dim).
lstm_out, self.hidden = self.lstm(input.view(len(input), self.batch_size, -1))
#print(lstm_out[-1].view(self.batch_size, -1))
# Only take the output from the final timetep
# Can pass on the entirety of lstm_out to the next layer if it is a seq2seq prediction
y_pred = self.linear(lstm_out[-1].view(self.batch_size, -1))
return y_pred.view(-1)
'''
lstm_input_size=20
h1 = 32
output_dim = 1
num_layers = 2
learning_rate = 1e-3
num_epochs = 200
num_train = 400
batch_size = 400
'''
model = LSTM(lstm_input_size, h1, batch_size=num_train, output_dim=output_dim, num_layers=num_layers)
print(model)
loss_fn = torch.nn.MSELoss(size_average=False)
optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)
#####################
# Train model
#####################
hist = np.zeros(num_epochs)
for t in range(num_epochs):
# Clear stored gradient
model.zero_grad()
# Initialise hidden state
# Don't do this if you want your LSTM to be stateful
model.hidden = model.init_hidden()
# Forward pass
y_pred = model(X_train)
loss = loss_fn(y_pred, y_train)
if t % 100 == 0:
print("Epoch ", t, "MSE: ", loss.item())
hist[t] = loss.item()
# Zero out gradient, else they will accumulate between epochs
optimiser.zero_grad()
# Backward pass
loss.backward()
# Update parameters
optimiser.step()
#####################
# Plot preds and performance
#####################
plt.plot(y_pred.detach().numpy(), label="Preds")
plt.plot(y_train.detach().numpy(), label="Data")
plt.legend()
plt.show()
plt.plot(hist, label="Training loss")
plt.legend()
plt.show()