天天看点

TensorFlow学习记录:读取ckpt模型里面的张量名字和读取pb模型里面的张量名字

一般情况下,我们得到一个模型后都想知道模型里面的张量,下面分别从ckpt模型和pb模型中读取里面的张量名字。

1.读取ckpt模型里面的张量

首先,ckpt模型需包含以下文件,一个都不能少

TensorFlow学习记录:读取ckpt模型里面的张量名字和读取pb模型里面的张量名字

然后编写代码,将所有张量的名字都保存到tensor_name_list_ckpt.txt文件中

import tensorflow as tf

#直接读取图的结构,不需要手动重新定义 
meta_graph = tf.train.import_meta_graph("model.ckpt.meta")

with tf.Session()as sess:
	meta_graph.restore(sess,"D:/Face_recognition_github/20180402-114759/model.ckpt")

	tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
	with open("tensor_name_list_ckpt.txt",'a+')as f:
		for tensor_name in tensor_name_list:
			f.write(tensor_name+"\n")
			# print(tensor_name,'\n')
		f.close()
           

运行结果截图(部分)

TensorFlow学习记录:读取ckpt模型里面的张量名字和读取pb模型里面的张量名字

2.读取pb模型里面的张量

需要一个pb文件

TensorFlow学习记录:读取ckpt模型里面的张量名字和读取pb模型里面的张量名字

编写代码

import tensorflow as tf

model_path = "D:/Face_recognition_github/20180402-114759/20180402-114759.pb"

with tf.gfile.FastGFile(model_path,'rb')as f:
	graph_def = tf.GraphDef()
	graph_def.ParseFromString(f.read())
	tf.import_graph_def(graph_def,name='')

	tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
	with open('tensor_name_list_pb.txt','a')as t:
		for tensor_name in tensor_name_list:
			t.write(tensor_name+'\n')
			print(tensor_name,'\n')
		t.close()
           

顺便再查看pb模型里面的张量的属性(ckpt模型的操作类似),保存到txt文件中

import tensorflow as tf

model_path = "/home/boss/Study/face_recognition_flask/20180402-114759/model.pb"

with tf.gfile.FastGFile(model_path,'rb')as f:
	graph_def = tf.GraphDef()
	graph_def.ParseFromString(f.read())
	tf.import_graph_def(graph_def,name='')

	# tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
	# with open('tensor_name_list_pb.txt','a')as t:
	# 	for tensor_name in tensor_name_list:
	# 		t.write(tensor_name+'\n')
	# 		print(tensor_name,'\n')
	# 	t.close()
	with tf.Session()as sess:
		op_list = sess.graph.get_operations()
		with open("model里面张量的属性.txt",'a+')as f:
			for index,op in enumerate(op_list):
				f.write(str(op.name)+"\n")                   #张量的名称
				f.write(str(op.values())+"\n")              #张量的属性
           

运行结果截图(部分)

TensorFlow学习记录:读取ckpt模型里面的张量名字和读取pb模型里面的张量名字

继续阅读