问题详情:

Caused by op u’save/RestoreV2’, defined at:
File “demo.py”, line 25, in
result_dict = news_demo.newsAggreg({image_path})
File “/home/rszj/liutao/news_aggreg/news_demo.py”, line 32, in newsAggreg
predict = news_predict.run(images_path)
File “/home/rszj/liutao/news_aggreg/news_predict.py”, line 179, in run
saver = tf.train.Saver(restore_dict) # when you want to save model
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 1139, in init
self.build()
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 1170, in build
restore_sequentially=self._restore_sequentially)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 691, in build
restore_sequentially, reshape)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 407, in _AddRestoreOps
tensors = self.restore_op(filename_tensor, saveable, preferred_shard)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 247, in restore_op
[spec.tensor.dtype])[0])
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_io_ops.py”, line 640, in restore_v2
dtypes=dtypes, name=name)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py”, line 767, in apply_op
op_def=op_def)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”, line 2506, in create_op
original_op=self._default_original_op, op_def=op_def)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”, line 1269, in init
self._traceback = _extract_stack()
NotFoundError (see above for traceback): Key LSTM/basic_lstm_cell/bias not found in checkpoint
[[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT], _device=”/job:localhost/replica:0/task:0/cpu:0”](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
[[Node: save/RestoreV2_26/_101 = _Recvclient_terminated=false, recv_device=”/job:localhost/replica:0/task:0/gpu:0”, send_device=”/job:localhost/replica:0/task:0/cpu:0”, send_device_incarnation=1, tensor_name=”edge_212_save/RestoreV2_26”, tensor_type=DT_FLOAT, _device=”/job:localhost/replica:0/task:0/gpu:0”]]
参考一:参考下面这篇博客进行解决
tensorflow1.x版本加载saver.restore目录报错
在 ubuntu源代码 news_predict.py中
saver = tf.train.Saver(restore_dict)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
saver.restore(sess, r'model/model.ckpt')
TF1.0正确运行结果如下
ubuntu中修改上述代码如下:
saver = tf.train.Saver(restore_dict)
init = tf.global_variables_initializer()
savsess = tf.Session()
sess.run(init)
module_file = tf.train.latest_checkpoint('news_tf_model/model.ckpt') #ckpt路径抽调出来
if module_file is not None: # 添加一个判断语句,判断ckpt的路径文件
saver.restore(sess, module_file)
TF1.0和TF1.2运行结果全部是下面情况
参考一分析:
“if module_file is not None”该判断仅仅是做了一个文件是否存在的判断,并没有从根本上解决LSTM的报错问题,而代码不执行“ saver.restore(sess, module_file)”,就造成最后得到的结果为空了。
问题分析:
其实,仔细查看提示,会发现,报错的是指出“Key LSTM/basic_lstm_cell/bias not found in checkpoint”,那必然是LSTM中的bias定义出现了问题。所以笔者打印了saver = tf.train.Saver(restore_dict)中的“restore_dict”,发现TF1.0中和TF1.2中参数存在差异如下表
TF1.0 | TF1.2 |
---|---|
lstm/basic_lstm_cell/weights | lstm/basic_lstm_cell/kernel |
lstm/basic_lstm_cell/biases | lstm/basic_lstm_cell/bias |
原来问题确实出现在了LSTM上了,TF1.0和TF1.2的LSTM竟然在命名上出现了差异,好吧,看来要在TF1.2上使用TF1.0训练好的ckpt模型,必须要对应LSTM的上面两个参数了。
要对应参数其实有两种办法,第一种,修改ckpt模型中LSTM两个变量名;第二种,在predict时,做符合TF版本的LSTM变量名的对应。
参考二:
接下来,先介绍第一种方法,根据 基于tensorflow 1.0的图像叙事功能测试(model/im2txt) 这篇博客的内容,修改代码如下
# 由于版本不同,需要进行修改
def RenameCkpt():
# 1.0.1 : 1.2.1
vars_to_rename = {
"lstm/basic_lstm_cell/weights": "lstm/basic_lstm_cell/kernel",
"lstm/basic_lstm_cell/biases": "lstm/basic_lstm_cell/bias",
}
new_checkpoint_vars = {}
reader = tf.train.NewCheckpointReader(FLAGS.checkpoint_path)
for old_name in reader.get_variable_to_shape_map():
if old_name in vars_to_rename:
new_name = vars_to_rename[old_name]
else:
new_name = old_name
new_checkpoint_vars[new_name] = tf.Variable(reader.get_tensor(old_name))
init = tf.global_variables_initializer()
saver = tf.train.Saver(new_checkpoint_vars)
with tf.Session() as sess:
sess.run(init)
saver.save(sess, "/home/ndscbigdata/work/change/tf/gan/im2txt/ckpt/newmodel.ckpt-2000000")
print("checkpoint file rename successful... ")
上述方法是修改ckpt模型中的lstm/basic_lstm_cell/kernel 和 lstm/basic_lstm_cell/bias,修改完成后的ckpt仅仅能够在1.2.1上正常运行,同样因为参数名修改了变得版本不对应,而无法在1.0.1上运行。
参考三:
根据以上描述,笔者想到了方法二,按照正常逻辑,修改restore_dict中的lstm/basic_lstm_cell/kernel 和 lstm/basic_lstm_cell/bias
restore_dict = {}
for i in variables[:]: # the first is global step
#restore_dict[i.name.replace(':0', '')] = i
if i.name.replace(':0', '')=='LSTM/basic_lstm_cell/biases':
print('LSTM/basic_lstm_cell/bias========================================')
restore_dict[i.name.replace('LSTM/basic_lstm_cell/biases:0',
'LSTM/basic_lstm_cell/bias')] = tf.get_variable('LSTM/basic_lstm_cell/bias',[,])
elif i.name.replace(':0', '')=='LSTM/basic_lstm_cell/weights':
print('LSTM/basic_lstm_cell/kernel========================================')
restore_dict[i.name.replace('LSTM/basic_lstm_cell/weights:0',
'LSTM/basic_lstm_cell/kernel')] = tf.get_variable('LSTM/basic_lstm_cell/kernel',[, ])
else:
restore_dict[i.name.replace(':0', '')] = i
原本以为可以兼容1.0.1和1.2.1版本了,但是出现一个问题,对同一张图片分别在tf1.0.1和tf1.2.1两个版本下进行多标签预测,见如下两图
图1—-tf1.2.1环境下运行结果(这是正确的结果)
图2 —-tf1.0.1环境下运行结果(发现只能预测第二个标签,第一个丢失了)
至于为何丢失的问题,我在做测试中,发现,尽管修改了对应于当前tf版本的 “lstm/basic_lstm_cell/weights” 和”lstm/basic_lstm_cell/biases”,但是并没有起到作用,这个可以通过注释下面两行代码运行程序,发现也是上述结果
总结:
总的来说,1.0.1和1.2.1在使用saver的时候,存在着ckpt模型参数和saver初始化restore_dict中的参数的一一对应的情况,其中以LSTM中的两个参数:lstm/basic_lstm_cell/weights 和 lstm/basic_lstm_cell/biases 容易出现因为版本的不同,ckpt与预测代码中自定义的restore_dict中两个参数不匹配的情况,就会报出本错误。
附加:
上述问题,笔者有在github-tensorflow官方进行问题提问,成员 skye 给了笔者一个地址作为参考,地址如下:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
这个地址中给出了checkpoint_convert的详细代码,内涵不同版本之间不同命名的转化问题。