天天看点

神经网络模型图绘制

# -*- coding: utf-8 -*-
# @Time    : 2019/11/12 10:46
# @Author  : Chicker
# @FileName: vis_model.py
# @Software: PyCharm
# @Blog    :http://blog.csdn.net/u010105243/article/

import os
import keras
from keras_bert import get_model
import pydot_ng

import os

os.environ["PATH"] += os.pathsep + 'D:/bin/'

model = get_model(
    token_num=30000,
    pos_num=512,
    transformer_num=12,
    head_num=12,
    embed_dim=768,
    feed_forward_dim=768 * 4,
)
model.summary(line_length=120)
current_path = os.path.dirname(os.path.abspath(__file__))
output_path = os.path.join(current_path, 'bert_small.png')
keras.utils.plot_model(model, show_shapes=True, to_file=output_path)

model = get_model(
    token_num=30000,
    pos_num=512,
    transformer_num=24,
    head_num=16,
    embed_dim=1024,
    feed_forward_dim=1024 * 4,
)
model.summary(line_length=120)
output_path = os.path.join(current_path, 'bert_big.png')
keras.utils.plot_model(model, show_shapes=True, to_file=output_path)

inputs, outputs = get_model(
    token_num=30000,
    pos_num=512,
    transformer_num=12,
    head_num=12,
    embed_dim=768,
    feed_forward_dim=768 * 4,
    training=False,
)
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss='mse', metrics={})
model.summary(line_length=120)
current_path = os.path.dirname(os.path.abspath(__file__))
output_path = os.path.join(current_path, 'bert_trained.png')
keras.utils.plot_model(model, show_shapes=True, to_file=output_path)