天天看点

unet-keras完整训练流程Unet-keras完整训练流程模型预测

unet-keras完整训练流程Unet-keras完整训练流程模型预测

Unet-keras完整训练流程

前言

​ UNet是一个经典的网络设计方式,在图像分割任务中具有大量的应用。也有许多新的方法在此基础上进行改进,融合更加新的网络设计理念,在小批量数据集上也经常能取得不错的效果。

Unet系列文章

关于Unet系列模型的介绍可以参考文章:https://zhuanlan.zhihu.com/p/57530767

​ 该文章介绍了U-Net、3D U-Net、TernausNet、Res-UNet 和Dense U-Net、MultiResUNet、R2U-Net、Attention UNet 等模型,至于这些方法的有效性,我们还需要在后续实验中进行验证。

代码实现过程

引入库文件

#coding=utf-8
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import argparse
import numpy as np  
from keras.models import Sequential  
from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation,Input  
from keras.utils.np_utils import to_categorical  
from keras.preprocessing.image import img_to_array  
from keras.callbacks import ModelCheckpoint  
from sklearn.preprocessing import LabelEncoder  
from keras.models import Model
from keras.layers.merge import concatenate
from PIL import Image  
import matplotlib.pyplot as plt  
import cv2
import random
import os
from tqdm import tqdm  
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
seed = 7  
np.random.seed(seed)  
           

加载图像并归一化

def load_img(path, grayscale=False):
    if grayscale:
        img = cv2.imread(path,cv2.IMREAD_GRAYSCALE)
    else:
        img = cv2.imread(path)
        img = np.array(img,dtype="float") / 255.0
    return img
           

分割数据集

filepath ='./unet_train/building/'  

def get_train_val(val_rate = 0.25):
    train_url = []    
    train_set = []
    val_set  = []
    for pic in os.listdir(filepath + 'src'):
        train_url.append(pic)
    random.shuffle(train_url)
    total_num = len(train_url)
    val_num = int(val_rate * total_num)
    for i in range(len(train_url)):
        if i < val_num:
            val_set.append(train_url[i]) 
        else:
            train_set.append(train_url[i])
    return train_set,val_set

train_set,val_set = get_train_val()
# 查看分割后的数据集大小
len(train_set), len(val_set)
           

获取批次数据

# data for training  
def generateData(batch_size,data=[]):  
    #print 'generateData...'
    while True:  
        train_data = []  
        train_label = []  
        batch = 0  
        for i in (range(len(data))): 
            url = data[i]
            batch += 1 
            img = load_img(filepath + 'src/' + url)
            img = img_to_array(img)  
            train_data.append(img)  
            label = load_img(filepath + 'label/' + url, grayscale=True) 
            label = img_to_array(label)
            train_label.append(label)  
            if batch % batch_size==0: 
                #print 'get enough bacth!\n'
                train_data = np.array(train_data)  
                train_label = np.array(train_label)  
                yield (train_data,train_label)  
                train_data = []  
                train_label = []  
                batch = 0  
           
# data for validation 
def generateValidData(batch_size,data=[]):  
    #print 'generateValidData...'
    while True:  
        valid_data = []  
        valid_label = []  
        batch = 0  
        for i in (range(len(data))):  
            url = data[i]
            batch += 1  
            img = load_img(filepath + 'src/' + url)
            img = img_to_array(img)  
            valid_data.append(img)  
            label = load_img(filepath + 'label/' + url, grayscale=True)
            label = img_to_array(label)
            valid_label.append(label)  
            if batch % batch_size==0:  
                valid_data = np.array(valid_data)  
                valid_label = np.array(valid_label)  
                yield (valid_data,valid_label)  
                valid_data = []  
                valid_label = []  
                batch = 0  
           

定义模型架构

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BZqxyvc3-1616036258347)(https://i.loli.net/2021/03/18/8koDuElSyvxc7Hd.png)]

# 定义一个model式模型
img_w = 256
img_h = 256
def unet():
    inputs = Input((img_w, img_h, 3))
    
    # 卷积+池化
    conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)
    conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)
    conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)
    conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2), )(conv3)

    conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)
    conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # 只卷积不池化
    conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)
    conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)
    
    #开始上采样过程
    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=-1)
    conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(up6)
    conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)
    
    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=-1)
    conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(up7)
    conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=-1)
    conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(up8)
    conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
    conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(up9)
    conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)
    
    # 由于我们要训练二分类模型,所以使用simoid函数,多分类模型则使用softmax函数
    conv10 = Conv2D(1, (1, 1), activation="sigmoid")(conv9)
    #conv10 = Conv2D(n_label, (1, 1), activation="softmax")(conv9)

    model = Model(inputs=inputs, outputs=conv10)
    model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

unet_model = unet()
unet_model.summary()
           
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 256, 256, 32) 896         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 256, 256, 32) 9248        conv2d_1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 128, 128, 32) 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 128, 128, 64) 18496       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 128, 128, 64) 36928       conv2d_3[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 64, 64, 64)   0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 128)  73856       max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 64, 64, 128)  147584      conv2d_5[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 32, 32, 128)  0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 32, 32, 256)  295168      max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 32, 32, 256)  590080      conv2d_7[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 16, 16, 256)  0           conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 16, 16, 512)  1180160     max_pooling2d_4[0][0]            
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 16, 16, 512)  2359808     conv2d_9[0][0]                   
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 32, 32, 512)  0           conv2d_10[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 32, 32, 768)  0           up_sampling2d_1[0][0]            
                                                                 conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 256)  1769728     concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 32, 32, 256)  590080      conv2d_11[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 64, 64, 256)  0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 64, 64, 384)  0           up_sampling2d_2[0][0]            
                                                                 conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 64, 64, 128)  442496      concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 64, 64, 128)  147584      conv2d_13[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 128, 128, 128 0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 128, 128, 192 0           up_sampling2d_3[0][0]            
                                                                 conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 128, 128, 64) 110656      concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 128, 128, 64) 36928       conv2d_15[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)  (None, 256, 256, 64) 0           conv2d_16[0][0]                  
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 256, 256, 96) 0           up_sampling2d_4[0][0]            
                                                                 conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 256, 256, 32) 27680       concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 256, 256, 32) 9248        conv2d_17[0][0]                  
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 256, 256, 1)  33          conv2d_18[0][0]                  
==================================================================================================
Total params: 7,846,657
Trainable params: 7,846,657
Non-trainable params: 0
__________________________________________________________________________________________________
           

定义模型参数

EPOCHS = 10
BS = 16
#data_shape = 360*480  
img_w = 256  
img_h = 256  
#有一个为背景  
#n_label = 4+1  
n_label = 1
  
classes = [0. ,  1.,  2.,   3.  , 4.]  
  
labelencoder = LabelEncoder()  
labelencoder.fit(classes)  

# 定义模型存储位置
modelcheck = ModelCheckpoint("unet_buildings.h5",monitor='val_acc',save_best_only=True,mode='max')  
callable = [modelcheck]
# 划分数据集
train_set,val_set = get_train_val()
train_numb = len(train_set)  
valid_numb = len(val_set)  
print ("the number of train data is",train_numb)  
print ("the number of val data is",valid_numb)
           
the number of train data is 7500
the number of val data is 2500
           
# 训练模型
H = unet_model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1,  
                validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1)
           
/home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages/ipykernel_launcher.py:3: UserWarning: Update your `fit_generator` call to the Keras 2 API: `fit_generator(generator=<generator..., steps_per_epoch=468, epochs=10, verbose=1, validation_data=<generator..., validation_steps=156, callbacks=[<keras.ca..., max_queue_size=1)`
  This is separate from the ipykernel package so we can avoid doing imports until


Epoch 1/10
468/468 [==============================] - 119s 254ms/step - loss: 0.4980 - acc: 0.7637 - val_loss: 0.4918 - val_acc: 0.7745
Epoch 2/10
468/468 [==============================] - 90s 193ms/step - loss: 0.4643 - acc: 0.7738 - val_loss: 0.4176 - val_acc: 0.7979
Epoch 3/10
468/468 [==============================] - 90s 193ms/step - loss: 0.4255 - acc: 0.7914 - val_loss: 0.4193 - val_acc: 0.7959
Epoch 4/10
468/468 [==============================] - 90s 192ms/step - loss: 0.4138 - acc: 0.7993 - val_loss: 0.4452 - val_acc: 0.7958
Epoch 5/10
468/468 [==============================] - 90s 193ms/step - loss: 0.4028 - acc: 0.8057 - val_loss: 0.3850 - val_acc: 0.8144
Epoch 6/10
468/468 [==============================] - 90s 193ms/step - loss: 0.3840 - acc: 0.8135 - val_loss: 0.3695 - val_acc: 0.8237
Epoch 7/10
468/468 [==============================] - 90s 193ms/step - loss: 0.3781 - acc: 0.8199 - val_loss: 0.3870 - val_acc: 0.8212
Epoch 8/10
468/468 [==============================] - 90s 193ms/step - loss: 0.3666 - acc: 0.8245 - val_loss: 0.3726 - val_acc: 0.8256
Epoch 9/10
468/468 [==============================] - 90s 193ms/step - loss: 0.3524 - acc: 0.8321 - val_loss: 0.3526 - val_acc: 0.8333
Epoch 10/10
468/468 [==============================] - 90s 193ms/step - loss: 0.3412 - acc: 0.8386 - val_loss: 0.3347 - val_acc: 0.8431
           

绘制训练和验证精度及损失

# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
N = EPOCHS
plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on U-Net Satellite Seg")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.show()
           
unet-keras完整训练流程Unet-keras完整训练流程模型预测
# 保存图片
plt.savefig("unet_building.png")
           

模型预测

from keras.models import load_model
# 定义预测图片名称及路径
TEST_SET = ['1.png','2.png', '3.png']
image_dir = "./test_image/"
           
# 加载训练好的模型
print("[INFO] loading network...")
model = load_model("unet_buildings20.h5")
# 设置长宽、滑动步长
image_size = 256
stride = 256

for n in range(len(TEST_SET)):
    path = TEST_SET[n]
    #load the image
    image = cv2.imread(image_dir + path)
    h,w,_ = image.shape
    
    # 对图像进行填充,使图像大小变成256的整数倍
    padding_h = (h//stride + 1) * stride 
    padding_w = (w//stride + 1) * stride
    padding_img = np.zeros((padding_h, padding_w, 3),dtype=np.uint8)
    padding_img[0:h,0:w,:] = image[:,:,:]
    #padding_img = padding_img.astype("float") / 255.0
    padding_img = img_to_array(padding_img)
    
    print('src:',padding_img.shape)
    mask_whole = np.zeros((padding_h, padding_w),dtype=np.uint8)
    print(padding_h, padding_w)
    for i in range(padding_h//stride):
        for j in range(padding_w//stride):
            crop = padding_img[i*stride:(i*stride+image_size), j*stride:(j*stride+image_size), :3]
            ch,cw,_ = crop.shape
            if ch != 256 or cw != 256:
                print('invalid size!')
                continue
            
            # 对分割好的图片进行预测
            crop = np.expand_dims(crop, axis=0) 
            pred = model.predict(crop,verbose=2)
            #print (np.unique(pred))  
            pred = pred.reshape((256,256)).astype(np.uint8)
            #print 'pred:',pred.shape
            mask_whole[i*stride:i*stride+image_size,j*stride:j*stride+image_size] = pred[:,:]
    
    cv2.imwrite('./predict/20pre'+str(n+1)+'.png',mask_whole[0:h,0:w])
    print("第%d" %n + "张图片预测完成")
    visualize = np.zeros((padding_h, padding_w)).astype(np.uint8)
    visualize = mask_whole *255
    # (255, 255, 0)是黄色
    result = np.array([visualize, visualize, np.zeros((padding_h, padding_w),dtype=np.uint8)])
    result = result.transpose(1, 2, 0)
    print(result.shape)

    #... get array s.t. arr.shape = (w, h, 3)
    img = Image.fromarray(result).convert('RGB')  # 将数组转化回图片
    img.save('./predict/20pre'+str(n+1)+'.tif')  #
           
unet-keras完整训练流程Unet-keras完整训练流程模型预测

思考总结

特征提取方式

​ U-net模型能够充分利用不同层次的图像特征,使得它具有良好的学习和表示能力,但它依赖于多级级联卷积神经网络。这些级联框架提取感兴趣的区域并做出密集的预测。这种方法在重复提取低层特征时会导致计算资源的过度和冗余使用。

适合于小目标分割任务

​ U-net系列网络在小目标分割任务上性能一直表现不错,笔者猜测这恰恰也得益于其对于底层特征(如颜色、纹理等)的重复计算,如同注意力机制一般,对底层特征的关注更多,更有利于提取小目标信息。

​ 对于遥感影像而言,大多数目标地类都是小目标,可能这就是为什么Unet网络在Kaggle遥感影像分割比赛中那么受人欢迎叭??

边界精度差

​ 由于硬盘和内存限制,笔者仅使用了1w张图片进行训练,所以精度不是很高,局部效果看起来还凑合吧!我只能说:懂得都懂!

​ 此外,由于U-net模型在池化时会有一定的精度损失,所以这里的建筑物边界都比较模糊,或者说是平滑,师兄说可以利用FPN或者其他正则化方法来进行处理,也许吧。

unet-keras完整训练流程Unet-keras完整训练流程模型预测
unet-keras完整训练流程Unet-keras完整训练流程模型预测

继续阅读