轉自:https://blog.csdn.net/amanfromearth/article/details/79155926#commentBox
在使用Tensorflow做讀取并finetune的時候,發現在讀取官方給的inception_v3預訓練模型總是出現各種錯誤,現記錄其正确的讀取方式和各種錯誤做法:
關鍵代碼如下:
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import inception_v3
from research.slim.preprocessing import inception_preprocessing
Pretrained_model_dir = '/Users/apple/tensorflow_model/models-master/research/slim/pre_train/inception_v3.ckpt'
image_size = 299
# 讀取網絡
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
imgPath = 'test.jpg'
testImage_string = tf.gfile.FastGFile(imgPath, 'rb').read()
testImage = tf.image.decode_jpeg(testImage_string, channels=3)
processed_image = inception_preprocessing.preprocess_image(testImage, image_size, image_size, is_training=False)
processed_images = tf.expand_dims(processed_image, 0)
logits, end_points = inception_v3.inception_v3(processed_images, num_classes=128, is_training=False)
w1 = tf.Variable(tf.truncated_normal([128, 5], stddev=tf.sqrt(1/128)))
b1 = tf.Variable(tf.zeros([5]))
logits = tf.nn.leaky_relu(tf.matmul(logits, w1) + b1)
with tf.Session() as sess:
# 先初始化所有變量,避免有些變量未讀取而産生錯誤
init = tf.global_variables_initializer()
sess.run(init)
#加載預訓練模型
print('Loading model check point from {:s}'.format(Pretrained_model_dir))
#這裡的exclusions是不需要讀取預訓練模型中的Logits,因為預設的類别數目是1000,當你的類别數目不是1000的時候,如果還要讀取的話,就會報錯
exclusions = ['InceptionV3/Logits',
'InceptionV3/AuxLogits']
#建立一個清單,包含除了exclusions之外所有需要讀取的變量
inception_except_logits = slim.get_variables_to_restore(exclude=exclusions)
#建立一個從預訓練模型checkpoint中讀取上述清單中的相應變量的參數的函數
init_fn = slim.assign_from_checkpoint_fn(Pretrained_model_dir, inception_except_logits,ignore_missing_vars=True)
#運作該函數
init_fn(sess)
print('Loaded.')
out = sess.run(logits)
print(out.shape)
print(out)
其中可能會出現的錯誤如下:
錯誤1
- 1
- 2
- 3
原因:
預訓練模型中的類别數class_num=1000,這裡輸入的class_num=5,當讀取完整模型的時候當然會出錯。
解決方案:
選擇不讀取包含類别數的Logits層和AuxLogits層:
- 1
- 2
錯誤2
Tensor name “xxxx” not found in checkpoint files
- 1
- 2
- 3
- 4
這裡的Tensor name可以是所有inception_v3中變量的名字,出現這種情況的各種原因和解決方案是:
1.建立圖的時候沒有用arg_scope,是這樣建立的:
- 1
解決方案:
在這裡加上arg_scope,裡面調用的是庫中自帶的inception_v3_arg_scope
- 1
- 2
2.在讀取checkpoint的時候未初始化所有變量,即未運作
- 1
- 2
這樣會導緻有一些checkpoint中不存在的變量未被初始化,比如使用Momentum時的每一層的Momentum參數等。
3.使用
slim.assign_from_checkpoint_fn()
函數時,沒有添加
ignore_missing_vars=True
屬性,由于預設ignore_missing_vars=False,是以,當使用非SGD的optimizer的時候(如Momentum、RMSProp等)時,會提示Momentum或者RMSProp的參數在checkpoint中無法找到,如:
使用Momentum時:
- 1
- 2
- 3
- 4
使用RMSProp時:
- 1
- 2
- 3
- 4
解決方法很簡單,就是把ignore_missing_vars=True
- 1
注意:一定要在之前的步驟都完成之後才能設成True,不然如果變量名稱全部出錯的話,會忽視掉checkpoint中所有的變量,進而不讀取任何參數。
以上就是我碰見的問題,希望有所幫助。