問題描述
keras提供了模型的可視化,如下:
![](https://img.laitimes.com/img/_0nNw4CM6IyYiwiM6ICdiwiI0gTMx81dsQWZ4lmZf1GLlpXazVmcvwFciV2dsQXYtJ3bm9CX9s2RkBnVHFmb1clWvB3MaVnRtp1XlBXe0xCMy81dvRWYoNHLwEzX5xCMx8FesU2cfdGLwMzX0xiRGZkRGZ0Xy9GbvNGLpZTY1EmMZVDUSFTU4VFRR9Fd4VGdsYTMfVmepNHLrJXYtJXZ0F2dvwVZnFWbp1zczV2YvJHctM3cv1Ce-cmbw5iNwQDOxADZ3U2N3QmMmhTNzYzX5UTM0YTM2IzLcBTMyIDMy8CXn9Gbi9CXzV2Zh1WavwVbvNmLvR3YxUjLyM3Lc9CX6MHc0RHaiojIsJye.png)
import keras
from keras import models, layers,Model
from keras.applications import VGG16
conv_base=VGG16(weights='imagenet',include_top=False,input_shape=(128,128,3))
smodel = models.Sequential()
smodel.add(conv_base)
smodel.add(layers.Flatten())
smodel.add(layers.Dense(256, activation='relu'))
smodel.add(layers.Dense(1, activation='sigmoid'))
keras.utils.plot_model(smodel, to_file='model.png', show_shapes=True, show_layer_names=True, rankdir='LR')
可以看出,其對Model對象VGG16僅展示了概要資訊,如何展示其詳盡資訊呢?
解決方案
keras.utils.plot_model時,會周遊到模型的各層,并使用諸如layer的name屬性等,如上,使用的就是 conv_base(Model對象) 的name。具體可以檢視keras.utils.vis_utils.model_to_dot()的源碼。
注意以下繼承關系:
Model->Container->Layer->Object。(子類->父類)
import keras
from keras import models, layers
from keras.applications import VGG16
conv_base=VGG16(weights='imagenet',include_top=False,input_shape=(128,128,3))
smodel = models.Sequential()
for m in conv_base.layers:
smodel.add(m)
smodel.add(layers.Flatten())
smodel.add(layers.Dense(256, activation='relu'))
smodel.add(layers.Dense(1, activation='sigmoid'))
keras.utils.plot_model(smodel, to_file='model.png', show_shapes=True, show_layer_names=True, rankdir='TB')