天天看點

TensorFlow .ckpt 模型的儲存與恢複.ckpt一、TensorFlow正常模型加載方法

.ckpt

.ckpt 全稱為 checkpoint,代表着一個檢查點,即為 model 訓練過程中的一個快照,可能是在訓練開始,也可能是在訓練完成。

tensorflow新版本不會生成.ckpt檔案,你隻要将這四個檔案放入一個檔案夾并命名,測試時直接調用這個檔案夾就行了,這就相當于舊版本的.ckpt檔案。

一、TensorFlow正常模型加載方法

儲存模型

tf.train.Saver()類,.save(sess, ckpt檔案目錄)方法

TensorFlow .ckpt 模型的儲存與恢複.ckpt一、TensorFlow正常模型加載方法

var_list是字典形式{變量名字元串: 變量符号},相對應的restore也根據同樣形式的字典将ckpt中的字元串對應的變量加載給程式中的符号。

加載模型

當我們基于checkpoint檔案(ckpt)加載參數時,實際上我們使用Saver.restore取代了initializer的初始化

TensorFlow .ckpt 模型的儲存與恢複.ckpt一、TensorFlow正常模型加載方法

checkpoint檔案會記錄儲存資訊,通過它可以定位最新儲存的模型:

ckpt = tf.train.get_checkpoint_state('./model/')
print(ckpt.model_checkpoint_path)
           
TensorFlow .ckpt 模型的儲存與恢複.ckpt一、TensorFlow正常模型加載方法

.meta檔案儲存了目前圖結構

.index檔案儲存了目前參數名

.data檔案儲存了目前參數值

1. 不加載圖結構,隻加載參數

由于實際上我們參數儲存的都是Variable變量的值,是以其他的參數值(例如batch_size)等,我們在restore時可能希望修改,但是圖結構在train時一般就已經确定了,是以我們可以使用tf.Graph().as_default()建立一個預設圖(建議使用上下文環境),利用這個新圖修改和變量無關的參值大小,進而達到目的。

'''
使用原網絡儲存的模型加載到自己重新定義的圖上
可以使用python變量名加載模型,也可以使用節點名
'''
import AlexNet as Net
import AlexNet_train as train
import random
import tensorflow as tf
 
IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'
 
with tf.Graph().as_default() as g:
 
    x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])
    y = Net.inference_1(x, N_CLASS=5, train=False)
 
    with tf.Session() as sess:
        # 程式前面得有 Variable 供 save or restore 才不報錯
        # 否則會提示沒有可儲存的變量
        saver = tf.train.Saver()
 
        ckpt = tf.train.get_checkpoint_state('./model/')
        img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
        img = sess.run(tf.expand_dims(tf.image.resize_images(
            tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))
 
        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess,'./model/model.ckpt-0')
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            res = sess.run(y, feed_dict={x: img})
            print(global_step,sess.run(tf.argmax(res,1)))
           

2. 加載圖結構和參數

'''
直接使用使用儲存好的圖
無需加載python定義的結構,直接使用節點名稱加載模型
由于節點形狀已經定下來了,是以有不便之處,placeholder定義batch後單張傳會報錯
現階段不推薦使用,以後如果了解深入了可能會找到使用方法
'''
import AlexNet_train as train
import random
import tensorflow as tf
 
IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'
 
 
ckpt = tf.train.get_checkpoint_state('./model/')                          # 通過檢查點檔案鎖定最新的模型
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')   # 載入圖結構,儲存在.meta檔案中
 
with tf.Session() as sess:
    saver.restore(sess,ckpt.model_checkpoint_path)                        # 載入參數,參數儲存在兩個檔案中,不過restore會自己尋找
 
    img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
    img = sess.run(tf.image.resize_images(
        tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)))
    imgs = []
    for i in range(128):
       imgs.append(img)
    print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0'),feed_dict={'Placeholder:0': imgs}))
 
    '''
    img = sess.run(tf.expand_dims(tf.image.resize_images(
        tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)), 0))
    print(img)
    imgs = []
    for i in range(128):
        imgs.append(img)
    print(sess.run(tf.get_default_graph().get_tensor_by_name('conv1:0'),
                   feed_dict={'Placeholder:0':img}))
           

注意,在所有兩種方式中都可以通過調用節點名稱使用節點輸出張量,節點.name屬性傳回節點名稱。

3. 簡化版本

# 連同圖結構一同加載
ckpt = tf.train.get_checkpoint_state('./model/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
with tf.Session() as sess:
    saver.restore(sess,ckpt.model_checkpoint_path)
             
# 隻加載資料,不加載圖結構,可以在新圖中改變batch_size等的值
# 不過需要注意,Saver對象執行個體化之前需要定義好新的圖結構,否則會報錯
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('./model/')
    saver.restore(sess,ckpt.model_checkpoint_path)
           

參考:疊加态的貓-『TensorFlow』模型載入方法彙總

繼續閱讀