天天看點

Faster-RCNN物體檢測---(3)ResNet搭建

名詞解釋:

Faster-RCNN物體檢測---(3)ResNet搭建

ResNet基本結構

Faster-RCNN物體檢測---(3)ResNet搭建

繼續分

Faster-RCNN物體檢測---(3)ResNet搭建

關于ResNet50

Faster-RCNN物體檢測---(3)ResNet搭建

文章中具體結構

Faster-RCNN物體檢測---(3)ResNet搭建

注意:

Faster-RCNN物體檢測---(3)ResNet搭建

代碼:

import keras.layers as KL
from keras.models import Model
import keras.backend as K
import tensorflow as tf
from keras.utils import np_utils
from keras.utils.vis_utils import plot_model
import numpy as np
import os

from keras.datasets import mnist

def building_block(filters, block):
#block1&2的判斷代碼
    if block != 0:
        stride = 1
    else:
        stride = 2

    def f(x):
        #主通路結構
        y = KL.Conv2D(filters=filters, kernel_size=(1,1),strides=stride)(x)
        y = KL.BatchNormalization(axis=3)(y)
        y = KL.Activation('relu')(y)

        y = KL.Conv2D(filters=filters, kernel_size=(3, 3), padding='same')(y)
        y = KL.BatchNormalization(axis=3)(y)
        y = KL.Activation('relu')(y)

        #輔路
        y = KL.Conv2D(filters=filters*4, kernel_size=(1,1))(y)
        y = KL.BatchNormalization(axis=3)(y)
        shortcut = 0
        if block == 0:
            shortcut = KL.Conv2D(filters=filters*4, kernel_size=(1,1),strides=stride)(x)
            shortcut = KL.BatchNormalization()(shortcut)
        else:
            shortcut = x
        #主通和shortcut相加
        y = KL.Add()([y, shortcut])
        y = KL.Activation('relu')(y)
        return y
    return f

def ResNet_Extractor(Xtrain, Ytrain, Xtest, Ytest):
    #頭部 TOP
    # customise your top input
    input = KL.Input([28, 28, 1])
    x = KL.Conv2D(filters=64, kernel_size=(3,3), padding='same')(input)
    x = KL.Activation('relu')(x)
    #配置設定、布局你的block關系 主要部分
    filters = 64
    block = [2,3,4] #代表幾個block
    for stage, block_num in enumerate(block):
        print('--stage--',stage,'----')
        for block_id in range(block_num):
            print('---block--', block_id, '----')
            x = building_block(filters=filters, block=block_id)(x)
        filters *= 2

    #尾部輸出
    x = KL.AveragePooling2D(pool_size=(2, 2))(x)
    x = KL.Flatten()(x)
    x = KL.Dense(units=10, activation='softmax')(x)

    model = Model(inputs=input, outputs=x)
    model.compile(loss='categorical_crossentropy'
                  ,optimizer='adam',
                  metrics=['accuracy'])
    history = model.fit(
        Xtrain, Ytrain,
        epochs=6,
        verbose=1,
        validation_data=(Xtest,Ytest)
    )
    model.save('resnetMnist.h5')
    return model

def main():
    f = np.load('C:/Users/Administrator/Desktop/keras代碼/mnist.npz')
    Xtrain, Ytrain = f['x_train'], f['y_train']
    Xtest, Ytest = f['x_test'], f['y_test']

    Xtrain = Xtrain.reshape(-1, 28, 28, 1)
    Xtest = Xtest.reshape(-1, 28, 28, 1)

    Xtrain = Xtrain/255.
    Xtest = Xtest/255.

    Ytrain = np_utils.to_categorical(Ytrain, 10)
    Ytest = np_utils.to_categorical(Ytest, 10)

    ResNet_Extractor(Xtrain,  Ytrain, Xtest, Ytest)

main()
           

結果:

Faster-RCNN物體檢測---(3)ResNet搭建
Faster-RCNN物體檢測---(3)ResNet搭建

繼續閱讀