天天看點

Resnet Cifar-10調試

一、下載下傳和運作

https://github.com/tensorflow/models 頁面即可下載下傳
具體項目是 models/tutorials/image/cifar10_estimator/
$ curl -o cifar--python.tar.gz https://www.cs.toronto.edu/~kriz/cifar--python.tar.gz
$ tar xzf cifar--python.tar.gz
$ python generate_cifar10_tfrecords.py --input_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-/cifar--batches-py --output_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-/
python cifar10_main.py --data_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar- \
                         --model_dir=/tmp/cifar10 \
                         --is_cpu_ps=True \
                         --force_gpu_compatible=True \
                         --num_gpus= \
                         --train_steps=
$ tensorboard --logdir=/tmp/cifar10
           

二、代碼分析

1.cifar10_main.py

1.1 指令行參數處理

tf.flags.FLAGS定義在/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform檔案夾下flags.py中。

FLAGS = _FlagValues()
class _FlagValues(object):
    def _parse_flags(self, args=None):
        result, unparsed = _global_parser.parse_known_args(args=args)
    def __getattr__(self, name):
        if not parsed:
              self._parse_flags()
           

python中如下代碼的作用。

if __name__ = "__main__": #使用這種方式保證了,如果此檔案被其它檔案import的時候,不會執行main中的代碼
    tf.app.run() #解析指令行參數,調用main函數 main(sys.argv)

在tf.app.run()中
     flags_passthrough = f._parse_flags(args=args)
           
1.2 訓練和評估

1)訓練和評估輸入

train_input_fn = functools.partial(input_fn, subset='train',
                                  num_shards=FLAGS.num_gpus)
eval_input_fn = functools.partial(input_fn, subset='eval',
                                 num_shards=FLAGS.num_gpus)
           

functools.partial的作用就是表明train_input_fn函數就是帶了train和FLAGS.num_gpus參數的input_fn函數。

2)Session配置

sess_config = tf.ConfigProto()
sess_config.allow_soft_placement = True
sess_config.log_device_placement = FLAGS.log_device_placement
sess_config.intra_op_parallelism_threads = FLAGS.num_intra_threads
sess_config.inter_op_parallelism_threads = FLAGS.num_inter_threads
sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible
           

3)Estimator配置

config = tf.estimator.RunConfig()
config = config.replace(session_config=sess_config)
classifier = tf.estimator.Estimator(
    model_fn=_resnet_model_fn, model_dir=FLAGS.model_dir, config=config)
           

4)訓練和評估

classifier.train(input_fn=train_input_fn,
                 steps=train_steps,
                 hooks=hooks)
eval_results = classifier.evaluate(
    input_fn=eval_input_fn,
    steps=eval_steps)
           

繼續閱讀