-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathloss.py
More file actions
130 lines (102 loc) · 4.95 KB
/
loss.py
File metadata and controls
130 lines (102 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# --------------------------------------------------------
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License
# --------------------------------------------------------
import torch
from torch import Tensor
import torch.nn as nn
from utils.torch_funcs import grl_hook, entropy_func
class WeightBCE(nn.Module):
def __init__(self, epsilon: float = 1e-8) -> None:
super(WeightBCE, self).__init__()
self.epsilon = epsilon
def forward(self, x: Tensor, label: Tensor, weight: Tensor) -> Tensor:
"""
:param x: [N, 1]
:param label: [N, 1]
:param weight: [N, 1]
"""
label = label.float()
cross_entropy = - label * torch.log(x + self.epsilon) - (1 - label) * torch.log(1 - x + self.epsilon)
return torch.sum(cross_entropy * weight.float()) / 2.
def d_align_uda(softmax_output: Tensor, features: Tensor = None, d_net=None,
coeff: float = None, ent: bool = False):
loss_func = WeightBCE()
d_input = softmax_output if features is None else features
d_output = d_net(d_input, coeff=coeff)
d_output = torch.sigmoid(d_output)
batch_size = softmax_output.size(0) // 2
labels = torch.tensor([[1]] * batch_size + [[0]] * batch_size).long().cuda() # 2N x 1
if ent:
x = softmax_output
entropy = entropy_func(x)
entropy.register_hook(grl_hook(coeff))
entropy = torch.exp(-entropy)
source_mask = torch.ones_like(entropy)
source_mask[batch_size:] = 0
source_weight = entropy * source_mask
target_mask = torch.ones_like(entropy)
target_mask[:batch_size] = 0
target_weight = entropy * target_mask
weight = source_weight / torch.sum(source_weight).detach().item() + \
target_weight / torch.sum(target_weight).detach().item()
else:
weight = torch.ones_like(labels).float() / batch_size
loss_alg = loss_func.forward(d_output, labels, weight.view(-1, 1))
return loss_alg
def d_align_msda(softmax_output: Tensor, features: Tensor = None, d_net=None,
coeff: float = None, ent: bool = False, batchsizes: list = []):
d_input = softmax_output if features is None else features
d_output = d_net(d_input, coeff=coeff)
labels = torch.cat(
(torch.tensor([1] * batchsizes[0]).long(), torch.tensor([0] * batchsizes[1]).long()), 0
).cuda() # [B_S + B_T]
if ent:
x = softmax_output
entropy = entropy_func(x)
entropy.register_hook(grl_hook(coeff))
entropy = torch.exp(-entropy)
source_mask = torch.ones_like(entropy)
source_mask[batchsizes[0]:] = 0
source_weight = entropy * source_mask
target_mask = torch.ones_like(entropy)
target_mask[:batchsizes[0]] = 0
target_weight = entropy * target_mask
weight = source_weight / torch.sum(source_weight).detach().item() + \
target_weight / torch.sum(target_weight).detach().item()
else:
weight = torch.ones_like(labels).float() / softmax_output.shape[0]
loss_ce = nn.CrossEntropyLoss(reduction='none')(d_output, labels)
loss_alg = torch.sum(weight * loss_ce)
return loss_alg
class MMD(nn.Module):
def __init__(self, kernel_mul=2.0, kernel_num=5):
super(MMD, self).__init__()
self.kernel_num = kernel_num
self.kernel_mul = kernel_mul
self.fix_sigma = None
def _guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0]) + int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0 - total1) ** 2).sum(2)
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val)
def forward(self, source, target):
# number of used samples
batch_size = int(source.size()[0])
kernels = self._guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num,
fix_sigma=self.fix_sigma)
XX = kernels[:batch_size, :batch_size]
YY = kernels[batch_size:, batch_size:]
XY = kernels[:batch_size, batch_size:]
YX = kernels[batch_size:, :batch_size]
loss = torch.mean(XX) + torch.mean(YY) - torch.mean(XY) - torch.mean(YX)
return loss