天天看點

利用U-Net網絡對遙感影像道路資訊分割提取一、論文閱讀二、代碼實作三、結果讨論

一、論文閱讀

原始論文是《U-Net: Convolutional Networks for Biomedical Image Segmentation》位址:https://arxiv.org/abs/1505.04597。其網絡結構主要是以“U”型編碼器-解碼器構成了下采樣-上采樣兩部分功能結構。下采樣采用典型的卷積網絡架構,就采樣結構結果而言,每層的Max-Pooling采樣減小了圖像尺寸,但是成倍增加了channels,具體每層卷積操作可以看代碼或者詳讀論文。上采用過程中對下采樣的結果進行Conv-Transpose反卷積過程,直到恢複網絡結構,網絡結構如圖1.1:

利用U-Net網絡對遙感影像道路資訊分割提取一、論文閱讀二、代碼實作三、結果讨論

圖1.1 U-Net網絡架構

二、代碼實作

代碼分成了三個py檔案,分别為資料預處理子產品dataset.py,網絡模型實作子產品unet.py以及main.py。

# dataset.py
from torch.utils.data import Dataset
import PIL.Image as Image
import os
def make_dataset(root):
    imgs=[]
    n=len(os.listdir(root))//2
    for i in range(n):
        img=os.path.join(root,"%03d.png"%i)
        mask=os.path.join(root,"%03d_mask.png"%i)
        imgs.append((img,mask))
    return imgs
class LiverDataset(Dataset):
    def __init__(self, root, transform=None, target_transform=None):
        imgs = make_dataset(root)
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
    def __getitem__(self, index):
        x_path, y_path = self.imgs[index]
        img_x = Image.open(x_path)
        img_y = Image.open(y_path)
        if self.transform is not None:
            img_x = self.transform(img_x)
        if self.target_transform is not None:
            img_y = self.target_transform(img_y)
        return img_x, img_y
    def __len__(self):
        return len(self.imgs)
           

dataset.py中有兩個功能函數,make_dataset子產品是将樣本以及樣本标簽導入。LiverDataset子產品是為了做DataLoader而準備。

# unet.py
import torch
from torch import nn
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, input):
        return self.conv(input)
class Unet(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(Unet, self).__init__()
        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64,out_ch, 1)
    def forward(self,x):
        c1=self.conv1(x)
        p1=self.pool1(c1)
        c2=self.conv2(p1)
        p2=self.pool2(c2)
        c3=self.conv3(p2)
        p3=self.pool3(c3)
        c4=self.conv4(p3)
        p4=self.pool4(c4)
        c5=self.conv5(p4)
        up_6= self.up6(c5)
        merge6 = torch.cat([up_6, c4], dim=1)
        c6=self.conv6(merge6)
        up_7=self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7=self.conv7(merge7)
        up_8=self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8=self.conv8(merge8)
        up_9=self.up9(c8)
        merge9=torch.cat([up_9,c1],dim=1)
        c9=self.conv9(merge9)
        c10=self.conv10(c9)
        return c10
           
# main.py
import torch
import argparse
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import LiverDataset
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import cv2
import os
from tensorboardX import SummaryWriter
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# mask隻需要轉換為tensor
y_transforms = transforms.ToTensor()
def train_model(model, criterion, optimizer, dataload, num_epochs=3):
    writer = SummaryWriter(r'model_record')
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        dt_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0
        for x, y in dataload:
            step += 1
            inputs = x.to(device)
            labels = y.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            writer.add_scalar('train loss', loss.item(), global_step=step+epoch*200)
            print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss/step))
    torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
    return model
#訓練模型
def train():
    model = Unet(3, 1).to(device)
    batch_size = 2
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters())
    liver_dataset = LiverDataset(r'data_road\train',transform=x_transforms,target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    train_model(model, criterion, optimizer, dataloaders)
#顯示模型的輸出結果
def test():
    model = Unet(3, 1)
    model.load_state_dict(torch.load(r'weights_5.pth',map_location=lambda storage, loc: storage.cuda(0)))
    # model.load_state_dict(torch.load(r'u_net_liver\weights_4.pth'))
    liver_dataset = LiverDataset(r'data_road\val1', transform=x_transforms,target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    model.eval()
    with torch.no_grad():
        all_IoU = 0.0
        record = 0
        for x, tagart in dataloaders:
            y=model(x)
            img_y=torch.squeeze(y).numpy()
            img_tagart = torch.squeeze(tagart).numpy()
            img_y[img_y > 0.3] = 255
            img_y[img_y <= 0.3] = 0
            # print([x for x in img_y if x in img_tagart])
            all = 0.0
            inte = 0.0
            for x1 in range(0,512):
                for x2 in range(0,512):
                    if img_y[x1,x2] == 255 and img_tagart[x1,x2] == 1:
                        all=all+1
                    if img_y[x1,x2] == 255 or img_tagart[x1,x2] == 1:
                        inte=inte+1
            all_IoU =all_IoU+ all/inte
            pathname = "%03d_predict.png"%record
            cv2.imwrite(os.path.join(r'data_road/result1',pathname),img_y)
            record=record+1
            print(all*1.0/inte)
        print(all_IoU/20.0)
if __name__ == '__main__':
    train()
    test()
           

main.py函數比較雜,其中我将訓練和測試函數都寫在了一起,在train時單獨運作train()将test()屏蔽即刻。采用的Loss函數是nn.BCEWithLogitsLoss(),這裡有興趣可以将其變換為其他的loss看看結果。值得注意的一點是這裡有一個門檻值0.3,對應的代碼是img_y[img_y > 0.3] = 255和img_y[img_y <= 0.3] = 0,這裡需要對每個不同情況自己去定義自己的分割門檻值去确定。optimizer選取的是Adam。

資料集采用的是Massachusetts road,資料位址為:https://www.cs.toronto.edu/~vmnih/data/,這裡可以用簡單爬蟲批量下載下傳,若有需求,可以讓我在下面評論貼出該資料集的網盤位址。還有一個細節(坑)就是,unet結構是需要512或者1024等大小的16整數倍的image sizes。是以這裡需要對下載下傳的資料集(images和labels)進行批量重采樣,采樣用最鄰近和雙線性内插均可,沒有太大影響,然而我是将資料集裁剪為了512*512,because of graphics memory。最後貼出我的訓練結果。

三、結果讨論

首先,我的疊代次數不太多,也沒有采用動态學習率政策,并且massachusetts road資料集也有很多坑(需要很多預處理,去除損壞樣本,誰用誰知道),是以最後的分割效果一般,僅僅是跑通網絡。結果見圖3.1,3.2和3.3。

利用U-Net網絡對遙感影像道路資訊分割提取一、論文閱讀二、代碼實作三、結果讨論

3.1 Training Loss

利用U-Net網絡對遙感影像道路資訊分割提取一、論文閱讀二、代碼實作三、結果讨論

3.2 Mean IoU

利用U-Net網絡對遙感影像道路資訊分割提取一、論文閱讀二、代碼實作三、結果讨論

3.3 從左到右分别是原始圖像-真值-提取結果

 讨論:本文簡要地用U-Net網絡跑了一下遙感影像道路資訊分割提取這個方面的研究,效果達到預期但是沒有想象突出,原因有以下兩點,1、原始資料集Massachusetts roads樣本有部分有較大的偏差,真值存在錯誤緻使訓練錯誤。2、Loss設計不合理,具體可以見有關遙感領域分割資訊提取Loss設計相關論文,本人也在學習階段。

歡迎大家留言讨論。