天天看點

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

論文位址:https://arxiv.org/abs/2007.10247

源碼位址:GitHub - researchmm/STTN: [ECCV'2020] STTN: Learning Joint Spatial-Temporal Transformations for Video Inpainting

一、項目介紹

        當下SOTA的方法大多采用注意模型,通過搜尋參考幀中缺失的内容來完成一幀,并進一步逐幀完成整個視訊。然而,這些方法在空間和時間次元上的注意結果可能會不一緻,這往往會導緻視訊中的模糊和時間僞影。

        本文提出時空轉換網絡STTN(Spatial-Temporal Transformer Network)。具體來說,是通過自注意機制同時填補所有輸入幀中的缺失區域,并提出通過時空對抗性損失來優化STTN。為了展示該模型的優越性,我們使用标準的靜止掩模和更真實的運動物體掩模進行了定量和定性的評價。

二、STTN

         模型輸入是圖像幀序列和masks序列,圖像幀序列經過Encoder、Mask經過scale變化成原來的1/4,然後一起送入Spatial-Temporal Transformer子產品;Spatial-Temporal Transformer子產品由8個TransformerBlock組成;最後Decoder子產品負責将特征還原成圖像幀序列。STTN的整體結構圖如下:

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

圖1

1.Encoder

        Frame-Level Encoder幀級編碼器,通過疊加二維卷積層來建構的,目的是為每一幀的低級别像素的深度特征,就是四個卷積層提取單幀圖像特征,要素不多,結構圖如下:

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

圖2

代碼如下:

# 位置model/sttn.py
self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, channel, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        )
           

2.Spatial-Temporal Transformer Network

           這是STTN的核心部分,通過一個多頭 patch-based attention子產品沿着空間和時間次元進行搜尋。transformer的不同頭部計算不同尺度上對空間patch的注意力。這樣的設計允許我們處理由複雜的運動引起的外觀變化。例如,對大尺寸的patch(例如,幀大小H×W)旨在修複固定的背景;對小尺寸的patch(如H/10×W/10)有助于在視訊的任意位置捕捉移動的前景資訊。

(1)TranformerBlock

        TransformerBlock由Embedding、Matching和Attending組成,代碼中Matching和Attending被放在一起合成了MultiHeadedAttention。輸入是幀序列特征和masks。

        幀序列的特征平分成四部分,每個部分經過Embedding映射為四種尺度的Key、Query、Value,進而對應不同尺度的patch。masks經過變換也變成四個尺度。将四個尺度的Key、Query、Value和四個尺度masks分别送入MultiHeadedAttention,然後将結果Concat到一起,經過FeedForward層進一步分特征融合,得到融合了時間次元上不同尺度空間patch的特征。結構圖如下:

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

 圖3

代碼如下:

# 位置model/sttn.py
class TransformerBlock(nn.Module):
    """
    Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
    """

    def __init__(self, patchsize, hidden=128):
        super().__init__()
        self.attention = MultiHeadedAttention(patchsize, d_model=hidden)
        self.feed_forward = FeedForward(hidden)

    def forward(self, x):
        x, m, b, c = x['x'], x['m'], x['b'], x['c']
        x = x + self.attention(x, m, b, c)
        x = x + self.feed_forward(x)
        return {'x': x, 'm': m, 'b': b, 'c': c}
           

(2)KQV Formatting

        圖3中的KQV Formatting結構如下圖:

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

圖4

        TranformerBlock輸入的幀序列特征,被平分成四個部分,每個部分經過變換,變成四種尺度patch的特征。

        代碼如下:

# 位置model/sttn.py
query = query.view(b, t, d_k, out_h, height, out_w, width)
query = query.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
                b,  t*out_h*out_w, d_k*height*width)
key = key.view(b, t, d_k, out_h, height, out_w, width)
key = key.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
                b,  t*out_h*out_w, d_k*height*width)
value = value.view(b, t, d_k, out_h, height, out_w, width)
value = value.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
                b,  t*out_h*out_w, d_k*height*width)
           

(3)Mask Formatting

        KQV Formatting将幀序列變成四種尺度,masks也需要對應的變成四種尺度,結構如下:

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

 圖5

代碼如下:

# 位置model/sttn.py
mm = m.view(b, t, 1, out_h, height, out_w, width)
mm = mm.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(
                b,  t*out_h*out_w, height*width)
mm = (mm.mean(-1) > 0.5).unsqueeze(1).repeat(1, t*out_h*out_w, 1)
           

(4)Attention

        圖3中的Attention層其實包括了論文中的Matching和Attending,結構圖如下:

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

 圖6

        圖6中的K*Q/sqrt(Q.size(-1))是在計算各個patch的相似性,對應論文中公式,第i個斑塊與第j個patch的相似性記為::

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

         圖6中的masked_fill(Mask, -1e9)是将圖像中的損壞部分mask掉,意思是隻學習圖像中完整的部分,壞的就不要學習了。

         論文中的Attention對應圖6中的matmul,負責計算相關patches的value權重和得到輸出patch的query。公式如下:

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

代碼如下:

# 位置model/sttn.py
class Attention(nn.Module):
    """
    Compute 'Scaled Dot Product Attention
    """

    def forward(self, query, key, value, m):
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
        scores.masked_fill(m, -1e9)
        p_attn = F.softmax(scores, dim=-1)
        p_val = torch.matmul(p_attn, value)
        return p_val, p_attn
           

3.Decoder

         frame-level decoder: 幀級解碼器,把特征解碼成幀。期間特征圖經過了兩次的膨脹,中間穿插幾個2d卷積,整體過程有點像Encoder倒過來,結構圖如下:

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

 圖7

代碼如下:

# 位置model/sttn.py
self.decoder = nn.Sequential(
            deconv(channel, 128, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            deconv(64, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
        )
           

三、損失函數

        本文使用GAN來對模型進行優化,G模型選擇了一個像素級的重建損失即L1Loss,D網絡使用T-PatchGAN來優化。

1.G模型損失函數

        G模型圖像破壞區域的L1Loss:

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

        G模型圖像有效區域的L1Loss:

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

        STTN的對抗性損失: ​​​

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

        上式看上去很複雜,其實就是将恢複的圖像送入D模型,然後送入損失函數(可選nsgan、lsgan、hinge)

        總結上面三個式子,得出G模型的損失函數,其中三個權重官方推薦

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程
【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

2.D網絡的損失函數

        對抗性的損失在提高視訊繪制的感覺品質和時空一緻性方面顯示出了良好的效果。公式如下:

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

         看山去還是很複雜,其實就是将原圖和複原圖分别送入損失函數(可選nsgan、lsgan、hinge),然後求和,代碼中是取均值,不過應該影響不大。

三、訓練流程

        下面是我根據官方代碼梳理的整個訓練過程:

        1.從資料集選取資料,同時為選取的資料随機帶有破壞圖案的masks

        2.根據masks将原圖的破壞部分變成0,得到masked_frame

        3.将masked_frame和masks送入G模型(生成模型,即STTN),得出估計pred_img

        4.根據pred_img修複圖像,得到comp_img

        5.将原圖和comp_img分别送入D模型,分别得到輸出的特征 real_vid_feat和fake_vid_feat

        6.使用real_vid_feat和fake_vid_feat對D模型進行優化(損失函數可選nsgan、lsgan、hinge)

        7.使用原圖、comp_img和gen_vid_feat對G模型進行優化(L1Loss)

代碼如下:

# 位置core/trainer.py
      def _train_epoch(self, pbar):
        device = self.config['device']

        for frames, masks in self.train_loader:
            self.adjust_learning_rate()
            self.iteration += 1
            frames, masks = frames.to(device), masks.to(device)
            b, t, c, h, w = frames.size()
            masked_frame = (frames * (1 - masks).float())
            # 将masked_frame和masks送入G模型(生成模型,即STTN),得出估計pred_img
            pred_img = self.netG(masked_frame, masks)
            frames = frames.view(b*t, c, h, w)
            masks = masks.view(b*t, 1, h, w)
            # 根據pred_img修複圖像,得到comp_img
            comp_img = frames*(1.-masks) + masks*pred_img
            gen_loss = 0
            dis_loss = 0
            # 将原圖和comp_img分别送入D模型,分别得到輸出的特征 real_vid_feat和fake_vid_feat
            real_vid_feat = self.netD(frames)
            fake_vid_feat = self.netD(comp_img.detach())
            # 計算D網絡的損失
            dis_real_loss = self.adversarial_loss(real_vid_feat, True, True)
            dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True)
            dis_loss += (dis_real_loss + dis_fake_loss) / 2
            self.add_summary(
                self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
            self.add_summary(
                self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
            self.optimD.zero_grad()
            dis_loss.backward()
            # 使用real_vid_feat和fake_vid_feat對D模型進行優化
            self.optimD.step()

            # G模型的對抗性損失
            gen_vid_feat = self.netD(comp_img)
            gan_loss = self.adversarial_loss(gen_vid_feat, True, False)
            gan_loss = gan_loss * self.config['losses']['adversarial_weight']
            gen_loss += gan_loss
            self.add_summary(
                self.gen_writer, 'loss/gan_loss', gan_loss.item())

            # G模型圖像破壞區域的L1Loss
            hole_loss = self.l1_loss(pred_img*masks, frames*masks)
            hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
            gen_loss += hole_loss 
            self.add_summary(
                self.gen_writer, 'loss/hole_loss', hole_loss.item())
            # G模型圖像有效區域的L1Loss
            valid_loss = self.l1_loss(pred_img*(1-masks), frames*(1-masks))
            valid_loss = valid_loss / torch.mean(1-masks) * self.config['losses']['valid_weight']
            gen_loss += valid_loss 
            self.add_summary(
                self.gen_writer, 'loss/valid_loss', valid_loss.item())
            
            self.optimG.zero_grad()
            gen_loss.backward()
            # 使用原圖、comp_img和gen_vid_feat對G模型進行優化
            self.optimG.step()

            # 日志
            if self.config['global_rank'] == 0:
                pbar.update(1)
                pbar.set_description((
                    f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};"
                    f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}")
                )

            # saving models
            if self.iteration % self.train_args['save_freq'] == 0:
                self.save(int(self.iteration//self.train_args['save_freq']))
            if self.iteration > self.train_args['iterations']:
                break
           

        接下來代碼中有些重點,需要簡單說明一下:

1.準備資料集

        項目中用到Davis或youtube-vos資料集,兩個資料集其實都是為segmentation任務設計的,代碼中都隻使用圖像資料,不使用标注資料。我們以davis資料集為例,davis資料集由90個視訊組成,每個視訊已經拆幀成圖檔,資料集下載下傳完每個視訊一個檔案夾,但是程式需要每個視訊這圖檔打成zip檔案,下面的程式可以用來完成這個工作:

import os
import zipfile


def zipDir(dirpath, out_full_name):
    zipname = zipfile.ZipFile(out_full_name, 'w', zipfile.ZIP_DEFLATED)
    for path, dirnames, filenames in os.walk(dirpath):
        fpath= path.replace(dirpath, '')

        for filename in filenames:
            zipname.write(os.path.join(path, filename), os.path.join(fpath, filename))
    zipname.close()


if __name__=="__main__":
    org_dir = r'datasets/davis/JPEGImages_org'
    zip_dir = r'datasets/davis/JPEGImages'
    g = os.walk(org_dir)
    for path, dir_list, file_list in g:
        for dir_name in dir_list:
            input_path = os.path.join(path, dir_name)
            output_path = os.path.join(zip_dir, dir_name+'.zip')
            print(input_path, '\n', output_path)
            zipDir(input_path, output_path)
           

2.資料選取政策

        資料是從90個視訊中随機挑一個,然後在這個視訊中選取sample_length張圖檔,最終每個視訊都會選取一個圖檔組,在論文中提到有兩種資料選取政策,就是下面這個公式:

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

         其中

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

代表以t為中心n為半徑的連續幀序列,代碼實作是50%機率用一個長度為sample_length的框随機滑動選取;

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

表示從以s采樣率的視訊

【論文筆記】圖像修複Learning Joint Spatial-Temporal Transformations for Video Inpainting一、項目介紹二、STTN代碼如下:三、損失函數三、訓練流程

中均勻采樣的遠處幀,代碼中并未使用這種方式,而是50%機率随機選取幀,這樣也許是為了解決緩解資料不夠多的問題。

        選圖檔組的代碼如下:

# 位置:core/dataset.py
def get_ref_index(length, sample_length):
    # 50%機率随機選取幀
    if random.uniform(0, 1) > 0.5:
        ref_index = random.sample(range(length), sample_length)
        ref_index.sort()
    else:
    # 50%機率用一個長度為sample_length的框随機滑動選取
        pivot = random.randint(0, length-sample_length)
        ref_index = [pivot+i for i in range(sample_length)]
    return ref_index
           

3.生成随機masks

        有了圖檔組,還需要為每個圖檔組随機生成masks。其中0代表背景,1代表破壞部分。代碼如下,注釋已經很清楚:

# 位置:core/utils.py
def create_random_shape_with_random_motion(video_length, imageHeight=240, imageWidth=432):
    # 生成的破壞圖案寬高占原圖的1/3到100%
    height = random.randint(imageHeight//3, imageHeight-1)
    width = random.randint(imageWidth//3, imageWidth-1)
    # 生成不規則的破壞圖案
    edge_num = random.randint(6, 8)
    ratio = random.randint(6, 8)/10
    region = get_random_shape(
        edge_num=edge_num, ratio=ratio, height=height, width=width)
    region_width, region_height = region.size
    # 随機放置破壞圖案
    x, y = random.randint(
        0, imageHeight-region_height), random.randint(0, imageWidth-region_width)
    velocity = get_random_velocity(max_speed=3)
    m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
    m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
    masks = [m.convert('L')]
    # 50%機率所有的mask一樣
    if random.uniform(0, 1) > 0.5:
        return masks*video_length
    # 50%機率mask中的破壞圖案會移動
    for _ in range(video_length-1):
        x, y, velocity = random_move_control_points(
            x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3)
        m = Image.fromarray(
            np.zeros((imageHeight, imageWidth)).astype(np.uint8))
        m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
        masks.append(m.convert('L'))
    return masks
           

繼續閱讀