天天看點

wav2vec 2.0 模型架構與代碼淺談

作者:小牆神遊

前言

是由Facebook 人工智能實驗室(Facebook AI Research Paris, FARP)與2020年發表在Advances in Neural Information Processing Systems 期刊的語音領域的論文。該論文提出的無監督語音表示架構:其核心就相當于是一種語音特征的提取器,這種提取器可以提取很多的音頻通用特征,而這些通用特征是可以運用在語音領域中微調後的下遊任務如:語音識别,聲紋識别,多輪對話。。。

模型架構

這裡我以Pytorch中的英文模型和騰訊遊戲知幾AI團隊與西工大ASLP組聯合釋出的為例子進行講解。wav2vec 2.0 模型的整體架構圖如下:

wav2vec 2.0 模型架構與代碼淺談

wav2vec 2.0 模型整體架構圖

其中的seq_num 表示幀長,Conv1d表示一維卷積網絡;Fp32GroupNorm表示實際上是Pytorch中的GroupNorm,不過fairseq裡面的層名字設定是這樣的;PQ就是乘積量化;FC就是全連接配接層;LN就是LayerNorm;Self-Attention表示的是注意力頭,每一層有8個;FFN前饋神經網絡實際上也是一種全連接配接網絡其内置次元是3072;PAD-MASK 就是掩碼層,估計就是0 值掩碼,掩碼政策:選6.5%的時間步長作為開始序号,每個序号後面的十個時間步長被掩碼。

中文&英文預訓練模型差異

雖然整體架構就如上圖所示,但是實際上因為調參和使用神經網絡架構不同的原因會有些許差別。比如:Dropout層使用數量和LayerNorm層使用順序等。舉個特征提取層的例子,以下是中文預訓練時的卷積特征提取層的部分模型:

Wav2Vec2Model(
  (feature_extractor): ConvFeatureExtractionModel(
    (conv_layers): ModuleList(
      (0): Sequential(
        (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): Fp32GroupNorm(512, 512, eps=1e-05, affine=True)
        (3): GELU()
      )
      (1): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): GELU()
      )
           

以下是英文預訓練時的卷積特征提取層的部分模型:

Wav2Vec2Model(
  (feature_extractor): FeatureExtractor(
    (conv_layers): ModuleList(
      (0): ConvLayerBlock(
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
      )
      (1): ConvLayerBlock(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
      )
           

很顯然可以看到英文的預訓練的GELU激活函數沒有使用,但是實際上作者在論文中提到過他是用了GELU函數的。我認為這是AI神經網絡的工程師們在實作不同語音預訓練的時候不同的調參政策。

wav2vec 2.0 模型架構與代碼淺談

代碼解析

首先是Pytorch代碼解析,由于暫時還沒做好中文預訓練模型下的自動語音識别(Automatic Speech Recognition,ASR)任務,暫時先談談英文的。

1. pytorch下的ASR任務代碼demo講解

首先是代碼的部分的展示,如下面的代碼塊所示:

import torchaudio
import torch
SPEECH_FILE = 'resources/En-1.wav'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)  # 擷取模型
print(model)
labels = bundle.get_labels()  # 擷取字典标簽
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
waveform = waveform.to(device)

with torch.inference_mode():
  features, _ = model.extract_features(waveform)

with torch.inference_mode():
  emission, _ = model(waveform)


class GreedyCTCDecoder(torch.nn.Module):
  def __init__(self, labels, blank=0):
    super().__init__()
    self.labels = labels
    self.blank = blank

  def forward(self, emission: torch.Tensor) -> str:
    indices = torch.argmax(emission, dim=-1)  # [num_seq,]
    indices = torch.unique_consecutive(indices, dim=-1)
    indices = [i for i in indices if i != self.blank]
    return ''.join([self.labels[i] for i in indices])


decoder = GreedyCTCDecoder(labels=bundle.get_labels())
transcript = decoder(emission[0])
print(transcript)
           

其中先導入torch、torchaudio子產品,torchaudio中的pipelines是為了友善執行某一項具體任務而講訓練過程中對音頻的處理以及對模型推理之後的機率分布處理操作或其他後置處理封裝在一起的API。features=[1,seq_num,768]是12層Transfomer-Encode特征提取之後得到的最終結果。GreedyCTCDecoder類是一個簡單的CTC對齊的方法,他是用來對機率最大的标簽進行比對的得到想要的結果。

2. HuggingFace的ASR任務

HuggingFace代碼部分的展示,如下面的代碼塊所示:

import soundfile as sf
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor


path_audio = 'resources/En-1.wav'
processor = Wav2Vec2Processor.from_pretrained("C:/Users/12046/.cache/huggingface/transformers")
model = Wav2Vec2ForCTC.from_pretrained("C:/Users/12046/.cache/huggingface/transformers")   # 用于ASR等,32維

audio_input, sample_rate = sf.read(path_audio)  # (31129,)
input_values = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values  # torch.Size([1, 31129])

logits = model(input_values).logits     # torch.Size([1, 97, 32])
predicted_ids = torch.argmax(logits, dim=-1)    # torch.Size([1, 97])

transcription = processor.decode(predicted_ids[0])  # ASR的解碼結果
print(transcription)

           

HuggingFace經常是用的就是transfomers這個庫了,實際上這裡的processor用的是和pytorch代碼中的GreedyCTCDecoder一個思想,不過因為他裡面沒有pytorch中友善的pipelines是以不能直接進行解碼操作,需要一個後置處理的類進行解碼操作。

總結

本篇部落格,對wav2vec2.0 的模型架構和代碼進行了簡單的講解,希望看到這篇部落格并且在進行語音方向學習的人有所收貨lol。

wav2vec 2.0 模型架構與代碼淺談

pytorch代碼運作結果圖

wav2vec 2.0 模型架構與代碼淺談

HuggingFace-代碼運作結果圖

繼續閱讀