Tensorflow Object Detection API 源码分析之 model_main.py
# Train和Eval的主函数,也是API的入口函数
"""Binary to run train and evaluation on object detection model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# 类似 argparse参数设置的python包
from absl import flags
import tensorflow as tf
from object_detection import model_hparams
from object_detection import model_lib
# 设置参数 (前两个为必要)
# model_dir 模型输出保存路径
# pipeline_config_path 配置文件路径
# num_train_steps,num_eval_steps 训练测试step数
# hparams_overrides 设置覆盖配置文件的超参数?
# checkpoint_dir 保存点路径,如果设置,则为eval-only模式
# run_once 如果为eval-only模式
# eval_training_data
flags.DEFINE_string(
'model_dir', None, 'Path to output model directory '
'where event and checkpoint files will be written.')
flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
'file.')
flags.DEFINE_integer('num_train_steps', None, 'Number of train steps.')
flags.DEFINE_integer('num_eval_steps', None, 'Number of train steps.')
flags.DEFINE_string(
'hparams_overrides', None, 'Hyperparameter overrides, '
'represented as a string containing comma-separated '
'hparam_name=value pairs.')
flags.DEFINE_string(
'checkpoint_dir', None, 'Path to directory holding a checkpoint. If '
'`checkpoint_dir` is provided, this binary operates in eval-only mode, '
'writing resulting metrics to `model_dir`.')
flags.DEFINE_boolean(
'run_once', False, 'If running in eval-only mode, whether to run just '
'one round of eval vs running continuously (default).'
)
flags.DEFINE_boolean('eval_training_data', False,
'If training data should be evaluated for this job.')
FLAGS = flags.FLAGS
def main(unused_argv):
flags.mark_flag_as_required('model_dir')
flags.mark_flag_as_required('pipeline_config_path')
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)
# 调用 model_lib.create_estimator_and_inputs
# 返回estimator train_input_fn eval_input_fn
# eval_on_train_input_fn predict_input_fn train_steps eval_steps
train_and_eval_dict = model_lib.create_estimator_and_inputs(
run_config=config,
hparams=model_hparams.create_hparams(FLAGS.hparams_overrides),
pipeline_config_path=FLAGS.pipeline_config_path,
train_steps=FLAGS.num_train_steps,
eval_steps=FLAGS.num_eval_steps)
estimator = train_and_eval_dict['estimator']
train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fn = train_and_eval_dict['eval_input_fn']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn']
train_steps = train_and_eval_dict['train_steps']
eval_steps = train_and_eval_dict['eval_steps']
# 根据参数,选择实际的训练/测试过程
if FLAGS.checkpoint_dir:
if FLAGS.eval_training_data:
name = 'training_data'
input_fn = eval_on_train_input_fn
else:
name = 'validation_data'
input_fn = eval_input_fn
if FLAGS.run_once:
estimator.evaluate(input_fn,
eval_steps,
checkpoint_path=tf.train.latest_checkpoint(
FLAGS.checkpoint_dir))
else:
model_lib.continuous_eval(estimator, FLAGS.model_dir, input_fn,
eval_steps, train_steps, name)
else:
train_spec, eval_specs = model_lib.create_train_and_eval_specs(
train_input_fn,
eval_input_fn,
eval_on_train_input_fn,
predict_input_fn,
train_steps,
eval_steps,
eval_on_train_data=False)
# 一般的训练过程都在这
# Currently only a single Eval Spec is allowed.
tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[])
if __name__ == '__main__':
tf.app.run()