問題詳情:

Caused by op u’save/RestoreV2’, defined at:
File “demo.py”, line 25, in
result_dict = news_demo.newsAggreg({image_path})
File “/home/rszj/liutao/news_aggreg/news_demo.py”, line 32, in newsAggreg
predict = news_predict.run(images_path)
File “/home/rszj/liutao/news_aggreg/news_predict.py”, line 179, in run
saver = tf.train.Saver(restore_dict) # when you want to save model
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 1139, in init
self.build()
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 1170, in build
restore_sequentially=self._restore_sequentially)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 691, in build
restore_sequentially, reshape)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 407, in _AddRestoreOps
tensors = self.restore_op(filename_tensor, saveable, preferred_shard)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py”, line 247, in restore_op
[spec.tensor.dtype])[0])
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_io_ops.py”, line 640, in restore_v2
dtypes=dtypes, name=name)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py”, line 767, in apply_op
op_def=op_def)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”, line 2506, in create_op
original_op=self._default_original_op, op_def=op_def)
File “/home/rszj/liutao/virtualenv/liutao_py2/mpy2tf1.2/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”, line 1269, in init
self._traceback = _extract_stack()
NotFoundError (see above for traceback): Key LSTM/basic_lstm_cell/bias not found in checkpoint
[[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT], _device=”/job:localhost/replica:0/task:0/cpu:0”](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
[[Node: save/RestoreV2_26/_101 = _Recvclient_terminated=false, recv_device=”/job:localhost/replica:0/task:0/gpu:0”, send_device=”/job:localhost/replica:0/task:0/cpu:0”, send_device_incarnation=1, tensor_name=”edge_212_save/RestoreV2_26”, tensor_type=DT_FLOAT, _device=”/job:localhost/replica:0/task:0/gpu:0”]]
參考一:參考下面這篇部落格進行解決
tensorflow1.x版本加載saver.restore目錄報錯
在 ubuntu源代碼 news_predict.py中
saver = tf.train.Saver(restore_dict)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
saver.restore(sess, r'model/model.ckpt')
TF1.0正确運作結果如下
ubuntu中修改上述代碼如下:
saver = tf.train.Saver(restore_dict)
init = tf.global_variables_initializer()
savsess = tf.Session()
sess.run(init)
module_file = tf.train.latest_checkpoint('news_tf_model/model.ckpt') #ckpt路徑抽調出來
if module_file is not None: # 添加一個判斷語句,判斷ckpt的路徑檔案
saver.restore(sess, module_file)
TF1.0和TF1.2運作結果全部是下面情況
參考一分析:
“if module_file is not None”該判斷僅僅是做了一個檔案是否存在的判斷,并沒有從根本上解決LSTM的報錯問題,而代碼不執行“ saver.restore(sess, module_file)”,就造成最後得到的結果為空了。
問題分析:
其實,仔細檢視提示,會發現,報錯的是指出“Key LSTM/basic_lstm_cell/bias not found in checkpoint”,那必然是LSTM中的bias定義出現了問題。是以筆者列印了saver = tf.train.Saver(restore_dict)中的“restore_dict”,發現TF1.0中和TF1.2中參數存在差異如下表
TF1.0 | TF1.2 |
---|---|
lstm/basic_lstm_cell/weights | lstm/basic_lstm_cell/kernel |
lstm/basic_lstm_cell/biases | lstm/basic_lstm_cell/bias |
原來問題确實出現在了LSTM上了,TF1.0和TF1.2的LSTM竟然在命名上出現了差異,好吧,看來要在TF1.2上使用TF1.0訓練好的ckpt模型,必須要對應LSTM的上面兩個參數了。
要對應參數其實有兩種辦法,第一種,修改ckpt模型中LSTM兩個變量名;第二種,在predict時,做符合TF版本的LSTM變量名的對應。
參考二:
接下來,先介紹第一種方法,根據 基于tensorflow 1.0的圖像叙事功能測試(model/im2txt) 這篇部落格的内容,修改代碼如下
# 由于版本不同,需要進行修改
def RenameCkpt():
# 1.0.1 : 1.2.1
vars_to_rename = {
"lstm/basic_lstm_cell/weights": "lstm/basic_lstm_cell/kernel",
"lstm/basic_lstm_cell/biases": "lstm/basic_lstm_cell/bias",
}
new_checkpoint_vars = {}
reader = tf.train.NewCheckpointReader(FLAGS.checkpoint_path)
for old_name in reader.get_variable_to_shape_map():
if old_name in vars_to_rename:
new_name = vars_to_rename[old_name]
else:
new_name = old_name
new_checkpoint_vars[new_name] = tf.Variable(reader.get_tensor(old_name))
init = tf.global_variables_initializer()
saver = tf.train.Saver(new_checkpoint_vars)
with tf.Session() as sess:
sess.run(init)
saver.save(sess, "/home/ndscbigdata/work/change/tf/gan/im2txt/ckpt/newmodel.ckpt-2000000")
print("checkpoint file rename successful... ")
上述方法是修改ckpt模型中的lstm/basic_lstm_cell/kernel 和 lstm/basic_lstm_cell/bias,修改完成後的ckpt僅僅能夠在1.2.1上正常運作,同樣因為參數名修改了變得版本不對應,而無法在1.0.1上運作。
參考三:
根據以上描述,筆者想到了方法二,按照正常邏輯,修改restore_dict中的lstm/basic_lstm_cell/kernel 和 lstm/basic_lstm_cell/bias
restore_dict = {}
for i in variables[:]: # the first is global step
#restore_dict[i.name.replace(':0', '')] = i
if i.name.replace(':0', '')=='LSTM/basic_lstm_cell/biases':
print('LSTM/basic_lstm_cell/bias========================================')
restore_dict[i.name.replace('LSTM/basic_lstm_cell/biases:0',
'LSTM/basic_lstm_cell/bias')] = tf.get_variable('LSTM/basic_lstm_cell/bias',[,])
elif i.name.replace(':0', '')=='LSTM/basic_lstm_cell/weights':
print('LSTM/basic_lstm_cell/kernel========================================')
restore_dict[i.name.replace('LSTM/basic_lstm_cell/weights:0',
'LSTM/basic_lstm_cell/kernel')] = tf.get_variable('LSTM/basic_lstm_cell/kernel',[, ])
else:
restore_dict[i.name.replace(':0', '')] = i
原本以為可以相容1.0.1和1.2.1版本了,但是出現一個問題,對同一張圖檔分别在tf1.0.1和tf1.2.1兩個版本下進行多标簽預測,見如下兩圖
圖1—-tf1.2.1環境下運作結果(這是正确的結果)
圖2 —-tf1.0.1環境下運作結果(發現隻能預測第二個标簽,第一個丢失了)
至于為何丢失的問題,我在做測試中,發現,盡管修改了對應于目前tf版本的 “lstm/basic_lstm_cell/weights” 和”lstm/basic_lstm_cell/biases”,但是并沒有起到作用,這個可以通過注釋下面兩行代碼運作程式,發現也是上述結果
總結:
總的來說,1.0.1和1.2.1在使用saver的時候,存在着ckpt模型參數和saver初始化restore_dict中的參數的一一對應的情況,其中以LSTM中的兩個參數:lstm/basic_lstm_cell/weights 和 lstm/basic_lstm_cell/biases 容易出現因為版本的不同,ckpt與預測代碼中自定義的restore_dict中兩個參數不比對的情況,就會報出本錯誤。
附加:
上述問題,筆者有在github-tensorflow官方進行問題提問,成員 skye 給了筆者一個位址作為參考,位址如下:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py
這個位址中給出了checkpoint_convert的詳細代碼,内涵不同版本之間不同命名的轉化問題。