LoginSignup
1
1

More than 1 year has passed since last update.

TensorFlowのモデルを可視化する。(Windows)

Last updated at Posted at 2022-01-07

はじめに

TensorFlowのネットワークモデルを可視化する方法を書いておきます。
kerasにbase.summary()という素晴らしい命令がありますが、テキスト出力です。
一目でネットワークの構造が見えるグラフで表示したい時に有用な方法です。

準備するもの

pydotとgraphvizを設置する必要があります。

Linuxだと両方ともpipで設置が可能ですが、
WindowsではgraphvizはWindows installerで設置する必要があります。

pip install

pip install pydot
pip install pydotplus

windows installer

こちらの記事を参考にして、設置してください。

(金子研究室)Windows で Graphviz のインストール

例題

VGG16モデルを可視化してみます。

from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.utils import plot_model

#学習済みモデル VGG16
base_model = VGG16(weights='imagenet')

#summary()で可視化
print(base_model.summary())

#グラフで可視化。
plot_model(base_model, show_shapes=True)

結果

コードのあるフォルダーにmodel.pngが作成されます。

model.png

参考として、print(base_model.summary())の結果はこちらです。

Model: "vgg16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         

 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      

 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     

 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         

 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     

 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    

 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0         

 block3_conv1 (Conv2D)       (None, 56, 56, 256)       295168    

 block3_conv2 (Conv2D)       (None, 56, 56, 256)       590080    

 block3_conv3 (Conv2D)       (None, 56, 56, 256)       590080    

 block3_pool (MaxPooling2D)  (None, 28, 28, 256)       0         

 block4_conv1 (Conv2D)       (None, 28, 28, 512)       1180160   

 block4_conv2 (Conv2D)       (None, 28, 28, 512)       2359808   

 block4_conv3 (Conv2D)       (None, 28, 28, 512)       2359808   

 block4_pool (MaxPooling2D)  (None, 14, 14, 512)       0         

 block5_conv1 (Conv2D)       (None, 14, 14, 512)       2359808   

 block5_conv2 (Conv2D)       (None, 14, 14, 512)       2359808   

 block5_conv3 (Conv2D)       (None, 14, 14, 512)       2359808   

 block5_pool (MaxPooling2D)  (None, 7, 7, 512)         0         

 flatten (Flatten)           (None, 25088)             0         

 fc1 (Dense)                 (None, 4096)              102764544 

 fc2 (Dense)                 (None, 4096)              16781312  

 predictions (Dense)         (None, 1000)              4097000   

=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

参考資料

  1. tf.keras.utils の plot_model で Failed to import pydot. You must install pydot and graphviz for pydotprint to work と言われた時【TensorFlow】
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1