天天看點

文本分類基于imdb一個文本分類的小demo—基于imdb資料集

一個文本分類的小demo—基于imdb資料集

準備整理整理keras跑各種nlp的文本算法。這裡首先是文本分類。

在colab上直接跑的,因為實在太窮了,還想體會gpu的感覺。。。

這裡參考大佬的做法

首先下載下傳資料

from keras.datasets import imdb
import keras
(train_x, train_y), (test_x, text_y)=keras.datasets.imdb.load_data(num_words=20000)
print("Training entries: {}, labels: {}".format(len(train_x), len(train_y)))
           
Training entries: 25000, labels: 25000
           

看一下第一個訓練語料

[1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 19193, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 10311, 8, 4, 107, 117, 5952, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 12118, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32]
           

這裡是吧對應的單詞轉換成了index,我們把他轉回去看看

word_index = imdb.get_word_index()
word2id = {k:(v+3) for k, v in word_index.items()}
word2id['<PAD>'] = 0
word2id['<START>'] = 1
word2id['<UNK>'] = 2
word2id['<UNUSED>'] = 3

id2word = {v:k for k, v in word2id.items()}
def get_words(sent_ids):
    return ' '.join([id2word.get(i, '?') for i in sent_ids])
sent = get_words(train_x[0])
print(sent)
           

輸出結果是語句的形式

<START> this film was just brilliant casting location scenery story direction everyone's really suited the part they played and you could just imagine being there robert <UNK> is an amazing actor and now the same being director <UNK> father came from the same scottish island as myself so i loved the fact there was a real connection with this film the witty remarks throughout the film were great it was just brilliant so much that i bought the film as soon as it was released for retail and would recommend it to everyone to watch and the fly fishing was amazing really cried at the end it was so sad and you know what they say if you cry at a film it must have been good and this definitely was also congratulations to the two little boy's that played the <UNK> of norman and paul they were just brilliant children are often left out of the praising list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all
           

但是這裡面每句話和每句話之間的字元數不一樣,keras要求輸入的字元數相同,是以需要對資料進行填充。

len:  218 189
           
train_x = keras.preprocessing.sequence.pad_sequences(
    train_x, value=word2id['<PAD>'],
    padding='post', maxlen=256
)
test_x = keras.preprocessing.sequence.pad_sequences(
    test_x, value=word2id['<PAD>'],
    padding='post', maxlen=256
)
print(train_x[0])
print('len: ',len(train_x[0]), len(train_x[1]))
           
[    1    14    22    16    43   530   973  1622  1385    65   458  4468
    66  3941     4   173    36   256     5    25   100    43   838   112
    50   670     2     9    35   480   284     5   150     4   172   112
   167     2   336   385    39     4   172  4536  1111    17   546    38
    13   447     4   192    50    16     6   147  2025    19    14    22
     4  1920  4613   469     4    22    71    87    12    16    43   530
    38    76    15    13  1247     4    22    17   515    17    12    16
   626    18 19193     5    62   386    12     8   316     8   106     5
     4  2223  5244    16   480    66  3785    33     4   130    12    16
    38   619     5    25   124    51    36   135    48    25  1415    33
     6    22    12   215    28    77    52     5    14   407    16    82
 10311     8     4   107   117  5952    15   256     4     2     7  3766
     5   723    36    71    43   530   476    26   400   317    46     7
     4 12118  1029    13   104    88     4   381    15   297    98    32
  2071    56    26   141     6   194  7486    18     4   226    22    21
   134   476    26   480     5   144    30  5535    18    51    36    28
   224    92    25   104     4   226    65    16    38  1334    88    12
    16   283     5    16  4472   113   103    32    15    16  5345    19
   178    32     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0     0     0     0     0
     0     0     0     0]
len:  256 256
           

可以到所有的代碼長度都是256,而且後面沒有字元的都填上了0.

設定模型結構,同時注意在堆疊層的時候,tensorflow.python.keras方式引用和keras引用不能混合使用

from keras.layers import Dense,Embedding,GlobalAveragePooling1D
model = keras.Sequential()
model.add(Embedding(20000,32))
model.add(GlobalAveragePooling1D())
model.add(Dense(32, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()
model.compile(optimizer='adam',
             loss='binary_crossentropy',
             metrics=['accuracy'])

           
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_2 (Embedding)      (None, None, 32)          640000    
_________________________________________________________________
global_average_pooling1d_2 ( (None, 32)                0         
_________________________________________________________________
dense_3 (Dense)              (None, 32)                1056      
_________________________________________________________________
dense_4 (Dense)              (None, 1)                 33        
=================================================================
Total params: 641,089
Trainable params: 641,089
Non-trainable params: 0
_________________________________________________________________
           

切分驗證集,對模型進行訓練

x_val = train_x[:10000]
x_train = train_x[10000:]

y_val = train_y[:10000]
y_train = train_y[10000:]

history = model.fit(x_train,y_train,
                   epochs=40, batch_size=512,
                   validation_data=(x_val, y_val),
                   verbose=1)

result = model.evaluate(test_x, text_y)
print(result)
           
Epoch 1/40
15000/15000 [==============================] - 0s 32us/step - loss: 0.6908 - accuracy: 0.6207 - val_loss: 0.6866 - val_accuracy: 0.7085
......
......
......
Epoch 38/40
15000/15000 [==============================] - 0s 20us/step - loss: 0.0237 - accuracy: 0.9975 - val_loss: 0.3760 - val_accuracy: 0.8782
Epoch 39/40
15000/15000 [==============================] - 0s 21us/step - loss: 0.0220 - accuracy: 0.9979 - val_loss: 0.3816 - val_accuracy: 0.8770
Epoch 40/40
15000/15000 [==============================] - 0s 20us/step - loss: 0.0205 - accuracy: 0.9979 - val_loss: 0.3867 - val_accuracy: 0.8772
           
import matplotlib.pyplot as plt
history_dict = history.history
history_dict.keys()
acc = history_dict['accuracy']
val_acc = history_dict['val_accuracy']
loss = history_dict['loss']
val_loss = history_dict['val_loss']
epochs = range(1, len(acc)+1)

plt.plot(epochs, loss, 'bo', label='train loss')
plt.plot(epochs, val_loss, 'b', label='val loss')
plt.title('Train and val loss')
plt.xlabel('Epochs')
plt.xlabel('loss')
plt.legend()
plt.show()
           
文本分類基于imdb一個文本分類的小demo—基于imdb資料集
plt.clf()   # clear figure

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.show
           
label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.show
           
文本分類基于imdb一個文本分類的小demo—基于imdb資料集

繼續閱讀