當訓練好DeeplabV3+模型後,生成了.ckpt檔案,下一步希望利用模型進行真實的場景預測,通用的做法生成.pb檔案。這樣做的好處是:1. 将變量轉換為常量,減小模型,2. 便于其他語言調用(.pb檔案可直接被C/C++/Java/NodeJS等讀取)。
運作 export_model.py 生成模型
利用官方代碼檔案export_model.py生成 frozen_inference_graph.pb 檔案,利用該檔案進行預測。這裡需要注意的是:必須知道模型的input和output,這可以通過檢視代碼獲得。
python export_model.py \
--checkpoint_path="./checkpoint_1/model.ckpt-518495" \ # 訓練得到的ckpt檔案
--export_path="./output_model/frozen_inference_graph.pb" \ # 需要導出的模型名稱
--model_variant="xception_65" \
--atrous_rates=6 \
--atrous_rates=12 \
--atrous_rates=18 \
--output_stride=16 \
--decoder_output_stride=4 \
--num_classes=3 \
--crop_size=1440 \ # 需要預測圖檔的大小,如果預測的圖像大小比該值大,将報錯
--crop_size=1440 \
--inference_scales=1.0
源碼清楚顯示:input_name是'ImageTensor',shape是[1, None, None, 3],資料類型是tf.uint8,你也可以在此處更改資料類型,output_name是 'SemanticPredictions'。知道了input和outp,就可以進行預測了。
# export_model.py部分代碼
# Input name of the exported model.
_INPUT_NAME = 'ImageTensor'
# Output name of the exported model.
_OUTPUT_NAME = 'SemanticPredictions'
def _create_input_tensors():
"""Creates and prepares input tensors for DeepLab model.
This method creates a 4-D uint8 image tensor 'ImageTensor' with shape
[1, None, None, 3]. The actual input tensor name to use during inference is
'ImageTensor:0'.
"""
# input_preprocess takes 4-D image tensor as input.
input_image = tf.placeholder(tf.uint8, [1, None, None, 3], name=_INPUT_NAME)
預測單張圖檔
利用生成的.pb檔案預測新的圖檔:
1. 讀取圖檔并轉換為uint8,shape為[1, None, None, 3]格式;
2. 讀取.pb檔案,指明輸入和輸出;
3.求輸出,輸出的label為0, 1, 2…,是以看上出全黑;
4. 結果後處理,這一步就因人而異了
import tensorflow as tf
from keras.preprocessing.image import load_img, img_to_array
img = load_img(img_path) # 輸入預測圖檔的url
img = img_to_array(img)
img = np.expand_dims(img, axis=0).astype(np.uint8) # uint8是之前導出模型時定義的
# 加載模型
sess = tf.Session()
with open("frozen_inference_graph.pb", "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
output = tf.import_graph_def(graph_def, input_map={"ImageTensor:0": img},
return_elements=["SemanticPredictions:0"])
# input_map 就是指明 輸入是什麼;
# return_elements 就是指明輸出是什麼;兩者在前面已介紹
result = sess.run(output)
print(result[0].shape) # (1, height, width)
結果展示:
我的工況是一個三分類問題,輸入圖檔1040X868,在個人筆記本上,預測比較慢:40s,部署在伺服器上,0.4s.
Tips-注意TF版本:
在預測時,經常出現記憶體溢出的問題,但模型隻有157MB,記憶體為16GB,一直不得解。原Tensorflow是1.8.0,從Github某地方下載下傳,CUDA用9.1版本。後來下載下傳官方Tensorflow1.8.0,CUDA支援9.0版本,不得不重新安裝CUDA,記憶體溢出問題消失。是以:請從正規管道下載下傳軟體。