-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexample.py
More file actions
156 lines (121 loc) · 4.94 KB
/
example.py
File metadata and controls
156 lines (121 loc) · 4.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# -*- coding: UTF-8 -*-
"""
@author: Luca Bondi ([email protected])
@author: Paolo Bestagini ([email protected])
@author: Nicolò Bonettini ([email protected])
Politecnico di Milano 2018
"""
import os
from glob import glob
from multiprocessing import cpu_count, Pool
import numpy as np
from PIL import Image
import time
import prnu
import sys
import matplotlib.pyplot as plt
import argparse
parser = argparse.ArgumentParser(description='This program extracts camera fingerprint using VDNet and VDID and compares them with the original implementation')
parser.add_argument("-denoiser", help="[original (default) | vdnet | vdid]", default='original')
parser.add_argument("-rm_zero_mean", help='Removes zero mean normalization', action='store_true',
default=False)
parser.add_argument("-rm_wiener", help='Removes Wiener filter', action='store_true',
default=False)
args = parser.parse_args()
def main():
"""
Main example script. Load a subset of flatfield and natural images from Dresden.
For each device compute the fingerprint from all the flatfield images.
For each natural image compute the noise residual.
Check the detection performance obtained with cross-correlation and PCE
:return:
"""
start = time.time()
denoiser = args.denoiser
remove_zero_m = args.rm_zero_mean
remove_wiener = args.rm_wiener
prnu.define_param(denoiser, remove_zero_m, remove_wiener)
print('Denoiser: ' + denoiser)
print('Remove zero mean: ' + str(remove_zero_m))
print('Remove wiener: ' + str(remove_wiener) + '\n')
ff_dirlist = np.array(sorted(glob('test/data/ff/*.jpg')))
ff_device = np.array([os.path.split(i)[1].rsplit('_', 1)[0] for i in ff_dirlist])
nat_dirlist = np.array(sorted(glob('test/data/nat/*.jpg')))
nat_device = np.array([os.path.split(i)[1].rsplit('_', 1)[0] for i in nat_dirlist])
print('Computing fingerprints')
fingerprint_device = sorted(np.unique(ff_device))
k = []
for device in fingerprint_device:
imgs = []
for img_path in ff_dirlist[ff_device == device]:
im = Image.open(img_path)
im_arr = np.asarray(im)
if im_arr.dtype != np.uint8:
print('Error while reading image: {}'.format(img_path))
continue
if im_arr.ndim != 3:
print('Image is not RGB: {}'.format(img_path))
continue
im_cut = prnu.cut_ctr(im_arr, (512, 512, 3))
imgs += [im_cut]
k += [prnu.extract_multiple_aligned(imgs, processes=1)]
k = np.stack(k, 0)
print('Computing residuals')
imgs = []
for img_path in nat_dirlist:
imgs += [prnu.cut_ctr(np.asarray(Image.open(img_path)), (512, 512, 3))]
w = []
for img in imgs:
w.append(prnu.extract_single(img))
w = np.stack(w, 0)
# Computing Ground Truth
gt = prnu.gt(fingerprint_device, nat_device)
print('Computing cross correlation')
cc_aligned_rot = prnu.aligned_cc(k, w)['cc']
print('Computing statistics cross correlation')
stats_cc = prnu.stats(cc_aligned_rot, gt)
print('Computing PCE')
pce_rot = np.zeros((len(fingerprint_device), len(nat_device)))
for fingerprint_idx, fingerprint_k in enumerate(k):
tn, tp, fp, fn = 0, 0, 0, 0
pce_values = []
natural_indices = []
for natural_idx, natural_w in enumerate(w):
cc2d = prnu.crosscorr_2d(fingerprint_k, natural_w)
prnu_pce = prnu.pce(cc2d)['pce']
pce_rot[fingerprint_idx, natural_idx] = prnu_pce
pce_values.append(prnu_pce)
natural_indices.append(natural_idx)
if fingerprint_device[fingerprint_idx] == nat_device[natural_idx]:
if prnu_pce > 60.:
tp += 1.
else:
fn += 1.
else:
if prnu_pce > 60.:
fp += 1.
else:
tn += 1.
tpr = tp / (tp + fn)
fpr = fp / (fp + tn)
plt.title('PRNU for ' + str(fingerprint_device[fingerprint_idx]) + ' - ' + denoiser)
plt.xlabel('query images')
plt.ylabel('PRNU')
plt.bar(natural_indices, pce_values)
plt.text(0.85, 0.85, 'TPR: ' + str(round(tpr, 2)) + '\nFPR: '+ str(round(fpr, 2)),
fontsize=10, color='k',
ha='left', va='bottom',
transform=plt.gca().transAxes)
plt.axhline(y=60, color='r', linestyle='-')
plt.xticks(natural_indices)
plt.savefig('plots/'+ denoiser + '/' +str(fingerprint_device[fingerprint_idx])+'.png')
plt.clf()
print('Computing statistics on PCE')
stats_pce = prnu.stats(pce_rot, gt)
print('AUC on CC {:.2f}'.format(stats_cc['auc']))
print('AUC on PCE {:.2f}'.format(stats_pce['auc']))
end = time.time()
elapsed = int(end - start)
print('Elapsed time: '+ str(elapsed) + ' seconds')
if __name__ == '__main__':
main()