天天看點

Early stoping和checkpoint在tensofrflow中的使用在訓練中使用early stoping終止模型和使用checkpoint儲存模型

在訓練中使用early stoping終止模型和使用checkpoint儲存模型

當我們訓練模型時,如果epoch設定太長,常常希望可以在loss不再下降或者accuracy不再提高時終止訓練,獲得模型,避免模型浪費時間,這時可以使用tensorflow 中的early stoping終止模型和使用checkpoint儲存模型:

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
import pandas as pd
import numpy as np
from scipy import sparse
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint #導入early stopping,checkpoint
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

batchsz = 256
....
....
.... # 用過pandas讀取資料并建構你自己的x_train,y_train,x_val,y_val,x_test,y_test
db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db = db.shuffle(6000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.batch(batchsz)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsz)

sample = next(iter(db))
print(sample[0].shape, sample[1].shape)

network = Sequential([layers.Conv1D(512, kernel_size=5, strides=3,padding='same',activation=tf.nn.relu),
                     layers.Dropout(0.4),
                     layers.Conv1D(256, kernel_size=5, strides=3,padding='same',activation=tf.nn.relu),
                     layers.Dropout(0.4),
                     layers.Conv1D(128, kernel_size=5, strides=3,padding='same',activation=tf.nn.relu),
                     layers.Dropout(0.4),
                     layers.Flatten(),
                     layers.Dense(128, activation=tf.nn.relu),
                     layers.Dropout(0.4),
                     layers.Dense(1,activation='sigmoid')]) # 建構你自己network

network.build(input_shape=(None, 2000,4))
network.summary()

# 建構early-stopping,監測目标是測試集的accuracy,8個epoch如果提高不到0。001即終止
early_stopping = EarlyStopping(monitor='val_acc',min_delta=0.001,patience=8) 

#建構checkpoint,儲存model名字為‘conv.h5‘,監測名額為測試集準确率val_acc,model=‘max’儲存val_acc最大的的,save_best_only=true,儲存最好的模型
checkpoint=ModelCheckpoint('conv.h5', monitor='val_acc',model='max',verbose=1,save_best_only=True) 

network.compile(optimizer=optimizers.Adam(lr=0.001),
                loss='binary_crossentropy',
                metrics=['accuracy']
               )

# 在callback中使用early-stopping和checkpoint。
network.fit(db, epochs=100, validation_data=ds_val, validation_steps=2,callbacks=[early_stopping,checkpoint]) 
network.evaluate(db_test)

#測試
sample = next(iter(db_test))
x = sample[0]
y = sample[1] 
pred = network.predict(x)  
           

如果想進一步了解early stoppong,可以檢視jason的文章Use Early Stopping to Halt the Training of Neural Networks At the Right Time

繼續閱讀