天天看点

Estimator详解Estimator

Estimator

Estimator是tensorflow推出的一个High level的API,用于简化机器学习

Estimator的优点

  1. 开发方便
  2. 方便整合其它tensorflow高阶api
  3. 单机和分布式代码一致
  4. 不必关系一些机器学习中的细节,比如(loading model、saving model、loging)等

Estimator概述

Estimator是谷歌为了简化和规范化机器学习编程过程而提出来的,它封装了以下几个部分:

  • 训练( training )
  • 评估 ( evaluation )
  • 预测 ( prediction )
  • 模型输出 ( export for serving )

    其中,前三部分都是再model_fn函数中定义的。

model_fn

model_fn函数用于构建自定义的模型及训练、预测方法

import tensorflow as tf
def model_fn(features, labels, mode, params):
"""features,labels,mode为固定参数,其中features,labels是通过input_fn传输过来的,mode则是estimator传过来用于判断训练、预测、测试过程的,不同的过程需要返回不同的结构。
"""
lr = params['lr']
try:
  init_checkpoint = params['init_checkpoint']
except KeyError:
  init_checkpoint = None

x = features['inputs']
y = features['labels]

# ##########这里定义自己的网络模型##########
pre = tf.layers.dense(x, 1)
loss = tf.reduce_mean(tf.pow(pre - y, 2), name='loss')
# #######################################

# 加载预训练模型
assignment_map = dict()
if init_checkpoint:
  for var in tf.train.list_variables(init_checkpoint): # 存放checkpoint的变量名称和shape
    assignment_map[var[0]] = var[0]
  tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

# 定义你训练过程要做的事情
if mode == tf.estimator.ModeKeys.Train:
  optimizer = tf.train.AdamOptimizer(lr)
  train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
  output_spec = tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

# 定义测试过程
elif mode == tf.estimator.ModeKeys.EVAL:
  metrics = {'eval_loss': loss}
  output_spec = tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)

# 定义预测过程
elif mode == tf.estimator.ModeKeys.PREDICT:
  predictions = {'predictions': pre}
  output_spec = tf.estimator.EstimatorSpec(mode, predictions=predictions)

else:
  raise TypeError

return output_spec
           

input_fn

input_fn用于将数据喂给模型。

输入函数是返回一个tf.data.Dataset对象,该对象输出包含features、label的元组。

其中, features是一个python字典,包含:1. 每个特征名称作为键 2. 每个特征下的值的数组作为数组的值。 label则是每个样本的标签的数组。

input_fn可以采用Dataset Api 来处理返回数据。

Dataset Api

tf.Data API用来读取、预处理数据。tf.data.Dataset是一种对数据处理的抽象,用来再tensorflow中表示要处理的数据。

一个dataset只能通过一下两种方法来得到:

  1. 从内存或者文件等数据源中创建dataset
  2. 从两外一个或者多个dataset中转换而来。

dataset常用的方法有:

  • tf.data.Dataset.from_tensors()
  • tf.data.Dataset.from_tensor_slices()
  • tf.data.TFRecordDataset()
  • tf.data.TextLineDataset()
  • Dataset.map()
  • Dataset.batch()
  • Dataset.filter()
def input_fn_bulider(input_file, batch_size, is_training):
  name_to_features = {
    'inputs': tf.FixedLenFeature([3], tf.float32),
    'labels': tf.FixedLenFeature([], tf.float32)
  }

  def input_fn(params):
    d = tf.data.TFRecordDataset(inputs_file)
    if is_training:
      d = d.repeat()
      d = d.shuffle()
    
    # 构建和返回dataset
    # map_and_batch是将map 和 batch 结合起来
    d = d.apply(tf.contrib.data.map_and_batch(lambda x: tf.parse_single_example(x, name_to_features), batch_size=batch_size))

    return d

  return input_fn

           

执行Estimator

if __name__ == '_main_':
  tf.logging.set_verbosity(tf.logging.INFO)
  runConfig = tf.estimator.RunConfig(save_checkpoints_setps=1,
                                    log_step_count_steps=1)
  
  estimator = tf.estimator.Estimator(model_fn, model_dir='you_save_path',
                                    config=runConfig, params={'lr': 0.01})
  # log_step_count_steps控制的只是loss的global_step的输出
  # 还可以通过tf.train.LoggingTensorHook自定义更多的输出
  # tensors是要输出的内容, 输入一个字典,key为变量名称,value为要计算的tensor的name
  logging_hook = tf.train.LoggingTensorHook(every_n_iter=1,
                                            tensors={'loss': 'loss'})
  # logging_hook需要再model_fn中设置
  # tf.estimator.EstimatorSpec(
  #    ...params...
  #    training_hooks = [logging_hook]
  #  )
  input_fn = input_fn_builder('test.tfrecord', batch_size=1, is_traing=True)
  estimator.train(input_fn, max_steps=1000)
  
           

参考

TensorFlow estimator详细介绍,实现模型的高效训练

TensorFlow Estimator 教程之----快速入门

tensorflow官方教程

继续阅读