天天看點

Tensorflow2.0學習筆記------堆疊模型

早有耳聞Tensorflow2.0相比1.x版本是重大飛躍,需要趕緊給自己充充電!

所有相關代碼測試均在colab中進行,不會配置的同學可以參見https://blog.csdn.net/hesongzefairy/article/details/105411219

2.0版本內建了keras後,使用tf.keras來搭建網絡,完全繼承了keras的優勢之處

Step1:導入tf.keras(這這個導入在Pycharm會顯示黃色,其實不是報錯,是Pycharm的bug)

import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
           

Step2:使用堆疊模型tf.keras.Sequential()建構一個四層網絡

model = tf.keras.Sequential()
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
           

Step3:設定訓練流程

model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
             loss=tf.keras.losses.categorical_crossentropy,
             metrics=[tf.keras.metrics.categorical_accuracy])
           

Step4:制作資料集

train_x = np.random.random((1000, 100))
train_y = np.random.random((1000, 10))

val_x = np.random.random((200, 100))
val_y = np.random.random((200, 10))
           

Step5:訓練

model.fit(train_x, train_y, epochs=10, batch_size=100,
          validation_data=(val_x, val_y))
           
Tensorflow2.0學習筆記------堆疊模型

Step6:改進資料集的儲存方式tf.data并重新訓練

dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
dataset = dataset.batch(32)
dataset = dataset.repeat()

val_dataset = tf.data.Dataset.from_tensor_slices((val_x, val_y))
val_dataset = val_dataset.batch(32)
val_dataset = val_dataset.repeat()

model.fit(dataset, epochs=10, steps_per_epoch=30,
          validation_data=val_dataset, validation_steps=3)
           

Step7:模型評估與測試(未使用tf.data和使用tf.data分别測試)

test_x = np.random.random((2000, 100))
test_y = np.random.random((2000, 10))
model.evaluate(test_x, test_y, batch_size=32)

test_data = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_data = test_data.batch(32).repeat()
model.evaluate(test_data, steps=30)

# predict
result = model.predict(test_x, batch_size=32)
print(result)
           
Tensorflow2.0學習筆記------堆疊模型

參考資料:

https://zhuanlan.zhihu.com/p/58825020

https://colab.research.google.com/notebooks/intro.ipynb#scrollTo=-Rh3-Vt9Nev9

https://www.tensorflow.org/

繼續閱讀