名詞解釋:
ResNet基本結構
繼續分
關于ResNet50
文章中具體結構
注意:
代碼:
import keras.layers as KL
from keras.models import Model
import keras.backend as K
import tensorflow as tf
from keras.utils import np_utils
from keras.utils.vis_utils import plot_model
import numpy as np
import os
from keras.datasets import mnist
def building_block(filters, block):
#block1&2的判斷代碼
if block != 0:
stride = 1
else:
stride = 2
def f(x):
#主通路結構
y = KL.Conv2D(filters=filters, kernel_size=(1,1),strides=stride)(x)
y = KL.BatchNormalization(axis=3)(y)
y = KL.Activation('relu')(y)
y = KL.Conv2D(filters=filters, kernel_size=(3, 3), padding='same')(y)
y = KL.BatchNormalization(axis=3)(y)
y = KL.Activation('relu')(y)
#輔路
y = KL.Conv2D(filters=filters*4, kernel_size=(1,1))(y)
y = KL.BatchNormalization(axis=3)(y)
shortcut = 0
if block == 0:
shortcut = KL.Conv2D(filters=filters*4, kernel_size=(1,1),strides=stride)(x)
shortcut = KL.BatchNormalization()(shortcut)
else:
shortcut = x
#主通和shortcut相加
y = KL.Add()([y, shortcut])
y = KL.Activation('relu')(y)
return y
return f
def ResNet_Extractor(Xtrain, Ytrain, Xtest, Ytest):
#頭部 TOP
# customise your top input
input = KL.Input([28, 28, 1])
x = KL.Conv2D(filters=64, kernel_size=(3,3), padding='same')(input)
x = KL.Activation('relu')(x)
#配置設定、布局你的block關系 主要部分
filters = 64
block = [2,3,4] #代表幾個block
for stage, block_num in enumerate(block):
print('--stage--',stage,'----')
for block_id in range(block_num):
print('---block--', block_id, '----')
x = building_block(filters=filters, block=block_id)(x)
filters *= 2
#尾部輸出
x = KL.AveragePooling2D(pool_size=(2, 2))(x)
x = KL.Flatten()(x)
x = KL.Dense(units=10, activation='softmax')(x)
model = Model(inputs=input, outputs=x)
model.compile(loss='categorical_crossentropy'
,optimizer='adam',
metrics=['accuracy'])
history = model.fit(
Xtrain, Ytrain,
epochs=6,
verbose=1,
validation_data=(Xtest,Ytest)
)
model.save('resnetMnist.h5')
return model
def main():
f = np.load('C:/Users/Administrator/Desktop/keras代碼/mnist.npz')
Xtrain, Ytrain = f['x_train'], f['y_train']
Xtest, Ytest = f['x_test'], f['y_test']
Xtrain = Xtrain.reshape(-1, 28, 28, 1)
Xtest = Xtest.reshape(-1, 28, 28, 1)
Xtrain = Xtrain/255.
Xtest = Xtest/255.
Ytrain = np_utils.to_categorical(Ytrain, 10)
Ytest = np_utils.to_categorical(Ytest, 10)
ResNet_Extractor(Xtrain, Ytrain, Xtest, Ytest)
main()
結果: