-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
122 lines (105 loc) · 4.95 KB
/
utils.py
File metadata and controls
122 lines (105 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.utils.data.dataset import Dataset
import random
import matplotlib.pyplot as plt
import os
import math
import torch.nn as nn
from skimage import measure
import torch.nn.functional as F
import os
from torch.nn import init
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
def weights_init_xavier(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1 and classname.find('SplAtConv2d') == -1:
init.xavier_normal(m.weight.data)
def weights_init_kaiming(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm') != -1:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
class Get_gradient_nopadding(nn.Module):
def __init__(self):
super(Get_gradient_nopadding, self).__init__()
kernel_v = [[0, -1, 0],
[0, 0, 0],
[0, 1, 0]]
kernel_h = [[0, 0, 0],
[-1, 0, 1],
[0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False).cuda()
self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False).cuda()
def forward(self, x):
x0 = x[:, 0]
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=1)
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=1)
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
return x0
def random_crop(img, mask, patch_size):
h, w = img.shape
if min(h, w) < patch_size:
img = np.pad(img, ((0, max(h, patch_size)-h),(0, max(w, patch_size)-w)), mode='constant')
mask = np.pad(mask, ((0, max(h, patch_size)-h),(0, max(w, patch_size)-w)), mode='constant')
h, w = img.shape
h_start = random.randint(0, h - patch_size)
h_end = h_start + patch_size
w_start = random.randint(0, w - patch_size)
w_end = w_start + patch_size
img_patch = img[h_start:h_end, w_start:w_end]
mask_patch = mask[h_start:h_end, w_start:w_end]
return img_patch, mask_patch
def Normalized(img, img_norm_cfg):
return (img-img_norm_cfg['mean'])/img_norm_cfg['std']
def Denormalization(img, img_norm_cfg):
return img*img_norm_cfg['std']+img_norm_cfg['mean']
def get_img_norm_cfg(dataset_name, dataset_dir):
if 'NUAA-SIRST' in dataset_name:
img_norm_cfg = dict(mean=101.06385040283203, std=34.619606018066406)
elif 'NUDT-SIRST-Sea'in dataset_name:
img_norm_cfg = dict(mean=43.62403869628906, std=18.91838264465332)
elif 'NUDT-SIRST' in dataset_name:
img_norm_cfg = dict(mean=107.80905151367188, std=33.02274703979492)
elif 'IRSTD-1K' in dataset_name:
img_norm_cfg = dict(mean=87.4661865234375, std=39.71953201293945)
elif 'SIRST4' in dataset_name:
img_norm_cfg = dict(mean=62.10432052612305, std=23.96998405456543)
else:
with open(dataset_dir+'/img_idx/train_' + dataset_name + '.txt', 'r') as f:
train_list = f.read().splitlines()
with open(dataset_dir+'/img_idx/test_' + dataset_name + '.txt', 'r') as f:
test_list = f.read().splitlines()
img_list = train_list + test_list
img_dir = dataset_dir + '/images/'
mean_list = []
std_list = []
for img_pth in img_list:
img = Image.open(img_dir + img_pth).convert('I')
img = np.array(img, dtype=np.float32)
mean_list.append(img.mean())
std_list.append(img.std())
img_norm_cfg = dict(mean=float(np.array(mean_list).mean()), std=float(np.array(std_list).mean()))
return img_norm_cfg
def get_optimizer(net, optimizer_name, scheduler_name, optimizer_settings, scheduler_settings):
if optimizer_name == 'Adam':
optimizer = torch.optim.Adam(net.parameters(), lr=optimizer_settings['lr'])
elif optimizer_name == 'Adagrad':
optimizer = torch.optim.Adagrad(net.parameters(), lr=optimizer_settings['lr'])
elif optimizer_name == 'SGD':
# optimizer = torch.optim.SGD(net.parameters(), lr=optimizer_settings['lr'])
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)
if scheduler_name == 'MultiStepLR':
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=scheduler_settings['step'], gamma=scheduler_settings['gamma'])
elif scheduler_name == 'CosineAnnealingLR':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_settings['epochs'], eta_min=scheduler_settings['min_lr'])
return optimizer, scheduler