-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathtorch_funcs.py
More file actions
65 lines (47 loc) · 1.49 KB
/
torch_funcs.py
File metadata and controls
65 lines (47 loc) · 1.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# --------------------------------------------------------
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License
# --------------------------------------------------------
import torch
import torch.nn as nn
from torch import Tensor
def grl_hook(coeff):
def func_(grad):
return -coeff * grad.clone()
return func_
def entropy_func(x: Tensor) -> Tensor:
"""
x: [N, C]
return: entropy: [N,]
"""
epsilon = 1e-5
entropy = -x * torch.log(x + epsilon)
entropy = torch.sum(entropy, dim=1)
return entropy
def get_relation(x: Tensor) -> Tensor:
"""
:param x: [B, L, C]
:return: [B, L, L]
"""
x1 = x @ x.transpose(1, 2)
x_norm = x.norm(dim=-1, keepdim=True) # [B, L, 1]
x2 = x_norm @ x_norm.transpose(1, 2) # [B, L, L]
x2 = torch.max(x2, torch.ones_like(x2) * 1e-8)
return x1 / x2
# initialization used only for HDA
def init_weights_fc(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, a=100)
nn.init.zeros_(m.bias)
def init_weights_fc0(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
nn.init.zeros_(m.bias)
def init_weights_fc1(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, a=1)
nn.init.zeros_(m.bias)
def init_weights_fc2(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, a=2)
nn.init.zeros_(m.bias)