forked from gzhu06/Y-vector
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtdnn.py
More file actions
160 lines (127 loc) · 5.27 KB
/
tdnn.py
File metadata and controls
160 lines (127 loc) · 5.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# reference from cvqluu using unfold method,
# (Based on my experience, it's faster than directly using dilated CNN)
# https://github.com/cvqluu/TDNN/blob/master/tdnn.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
class TDNNLayer(nn.Module):
def __init__(self, input_dim, output_dim,
context_size, dilation=1):
'''
TDNN as defined by https://www.danielpovey.com/files/2015_interspeech_multisplice.pdf
Affine transformation not applied globally to all frames but smaller windows with local context
batch_norm: True to include batch normalisation after the non linearity
Context size and dilation determine the frames selected
(although context size is not really defined in the traditional sense)
For example:
context size 5 and dilation 1 is equivalent to [-2,-1,0,1,2]
context size 3 and dilation 2 is equivalent to [-2, 0, 2]
context size 1 and dilation 1 is equivalent to [0]
'''
super(TDNNLayer, self).__init__()
self.context_size = context_size
self.input_dim = input_dim
self.output_dim = output_dim
self.dilation = dilation
self.kernel = nn.Linear(input_dim*context_size, output_dim)
def forward(self, inputs):
'''
input: size (batch, input_features, seq_len)
outpu: size (batch, new_seq_len, output_features)
'''
# ----------Convolution = unfold + matmul + fold
x = inputs
_, d, _ = x.shape
assert (d == self.input_dim), 'Input dimension was wrong. Expected ({}), got ({})'.format(self.input_dim, d)
x = x.unsqueeze(1)
# Unfold input into smaller temporal contexts
x = F.unfold(x, (self.input_dim, self.context_size),
stride=(self.input_dim, 1),
dilation=(1, self.dilation))
# N, output_dim*context_size, new_t = x.shape
x = x.transpose(1, 2)
x = self.kernel(x) # matmul
# transpose to channel first
x = x.transpose(1, 2)
return x
class TDNNBlock(nn.Module):
def __init__(self, input_dim, bn_dim,
skip, context_size, dilation=1,
bottleneck=False):
'''
TDNNBlock
'''
super(TDNNBlock, self).__init__()
# bn conv
self.bottleneck = bottleneck
if bottleneck:
self.bnconv1d = nn.Conv1d(input_dim, bn_dim, 1)
self.nonlinear1 = nn.PReLU()
self.norm1 = nn.GroupNorm(1, bn_dim, eps=1e-08)
self.tdnnblock = TDNNLayer(bn_dim, input_dim, context_size, dilation)
else:
self.tdnnblock = TDNNLayer(input_dim, input_dim, context_size, dilation)
# tdnn
self.nonlinear2 = nn.PReLU()
self.norm2 = nn.GroupNorm(1, input_dim, eps=1e-08)
# skip connection
self.skip = skip
if self.skip:
self.skip_out = nn.MaxPool1d(kernel_size=context_size,
stride=1, dilation=dilation)
def forward(self, x):
'''
input: size (batch, seq_len, input_features)
outpu: size (batch, new_seq_len, output_features)
'''
out = x
if self.bottleneck:
out = self.nonlinear1(self.bnconv1d(out))
out = self.norm1(out)
out = self.nonlinear2(self.tdnnblock(out))
out = self.norm2(out)
if self.skip:
skip = self.skip_out(x)
return out, skip
else:
return out
class TDNN(nn.Module):
def __init__(self, filter_dim, input_dim, bn_dim,
skip, context_size=3, layer=9, stack=1,
bottleneck=False):
'''
stacked TDNN Blocks
'''
super(TDNN, self).__init__()
# # BottleNeck Layer
# self.LN = nn.GroupNorm(1, filter_dim, eps=1e-8)
# self.BN_conv = nn.Conv1d(filter_dim, input_dim, 1)
# Residual Connection
self.skip = skip
# TDNN for feature extraction
self.receptive_field = 0
self.tdnn = nn.ModuleList([])
for s in range(stack):
for i in range(layer):
self.tdnn.append(TDNNBlock(input_dim, bn_dim, self.skip,
context_size=3, dilation=2**i,
bottleneck=bottleneck))
if i == 0 and s == 0:
self.receptive_field += context_size
else:
self.receptive_field += (context_size - 1) * 2 ** i
print("Receptive field: {:3d} frames.".format(self.receptive_field))
def forward(self, x):
'''
input: size (batch, seq_len, input_features)
outpu: size (batch, new_seq_len, output_features)
'''
# output = self.BN_conv(self.LN(x))
for i in range(len(self.tdnn)):
if self.skip:
output, skips = self.tdnn[i](x)
output = skips + output
else:
output = self.tdnn[i](output)
return output