1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Drawing GoogLeNet with caffe's draw_net

Last updated at Posted at 2015-12-05

GoogLeNetをcaffeのdraw_netで描くとすごいことになるので(例はこれとかこれ),描画を簡単化する.

caffe/python/caffe/draw.pyを書き換える
diff --git a/python/caffe/draw.py b/python/caffe/draw.py
index a002b60..2fb0606 100644
--- a/python/caffe/draw.py
+++ b/python/caffe/draw.py
@@ -75,31 +75,37 @@ def get_layer_label(layer, rankdir):
         separator = '\\n'
 
     if layer.type == 'Convolution' or layer.type == 'Deconvolution':
-        # Outer double quotes needed or else colon characters don't parse
-        # properly
-        node_label = '"%s%s(%s)%skernel size: %d%sstride: %d%spad: %d"' %\
-                     (layer.name,
-                      separator,
-                      layer.type,
-                      separator,
-                      layer.convolution_param.kernel_size,
-                      separator,
-                      layer.convolution_param.stride,
-                      separator,
-                      layer.convolution_param.pad)
-    elif layer.type == 'Pooling':
-        pooling_types_dict = get_pooling_types_dict()
-        node_label = '"%s%s(%s %s)%skernel size: %d%sstride: %d%spad: %d"' %\
-                     (layer.name,
-                      separator,
-                      pooling_types_dict[layer.pooling_param.pool],
-                      layer.type,
-                      separator,
-                      layer.pooling_param.kernel_size,
-                      separator,
-                      layer.pooling_param.stride,
-                      separator,
-                      layer.pooling_param.pad)
+        separator = '\\n'
+
+#     if layer.type == 'Convolution' or layer.type == 'Deconvolution':
+#         # Outer double quotes needed or else colon characters don't parse
+#         # properly
+#         node_label = '"%s%s(%s)%skernel size: %d%sstride: %d%spad: %d"' %\
+#                      (layer.name,
+#                       separator,
+#                       layer.type,
+#                       separator,
+#                       layer.convolution_param.kernel_size[0] if len(layer.convolution_param.kernel_size._values) else 1,
+#                       separator,
+#                       layer.convolution_param.stride[0] if len(layer.convolution_param.stride._values) else 1,
+#                       separator,
+#                       layer.convolution_param.pad[0] if len(layer.convolution_param.pad._values) else 0)
+#     elif layer.type == 'Pooling':
+#         pooling_types_dict = get_pooling_types_dict()
+#         node_label = '"%s%s(%s %s)%skernel size: %d%sstride: %d%spad: %d"' %\
+#                      (layer.name,
+#                       separator,
+#                       pooling_types_dict[layer.pooling_param.pool],
+#                       layer.type,
+#                       separator,
+#                       layer.pooling_param.kernel_size,
+#                       separator,
+#                       layer.pooling_param.stride,
+#                       separator,
+#                       layer.pooling_param.pad)
+#     else:
+    if layer.type == 'InnerProduct' or layer.type == '':
+        node_label = '"%s%s(%s)"' % (layer.name, '\\n', 'full connect')
     else:
         node_label = '"%s%s(%s)"' % (layer.name, separator, layer.type)
     return node_label
@@ -140,6 +146,7 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
     pydot_edges = []
     for layer in caffe_net.layer:
         node_label = get_layer_label(layer, rankdir)
+        if layer.type == 'Dropout' or layer.type == 'ReLU': continue
         node_name = "%s_%s" % (layer.name, layer.type)
         if (len(layer.bottom) == 1 and len(layer.top) == 1 and
            layer.bottom[0] == layer.top[0]):
@@ -159,10 +166,10 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
                                 'label': edge_label})
         for top_blob in layer.top:
             pydot_nodes[top_blob + '_blob'] = pydot.Node('%s' % (top_blob))
-            if label_edges:
-                edge_label = get_edge_label(layer)
-            else:
-                edge_label = '""'
+#            if label_edges:
+#                edge_label = get_edge_label(layer)
+#            else:
+            edge_label = '""'
             pydot_edges.append({'src': node_name,
                                 'dst': top_blob + '_blob',
                                 'label': edge_label})

あとは実行.deploy.prototxtファイルにはtrain_val.prototxtからfcレイヤを手動で追加.

./caffe/python/draw_net.py deploy.prototxt googlenet-deploy.pdf --rankdir 'BT'

これでこんな図が出来上がり.
googlenet-deploy-all.png

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?