天天看點

keras模型可視化pydot-ng 和 graphviz安裝問題(ubuntu)

方法一:

keras.utils.vis_utils

子產品提供了畫出Keras模型的函數(利用graphviz)

然而模型可視化過程會報錯誤:

from keras.utils import plot_model
plot_model(model, to_file='model.png')
           

keras文檔給出的解決方法:

pip install pydot-ng & brew install graphviz
           

安裝時會提醒你添加環境變量:

You may want to update following environments after installed linuxbrew.

  PATH, MANPATH, INFOPATH 
           

打開.bashrc:

在最後添加提示的環境變量即可

如果已經安裝

.linuxbrew

,若提示錯誤,可以把

.linuxbrew

删除再繼續安裝

詳細homebrew在Linux下的使用讨論及Linuxbrew安裝方法

方法二 :

打開keras可視化代碼:

def _check_pydot():
    try:
        # Attempt to create an image of a blank graph
        # to check the pydot/graphviz installation.
        pydot.Dot.create(pydot.Dot())
    except Exception:
        # pydot raises a generic Exception here,
        # so no specific class can be caught.
        raise ImportError('Failed to import pydot. You must install pydot'
                          ' and graphviz for `pydotprint` to work.')
           

可自行pip安裝:

sudo apt-get install graphviz
sudo pip install pydot-ng
           

注意需要先安裝

graphviz

再裝

pydot-ng

可視化結果

随便寫了一個2層LSTM的網絡:

from keras.models import Model
from keras.layers import LSTM, Activation, Input
import numpy as np
from keras.utils.vis_utils import plot_model

data_dim = 
timesteps = 
num_classes = 

inputs = Input(shape=(,))
lstm1 = LSTM(, return_sequences=True)(inputs)
lstm2 = LSTM( , return_sequences=True)(lstm1)
outputs = Activation('softmax')(lstm2)
model = Model(inputs=inputs,outputs=outputs)
model.compile(loss='categorical_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

x_train = np.random.random((, timesteps, data_dim))
y_train = np.random.random((, timesteps, num_classes))

x_val = np.random.random((, timesteps, data_dim))
y_val = np.random.random((, timesteps, num_classes))

model.fit(x_train, y_train,
          batch_size=, epochs=,
          validation_data=(x_val, y_val))
#模型可視化
plot_model(model, to_file='model.png')
x = np.arange().reshape(,,)
a = model.predict(x,batch_size=)
print a
           

結果:

keras模型可視化pydot-ng 和 graphviz安裝問題(ubuntu)

繼續閱讀