forked from lingjivoo/OpenGraphAU
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
90 lines (73 loc) · 3.57 KB
/
dataset.py
File metadata and controls
90 lines (73 loc) · 3.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
import random
from PIL import Image
from torch.utils.data import Dataset
import os
def make_dataset(image_list, label_list, au_relation=None):
len_ = len(image_list)
if au_relation is not None:
images = [(image_list[i].strip(), label_list[i, :],au_relation[i,:]) for i in range(len_)]
else:
images = [(image_list[i].strip(), label_list[i, :]) for i in range(len_)]
return images
def pil_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def default_loader(path):
return pil_loader(path)
class HybridDataset(Dataset):
def __init__(self, root_path, phase='train', transform=None, stage=1, loader=default_loader):
assert stage>0 and stage <=2, 'The stage num must be restricted from 1 to 2'
assert phase in ['train', 'val', 'test'], 'phase must be train, val or test'
self._root_path = root_path
self._phase = phase
self._stage = stage
self._transform = transform
self.loader = loader
self.img_folder_path = os.path.join(root_path,'img')
if self._phase == 'train':
# img
train_image_list_path = os.path.join(root_path, 'list', 'hybrid_train_img_path.txt')
train_image_list = open(train_image_list_path).readlines()
# img labels
train_label_list_path = os.path.join(root_path, 'list', 'hybrid_train_label.txt')
train_label_list = np.loadtxt(train_label_list_path)
# AU relation
if self._stage == 2:
au_relation_list_path = os.path.join(root_path, 'list', 'hybrid_train_AU_relation.txt')
au_relation_list = np.loadtxt(au_relation_list_path)
self.data_list = make_dataset(train_image_list, train_label_list, au_relation_list)
else:
self.data_list = make_dataset(train_image_list, train_label_list)
elif self._phase == 'val':
# img
eval_image_list_path = os.path.join(root_path, 'list', 'hybrid_val_img_path.txt')
eval_image_list = open(eval_image_list_path).readlines()
# img labels
eval_label_list_path = os.path.join(root_path, 'list', 'hybrid_val_label.txt')
eval_label_list = np.loadtxt(eval_label_list_path)
self.data_list = make_dataset(eval_image_list, eval_label_list)
else:
# img
eval_image_list_path = os.path.join(root_path, 'list', 'hybrid_test_img_path.txt')
eval_image_list = open(eval_image_list_path).readlines()
# img labels
eval_label_list_path = os.path.join(root_path, 'list', 'hybrid_test_label.txt')
eval_label_list = np.loadtxt(eval_label_list_path)
self.data_list = make_dataset(eval_image_list, eval_label_list)
def __getitem__(self, index):
if self._stage == 2 and self._phase == 'train':
img, label, au_relation = self.data_list[index]
img = self.loader(os.path.join(self.img_folder_path, img))
if self._transform is not None:
img = self._transform(img)
return img, label, au_relation
else:
img, label = self.data_list[index]
img = self.loader(os.path.join(self.img_folder_path, img))
if self._transform is not None:
img = self._transform(img)
return img, label
def __len__(self):
return len(self.data_list)