天天看點

TensorFlow學習記錄:讀取ckpt模型裡面的張量名字和讀取pb模型裡面的張量名字

一般情況下,我們得到一個模型後都想知道模型裡面的張量,下面分别從ckpt模型和pb模型中讀取裡面的張量名字。

1.讀取ckpt模型裡面的張量

首先,ckpt模型需包含以下檔案,一個都不能少

TensorFlow學習記錄:讀取ckpt模型裡面的張量名字和讀取pb模型裡面的張量名字

然後編寫代碼,将所有張量的名字都儲存到tensor_name_list_ckpt.txt檔案中

import tensorflow as tf

#直接讀取圖的結構,不需要手動重新定義 
meta_graph = tf.train.import_meta_graph("model.ckpt.meta")

with tf.Session()as sess:
	meta_graph.restore(sess,"D:/Face_recognition_github/20180402-114759/model.ckpt")

	tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
	with open("tensor_name_list_ckpt.txt",'a+')as f:
		for tensor_name in tensor_name_list:
			f.write(tensor_name+"\n")
			# print(tensor_name,'\n')
		f.close()
           

運作結果截圖(部分)

TensorFlow學習記錄:讀取ckpt模型裡面的張量名字和讀取pb模型裡面的張量名字

2.讀取pb模型裡面的張量

需要一個pb檔案

TensorFlow學習記錄:讀取ckpt模型裡面的張量名字和讀取pb模型裡面的張量名字

編寫代碼

import tensorflow as tf

model_path = "D:/Face_recognition_github/20180402-114759/20180402-114759.pb"

with tf.gfile.FastGFile(model_path,'rb')as f:
	graph_def = tf.GraphDef()
	graph_def.ParseFromString(f.read())
	tf.import_graph_def(graph_def,name='')

	tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
	with open('tensor_name_list_pb.txt','a')as t:
		for tensor_name in tensor_name_list:
			t.write(tensor_name+'\n')
			print(tensor_name,'\n')
		t.close()
           

順便再檢視pb模型裡面的張量的屬性(ckpt模型的操作類似),儲存到txt檔案中

import tensorflow as tf

model_path = "/home/boss/Study/face_recognition_flask/20180402-114759/model.pb"

with tf.gfile.FastGFile(model_path,'rb')as f:
	graph_def = tf.GraphDef()
	graph_def.ParseFromString(f.read())
	tf.import_graph_def(graph_def,name='')

	# tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
	# with open('tensor_name_list_pb.txt','a')as t:
	# 	for tensor_name in tensor_name_list:
	# 		t.write(tensor_name+'\n')
	# 		print(tensor_name,'\n')
	# 	t.close()
	with tf.Session()as sess:
		op_list = sess.graph.get_operations()
		with open("model裡面張量的屬性.txt",'a+')as f:
			for index,op in enumerate(op_list):
				f.write(str(op.name)+"\n")                   #張量的名稱
				f.write(str(op.values())+"\n")              #張量的屬性
           

運作結果截圖(部分)

TensorFlow學習記錄:讀取ckpt模型裡面的張量名字和讀取pb模型裡面的張量名字

繼續閱讀