天天看點

GAN與WassersteinGAN代碼keras分析

https://github.com/tdeboissiere/DeepLearningImplementations

使用mnist資料集。

train_GAN.py:

gen_loss = 
        disc_loss = 

        # Start training
        print("Start training")
        for e in range(nb_epoch):
        # 每一回合
        # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            batch_counter = 
            start = time.time()

            for X_real_batch in data_utils.gen_batch(X_real_train, batch_size):
# 從X_real_train随機抽取一個batch,循環次數,通過下面
# 達到batch_counter >= n_batch_per_epoch,break控
# 制
                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(X_real_batch,
                                                           generator_model,
                                                           batch_counter,
                                                           batch_size,
                                                           noise_dim,
                                                           noise_scale=noise_scale,
                                                           label_smoothing=label_smoothing,
                                                           label_flipping=label_flipping)
# get_disc_batch為判别器生成一個batch資料,X_disc, 
# y_disc分别為資料和标簽
# data_utils.get_disc_batch 與WGAN中不同,這裡是
# batch_counter奇偶交替,資料為真實資料和生成的資料交替


                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)
                # 判别器更新一次             
                # Create a batch to feed the generator model
                X_gen, y_gen = data_utils.get_gen_batch(batch_size, noise_dim, noise_scale=noise_scale)
                # 采樣一個batch的噪聲給生成器
                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen, y_gen)
                # 更新一次生成器
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                batch_counter += 
                progbar.add(batch_size, values=[("D logloss", disc_loss),
                                                ("G logloss", gen_loss)])

                # Save images for visualization
                if batch_counter %  == :
                    data_utils.plot_generated_batch(X_real_batch, generator_model,
                                                    batch_size, noise_dim, image_dim_ordering)

                if batch_counter >= n_batch_per_epoch:
                    break
#總結:1、給判别器一個batch的真實/生成資料,判别器更新;
      、給生成器一個batch的真實資料,生成器更新;
      、生成器和判别器更新次數相同。
           

train_WGAN.py

############################
    # Compile models
    ############################
    generator_model.compile(loss='mse', optimizer=opt_G)
    discriminator_model.trainable = False
    DCGAN_model.compile(loss=models.wasserstein, optimizer=opt_G)
    discriminator_model.trainable = True
    discriminator_model.compile(loss=models.wasserstein, optimizer=opt_D)
    # 編譯模型,DCGAN_model是generator_model+discriminator_model.trainable = False
    # discriminator_model是discriminator_model.trainable = True

    # Global iteration counter for generator updates
    gen_iterations = 

    #################
    # Start training
    ################
    for e in range(nb_epoch): # nb_epoch:總共疊代多少回合
        # Initialize progbar and batch counter
        progbar = generic_utils.Progbar(epoch_size)
        batch_counter = 
        start = time.time()

        while batch_counter < n_batch_per_epoch: # n_batch_per_epoch:每個回合包含的batch數量
        # 不斷送入batch,直到每個回合的batch限定的數量
            if gen_iterations <  or gen_iterations %  == :
                disc_iterations = 
            else:
                disc_iterations = kwargs["disc_iterations"]
# 最開始生成器疊代次數gen_iterations < 25時,生成器
# 疊代1次,判别器疊代100次,疊代1次,送入1個batch;
# gen_iterations >= 25後,生成器疊代1次,判别器疊代5
# 次(disc_iterations預設是5);
# 在gen_iterations整除500時,仍是生成器疊代1次,判别器# 疊代100次。



            ###################################
            # 1) Train the critic / discriminator
            ###################################
            list_disc_loss_real = []
            list_disc_loss_gen = []
            for disc_it in range(disc_iterations):
                # 每個batch疊代disc_iterations次
                # Clip discriminator weights 修剪判别器權重 
                for l in discriminator_model.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, clamp_lower, clamp_upper) for w in weights]
                    l.set_weights(weights)
                X_real_batch = next(data_utils.gen_batch(X_real_train, batch_size))

                # Create a batch to feed the discriminator model
                X_disc_real, X_disc_gen = data_utils.get_disc_batch(X_real_batch,
                                                                    generator_model,
                                                                    batch_counter,
                                                                    batch_size,
                                                                    noise_dim,
                                                                    noise_scale=noise_scale)

                # Update the discriminator
                disc_loss_real = discriminator_model.train_on_batch(X_disc_real, -np.ones(X_disc_real.shape[]))
 # 送入真實資料,标簽是-1
                disc_loss_gen = discriminator_model.train_on_batch(X_disc_gen, np.ones(X_disc_gen.shape[]))
 # 送入生成資料,标簽是1 
 # loss是Wasserstein,給入真實樣本,pred越高越好,
 # 然而 loss要減小的,而mean(y_true * y_pred)
 # =mean(-y_pred)正好符合減小;給入生成樣本,pred越小
 # 越好,mean(y_true * y_pred)=mean(y_pred),也正好
 # 符合要求
            list_disc_loss_real.append(disc_loss_real)
                list_disc_loss_gen.append(disc_loss_gen)               
# 對噪聲采樣,送入生成器
# 判别器更新disc_iterations次後,生成器更新1次
            # Freeze the discriminator
            discriminator_model.trainable = False
            gen_loss = DCGAN_model.train_on_batch(X_gen, -np.ones(X_gen.shape[]))
            # Unfreeze the discriminator
            discriminator_model.trainable = True

            gen_iterations += 
            batch_counter +=  
#總結:
#1、判别器的一次疊代中,先将所有層weight進行clip,
#  給入真實資料一個batch,判别器更新;再給入一個batch生#  成資料,判别器更新;
#2、判别器如此疊代多次,每一次均是送入新的batch。
#3、判别器更新多次後,生成器送入采樣的一個batch噪聲資料,#   進行一次更新;
#4、生成器更新一次後,一次回合中的batch_counter加1,
#   直到batch_counter到達每回合的batch數量,判别器送
#   入的不止這個batch的數量。
           

models_WGAN.py

def generator_upsampling(noise_dim, img_dim, bn_mode, model_name="generator_upsampling", dset="mnist"):
    # Noise input and reshaping
    x = Dense(f * start_dim * start_dim, input_dim=noise_dim)(gen_input)
    x = Reshape(reshape_shape)(x)
    x = BatchNormalization(mode=bn_mode, axis=bn_axis)(x)
    x = Activation("relu")(x)
    # 100維噪聲,全連接配接512*7*7,之後reshape成512個特征圖形式7*7,再BN和relu
    # 以上與models_GAN沒有差別
    # Upscaling blocks: Upsampling2D->Conv2D->ReLU->BN->Conv2D->ReLU
    for i in range(nb_upconv):
        x = UpSampling2D(size=(, ))(x)
        nb_filters = int(f / ( ** (i + )))
        x = Convolution2D(nb_filters, , , border_mode="same", init=conv2D_init)(x)
        x = BatchNormalization(mode=bn_mode, axis=)(x)
        x = Activation("relu")(x)
        x = Convolution2D(nb_filters, , , border_mode="same", init=conv2D_init)(x)
        x = Activation("relu")(x)
    # 兩個上采樣層,上采樣完後,經曆兩次3*3卷積層和relu,這裡卷積加了pad,并未改變特征圖大小
    # Last Conv to get the output image
    x = Convolution2D(output_channels, , , name="gen_conv2d_final",                    border_mode="same", activation='tanh', init=conv2D_init)(x)
    # 最後層3*3卷積tanh,得到圖像28*28
    generator_model = Model(input=[gen_input], output=[x], name=model_name)
    visualize_model(generator_model)
 # GAN相比,多了conv2D_init的初始化
    return generator_model
 #生成器總結:
 #1、輸入噪聲維數是100
 #2、進入全連接配接層,512*7*7,進而轉成512個7*7特征圖
 #3、上采樣一次,變成512*14*14,256個3*3卷積核,特征
 #   圖大小不變,再進行一次256個3*3卷積核卷積
 #4、上采樣一次,變成256*28*28,128個3*3卷積核卷積,重
 #   複一次,均得到128*28*28
 #5、最後層卷積,一個特征圖,變為1*28*28
 #整個網絡結構:
 #100->7*7*512->上采樣14*14*512->卷積(3*3)
 #14*14*256->卷積(3*3)14*14*256->上采樣28*28*256
 #->卷積(3*3)28*28*128->卷積(3*3)28*28*128
 #->卷積(3*3)28*28*1
    # Get the list of number of conv filters
    # (first layer starts with 64), filters are subsequently doubled
    nb_conv = int(np.floor(np.log(min_s // ) / np.log())) # 共幾層卷積,計算公式 
def discriminator(img_dim, bn_mode, model_name="discriminator"):
    list_f = [ * min(, ( ** i)) for i in range(nb_conv)]
# 第一層卷積核個數是64,後面逐層乘以2,最多乘到8
# list定義與GAN不同
    # First conv with 2x2 strides
    x = Convolution2D(list_f[], , , subsample=(, ), name="disc_conv2d_1",
                     border_mode="same", bias=False, init=conv2D_init)(disc_input)
    x = BatchNormalization(mode=bn_mode, axis=bn_axis)(x)
    x = LeakyReLU()(x)
    #第一層卷積計算,2x2 strides
    # Conv blocks: Conv2D(2x2 strides)->BN->LReLU
    for i, f in enumerate(list_f[:]):
        name = "disc_conv2d_%s" % (i + )
        x = Convolution2D(f, , , subsample=(, ), name=name, border_mode="same", bias=False, init=conv2D_init)(x)
        x = BatchNormalization(mode=bn_mode, axis=bn_axis)(x)
        x = LeakyReLU()(x)
    # 與GAN不同在于,conv2D_init初始化,以及bias=False的設定
    # 接下來幾層卷積計算,均有batchNorm以及LeakyReLU

    # Last convolution
    #x = Convolution2D(1, 3, 3, name="last_conv", border_mode="same", bias=False, init=conv2D_init)(x)
    x = Convolution2D(, , , name="last_conv", border_mode="same", bias=False)(x)
    # Average pooling
    x = GlobalAveragePooling2D()(x)
 # 最後層conv,特征圖個數是1,進行了全局進行了平均,即得分,GAN最後層是softmax
 # 判别器總結:
 # 1、輸入為28*28
 # 2、第一層卷積核個數為64,大小3*3,步長是subsample=(2, 2),padding是same,輸出為14*14*64,注意:這裡為tf裡的conv2d的參數
 # 3、第二層卷積核個數128,大小3*3,步長是subsample=(2, 2),padding是same,輸出為7*7*128
 # 4、最後一層卷積,卷積核個數為1,大小3*3,padding是same,輸出為7*7*1,再進行全局平均,得到score,大小為1
 # 整體網絡結構:三層卷積
 # 28*28->卷積(3*3,stride2*2)14*14*64->卷積
 # (3*3,stride2*2)7*7*128->卷積(3*3)7*7*1
 # ->全局池化1
    discriminator_model = Model(input=[disc_input], output=[x], name=model_name)
    visualize_model(discriminator_model)

    return discriminator_model
def wasserstein(y_true, y_pred):

    # return K.mean(y_true * y_pred) / K.mean(y_true)
    return K.mean(y_true * y_pred)
    #因為keras的所有loss都是最小化,對于真實樣本maximize true_score = K.mean(y_pred),這裡y_true=-1,就變成了求最小
           

繼續閱讀