在訓練中使用early stoping終止模型和使用checkpoint儲存模型
當我們訓練模型時,如果epoch設定太長,常常希望可以在loss不再下降或者accuracy不再提高時終止訓練,獲得模型,避免模型浪費時間,這時可以使用tensorflow 中的early stoping終止模型和使用checkpoint儲存模型:
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
import pandas as pd
import numpy as np
from scipy import sparse
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint #導入early stopping,checkpoint
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
batchsz = 256
....
....
.... # 用過pandas讀取資料并建構你自己的x_train,y_train,x_val,y_val,x_test,y_test
db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db = db.shuffle(6000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.batch(batchsz)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsz)
sample = next(iter(db))
print(sample[0].shape, sample[1].shape)
network = Sequential([layers.Conv1D(512, kernel_size=5, strides=3,padding='same',activation=tf.nn.relu),
layers.Dropout(0.4),
layers.Conv1D(256, kernel_size=5, strides=3,padding='same',activation=tf.nn.relu),
layers.Dropout(0.4),
layers.Conv1D(128, kernel_size=5, strides=3,padding='same',activation=tf.nn.relu),
layers.Dropout(0.4),
layers.Flatten(),
layers.Dense(128, activation=tf.nn.relu),
layers.Dropout(0.4),
layers.Dense(1,activation='sigmoid')]) # 建構你自己network
network.build(input_shape=(None, 2000,4))
network.summary()
# 建構early-stopping,監測目标是測試集的accuracy,8個epoch如果提高不到0。001即終止
early_stopping = EarlyStopping(monitor='val_acc',min_delta=0.001,patience=8)
#建構checkpoint,儲存model名字為‘conv.h5‘,監測名額為測試集準确率val_acc,model=‘max’儲存val_acc最大的的,save_best_only=true,儲存最好的模型
checkpoint=ModelCheckpoint('conv.h5', monitor='val_acc',model='max',verbose=1,save_best_only=True)
network.compile(optimizer=optimizers.Adam(lr=0.001),
loss='binary_crossentropy',
metrics=['accuracy']
)
# 在callback中使用early-stopping和checkpoint。
network.fit(db, epochs=100, validation_data=ds_val, validation_steps=2,callbacks=[early_stopping,checkpoint])
network.evaluate(db_test)
#測試
sample = next(iter(db_test))
x = sample[0]
y = sample[1]
pred = network.predict(x)