天天看點

Tensorflow |(5)模型儲存與恢複、自定義指令行參數

Tensorflow |(1)初識Tensorflow

Tensorflow |(2)張量的階和資料類型及張量操作

Tensorflow |(3)變量的的建立、初始化、儲存和加載

Tensorflow |(4)名稱域、圖 和會話

Tensorflow |(5)模型儲存與恢複、自定義指令行參數

模型儲存與恢複、自定義指令行參數、

在我們訓練或者測試過程中,總會遇到需要儲存訓練完成的模型,然後從中恢複繼續我們的測試或者其它使用。模型的儲存和恢複也是通過tf.train.Saver類去實作,它主要通過将Saver類添加OPS儲存和恢複變量到checkpoint。它還提供了運作這些操作的便利方法。

tf.train.Saver(var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=tf.SaverDef.V2, pad_step_number=False)

var_list:指定将要儲存和還原的變量。它可以作為一個dict或一個清單傳遞.

max_to_keep:訓示要保留的最近檢查點檔案的最大數量。建立新檔案時,會删除較舊的檔案。如果無或0,則保留所有檢查點檔案。預設為5(即保留最新的5個檢查點檔案。)

keep_checkpoint_every_n_hours:多久生成一個新的檢查點檔案。預設為10,000小時

儲存

儲存我們的模型需要調用Saver.save()方法。save(sess, save_path, global_step=None),checkpoint是專有格式的二進制檔案,将變量名稱映射到張量值。

————————————————

版權聲明:本文為CSDN部落客「DrugAI」的原創文章,遵循CC 4.0 BY-SA版權協定,轉載請附上原文出處連結及本聲明。

原文連結:

https://blog.csdn.net/u012325865/article/details/104346708

import tensorflow as tf
 
a = tf.Variable([[1.0,2.0]],name="a")
b = tf.Variable([[3.0],[4.0]],name="b")
c = tf.matmul(a,b)
 
saver=tf.train.Saver()
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(sess.run(c))
    saver.save(sess, '/tmp/ckpt/test/matmul')      

我們可以看儲存了什麼檔案

Tensorflow |(5)模型儲存與恢複、自定義指令行參數

在多次訓練的時候可以指定多少間隔生成檢查點檔案

saver.save(sess, '/tmp/ckpt/test/matmu', global_step=0) ==> filename: 'matmu-0'
 
saver.save(sess, '/tmp/ckpt/test/matmu', global_step=1000) ==> filename: 'matmu-1000'      

恢複

恢複模型的方法是restore(sess, save_path),save_path是以前儲存參數的路徑,我們可以使用tf.train.latest_checkpoint來擷取最近的檢查點檔案(也惡意直接寫檔案目錄)

import tensorflow as tf
 
a = tf.Variable([[1.0,2.0]],name="a")
b = tf.Variable([[3.0],[4.0]],name="b")
c = tf.matmul(a,b)
 
saver=tf.train.Saver(max_to_keep=1)
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(sess.run(c))
    saver.save(sess, '/tmp/ckpt/test/matmul')
 
    # 恢複模型
    model_file = tf.train.latest_checkpoint('/tmp/ckpt/test/')
    saver.restore(sess, model_file)
    print(sess.run([c], feed_dict={a: [[5.0,6.0]], b: [[7.0],[8.0]]}))      

自定義指令行參數

tf.app.run(),預設調用main()函數,運作程式。main(argv)必須傳一個參數。

tf.app.flags,它支援應用從指令行接受參數,可以用來指定叢集配置等。在tf.app.flags下面有各種定義參數的類型

DEFINE_string(flag_name, default_value, docstring)

DEFINE_integer(flag_name, default_value, docstring)

DEFINE_boolean(flag_name, default_value, docstring)

DEFINE_float(flag_name, default_value, docstring)

第一個也就是參數的名字,路徑、大小等等。第二個參數提供具體的值。第三個參數是說明文檔

tf.app.flags.FLAGS,在flags有一個FLAGS标志,它在程式中可以調用到我們前面具體定義的flag_name.

import tensorflow as tf
 
FLAGS = tf.app.flags.FLAGS
 
tf.app.flags.DEFINE_string('data_dir', '/tmp/tensorflow/mnist/input_data',
                           """資料集目錄""")
tf.app.flags.DEFINE_integer('max_steps', 2000,
                            """訓練次數""")
tf.app.flags.DEFINE_string('summary_dir', '/tmp/summary/mnist/convtrain',
                           """事件檔案目錄""")
 
 
def main(argv):
    print(FLAGS.data_dir)
    print(FLAGS.max_steps)
    print(FLAGS.summary_dir)
    print(argv)
 
 
if __name__=="__main__":
    tf.app.run()      

繼續閱讀