天天看點

Tensorflow2 使用經典的模型

import tensorflow as tf
from tensorflow.keras import applications

"""
keras.applications 中共有以下模型
DenseNet121(...)
DenseNet169(...)
DenseNet201(...)
InceptionResNetV2(...)
InceptionV3(...)
MobileNet(...)
MobileNetV2(...)
NASNetLarge(...)
NASNetMobile(...)
ResNet101(...)
ResNet101V2(...)
ResNet152(...)
ResNet152V2(...)
ResNet50(...)
ResNet50V2(...)
VGG16(...)
VGG19(...)
Xception(...)
下面僅以NASNetLarge為例
"""

# 第一種情況自定義輸入大小 這時不能加載imagenet預訓練參數  
IMG_SHAPE = (224,224,3)                                          
base_model = applications.NASNetLarge(input_shape=IMG_SHAPE, include_top=False, 
					weights=None, pooling='avg')

base_model.trainable = True # 參數可訓練

# 第二種情況采用API自定義的輸入大小,這時可以加載imagenet預訓練參數                                           
base_model = applications.NASNetLarge(input_shape=None, include_top=False, 
					weights='imagenet', pooling='avg')
					
base_model.trainable = False # 參數可訓練 這個看你自己的想法

# 第三種情況 不對網絡進行修改
base_model = applications.NASNetLarge(input_shape=None, include_top=True, 
					weights='imagenet', pooling=None)
					
base_model.trainable = True # 參數可訓練 這個看你自己的想法

# 看一下模型的結構
base_model.summary()

# 大多數 我們用的還是第二種情況
           

參考網址:

https://tensorflow.google.cn/api_docs/python/tf/keras/applications

https://www.tensorflow.org/tutorials/images/transfer_learning

https://keras.io/applications/