天天看点

tensorflow加载预训练模型的部分参数方法

在训练时,修改了部分网络结构,但其他的网络结果仍一样,又不想重新开始训练,

这就需要加载已训练的好的模型的参数。下面是在训练中需要加载部分参数模型,

以此记录:

if args.checkpoint and os.path.isdir(args.checkpoint):

            logger.info('Restore from checkpoint...')

           #获取网络中所以可以加载的参数

            variables = tf.contrib.framework.get_variables_to_restore()

           #删除BackBone层中的参数

            variables_to_resotre = [v for v in variables if v.name.split('/')[0]!='BackBone']

            #构建剩余部分参数的saver

            saver = tf.train.Saver(variables_to_resotre)

            saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint))

            logger.info('Restore from checkpoint...Done')

当在神经网络中增加几层,或者修改小部分,可以这样设置:

import tensorflow.contrib.slim as slim

  variables_to_restore = slim.get_variables_to_restore()

            CKPT_FILE=last_checkpoint_path+'/model-38000'

            load_fn = slim.assign_from_checkpoint_fn(CKPT_FILE,variables_to_restore,ignore_missing_vars=True)

            load_fn(sess)

   参考资料:

https://blog.csdn.net/runningwei/article/details/85677793

https://blog.csdn.net/b876144622/article/details/79962727

继续阅读