|
| 1 | +import warnings |
| 2 | +warnings.simplefilter(action='ignore', category=FutureWarning) |
| 3 | +import itertools |
| 4 | +import os |
| 5 | +import time |
| 6 | +import argparse |
| 7 | +import json |
| 8 | +import torch |
| 9 | +import torch.nn.functional as F |
| 10 | +from torch.utils.tensorboard import SummaryWriter |
| 11 | +from torch.utils.data import DistributedSampler, DataLoader |
| 12 | +import torch.multiprocessing as mp |
| 13 | +from torch.distributed import init_process_group |
| 14 | +from torch.nn.parallel import DistributedDataParallel |
| 15 | +from vocoder.hifigan.env import AttrDict, build_env |
| 16 | +from vocoder.hifigan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist |
| 17 | +from vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\ |
| 18 | + discriminator_loss |
| 19 | +from vocoder.hifigan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint |
| 20 | + |
| 21 | +torch.backends.cudnn.benchmark = True |
| 22 | + |
| 23 | + |
| 24 | +def train(rank, a, h): |
| 25 | + |
| 26 | + a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan') |
| 27 | + a.checkpoint_path.mkdir(exist_ok=True) |
| 28 | + a.training_epochs = 3100 |
| 29 | + a.stdout_interval = 5 |
| 30 | + a.checkpoint_interval = 25000 |
| 31 | + a.summary_interval = 5000 |
| 32 | + a.validation_interval = 1000 |
| 33 | + a.fine_tuning = True |
| 34 | + |
| 35 | + a.input_wavs_dir = a.syn_dir.joinpath("audio") |
| 36 | + a.input_mels_dir = a.syn_dir.joinpath("mels") |
| 37 | + |
| 38 | + if h.num_gpus > 1: |
| 39 | + init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], |
| 40 | + world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) |
| 41 | + |
| 42 | + torch.cuda.manual_seed(h.seed) |
| 43 | + device = torch.device('cuda:{:d}'.format(rank)) |
| 44 | + |
| 45 | + generator = Generator(h).to(device) |
| 46 | + mpd = MultiPeriodDiscriminator().to(device) |
| 47 | + msd = MultiScaleDiscriminator().to(device) |
| 48 | + |
| 49 | + if rank == 0: |
| 50 | + print(generator) |
| 51 | + os.makedirs(a.checkpoint_path, exist_ok=True) |
| 52 | + print("checkpoints directory : ", a.checkpoint_path) |
| 53 | + |
| 54 | + if os.path.isdir(a.checkpoint_path): |
| 55 | + cp_g = scan_checkpoint(a.checkpoint_path, 'g_') |
| 56 | + cp_do = scan_checkpoint(a.checkpoint_path, 'do_') |
| 57 | + |
| 58 | + steps = 0 |
| 59 | + if cp_g is None or cp_do is None: |
| 60 | + state_dict_do = None |
| 61 | + last_epoch = -1 |
| 62 | + else: |
| 63 | + state_dict_g = load_checkpoint(cp_g, device) |
| 64 | + state_dict_do = load_checkpoint(cp_do, device) |
| 65 | + generator.load_state_dict(state_dict_g['generator']) |
| 66 | + mpd.load_state_dict(state_dict_do['mpd']) |
| 67 | + msd.load_state_dict(state_dict_do['msd']) |
| 68 | + steps = state_dict_do['steps'] + 1 |
| 69 | + last_epoch = state_dict_do['epoch'] |
| 70 | + |
| 71 | + if h.num_gpus > 1: |
| 72 | + generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) |
| 73 | + mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) |
| 74 | + msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) |
| 75 | + |
| 76 | + optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) |
| 77 | + optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()), |
| 78 | + h.learning_rate, betas=[h.adam_b1, h.adam_b2]) |
| 79 | + |
| 80 | + if state_dict_do is not None: |
| 81 | + optim_g.load_state_dict(state_dict_do['optim_g']) |
| 82 | + optim_d.load_state_dict(state_dict_do['optim_d']) |
| 83 | + |
| 84 | + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) |
| 85 | + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) |
| 86 | + |
| 87 | + training_filelist, validation_filelist = get_dataset_filelist(a) |
| 88 | + |
| 89 | + # print(training_filelist) |
| 90 | + # exit() |
| 91 | + |
| 92 | + trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels, |
| 93 | + h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0, |
| 94 | + shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device, |
| 95 | + fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir) |
| 96 | + |
| 97 | + train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None |
| 98 | + |
| 99 | + train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, |
| 100 | + sampler=train_sampler, |
| 101 | + batch_size=h.batch_size, |
| 102 | + pin_memory=True, |
| 103 | + drop_last=True) |
| 104 | + |
| 105 | + if rank == 0: |
| 106 | + validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels, |
| 107 | + h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, |
| 108 | + fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, |
| 109 | + base_mels_path=a.input_mels_dir) |
| 110 | + validation_loader = DataLoader(validset, num_workers=1, shuffle=False, |
| 111 | + sampler=None, |
| 112 | + batch_size=1, |
| 113 | + pin_memory=True, |
| 114 | + drop_last=True) |
| 115 | + |
| 116 | + sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) |
| 117 | + |
| 118 | + generator.train() |
| 119 | + mpd.train() |
| 120 | + msd.train() |
| 121 | + for epoch in range(max(0, last_epoch), a.training_epochs): |
| 122 | + if rank == 0: |
| 123 | + start = time.time() |
| 124 | + print("Epoch: {}".format(epoch+1)) |
| 125 | + |
| 126 | + if h.num_gpus > 1: |
| 127 | + train_sampler.set_epoch(epoch) |
| 128 | + |
| 129 | + for i, batch in enumerate(train_loader): |
| 130 | + if rank == 0: |
| 131 | + start_b = time.time() |
| 132 | + x, y, _, y_mel = batch |
| 133 | + x = torch.autograd.Variable(x.to(device, non_blocking=True)) |
| 134 | + y = torch.autograd.Variable(y.to(device, non_blocking=True)) |
| 135 | + y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) |
| 136 | + y = y.unsqueeze(1) |
| 137 | + |
| 138 | + y_g_hat = generator(x) |
| 139 | + y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, |
| 140 | + h.fmin, h.fmax_for_loss) |
| 141 | + |
| 142 | + optim_d.zero_grad() |
| 143 | + |
| 144 | + # MPD |
| 145 | + y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) |
| 146 | + loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) |
| 147 | + |
| 148 | + # MSD |
| 149 | + y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) |
| 150 | + loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) |
| 151 | + |
| 152 | + loss_disc_all = loss_disc_s + loss_disc_f |
| 153 | + |
| 154 | + loss_disc_all.backward() |
| 155 | + optim_d.step() |
| 156 | + |
| 157 | + # Generator |
| 158 | + optim_g.zero_grad() |
| 159 | + |
| 160 | + # L1 Mel-Spectrogram Loss |
| 161 | + loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45 |
| 162 | + |
| 163 | + y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) |
| 164 | + y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) |
| 165 | + loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) |
| 166 | + loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) |
| 167 | + loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) |
| 168 | + loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) |
| 169 | + loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel |
| 170 | + |
| 171 | + loss_gen_all.backward() |
| 172 | + optim_g.step() |
| 173 | + |
| 174 | + if rank == 0: |
| 175 | + # STDOUT logging |
| 176 | + if steps % a.stdout_interval == 0: |
| 177 | + with torch.no_grad(): |
| 178 | + mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() |
| 179 | + |
| 180 | + print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'. |
| 181 | + format(steps, loss_gen_all, mel_error, time.time() - start_b)) |
| 182 | + |
| 183 | + # checkpointing |
| 184 | + if steps % a.checkpoint_interval == 0 and steps != 0: |
| 185 | + checkpoint_path = "{}/g_{:08d}.pt".format(a.checkpoint_path, steps) |
| 186 | + save_checkpoint(checkpoint_path, |
| 187 | + {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()}) |
| 188 | + checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps) |
| 189 | + save_checkpoint(checkpoint_path, |
| 190 | + {'mpd': (mpd.module if h.num_gpus > 1 |
| 191 | + else mpd).state_dict(), |
| 192 | + 'msd': (msd.module if h.num_gpus > 1 |
| 193 | + else msd).state_dict(), |
| 194 | + 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, |
| 195 | + 'epoch': epoch}) |
| 196 | + |
| 197 | + # Tensorboard summary logging |
| 198 | + if steps % a.summary_interval == 0: |
| 199 | + sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) |
| 200 | + sw.add_scalar("training/mel_spec_error", mel_error, steps) |
| 201 | + |
| 202 | + # Validation |
| 203 | + if steps % a.validation_interval == 0: # and steps != 0: |
| 204 | + generator.eval() |
| 205 | + torch.cuda.empty_cache() |
| 206 | + val_err_tot = 0 |
| 207 | + with torch.no_grad(): |
| 208 | + for j, batch in enumerate(validation_loader): |
| 209 | + x, y, _, y_mel = batch |
| 210 | + y_g_hat = generator(x.to(device)) |
| 211 | + y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) |
| 212 | + y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, |
| 213 | + h.hop_size, h.win_size, |
| 214 | + h.fmin, h.fmax_for_loss) |
| 215 | +# val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() |
| 216 | + |
| 217 | + if j <= 4: |
| 218 | + if steps == 0: |
| 219 | + sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate) |
| 220 | + sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps) |
| 221 | + |
| 222 | + sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate) |
| 223 | + y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, |
| 224 | + h.sampling_rate, h.hop_size, h.win_size, |
| 225 | + h.fmin, h.fmax) |
| 226 | + sw.add_figure('generated/y_hat_spec_{}'.format(j), |
| 227 | + plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps) |
| 228 | + |
| 229 | + val_err = val_err_tot / (j+1) |
| 230 | + sw.add_scalar("validation/mel_spec_error", val_err, steps) |
| 231 | + |
| 232 | + generator.train() |
| 233 | + |
| 234 | + steps += 1 |
| 235 | + |
| 236 | + scheduler_g.step() |
| 237 | + scheduler_d.step() |
| 238 | + |
| 239 | + if rank == 0: |
| 240 | + print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) |
0 commit comments