天天看點

BartModel 源碼解析

1. GenerationMixin 類

這個類的源碼中給了這麼一個部落格連結: https://huggingface.co/blog/how-to-generate 。對生成的了解大有幫助。我總結一下這個部落格的内容如下:

  • 自回歸的假設是:整條句子的機率其實就是條件機率的乘積
  • 生成的句子的長度其實是動态決定的。
  • 文中列出了幾種解碼政策:​

    ​Greedy Search​

    ​​ ​

    ​Beam Search​

    ​​ ​

    ​Top-K sampling​

    ​​ ​

    ​top-p sampling​

    ​ 在貪心政策和束搜尋中,會導緻一個重複生成的問題。也就是下面這樣:

Output:

I enjoy walking with my cute dog, but I’m not sure if I’ll ever be able to walk with my dog. I’m not sure if I’ll ever be able to walk with my dog.

I’m not sure if I’ll…

解決這個重複生成的問題,就是采用n-grams 的政策。就是讓生成的文本中,限制ngrams 重複出現的次數。

beam search 不适合開放域生成。這個文章講的也是比較淺顯,但是易懂~

BartModel 源碼解析

今天在用Bart做生成的時候,發現model.generate() 方法,發現原來是​

​PreTrainedModel​

​​ 這個基類繼承了​

​GenerationMixin​

​,而這個類則是用于生成方法的基類。先看源碼,不得不說,這個源碼是真的長。。。但其實主要的還是下面這個while循環

BartModel 源碼解析
BartModel 源碼解析

下面看看這個 ​

​prepare_inputs_for_generation​

2. 參數(詞表)綁定操作

我在訓練一個以Bart為基礎的模型時,發現訓練的loss是能夠很好的降下去的,但是在generate的時候,生成的全是相同的token。很是奇怪,損失下降如下:

BartModel 源碼解析

但是生成得到的pred卻是下面這個樣子:

BartModel 源碼解析

我定義的Model name是​

​MybartModel​

​,其中的參數是從預訓練中加載出來的。代碼如下:

BartModel 源碼解析

但是針對上面 出現的token重複 的問題,非常疑惑,因為我并不知道是怎麼回事兒。直到我師兄說我沒有對vocabulary做限制導緻的,單純的load參數隻能保證在初始化的時候一緻,但是無法保證在訓練的時候也一緻。即要讓如下兩個參數保持一緻:

BartModel 源碼解析

而這個保持一緻的實作是在 from_pretrained() 中完成的:

BartModel 源碼解析

具體細節後面再分析。過了兩天終于把這個問題解決了,這次bug 的根本原因是:我不了解BartForConditionGeneration 和 BartModel 之間的差別,導緻我直接copy了 BartModel 模型,進而丢失了原有模型的一部分~ ,進而得不到正确的生成結果

我發現這個問題的過程是:

@staticmethod
    def _reorder_cache(past, beam_idx):
        reordered_past = ()
        for layer_past in past:
            # cached cross_attention states don't have to be reordered -> they are always the same
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
            )
        return reordered_past      

這個代碼的目的和邏輯是什麼?

緩存cross_attention 的狀态,不需要再次排序。(它們始終相同)

BartDecoderLayer

再聊聊這個BartDecoderLayer,這是Decoder中的基本元件,我們看看其中是怎麼運作的:

class BartDecoderLayer(nn.Module):
    def __init__(self, config: BartConfig):
        super().__init__()
        self.embed_dim = config.d_model

        self.self_attn = BartAttention(
            embed_dim=self.embed_dim,
            num_heads=config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
        )
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout

        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.encoder_attn = BartAttention(
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
        )
        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = True,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
                cross attention 的輸入,其shape 為 (batch,seq_len,embed_dim)
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                這個encoder_attention_mask 與 上面的 attention_mask 有什麼差別?
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
                size `(decoder_attention_heads,)`.
            past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        residual = hidden_states

        # Self Attention
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        # add present self-attn cache to positions 1,2 of present_key_value tuple
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            past_key_value=self_attn_past_key_value,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # Cross-Attention Block
        cross_attn_present_key_value = None
        cross_attn_weights = None
        if encoder_hidden_states is not None:
            residual = hidden_states

            # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
                hidden_states=hidden_states,
                key_value_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
                layer_head_mask=cross_attn_layer_head_mask,
                past_key_value=cross_attn_past_key_value,
                output_attentions=output_attentions,
            )
            hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
            hidden_states = residual + hidden_states
            hidden_states = self.encoder_attn_layer_norm(hidden_states)

            # add cross-attn to positions 3,4 of present_key_value tuple
            present_key_value = present_key_value + cross_attn_present_key_value

        # Fully Connected
        residual = hidden_states
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
        hidden_states = self.fc2(hidden_states)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        hidden_states = self.final_layer_norm(hidden_states)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights, cross_attn_weights)

        if use_cache:
            outputs += (present_key_value,)

        return outputs      

主要有如下問題:

  • 這段代碼是要實作什麼?
  • hidden_states 和 encoder_hidden_states 是什麼關系?

    hidden_states 是塞入到decoder的input_id 得到的初始embedding, ​​

    ​encoder_hidden_states​

  • attention_mask 和 encoder_attention_mask 是什麼差別?
  • past_key_value 是幹啥的?
BartModel 源碼解析
BartModel 源碼解析
BartModel 源碼解析
BartModel 源碼解析
BartModel 源碼解析

BartAttention

先了解一下 ​

​CrossAttention​

​。源碼如下:

class BartAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        self.scaling = self.head_dim**-0.5
        self.is_decoder = is_decoder

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
    
    # 就是一個reshape的操作,因為是Multi-Head Attention,是以這裡需要shape成需要的樣子
    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel"""

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None

        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scaling
        # get key, value proj
        
        # 在cross_attention 中,且使用cache 的情況下,預測第二個詞開始會使用的邏輯
        # 因為有多層decoder,是以這裡重複使用之前就生成好的key_states, value_states
        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_states = past_key_value[0]
            value_states = past_key_value[1]
        
        # (1)訓練的時候cross_attention
        # (2) 預測時候cross attention 的第一個詞
        elif is_cross_attention:
            # cross_attentions
            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)

        
        elif past_key_value is not None:
            # reuse k, v, self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        else:
            # self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
        
        # 針對不同的attention狀态,進行一個值的儲存
        if self.is_decoder:
            # 情況一:if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            
            # 情況二:if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            
            # 情況三: if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_states, value_states)
        
        proj_shape = (bsz * self.num_heads, -1, self.head_dim)  # 再搞成這個形狀是為什麼?
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.view(*proj_shape)
        value_states = value_states.view(*proj_shape)

        src_len = key_states.size(1)
        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
            )
        
        # 在計算完attention值之後,這時的size 是[bsz*self.num_heads,tgt_len,tgt_len]
        if attention_mask is not None:
            # 判斷attention_mask,這裡的size 其實就是 (bsz, 1, tgt_len, tgt_len)。為什麼又搞出來一個src_len呢?
            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
                )
            # decoder的時候,使用的是teach forcing,因為要mask掉之後的token,是以計算目前的token時,要保證看不到後面的token
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        if layer_head_mask is not None:
            if layer_head_mask.size() != (self.num_heads,):
                raise ValueError(
                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
                )
            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if output_attentions:
            # this operation is a bit awkward, but it's required to
            # make sure that attn_weights keeps its gradient.
            # In order to do so, attn_weights have to be reshaped
            # twice and have to be reused in the following
            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
        else:
            attn_weights_reshaped = None

        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        # 計算得到attention_probs之後,就是和V做乘法得到每個位置的hidden states
        attn_output = torch.bmm(attn_probs, value_states)

        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
            )
        # 修改一下shape
        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)

        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
        # partitioned aross GPUs when using tensor-parallelism.
        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
        # 為什麼最後還要再搞個out_proj ?
        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights_reshaped, past_key_value      

可以看到attention_mask 其實長下面這樣(是一個下三角矩陣,上三角代表要屏蔽的):

BartModel 源碼解析

cross_attention 和 self-attention 都是使用上面這個代碼(BartAttention把這些所有的attention寫在了這個函數中)。decoder 有兩類attention,encoder 隻有一類attention。是以加一起有三類attention。上面代碼的邏輯随着 cross/self-attention 是有變化的。下面就詳細講一下在cross-attention中的計算邏輯。

  • 其key_states 和 value_states 都是從 past_key_value 中得到

    這裡的 ​

    ​attn_weights​

    ​ 為什麼是不是一個方陣?

    ​attn_probs​

    ​ 的形狀如下:
  • BartModel 源碼解析
  • ​value_states​

    ​的shape如下:
  • BartModel 源碼解析
  • 送入到corss_attention 的q是 次元是 ​

    ​(10,1,1024)​

    ​, 變成了 (160,1,64), key的次元是(160,1024,64),value 的次元是 (160,1024,64)。

    q 是來自于decoder,k,v 是來自于encoder。

  • past_key_value 的邏輯

Encoder-Decoder 的真實樣子

我們通常看到的圖長下面這樣:

BartModel 源碼解析

但這個圖還不夠準确,如果是生成模型,那麼準确的模型結構應該是下面這樣:

BartModel 源碼解析

上圖說明,decoder的時候,其實是每層都要有一個cross attention。

要時刻記得 decoder 的目标是得到接下來的生成單詞,又因為是自回歸的,是以每次 decoder 得到的hidden_state 都是一個單詞,準确來說,其次元是 (bsz,1,1024/768)。這個1024/768 指的是隐藏層的次元。

use_cache 的作用是什麼?

可以參考一下下面這個回複:

​ https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958​​ 我稍微解釋一下:

use_cache 僅僅在generate() 的時候使用,而不是在訓練的時候。

BartModel 源碼

BartModel 源碼解析

這裡為啥先對encoder_outputs 做一個判斷?我的猜測是:如果是第一層的decoder layer,那麼就需要走這個self.encoder,後幾層的decoder layer 則可以直接複用之前計算好的值。

但是我感覺這個了解是不對的,因為複用 encoder_outputs 是在decoder中複用的,而這段代碼是在BartModel 中的。

生成的速度很慢,但是訓練速度是正常的。

繼續閱讀