天天看點

Tensorflow Object Detection API 源碼分析之 model_main.py

Tensorflow Object Detection API 源碼分析之 model_main.py

# Train和Eval的主函數,也是API的入口函數
"""Binary to run train and evaluation on object detection model."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# 類似 argparse參數設定的python包
from absl import flags  

import tensorflow as tf

from object_detection import model_hparams
from object_detection import model_lib

# 設定參數 (前兩個為必要)
# model_dir 模型輸出儲存路徑
# pipeline_config_path 配置檔案路徑

# num_train_steps,num_eval_steps 訓練測試step數
# hparams_overrides 設定覆寫配置檔案的超參數?
# checkpoint_dir 儲存點路徑,如果設定,則為eval-only模式
# run_once 如果為eval-only模式
# eval_training_data

flags.DEFINE_string(
    'model_dir', None, 'Path to output model directory '
    'where event and checkpoint files will be written.')
flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
                    'file.')
flags.DEFINE_integer('num_train_steps', None, 'Number of train steps.')
flags.DEFINE_integer('num_eval_steps', None, 'Number of train steps.')
flags.DEFINE_string(
    'hparams_overrides', None, 'Hyperparameter overrides, '
    'represented as a string containing comma-separated '
    'hparam_name=value pairs.')
flags.DEFINE_string(
    'checkpoint_dir', None, 'Path to directory holding a checkpoint.  If '
    '`checkpoint_dir` is provided, this binary operates in eval-only mode, '
    'writing resulting metrics to `model_dir`.')
flags.DEFINE_boolean(
    'run_once', False, 'If running in eval-only mode, whether to run just '
    'one round of eval vs running continuously (default).'
)
flags.DEFINE_boolean('eval_training_data', False,
                     'If training data should be evaluated for this job.')
FLAGS = flags.FLAGS


def main(unused_argv):
  flags.mark_flag_as_required('model_dir')
  flags.mark_flag_as_required('pipeline_config_path')
  config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)

  # 調用 model_lib.create_estimator_and_inputs
  # 傳回estimator train_input_fn eval_input_fn 
  # eval_on_train_input_fn predict_input_fn train_steps eval_steps
  train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
      hparams=model_hparams.create_hparams(FLAGS.hparams_overrides),
      pipeline_config_path=FLAGS.pipeline_config_path,
      train_steps=FLAGS.num_train_steps,
      eval_steps=FLAGS.num_eval_steps)
  estimator = train_and_eval_dict['estimator']
  train_input_fn = train_and_eval_dict['train_input_fn']
  eval_input_fn = train_and_eval_dict['eval_input_fn']
  eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
  predict_input_fn = train_and_eval_dict['predict_input_fn']
  train_steps = train_and_eval_dict['train_steps']
  eval_steps = train_and_eval_dict['eval_steps']

  # 根據參數,選擇實際的訓練/測試過程
  if FLAGS.checkpoint_dir:
    if FLAGS.eval_training_data:
      name = 'training_data'
      input_fn = eval_on_train_input_fn
    else:
      name = 'validation_data'
      input_fn = eval_input_fn
    if FLAGS.run_once:
      estimator.evaluate(input_fn,
                         eval_steps,
                         checkpoint_path=tf.train.latest_checkpoint(
                             FLAGS.checkpoint_dir))
    else:
      model_lib.continuous_eval(estimator, FLAGS.model_dir, input_fn,
                                eval_steps, train_steps, name)
  else:
    train_spec, eval_specs = model_lib.create_train_and_eval_specs(
        train_input_fn,
        eval_input_fn,
        eval_on_train_input_fn,
        predict_input_fn,
        train_steps,
        eval_steps,
        eval_on_train_data=False)

    # 一般的訓練過程都在這
    # Currently only a single Eval Spec is allowed.
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[])


if __name__ == '__main__':
  tf.app.run()

           

繼續閱讀