天天看點

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

一、圖像分割

1.1 圖像分割是什麼?

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

圖像分割:将圖像每一個像素進行分類

1.2 圖像分割分類

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

圖像分割分類:

  • 超像素分割:少量超像素代替大量像素,常用于圖像預處理
    • 超像素:一個超像素由很多由相同性質的像素構成,如左上圖中的每個白色塊
  • 語義分割:逐像素分類,無法區分個體
  • 執行個體分割:對個體目标進行分割,像素級目标檢測
    • 隻會将感興趣的目标進行分割,比如說圖中的人
  • 全景分割:語義分割結合執行個體分割
    • 将每個個體進行區分
    • 将每個像素進行分類

二、圖像分割的實作

2.1 模型是如何将圖像分割的?

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

圖像分類:輸出是一個一維的向量,上面每一個分量表示一個類别

圖像分割:輸出是一個三維的張量,二維面上的每一個點對應第三維向量,其每個分量對應一個類别

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

計算機接受圖像,即3-d張量的輸入,輸出也是3-d張量

基于pascal voc資料集,類别為21,具體類别資訊如下:

classes = ['__background__',
           'aeroplane', 'bicycle', 'bird', 'boat',
           'bottle', 'bus', 'car', 'cat', 'chair',
           'cow', 'diningtable', 'dog', 'horse',
           'motorbike', 'person', 'pottedplant',
           'sheep', 'sofa', 'train', 'tvmonitor']
           

2.2 torch.hub

PyTorch-Hub——PyTorch模型庫,有大量模型供開發者調用

torch.hub.load('pytorch/vision', 'deeplabv3_resnet101',pretrained=True)
model = torch.hub.load(github, model, *args, **kwargs)
           

功能:加載預訓練模型

主要參數:

  • github: str, 項目名, eg:pytorch/vision,<repo _owner/repo _name[:tag_name]>
  • model: str, 模型名
  • pretrained: 是否加載預訓練模型的參數

功能:列出github參數所指定項目中所提供的模型

功能:列出模型中有哪些參數

2.3 代碼示例

# -*- coding: utf-8 -*-
"""
# @file name  : seg_demo.py
# @author     : TingsongYu https://github.com/TingsongYu
# @date       : 2019-11-22
# @brief      : torch.hub調用deeplab-V3進行圖像分割
"""

import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if __name__ == "__main__":

    path_img = os.path.join(BASE_DIR, "demo_img1.png")
    # path_img = os.path.join(BASE_DIR, "demo_img2.png")
    # path_img = os.path.join(BASE_DIR, "demo_img3.png")

    # config
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # 1. load data & model
    input_image = Image.open(path_img).convert("RGB")
    model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)
    model.eval()

    # 2. preprocess
    input_tensor = preprocess(input_image)
    input_bchw = input_tensor.unsqueeze(0)

    # 3. to device
    if torch.cuda.is_available():
        input_bchw = input_bchw.to(device)
        model.to(device)

    # 4. forward
    with torch.no_grad():
        tic = time.time()
        print("input img tensor shape:{}".format(input_bchw.shape))
        output_4d = model(input_bchw)['out']
        output = output_4d[0]
        print("pass: {:.3f}s use: {}".format(time.time() - tic, device))
        print("output img tensor shape:{}".format(output.shape))
    output_predictions = output.argmax(0)

    # 5. visualization
    palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
    colors = (colors % 255).numpy().astype("uint8")

    # plot the semantic segmentation predictions of 21 classes in each color
    r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
    r.putpalette(colors)
    plt.subplot(121).imshow(r)
    plt.subplot(122).imshow(input_image)
    plt.show()

    # appendix
    classes = ['__background__',
                       'aeroplane', 'bicycle', 'bird', 'boat',
                       'bottle', 'bus', 'car', 'cat', 'chair',
                       'cow', 'diningtable', 'dog', 'horse',
                       'motorbike', 'person', 'pottedplant',
                       'sheep', 'sofa', 'train', 'tvmonitor']

           

運作結果:

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

第一張圖檔分割結果:

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

第二張圖檔分割結果:

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

第三圖檔分割結果:

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

由上圖可知,将狗頭部放在貓圖檔上面,則整體被識别為狗,分類産生了混淆

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

将貓和狗的頭部進行交換,則圖像分割模型會将其分類錯誤,故說明在分割和分類時,起主導作用的是頭部的像素

注意:圖像分割不簡單地對每個像素進行分類,同時還會考慮周圍的資訊,并且綜合考慮圖像上每個區域的資訊

三、深度學習圖像分割模型簡介

3.1 FCN

FCN:《Fully Convolutional Networks for Semantic Segmentation 》2014

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

最主要貢獻:

利用全卷積完成pixelwise prediction

解決的問題:

linear層(fc層)的局限性:存在linear層的網絡模型,其輸入尺寸必須是一緻的,然而在inference階段輸入的圖像的尺寸可能不一緻,這時存在linear層的網絡模型就無法很好地适應圖檔輸入尺度不一緻的應用場景,FCN利用全卷積網絡替換掉了linear層,這時就能完成逐像素預測

3.2 U-net

U-net:《U-Net: Convolutional Networks for Biomedical Image Segmentation》

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

網絡結構說明:

網絡模型可以從中間分成兩半,左邊一半稱為編碼器,右邊稱為解碼器,整體成U形,故稱之為Unet

主要特點:

  • 可以在比較小的資料集上産生很好的結果(Unet起初用于醫學圖像分割,醫學圖像資料集一般都很小)
  • encoder的feature map可以傳到decoder進行特征融合,即上圖灰色箭頭部分
  • 輸入和輸出的尺寸可以不比對
27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

輸入輸出:

  • 輸入次元:572x272x1
    • 實際區域是圖中388x388的黃色框,邊緣的區域來自鏡像填充
    • 因為輸入的細胞圖檔是灰階圖,是以通道為1
  • 輸出次元:388x388x2
    • 細胞分割隻對像素做二分類,故輸出隻需要2個通道

最主要貢獻:

奠定Unet系列分割模型的基本結構——編碼器與解碼器的特征融合

3.3 DeepLab系列

3.3.1 DeepLab系列——V1

DeepLabV1:《Semantic image segmentation with deep convolutional nets and fully connected CRFs》

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

主要特點:

  1. 孔洞卷積:借助孔洞卷積,增大感受野
  2. CRF:采用CRF進行mask後處理,得到精細地mask輸出

3.3.2 DeepLab系列——V2

DeepLabV2:《DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and FullyConnected CRFs》

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

主要特點:

ASPP(Atrous spatial pyramid pooling ):解決多尺度問題

使用4種不同尺度的空間金字塔池化對input feature map進行池化操作

3.3.3 DeepLab系列——V3

DeepLabv3:《Rethinking Atrous Convolution for Semantic Image Segmentation》

主要特點:

  1. 孔洞卷積的串行
  2. ASPP的并行

解決多尺度問題的四種方法:

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

(a) 對圖像進行多尺度地輸入, 得到多尺度的特征圖,再進行融合

(b) Encoder-Decoder方法,即Uner中使用的方法

(c) DeepLabV1中使用的空洞卷積方法,左邊是原始卷積,右邊是空洞卷積,其卷積結果能保持圖像分辨率,進而能獲得更精确的資訊

(d) 空間金字塔池化,用不同大小的卷積核來完成對特征圖的不同尺度特征的提取,最後得到最終的feature map

DeepLabv3對後兩種方法進行改進,提出了孔洞卷積的串行和ASPP的并行

孔洞卷積的串行結構:

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

上面是傳統的卷積結構

下面是孔洞卷積的串行結構,該結構能夠使輸出的feature map保持較好的分辨率,進而為圖像分割提供更精确的資訊

ASPP的并行結構:

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

ASPP層使用了不同的rate實作對特征的提取,然後對特征拼接,最後經過一個1x1的卷積核,來完成最終特征的提取

3.3.3 DeepLab系列——V3+

DeepLabv3+:《Encoder-Decoder with Atrous Separable

Convolution for Semantic Image Segmentation》

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

網絡結構:

Encoder:Backbone部分使用了v1中的空洞卷積,接着使用v3中ASPP結構得到不同尺度的feature map,對其進行concat後再進過1x1的卷積得到一個特征圖,稱其為編碼

Decoder:接收Encoder部分的編碼的輸入,對其進行4倍上采樣,同時接收Encoder部分空洞卷積的輸出經過1x1的卷積,将兩部分的輸出進行拼接,得到特征圖再進行4倍上采樣,得到最終的預測

主要特點:

  • 綜合了v3中使用的空洞卷積和ASPP并行結構
  • 在deeplabv3基礎上加上Encoder-Decoder結構

3.4 圖像分割綜述

《Deep Semantic Segmentation of Natural and Medical Images: A Review》 2019

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

圖像分割資源:

https://github.com/shawnbit/unet-family

https://github.com/yassouali/pytorch_segmentation

四、訓練Unet完成人像摳圖

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

資料來源: https://github.com/PetroWu/AutoPortraitMatting

輸入:3通道的RGB圖像

輸出:1通道的mask

4.1 Unet代碼——網絡結構

unet.py:

from collections import OrderedDict

import torch
import torch.nn as nn


class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

           

說明:

Unet網絡結構分為encoder和decoder部分,每個部分都是由基本的UNet._block構成,而UNet._block結構為{conv+norm+relu}x2

4.2 訓練代碼

2_unet_portrait_matting.py:

# -*- coding: utf-8 -*-
"""
# @file name  : unet_portrait_matting.py
# @author     : TingsongYu https://github.com/TingsongYu
# @date       : 2019-11-25
# @brief      : train unet
"""

import os
import time
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
from tools.common_tools import set_seed
from tools.my_dataset import PortraitDataset
from tools.unet import UNet

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

set_seed()  # 設定随機種子


def compute_dice(y_pred, y_true):
    """
    :param y_pred: 4-d tensor, value = [0,1]
    :param y_true: 4-d tensor, value = [0,1]
    :return:
    """
    y_pred, y_true = np.array(y_pred), np.array(y_true)
    y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
    return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))


if __name__ == "__main__":

    # config
    LR = 0.01
    BATCH_SIZE = 8
    max_epoch = 1   # 400
    start_epoch = 0
    lr_step = 150
    val_interval = 3
    checkpoint_interval = 20
    vis_num = 10
    mask_thres = 0.5

    train_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "train")
    valid_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "valid")

    # step 1
    train_set = PortraitDataset(train_dir)
    valid_set = PortraitDataset(valid_dir)

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    valid_loader = DataLoader(valid_set, batch_size=1, shuffle=True, drop_last=False)

    # step 2
    net = UNet(in_channels=3, out_channels=1, init_features=32)   # init_features is 64 in stander uent
    net.to(device)

    # step 3
    loss_fn = nn.MSELoss()
    # step 4
    optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step, gamma=0.1)

    # step 5
    train_curve = list()
    valid_curve = list()
    train_dice_curve = list()
    valid_dice_curve = list()
    for epoch in range(start_epoch, max_epoch):

        train_loss_total = 0.
        train_dice_total = 0.

        net.train()
        for iter, (inputs, labels) in enumerate(train_loader):

            if torch.cuda.is_available():
                inputs, labels = inputs.to(device), labels.to(device)

            # forward
            outputs = net(inputs)

            # backward
            optimizer.zero_grad()
            loss = loss_fn(outputs, labels)
            loss.backward()

            optimizer.step()

            # print
            train_dice = compute_dice(outputs.ge(mask_thres).cpu().data.numpy(), labels.cpu())
            train_dice_curve.append(train_dice)
            train_curve.append(loss.item())

            train_loss_total += loss.item()

            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] running_loss: {:.4f}, mean_loss: {:.4f} "
                  "running_dice: {:.4f} lr:{}".format(epoch, max_epoch, iter + 1, len(train_loader), loss.item(),
                                    train_loss_total/(iter+1), train_dice, scheduler.get_lr()))

        scheduler.step()

        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint = {"model_state_dict": net.state_dict(),
                          "optimizer_state_dict": optimizer.state_dict(),
                          "epoch": epoch}
            path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
            torch.save(checkpoint, path_checkpoint)

        # validate the model
        if (epoch+1) % val_interval == 0:

            net.eval()
            valid_loss_total = 0.
            valid_dice_total = 0.

            with torch.no_grad():
                for j, (inputs, labels) in enumerate(valid_loader):
                    if torch.cuda.is_available():
                        inputs, labels = inputs.to(device), labels.to(device)

                    outputs = net(inputs)
                    loss = loss_fn(outputs, labels)

                    valid_loss_total += loss.item()

                    valid_dice = compute_dice(outputs.ge(mask_thres).cpu().data, labels.cpu())
                    valid_dice_total += valid_dice

                valid_loss_mean = valid_loss_total/len(valid_loader)
                valid_dice_mean = valid_dice_total/len(valid_loader)
                valid_curve.append(valid_loss_mean)
                valid_dice_curve.append(valid_dice_mean)

                print("Valid:\t Epoch[{:0>3}/{:0>3}] mean_loss: {:.4f} dice_mean: {:.4f}".format(
                    epoch, max_epoch, valid_loss_mean, valid_dice_mean))

    # 可視化
    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(valid_loader):
            if idx > vis_num:
                break
            if torch.cuda.is_available():
                inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            pred = outputs.ge(mask_thres)

            mask_pred = outputs.ge(0.5).cpu().data.numpy().astype("uint8")

            img_hwc = inputs.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
            plt.subplot(121).imshow(img_hwc)
            mask_pred_gray = mask_pred.squeeze() * 255
            plt.subplot(122).imshow(mask_pred_gray, cmap="gray")
            plt.show()
            plt.pause(0.5)
            plt.close()

    # plot curve
    train_x = range(len(train_curve))
    train_y = train_curve

    train_iters = len(train_loader)
    valid_x = np.arange(1, len(
        valid_curve) + 1) * train_iters * val_interval  # 由于valid中記錄的是epochloss,需要對記錄點進行轉換到iterations
    valid_y = valid_curve

    plt.plot(train_x, train_y, label='Train')
    plt.plot(valid_x, valid_y, label='Valid')

    plt.legend(loc='upper right')
    plt.ylabel('loss value')
    plt.xlabel('Iteration')
    plt.title("Plot in {} epochs".format(max_epoch))
    plt.show()

    # dice curve
    train_x = range(len(train_dice_curve))
    train_y = train_dice_curve

    train_iters = len(train_loader)
    valid_x = np.arange(1, len(
        valid_dice_curve) + 1) * train_iters * val_interval  # 由于valid中記錄的是epochloss,需要對記錄點進行轉換到iterations
    valid_y = valid_dice_curve

    plt.plot(train_x, train_y, label='Train')
    plt.plot(valid_x, valid_y, label='Valid')

    plt.legend(loc='upper right')
    plt.ylabel('dice value')
    plt.xlabel('Iteration')
    plt.title("Plot in {} epochs".format(max_epoch))
    plt.show()
    torch.cuda.empty_cache()

           

運作一個epoch的結果:

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖
27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖
27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖
27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

4.3 測試代碼

使用訓練了400個epoch的模型進行測試

# -*- coding: utf-8 -*-
"""
# @file name  : portrait_inference.py
# @author     : TingsongYu https://github.com/TingsongYu
# @date       : 2019-11-25
# @brief      : inference portrait
"""

import os
import time
import random
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torch.optim as optim
import torchvision.models as models
from tools.common_tools import set_seed
from tools.my_dataset import PortraitDataset
from tools.unet import UNet

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

set_seed()  # 設定随機種子


def compute_dice(y_pred, y_true):
    """
    :param y_pred: 4-d tensor, value = [0,1]
    :param y_true: 4-d tensor, value = [0,1]
    :return:
    """
    y_pred, y_true = np.array(y_pred), np.array(y_true)
    y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
    return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))


def get_img_name(img_dir, format="jpg"):
    """
    擷取檔案夾下format格式的檔案名
    :param img_dir: str
    :param format: str
    :return: list
    """
    file_names = os.listdir(img_dir)
    img_names = list(filter(lambda x: x.endswith(format), file_names))
    img_names = list(filter(lambda x: not x.endswith("matte.png"), img_names))

    if len(img_names) < 1:
        raise ValueError("{}下找不到{}格式資料".format(img_dir, format))
    return img_names


def get_model(m_path):

    unet = UNet(in_channels=3, out_channels=1, init_features=32)
    checkpoint = torch.load(m_path, map_location="cpu")

    # remove module.
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in checkpoint['model_state_dict'].items():
        namekey = k[7:] if k.startswith('module.') else k
        new_state_dict[namekey] = v

    unet.load_state_dict(new_state_dict)

    return unet


if __name__ == "__main__":

    img_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "valid")
    model_path = "checkpoint_399_epoch.pkl"
    time_total = 0
    num_infer = 5
    mask_thres = .5

    # 1. data
    img_names = get_img_name(img_dir, format="png")
    random.shuffle(img_names)
    num_img = len(img_names)

    # 2. model
    unet = get_model(model_path)
    unet.to(device)
    unet.eval()

    for idx, img_name in enumerate(img_names):
        if idx > num_infer:
            break

        path_img = os.path.join(img_dir, img_name)
        # path_img = "C:\\Users\\Administrator\\Desktop\\Andrew-wu.png"
        #
        # step 1/4 : path --> img_chw
        img_hwc = Image.open(path_img).convert('RGB')
        img_hwc = img_hwc.resize((224, 224))
        img_arr = np.array(img_hwc)
        img_chw = img_arr.transpose((2, 0, 1))

        # step 2/4 : img --> tensor
        img_tensor = torch.tensor(img_chw).to(torch.float)
        img_tensor.unsqueeze_(0)
        img_tensor = img_tensor.to(device)

        # step 3/4 : tensor --> features
        time_tic = time.time()
        outputs = unet(img_tensor)
        time_toc = time.time()

        # step 4/4 : visualization
        pred = outputs.ge(mask_thres)
        mask_pred = outputs.ge(0.5).cpu().data.numpy().astype("uint8")

        img_hwc = img_tensor.cpu().data.numpy()[0, :, :, :].transpose((1, 2, 0)).astype("uint8")
        plt.subplot(121).imshow(img_hwc)
        mask_pred_gray = mask_pred.squeeze() * 255
        plt.subplot(122).imshow(mask_pred_gray, cmap="gray")
        plt.show()
        # plt.pause(0.5)
        plt.close()

        time_s = time_toc - time_tic
        time_total += time_s

        print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s))

           

運作結果:

27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖
27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖
27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖
27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖
27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖
27圖像分割一、圖像分割二、圖像分割的實作三、深度學習圖像分割模型簡介四、訓練Unet完成人像摳圖

繼續閱讀