天天看點

風格遷移0-09:stylegan-源碼無死角解讀(5)-Discriminator網絡詳解

以下連結是個人關于stylegan所有見解,如有錯誤歡迎大家指出,我會第一時間糾正,如有興趣可以加微信:17575010159 互相讨論技術。若是幫助到了你什麼,一定要記得點贊奧!因為這是對我最大的鼓勵。

風格遷移0-00:stylegan-目錄-史上最全:https://blog.csdn.net/weixin_43013761/article/details/100895333

源碼注釋

在貼出源碼注釋之前,我想把Discriminator的結構貼出來把,如果大家已經跑過該程式,應該看到過很多次列印了:

D                     Params    OutputShape          WeightShape     
---                   ---       ---                  ---             
images_in             -         (?, 3, 1024, 1024)   -               
labels_in             -         (?, 0)               -               
lod                   -         ()                   -               
FromRGB_lod0          64        (?, 16, 1024, 1024)  (1, 1, 3, 16)   
1024x1024/Conv0       2320      (?, 16, 1024, 1024)  (3, 3, 16, 16)  
1024x1024/Conv1_down  4640      (?, 32, 512, 512)    (3, 3, 16, 32)  
Downscale2D           -         (?, 3, 512, 512)     -               
FromRGB_lod1          128       (?, 32, 512, 512)    (1, 1, 3, 32)   
Grow_lod0             -         (?, 32, 512, 512)    -               
512x512/Conv0         9248      (?, 32, 512, 512)    (3, 3, 32, 32)  
512x512/Conv1_down    18496     (?, 64, 256, 256)    (3, 3, 32, 64)  
Downscale2D_1         -         (?, 3, 256, 256)     -               
FromRGB_lod2          256       (?, 64, 256, 256)    (1, 1, 3, 64)   
Grow_lod1             -         (?, 64, 256, 256)    -               
256x256/Conv0         36928     (?, 64, 256, 256)    (3, 3, 64, 64)  
256x256/Conv1_down    73856     (?, 128, 128, 128)   (3, 3, 64, 128) 
Downscale2D_2         -         (?, 3, 128, 128)     -               
FromRGB_lod3          512       (?, 128, 128, 128)   (1, 1, 3, 128)  
Grow_lod2             -         (?, 128, 128, 128)   -               
128x128/Conv0         147584    (?, 128, 128, 128)   (3, 3, 128, 128)
128x128/Conv1_down    295168    (?, 256, 64, 64)     (3, 3, 128, 256)
Downscale2D_3         -         (?, 3, 64, 64)       -               
FromRGB_lod4          1024      (?, 256, 64, 64)     (1, 1, 3, 256)  
Grow_lod3             -         (?, 256, 64, 64)     -               
64x64/Conv0           590080    (?, 256, 64, 64)     (3, 3, 256, 256)
64x64/Conv1_down      1180160   (?, 512, 32, 32)     (3, 3, 256, 512)
Downscale2D_4         -         (?, 3, 32, 32)       -               
FromRGB_lod5          2048      (?, 512, 32, 32)     (1, 1, 3, 512)  
Grow_lod4             -         (?, 512, 32, 32)     -               
32x32/Conv0           2359808   (?, 512, 32, 32)     (3, 3, 512, 512)
32x32/Conv1_down      2359808   (?, 512, 16, 16)     (3, 3, 512, 512)
Downscale2D_5         -         (?, 3, 16, 16)       -               
FromRGB_lod6          2048      (?, 512, 16, 16)     (1, 1, 3, 512)  
Grow_lod5             -         (?, 512, 16, 16)     -               
16x16/Conv0           2359808   (?, 512, 16, 16)     (3, 3, 512, 512)
16x16/Conv1_down      2359808   (?, 512, 8, 8)       (3, 3, 512, 512)
Downscale2D_6         -         (?, 3, 8, 8)         -               
FromRGB_lod7          2048      (?, 512, 8, 8)       (1, 1, 3, 512)  
Grow_lod6             -         (?, 512, 8, 8)       -               
8x8/Conv0             2359808   (?, 512, 8, 8)       (3, 3, 512, 512)
8x8/Conv1_down        2359808   (?, 512, 4, 4)       (3, 3, 512, 512)
Downscale2D_7         -         (?, 3, 4, 4)         -               
FromRGB_lod8          2048      (?, 512, 4, 4)       (1, 1, 3, 512)  
Grow_lod7             -         (?, 512, 4, 4)       -               
4x4/MinibatchStddev   -         (?, 513, 4, 4)       -               
4x4/Conv              2364416   (?, 512, 4, 4)       (3, 3, 513, 512)
4x4/Dense0            4194816   (?, 512)             (8192, 512)     
4x4/Dense1            513       (?, 1)               (512, 1)        
scores_out            -         (?, 1)               -               
---                   ---       ---                  ---             
Total                 23087249   
           

下面是源碼的注釋:

#----------------------------------------------------------------------------
# Discriminator used in the StyleGAN paper.

def D_basic(
    images_in,                          # First input: Images [minibatch, channel, height, width].
    labels_in,                          # Second input: Labels [minibatch, label_size].
    num_channels        = 1,            # Number of input color channels. Overridden based on dataset.
    resolution          = 32,           # Input resolution. Overridden based on dataset.
    label_size          = 0,            # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
    fmap_base           = 8192,         # Overall multiplier for the number of feature maps.
    fmap_decay          = 1.0,          # log2 feature map reduction when doubling the resolution.
    fmap_max            = 512,          # Maximum number of feature maps in any layer.
    nonlinearity        = 'lrelu',      # Activation function: 'relu', 'lrelu',
    use_wscale          = True,         # Enable equalized learning rate?
    mbstd_group_size    = 4,            # Group size for the minibatch standard deviation layer, 0 = disable.
    mbstd_num_features  = 1,            # Number of features for the minibatch standard deviation layer.
    dtype               = 'float32',    # Data type to use for activations and outputs.
    fused_scale         = 'auto',       # True = fused convolution + scaling, False = separate ops, 'auto' = decide automatically.
    blur_filter         = [1,2,1],      # Low-pass filter to apply when resampling activations. None = no filtering.
    structure           = 'auto',       # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically.
    is_template_graph   = False,        # True = template graph constructed by the Network class, False = actual evaluation.
    **_kwargs):                         # Ignore unrecognized keyword args.

    # 在我們的網絡中,輸入為(?, 3, 1024,1024),即得到的resolution_log2為10
    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4

    # 通過stage,stage為fmap網絡全連結的層數,stage指定的層數不同,求的的fmap不同,
    def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)

    # 進行模糊操作
    def blur(x): return blur2d(x, blur_filter) if blur_filter else x

    if structure == 'auto': structure = 'linear' if is_template_graph else 'recursive'
    act, gain = {'relu': (tf.nn.relu, np.sqrt(2)), 'lrelu': (leaky_relu, np.sqrt(2))}[nonlinearity]

    # 輸入圖檔,為生成器生成1024分辨率的圖像
    images_in.set_shape([None, num_channels, resolution, resolution])
    # 标簽
    labels_in.set_shape([None, label_size])

    # 對輸入的圖檔進行格式轉換,一般轉換為float類型
    images_in = tf.cast(images_in, dtype)
    labels_in = tf.cast(labels_in, dtype)

    # 擷取目前lod,可以簡單了解為2的lod次方,代表分辨率,我們知道圖檔最開始輸出的是低分辨率的圖像,
    # 雖然圖像的像素都是1024,倒是經過平滑之後,幾乎都看不出來是什麼
    lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype)
    scores_out = None

    # Building blocks.
    # 一個卷積之後加上一個偏置,然後經過一個激活函數,這裡的res控制的是分辨率(不要和真是的分辨率混合),
    # 我所說的分辨率,是吧1024按照2的res次方分割的分辨率.這裡輸出的格式為RGB
    def fromrgb(x, res): # res = 2..resolution_log2
        with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)):
            return act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=1, gain=gain, use_wscale=use_wscale)))

    # 分開了兩個部分,根據res決定分辨率
    def block(x, res): # res = 2..resolution_log2
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            if res >= 3: # 8x8 and up
                with tf.variable_scope('Conv0'):
                    x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale)))
                with tf.variable_scope('Conv1_down'):
                    x = act(apply_bias(conv2d_downscale2d(blur(x), fmaps=nf(res-2), kernel=3, gain=gain, use_wscale=use_wscale, fused_scale=fused_scale)))
            else: # 4x4
                if mbstd_group_size > 1:
                    x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features)
                with tf.variable_scope('Conv'):
                    x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale)))
                with tf.variable_scope('Dense0'):
                    x = act(apply_bias(dense(x, fmaps=nf(res-2), gain=gain, use_wscale=use_wscale)))
                with tf.variable_scope('Dense1'):
                    x = apply_bias(dense(x, fmaps=max(label_size, 1), gain=1, use_wscale=use_wscale))
            return x

    # Fixed structure: simple and efficient, but does not support progressive growing.
    # 簡單直接方式,直接進行搭建,分辨率固定不變,即沒有分步成長
    if structure == 'fixed':
        x = fromrgb(images_in, resolution_log2)
        for res in range(resolution_log2, 2, -1):
            x = block(x, res)
        scores_out = block(x, 2)

    # Linear structure: simple but inefficient.
    # 從高分辨率開始,逐漸進行下采樣,每個下采樣都會有對應的RGB圖像進行輸出
    if structure == 'linear':
        img = images_in
        x = fromrgb(img, resolution_log2)
        for res in range(resolution_log2, 2, -1):
            lod = resolution_log2 - res
            x = block(x, res)
            img = downscale2d(img)
            y = fromrgb(img, res - 1)
            with tf.variable_scope('Grow_lod%d' % lod):
                x = tflib.lerp_clip(x, y, lod_in - lod)
        scores_out = block(x, 2)

    # Recursive structure: complex but efficient.
    # 這個沒有去了解,有知道的哥們可以告訴我,我會把你的解釋寫在這個地方
    if structure == 'recursive':
        def cset(cur_lambda, new_cond, new_lambda):
            return lambda: tf.cond(new_cond, new_lambda, cur_lambda)
        def grow(res, lod):
            x = lambda: fromrgb(downscale2d(images_in, 2**lod), res)
            if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1))
            x = block(x(), res); y = lambda: x
            if res > 2: y = cset(y, (lod_in > lod), lambda: tflib.lerp(x, fromrgb(downscale2d(images_in, 2**(lod+1)), res - 1), lod_in - lod))
            return y()
        scores_out = grow(2, resolution_log2 - 2)

    # Label conditioning from "Which Training Methods for GANs do actually Converge?"
    # 該處似乎為零,是以沒有執行
    if label_size:
        with tf.variable_scope('LabelSwitch'):
            scores_out = tf.reduce_sum(scores_out * labels_in, axis=1, keepdims=True)
    # 這裡的scores_out次元為(?,1),輸出代表的應該是這個圖檔為真或者為假的機率
    assert scores_out.dtype == tf.as_dtype(dtype)
    scores_out = tf.identity(scores_out, name='scores_out')
    return scores_out
           

這樣的結構真的簡單,一句總結,就是輸入圖檔,然後通過一系列的卷積激活,全連接配接操作,然後得到一個值,這個值就是對應圖檔圖檔是否為真是圖檔的機率值。

到這裡,我們對整個網絡的分析就已經完成了,下面就會進入我們核心的核心了,那就是損失函數的講解。如果覺得我寫的部落格對大家有所幫助,希望大家能給我點點贊,感謝大家一直以來的關注。