-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathnetwork.py
More file actions
108 lines (97 loc) · 3.79 KB
/
network.py
File metadata and controls
108 lines (97 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torch.autograd import Variable
import math
import pdb
import torch.nn.utils.weight_norm as weightNorm
from collections import OrderedDict
def init_weights(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
nn.init.kaiming_uniform_(m.weight)
nn.init.zeros_(m.bias)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
elif classname.find('Linear') != -1:
nn.init.xavier_normal_(m.weight)
nn.init.zeros_(m.bias)
res_dict = {"resnet18":models.resnet18, "resnet34":models.resnet34, "resnet50":models.resnet50,
"resnet101":models.resnet101, "resnet152":models.resnet152, "resnext50":models.resnext50_32x4d, "resnext101":models.resnext101_32x8d}
class ResBase(nn.Module):
def __init__(self, res_name, pretrain=True):
super(ResBase, self).__init__()
model_resnet = res_dict[res_name](pretrained=pretrain)
self.conv1 = model_resnet.conv1
self.bn1 = model_resnet.bn1
self.relu = model_resnet.relu
self.maxpool = model_resnet.maxpool
self.layer1 = model_resnet.layer1
self.layer2 = model_resnet.layer2
self.layer3 = model_resnet.layer3
self.layer4 = model_resnet.layer4
self.avgpool = model_resnet.avgpool
self.in_features = model_resnet.fc.in_features
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
class feat_bootleneck(nn.Module):
def __init__(self, feature_dim, bottleneck_dim=256, type="ori"):
super(feat_bootleneck, self).__init__()
self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p=0.5)
self.bottleneck = nn.Linear(feature_dim, bottleneck_dim)
self.bottleneck.apply(init_weights)
self.type = type
def forward(self, x):
x = self.bottleneck(x)
if self.type == "bn" or self.type == "bn_relu" or self.type == "bn_relu_drop":
x = self.bn(x)
if self.type == "bn_relu" or self.type == "bn_relu_drop":
x = self.relu(x)
if self.type == "bn_relu_drop":
x = self.dropout(x)
return x
class feat_classifier(nn.Module):
def __init__(self, class_num, bottleneck_dim=256, type="linear"):
super(feat_classifier, self).__init__()
self.type = type
if type == 'wn':
self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight")
self.fc.apply(init_weights)
elif type == 'linear':
self.fc = nn.Linear(bottleneck_dim, class_num)
self.fc.apply(init_weights)
else:
self.fc = nn.Linear(bottleneck_dim, class_num, bias=False)
nn.init.xavier_normal_(self.fc.weight)
def forward(self, x):
if not self.type in {'wn', 'linear'}:
w = self.fc.weight
w = torch.nn.functional.normalize(w, dim=1, p=2)
x = torch.nn.functional.normalize(x, dim=1, p=2)
x = torch.nn.functional.linear(x, w)
else:
x = self.fc(x)
return x
class feat_classifier_simpl(nn.Module):
def __init__(self, class_num, feat_dim):
super(feat_classifier_simpl, self).__init__()
self.fc = nn.Linear(feat_dim, class_num)
nn.init.xavier_normal_(self.fc.weight)
def forward(self, x):
x = self.fc(x)
return x