forked from facebookarchive/loop
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate.py
More file actions
153 lines (114 loc) · 4.32 KB
/
generate.py
File metadata and controls
153 lines (114 loc) · 4.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import nltk
import argparse
import random
import numpy as np
from string import digits
import torch
from torch.autograd import Variable
from model import Loop
from data import NpzFolder
from utils import generate_merlin_wav
parser = argparse.ArgumentParser(description='PyTorch Phonological Loop \
Generation')
parser.add_argument('--npz', type=str, default='',
help='Dataset sample to generate.')
parser.add_argument('--text', default='',
type=str, help='Free text to generate.')
parser.add_argument('--spkr', default=0,
type=int, help='Speaker id.')
parser.add_argument('--checkpoint', default='checkpoints/vctk/lastmodel.pth',
type=str, help='Model used for generation.')
parser.add_argument('--gpu', default=-1,
type=int, help='GPU device ID, use -1 for CPU.')
# init
args = parser.parse_args()
if args.gpu >= 0:
torch.cuda.set_device(args.gpu)
def text2phone(text, char2code):
cmudict = nltk.corpus.cmudict.dict()
result = []
for word in text.split():
result += random.choice(cmudict[word])
result = [str(ph.lower()).translate(None, digits) for ph in result]
result = [char2code[ph] for ph in result]
return torch.LongTensor(result)
def trim_pred(out, attn):
tq = attn.abs().sum(1).data
for stopi in range(tq.size(0) - 1, -1, -1):
if tq[stopi][0] > 0.5:
break
out = out[:stopi, :]
attn = attn[:stopi, :]
return out, attn
def npy_loader_phonemes(path):
feat = np.load(path)
txt = feat['phonemes'].astype('int64')
txt = torch.from_numpy(txt)
audio = feat['audio_features']
audio = torch.from_numpy(audio)
return txt, audio
def main():
weights = torch.load(args.checkpoint,
map_location=lambda storage, loc: storage)
opt = torch.load(os.path.dirname(args.checkpoint) + '/args.pth')
train_args = opt[0]
train_dataset = NpzFolder(train_args.data + '/numpy_features')
char2code = train_dataset.dict
spkr2code = train_dataset.speakers
norm_path = train_args.data + '/norm_info/norm.dat'
train_args.noise = 0
model = Loop(train_args)
model.load_state_dict(weights)
if args.gpu >= 0:
model.cuda()
model.eval()
if args.spkr not in range(len(spkr2code)):
print('ERROR: Unknown speaker id: %d.' % args.spkr)
return
txt, feat, spkr, output_fname = None, None, None, None
if args.npz is not '':
txt, feat = npy_loader_phonemes(args.npz)
txt = Variable(txt.unsqueeze(1), volatile=True)
feat = Variable(feat.unsqueeze(1), volatile=True)
spkr = Variable(torch.LongTensor([args.spkr]), volatile=True)
fname = os.path.basename(args.npz)[:-4]
output_fname = fname + '.gen_' + str(args.spkr)
elif args.text is not '':
txt = text2phone(args.text, char2code)
feat = torch.FloatTensor(500, 63)
spkr = torch.LongTensor([args.spkr])
txt = Variable(txt.unsqueeze(1), volatile=True)
feat = Variable(feat.unsqueeze(1), volatile=True)
spkr = Variable(spkr, volatile=True)
fname = args.text.replace(' ', '_')
output_fname = fname + '.gen_' + str(args.spkr)
else:
print('ERROR: Must supply npz file path or text as source.')
return
if args.gpu >= 0:
txt = txt.cuda()
feat = feat.cuda()
spkr = spkr.cuda()
out, attn = model([txt, spkr], feat)
out, attn = trim_pred(out, attn)
output_dir = os.path.join(os.path.dirname(args.checkpoint), 'results')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
generate_merlin_wav(out.data.cpu().numpy(),
output_dir,
output_fname,
norm_path)
if args.npz is not '':
output_orig_fname = os.path.basename(args.npz)[:-4] + '.orig'
generate_merlin_wav(feat[:, 0, :].data.cpu().numpy(),
output_dir,
output_orig_fname,
norm_path)
if __name__ == '__main__':
main()