天天看点

Tensorflow Estimator做迁移学习(Transfer Learning)

在TensorFlow官方ResNet模型实现分析中我们分析了基于Estimator的模型实现与运行的基本方法。除此之外,这份源码还提供了神经网络中常用的一种手段——迁移学习(Transfer Learning)的实现。

迁移学习

取决于具体的任务,从零开始训练一个深度神经网络有时需要海量的数据才能得到较好的效果。如果你手头的数据有限,又想采用神经网络作为解决方案,可以尝试一下迁移学习。

举一个例子:你负责维护工厂的一条自动化生产线,在传送带上有10种不同的零件随机经过。工业照相机可以逐一捕捉完整的零件图像,但是需要你来根据零件类型调整后续的机械手动作。现在可用于训练的零件图像非常有限,而你手头正好有一个使用大量数据训练好的ImageNet图像分类神经网络模型。如何充分利用这两点是一个典型的迁移学习应用场景。

迁移学习迁移了什么

深度神经网络的结构存在层级。对于卷积神经网络CNN来说,不同层级的卷积层所表现出的特征提取也呈现层级性。具体来说,底层的卷积层对于低阶特征较为敏感,例如边缘、团块等;随着层级的升高,提取的特征越来越抽象。这种随层级变化的特征提取能力是迁移学习的基础。它保证了当任务具备相似性时,例如分类1024种不同的自然物体与分类10种不同的零件,已经训练好的神经网络的特征提取层可以“迁移”到新的分类任务中来继续承担特征提取的功能。

迁移学习的具体的做法

常用的做法包括:

  1. “冻结”特征提取部分。
  2. 使用新数据训练末端负责输出分类的若干全连接层。

TensorFlow如何实现

官方的ResNet模型实现提供了迁移学习的功能。只需要指定

--pretrained_model_checkpoint_path

--fine_tune

这两个flag就可以实现。

具体到代码中,首先在载入模型时要跳过最终的dense层。

if flags_obj.pretrained_model_checkpoint_path is not None:
    warm_start_settings = tf.estimator.WarmStartSettings(
        flags_obj.pretrained_model_checkpoint_path,
        vars_to_warm_start='^(?!.*dense)')
           

参数vars_to_warm_start采用正则表达式的方式过滤掉了最后的全连接层。

然后在根据梯度更新参数时,过滤掉不需要更新的部分。

grad_vars = optimizer.compute_gradients(loss)
      if fine_tune:
        grad_vars = _dense_grad_filter(grad_vars)
      minimize_op = optimizer.apply_gradients(grad_vars, global_step)
           

这里的_dense_grad_filter的实现如下:

def _dense_grad_filter(gvs):
      """Only apply gradient updates to the final layer.

      This function is used for fine tuning.

      Args:
        gvs: list of tuples with gradients and variable info
      Returns:
        filtered gradients so that only the dense layer remains
      """
      return [(g, v) for g, v in gvs if 'dense' in v.name]
           

这种实现方法是根据node的name属性来实现的。所以在改造网络的时候,注意自己添加的node name不要与之冲突。

参考

迁移学习用于图像识别的Tensorflow实现

https://yinguobing.com/tensorflow-transfer-learning/

tensorflow estimator 使用hook实现finetune

https://github.com/tensorflow/tensorflow/issues/10155

https://medium.com/@utsumuki_neko/using-inception-v3-from-tensorflow-hub-for-transfer-learning-a931ff884526

https://github.com/tensorflow/tensorflow/issues/14713

https://stackoverflow.com/questions/46423956/load-checkpoint-and-finetuning-using-tf-estimator-estimator

TensorFlow如何实现Transfer Learning

TensorFlow 迁移学习识花实战案例

基于Tensorflow高阶API构建大规模分布式深度学习模型系列之自定义Estimator(以文本分类CNN模型为例)

继续阅读