一、下載下傳和運作
https://github.com/tensorflow/models 頁面即可下載下傳
具體項目是 models/tutorials/image/cifar10_estimator/
$ curl -o cifar--python.tar.gz https://www.cs.toronto.edu/~kriz/cifar--python.tar.gz
$ tar xzf cifar--python.tar.gz
$ python generate_cifar10_tfrecords.py --input_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-/cifar--batches-py --output_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar-/
python cifar10_main.py --data_dir=/home/jqh/jiangqiuhua/Learning_Data/cifar- \
--model_dir=/tmp/cifar10 \
--is_cpu_ps=True \
--force_gpu_compatible=True \
--num_gpus= \
--train_steps=
$ tensorboard --logdir=/tmp/cifar10
二、代碼分析
1.cifar10_main.py
1.1 指令行參數處理
tf.flags.FLAGS定義在/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform檔案夾下flags.py中。
FLAGS = _FlagValues()
class _FlagValues(object):
def _parse_flags(self, args=None):
result, unparsed = _global_parser.parse_known_args(args=args)
def __getattr__(self, name):
if not parsed:
self._parse_flags()
python中如下代碼的作用。
if __name__ = "__main__": #使用這種方式保證了,如果此檔案被其它檔案import的時候,不會執行main中的代碼
tf.app.run() #解析指令行參數,調用main函數 main(sys.argv)
在tf.app.run()中
flags_passthrough = f._parse_flags(args=args)
1.2 訓練和評估
1)訓練和評估輸入
train_input_fn = functools.partial(input_fn, subset='train',
num_shards=FLAGS.num_gpus)
eval_input_fn = functools.partial(input_fn, subset='eval',
num_shards=FLAGS.num_gpus)
functools.partial的作用就是表明train_input_fn函數就是帶了train和FLAGS.num_gpus參數的input_fn函數。
2)Session配置
sess_config = tf.ConfigProto()
sess_config.allow_soft_placement = True
sess_config.log_device_placement = FLAGS.log_device_placement
sess_config.intra_op_parallelism_threads = FLAGS.num_intra_threads
sess_config.inter_op_parallelism_threads = FLAGS.num_inter_threads
sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible
3)Estimator配置
config = tf.estimator.RunConfig()
config = config.replace(session_config=sess_config)
classifier = tf.estimator.Estimator(
model_fn=_resnet_model_fn, model_dir=FLAGS.model_dir, config=config)
4)訓練和評估
classifier.train(input_fn=train_input_fn,
steps=train_steps,
hooks=hooks)
eval_results = classifier.evaluate(
input_fn=eval_input_fn,
steps=eval_steps)