天天看點

FastSpeech2 代碼閱讀筆記——模型搭建

作者:語音刺客

FastSpeech2模型搭建主要涉及的兩個檔案為fastspeech.py和model路徑下的modules.py檔案。

1.model/modules.py

本檔案主要是定義Variance Adaptor,其中主要包括Duration Predictor、Length Regulator、Pitch Predictor和Energy Predictor,詳細代碼和注釋解析如下所示

import os
 import json
 import copy
 import math
 from collections import OrderedDict
 
 import torch
 import torch.nn as nn
 import numpy as np
 import torch.nn.functional as F
 
 from utils.tools import get_mask_from_lengths, pad
 
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 # 完整Variance Adaptor
 class VarianceAdaptor(nn.Module):
     """Variance Adaptor"""
 
     def __init__(self, preprocess_config, model_config):
         super(VarianceAdaptor, self).__init__()
         self.duration_predictor = VariancePredictor(model_config)
         self.length_regulator = LengthRegulator()
         self.pitch_predictor = VariancePredictor(model_config)
         self.energy_predictor = VariancePredictor(model_config)
 
         # 設定pitch和energy的級别
         self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"]["feature"]
         self.energy_feature_level = preprocess_config["preprocessing"]["energy"]["feature"]
         assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
         assert self.energy_feature_level in ["phoneme_level", "frame_level"]
 
         # 設定pitch何energy的量化方式
         pitch_quantization = model_config["variance_embedding"]["pitch_quantization"]
         energy_quantization = model_config["variance_embedding"]["energy_quantization"]
         n_bins = model_config["variance_embedding"]["n_bins"]
         assert pitch_quantization in ["linear", "log"]
         assert energy_quantization in ["linear", "log"]
 
         # 加載pitch和energy正則化所需參數
         with open(
             os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
         ) as f:
             stats = json.load(f)
             pitch_min, pitch_max = stats["pitch"][:2]
             energy_min, energy_max = stats["energy"][:2]
 
         # if量化參數為log,表示在處理過程中沒有經過量化,正常情況下量化方式為linear
         if pitch_quantization == "log":
             # torch.exp() 表示e^{input}
             # torch.linsapce(x,y,num)表示傳回x和y之間的等間隔的區間,共num個
             self.pitch_bins = nn.Parameter(
                 torch.exp(
                     torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1) # 255
                 ),
                 requires_grad=False,
             )
         else:
             self.pitch_bins = nn.Parameter(
                 torch.linspace(pitch_min, pitch_max, n_bins - 1),
                 requires_grad=False,
             )
         if energy_quantization == "log":
             self.energy_bins = nn.Parameter(
                 torch.exp(
                     torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1)
                 ),
                 requires_grad=False,
             )
         else:
             self.energy_bins = nn.Parameter(
                 torch.linspace(energy_min, energy_max, n_bins - 1),
                 requires_grad=False,
             )
         # pitch和energy的嵌入層
         self.pitch_embedding = nn.Embedding(
             n_bins, model_config["transformer"]["encoder_hidden"]
         )
         self.energy_embedding = nn.Embedding(
             n_bins, model_config["transformer"]["encoder_hidden"]
         )
 
     # 計算pitch嵌入層
     def get_pitch_embedding(self, x, target, mask, control):
         prediction = self.pitch_predictor(x, mask)  # pitch預測器預測的數值
         if target is not None:  # target存在,訓練過程,使用target計算embedding
             embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins))
         else:  # target不存在,預測過程,使用prediction計算embedding
             prediction = prediction * control   # control是用于控制的系數
             embedding = self.pitch_embedding(torch.bucketize(prediction, self.pitch_bins))
         return prediction, embedding  # prediction用于訓練過程計算損失,embedding與x相加進行後續計算
 
     # 計算energy嵌入層
     def get_energy_embedding(self, x, target, mask, control):
         prediction = self.energy_predictor(x, mask)
         if target is not None:
             embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins))
         else:
             prediction = prediction * control
             embedding = self.energy_embedding(torch.bucketize(prediction, self.energy_bins))
         return prediction, embedding
 
     def forward(
         self,
         x,
         src_mask,
         mel_mask=None,
         max_len=None,
         pitch_target=None,
         energy_target=None,
         duration_target=None,
         p_control=1.0,
         e_control=1.0,
         d_control=1.0,
     ):
 
         log_duration_prediction = self.duration_predictor(x, src_mask)  # 對音素序列預測的持續時間
         if self.pitch_feature_level == "phoneme_level":
             pitch_prediction, pitch_embedding = self.get_pitch_embedding(
                 x, pitch_target, src_mask, p_control
             )
             x = x + pitch_embedding  # 累加pitch嵌入層
         if self.energy_feature_level == "phoneme_level":
             energy_prediction, energy_embedding = self.get_energy_embedding(
                 x, energy_target, src_mask, p_control
             )
             x = x + energy_embedding  # 累加energy嵌入層
 
         if duration_target is not None:  # duration_target,訓練過程,使用duration_target計算
             x, mel_len = self.length_regulator(x, duration_target, max_len)  # 使用duration_target調整x
             duration_rounded = duration_target
         else:  # 預測過程
             # 基于log_duration_prediction建構duration_rounded,用于調整x
             # torch.clamp() 将輸入input張量每個元素的夾緊到區間 [min,max][min,max],并傳回結果到一個新張量
             # torch.round() 四舍五入
             duration_rounded = torch.clamp(
                 (torch.round(torch.exp(log_duration_prediction) - 1) * d_control),
                 min=0,
             )
             x, mel_len = self.length_regulator(x, duration_rounded, max_len)
             mel_mask = get_mask_from_lengths(mel_len)
 
         if self.pitch_feature_level == "frame_level":
             pitch_prediction, pitch_embedding = self.get_pitch_embedding(
                 x, pitch_target, mel_mask, p_control
             )
             x = x + pitch_embedding
         if self.energy_feature_level == "frame_level":
             energy_prediction, energy_embedding = self.get_energy_embedding(
                 x, energy_target, mel_mask, p_control
             )
             x = x + energy_embedding
 
         return (  # 此處三個prediction用于後續計算損失
             x,
             pitch_prediction,
             energy_prediction,
             log_duration_prediction,
             duration_rounded,
             mel_len,
             mel_mask,
         )
 
 # 長度調節器
 class LengthRegulator(nn.Module):
     """Length Regulator"""
 
     def __init__(self):
         super(LengthRegulator, self).__init__()
 
     # 對輸入的音素序列x進行長度調正
     def LR(self, x, duration, max_len):
         """
         基于音素持續時間将音素序列長度與mel譜圖長度對齊
         @param x: 經過FFT塊轉換後的音素序列,[batch_size, max_sequence_len, encoder_dim]
         @param duration: 音素持續時間矩陣,[batch_size, max_sequence_len]
         @param max_len: 音素譜圖序列中最大長度
         @return: 長度經過調整後的音素序列,[batch_size, max_len, encoder_dim]
         """
         output = list()
         mel_len = list()
         for batch, expand_target in zip(x, duration):
             expanded = self.expand(batch, expand_target)  # 獲得一個長度完整調整之後音素序列
             output.append(expanded)
             mel_len.append(expanded.shape[0])  # 記錄mel譜圖長度大小,友善後續生成mask
 
         # 如果傳入max_len就按其進行pad,如果沒有就以output中最長序列大小進行pad
         if max_len is not None:
             output = pad(output, max_len)
         else:
             output = pad(output)
 
         return output, torch.LongTensor(mel_len).to(device)
 
     def expand(self, batch, predicted):
         """
         将輸入的一個音素序列的長度按其對應的持續時間調整
         @param batch:一個音頻對應文本的音素序列,[max_sequence_len, encoder_dim]
         @param predicted:音素序列中每個音素對應的持續序列,長度為max_sequence_len
         @return:長度調整後的音素序列,長度與mel譜圖長度一緻
         """
 
         out = list()
 
         for i, vec in enumerate(batch):
             expand_size = predicted[i].item()  # i對應的音素對應持續時間,即需要重複的次數
             out.append(vec.expand(max(int(expand_size), 0), -1))  # 将i對應的音素的表征向量vec重複expand_size次
         out = torch.cat(out, 0)  # 将整個音素序列cat起來
 
         return out
 
     def forward(self, x, duration, max_len):
         output, mel_len = self.LR(x, duration, max_len)
         return output, mel_len
 
 
 class VariancePredictor(nn.Module):
     """Duration, Pitch and Energy Predictor"""
 
     def __init__(self, model_config):
         super(VariancePredictor, self).__init__()
 
         self.input_size = model_config["transformer"]["encoder_hidden"]  # 輸入尺寸 256
         self.filter_size = model_config["variance_predictor"]["filter_size"]  # 輸出尺寸 256
         self.kernel = model_config["variance_predictor"]["kernel_size"]  # 卷積核大小 3
         self.conv_output_size = model_config["variance_predictor"]["filter_size"]
         self.dropout = model_config["variance_predictor"]["dropout"]
 
         # 定義一個包含激活函數和正則項的卷積序列,即[Con1D+Relu+LN+Dropout]+[Con1D+Relu+LN+Dropout]
         self.conv_layer = nn.Sequential(
             OrderedDict(
                 [
                     (
                         "conv1d_1",
                         Conv(
                             self.input_size,
                             self.filter_size,
                             kernel_size=self.kernel,
                             padding=(self.kernel - 1) // 2,
                         ),
                     ),
                     ("relu_1", nn.ReLU()),
                     ("layer_norm_1", nn.LayerNorm(self.filter_size)),
                     ("dropout_1", nn.Dropout(self.dropout)),
                     (
                         "conv1d_2",
                         Conv(
                             self.filter_size,
                             self.filter_size,
                             kernel_size=self.kernel,
                             padding=1,
                         ),
                     ),
                     ("relu_2", nn.ReLU()),
                     ("layer_norm_2", nn.LayerNorm(self.filter_size)),
                     ("dropout_2", nn.Dropout(self.dropout)),
                 ]
             )
         )
 
         self.linear_layer = nn.Linear(self.conv_output_size, 1)
 
     def forward(self, encoder_output, mask):
         out = self.conv_layer(encoder_output)  # [Con1D+Relu+LN+Dropout]+[Con1D+Relu+LN+Dropout]
         out = self.linear_layer(out)  # 最後輸出前的線性層
         out = out.squeeze(-1)  # 因為線性層傳回的是1,即輸出的尺寸的最後一維是1,将其壓縮掉
 
         if mask is not None:  # 将mask對應地方設定為0
             out = out.masked_fill(mask, 0.0)
 
         return out
 
 # 自定義的一維卷積網絡
 class Conv(nn.Module):
     """
     Convolution Module
     """
 
     def __init__(
         self,
         in_channels,
         out_channels,
         kernel_size=1,
         stride=1,
         padding=0,
         dilation=1,
         bias=True,
         w_init="linear",
     ):
         """
         :param in_channels: dimension of input
         :param out_channels: dimension of output
         :param kernel_size: size of kernel
         :param stride: size of stride
         :param padding: size of padding
         :param dilation: dilation rate
         :param bias: boolean. if True, bias is included.
         :param w_init: str. weight inits with xavier initialization.
         """
         super(Conv, self).__init__()
 
         self.conv = nn.Conv1d(
             in_channels,
             out_channels,
             kernel_size=kernel_size,
             stride=stride,
             padding=padding,
             dilation=dilation,
             bias=bias,
         )
 
     def forward(self, x):
         x = x.contiguous().transpose(1, 2)
         x = self.conv(x)
         x = x.contiguous().transpose(1, 2)
 
         return x
            

2.model/fastspeech2.py

本檔案将Encoder, Decoder, PostNet和Variance Adaptor子產品內建在一起,完成FastSpeech2模型搭建

import os
 import json
 
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
 from transformer import Encoder, Decoder, PostNet
 from .modules import VarianceAdaptor
 from utils.tools import get_mask_from_lengths
 
 
 class FastSpeech2(nn.Module):
     """ FastSpeech2 """
 
     def __init__(self, preprocess_config, model_config):
         super(FastSpeech2, self).__init__()
         self.model_config = model_config
 
         self.encoder = Encoder(model_config)  # variance adaptor之前encoder
         self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config)
         self.decoder = Decoder(model_config)  # variance adaptor之後decoder
         self.mel_linear = nn.Linear(
             model_config["transformer"]["decoder_hidden"],  # 256
             preprocess_config["preprocessing"]["mel"]["n_mel_channels"],  # 80
         )
         self.postnet = PostNet()
 
         self.speaker_emb = None
         # 如果為多說話人
         if model_config["multi_speaker"]: # True
             # 加載speaker檔案
             with open(os.path.join(preprocess_config["path"]["preprocessed_path"], "speakers.json"),"r",) \
                     as f:
                 n_speaker = len(json.load(f))
             # 建構speaker embedding
             self.speaker_emb = nn.Embedding(
                 n_speaker,
                 model_config["transformer"]["encoder_hidden"],  # 256
             )
 
     def forward(
         self,
         speakers,
         texts,
         src_lens,
         max_src_len,
         mels=None,
         mel_lens=None,
         max_mel_len=None,
         p_targets=None,
         e_targets=None,
         d_targets=None,
         p_control=1.0,  # 控制系數
         e_control=1.0,
         d_control=1.0,
     ):
         src_masks = get_mask_from_lengths(src_lens, max_src_len)  # 原始文本序列mask
         mel_masks = (
             get_mask_from_lengths(mel_lens, max_mel_len)
             if mel_lens is not None
             else None
         ) # mel譜圖序列mask
 
         output = self.encoder(texts, src_masks) # 編碼
 
         if self.speaker_emb is not None:  # 如果存在speaker嵌入層,将其和output相加
             output = output + self.speaker_emb(speakers).unsqueeze(1).expand(
                 -1, max_src_len, -1
             )
 
         # 通過Variance Adaptor子產品計算
         (
             output,
             p_predictions,
             e_predictions,
             log_d_predictions,
             d_rounded,
             mel_lens,
             mel_masks,
         ) = self.variance_adaptor(
             output,
             src_masks,
             mel_masks,
             max_mel_len,
             p_targets,
             e_targets,
             d_targets,
             p_control,
             e_control,
             d_control,
         )
 
         output, mel_masks = self.decoder(output, mel_masks)  # 解碼
         output = self.mel_linear(output)  # 線性轉換
 
         postnet_output = self.postnet(output) + output  # 後處理
 
         return (
             output,
             postnet_output,
             p_predictions,
             e_predictions,
             log_d_predictions,
             d_rounded,
             src_masks,
             mel_masks,
             src_lens,
             mel_lens,
         )