天天看點

【tensorflow2.0】fashion mnist 資料集訓練

資料集介紹

使用Fashion MNIST資料集,其中包含10個類别的70,000個灰階圖像。圖像顯示了低分辨率(28 x 28像素)的單個衣​​物,如下所示(圖檔來自tensorflow官方文檔):

【tensorflow2.0】fashion mnist 資料集訓練

圖像是28x28 NumPy數組,像素值範圍是0到255。标簽是整數數組,範圍是0到9。這些對應于圖像表示的衣服類别:

Label Class
T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot

代碼

import tensorflow as tf
import pandas as pd
import matplotlib as mlt
import matplotlib.pyplot as plt
print(tf.__version__)
print(tf.test.is_gpu_available())
# 加載mnist資料集

fashion_mnist = tf.keras.datasets.fashion_mnist

(X_train_all, Y_train_all),(X_test, Y_test) = fashion_mnist.load_data()

X_train_all = X_train_all/255
X_test = X_test/255

# 将訓練集拆分出驗證集,讓模型每跑完一次資料就驗證一次準确度
x_valid, x_train  = X_train_all[:5000], X_train_all[5000:]
y_valid, y_train  = Y_train_all[:5000], Y_train_all[5000:]

# 模型建構 使用的是tf.keras.Sequential
# relu:y=max(0,x) 即取0和x中的最大值
# softmax: 将輸出向量變成機率分布,例如 x = [x1, x2, x3], 則
#                                     y = [e^x1/sum, e^x2/sum, e^x3/sum],
#                                     sum = e^x1+e^x2+e^x3

model = tf.keras.models.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28,28)), # Flatten函數的作用是将輸入的二維數組進行展開,使其變成一維的數組
        tf.keras.layers.Dense(256,activation='relu'), # 建立權連接配接層,激活函數使用relu
        tf.keras.layers.Dropout(0.2),                 # 使用dropout緩解過拟合的發生
        tf.keras.layers.Dense(10, activation='softmax') # 輸出層
    ]
)

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy', # 損失函數使用交叉熵
              metrics=['accuracy'])

model.summary() # 列印模型資訊

# history記錄模型訓練過程中的一些值
history = model.fit(x_train, y_train, epochs=5,
                    validation_data=(x_valid,y_valid))

print('history:',history.history)

# 将history中的資料以圖檔表示出來
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.ylim(0,1)
plt.show()

model.evaluate(X_test,  Y_test, verbose=2)



           

模型結構

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
flatten (Flatten)            (None, 784)               0
_________________________________________________________________
dense (Dense)                (None, 256)               200960
_________________________________________________________________
dropout (Dropout)            (None, 256)               0
_________________________________________________________________
dense_1 (Dense)              (None, 10)                2570
=================================================================
Total params: 203,530
Trainable params: 203,530
Non-trainable params: 0
_________________________________________________________________
           

訓練過程

Train on 55000 samples, validate on 5000 samples
Epoch 1/5
55000/55000 [==============================] - 6s 106us/sample - loss: 0.5183 - accuracy: 0.8162 - val_loss: 0.3885 - val_accuracy: 0.8598
Epoch 2/5
55000/55000 [==============================] - 5s 95us/sample - loss: 0.3908 - accuracy: 0.8570 - val_loss: 0.3656 - val_accuracy: 0.8696
Epoch 3/5
55000/55000 [==============================] - 5s 95us/sample - loss: 0.3585 - accuracy: 0.8697 - val_loss: 0.3203 - val_accuracy: 0.8836
Epoch 4/5
55000/55000 [==============================] - 5s 95us/sample - loss: 0.3358 - accuracy: 0.8767 - val_loss: 0.3326 - val_accuracy: 0.8796
Epoch 5/5
55000/55000 [==============================] - 5s 98us/sample - loss: 0.3237 - accuracy: 0.8808 - val_loss: 0.3297 - val_accuracy: 0.8824
           
【tensorflow2.0】fashion mnist 資料集訓練

繼續閱讀