1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

TensorFlowのモデルを可視化する。(Windows)

Last updated at Posted at 2022-01-07

#はじめに
TensorFlowのネットワークモデルを可視化する方法を書いておきます。
kerasにbase.summary()という素晴らしい命令がありますが、テキスト出力です。
一目でネットワークの構造が見えるグラフで表示したい時に有用な方法です。

#準備するもの

pydotとgraphvizを設置する必要があります。

Linuxだと両方ともpipで設置が可能ですが、
WindowsではgraphvizはWindows installerで設置する必要があります。

pip install

pip install pydot
pip install pydotplus

#windows installer
こちらの記事を参考にして、設置してください。

(金子研究室)Windows で Graphviz のインストール

#例題

VGG16モデルを可視化してみます。

from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.utils import plot_model

#学習済みモデル VGG16
base_model = VGG16(weights='imagenet')

#summary()で可視化
print(base_model.summary())

#グラフで可視化。
plot_model(base_model, show_shapes=True)

#結果

コードのあるフォルダーにmodel.pngが作成されます。

model.png

参考として、print(base_model.summary())の結果はこちらです。

Model: "vgg16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0         
                                                                 
 block3_conv1 (Conv2D)       (None, 56, 56, 256)       295168    
                                                                 
 block3_conv2 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_conv3 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_pool (MaxPooling2D)  (None, 28, 28, 256)       0         
                                                                 
 block4_conv1 (Conv2D)       (None, 28, 28, 512)       1180160   
                                                                 
 block4_conv2 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_conv3 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_pool (MaxPooling2D)  (None, 14, 14, 512)       0         
                                                                 
 block5_conv1 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv2 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv3 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_pool (MaxPooling2D)  (None, 7, 7, 512)         0         
                                                                 
 flatten (Flatten)           (None, 25088)             0         
                                                                 
 fc1 (Dense)                 (None, 4096)              102764544 
                                                                 
 fc2 (Dense)                 (None, 4096)              16781312  
                                                                 
 predictions (Dense)         (None, 1000)              4097000   
                                                                 
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

#参考資料

  1. tf.keras.utils の plot_model で Failed to import pydot. You must install pydot and graphviz for pydotprint to work と言われた時【TensorFlow】
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?