天天看點

論文閱讀以及pytorch源碼詳解-image-inpainting-via-generative-multi-column-convolutional-neural-networks-Paper

1. 動機(motivation)

1.針對如何提取到圖像合适特征的問題,本文提出了多個分支的卷積分支,每個分支采用不同的感受野,并将圖像分解成不同的感受野

2.針對如何為缺失區域尋找相似的patch,本文提出了馬爾可夫随機場(ID-MRF)項,

3.針對缺失區域的修複結果有很多可能性的結果,提出了新的置信驅動的重建損失(與空間衰減損失類似),根據缺失區域的空間位置限制生成的内容

2.具體方法

論文閱讀以及pytorch源碼詳解-image-inpainting-via-generative-multi-column-convolutional-neural-networks-Paper

訓練的是一種端到端的方式,輸入是X破損的圖檔和掩碼M,缺損的區域的填充值為0,M是二進制掩碼,0 代表已知的像素,1代表破損區域。

###2.1網絡的架構

如上圖所示,包含三個子網絡。一個生成網絡,一個全局和局部鑒别器網絡,和一個預訓練的VGG網絡來計算ID-MRF loss。在測試階段僅僅隻有生成網絡被使用。

生成器網絡包含三個平行的編碼-解碼卷積結構的分支來提取輸入資料(破損圖檔和掩碼M)的不同水準的特征,一個共享的解碼器網絡将三個分支提取的特征(這裡的特征圖的尺寸是和原始圖檔大小一樣大)進行concat組合起來作為輸入,将組合的特征進行解碼到自然圖像的資料空間上去(即進行圖像的修複)。如圖2所示,三個分支使用不同的感受野進行特征提取。不同的感受野必然會導緻最後得到的特征圖的尺寸不一樣大,那麼三個分支的提取到的特征圖就不好concat組合,本文是采用雙線性插值進行上采樣進行擴大特征圖的尺寸。

雖然三個分支看上去是互相獨立的,但是由于共享解碼器,三者之間是互相影響的

2.2 ID-MRF Regularization

這一部分,解決上述語義結構比對和計算量大的疊代MRF優化問題。計劃是隻在訓練階段采用mrf的正規化.ID-MRF是在特征空間上對生成區域(修複的區域)的内容和相應真實圖檔最近鄰區域之間不同的優化。由于隻在訓練中使用它,完整的ground truth圖像可以讓我們知道高品質的最近鄰,并給網絡适當的限制。

​ 要計算ID-MRF損失,可以簡單地使用直接相似度度量(如餘弦相似度)來找到生成内容中的更新檔的最近鄰居。但這一過程往往産生平滑的結構,因為一個平坦的區域容易連接配接到類似的模式,并迅速減少結構的多樣性。我們采用相對距離度量[17,16,22]來模組化局部特征與目标特征集之間的關系。它可以恢複如圖3(b)所示的細微細節。

論文閱讀以及pytorch源碼詳解-image-inpainting-via-generative-multi-column-convolutional-neural-networks-Paper

具體地,用 Y g ∗ Y_g^* Yg∗​代表對缺失區域的修複結果的内容, Y g ∗ L Y_g^{*L} Yg∗L​和 Y L Y^L YL分别代表來自預訓練模型的第L層的特征。

patch v和s分别來自 Y g ∗ L Y_g^{*L} Yg∗L​和 Y L Y^L YL,定義v與s的相對相似度為:

論文閱讀以及pytorch源碼詳解-image-inpainting-via-generative-multi-column-convolutional-neural-networks-Paper

注意:Y是真實圖檔

這裡的u(.,.)是計算餘弦相似度。 r ∈ p s ( Y L ) r\in ps(Y^L) r∈ps(YL)意思是r是屬于除了s的 Y L Y^L YL,h 和 ϵ \epsilon ϵ是兩個正常數。如果v比 Y L Y^L YL中的其他patch更像s, RS(v,s)會變大。

接下來,RS(v,s)歸一化為:

論文閱讀以及pytorch源碼詳解-image-inpainting-via-generative-multi-column-convolutional-neural-networks-Paper

最後,根據公式2, Y g ∗ L Y_g^{*L} Yg∗L​和 Y L Y^L YL之間的ID-MRF損失被定義為:

論文閱讀以及pytorch源碼詳解-image-inpainting-via-generative-multi-column-convolutional-neural-networks-Paper

這裡的Z是标準化參數,對于每一個屬于 Y L Y^L YL的patch s, v ’ = a r g m a x v ∈ Y g ∗ L R S ( v , s ) ∗ v’=arg max_{v\in Y_g^{*L} }RS(v,s)^* v’=argmaxv∈Yg∗L​​RS(v,s)∗。

味着v‘相對于 Y g ∗ L Y_g^{*L} Yg∗L​中的其他patch更加接近patch s。一個極端的例子是 Y g ∗ L Y_g^{*L} Yg∗L​中的所有pathch都非常接近一個patch s。而其他的patch r

論文閱讀以及pytorch源碼詳解-image-inpainting-via-generative-multi-column-convolutional-neural-networks-Paper

是以Lm(L)值更大。

另一個方面,當 Y g ∗ L Y_g^{*L} Yg∗L​中的patch與 Y L Y^L YL中的候選者非常接近, Y L Y^L YL中的每一個 patch r在 Y g ∗ L Y_g^{*L} Yg∗L​中有一個唯一的最近鄰。那麼結果就是RS’(v,r)變大,LM(L)變小。

從這個觀點出發,最小化LM(L)鼓勵 Y g ∗ L Y_g^{*L} Yg∗L​中的每一個patch V都不同于 Y L Y^L YL中的patch,使得變得多樣化。

​ 該方法的一個明顯優點是提高了 Y g ∗ L Y_g^{*L} Yg∗L​和 Y L Y^L YL特征分布之間的相似性。通過最小化ID-MRF損失,不僅局部神經patch在 Y L Y^L YL中找到對應的候選紋理,而且特征分布更接近,有助于捕獲複雜紋理的變化。

​ 我們最終的ID-MRF損失是在VGG19的幾個特征層上計算的。按照一般實踐[5,14],我們使用conv4_2描述圖像語義結構。然後利用conv3_2和conv4_2 4将圖像紋理描述為:

論文閱讀以及pytorch源碼詳解-image-inpainting-via-generative-multi-column-convolutional-neural-networks-Paper

2.3 Information Fusion

  1. 空間重建損失

    破損區域距離邊界近的應該比距離邊界遠的具有更加多的限制。

  2. 生成對抗損失

    采用更加優化的w-GAN來實作

2.4最終的損失函數

論文閱讀以及pytorch源碼詳解-image-inpainting-via-generative-multi-column-convolutional-neural-networks-Paper

###2.5訓練方法

首先僅僅使用重建損失即将 λ m r f 和 λ a d v \lambda_{mrf}和\lambda_{adv} λmrf​和λadv​設定為0進行訓練,來穩定後面的對抗訓練。

模型G收斂後,我們設定λ mrf = 0.05和λ adv = 0.001進行微調直到收斂。利用Adam優化器[13]對訓練過程進行優化,學習率為1e4。設β 1 = 0.5, β 2 = 0.9。批大小為16。

3. GMCNN的pytorch源碼詳解與實作

3.1訓練配置代碼,train_options.py

import argparse
import os
import time

class TrainOptions:
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self.initialized = False

    def initialize(self):
        # experiment specifics
        self.parser.add_argument('--dataset', type=str, default='Celebhq',help='dataset of the experiment.')
        #self.parser.add_argument('--data_file', type=str, default='', help='the file storing training image paths')
        self.parser.add_argument('--data_file', type=str, default='/root/workspace/pyproject/inpainting_gmcnn-master/pytorch/util/celeba_256_train.txt', help='the file storing training image paths')#這個檔案裡是存放的每張圖檔的絕對路徑
        
        self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2')
        self.parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='models are saved here')
       # self.parser.add_argument('--load_model_dir', type=str, default='', help='pretrained models are given here')
        self.parser.add_argument('--load_model_dir', type=str, default='/root/workspace/pyproject/inpainting_gmcnn-master/pytorch/checkpoints/20210509-164655_GMCNN_Celebhq_b8_s256x256_gc32_dc64_randmask-rect_pretrain', help='pretrained models are given here')
        self.parser.add_argument('--phase', type=str, default='train')

        # input/output sizes
       # self.parser.add_argument('--batch_size', type=int, default=16, help='input batch size')
        self.parser.add_argument('--batch_size', type=int, default=8, help='input batch size')

        # for setting inputs
        self.parser.add_argument('--random_crop', type=int, default=1,
                                 help='using random crop to process input image when '
                                      'the required size is smaller than the given size')
        self.parser.add_argument('--random_mask', type=int, default=1)
        self.parser.add_argument('--mask_type', type=str, default='rect')
        self.parser.add_argument('--pretrain_network', type=int, default=0)#wm,是否是預訓練網絡,1代表預訓練,預訓練是僅僅用重建損失訓練生成網絡,0代表微調網絡,加上ID-MRF和生成對抗損失
        self.parser.add_argument('--lambda_adv', type=float, default=1e-3)
        self.parser.add_argument('--lambda_rec', type=float, default=1.4)
        self.parser.add_argument('--lambda_ae', type=float, default=1.2)
        self.parser.add_argument('--lambda_mrf', type=float, default=0.05)
        self.parser.add_argument('--lambda_gp', type=float, default=10)
        self.parser.add_argument('--random_seed', type=bool, default=False)
        self.parser.add_argument('--padding', type=str, default='SAME')
        self.parser.add_argument('--D_max_iters', type=int, default=5)#訓練時,生成器每訓練5次,然後更新一次鑒别器的網絡
        self.parser.add_argument('--lr', type=float, default=1e-5, help='learning rate for training')

        self.parser.add_argument('--train_spe', type=int, default=1000)
        self.parser.add_argument('--epochs', type=int, default=40)
        self.parser.add_argument('--viz_steps', type=int, default=5)
        self.parser.add_argument('--spectral_norm', type=int, default=1)

        self.parser.add_argument('--img_shapes', type=str, default='256,256,3',
                                 help='given shape parameters: h,w,c or h,w')
        self.parser.add_argument('--mask_shapes', type=str, default='128,128',
                                 help='given mask parameters: h,w')
        self.parser.add_argument('--max_delta_shapes', type=str, default='32,32')
        self.parser.add_argument('--margins', type=str, default='0,0')


        # for generator
        self.parser.add_argument('--g_cnum', type=int, default=32,
                                 help='# of generator filters in first conv layer')
        self.parser.add_argument('--d_cnum', type=int, default=64,
                                 help='# of discriminator filters in first conv layer')

        # for id-mrf computation
        self.parser.add_argument('--vgg19_path', type=str, default='vgg19_weights/imagenet-vgg-verydeep-19.mat')
        # for instance-wise features
        self.initialized = True

    def parse(self):
        if not self.initialized:
            self.initialize()
        self.opt = self.parser.parse_args()

        self.opt.dataset_path = self.opt.data_file

        str_ids = self.opt.gpu_ids.split(',')
        self.opt.gpu_ids = []
        for str_id in str_ids:
            id = int(str_id)
            if id >= 0:
                self.opt.gpu_ids.append(str(id))

        assert self.opt.random_crop in [0, 1]
        self.opt.random_crop = True if self.opt.random_crop == 1 else False

        assert self.opt.random_mask in [0, 1]
        self.opt.random_mask = True if self.opt.random_mask == 1 else False

        assert self.opt.pretrain_network in [0, 1]
        self.opt.pretrain_network = True if self.opt.pretrain_network == 1 else False

        assert self.opt.spectral_norm in [0, 1]
        self.opt.spectral_norm = True if self.opt.spectral_norm == 1 else False

        assert self.opt.padding in ['SAME', 'MIRROR']

        assert self.opt.mask_type in ['rect', 'stroke']

        str_img_shapes = self.opt.img_shapes.split(',')
        self.opt.img_shapes = [int(x) for x in str_img_shapes]

        str_mask_shapes = self.opt.mask_shapes.split(',')
        self.opt.mask_shapes = [int(x) for x in str_mask_shapes]

        str_max_delta_shapes = self.opt.max_delta_shapes.split(',')
        self.opt.max_delta_shapes = [int(x) for x in str_max_delta_shapes]

        str_margins = self.opt.margins.split(',')
        self.opt.margins = [int(x) for x in str_margins]

        # model name and date
        self.opt.date_str = time.strftime('%Y%m%d-%H%M%S')
        self.opt.model_name = 'GMCNN'
        self.opt.model_folder = self.opt.date_str + '_' + self.opt.model_name
        self.opt.model_folder += '_' + self.opt.dataset
        self.opt.model_folder += '_b' + str(self.opt.batch_size)
        self.opt.model_folder += '_s' + str(self.opt.img_shapes[0]) + 'x' + str(self.opt.img_shapes[1])
        self.opt.model_folder += '_gc' + str(self.opt.g_cnum)
        self.opt.model_folder += '_dc' + str(self.opt.d_cnum)

        self.opt.model_folder += '_randmask-' + self.opt.mask_type if self.opt.random_mask else ''
        self.opt.model_folder += '_pretrain' if self.opt.pretrain_network else ''

        if os.path.isdir(self.opt.checkpoint_dir) is False:
            os.mkdir(self.opt.checkpoint_dir)

        self.opt.model_folder = os.path.join(self.opt.checkpoint_dir, self.opt.model_folder)
        if os.path.isdir(self.opt.model_folder) is False:
            os.mkdir(self.opt.model_folder)

        # set gpu ids
        if len(self.opt.gpu_ids) > 0:
            os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(self.opt.gpu_ids)

        args = vars(self.opt)

        print('------------ Options -------------')
        for k, v in sorted(args.items()):
            print('%s: %s' % (str(k), str(v)))
        print('-------------- End ----------------')

        return self.opt

           

3.2訓練代碼train.py

import os
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
from data.data import InpaintingDataset, ToTensor
from model.net import InpaintingModel_GMCNN
from options.train_options import TrainOptions
from util.utils import getLatest
import tqdm

config = TrainOptions().parse()#wm擷取訓練的配置資訊超參數
print("訓練配置資訊config:",config)#wm


print('loading data........')
#wm,根據圖檔的絕對路徑,加載資料集
dataset = InpaintingDataset(config.dataset_path, '', transform=transforms.Compose([
    ToTensor()#圖檔資料将會被轉換成tensor,并且數值都在0-1之間
]))


#wm,生成資料集的batch_size疊代器
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, drop_last=True)
print('data load end.........')

print('configuring model..')
ourModel = InpaintingModel_GMCNN(in_channels=4, opt=config)#wm,根據訓練配置資訊參數,執行個體化一個GMCNN模型


ourModel.print_networks()#列印模型的網絡


if config.load_model_dir != '':
    print('Loading pretrained model from {}'.format(config.load_model_dir))
    ourModel.load_networks(getLatest(os.path.join(config.load_model_dir, '*.pth')))
    print('Loading done.')
# ourModel = torch.nn.DataParallel(ourModel).cuda()
print('model setting up..')
print('training initializing..')


writer = SummaryWriter(log_dir=config.model_folder)#使用tensorboardX執行個體化一個日志類

cnt = 0#用來記錄訓練了多少個batch_size
#config.epochs=30
for epoch in range(config.epochs):

    for i, data in enumerate(dataloader):
        gt = data['gt'].cuda()
        # normalize to values between -1 and 1,
        gt = gt / 127.5 - 1

        data_in = {'gt': gt}
        ourModel.setInput(data_in)#wm,将一個batch_size裡的圖檔送入網絡
        ourModel.optimize_parameters()#wm,通過這一個batch_size的資料對網絡進行訓練優化參數

        if (i+1) % config.viz_steps == 0:                   #viz_steps=5
            ret_loss = ourModel.get_current_losses()#wm,得到目前這個一個batch資料計算到的各種損失值
            if config.pretrain_network is False:
                print(
                    '[%d, %5d] G_loss: %.4f (rec: %.4f, ae: %.4f, adv: %.4f, mrf: %.4f), D_loss: %.4f'
                    % (epoch + 1, i + 1, ret_loss['G_loss'], ret_loss['G_loss_rec'], ret_loss['G_loss_ae'],
                       ret_loss['G_loss_adv'], ret_loss['G_loss_mrf'], ret_loss['D_loss']))

                writer.add_scalar('adv_loss', ret_loss['G_loss_adv'], cnt)
                writer.add_scalar('D_loss', ret_loss['D_loss'], cnt)
                writer.add_scalar('G_mrf_loss', ret_loss['G_loss_mrf'], cnt)
            else:
                print('[%d, %5d] G_loss: %.4f (rec: %.4f, ae: %.4f)'
                      % (epoch + 1, i + 1, ret_loss['G_loss'], ret_loss['G_loss_rec'], ret_loss['G_loss_ae']))

            #wm,将各種損失的值添加到日志類writer中,cnt是訓練了第多少個batch_size
            writer.add_scalar('G_loss', ret_loss['G_loss'], cnt)
            writer.add_scalar('reconstruction_loss', ret_loss['G_loss_rec'], cnt)
            writer.add_scalar('autoencoder_loss', ret_loss['G_loss_ae'], cnt)

            #images中包含了三中類型的圖
            images = ourModel.get_current_visuals_tensor()

            im_completed = vutils.make_grid(images['completed'], normalize=True, scale_each=True)#修複的圖
            im_input = vutils.make_grid(images['input'], normalize=True, scale_each=True)#輸入的帶掩碼的圖
            im_gt = vutils.make_grid(images['gt'], normalize=True, scale_each=True)#真實的圖

            # wm,将訓練過程中産生的圖添加到日志類writer中,cnt是訓練了第多少個batch_size
            writer.add_image('gt', im_gt, cnt)
            writer.add_image('input', im_input, cnt)
            writer.add_image('completed', im_completed, cnt)

            #wm,每訓練1000個batch_size,就儲存一次模型
            if (i+1) % config.train_spe == 0:#wm,train_spe=1000
                print('saving model ..')
                ourModel.save_networks(epoch+1)
        cnt += 1
    ourModel.save_networks(epoch+1)#儲存最後一個epoch的模型

writer.export_scalars_to_json(os.path.join(config.model_folder, 'GMCNN_scalars.json'))
writer.close()

           

3.3搭建GMCNN網絡net.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from model.basemodel import BaseModel
from model.basenet import BaseNet
from model.loss import WGANLoss, IDMRFLoss
from model.layer import init_weights, PureUpsampling, ConfidenceDrivenMaskLayer, SpectralNorm
import numpy as np

# generative multi-column convolutional neural net
#1.GMCNN的分支卷積網絡,即修複器的網絡,用不同的感受野來進行特征提取
class GMCNN(BaseNet):
    def __init__(self, in_channels, out_channels, cnum=32, act=F.elu, norm=F.instance_norm, using_norm=False):
        super(GMCNN, self).__init__()
        self.act = act
        self.using_norm = using_norm
        if using_norm is True:
            self.norm = norm
        else:
            self.norm = None
        ch = cnum

        # network structure
        self.EB1 = []#wm,第一個分支
        self.EB2 = []#wm,第二個分支
        self.EB3 = []#wm,第三個分支
        self.decoding_layers = []#一個共享的解碼器層

        self.EB1_pad_rec = []
        self.EB2_pad_rec = []
        self.EB3_pad_rec = []

        self.EB1.append(nn.Conv2d(in_channels, ch, kernel_size=7, stride=1))

        self.EB1.append(nn.Conv2d(ch, ch * 2, kernel_size=7, stride=2))
        self.EB1.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=7, stride=1))

        self.EB1.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=7, stride=2))
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))

        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=2))
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=4))
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=8))
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1, dilation=16))

        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))
        self.EB1.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=7, stride=1))

        self.EB1.append(PureUpsampling(scale=4))

        self.EB1_pad_rec = [3, 3, 3, 3, 3, 3, 6, 12, 24, 48, 3, 3, 0]

        self.EB2.append(nn.Conv2d(in_channels, ch, kernel_size=5, stride=1))

        self.EB2.append(nn.Conv2d(ch, ch * 2, kernel_size=5, stride=2))
        self.EB2.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=5, stride=1))

        self.EB2.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, stride=2))
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))

        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=2))
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=4))
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=8))
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1, dilation=16))

        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))
        self.EB2.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, stride=1))

        self.EB2.append(PureUpsampling(scale=2, mode='nearest'))
        self.EB2.append(nn.Conv2d(ch * 4, ch * 2, kernel_size=5, stride=1))
        self.EB2.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=5, stride=1))
        self.EB2.append(PureUpsampling(scale=2))
        self.EB2_pad_rec = [2, 2, 2, 2, 2, 2, 4, 8, 16, 32, 2, 2, 0, 2, 2, 0]

        self.EB3.append(nn.Conv2d(in_channels, ch, kernel_size=3, stride=1))

        self.EB3.append(nn.Conv2d(ch, ch * 2, kernel_size=3, stride=2))
        self.EB3.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=3, stride=1))

        self.EB3.append(nn.Conv2d(ch * 2, ch * 4, kernel_size=3, stride=2))
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))

        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=2))
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=4))
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=8))
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1, dilation=16))

        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))
        self.EB3.append(nn.Conv2d(ch * 4, ch * 4, kernel_size=3, stride=1))

        self.EB3.append(PureUpsampling(scale=2, mode='nearest'))
        self.EB3.append(nn.Conv2d(ch * 4, ch * 2, kernel_size=3, stride=1))
        self.EB3.append(nn.Conv2d(ch * 2, ch * 2, kernel_size=3, stride=1))
        self.EB3.append(PureUpsampling(scale=2, mode='nearest'))
        self.EB3.append(nn.Conv2d(ch * 2, ch, kernel_size=3, stride=1))
        self.EB3.append(nn.Conv2d(ch, ch, kernel_size=3, stride=1))

        self.EB3_pad_rec = [1, 1, 1, 1, 1, 1, 2, 4, 8, 16, 1, 1, 0, 1, 1, 0, 1, 1]

        self.decoding_layers.append(nn.Conv2d(ch * 7, ch // 2, kernel_size=3, stride=1))
        self.decoding_layers.append(nn.Conv2d(ch // 2, out_channels, kernel_size=3, stride=1))

        self.decoding_pad_rec = [1, 1]

        self.EB1 = nn.ModuleList(self.EB1)#将清單子產品連接配接組合成網絡結構
        self.EB2 = nn.ModuleList(self.EB2)
        self.EB3 = nn.ModuleList(self.EB3)
        self.decoding_layers = nn.ModuleList(self.decoding_layers)

        # padding operations
        padlen = 49
        self.pads = [0] * padlen
        for i in range(padlen):
            self.pads[i] = nn.ReflectionPad2d(i)
        self.pads = nn.ModuleList(self.pads)

    def forward(self, x):#将一張圖檔複制三份,分别送入三個分支
        x1, x2, x3 = x, x, x
        for i, layer in enumerate(self.EB1):
            pad_idx = self.EB1_pad_rec[i]
            x1 = layer(self.pads[pad_idx](x1))#對特征圖外圍進行padding,然後進行卷積操作
            if self.using_norm:
                x1 = self.norm(x1)
            if pad_idx != 0:
                x1 = self.act(x1)#分支1的特征圖結果

        for i, layer in enumerate(self.EB2):
            pad_idx = self.EB2_pad_rec[i]
            x2 = layer(self.pads[pad_idx](x2))
            if self.using_norm:
                x2 = self.norm(x2)
            if pad_idx != 0:
                x2 = self.act(x2)#分支2的特征圖結果

        for i, layer in enumerate(self.EB3):
            pad_idx = self.EB3_pad_rec[i]
            x3 = layer(self.pads[pad_idx](x3))
            if self.using_norm:
                x3 = self.norm(x3)
            if pad_idx != 0:
                x3 = self.act(x3)#分支3的特征圖結果

        x_d = torch.cat((x1, x2, x3), 1)#wm,将三個分支的結果cat一起

        #wm,經過編碼器
        x_d = self.act(self.decoding_layers[0](self.pads[self.decoding_pad_rec[0]](x_d)))
        x_d = self.decoding_layers[1](self.pads[self.decoding_pad_rec[1]](x_d))
        x_out = torch.clamp(x_d, -1, 1)#wm,将值限制在-1,到1之間

        return x_out#傳回的是一個batch_size的圖檔資料,資料類型是tensor,值的範圍在(-1,1)


# return one dimensional output indicating the probability of realness or fakeness
#2.基礎鑒别器子產品
class Discriminator(BaseNet):
    def __init__(self, in_channels, cnum=32, fc_channels=8*8*32*4, act=F.elu, norm=None, spectral_norm=True):
        super(Discriminator, self).__init__()
        self.act = act
        self.norm = norm
        self.embedding = None
        self.logit = None

        ch = cnum
        self.layers = []
        if spectral_norm:
            self.layers.append(SpectralNorm(nn.Conv2d(in_channels, ch, kernel_size=5, padding=2, stride=2)))
            self.layers.append(SpectralNorm(nn.Conv2d(ch, ch * 2, kernel_size=5, padding=2, stride=2)))
            self.layers.append(SpectralNorm(nn.Conv2d(ch * 2, ch * 4, kernel_size=5, padding=2, stride=2)))
            self.layers.append(SpectralNorm(nn.Conv2d(ch * 4, ch * 4, kernel_size=5, padding=2, stride=2)))
            self.layers.append(SpectralNorm(nn.Linear(fc_channels, 1)))#傳回一個标量,代表對圖檔的打分,對真實的圖檔打的高,對修複的圖打分低
        else:
            self.layers.append(nn.Conv2d(in_channels, ch, kernel_size=5, padding=2, stride=2))
            self.layers.append(nn.Conv2d(ch, ch * 2, kernel_size=5, padding=2, stride=2))
            self.layers.append(nn.Conv2d(ch*2, ch*4, kernel_size=5, padding=2, stride=2))
            self.layers.append(nn.Conv2d(ch*4, ch*4, kernel_size=5, padding=2, stride=2))
            self.layers.append(nn.Linear(fc_channels, 1))#傳回一個标量,代表對圖檔的打分,對真實的圖檔打的高,對修複的圖打分低

        self.layers = nn.ModuleList(self.layers)#将清單裡面的子產品連接配接組合成網絡結構

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = layer(x)
            if self.norm is not None:
                x = self.norm(x)
            x = self.act(x)
        self.embedding = x.view(x.size(0), -1)#将卷積得到的特征圖展成一維向量

        self.logit = self.layers[-1](self.embedding)
        return self.logit#傳回一個标量,代表對圖檔的打分,對真實的圖檔打的高,對修複的圖打分低



#3綜合鑒别器,利用基礎鑒别器子產品,将全局鑒别器和局部鑒别器組合在一起,差別在于特征圖的尺寸不同,即最後一層展成一維向量後長度不同
class GlobalLocalDiscriminator(BaseNet):
    def __init__(self, in_channels, cnum=32, g_fc_channels=16*16*32*4, l_fc_channels=8*8*32*4, act=F.elu, norm=None,
                 spectral_norm=True):
        super(GlobalLocalDiscriminator, self).__init__()
        self.act = act
        self.norm = norm

        self.global_discriminator = Discriminator(in_channels=in_channels, fc_channels=g_fc_channels, cnum=cnum,
                                                  act=act, norm=norm, spectral_norm=spectral_norm)
        self.local_discriminator = Discriminator(in_channels=in_channels, fc_channels=l_fc_channels, cnum=cnum,
                                                 act=act, norm=norm, spectral_norm=spectral_norm)

    def forward(self, x_g, x_l):
        x_global = self.global_discriminator(x_g)
        x_local = self.local_discriminator(x_l)
        return x_global, x_local#放回的是全局鑒别器的得分,局部鑒别器的得分


from util.utils import generate_mask


#4.利用前面的子產品,組合成GMCNN的修複模型
class InpaintingModel_GMCNN(BaseModel):
    def __init__(self, in_channels, act=F.elu, norm=None, opt=None):
        super(InpaintingModel_GMCNN, self).__init__()
        self.opt = opt
        self.init(opt)
        #得到一個計算損失的掩碼權重,完好處的像素的掩碼處權重較大,缺失區域的掩碼權重相對較小,呈高斯形狀
        self.confidence_mask_layer = ConfidenceDrivenMaskLayer()
        #執行個體化一個修複器
        self.netGM = GMCNN(in_channels, out_channels=3, cnum=opt.g_cnum, act=act, norm=norm).cuda() #wm,三個平行網絡+一個解碼器,并放到cuda上

        init_weights(self.netGM)#wm,初始化網絡

        self.model_names = ['GM']
        if self.opt.phase == 'test':
            return

        self.netD = None
        #wm,将生成器的網絡參數,放入Adam優化器中
        self.optimizer_G = torch.optim.Adam(self.netGM.parameters(), lr=opt.lr, betas=(0.5, 0.9))
        self.optimizer_D = None

        self.wganloss = None
        self.recloss = nn.L1Loss()
        self.aeloss = nn.L1Loss()
        self.mrfloss = None

        self.lambda_adv = opt.lambda_adv#生成對抗損失權重的超參數
        self.lambda_rec = opt.lambda_rec#重建損失的超參數
        self.lambda_ae = opt.lambda_ae
        self.lambda_gp = opt.lambda_gp#w-gan的中超參數
        self.lambda_mrf = opt.lambda_mrf#mrf損失的權重超參數

        self.G_loss = None
        self.G_loss_reconstruction = None
        self.G_loss_mrf = None
        self.G_loss_adv, self.G_loss_adv_local = None, None
        self.G_loss_ae = None
        self.D_loss, self.D_loss_local = None, None
        self.GAN_loss = None

        self.gt, self.gt_local = None, None
        self.mask, self.mask_01 = None, None
        self.rect = None

        self.im_in, self.gin = None, None

        self.completed, self.completed_local = None, None
        self.completed_logit, self.completed_local_logit = None, None
        self.gt_logit, self.gt_local_logit = None, None

        self.pred = None

        #wm,如果不是對模型進行預訓練,需要執行個體化一個鑒别器網絡,這裡的預訓練指的是對模型僅僅用重建損失進行預訓練:
        if self.opt.pretrain_network is False:
            if self.opt.mask_type == 'rect':
                self.netD = GlobalLocalDiscriminator(3, cnum=opt.d_cnum, act=act,
                                                     g_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4,
                                                     l_fc_channels=opt.mask_shapes[0]//16*opt.mask_shapes[1]//16*opt.d_cnum*4,
                                                     spectral_norm=self.opt.spectral_norm).cuda()
            else:
                self.netD = GlobalLocalDiscriminator(3, cnum=opt.d_cnum, act=act,
                                                     spectral_norm=self.opt.spectral_norm,
                                                     g_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4,
                                                     l_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4).cuda()
            init_weights(self.netD)#初始化鑒别器
            self.optimizer_D = torch.optim.Adam(filter(lambda x: x.requires_grad, self.netD.parameters()), lr=opt.lr,
                                                betas=(0.5, 0.9))#将鑒别器的網絡參數放到Adam優化器中
            self.wganloss = WGANLoss()#執行個體化WGAN損失
            self.mrfloss = IDMRFLoss()#執行個體化IDMRF損失

    #初始化各種變量,并獲得輸入生成器網絡的輸入圖檔資料
    def initVariables(self):
        self.gt = self.input['gt']#擷取一個batch_size的真圖
        mask, rect = generate_mask(self.opt.mask_type, self.opt.img_shapes, self.opt.mask_shapes)#wm,生成掩碼,和矩形空洞的位置
        self.mask_01 = torch.from_numpy(mask).cuda().repeat([self.opt.batch_size, 1, 1, 1])#0代表完好區域,1代表缺失區域,從numpy格式轉換成tensor
        self.mask = self.confidence_mask_layer(self.mask_01)#掩碼權重參數,用來計算重建損失時用的

        if self.opt.mask_type == 'rect':
            self.rect = [rect[0, 0], rect[0, 1], rect[0, 2], rect[0, 3]]
            #用來得到局部的真實圖
            self.gt_local = self.gt[:, :, self.rect[0]:self.rect[0] + self.rect[1],self.rect[2]:self.rect[2] + self.rect[3]]
        else:
            self.gt_local = self.gt

        self.im_in = self.gt * (1 - self.mask_01)#隻有完好區域為原始的真實值,空洞區域的值為0
        self.gin = torch.cat((self.im_in, self.mask_01), 1)#這是最開始輸入修複網絡中的圖檔資料,4個通道

    #前向計算生成器,得到生成器的各種損失
    def forward_G(self):
        self.G_loss_reconstruction = self.recloss(self.completed * self.mask, self.gt.detach() * self.mask)#計算最終修複的結果和真實圖的損失,并用了掩碼權重
        self.G_loss_reconstruction = self.G_loss_reconstruction / torch.mean(self.mask_01)

        self.G_loss_ae = self.aeloss(self.pred * (1 - self.mask_01), self.gt.detach() * (1 - self.mask_01))#計算原本完好區域和預測出的完好區域的損失
        self.G_loss_ae = self.G_loss_ae / torch.mean(1 - self.mask_01)

        self.G_loss = self.lambda_rec * self.G_loss_reconstruction + self.lambda_ae * self.G_loss_ae#給重建損失乘以相關權重系數

        if self.opt.pretrain_network is False:#如果不是預訓練,那麼還得計算生成對抗損失和ID-MRF損失
            # discriminator
            self.completed_logit, self.completed_local_logit = self.netD(self.completed, self.completed_local)#擷取鑒别器網絡對修複的圖的全局打分和局部打分

            self.G_loss_mrf = self.mrfloss((self.completed_local+1)/2.0, (self.gt_local.detach()+1)/2.0)#計算ID-MRF損失
            self.G_loss = self.G_loss + self.lambda_mrf * self.G_loss_mrf#生成器的損失加上ID-MRF損失

            self.G_loss_adv = -self.completed_logit.mean()#生成對抗的全局損失
            self.G_loss_adv_local = -self.completed_local_logit.mean()#生成對抗的局部損失
            self.G_loss = self.G_loss + self.lambda_adv * (self.G_loss_adv + self.G_loss_adv_local)#總的損失


    # 前向計算鑒别器,得到鑒别器的各種損失
    def forward_D(self):
        self.completed_logit, self.completed_local_logit = self.netD(self.completed.detach(), self.completed_local.detach())#d對修複圖檔的全局和局部鑒别打分
        self.gt_logit, self.gt_local_logit = self.netD(self.gt, self.gt_local)#對真實圖檔全局和局部的鑒别打分
        # hinge loss
        self.D_loss_local = nn.ReLU()(1.0 - self.gt_local_logit).mean() + nn.ReLU()(1.0 + self.completed_local_logit).mean()#對局部圖檔的鑒别器的損失
        self.D_loss = nn.ReLU()(1.0 - self.gt_logit).mean() + nn.ReLU()(1.0 + self.completed_logit).mean()#對全局圖檔鑒别器的損失

        self.D_loss = self.D_loss + self.D_loss_local

    #反向傳播計算生成器的梯度
    def backward_G(self):
        self.G_loss.backward()
    #反向傳播計算鑒别器的梯度
    def backward_D(self):
        self.D_loss.backward(retain_graph=True)


    #進行資料流的正向流動
    def optimize_parameters(self):
        self.initVariables()

        self.pred = self.netGM(self.gin)#将破損圖檔送入修複網絡中進行修複,得到預測結果
        self.completed = self.pred * self.mask_01 + self.gt * (1 - self.mask_01)#将預測得到的圖檔,完好區域用以前的真值進行替換,那麼就得到了最終的修複結果

        if self.opt.mask_type == 'rect':
            self.completed_local = self.completed[:, :, self.rect[0]:self.rect[0] + self.rect[1],
                                   self.rect[2]:self.rect[2] + self.rect[3]]
        else:
            self.completed_local = self.completed

        if self.opt.pretrain_network is False:#如果不是預訓練階段的僅僅用重建損失訓練生成器網絡,那麼還有生成對抗損失
            for i in range(self.opt.D_max_iters):
                self.optimizer_D.zero_grad()#鑒别器網絡的梯度清為0
                self.optimizer_G.zero_grad()#生成器網絡的梯度清為0
                self.forward_D()#正向傳播鑒别器
                self.backward_D()#反向傳播
                self.optimizer_D.step()#更新鑒别器的網絡參數

        self.optimizer_G.zero_grad()#生成器網絡的梯度清為0
        self.forward_G()#生成器正向傳播
        self.backward_G()#生成器反向傳播
        self.optimizer_G.step()#更新生成器的網絡參數

    #傳回目前所有的損失,采用字典結構資料進行傳回
    def get_current_losses(self):
        l = {'G_loss': self.G_loss.item(), 'G_loss_rec': self.G_loss_reconstruction.item(),
             'G_loss_ae': self.G_loss_ae.item()}#如果是預訓練階段隻有重建損失

        if self.opt.pretrain_network is False:
            l.update({'G_loss_adv': self.G_loss_adv.item(),
                      'G_loss_adv_local': self.G_loss_adv_local.item(),
                      'D_loss': self.D_loss.item(),
                      'G_loss_mrf': self.G_loss_mrf.item()})
        return l

    #得到目前的網絡輸入圖檔,真實圖檔,最終修複得到的圖檔,圖檔的資料是tensor格式
    def get_current_visuals(self):
        return {'input': self.im_in.cpu().detach().numpy(), 'gt': self.gt.cpu().detach().numpy(),
                'completed': self.completed.cpu().detach().numpy()}

    #得到目前的網絡輸入圖檔,真實圖檔,最終修複得到的圖檔,圖檔的資料是tensor格式
    def get_current_visuals_tensor(self):
        return {'input': self.im_in.cpu().detach(), 'gt': self.gt.cpu().detach(),
                'completed': self.completed.cpu().detach()}


    #對圖檔進行評估
    def evaluate(self, im_in, mask):
        im_in = torch.from_numpy(im_in).type(torch.FloatTensor).cuda() / 127.5 - 1
        mask = torch.from_numpy(mask).type(torch.FloatTensor).cuda()
        im_in = im_in * (1-mask)
        xin = torch.cat((im_in, mask), 1)
        ret = self.netGM(xin) * mask + im_in * (1-mask)
        ret = (ret.cpu().detach().numpy() + 1) * 127.5
        return ret.astype(np.uint8)

           

3.4一些常用的loss.py,包括有ID-MRF loss

import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.nn.functional as F
from model.layer import VGG19FeatLayer
from functools import reduce

class WGANLoss(nn.Module):
    def __init__(self):
        super(WGANLoss, self).__init__()

    def __call__(self, input, target):
        d_loss = (input - target).mean()
        g_loss = -input.mean()
        return {'g_loss': g_loss, 'd_loss': d_loss}


def gradient_penalty(xin, yout, mask=None):
    gradients = autograd.grad(yout, xin, create_graph=True,
                              grad_outputs=torch.ones(yout.size()).cuda(), retain_graph=True, only_inputs=True)[0]
    if mask is not None:
        gradients = gradients * mask
    gradients = gradients.view(gradients.size(0), -1)
    gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gp


def random_interpolate(gt, pred):
    batch_size = gt.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1).cuda()
    # alpha = alpha.expand(gt.size()).cuda()
    interpolated = gt * alpha + pred * (1 - alpha)
    return interpolated


class IDMRFLoss(nn.Module):
    def __init__(self, featlayer=VGG19FeatLayer):
        super(IDMRFLoss, self).__init__()
        self.featlayer = featlayer()
        self.feat_style_layers = {'relu3_2': 1.0, 'relu4_2': 1.0}
        self.feat_content_layers = {'relu4_2': 1.0}
        self.bias = 1.0
        self.nn_stretch_sigma = 0.5
        self.lambda_style = 1.0
        self.lambda_content = 1.0

    def sum_normalize(self, featmaps):
        reduce_sum = torch.sum(featmaps, dim=1, keepdim=True)
        return featmaps / reduce_sum

    def patch_extraction(self, featmaps):
        patch_size = 1
        patch_stride = 1
        patches_as_depth_vectors = featmaps.unfold(2, patch_size, patch_stride).unfold(3, patch_size, patch_stride)
        self.patches_OIHW = patches_as_depth_vectors.permute(0, 2, 3, 1, 4, 5)
        dims = self.patches_OIHW.size()
        self.patches_OIHW = self.patches_OIHW.view(-1, dims[3], dims[4], dims[5])
        return self.patches_OIHW

    def compute_relative_distances(self, cdist):
        epsilon = 1e-5
        div = torch.min(cdist, dim=1, keepdim=True)[0]
        relative_dist = cdist / (div + epsilon)
        return relative_dist

    def exp_norm_relative_dist(self, relative_dist):
        scaled_dist = relative_dist
        dist_before_norm = torch.exp((self.bias - scaled_dist)/self.nn_stretch_sigma)
        self.cs_NCHW = self.sum_normalize(dist_before_norm)
        return self.cs_NCHW

    def mrf_loss(self, gen, tar):
        meanT = torch.mean(tar, 1, keepdim=True)
        gen_feats, tar_feats = gen - meanT, tar - meanT

        gen_feats_norm = torch.norm(gen_feats, p=2, dim=1, keepdim=True)
        tar_feats_norm = torch.norm(tar_feats, p=2, dim=1, keepdim=True)

        gen_normalized = gen_feats / gen_feats_norm
        tar_normalized = tar_feats / tar_feats_norm

        cosine_dist_l = []
        BatchSize = tar.size(0)

        for i in range(BatchSize):
            tar_feat_i = tar_normalized[i:i+1, :, :, :]
            gen_feat_i = gen_normalized[i:i+1, :, :, :]
            patches_OIHW = self.patch_extraction(tar_feat_i)

            cosine_dist_i = F.conv2d(gen_feat_i, patches_OIHW)
            cosine_dist_l.append(cosine_dist_i)
        cosine_dist = torch.cat(cosine_dist_l, dim=0)
        cosine_dist_zero_2_one = - (cosine_dist - 1) / 2
        relative_dist = self.compute_relative_distances(cosine_dist_zero_2_one)
        rela_dist = self.exp_norm_relative_dist(relative_dist)
        dims_div_mrf = rela_dist.size()
        k_max_nc = torch.max(rela_dist.view(dims_div_mrf[0], dims_div_mrf[1], -1), dim=2)[0]
        div_mrf = torch.mean(k_max_nc, dim=1)
        div_mrf_sum = -torch.log(div_mrf)
        div_mrf_sum = torch.sum(div_mrf_sum)
        return div_mrf_sum

    def forward(self, gen, tar):
        gen_vgg_feats = self.featlayer(gen)
        tar_vgg_feats = self.featlayer(tar)

        style_loss_list = [self.feat_style_layers[layer] * self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) for layer in self.feat_style_layers]
        self.style_loss = reduce(lambda x, y: x+y, style_loss_list) * self.lambda_style
        #reduce函數會對元素進行積累
        content_loss_list = [self.feat_content_layers[layer] * self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) for layer in self.feat_content_layers]
        self.content_loss = reduce(lambda x, y: x+y, content_loss_list) * self.lambda_content

        return self.style_loss + self.content_loss


class StyleLoss(nn.Module):
    def __init__(self, featlayer=VGG19FeatLayer, style_layers=None):
        super(StyleLoss, self).__init__()
        self.featlayer = featlayer()
        if style_layers is not None:
            self.feat_style_layers = style_layers
        else:
            self.feat_style_layers = {'relu2_2': 1.0, 'relu3_2': 1.0, 'relu4_2': 1.0}

    def gram_matrix(self, x):
        b, c, h, w = x.size()
        feats = x.view(b * c, h * w)
        g = torch.mm(feats, feats.t())
        return g.div(b * c * h * w)

    def _l1loss(self, gen, tar):
        return torch.abs(gen-tar).mean()

    def forward(self, gen, tar):
        gen_vgg_feats = self.featlayer(gen)
        tar_vgg_feats = self.featlayer(tar)
        style_loss_list = [self.feat_style_layers[layer] * self._l1loss(self.gram_matrix(gen_vgg_feats[layer]), self.gram_matrix(tar_vgg_feats[layer])) for
                           layer in self.feat_style_layers]
        style_loss = reduce(lambda x, y: x + y, style_loss_list)
        return style_loss


class ContentLoss(nn.Module):
    def __init__(self, featlayer=VGG19FeatLayer, content_layers=None):
        super(ContentLoss, self).__init__()
        self.featlayer = featlayer()
        if content_layers is not None:
            self.feat_content_layers = content_layers
        else:
            self.feat_content_layers = {'relu4_2': 1.0}

    def _l1loss(self, gen, tar):
        return torch.abs(gen-tar).mean()

    def forward(self, gen, tar):
        gen_vgg_feats = self.featlayer(gen)
        tar_vgg_feats = self.featlayer(tar)
        content_loss_list = [self.feat_content_layers[layer] * self._l1loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) for
                             layer in self.feat_content_layers]
        content_loss = reduce(lambda x, y: x + y, content_loss_list)
        return content_loss


class TVLoss(nn.Module):
    def __init__(self):
        super(TVLoss, self).__init__()

    def forward(self, x):
        h_x, w_x = x.size()[2:]
        h_tv = torch.abs(x[:, :, 1:, :] - x[:, :, :h_x-1, :])
        w_tv = torch.abs(x[:, :, :, 1:] - x[:, :, :, :w_x-1])
        loss = torch.sum(h_tv) + torch.sum(w_tv)
        return loss
           

4參考文獻

4.1原論文

Image Inpainting via Generative Multi-column

Convolutional Neural Networks

4.2源碼

https://github.com/shepnerd/inpainting_gmcnn

繼續閱讀