-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgcn_utils.py
70 lines (51 loc) · 2.09 KB
/
gcn_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class InstanceGCN(nn.Module):
def __init__(self, trigger_dim, entity_dim, num_layers=2):
super(InstanceGCN, self).__init__()
self.num_layers = num_layers
self.trigger_dim = trigger_dim
self.entity_dim = entity_dim
self.unary_dim = self.trigger_dim
# gcn layer
self.T_T = nn.ModuleList()
self.T_E = nn.ModuleList()
self.E_T = nn.ModuleList()
self.E_E = nn.ModuleList()
for _ in range(self.num_layers):
self.T_T.append(nn.Linear(self.unary_dim, self.unary_dim))
self.T_E.append(nn.Linear(self.unary_dim, self.unary_dim))
self.E_T.append(nn.Linear(self.unary_dim, self.unary_dim))
self.E_E.append(nn.Linear(self.unary_dim, self.unary_dim))
# forget gates
self.f_t = nn.Sequential(
nn.Linear(self.trigger_dim * 2, self.trigger_dim),
nn.Sigmoid()
)
self.f_e = nn.Sequential(
nn.Linear(self.entity_dim * 2, self.entity_dim),
nn.Sigmoid()
)
def forward(self, A, T, E, t_n, e_n):
'''
T.shape = [bs, t_n, d]
E.shape = [bs, e_n, d]
n = t_n + e_n
'''
n_1 = t_n
n_2 = t_n + e_n
A_t = A[:, :n_1, :] # [bs, t_n, n]
A_e = A[:, n_1:n_2, :] # [bs, e_n, n]
for l in range(self.num_layers):
new_T = F.relu(self.T_T[l](A_t[:, :, :n_1].bmm(T)) + \
self.T_E[l](A_t[:, :, n_1: n_2].bmm(E)))
new_E = F.relu(self.E_T[l](A_e[:, :, :n_1].bmm(T)) + \
self.E_E[l](A_e[:, :, n_1: n_2].bmm(E)))
forget_T = self.f_t(torch.cat([new_T, T], dim=2)) # [bs, t_n, d]
T = forget_T * new_T + (1 - forget_T) * T
forget_E = self.f_e(torch.cat([new_E, E], dim=2)) # [bs, e_n, d]
E = forget_E * new_E + (1 - forget_E) * E
return T, E