天天看點

語義分割3__基于Mobile網絡的Unet模型詳解以及訓練自己的Unet模型(劃分斑馬線)

文章目錄

    • 1. 什麼是Unet模型
    • 2. Unet模型的代碼實作
      • 2.1、主幹模型Mobilenet
      • 2.2、Unet的Decoder解碼部分
    • 3. 代碼測試
    • 4. 訓練檔案詳解
    • 5. LOSS函數的組成
    • 6. 訓練代碼
      • 6.1、檔案存放方式
      • 6.2、訓練檔案
      • 6.3、預測檔案
    • 7. 訓練結果

Unet

是一個語義分割模型,其主要執行過程與其它語義分割模型類似,首先利用卷積進行下采樣,然後提取出一層又一層的特征,利用這一層又一層的特征,其再進行上采樣,最後得出一個每個像素點對應其種類的圖像。

看如下這幅圖我們大概可以看出個是以然來:

在進行

Segnet

的詳解的時候我們知道,其隻選了一個h*w壓縮了四次的特征層進行三次上采樣得到最後的結果。

但是

Unet

不一樣,其利用到了壓縮了二、三、四次的特征層,最後輸出圖像分割的結果(可以選擇是否需要壓縮了一次的特征層)。

語義分割3__基于Mobile網絡的Unet模型詳解以及訓練自己的Unet模型(劃分斑馬線)

具體的網絡結構如下,左邊的順序從上向下傳播,右邊的順序從下向上傳播:

語義分割3__基于Mobile網絡的Unet模型詳解以及訓練自己的Unet模型(劃分斑馬線)

其主要的過程就是,将h*w被壓縮了四次的f4進行一次上采樣後與f3進行concatenate,然後再進行一次上采樣與f2進行concatenate,然後再進行一次上采樣(這裡可以選擇是否與f1進行concatenate),最後利用卷積輸出filter為nclasses的圖像。(一共進行三次上采樣)

Unet

模型的代碼分為兩部分。

該部分用于特征提取,實際上就是正常的

mobilenet

結構。

from keras.models import *
from keras.layers import *
import keras.backend as K
import keras

IMAGE_ORDERING = 'channels_last'

def relu6(x):
	return K.relu(x, max_value=6)

def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):

	channel_axis = 1 if IMAGE_ORDERING == 'channels_first' else -1
	filters = int(filters * alpha)
	x = ZeroPadding2D(padding=(1, 1), name='conv1_pad', data_format=IMAGE_ORDERING  )(inputs)
	x = Conv2D(filters, kernel , data_format=IMAGE_ORDERING  ,
										padding='valid',
										use_bias=False,
										strides=strides,
										name='conv1')(x)
	x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x)
	return Activation(relu6, name='conv1_relu')(x)

def _depthwise_conv_block(inputs, pointwise_conv_filters, alpha,
													depth_multiplier=1, strides=(1, 1), block_id=1):

	channel_axis = 1 if IMAGE_ORDERING == 'channels_first' else -1
	pointwise_conv_filters = int(pointwise_conv_filters * alpha)

	x = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING , name='conv_pad_%d' % block_id)(inputs)
	x = DepthwiseConv2D((3, 3) , data_format=IMAGE_ORDERING ,
														 padding='valid',
														 depth_multiplier=depth_multiplier,
														 strides=strides,
														 use_bias=False,
														 name='conv_dw_%d' % block_id)(x)
	x = BatchNormalization(
			axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x)
	x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x)

	x = Conv2D(pointwise_conv_filters, (1, 1), data_format=IMAGE_ORDERING ,
										padding='same',
										use_bias=False,
										strides=(1, 1),
										name='conv_pw_%d' % block_id)(x)
	x = BatchNormalization(axis=channel_axis,
																name='conv_pw_%d_bn' % block_id)(x)
	return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x)

def get_mobilenet_encoder( input_height=224 ,  input_width=224 , pretrained='imagenet' ):

	alpha=1.0
	depth_multiplier=1
	dropout=1e-3


	img_input = Input(shape=(input_height,input_width , 3 ))


	x = _conv_block(img_input, 32, alpha, strides=(2, 2))
	x = _depthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1) 
	f1 = x

	x = _depthwise_conv_block(x, 128, alpha, depth_multiplier,
														strides=(2, 2), block_id=2)  
	x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3) 
	f2 = x

	x = _depthwise_conv_block(x, 256, alpha, depth_multiplier,
														strides=(2, 2), block_id=4)  
	x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5) 
	f3 = x

	x = _depthwise_conv_block(x, 512, alpha, depth_multiplier,
														strides=(2, 2), block_id=6) 
	x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7) 
	x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8) 
	x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9) 
	x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10) 
	x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11) 
	f4 = x 

	x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier,
														strides=(2, 2), block_id=12)  
	x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13) 
	f5 = x 

	return img_input , [f1 , f2 , f3 , f4 , f5 ]
           

這一部分對應着上面

Unet

模型中的解碼部分。

其關鍵就是把獲得的特征重新映射到比較大的圖中的每一個像素點,用于每一個像素點的分類。

from keras.models import *
from keras.layers import *
from nets.mobilenet import get_mobilenet_encoder


IMAGE_ORDERING = 'channels_last'
MERGE_AXIS = -1


def _unet( n_classes , encoder , l1_skip_conn=True,  input_height=416, input_width=608  ):

	img_input , levels = encoder( input_height=input_height ,  input_width=input_width )
	[f1 , f2 , f3 , f4 , f5 ] = levels 

	o = f4
	# 26,26,512
	o = ( ZeroPadding2D( (1,1) , data_format=IMAGE_ORDERING ))(o)
	o = ( Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING))(o)
	o = ( BatchNormalization())(o)

	# 52,52,512
	o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
	# 52,52,768
	o = ( concatenate([ o ,f3],axis=MERGE_AXIS )  )
	o = ( ZeroPadding2D( (1,1), data_format=IMAGE_ORDERING))(o)
	# 52,52,256
	o = ( Conv2D( 256, (3, 3), padding='valid', data_format=IMAGE_ORDERING))(o)
	o = ( BatchNormalization())(o)

	# 104,104,256
	o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
	# 104,104,384
	o = ( concatenate([o,f2],axis=MERGE_AXIS ) )
	o = ( ZeroPadding2D((1,1) , data_format=IMAGE_ORDERING ))(o)
	# 104,104,128
	o = ( Conv2D( 128 , (3, 3), padding='valid' , data_format=IMAGE_ORDERING ) )(o)
	o = ( BatchNormalization())(o)
	# 208,208,128
	o = ( UpSampling2D( (2,2), data_format=IMAGE_ORDERING))(o)
	
	if l1_skip_conn:
		o = ( concatenate([o,f1],axis=MERGE_AXIS ) )

	o = ( ZeroPadding2D((1,1)  , data_format=IMAGE_ORDERING ))(o)
	o = ( Conv2D( 64 , (3, 3), padding='valid'  , data_format=IMAGE_ORDERING ))(o)
	o = ( BatchNormalization())(o)

	o =  Conv2D( n_classes , (3, 3) , padding='same', data_format=IMAGE_ORDERING )( o )
	
	# 将結果進行reshape
	o = Reshape((int(input_height/2)*int(input_width/2), -1))(o)
	o = Softmax()(o)
	model = Model(img_input,o)

	return model


def mobilenet_unet( n_classes ,  input_height=224, input_width=224 , encoder_level=3):

	model =  _unet( n_classes , get_mobilenet_encoder ,  input_height=input_height, input_width=input_width  )
	model.model_name = "mobilenet_unet"
	return model
           

将上面兩個代碼分别儲存為

mobilenet.py

unet.py

。按照如下方式存儲:

語義分割3__基于Mobile網絡的Unet模型詳解以及訓練自己的Unet模型(劃分斑馬線)

此時我們運作

test.py

的代碼:

from nets.unet import mobilenet_unet


model = mobilenet_unet(2,416,416)
model.summary()
           

如果沒有出錯的話就會得到如下的結果:

語義分割3__基于Mobile網絡的Unet模型詳解以及訓練自己的Unet模型(劃分斑馬線)

其模型比

Segnet

稍微大一點。 到這裡就完成了基于

Mobile

模型的

Unet

的搭建。

這個要從訓練檔案講起。語義分割模型訓練的檔案分為兩部分。

第一部分是原圖,像這樣:

語義分割3__基于Mobile網絡的Unet模型詳解以及訓練自己的Unet模型(劃分斑馬線)

第二部分标簽,像這樣:

語義分割3__基于Mobile網絡的Unet模型詳解以及訓練自己的Unet模型(劃分斑馬線)

當你們看到這個标簽的時候你們會說,我靠,你給我看的什麼辣雞,全黑的算什麼标簽,其實并不是這樣的,這個标簽看起來全黑,但是實際上在斑馬線的部分其RGB三個通道的值都是1。

其實給你們換一個圖你們就可以更明顯的看到了。

這是

voc

資料集中語義分割的訓練集中的一幅圖:

語義分割3__基于Mobile網絡的Unet模型詳解以及訓練自己的Unet模型(劃分斑馬線)

這是它的标簽。

語義分割3__基于Mobile網絡的Unet模型詳解以及訓練自己的Unet模型(劃分斑馬線)

為什麼這裡的标簽看起來就清楚的多呢,因為在voc中,其一共需要分21類,是以火車的RGB的值可能都大于10了,當然看得見。

是以,在訓練集中,如果像本文一樣分兩類,那麼背景的RGB就是000,斑馬線的RGB就是111,如果分多類,那麼還會存在222,333,444這樣的。這說明其屬于不同的類。

關于

loss

函數的組成我們需要看兩個

loss

函數的組成部分,第一個是預測結果。

# 此時輸出為h_input/2,w_input/2,nclasses
o =  Conv2D( n_classes , (3, 3) , padding='same', data_format=IMAGE_ORDERING )( o )
# 将結果進行reshape
o = Reshape((int(input_height/2)*int(input_width/2), -1))(o)
o = Softmax()(o)
model = Model(img_input,o)
           

其首先利用filter為n_classes的卷積核進行卷積,此時輸出為h_input/2,w_input/2,nclasses,對應着每一個h*w像素點上的種類。之後利用Softmax估計屬于每一個種類的機率。

其最後預測

y_pre

其實就是每一個像素點屬于哪一個種類的機率。

第二個是真實值,真實值是這樣處理的。

# 從檔案中讀取圖像
img = Image.open(r".\dataset2\png" + '/' + name)
img = img.resize((int(WIDTH/2),int(HEIGHT/2)))
img = np.array(img)
seg_labels = np.zeros((int(HEIGHT/2),int(WIDTH/2),NCLASSES))
for c in range(NCLASSES):
    seg_labels[: , : , c ] = (img[:,:,0] == c ).astype(int)
seg_labels = np.reshape(seg_labels, (-1,NCLASSES))
Y_train.append(seg_labels)
           

其将

png

圖先進行

resize

resize

後其大小與預測

y_pre

h*w

相同,然後讀取每一個像素點屬于什麼種類,并存入。

其最後真實

y_true

其實就是每一個像素點确實屬于哪個種類。

最後

loss

函數的組成就是

y_true

y_pre

的交叉熵。

大家可以在我的

github

上下載下傳完整的代碼。

https://github.com/bubbliiiing/Semantic-Segmentation

資料集的連結為:

連結:https://pan.baidu.com/s/1uzwqLaCXcWe06xEXk1ROWw

提取碼:pp6w

如圖所示:

語義分割3__基于Mobile網絡的Unet模型詳解以及訓練自己的Unet模型(劃分斑馬線)

其中

img

img_out

是測試檔案。

訓練檔案如下:

from nets.unet import mobilenet_unet
from keras.optimizers import Adam
from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from PIL import Image
import keras
from keras import backend as K
import numpy as np

NCLASSES = 2
HEIGHT = 416
WIDTH = 416

def generate_arrays_from_file(lines,batch_size):
    # 擷取總長度
    n = len(lines)
    i = 0
    while 1:
        X_train = []
        Y_train = []
        # 擷取一個batch_size大小的資料
        for _ in range(batch_size):
            if i==0:
                np.random.shuffle(lines)
            name = lines[i].split(';')[0]
            # 從檔案中讀取圖像
            img = Image.open(r".\dataset2\jpg" + '/' + name)
            img = img.resize((WIDTH,HEIGHT))
            img = np.array(img)
            img = img/255
            X_train.append(img)

            name = (lines[i].split(';')[1]).replace("\n", "")
            # 從檔案中讀取圖像
            img = Image.open(r".\dataset2\png" + '/' + name)
            img = img.resize((int(WIDTH/2),int(HEIGHT/2)))
            img = np.array(img)
            seg_labels = np.zeros((int(HEIGHT/2),int(WIDTH/2),NCLASSES))
            for c in range(NCLASSES):
                seg_labels[: , : , c ] = (img[:,:,0] == c ).astype(int)
            seg_labels = np.reshape(seg_labels, (-1,NCLASSES))
            Y_train.append(seg_labels)

            # 讀完一個周期後重新開始
            i = (i+1) % n
        yield (np.array(X_train),np.array(Y_train))

def loss(y_true, y_pred):
    crossloss = K.binary_crossentropy(y_true,y_pred)
    loss = 4 * K.sum(crossloss)/HEIGHT/WIDTH
    return loss

if __name__ == "__main__":
    log_dir = "logs/"
    # 擷取model
    model = mobilenet_unet(n_classes=NCLASSES,input_height=HEIGHT, input_width=WIDTH)
    # model.summary()
    BASE_WEIGHT_PATH = ('https://github.com/fchollet/deep-learning-models/'
										'releases/download/v0.6/')
    model_name = 'mobilenet_%s_%d_tf_no_top.h5' % ( '1_0' , 224 )
   
    weight_path = BASE_WEIGHT_PATH + model_name
    weights_path = keras.utils.get_file(model_name, weight_path )
    print(weight_path)
    model.load_weights(weights_path,by_name=True,skip_mismatch=True)

    # model.summary()
    # 打開資料集的txt
    with open(r".\dataset2\train.txt","r") as f:
        lines = f.readlines()

    # 打亂行,這個txt主要用于幫助讀取資料來訓練
    # 打亂的資料更有利于訓練
    np.random.seed(10101)
    np.random.shuffle(lines)
    np.random.seed(None)

    # 90%用于訓練,10%用于估計。
    num_val = int(len(lines)*0.1)
    num_train = len(lines) - num_val

    # 儲存的方式,1世代儲存一次
    checkpoint_period = ModelCheckpoint(
                                    log_dir + 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',
                                    monitor='val_loss', 
                                    save_weights_only=True, 
                                    save_best_only=True, 
                                    period=1
                                )
    # 學習率下降的方式,val_loss三次不下降就下降學習率繼續訓練
    reduce_lr = ReduceLROnPlateau(
                            monitor='val_loss', 
                            factor=0.5, 
                            patience=3, 
                            verbose=1
                        )
    # 是否需要早停,當val_loss一直不下降的時候意味着模型基本訓練完畢,可以停止
    early_stopping = EarlyStopping(
                            monitor='val_loss', 
                            min_delta=0, 
                            patience=10, 
                            verbose=1
                        )

    # 交叉熵
    model.compile(loss = loss,
            optimizer = Adam(lr=1e-3),
            metrics = ['accuracy'])
    batch_size = 2
    print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
    
    # 開始訓練
    model.fit_generator(generate_arrays_from_file(lines[:num_train], batch_size),
            steps_per_epoch=max(1, num_train//batch_size),
            validation_data=generate_arrays_from_file(lines[num_train:], batch_size),
            validation_steps=max(1, num_val//batch_size),
            epochs=50,
            initial_epoch=0,
            callbacks=[checkpoint_period, reduce_lr])

    model.save_weights(log_dir+'last1.h5')
           

預測檔案如下:

from nets.unet import mobilenet_unet
from PIL import Image
import numpy as np
import random
import copy
import os

random.seed(0)
class_colors = [[0,0,0],[0,255,0]]
NCLASSES = 2
HEIGHT = 416
WIDTH = 416


model = mobilenet_unet(n_classes=NCLASSES,input_height=HEIGHT, input_width=WIDTH)
model.load_weights("logs/ep015-loss0.070-val_loss0.076.h5")

imgs = os.listdir("./img")

for jpg in imgs:

    img = Image.open("./img/"+jpg)
    old_img = copy.deepcopy(img)
    orininal_h = np.array(img).shape[0]
    orininal_w = np.array(img).shape[1]

    img = img.resize((WIDTH,HEIGHT))
    img = np.array(img)
    img = img/255
    img = img.reshape(-1,HEIGHT,WIDTH,3)
    pr = model.predict(img)[0]

    pr = pr.reshape((int(HEIGHT/2), int(WIDTH/2),NCLASSES)).argmax(axis=-1)

    seg_img = np.zeros((int(HEIGHT/2), int(WIDTH/2),3))
    colors = class_colors

    for c in range(NCLASSES):
        seg_img[:,:,0] += ( (pr[:,: ] == c )*( colors[c][0] )).astype('uint8')
        seg_img[:,:,1] += ((pr[:,: ] == c )*( colors[c][1] )).astype('uint8')
        seg_img[:,:,2] += ((pr[:,: ] == c )*( colors[c][2] )).astype('uint8')

    seg_img = Image.fromarray(np.uint8(seg_img)).resize((orininal_w,orininal_h))

    image = Image.blend(old_img,seg_img,0.3)
    image.save("./img_out/"+jpg)
           

原圖:

語義分割3__基于Mobile網絡的Unet模型詳解以及訓練自己的Unet模型(劃分斑馬線)

繼續閱讀