天天看點

unet-keras完整訓練流程Unet-keras完整訓練流程模型預測

unet-keras完整訓練流程Unet-keras完整訓練流程模型預測

Unet-keras完整訓練流程

前言

​ UNet是一個經典的網絡設計方式,在圖像分割任務中具有大量的應用。也有許多新的方法在此基礎上進行改進,融合更加新的網絡設計理念,在小批量資料集上也經常能取得不錯的效果。

Unet系列文章

關于Unet系列模型的介紹可以參考文章:https://zhuanlan.zhihu.com/p/57530767

​ 該文章介紹了U-Net、3D U-Net、TernausNet、Res-UNet 和Dense U-Net、MultiResUNet、R2U-Net、Attention UNet 等模型,至于這些方法的有效性,我們還需要在後續實驗中進行驗證。

代碼實作過程

引入庫檔案

#coding=utf-8
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import argparse
import numpy as np  
from keras.models import Sequential  
from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation,Input  
from keras.utils.np_utils import to_categorical  
from keras.preprocessing.image import img_to_array  
from keras.callbacks import ModelCheckpoint  
from sklearn.preprocessing import LabelEncoder  
from keras.models import Model
from keras.layers.merge import concatenate
from PIL import Image  
import matplotlib.pyplot as plt  
import cv2
import random
import os
from tqdm import tqdm  
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
seed = 7  
np.random.seed(seed)  
           

加載圖像并歸一化

def load_img(path, grayscale=False):
    if grayscale:
        img = cv2.imread(path,cv2.IMREAD_GRAYSCALE)
    else:
        img = cv2.imread(path)
        img = np.array(img,dtype="float") / 255.0
    return img
           

分割資料集

filepath ='./unet_train/building/'  

def get_train_val(val_rate = 0.25):
    train_url = []    
    train_set = []
    val_set  = []
    for pic in os.listdir(filepath + 'src'):
        train_url.append(pic)
    random.shuffle(train_url)
    total_num = len(train_url)
    val_num = int(val_rate * total_num)
    for i in range(len(train_url)):
        if i < val_num:
            val_set.append(train_url[i]) 
        else:
            train_set.append(train_url[i])
    return train_set,val_set

train_set,val_set = get_train_val()
# 檢視分割後的資料集大小
len(train_set), len(val_set)
           

擷取批次資料

# data for training  
def generateData(batch_size,data=[]):  
    #print 'generateData...'
    while True:  
        train_data = []  
        train_label = []  
        batch = 0  
        for i in (range(len(data))): 
            url = data[i]
            batch += 1 
            img = load_img(filepath + 'src/' + url)
            img = img_to_array(img)  
            train_data.append(img)  
            label = load_img(filepath + 'label/' + url, grayscale=True) 
            label = img_to_array(label)
            train_label.append(label)  
            if batch % batch_size==0: 
                #print 'get enough bacth!\n'
                train_data = np.array(train_data)  
                train_label = np.array(train_label)  
                yield (train_data,train_label)  
                train_data = []  
                train_label = []  
                batch = 0  
           
# data for validation 
def generateValidData(batch_size,data=[]):  
    #print 'generateValidData...'
    while True:  
        valid_data = []  
        valid_label = []  
        batch = 0  
        for i in (range(len(data))):  
            url = data[i]
            batch += 1  
            img = load_img(filepath + 'src/' + url)
            img = img_to_array(img)  
            valid_data.append(img)  
            label = load_img(filepath + 'label/' + url, grayscale=True)
            label = img_to_array(label)
            valid_label.append(label)  
            if batch % batch_size==0:  
                valid_data = np.array(valid_data)  
                valid_label = np.array(valid_label)  
                yield (valid_data,valid_label)  
                valid_data = []  
                valid_label = []  
                batch = 0  
           

定義模型架構

[外鍊圖檔轉存失敗,源站可能有防盜鍊機制,建議将圖檔儲存下來直接上傳(img-BZqxyvc3-1616036258347)(https://i.loli.net/2021/03/18/8koDuElSyvxc7Hd.png)]

# 定義一個model式模型
img_w = 256
img_h = 256
def unet():
    inputs = Input((img_w, img_h, 3))
    
    # 卷積+池化
    conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)
    conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)
    conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)
    conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2), )(conv3)

    conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)
    conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # 隻卷積不池化
    conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)
    conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)
    
    #開始上采樣過程
    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=-1)
    conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(up6)
    conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)
    
    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=-1)
    conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(up7)
    conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=-1)
    conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(up8)
    conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
    conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(up9)
    conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)
    
    # 由于我們要訓練二分類模型,是以使用simoid函數,多分類模型則使用softmax函數
    conv10 = Conv2D(1, (1, 1), activation="sigmoid")(conv9)
    #conv10 = Conv2D(n_label, (1, 1), activation="softmax")(conv9)

    model = Model(inputs=inputs, outputs=conv10)
    model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

unet_model = unet()
unet_model.summary()
           
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 256, 256, 32) 896         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 256, 256, 32) 9248        conv2d_1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 128, 128, 32) 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 128, 128, 64) 18496       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 128, 128, 64) 36928       conv2d_3[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 64, 64, 64)   0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 128)  73856       max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 64, 64, 128)  147584      conv2d_5[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 32, 32, 128)  0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 32, 32, 256)  295168      max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 32, 32, 256)  590080      conv2d_7[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 16, 16, 256)  0           conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 16, 16, 512)  1180160     max_pooling2d_4[0][0]            
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 16, 16, 512)  2359808     conv2d_9[0][0]                   
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 32, 32, 512)  0           conv2d_10[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 32, 32, 768)  0           up_sampling2d_1[0][0]            
                                                                 conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 256)  1769728     concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 32, 32, 256)  590080      conv2d_11[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 64, 64, 256)  0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 64, 64, 384)  0           up_sampling2d_2[0][0]            
                                                                 conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 64, 64, 128)  442496      concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 64, 64, 128)  147584      conv2d_13[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 128, 128, 128 0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 128, 128, 192 0           up_sampling2d_3[0][0]            
                                                                 conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 128, 128, 64) 110656      concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 128, 128, 64) 36928       conv2d_15[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)  (None, 256, 256, 64) 0           conv2d_16[0][0]                  
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 256, 256, 96) 0           up_sampling2d_4[0][0]            
                                                                 conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 256, 256, 32) 27680       concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 256, 256, 32) 9248        conv2d_17[0][0]                  
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 256, 256, 1)  33          conv2d_18[0][0]                  
==================================================================================================
Total params: 7,846,657
Trainable params: 7,846,657
Non-trainable params: 0
__________________________________________________________________________________________________
           

定義模型參數

EPOCHS = 10
BS = 16
#data_shape = 360*480  
img_w = 256  
img_h = 256  
#有一個為背景  
#n_label = 4+1  
n_label = 1
  
classes = [0. ,  1.,  2.,   3.  , 4.]  
  
labelencoder = LabelEncoder()  
labelencoder.fit(classes)  

# 定義模型存儲位置
modelcheck = ModelCheckpoint("unet_buildings.h5",monitor='val_acc',save_best_only=True,mode='max')  
callable = [modelcheck]
# 劃分資料集
train_set,val_set = get_train_val()
train_numb = len(train_set)  
valid_numb = len(val_set)  
print ("the number of train data is",train_numb)  
print ("the number of val data is",valid_numb)
           
the number of train data is 7500
the number of val data is 2500
           
# 訓練模型
H = unet_model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1,  
                validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1)
           
/home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages/ipykernel_launcher.py:3: UserWarning: Update your `fit_generator` call to the Keras 2 API: `fit_generator(generator=<generator..., steps_per_epoch=468, epochs=10, verbose=1, validation_data=<generator..., validation_steps=156, callbacks=[<keras.ca..., max_queue_size=1)`
  This is separate from the ipykernel package so we can avoid doing imports until


Epoch 1/10
468/468 [==============================] - 119s 254ms/step - loss: 0.4980 - acc: 0.7637 - val_loss: 0.4918 - val_acc: 0.7745
Epoch 2/10
468/468 [==============================] - 90s 193ms/step - loss: 0.4643 - acc: 0.7738 - val_loss: 0.4176 - val_acc: 0.7979
Epoch 3/10
468/468 [==============================] - 90s 193ms/step - loss: 0.4255 - acc: 0.7914 - val_loss: 0.4193 - val_acc: 0.7959
Epoch 4/10
468/468 [==============================] - 90s 192ms/step - loss: 0.4138 - acc: 0.7993 - val_loss: 0.4452 - val_acc: 0.7958
Epoch 5/10
468/468 [==============================] - 90s 193ms/step - loss: 0.4028 - acc: 0.8057 - val_loss: 0.3850 - val_acc: 0.8144
Epoch 6/10
468/468 [==============================] - 90s 193ms/step - loss: 0.3840 - acc: 0.8135 - val_loss: 0.3695 - val_acc: 0.8237
Epoch 7/10
468/468 [==============================] - 90s 193ms/step - loss: 0.3781 - acc: 0.8199 - val_loss: 0.3870 - val_acc: 0.8212
Epoch 8/10
468/468 [==============================] - 90s 193ms/step - loss: 0.3666 - acc: 0.8245 - val_loss: 0.3726 - val_acc: 0.8256
Epoch 9/10
468/468 [==============================] - 90s 193ms/step - loss: 0.3524 - acc: 0.8321 - val_loss: 0.3526 - val_acc: 0.8333
Epoch 10/10
468/468 [==============================] - 90s 193ms/step - loss: 0.3412 - acc: 0.8386 - val_loss: 0.3347 - val_acc: 0.8431
           

繪制訓練和驗證精度及損失

# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
N = EPOCHS
plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on U-Net Satellite Seg")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.show()
           
unet-keras完整訓練流程Unet-keras完整訓練流程模型預測
# 儲存圖檔
plt.savefig("unet_building.png")
           

模型預測

from keras.models import load_model
# 定義預測圖檔名稱及路徑
TEST_SET = ['1.png','2.png', '3.png']
image_dir = "./test_image/"
           
# 加載訓練好的模型
print("[INFO] loading network...")
model = load_model("unet_buildings20.h5")
# 設定長寬、滑動步長
image_size = 256
stride = 256

for n in range(len(TEST_SET)):
    path = TEST_SET[n]
    #load the image
    image = cv2.imread(image_dir + path)
    h,w,_ = image.shape
    
    # 對圖像進行填充,使圖像大小變成256的整數倍
    padding_h = (h//stride + 1) * stride 
    padding_w = (w//stride + 1) * stride
    padding_img = np.zeros((padding_h, padding_w, 3),dtype=np.uint8)
    padding_img[0:h,0:w,:] = image[:,:,:]
    #padding_img = padding_img.astype("float") / 255.0
    padding_img = img_to_array(padding_img)
    
    print('src:',padding_img.shape)
    mask_whole = np.zeros((padding_h, padding_w),dtype=np.uint8)
    print(padding_h, padding_w)
    for i in range(padding_h//stride):
        for j in range(padding_w//stride):
            crop = padding_img[i*stride:(i*stride+image_size), j*stride:(j*stride+image_size), :3]
            ch,cw,_ = crop.shape
            if ch != 256 or cw != 256:
                print('invalid size!')
                continue
            
            # 對分割好的圖檔進行預測
            crop = np.expand_dims(crop, axis=0) 
            pred = model.predict(crop,verbose=2)
            #print (np.unique(pred))  
            pred = pred.reshape((256,256)).astype(np.uint8)
            #print 'pred:',pred.shape
            mask_whole[i*stride:i*stride+image_size,j*stride:j*stride+image_size] = pred[:,:]
    
    cv2.imwrite('./predict/20pre'+str(n+1)+'.png',mask_whole[0:h,0:w])
    print("第%d" %n + "張圖檔預測完成")
    visualize = np.zeros((padding_h, padding_w)).astype(np.uint8)
    visualize = mask_whole *255
    # (255, 255, 0)是黃色
    result = np.array([visualize, visualize, np.zeros((padding_h, padding_w),dtype=np.uint8)])
    result = result.transpose(1, 2, 0)
    print(result.shape)

    #... get array s.t. arr.shape = (w, h, 3)
    img = Image.fromarray(result).convert('RGB')  # 将數組轉化回圖檔
    img.save('./predict/20pre'+str(n+1)+'.tif')  #
           
unet-keras完整訓練流程Unet-keras完整訓練流程模型預測

思考總結

特征提取方式

​ U-net模型能夠充分利用不同層次的圖像特征,使得它具有良好的學習和表示能力,但它依賴于多級級聯卷積神經網絡。這些級聯架構提取感興趣的區域并做出密集的預測。這種方法在重複提取低層特征時會導緻計算資源的過度和備援使用。

适合于小目标分割任務

​ U-net系列網絡在小目标分割任務上性能一直表現不錯,筆者猜測這恰恰也得益于其對于底層特征(如顔色、紋理等)的重複計算,如同注意力機制一般,對底層特征的關注更多,更有利于提取小目标資訊。

​ 對于遙感影像而言,大多數目标地類都是小目标,可能這就是為什麼Unet網絡在Kaggle遙感影像分割比賽中那麼受人歡迎叭??

邊界精度差

​ 由于硬碟和記憶體限制,筆者僅使用了1w張圖檔進行訓練,是以精度不是很高,局部效果看起來還湊合吧!我隻能說:懂得都懂!

​ 此外,由于U-net模型在池化時會有一定的精度損失,是以這裡的建築物邊界都比較模糊,或者說是平滑,師兄說可以利用FPN或者其他正則化方法來進行處理,也許吧。

unet-keras完整訓練流程Unet-keras完整訓練流程模型預測
unet-keras完整訓練流程Unet-keras完整訓練流程模型預測

繼續閱讀