Skip to content

Commit 3fbe03f

Browse files
authored
Support train hifigan (babysor#83)
* support train hifigan
1 parent 222e302 commit 3fbe03f

File tree

8 files changed

+274
-11
lines changed

8 files changed

+274
-11
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
*.sh
1818
synthesizer/saved_models/*
1919
vocoder/saved_models/*
20+
cp_hifigan/*
2021
!vocoder/saved_models/pretrained/*

README-CN.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,12 @@
5858
* 预处理数据:
5959
`python vocoder_preprocess.py <datasets_root>`
6060

61-
* 训练声码器:
61+
* 训练wavernn声码器:
6262
`python vocoder_train.py mandarin <datasets_root>`
6363

64+
* 训练hifigan声码器:
65+
`python vocoder_train.py mandarin <datasets_root> hifigan`
66+
6467
### 3. 启动工具箱
6568
然后您可以尝试使用工具箱:
6669
`python demo_toolbox.py -d <datasets_root>`

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,12 @@ Code:aid4
6161
* Preprocess the data:
6262
`python vocoder_preprocess.py <datasets_root>`
6363

64-
* Train the vocoder:
64+
* Train the wavernn vocoder:
6565
`python vocoder_train.py mandarin <datasets_root>`
6666

67+
* Train the hifigan vocoder
68+
`python vocoder_train.py mandarin <datasets_root> hifigan`
69+
6770
### 3. Launch the Toolbox
6871
You can then try the toolbox:
6972

toolbox/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,10 @@ def init_vocoder(self):
361361
# Sekect vocoder based on model name
362362
if model_fpath.name[0] == "g":
363363
vocoder = gan_vocoder
364-
self.ui.log("vocoder is hifigan")
364+
self.ui.log("set hifigan as vocoder")
365365
else:
366366
vocoder = rnn_vocoder
367+
self.ui.log("set wavernn as vocoder")
367368

368369
self.ui.log("Loading the vocoder %s... " % model_fpath)
369370
self.ui.set_loading(1)

vocoder/hifigan/meldataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def get_dataset_filelist(a):
8484
files = os.listdir(a.input_wavs_dir)
8585
random.shuffle(files)
8686
files = [os.path.join(a.input_wavs_dir, f) for f in files]
87-
training_files = files[: -500]
88-
validation_files = files[-500: ]
87+
training_files = files[: -int(len(files)*0.05)]
88+
validation_files = files[-int(len(files)*0.05): ]
8989

9090
return training_files, validation_files
9191

vocoder/hifigan/train.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import warnings
2+
warnings.simplefilter(action='ignore', category=FutureWarning)
3+
import itertools
4+
import os
5+
import time
6+
import argparse
7+
import json
8+
import torch
9+
import torch.nn.functional as F
10+
from torch.utils.tensorboard import SummaryWriter
11+
from torch.utils.data import DistributedSampler, DataLoader
12+
import torch.multiprocessing as mp
13+
from torch.distributed import init_process_group
14+
from torch.nn.parallel import DistributedDataParallel
15+
from vocoder.hifigan.env import AttrDict, build_env
16+
from vocoder.hifigan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
17+
from vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\
18+
discriminator_loss
19+
from vocoder.hifigan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
20+
21+
torch.backends.cudnn.benchmark = True
22+
23+
24+
def train(rank, a, h):
25+
26+
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan')
27+
a.checkpoint_path.mkdir(exist_ok=True)
28+
a.training_epochs = 3100
29+
a.stdout_interval = 5
30+
a.checkpoint_interval = 25000
31+
a.summary_interval = 5000
32+
a.validation_interval = 1000
33+
a.fine_tuning = True
34+
35+
a.input_wavs_dir = a.syn_dir.joinpath("audio")
36+
a.input_mels_dir = a.syn_dir.joinpath("mels")
37+
38+
if h.num_gpus > 1:
39+
init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
40+
world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
41+
42+
torch.cuda.manual_seed(h.seed)
43+
device = torch.device('cuda:{:d}'.format(rank))
44+
45+
generator = Generator(h).to(device)
46+
mpd = MultiPeriodDiscriminator().to(device)
47+
msd = MultiScaleDiscriminator().to(device)
48+
49+
if rank == 0:
50+
print(generator)
51+
os.makedirs(a.checkpoint_path, exist_ok=True)
52+
print("checkpoints directory : ", a.checkpoint_path)
53+
54+
if os.path.isdir(a.checkpoint_path):
55+
cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
56+
cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
57+
58+
steps = 0
59+
if cp_g is None or cp_do is None:
60+
state_dict_do = None
61+
last_epoch = -1
62+
else:
63+
state_dict_g = load_checkpoint(cp_g, device)
64+
state_dict_do = load_checkpoint(cp_do, device)
65+
generator.load_state_dict(state_dict_g['generator'])
66+
mpd.load_state_dict(state_dict_do['mpd'])
67+
msd.load_state_dict(state_dict_do['msd'])
68+
steps = state_dict_do['steps'] + 1
69+
last_epoch = state_dict_do['epoch']
70+
71+
if h.num_gpus > 1:
72+
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
73+
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
74+
msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
75+
76+
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
77+
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
78+
h.learning_rate, betas=[h.adam_b1, h.adam_b2])
79+
80+
if state_dict_do is not None:
81+
optim_g.load_state_dict(state_dict_do['optim_g'])
82+
optim_d.load_state_dict(state_dict_do['optim_d'])
83+
84+
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
85+
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
86+
87+
training_filelist, validation_filelist = get_dataset_filelist(a)
88+
89+
# print(training_filelist)
90+
# exit()
91+
92+
trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
93+
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
94+
shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
95+
fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
96+
97+
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
98+
99+
train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
100+
sampler=train_sampler,
101+
batch_size=h.batch_size,
102+
pin_memory=True,
103+
drop_last=True)
104+
105+
if rank == 0:
106+
validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
107+
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
108+
fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
109+
base_mels_path=a.input_mels_dir)
110+
validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
111+
sampler=None,
112+
batch_size=1,
113+
pin_memory=True,
114+
drop_last=True)
115+
116+
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
117+
118+
generator.train()
119+
mpd.train()
120+
msd.train()
121+
for epoch in range(max(0, last_epoch), a.training_epochs):
122+
if rank == 0:
123+
start = time.time()
124+
print("Epoch: {}".format(epoch+1))
125+
126+
if h.num_gpus > 1:
127+
train_sampler.set_epoch(epoch)
128+
129+
for i, batch in enumerate(train_loader):
130+
if rank == 0:
131+
start_b = time.time()
132+
x, y, _, y_mel = batch
133+
x = torch.autograd.Variable(x.to(device, non_blocking=True))
134+
y = torch.autograd.Variable(y.to(device, non_blocking=True))
135+
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
136+
y = y.unsqueeze(1)
137+
138+
y_g_hat = generator(x)
139+
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
140+
h.fmin, h.fmax_for_loss)
141+
142+
optim_d.zero_grad()
143+
144+
# MPD
145+
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
146+
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
147+
148+
# MSD
149+
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
150+
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
151+
152+
loss_disc_all = loss_disc_s + loss_disc_f
153+
154+
loss_disc_all.backward()
155+
optim_d.step()
156+
157+
# Generator
158+
optim_g.zero_grad()
159+
160+
# L1 Mel-Spectrogram Loss
161+
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
162+
163+
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
164+
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
165+
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
166+
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
167+
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
168+
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
169+
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
170+
171+
loss_gen_all.backward()
172+
optim_g.step()
173+
174+
if rank == 0:
175+
# STDOUT logging
176+
if steps % a.stdout_interval == 0:
177+
with torch.no_grad():
178+
mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
179+
180+
print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
181+
format(steps, loss_gen_all, mel_error, time.time() - start_b))
182+
183+
# checkpointing
184+
if steps % a.checkpoint_interval == 0 and steps != 0:
185+
checkpoint_path = "{}/g_{:08d}.pt".format(a.checkpoint_path, steps)
186+
save_checkpoint(checkpoint_path,
187+
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
188+
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
189+
save_checkpoint(checkpoint_path,
190+
{'mpd': (mpd.module if h.num_gpus > 1
191+
else mpd).state_dict(),
192+
'msd': (msd.module if h.num_gpus > 1
193+
else msd).state_dict(),
194+
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
195+
'epoch': epoch})
196+
197+
# Tensorboard summary logging
198+
if steps % a.summary_interval == 0:
199+
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
200+
sw.add_scalar("training/mel_spec_error", mel_error, steps)
201+
202+
# Validation
203+
if steps % a.validation_interval == 0: # and steps != 0:
204+
generator.eval()
205+
torch.cuda.empty_cache()
206+
val_err_tot = 0
207+
with torch.no_grad():
208+
for j, batch in enumerate(validation_loader):
209+
x, y, _, y_mel = batch
210+
y_g_hat = generator(x.to(device))
211+
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
212+
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
213+
h.hop_size, h.win_size,
214+
h.fmin, h.fmax_for_loss)
215+
# val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
216+
217+
if j <= 4:
218+
if steps == 0:
219+
sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
220+
sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
221+
222+
sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
223+
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
224+
h.sampling_rate, h.hop_size, h.win_size,
225+
h.fmin, h.fmax)
226+
sw.add_figure('generated/y_hat_spec_{}'.format(j),
227+
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
228+
229+
val_err = val_err_tot / (j+1)
230+
sw.add_scalar("validation/mel_spec_error", val_err, steps)
231+
232+
generator.train()
233+
234+
steps += 1
235+
236+
scheduler_g.step()
237+
scheduler_d.step()
238+
239+
if rank == 0:
240+
print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))

vocoder/vocoder_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from torch.utils.data import Dataset
22
from pathlib import Path
3-
from vocoder import audio
4-
import vocoder.hparams as hp
3+
from vocoder.wavernn import audio
4+
import vocoder.wavernn.hparams as hp
55
import numpy as np
66
import torch
77

vocoder_train.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from utils.argutils import print_args
22
from vocoder.wavernn.train import train
3+
from vocoder.hifigan.train import train as train_hifigan
4+
from vocoder.hifigan.env import AttrDict
35
from pathlib import Path
46
import argparse
7+
import json
58

69

710
if __name__ == "__main__":
@@ -18,6 +21,9 @@
1821
parser.add_argument("datasets_root", type=str, help= \
1922
"Path to the directory containing your SV2TTS directory. Specifying --syn_dir or --voc_dir "
2023
"will take priority over this argument.")
24+
parser.add_argument("vocoder_type", type=str, default="wavernn", help= \
25+
"Choose the vocoder type for train. Defaults to wavernn"
26+
"Now, Support <hifigan> and <wavernn> for choose")
2127
parser.add_argument("--syn_dir", type=str, default=argparse.SUPPRESS, help= \
2228
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
2329
"the wavs and the embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/.")
@@ -37,9 +43,9 @@
3743
"model.")
3844
parser.add_argument("-f", "--force_restart", action="store_true", help= \
3945
"Do not load any saved model and restart from scratch.")
46+
parser.add_argument("--config", type=str, default="vocoder/hifigan/config_16k_.json")
4047
args = parser.parse_args()
4148

42-
# Process the arguments
4349
if not hasattr(args, "syn_dir"):
4450
args.syn_dir = Path(args.datasets_root, "SV2TTS", "synthesizer")
4551
args.syn_dir = Path(args.syn_dir)
@@ -50,7 +56,16 @@
5056
args.models_dir = Path(args.models_dir)
5157
args.models_dir.mkdir(exist_ok=True)
5258

53-
# Run the training
5459
print_args(args, parser)
55-
train(**vars(args))
56-
60+
61+
# Process the arguments
62+
if args.vocoder_type == "wavernn":
63+
# Run the training wavernn
64+
train(**vars(args))
65+
elif args.vocoder_type == "hifigan":
66+
with open(args.config) as f:
67+
json_config = json.load(f)
68+
h = AttrDict(json_config)
69+
train_hifigan(0, args, h)
70+
71+

0 commit comments

Comments
 (0)