天天看點

GAN學習 | DCGAN TensorFlow代碼解讀(1):main.pymain.py 代碼解讀小白總結

目錄

main.py 代碼解讀

line11-30  封裝外部傳入的參數

tf.app.flags的了解

line33-48 定義主函數

pprint.PrettyPrinter() 的了解

 tf.ConfigProto() 的了解

line 33-105 主函數中定義session

小白總結

代碼塊學習

Github: DCGAN-Master

本文針對DCGAN的TensorFlow master代碼進行解讀,并針對MNIST手寫資料進行實驗。

main.py 代碼解讀

line1-8 import 庫、import自己py檔案中的函數和模型

import os
import scipy.misc
import numpy as np
import tensorflow as tf

# 從自定義的檔案中import
from model import DCGAN
from utils import pp, visualize, to_json, show_all_variables
           

line11-30  封裝外部傳入的參數

flags = tf.app.flags  # 外部參數傳遞
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_float("train_size", np.inf, "The size of train images [np.inf]")  # np.inf??
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]")
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]")
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
flags.DEFINE_string("dataset", "mnist", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("data_dir", "./data", "Root directory of dataset [data]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]")
FLAGS = flags.FLAGS
           

tf.app.flags的了解

函數本身

  • tf.app.flags,用于接受指令行傳遞參數,flags可以幫助我們通過指令行來動态的更改代碼中的參數。比如,在這個py檔案中,首先定義了一些參數,然後将參數統一儲存到變量FLAGS中,相當于指派,後邊調用這些參數的時候直接使用FLAGS參數即可。
  • 基本參數類型有三種flags.DEFINE_integer、flags.DEFINE_float、flags.DEFINE_boolean。
  • 第一個是參數名稱,第二個參數是預設值,第三個是參數描述

使用過程

  1. 調用flags = tf.app.flags
  2. 參數定義說明:flags.DEFINE_integer、flags.DEFINE_float、flags.DEFINE_boolean
  3. FLAGS = flags.FLAGS
  4. tf.app.run

line33-48 定義主函數

def main(_):
  pp.pprint(flags.FLAGS.__flags)  # 列印所有外部定義的flag包含的參數

  if FLAGS.input_width is None:
    FLAGS.input_width = FLAGS.input_height
  if FLAGS.output_width is None:
    FLAGS.output_width = FLAGS.output_height

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)

  # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()  # 用在建立session的時候,用來對session進行參數配置
  run_config.gpu_options.allow_growth=True  # 配置設定器将不會指定所有的GPU記憶體,而是根據需求增長 
           

在model.py中 line18 定義了類pp = pprint.PrettyPrinter() ,是以在主函數中可以直接用pprint以解釋器可以解析的輸入形式列印python資料結構。

pprint.PrettyPrinter() 的了解

pprint子產品:列印出任何python資料結構類和方法

使用過程

  1. pp = pprint.PrettyPrinter()
  2. pp.pprint(data)  #具體函數

 tf.ConfigProto() 的了解

用于對session進行參數配置, 當allow_growth設定為True時,配置設定器将不會指定所有的GPU記憶體,而是根據需求增長。(還有其他可選參數可以進行設定)

line 33-105 主函數中定義session

with tf.Session(config=run_config) as sess:
    if FLAGS.dataset == 'mnist':  # 直接調用外部參數
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10,
          z_dim=FLAGS.generate_test_images,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir)
    else:
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          z_dim=FLAGS.generate_test_images,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir)

    show_all_variables()  # 顯示所有變量

    if FLAGS.train:
      dcgan.train(FLAGS)
    else:
      if not dcgan.load(FLAGS.checkpoint_dir)[0]:
        raise Exception("[!] Train a model first, then run test mode")
      

    # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
    #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
    #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
    #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
    #                 [dcgan.h4_w, dcgan.h4_b, None])

    # Below is codes for visualization
    OPTION = 1
    visualize(sess, dcgan, FLAGS, OPTION)  # model中定義的visualize function

if __name__ == '__main__':
  tf.app.run()
           

小白總結

代碼塊學習

(1)flag的使用

繼續閱讀