-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_sEMG.py
More file actions
103 lines (90 loc) · 3.26 KB
/
eval_sEMG.py
File metadata and controls
103 lines (90 loc) · 3.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Last Modified: 2021.08.31
import numpy as np
import torch
import librosa
from scipy.stats.stats import pearsonr
import scipy.signal
def cal_snr(clean,enhanced,dtype='numpy'):
if dtype == 'numpy':
noise = enhanced - clean
noise_pw = np.dot(noise,noise)
signal_pw = np.dot(clean,clean)
SNR = 10*np.log10(signal_pw/noise_pw)
else:
noise = enhanced - clean
noise_pw = torch.sum(noise*noise,1)
signal_pw = torch.sum(clean*clean,1)
SNR = torch.mean(10*torch.log10(signal_pw/noise_pw)).item()
return round(SNR,3)
def cal_rmse(clean,enhanced,dtype='numpy'):
if dtype == 'numpy':
RMSE = np.sqrt(((enhanced - clean) ** 2).mean())
else:
RMSE = torch.sqrt(torch.mean(torch.square(enhanced - clean))).item()
return round(RMSE,6)
def cal_prd(clean,enhanced,dtype='numpy'):
if dtype == 'numpy':
PRD = np.sqrt(np.sum((enhanced - clean) ** 2) / np.sum(clean ** 2)) * 100
else:
PRD = torch.mul(torch.sqrt(torch.div(torch.sum(torch.square(enhanced - clean)),torch.sum(torch.square(clean)))),100).item()
return round(PRD,3)
def cal_R2(clean,enhanced):
R2 = pearsonr(clean,enhanced)[0]**2
return round(R2,3)
def cal_CC(clean,enhanced):
CC = np.correlate(clean,enhanced)[0]
return round(CC,3)
def cal_ARV(emg):
win = 1000
ARV = []
emg = abs(emg)
for i in range(0,emg.shape[0],win):
ARV.append((emg[i:i+win]).mean())
return np.array(ARV)
def cal_KR(x):
bins = np.linspace(-5,5,1001)
pdf, _ = np.histogram(normalize(x),bins,density=True) # _ is bin
cdf= np.cumsum(pdf)/np.sum(pdf)
KR = (find_nearest(cdf,0.975)-find_nearest(cdf,0.025))/(find_nearest(cdf,0.75)-find_nearest(cdf,0.25))-2.91
bin_centers = 0.5*(bins[1:] + bins[:-1])
return KR
def cal_MF(emg,stimulus):
# 10 - 500Hz mean frequency
freq = librosa.fft_frequencies(sr=1000,n_fft=256)
start = next(i for i,v in enumerate(freq) if v >=10)
freq = np.expand_dims(freq[start:],1)
spec = make_spectrum(emg,feature_type=None)[0][start:,:]
weighted_f = np.sum(freq*spec,0)
spec_column_pow = np.sum(spec,0)
MF = weighted_f / spec_column_pow
MF = [MF[i] for i,v in enumerate(stimulus[::32]) if v>0]
return np.array(MF)
def find_nearest(array, value):
idx = (np.abs(array - value)).argmin()
return idx
def normalize(x):
return (x-x.mean())/np.std(x)
def make_spectrum(y=None, is_slice=False, feature_type='logmag', mode=None, FRAMELENGTH=None,
SHIFT=None, _max=None, _min=None):
D = librosa.stft(y, center=True, n_fft=256, hop_length=32, win_length=128, window=scipy.signal.get_window('hamming', 128))
utt_len = D.shape[-1]
phase = np.exp(1j * np.angle(D))
D = np.abs(D)
### Feature type
if feature_type == 'logmag':
Sxx = np.log1p(D)
elif feature_type == 'lps':
Sxx = np.log10(D**2)
else:
Sxx = D
### normalizaiton mode
if mode == 'mean_std':
# mean = np.mean(Sxx, axis=1).reshape(((hp.n_fft//2)+1, 1))
# std = np.std(Sxx, axis=1).reshape(((hp.n_fft//2)+1, 1))+1e-12
# Sxx = (Sxx-mean)/std
Sxx = normalize(Sxx) #meaningless
elif mode == 'minmax':
_min = np.max(Sxx)
_max = np.min(Sxx)
Sxx = (Sxx - _min)/(_max - _min)
return Sxx, phase, len(y)