天天看點

tensorboard ckpt pb 模型的輸出節點_FlyAI小課堂:Tensorflow-模型和資料的儲存和載入...

tensorboard ckpt pb 模型的輸出節點_FlyAI小課堂:Tensorflow-模型和資料的儲存和載入...
tensorboard ckpt pb 模型的輸出節點_FlyAI小課堂:Tensorflow-模型和資料的儲存和載入...
在開始學習之前推薦大家可以多在FlyAI競賽服務平台多參加訓練和競賽,以此來提升自己的能力。FlyAI是為AI開發者提供資料競賽并支援GPU離線訓練的一站式服務平台。每周免費提供項目開源算法樣例,支援算法能力變現以及快速的疊代算法模型。

目錄

  1. 基本方法
  2. 不需重新定義網絡結構的方法
  3. saved_model方式

附件一:sklearn上的用法

一、基本方法

1.1 儲存

  • 定義變量
  • 使用saver.save()方法儲存
import tensorflow as tf
import numpy as np
W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w')
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b')
 
init = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess,"save/model.ckpt")
           

1.2 載入

  • 定義變量
  • 使用saver.restore()方法載入
import tensorflow as tf
import numpy as np
W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')
 
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess,"save/model.ckpt")
           

1.3 說明

  1. 建立saver時,可以指定需要存儲的tensor,如果沒有指定,則全部儲存;

2. 預設情況下:saver.save(sess,"save/model.ckpt")産生4個檔案:

tensorboard ckpt pb 模型的輸出節點_FlyAI小課堂:Tensorflow-模型和資料的儲存和載入...

checkpoint檔案儲存最新的模型;

model.ckpt.data 以字典的形式儲存權重偏置項等訓練參數

model.ckpt.index:存儲訓練好的參數索引

model.ckpt.meta : 元檔案(meta) 中儲存了MetaGraphDef 的持久化資料,即模型資料,計算圖的網絡結構資訊,完整的graph、variables、operation、collection。

3. 如何知道tensor的名字,最好是定義tensor的時候就指定名字,如上面代碼中的name='w',如果你沒有定義name,tensorflow也會設定name,隻不過這個name就是根據你的tensor或者操作的性質。是以最好還是自己定義好name。

【說明:這種方法不友善的在于,在使用模型的時候,必須把模型的結構重新定義一遍,然後載入對應名字的變量的值。但是很多時候我們都更希望能夠讀取一個檔案然後就直接使用模型,而不是還要把模型重新定義一遍。是以就需要使用另一種方法。】

二、不需重新定義網絡結構的方法

tf.train.import_meta_graph

import_meta_graph(

meta_graph_or_file,

clear_devices=False,

import_scope=None,

**kwargs

)

這個方法可以從檔案中将儲存的graph的所有節點加載到目前的default graph中,并傳回一個saver。也就是說,我們在儲存的時候,除了将變量的值儲存下來,其實還有将對應graph中的各種節點儲存下來,是以模型的結構也同樣被儲存下來了。

比如我們想要儲存計算最後預測結果的y,則應該在訓練階段将它添加到collection中。具體代碼如下:

2.1 儲存

和1.1一樣,保持不變

2.2 載入

import tensorflow as tf
import numpy as np
# W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')
# b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')
 
# saver = tf.train.Saver()
with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph("save/model.ckpt.meta")
    new_saver.restore(sess, "save/model.ckpt")
           

【個人了解:model.ckpt.meta : 儲存了TensorFlow計算圖的網絡結構資訊,import_meta_graph("save/model.ckpt.meta")這句拉取了結構,故不用重新定義。】

三、saved_model方式

實作了 (y = x + b)當輸入一個x 那麼輸出的結果y就等于輸入x加上b。

3.1 儲存
# Author:yifan
import os
import tensorflow as tf # 以下所有代碼預設導入
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
# 儲存模型路徑
PATH = './models'
# 建立一個變量
one = tf.Variable(3.0)
# 建立一個占位符,在 Tensorflow 中需要定義 placeholder 的 type ,一般為 float32 形式
num = tf.placeholder(tf.float32,name='input')
# 建立一個加法步驟,注意這裡并沒有直接計算
sum = tf.add(num,one,name='output')
# 初始化變量,如果定義Variable就必須初始化
init = tf.global_variables_initializer()
# 建立會話sess
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(sum, feed_dict={num: 5.0}))
    # #儲存SavedModel模型,使用以下三句
    builder = tf.saved_model.builder.SavedModelBuilder(PATH)
    signature = predict_signature_def(inputs={'input':num}, outputs={'output':sum})
    builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.SERVING],signature_def_map={'predict': signature})
    builder.save()
           
說明:
  1. tf.saved_model.builder.SavedModelBuilder:該方法的參數是傳入用于儲存模型的目錄名,目錄不用預先建立
  2. predict_signature_def:将輸入節點、輸出節點和名字(sig_name)傳入,生成一個簽名對象。傳入的參數為輸入和輸出以及他們的name。
  3. add_meta_graph_and_variables:将簽名加入到模型中
  4. 第一個參數傳入的是Session它包含了目前graph(圖)和Variables(變量)。
  5. 第二個參數是給目前需要儲存的MetaGraph 一個标簽,标簽名可以自定義,在之後載入模型的時候,需要根據這個标簽名去查找對應的MetaGraphDef,找不到就會報如 RuntimeError: MetaGraphDef associated with tags 'foo' could not be found in SavedModel這樣的錯。

---- 标簽也可以選用系統定義好的參數,tf.saved_model.tag_constants.SERVING與 tf.saved_model.tag_constants.TRAINING等。

運作結果:8.0,和儲存的模型:
tensorboard ckpt pb 模型的輸出節點_FlyAI小課堂:Tensorflow-模型和資料的儲存和載入...
  1. 執行完成後會在目前項目的目錄下生成models檔案夾,裡面包含variables檔案夾以及saved_model.pb檔案。
  2. variables儲存所有變量資訊,
  3. saved_model.pb用于儲存模型結構等資訊,含圖形結構。

注意:目前目錄下不可以存在models檔案夾,否則會報錯。

3.2 載入

# Author:yifan
import tensorflow as tf # 以下所有代碼預設導入
PATH = './models'
with tf.Session() as sess:
  tf.saved_model.loader.load(sess, ["serve"], PATH)
#一種載入變量的方式:
  in_x =tf.saved_model.loader.load(sess, ["serve"], PATH).signature_def['predict'].inputs['input'].name
#另一種載入變量的方式:
# in_x = sess.graph.get_tensor_by_name('input:0')     #加載輸入變量
  y = sess.graph.get_tensor_by_name('output:0')       #加載輸出變量
  scores = sess.run(y, feed_dict={in_x: 3.})
  print(scores)
           
說明:
  1. tf.saved_model.loader.load方法加載模型,第二個參數["serve"]為TAG标簽與存模型時候指定的字段相同(tf.saved_model.tag_constants.SERVING = "serve",本文中調用了tf的定義),第三個參數為模型路徑;
  2. tf.saved_model.loader.load(sess, ["serve"], PATH).signature_def['predict'].inputs['input'].name:用signature_def方法從導入的模型中提取簽名。和3)作用是一樣的。
  3. sess.graph.get_tensor_by_name:加載輸入輸出變量,注意這裡的變量name都需要加上":0",如"input"變為"input:0"
  4. 最後像之前那樣sess.run(),feed喂入資料,這裡輸入了個3.0。
結果:

6.0

3.3 檢視模型的Signature簽名

傳統的導入 需要用get_tensor_by_name , 這樣就需要記錄tensor的name熟悉,很麻煩。通過signature,我們可以指定變量的别名,友善存取。但如果我們拿到了别人的含有signature一個SavedModel模型而且并不知道"标簽"那麼怎麼調用呢?

---Tensorflow官方已經為我們準備好了一個腳本,tensorflow下的

saved_model_cli.py

檔案可以幫到。

我們可以'WIN+R'輸入'cmd'然後回車打開你的CMD,然後指定路徑到你的模型目錄下,運作:

saved_model_cli show --dir=./ --all

tensorboard ckpt pb 模型的輸出節點_FlyAI小課堂:Tensorflow-模型和資料的儲存和載入...
列印出的資訊中我們就可以看到模型的輸入/輸出的名稱、資料類型、shape以及方法名稱。

附件一:sklearn上的用法

儲存參數:

from sklearn.externals import joblib

joblib.dump((centres, des_list,img_features), "imgs_features.pkl", compress=3)

讀取參數:

centres, des_list, img_features = joblib.load("imgs_features.pkl") #讀取儲存的特征

參考文章

【1】

TensorFlow 模型儲存/載入的兩種方法​blog.csdn.net

【2】

【tensorflow】儲存模型、再次加載模型等操作_I am what i am-CSDN部落格_儲存模型​blog.csdn.net

tensorboard ckpt pb 模型的輸出節點_FlyAI小課堂:Tensorflow-模型和資料的儲存和載入...

【3】

TensorFlow saved_model 子產品​blog.csdn.net

【4】

Tensorflow學習筆記(二)模型的儲存與加載(一 )​blog.csdn.net

tensorboard ckpt pb 模型的輸出節點_FlyAI小課堂:Tensorflow-模型和資料的儲存和載入...
更多精彩内容請通路FlyAI-AI競賽服務平台;為AI開發者提供資料競賽并支援GPU離線訓練的一站式服務平台;每周免費提供項目開源算法樣例,支援算法能力變現以及快速的疊代算法模型。 挑戰者,都在FlyAI!!!

繼續閱讀