天天看點

文本分類之 residual-connection+selfAttention的詞向量平均模型

這是一個文本分類的系列專題,将采用不同的方法有簡單到複雜實作文本分類。

使用Stanford sentiment treebank 電影評論資料集 (Socher et al. 2013). 資料集可以從這裡下載下傳

連結:資料集

提取碼:yeqw

代碼請參考:文本分類 和部落格code一緻

文本分類 之 self attention 機制

在前面 word average model 和 word average with attention model的基礎上,我們做個擴充,加上self attention.

我們再定義一種基于self-attention 的句子模型。

α t = e m b ( x t ) T e m b ( x s ) \alpha_t = emb(x_t)^T emb(x_s) αt​=emb(xt​)Temb(xs​)

α t ∝ e x p { ∑ t α t s } \alpha_t \propto exp\{\sum_t\alpha_{ts}\} αt​∝exp{t∑​αts​}

h s e l f = ∑ t α t e m b ( x t ) h_{self} = \sum_t\alpha_t emb(x_t) hself​=t∑​αt​emb(xt​)

句子的正面情感的機率為

σ ( W T h s e l f ) \sigma(W^Th_{self}) σ(WThself​)

單詞的權重是該單詞的embedding和所有其他單詞的embedding的dot product的和,然後做softmax歸一化。這個模型和 word average with attention 的差別是沒有額外引入模型參數u.

另一個變種是把詞向量的平均向量也加入self-attention向量,相當于一種residual connection 的方法。

σ ( W T ( h s e l f + h a v g ) ) \sigma(W^T(h_{self} + h_{avg})) σ(WT(hself​+havg​))

本文我們将實作 加residual connection,self-attention的 詞向量平均模型。

import random
from collections import Counter
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
USE_CUDA = torch.cuda.is_available()
device = torch.device('cuda' if USE_CUDA else 'cpu')
           

讀資料

with open('senti.train.tsv','r') as rf:
    lines = rf.readlines()
print(lines[:10])
           
['hide new secretions from the parental units\t0\n', 'contains no wit , only labored gags\t0\n', 'that loves its characters and communicates something rather beautiful about human nature\t1\n', 'remains utterly satisfied to remain the same throughout\t0\n', 'on the worst revenge-of-the-nerds clich茅s the filmmakers could dredge up\t0\n', "that 's far too tragic to merit such superficial treatment\t0\n", 'demonstrates that the director of such Hollywood blockbusters as Patriot Games can still turn out a small , personal film with an emotional wallop .\t1\n', 'of saucy\t1\n', "a depressed fifteen-year-old 's suicidal poetry\t0\n", "are more deeply thought through than in most ` right-thinking ' films\t1\n"]
           
def read_corpus(path):
    sentences = []
    labels = []
    with open(path,'r', encoding='utf-8') as f:
        for line in f:
            sentence, label = line.split('\t')
            sentences.append(sentence.lower().split())
            labels.append(label[0])
    return sentences, labels
           
train_sentences, train_labels = read_corpus(train_path)
dev_sentences, dev_labels = read_corpus(dev_path)
test_sentences, test_labels = read_corpus(test_path)
           
(['contains', 'no', 'wit', ',', 'only', 'labored', 'gags'], '0')
           

構造詞典

def build_vocab(sentences, word_size=20000):
    c = Counter()
    for sent in sentences:
        for word in sent:
            c[word] += 1
    print('文本總單詞量為:',len(c))
    words_most_common = c.most_common(word_size)
    ## adding unk, pad
    idx2word = ['<pad>','<unk>'] + [item[0] for item in words_most_common]
    word2dix = {w:i for i, w in enumerate(idx2word)}
    return idx2word, word2dix
           
WORD_SIZE=20000
idx2word, word2dix = build_vocab(train_sentences, word_size=WORD_SIZE)
           
文本總單詞量為: 14828
           
['<pad>', '<unk>', 'the', ',', 'a', 'and', 'of', '.', 'to', "'s"]
           

構造batch

def numeralization(sentences, labels, word2idx):
    '把word list表示的句子轉成 index 表示的清單'
    numeral_sent = [[word2dix.get(w, word2dix['<unk>']) for w in s] for s in sentences]
    numeral_label =[int(label) for label in labels]
    return list(zip(numeral_sent, numeral_label))
           
num_train_data = numeralization(train_sentences, train_labels, word2dix)
num_test_data = numeralization(test_sentences, test_labels, word2dix)
num_dev_data = numeralization(dev_sentences, dev_labels, word2dix)

           
def convert2tensor(batch_sentences):
    '将batch資料轉成tensor,這裡主要是為了padding'
    lengths = [len(s) for s in batch_sentences]
    max_len = max(lengths)
    batch_size = len(batch_sentences)
    batch = torch.zeros(batch_size, max_len, dtype=torch.long)
    for i, l in enumerate(lengths):
        batch[i, :l] = torch.tensor(batch_sentences[i])
    return batch
           
def generate_batch(numeral_sentences_labels, batch_size=32):
    '''将list index 資料 分成batch '''
    batches = []
    num_sample = len(numeral_sentences_labels)
    random.shuffle(numeral_sentences_labels)
    numeral_sent = [n[0] for n in numeral_sentences_labels]
    numeral_label = [n[1] for n in numeral_sentences_labels]
    for start in range(0, num_sample, batch_size):
        end = start + batch_size
        if end > num_sample:
            batch_sentences = numeral_sent[start : num_sample]
            batch_labels = numeral_label[start : num_sample]
            batch_sent_tensor = convert2tensor(batch_sentences)
            batch_label_tensor = torch.tensor(batch_labels, dtype=torch.float)
        else:
            batch_sentences = numeral_sent[start : end]
            batch_labels = numeral_label[start : end]
            batch_sent_tensor = convert2tensor(batch_sentences)
            batch_label_tensor = torch.tensor(batch_labels, dtype=torch.float)
        batches.append((batch_sent_tensor.cuda(), batch_label_tensor.cuda()))
    return batches
           

構模組化型

class AVGSelfAttnModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, output_size, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.qkv = nn.Linear(embed_dim, embed_dim, bias=False)
        self.fc = nn.Linear(embed_dim, output_size,bias=False)
        
    def forward(self, text):
        ## [batch_size, seq_len]->[batch_size, seq_len, embed_dim]
        embed = self.embedding(text)
        ##[batch_size, seq_len, embed_dim]->[batch_size, seq_len, embed_dim]
        x = self.qkv(embed) 
        ## 計算句子attention
        h_attn = self.attention(x)
        ## 添加 residual connection
        h_attn += embed
        ## 添加 layer norm (可以分别看一下添加和不添加的效果)
#         h_attn = self.layer_norma(h_attn)
        ## 計算平 整個句子 attention 之後的embedding 句子相加得到句子的表示
        h_attn = torch.sum(h_attn, dim=1).squeeze()
        out = self.fc(h_attn)
        return out
    
    def attention(self, x):
        d_k = x.size(-1)
        ##[batch_size, seq_len, embed_dim] * [batch_size, embed_dim, seq_len] ->[batch_size, seq_len, seq_len]
        score = torch.matmul(x, x.transpose(-2, -1))/math.sqrt(d_k)
        ## 計算權重 attn:[batch_size, seq_len, seq_len]
        attn = F.softmax(score, dim=-1) 
        ## 計算context 值 attn_x: [batch_size, seq_len, embed_dim]
        attn_x = torch.matmul(attn, x)
        return attn_x
    
    def layer_norm(self, x):
        mean = x.mean(-1, keep_dim=True)
        std = x.std(-1, keep_dim=True)
        x_lm = (x-mean)/std
        return x_lm

    def get_embed_weigth(self):
        return self.embedding.weight.data
           
VOCAB_SIZE = len(word2dix)
EMBEDDING_DIM = 100
OUTPUT_SIZE = 1
PAD_IDX = word2dix['<pad>']
           
model = AVGSelfAttnModel(vocab_size=VOCAB_SIZE,
                 embed_dim=EMBEDDING_DIM,
                 output_size=OUTPUT_SIZE, 
                 pad_idx=PAD_IDX)
model.to(device)
           
AVGSelfAttnModel(
  (embedding): Embedding(14830, 100, padding_idx=0)
  (qkv): Linear(in_features=100, out_features=100, bias=False)
  (fc): Linear(in_features=100, out_features=1, bias=False)
)
           

定義損失函數 和優化函數

criterion = nn.BCEWithLogitsLoss()
criterion = criterion.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
           

訓練模型

def get_accuracy(output, label):
    ## output: batch_size 
    y_hat = torch.round(torch.sigmoid(output)) ## 将output 轉成0和1
    correct = (y_hat == label).float()
    acc = correct.sum()/len(correct)
    return acc
           
def evaluate(batch_data, model, criterion, get_accuracy):
    model.eval()
    num_epoch = epoch_loss = epoch_acc = 0
    with torch.no_grad():
        for text, label in batch_data:
            out = model(text).squeeze(1)
            loss = criterion(out, label)
            acc = get_accuracy(out, label)
            num_epoch +=1 
            epoch_loss += loss.item()
            epoch_acc += acc.item()
    
    return epoch_loss/num_epoch, epoch_acc/num_epoch          
           
def train(batch_data, model, criterion, optimizer, get_accuracy):
    model.train()
    num_epoch = epoch_loss = epoch_acc = 0
    for text, label in batch_data:
        model.zero_grad()
        out = model(text).squeeze(1)
        loss = criterion(out, label)
        acc = get_accuracy(out, label)
        loss.backward()
        optimizer.step()
        num_epoch +=1 
        epoch_loss += loss.item()
        epoch_acc += acc.item()
    
    return epoch_loss/num_epoch, epoch_acc/num_epoch
        
           
NUM_EPOCH = 30
best_valid_acc = -1

dev_data = generate_batch(num_dev_data)
for epoch in range(NUM_EPOCH):
    train_data = generate_batch(num_train_data)
    train_loss, train_acc = train(train_data, model, criterion, optimizer, get_accuracy)
    valid_loss, valid_acc = evaluate(dev_data, model, criterion, get_accuracy)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
        torch.save(model.state_dict(),'self-attn-model.pt')
    
    print(f'Epoch: {epoch+1:02} :')
    print(f'\t Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Valid Loss: {valid_loss:.4f} | Valid Acc: {valid_acc*100:.2f}%')
    
           
Epoch: 01 :
	 Train Loss: 0.5429 | Train Acc: 72.38%
	 Valid Loss: 0.4695 | Valid Acc: 78.12%
Epoch: 02 :
	 Train Loss: 0.2947 | Train Acc: 88.60%
	 Valid Loss: 0.5573 | Valid Acc: 79.02%
Epoch: 03 :
	 Train Loss: 0.2277 | Train Acc: 91.26%
	 Valid Loss: 0.6375 | Valid Acc: 79.80%
Epoch: 04 :
	 Train Loss: 0.1964 | Train Acc: 92.50%
	 Valid Loss: 0.7260 | Valid Acc: 80.25%
Epoch: 05 :
	 Train Loss: 0.1759 | Train Acc: 93.27%
	 Valid Loss: 0.7696 | Valid Acc: 82.25%
Epoch: 06 :
	 Train Loss: 0.1642 | Train Acc: 93.81%
	 Valid Loss: 0.8865 | Valid Acc: 80.58%
Epoch: 07 :
	 Train Loss: 0.1538 | Train Acc: 94.13%
	 Valid Loss: 0.9686 | Valid Acc: 79.35%
Epoch: 08 :
	 Train Loss: 0.1461 | Train Acc: 94.53%
	 Valid Loss: 0.9697 | Valid Acc: 81.81%
Epoch: 09 :
	 Train Loss: 0.1409 | Train Acc: 94.63%
	 Valid Loss: 1.1235 | Valid Acc: 79.46%
Epoch: 10 :
	 Train Loss: 0.1356 | Train Acc: 94.89%
	 Valid Loss: 1.1045 | Valid Acc: 81.14%
Epoch: 11 :
	 Train Loss: 0.1326 | Train Acc: 95.05%
	 Valid Loss: 1.2394 | Valid Acc: 80.13%
Epoch: 12 :
	 Train Loss: 0.1296 | Train Acc: 95.11%
	 Valid Loss: 1.3044 | Valid Acc: 79.35%
Epoch: 13 :
	 Train Loss: 0.1265 | Train Acc: 95.18%
	 Valid Loss: 1.4154 | Valid Acc: 79.02%
Epoch: 14 :
	 Train Loss: 0.1242 | Train Acc: 95.28%
	 Valid Loss: 1.4540 | Valid Acc: 79.35%
Epoch: 15 :
	 Train Loss: 0.1219 | Train Acc: 95.36%
	 Valid Loss: 1.5596 | Valid Acc: 78.91%
Epoch: 16 :
	 Train Loss: 0.1208 | Train Acc: 95.40%
	 Valid Loss: 1.5866 | Valid Acc: 78.68%
Epoch: 17 :
	 Train Loss: 0.1190 | Train Acc: 95.48%
	 Valid Loss: 1.6453 | Valid Acc: 78.35%
Epoch: 18 :
	 Train Loss: 0.1175 | Train Acc: 95.51%
	 Valid Loss: 1.6904 | Valid Acc: 79.35%
Epoch: 19 :
	 Train Loss: 0.1170 | Train Acc: 95.59%
	 Valid Loss: 1.7406 | Valid Acc: 79.24%
Epoch: 20 :
	 Train Loss: 0.1160 | Train Acc: 95.57%
	 Valid Loss: 1.8767 | Valid Acc: 77.01%
Epoch: 21 :
	 Train Loss: 0.1149 | Train Acc: 95.67%
	 Valid Loss: 1.8612 | Valid Acc: 78.68%
Epoch: 22 :
	 Train Loss: 0.1142 | Train Acc: 95.62%
	 Valid Loss: 1.9032 | Valid Acc: 78.46%
Epoch: 23 :
	 Train Loss: 0.1126 | Train Acc: 95.68%
	 Valid Loss: 1.9864 | Valid Acc: 77.90%
Epoch: 24 :
	 Train Loss: 0.1118 | Train Acc: 95.78%
	 Valid Loss: 2.0475 | Valid Acc: 76.67%
Epoch: 25 :
	 Train Loss: 0.1113 | Train Acc: 95.76%
	 Valid Loss: 2.0904 | Valid Acc: 77.79%
Epoch: 26 :
	 Train Loss: 0.1100 | Train Acc: 95.85%
	 Valid Loss: 2.1268 | Valid Acc: 77.01%
Epoch: 27 :
	 Train Loss: 0.1105 | Train Acc: 95.75%
	 Valid Loss: 2.1717 | Valid Acc: 77.90%
Epoch: 28 :
	 Train Loss: 0.1092 | Train Acc: 95.88%
	 Valid Loss: 2.2729 | Valid Acc: 77.46%
Epoch: 29 :
	 Train Loss: 0.1091 | Train Acc: 95.79%
	 Valid Loss: 2.3031 | Valid Acc: 78.01%
Epoch: 30 :
	 Train Loss: 0.1082 | Train Acc: 95.95%
	 Valid Loss: 2.3582 | Valid Acc: 77.34%
           
<All keys matched successfully>
           
test_data = generate_batch(num_test_data)
test_loss, test_acc = evaluate(test_data, model, criterion, get_accuracy)
print(f'Test Loss: {test_loss:.4f} |  Test Acc: {test_acc*100:.2f}%')
           
Test Loss: 0.6522 |  Test Acc: 81.61%
           

繼續閱讀