-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy patharmaplus.py
More file actions
76 lines (55 loc) · 2.34 KB
/
armaplus.py
File metadata and controls
76 lines (55 loc) · 2.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import Callable, Optional
import torch
from torch.nn import Parameter, Dropout, ReLU
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import add_self_loops, degree
from scipy.special import factorial
from TDConv import TDConv
class ARMAPlusConv(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, init_t: float,
num_stacks: int = 1, num_layers: int = 1,
shared_weights: bool = False,
act: Optional[Callable] = ReLU(), dropout: float = 0.,
bias: bool = True):
super(ARMAPlusConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.num_stacks = num_stacks
self.num_layers = num_layers
self.act = act
self.shared_weights = shared_weights
assert(num_layers == 1)
self.diffusion = TDConv(in_channels, init_t)
K, T, F_in, F_out = num_stacks, num_layers, in_channels, out_channels
self.init_weight = Parameter(torch.Tensor(K, F_in, F_out))
self.root_weight = Parameter(torch.Tensor(T, K, F_in, F_out))
self.bias = Parameter(torch.Tensor(T, K, 1, F_out))
self.dropout = Dropout(p=dropout)
self.reset_parameters()
def forward(self, x, edge_index, edge_weight=None):
# edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# # Step 2: Linearly transform node feature matrix.
# # Step 3: Compute normalization.
# row, col = edge_index
# deg = degree(col, x.size(0), dtype=x.dtype)
# deg_inv_sqrt = deg.pow(-0.5)
# norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
x = self.diffusion(x, edge_index)
x = x.unsqueeze(-3)
out = x
out = out @ self.init_weight
root = self.dropout(x)
out += root @ self.root_weight[0]
out += self.bias[0]
out = self.act(out)
return out.mean(dim=-3)
def reset_parameters(self):
glorot(self.init_weight)
glorot(self.root_weight)
zeros(self.bias)
self.diffusion.reset_parameters()
# def message(self, x_j, norm):
# # x_j has shape [E, out_channels]
# # Step 4: Normalize node features.
# return norm.view(-1, 1) * x_j