目前時間是2021/6/16,matchzoo作為非常強大的文本比對庫,目前未更新到TF2.4以上版本,無法使用本機3090加速,為此我将源碼反向推導實作,使用TF2.4形式實作這些模型
"""
作者英俊
QQ 2227495940
所有權:西安建築科技大學草堂校區 信控樓704實驗室
"""
"暫定隻能扣13個模型出來"
'暫定隻能扣13個模型出來'
導入第三方庫包
# 導入tf 架構以及tf架構裡面的keras,之後的反推模型需要依賴這兩個庫
import tensorflow as tf #導入TF
from tensorflow import keras # 導入keras庫
from tensorflow.keras import backend as K # 導入背景
print(tf.__version__)
2.4.0
import matchzoo as mz
# 檢查目前可以支援那些模型,但是由于本人的研究水準,目前應該隻能剝離16種模型進行加速
mz.models.list_available()
Using TensorFlow backend.
[matchzoo.models.naive.Naive,
matchzoo.models.dssm.DSSM,
matchzoo.models.cdssm.CDSSM,
matchzoo.models.dense_baseline.DenseBaseline,
matchzoo.models.arci.ArcI,
matchzoo.models.arcii.ArcII,
matchzoo.models.match_pyramid.MatchPyramid,
matchzoo.models.knrm.KNRM,
matchzoo.models.duet.DUET,
matchzoo.models.drmmtks.DRMMTKS,
matchzoo.models.drmm.DRMM,
matchzoo.models.anmm.ANMM,
matchzoo.models.mvlstm.MVLSTM,
matchzoo.contrib.models.match_lstm.MatchLSTM,
matchzoo.contrib.models.match_srnn.MatchSRNN,
matchzoo.contrib.models.hbmp.HBMP,
matchzoo.contrib.models.esim.ESIM,
matchzoo.contrib.models.bimpm.BiMPM,
matchzoo.contrib.models.diin.DIIN,
matchzoo.models.conv_knrm.ConvKNRM]
資料讀取
import pandas as pd
# 讀取資料,将資料集加載進來,并且檢查
data_df = pd.read_csv("data/atec_nlp_sim_train_all.csv", sep="\t", header=None,
encoding="utf-8-sig", names=["sent1", "sent2", "label"])
# 擷取首部和尾部
data_df.head(10).append(data_df.tail(5))
sent1 | sent2 | label | |
---|---|---|---|
1 | 怎麼更改花呗手機号碼 | 我的花呗是以前的手機号碼,怎麼更改成現在的支付寶的号碼手機号 | 1 |
2 | 也開不了花呗,就這樣了?完事了 | 真的嘛?就是花呗付款 | |
3 | 花呗當機以後還能開通嗎 | 我的條件可以開通花呗借款嗎 | |
4 | 如何得知關閉借呗 | 想永久關閉借呗 | |
5 | 花呗掃碼付錢 | 二維碼掃描可以用花呗嗎 | |
6 | 花呗逾期後不能分期嗎 | 我這個 逾期後還完了 最低還款 後 能分期嗎 | |
7 | 花呗分期清空 | 花呗分期查詢 | |
8 | 借呗逾期短信通知 | 如何購買花呗短信通知 | |
9 | 借呗即将到期要還的賬單還能分期嗎 | 借呗要分期還,是嗎 | |
10 | 花呗為什麼不能支付手機交易 | 花呗透支了為什麼不可以繼續用了 | |
102473 | 花呗分期還一期後能用嗎 | 分期是還花呗嗎 | |
102474 | 我的支付寶手機号和花呗手機号不一樣怎麼辦 | 支付寶上的手機号,怎麼和花呗上的不一樣 | 1 |
102475 | 借呗這個月的分期晚幾天還可以嗎 | 借呗分期後可以更改分期時間嗎 | |
102476 | 我怎麼沒有花呗臨時額度了 | 花呗有零時額度嗎 | |
102477 | 怎麼授權芝麻信用給花呗 | 花呗授權聯系人怎麼授權 |
(102477, 3)
處理資料
import sklearn # 導入sklearn庫
from sklearn.model_selection import train_test_split #可以将資料集劃分訓練集/測試集/驗證集
# 為了防止運作速度緩慢,從總的資料集中抽取3500個樣本進行實驗
sent1=data_df.sent1.values[:3501]
sent2=data_df.sent2.values[:3501]
label=data_df.label.values[:3501]
# 這裡是訓練集
sent1_=sent1[:2501]
sent2_=sent2[:2501]
label_=label[:2501]
# 這裡是驗證集
_sent1=sent1[2501:]
_sent2=sent2[2501:]
_label=label[2501:]
# 将訓練集轉換成matchzoo需要的亞子
train_dev_data=pd.DataFrame()
train_dev_data['id_left']=range(2501)
train_dev_data['text_left']=sent1_
train_dev_data['id_right']=range(2501)
train_dev_data['text_right']=sent2_
train_dev_data['label']=label_
# 将測試集轉換成matchzoo需要的亞子
test_data=pd.DataFrame()
test_data['id_left']=range(1000)
test_data['text_left']=_sent1
test_data['id_right']=range(1000)
test_data['text_right']=_sent2
# test_data['label']=_label
# 擷取字典
from collections import Counter
c = Counter()
sent_data = data_df["sent1"].values + data_df["sent2"].values
for d in sent_data:
c.update(d)
word_counts = sorted(dict(c).items(), key=lambda x: x[1], reverse=True)
print(word_counts[:10])
# 擷取實作字典和idx的互相轉換
vocab_words = ["<PAD>", "<UNK>"]
for w, c in word_counts:
vocab_words.append(w)
vocab2id = {w: i for i, w in enumerate(vocab_words)}
id2vocab = {i: w for i, w in enumerate(vocab_words)}
print("vocab size: ", len(vocab2id))
print(list(vocab2id.items())[:5])
print(list(id2vocab.items())[:5])
# 儲存字典
with open("vocab.txt", "w", encoding="utf8") as f:
for w, i in vocab2id.items():
f.write(w+"\n")
# 文本轉換為字典
def sent2index(vocab2id, words):
return [vocab2id[w] if w in vocab2id else vocab2id["<UNK>"] for w in words]
# 将訓練集轉換成數字格式
train_dev_data["text_left"] = train_dev_data["text_left"].apply(lambda x: sent2index(vocab2id, x))
train_dev_data["text_right"] = train_dev_data["text_right"].apply(lambda x: sent2index(vocab2id, x))
# 将測試機轉換成數字形式
test_data["text_left"] = test_data["text_left"].apply(lambda x: sent2index(vocab2id, x))
test_data["text_right"] =test_data["text_right"].apply(lambda x: sent2index(vocab2id, x))
[('呗', 211063), ('花', 151025), ('麼', 83328), ('還', 80050), ('借', 69825), ('我', 67036), ('款', 62302), ('的', 61108), ('了', 56689), ('用', 52685)]
vocab size: 2175
[('<PAD>', 0), ('<UNK>', 1), ('呗', 2), ('花', 3), ('麼', 4)]
[(0, '<PAD>'), (1, '<UNK>'), (2, '呗'), (3, '花'), (4, '麼')]
id_left | text_left | id_right | text_right | label | |
---|---|---|---|---|---|
[1515, 15, 4, 187, 129, 3, 2, 57, 73, 43, 60] | [7, 9, 3, 2, 18, 23, 52, 9, 57, 73, 43, 60, 14... | 1 | |||
1 | 1 | [160, 31, 13, 10, 3, 2, 14, 95, 66, 89, 10, 20... | 1 | [564, 9, 179, 200, 95, 18, 3, 2, 25, 8] | |
2 | 2 | [3, 2, 155, 132, 23, 51, 5, 21, 31, 36, 16] | 2 | [7, 9, 243, 213, 22, 23, 31, 36, 3, 2, 6, 8, 16] | |
3 | 3 | [76, 85, 260, 227, 69, 96, 6, 2] | 3 | [92, 459, 142, 69, 96, 6, 2] | |
4 | 4 | [3, 2, 231, 60, 25, 34] | 4 | [180, 271, 60, 231, 679, 22, 23, 11, 3, 2, 16] | |
5 | 5 | [3, 2, 71, 19, 51, 13, 21, 29, 19, 16] | 5 | [7, 66, 37, 53, 71, 19, 51, 5, 124, 10, 53, 93... | |
6 | 6 | [3, 2, 29, 19, 94, 680] | 6 | [3, 2, 29, 19, 131, 261] | |
7 | 7 | [6, 2, 71, 19, 236, 58, 36, 227] | 7 | [76, 85, 175, 86, 3, 2, 236, 58, 36, 227] | |
8 | 8 | [6, 2, 569, 464, 41, 19, 48, 5, 9, 46, 78, 5, ... | 8 | [6, 2, 48, 29, 19, 5, 14, 18, 16] | |
9 | 9 | [3, 2, 26, 17, 4, 13, 21, 33, 25, 57, 73, 144,... | 9 | [3, 2, 401, 33, 10, 26, 17, 4, 13, 22, 23, 383... | |
2491 | 2491 | [3, 2, 5, 8, 107, 20, 97, 7, 9, 46, 78, 107, 2... | 2491 | [54, 11, 3, 2, 87, 49, 9, 107, 20, 97, 88, 44,... | 1 |
2492 | 2492 | [3, 2, 33, 25, 366, 374] | 2492 | [27, 28, 6, 2, 366, 374] | |
2493 | 2493 | [3, 2, 29, 19, 10, 21, 40, 52, 5, 8, 16] | 2493 | [3, 2, 29, 19, 15, 4, 174, 20, 5, 8] | |
2494 | 2494 | [3, 2, 22, 23, 115, 104, 350, 369, 16] | 2494 | [3, 2, 121, 42, 20, 30, 22, 23, 38, 350, 369, ... | |
2495 | 2495 | [57, 73, 70, 74, 63, 7, 9, 3, 2, 339, 505, 365... | 2495 | [7, 9, 3, 2, 70, 167, 14, 26, 17, 4, 7, 9, 33,... | |
2496 | 2496 | [3, 2, 75, 98, 29, 19, 10, 35, 213, 242, 240, ... | 2496 | [3, 2, 46, 78, 29, 19, 10, 35, 84, 32, 5, 94, ... | |
2497 | 2497 | [7, 18, 3, 2, 59, 108, 6] | 2497 | [3, 2, 22, 23, 356, 6, 61, 107, 20, 30, 16] | |
2498 | 2498 | [3, 2, 78, 13, 379, 10, 147, 102, 59, 194] | 2498 | [3, 2, 38, 102, 194, 253, 343] | |
2499 | 2499 | [3, 2, 71, 19, 111, 82, 56, 35, 183, 287, 16] | 2499 | [3, 2, 238, 12, 12, 12, 14, 71, 19, 12, 12, 12... | |
2500 | 2500 | [11, 3, 2, 59, 9, 34, 86, 145, 151, 50, 109, 9... | 2500 | [3, 2, 32, 238, 34, 14, 50, 9, 8, 147, 102, 59... |
id_left | text_left | id_right | text_right | |
---|---|---|---|---|
[27, 28, 6, 2, 128, 43, 5, 8] | [27, 28, 6, 2, 5, 8, 103, 19, 10] | |||
1 | 1 | [7, 31, 36, 3, 2, 51, 53, 20, 30, 5, 18, 12, 1... | 1 | [7, 32, 24, 31, 36, 3, 2, 16] |
2 | 2 | [6, 2, 18, 195, 44, 5, 16] | 2 | [6, 2, 417, 424, 141, 37, 44, 105, 48, 5, 16, ... |
3 | 3 | [3, 2, 29, 19, 14, 128, 19, 199, 82] | 3 | [3, 2, 22, 23, 29, 12, 12, 12, 19, 9, 16] |
4 | 4 | [682, 840, 36, 68, 122, 13, 21, 31, 36, 3, 2, ... | 4 | [18, 13, 18, 682, 840, 36, 68, 1395, 32, 24, 3... |
5 | 5 | [7, 126, 52, 9, 46, 43, 18, 22, 23, 11, 3, 2, 9] | 5 | [7, 9, 3, 2, 61, 38, 22, 23, 11, 16] |
6 | 6 | [26, 17, 4, 6, 2, 48, 124, 503, 362, 482] | 6 | [76, 85, 124, 503, 6, 2, 40, 20, 362, 482] |
7 | 7 | [3, 2, 417, 424, 45, 65, 34, 99, 22, 23, 91, 2... | 7 | [3, 2, 29, 12, 12, 12, 19, 111, 82, 45, 65] |
8 | 8 | [26, 17, 4, 7, 9, 162, 35, 37, 46, 43, 75, 98,... | 8 | [14, 7, 66, 37, 33, 25, 39, 26, 17, 4, 11, 13,... |
9 | 9 | [3, 2, 144, 290, 69, 96, 10, 15, 4, 5, 48, 5, 8] | 9 | [3, 2, 144, 290, 69, 96, 17, 4, 184, 201] |
990 | 990 | [27, 28, 6, 2, 155, 132, 505, 4, 202, 155] | 990 | [6, 2, 155, 132, 45, 142, 21, 132, 155] |
991 | 991 | [66, 228, 34, 7, 3, 2, 32, 24, 253, 46] | 991 | [50, 8, 41, 3, 2, 34, 15, 4, 32, 41, 46, 176, ... |
992 | 992 | [3, 2, 22, 160, 38, 86, 57, 73, 33, 25, 16, 20... | 992 | [3, 2, 33, 25, 24, 57, 854, 49, 16] |
993 | 993 | [7, 125, 296, 296, 629, 213, 172, 59, 108, 74,... | 993 | [21, 11, 3, 2, 53, 113, 18, 38, 7, 9, 368, 108... |
994 | 994 | [7, 92, 31, 36, 58, 11, 55, 3, 2, 33, 25] | 994 | [3, 2, 58, 11, 55, 47, 34, 31, 36, 10, 15, 4, 11] |
995 | 995 | [7, 9, 47, 8, 60, 13, 21, 31, 36, 3, 2, 97, 58... | 995 | [47, 34, 60, 21, 11, 3, 2, 25, 8] |
996 | 996 | [3, 2, 13, 304, 72, 5, 8, 56, 79, 72, 5, 8, 16] | 996 | [3, 2, 5, 8, 51, 217, 79, 72, 45, 62, 10] |
997 | 997 | [61, 38, 32, 24, 27, 28, 6, 2, 10, 16] | 997 | [6, 2, 160, 32, 24] |
998 | 998 | [70, 37, 44, 3, 2, 11, 10, 12, 12, 12, 14, 66,... | 998 | [3, 2, 20, 30, 12, 12, 12, 14, 118, 21, 29, 19... |
999 | 999 | [76, 181, 403, 437, 257, 32, 24, 330, 272, 79,... | 999 | [3, 2, 40, 20, 403, 437, 237, 122, 32, 24] |
# 文本的長度
max_len = 15
# 字典的長度
vocab_size = len(vocab2id)
# 詞向量長度
embedding_size = 128
# 確定text_left和text_right對齊
from tensorflow.keras.preprocessing.sequence import pad_sequences
## 訓練集對齊
sent1_datas = train_dev_data.text_left.values.tolist()
sent2_datas = train_dev_data.text_right.values.tolist()
labels = train_dev_data.label.values.tolist()
train_sent1=pad_sequences(sent1_datas, maxlen=max_len)
train_sent2 = pad_sequences(sent2_datas, maxlen=max_len)
## 測試集對齊
test_sent1_datas = test_data.text_left.values.tolist()
test_sent2_datas = test_data.text_right.values.tolist()
test_sent1=pad_sequences(test_sent1_datas, maxlen=max_len)
test_sent2 = pad_sequences(test_sent2_datas, maxlen=max_len)
# 劃分訓練 測試資料集
count = len(labels)
# idx1, idx2 = int(count*0.8), int(count*0.9)
idx1= int(count*0.8)
sent1_train, sent2_train = train_sent1[:idx1], train_sent2[:idx1]
sent1_val, sent2_val = train_sent1[idx1:], train_sent2[idx1:]
# sent1_test, sent2_test = sent1_datas[idx2:], sent2_datas[idx2:]
train_labels, val_labels= labels[:idx1], labels[idx1:]
print("train data: ", len(sent1_train), len(sent2_train), len(train_labels))
print("val data: ", len(sent1_val), len(sent2_val), len(val_labels))
# print("test data: ", len(sent1_test), len(sent2_test), len(test_labels))
import numpy as np # 将list轉換成array
train_labels=np.array(train_labels)
val_labels=np.array(val_labels)
train data: 2000 2000 2000
val data: 501 501 501
train_labels
array([1, 0, 0, ..., 0, 1, 0])