-
Notifications
You must be signed in to change notification settings - Fork 71
Expand file tree
/
Copy pathlayers.py
More file actions
439 lines (382 loc) · 17.7 KB
/
layers.py
File metadata and controls
439 lines (382 loc) · 17.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
import math
import torch
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from torch import nn
import torch.nn.functional as F
class GraphConvolutionBS(Module):
"""
GCN Layer with BN, Self-loop and Res connection.
"""
def __init__(self, in_features, out_features, activation=lambda x: x, withbn=True, withloop=True, bias=True,
res=False):
"""
Initial function.
:param in_features: the input feature dimension.
:param out_features: the output feature dimension.
:param activation: the activation function.
:param withbn: using batch normalization.
:param withloop: using self feature modeling.
:param bias: enable bias.
:param res: enable res connections.
"""
super(GraphConvolutionBS, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.sigma = activation
self.res = res
# Parameter setting.
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
# Is this the best practice or not?
if withloop:
self.self_weight = Parameter(torch.FloatTensor(in_features, out_features))
else:
self.register_parameter("self_weight", None)
if withbn:
self.bn = torch.nn.BatchNorm1d(out_features)
else:
self.register_parameter("bn", None)
if bias:
self.bias = Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.self_weight is not None:
stdv = 1. / math.sqrt(self.self_weight.size(1))
self.self_weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input, adj):
support = torch.mm(input, self.weight)
output = torch.spmm(adj, support)
# Self-loop
if self.self_weight is not None:
output = output + torch.mm(input, self.self_weight)
if self.bias is not None:
output = output + self.bias
# BN
if self.bn is not None:
output = self.bn(output)
# Res
if self.res:
return self.sigma(output) + input
else:
return self.sigma(output)
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
class GraphBaseBlock(Module):
"""
The base block for Multi-layer GCN / ResGCN / Dense GCN
"""
def __init__(self, in_features, out_features, nbaselayer,
withbn=True, withloop=True, activation=F.relu, dropout=True,
aggrmethod="concat", dense=False):
"""
The base block for constructing DeepGCN model.
:param in_features: the input feature dimension.
:param out_features: the hidden feature dimension.
:param nbaselayer: the number of layers in the base block.
:param withbn: using batch normalization in graph convolution.
:param withloop: using self feature modeling in graph convolution.
:param activation: the activation function, default is ReLu.
:param dropout: the dropout ratio.
:param aggrmethod: the aggregation function for baseblock, can be "concat" and "add". For "resgcn", the default
is "add", for others the default is "concat".
:param dense: enable dense connection
"""
super(GraphBaseBlock, self).__init__()
self.in_features = in_features
self.hiddendim = out_features
self.nhiddenlayer = nbaselayer
self.activation = activation
self.aggrmethod = aggrmethod
self.dense = dense
self.dropout = dropout
self.withbn = withbn
self.withloop = withloop
self.hiddenlayers = nn.ModuleList()
self.__makehidden()
if self.aggrmethod == "concat" and dense == False:
self.out_features = in_features + out_features
elif self.aggrmethod == "concat" and dense == True:
self.out_features = in_features + out_features * nbaselayer
elif self.aggrmethod == "add":
if in_features != self.hiddendim:
raise RuntimeError("The dimension of in_features and hiddendim should be matched in add model.")
self.out_features = out_features
elif self.aggrmethod == "nores":
self.out_features = out_features
else:
raise NotImplementedError("The aggregation method only support 'concat','add' and 'nores'.")
def __makehidden(self):
# for i in xrange(self.nhiddenlayer):
for i in range(self.nhiddenlayer):
if i == 0:
layer = GraphConvolutionBS(self.in_features, self.hiddendim, self.activation, self.withbn,
self.withloop)
else:
layer = GraphConvolutionBS(self.hiddendim, self.hiddendim, self.activation, self.withbn, self.withloop)
self.hiddenlayers.append(layer)
def _doconcat(self, x, subx):
if x is None:
return subx
if self.aggrmethod == "concat":
return torch.cat((x, subx), 1)
elif self.aggrmethod == "add":
return x + subx
elif self.aggrmethod == "nores":
return x
def forward(self, input, adj):
x = input
denseout = None
# Here out is the result in all levels.
for gc in self.hiddenlayers:
denseout = self._doconcat(denseout, x)
x = gc(x, adj)
x = F.dropout(x, self.dropout, training=self.training)
if not self.dense:
return self._doconcat(x, input)
return self._doconcat(x, denseout)
def get_outdim(self):
return self.out_features
def __repr__(self):
return "%s %s (%d - [%d:%d] > %d)" % (self.__class__.__name__,
self.aggrmethod,
self.in_features,
self.hiddendim,
self.nhiddenlayer,
self.out_features)
class MultiLayerGCNBlock(Module):
"""
Muti-Layer GCN with same hidden dimension.
"""
def __init__(self, in_features, out_features, nbaselayer,
withbn=True, withloop=True, activation=F.relu, dropout=True,
aggrmethod=None, dense=None):
"""
The multiple layer GCN block.
:param in_features: the input feature dimension.
:param out_features: the hidden feature dimension.
:param nbaselayer: the number of layers in the base block.
:param withbn: using batch normalization in graph convolution.
:param withloop: using self feature modeling in graph convolution.
:param activation: the activation function, default is ReLu.
:param dropout: the dropout ratio.
:param aggrmethod: not applied.
:param dense: not applied.
"""
super(MultiLayerGCNBlock, self).__init__()
self.model = GraphBaseBlock(in_features=in_features,
out_features=out_features,
nbaselayer=nbaselayer,
withbn=withbn,
withloop=withloop,
activation=activation,
dropout=dropout,
dense=False,
aggrmethod="nores")
def forward(self, input, adj):
return self.model.forward(input, adj)
def get_outdim(self):
return self.model.get_outdim()
def __repr__(self):
return "%s %s (%d - [%d:%d] > %d)" % (self.__class__.__name__,
self.aggrmethod,
self.model.in_features,
self.model.hiddendim,
self.model.nhiddenlayer,
self.model.out_features)
class ResGCNBlock(Module):
"""
The multiple layer GCN with residual connection block.
"""
def __init__(self, in_features, out_features, nbaselayer,
withbn=True, withloop=True, activation=F.relu, dropout=True,
aggrmethod=None, dense=None):
"""
The multiple layer GCN with residual connection block.
:param in_features: the input feature dimension.
:param out_features: the hidden feature dimension.
:param nbaselayer: the number of layers in the base block.
:param withbn: using batch normalization in graph convolution.
:param withloop: using self feature modeling in graph convolution.
:param activation: the activation function, default is ReLu.
:param dropout: the dropout ratio.
:param aggrmethod: not applied.
:param dense: not applied.
"""
super(ResGCNBlock, self).__init__()
self.model = GraphBaseBlock(in_features=in_features,
out_features=out_features,
nbaselayer=nbaselayer,
withbn=withbn,
withloop=withloop,
activation=activation,
dropout=dropout,
dense=False,
aggrmethod="add")
def forward(self, input, adj):
return self.model.forward(input, adj)
def get_outdim(self):
return self.model.get_outdim()
def __repr__(self):
return "%s %s (%d - [%d:%d] > %d)" % (self.__class__.__name__,
self.aggrmethod,
self.model.in_features,
self.model.hiddendim,
self.model.nhiddenlayer,
self.model.out_features)
class DenseGCNBlock(Module):
"""
The multiple layer GCN with dense connection block.
"""
def __init__(self, in_features, out_features, nbaselayer,
withbn=True, withloop=True, activation=F.relu, dropout=True,
aggrmethod="concat", dense=True):
"""
The multiple layer GCN with dense connection block.
:param in_features: the input feature dimension.
:param out_features: the hidden feature dimension.
:param nbaselayer: the number of layers in the base block.
:param withbn: using batch normalization in graph convolution.
:param withloop: using self feature modeling in graph convolution.
:param activation: the activation function, default is ReLu.
:param dropout: the dropout ratio.
:param aggrmethod: the aggregation function for the output. For denseblock, default is "concat".
:param dense: default is True, cannot be changed.
"""
super(DenseGCNBlock, self).__init__()
self.model = GraphBaseBlock(in_features=in_features,
out_features=out_features,
nbaselayer=nbaselayer,
withbn=withbn,
withloop=withloop,
activation=activation,
dropout=dropout,
dense=True,
aggrmethod=aggrmethod)
def forward(self, input, adj):
return self.model.forward(input, adj)
def get_outdim(self):
return self.model.get_outdim()
def __repr__(self):
return "%s %s (%d - [%d:%d] > %d)" % (self.__class__.__name__,
self.aggrmethod,
self.model.in_features,
self.model.hiddendim,
self.model.nhiddenlayer,
self.model.out_features)
class InecptionGCNBlock(Module):
"""
The multiple layer GCN with inception connection block.
"""
def __init__(self, in_features, out_features, nbaselayer,
withbn=True, withloop=True, activation=F.relu, dropout=True,
aggrmethod="concat", dense=False):
"""
The multiple layer GCN with inception connection block.
:param in_features: the input feature dimension.
:param out_features: the hidden feature dimension.
:param nbaselayer: the number of layers in the base block.
:param withbn: using batch normalization in graph convolution.
:param withloop: using self feature modeling in graph convolution.
:param activation: the activation function, default is ReLu.
:param dropout: the dropout ratio.
:param aggrmethod: the aggregation function for baseblock, can be "concat" and "add". For "resgcn", the default
is "add", for others the default is "concat".
:param dense: not applied. The default is False, cannot be changed.
"""
super(InecptionGCNBlock, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.hiddendim = out_features
self.nbaselayer = nbaselayer
self.activation = activation
self.aggrmethod = aggrmethod
self.dropout = dropout
self.withbn = withbn
self.withloop = withloop
self.midlayers = nn.ModuleList()
self.__makehidden()
if self.aggrmethod == "concat":
self.out_features = in_features + out_features * nbaselayer
elif self.aggrmethod == "add":
if in_features != self.hiddendim:
raise RuntimeError("The dimension of in_features and hiddendim should be matched in 'add' model.")
self.out_features = out_features
else:
raise NotImplementedError("The aggregation method only support 'concat', 'add'.")
def __makehidden(self):
# for j in xrange(self.nhiddenlayer):
for j in range(self.nbaselayer):
reslayer = nn.ModuleList()
# for i in xrange(j + 1):
for i in range(j + 1):
if i == 0:
layer = GraphConvolutionBS(self.in_features, self.hiddendim, self.activation, self.withbn,
self.withloop)
else:
layer = GraphConvolutionBS(self.hiddendim, self.hiddendim, self.activation, self.withbn,
self.withloop)
reslayer.append(layer)
self.midlayers.append(reslayer)
def forward(self, input, adj):
x = input
for reslayer in self.midlayers:
subx = input
for gc in reslayer:
subx = gc(subx, adj)
subx = F.dropout(subx, self.dropout, training=self.training)
x = self._doconcat(x, subx)
return x
def get_outdim(self):
return self.out_features
def _doconcat(self, x, subx):
if self.aggrmethod == "concat":
return torch.cat((x, subx), 1)
elif self.aggrmethod == "add":
return x + subx
def __repr__(self):
return "%s %s (%d - [%d:%d] > %d)" % (self.__class__.__name__,
self.aggrmethod,
self.in_features,
self.hiddendim,
self.nbaselayer,
self.out_features)
class Dense(Module):
"""
Simple Dense layer, Do not consider adj.
"""
def __init__(self, in_features, out_features, activation=lambda x: x, bias=True, res=False):
super(Dense, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.sigma = activation
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
self.res = res
self.bn = nn.BatchNorm1d(out_features)
if bias:
self.bias = Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input, adj):
output = torch.mm(input, self.weight)
if self.bias is not None:
output = output + self.bias
output = self.bn(output)
return self.sigma(output)
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'