-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloss.py
More file actions
83 lines (63 loc) · 2.57 KB
/
loss.py
File metadata and controls
83 lines (63 loc) · 2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
import torch.nn.functional as F
def imq_kernel(X: torch.Tensor,
Y: torch.Tensor,
h_dim: int):
batch_size = X.size(0)
norms_x = X.pow(2).sum(1, keepdim=True) # batch_size x 1
prods_x = torch.mm(X, X.t()) # batch_size x batch_size
dists_x = norms_x + norms_x.t() - 2 * prods_x
norms_y = Y.pow(2).sum(1, keepdim=True) # batch_size x 1
prods_y = torch.mm(Y, Y.t()) # batch_size x batch_size
dists_y = norms_y + norms_y.t() - 2 * prods_y
dot_prd = torch.mm(X, Y.t())
dists_c = norms_x + norms_y.t() - 2 * dot_prd
stats = 0
for scale in [.1, .2, .5, 1., 2., 5., 10.]:
C = 2 * h_dim * 1.0 * scale
res1 = C / (C + dists_x)
res1 += C / (C + dists_y)
if torch.cuda.is_available():
res1 = (1 - torch.eye(batch_size).cuda()) * res1
else:
res1 = (1 - torch.eye(batch_size)) * res1
res1 = res1.sum() / (batch_size - 1)
res2 = C / (C + dists_c)
res2 = res2.sum() * 2. / (batch_size)
stats += res1 - res2
return stats
def rbf_kernel(X: torch.Tensor,
Y: torch.Tensor,
h_dim: int):
batch_size = X.size(0)
norms_x = X.pow(2).sum(1, keepdim=True) # batch_size x 1
prods_x = torch.mm(X, X.t()) # batch_size x batch_size
dists_x = norms_x + norms_x.t() - 2 * prods_x
norms_y = Y.pow(2).sum(1, keepdim=True) # batch_size x 1
prods_y = torch.mm(Y, Y.t()) # batch_size x batch_size
dists_y = norms_y + norms_y.t() - 2 * prods_y
dot_prd = torch.mm(X, Y.t())
dists_c = norms_x + norms_y.t() - 2 * dot_prd
stats = 0
for scale in [.1, .2, .5, 1., 2., 5., 10.]:
C = 2 * h_dim * 1.0 / scale
res1 = torch.exp(-C * dists_x)
res1 += torch.exp(-C * dists_y)
if torch.cuda.is_available():
res1 = (1 - torch.eye(batch_size).cuda()) * res1
else:
res1 = (1 - torch.eye(batch_size)) * res1
res1 = res1.sum() / (batch_size - 1)
res2 = torch.exp(-C * dists_c)
res2 = res2.sum() * 2. / batch_size
stats += res1 - res2
return stats
def jenson_shannon_divergence(net_1_logits, net_2_logits):
net_1_probs = F.softmax(net_1_logits, dim=0)
net_2_probs = F.softmax(net_2_logits, dim=0)
total_m = 0.5 * (net_1_probs + net_2_probs)
loss = 0.0
loss += F.kl_div(F.log_softmax(net_1_logits, dim=0), total_m, reduction="batchmean")
loss += F.kl_div(F.log_softmax(net_2_logits, dim=0), total_m, reduction="batchmean")
return (0.5 * loss)