63
42

More than 5 years have passed since last update.

Depthwise (separable) convolutionとか色々な畳込みの処理時間を比較してみる

Last updated at Posted at 2017-05-02

はじめに

Kerasの作者@fcholletさんのCVPR'17論文XceptionとGoogleのMobileNets論文を読んだにて紹介したdepthwise (separable) convolutionとpointwise convolutionは、畳み込みのカーネルを空間方向とチャネル方向に分離することで、パラメータ数と計算時間を削減していた。
似たようなアプローチとして、Inception V31では空間方向の畳み込みを縦方向と横方向の畳み込みに分離し、畳み込みの受容野を維持しながら、パラメータを削減を行っている。具体的には1x77x1の畳み込みを利用して7x7の畳み込みを近似している。Inception V72では1x33x1の畳み込みが使われている。
ここでは、いくつかの種類の畳み込みの実際の処理時間を比較する。コードは下記にあります。
https://github.com/yu4u/conv-benchmark

(追記1)Keras/TensorFlow実装でも検証したので、結果とコードを追記した。それを受けて、結論や考察も修正を行っている。

(追記2)PyTorchでcudnn.benchmark = Trueおよびcudnn.fastest = Trueのオプションを追加した結果を追加。

(追記3)Dilated convolutionの結果を追記

結論だけ先に書くと、depthwise convolutionは理論上の計算量と実際の処理時間がかなり乖離しているものの、CPU環境であればある程度高速化が見込める。空間方向の畳み込みの分割もCPU環境であればほぼ理論値通りの処理時間の削減につながり、効率的。
一方、GPUの場合は、どちらもあまり効果がなさそうな結果となった(かなりtoy modelなので結論に責任は負いません)。

TensorFlowの効率的な実装とやらでどの程度早くなるんだろうとMobileNetsの論文3を見直してみたら、計算量(Million Mult-Adds)でしか比較がなくて、ちょっとアレー??ってなった。

畳み込みの比較

下記の畳み込みを比較する。

  • conv3x3:最も基本的な3x3の畳み込み
  • conv1x3, conv3x1:3x3の畳み込みを縦と横に分離した畳み込み
  • conv3x3sep:depthwiseの3x3の畳み込み
  • conv1x1:1x1の畳み込み (pointwise convolution)
  • conv5x5:ちょっと大きめの畳み込み
  • conv3x3dilated:dilation=2の3x3の畳み込み

アプローチとしては、上記のいずれかの畳み込みを利用して、入力画像サイズを32x32とした16層のCNNを構築し、バッチサイズを32としてランダムな入力を100回forwardした時間を計測することを、チャネル数を8から64まで変化させて繰り返し行なった。活性化層やBN層は加えていない。学習もしていないため、パラメータは初期値4
PyTorchで実装した。コードは下の方。PyTorch初心者なので変なことしたらツッコミください。

結果

CPUとGPUで処理時間の比較を行った。環境はUbuntu 16.04, CPU: i7-7700 3.60GHz、GPU: GeForce GTX1080。

Summary

conv1x1 conv3x1 conv1x3 conv3x3sep conv3x3 conv5x5 conv3x3dilated
Keras CPU 6.736 14.133 14.043 7.184 43.700 118.898 49.442
Keras GPU 1.135 1.525 1.440 1.556 1.571 2.848 2.008
PyTorch CPU 6.956 17.209 16.916 16.480 50.636 133.781 111.480
PyTorch GPU 0.102 0.180 0.186 1.951 0.230 1.024 0.484

PyTorch

PyTorchでは、Conv2dのパラメータgroupsに入力フィルタ数を指定することでdepthwiseな畳み込みが実現できる。この引数は元々、入力をチャネル方向にgroups (e.g. 2) 分割して、それぞれ異なる畳み込みを行うことを想定したもので、入力フィルタ数まで分割されるような用途はあまり想定されていないと思われる。

CPU

pytorch_cpu.png

conv1x1 conv3x1 conv1x3 conv3x3sep conv3x3 conv5x5 conv3x3dilated
processing time [sec] 6.956 17.209 16.916 16.480 50.636 133.781 111.480
vs 3x3 0.137 0.340 0.334 0.325 1.000 2.642 2.202
theoretical complexity 0.111 0.333 0.333 0.016 1.000 2.778 1.000

上記のテーブルは、チャネル数64のとした場合の処理時間[秒]と、conv3x3の処理時間を1とした際の処理時間、計算量のみを考慮した理論値の処理時間比をまとめている。

まず、畳み込みのサイズに関しては、理論値通り畳み込みのカーネルサイズが小さくなると、それに応じて高速化されることが分かる。なので、5x5の畳み込みを3x3の畳み込み2つに置き換えたり、3x3の畳み込みを3x1と1x3の畳み込みに置き換えたりすることは高速化につながることが分かる。

一方、depthwise (separable) な畳み込みは理論値と多く乖離して遅くなっている。1x1については(1x1のdepthwise畳み込みとか意味ないだろというツッコミはさておき)逆に遅くなっている始末である。これは恐らくフレームワークの実装の問題で、Cholletさんも論文5で下記のように言及しているので、depthwise専用の実装で高速化されるのではないかと思う。

... This is made practical by the efficient depthwise convolution implementation available in TensorFlow.

ちなみにPyTorchの実装ではseparable convolutionを利用したが、これは例えば2分割とかそういうレベルでの利用を前提としたもので、完全にdepthwiseな利用は想定していないのだと思われる。

とはいえ、conv1x1+conv3x3sepでもconv3x3よりも処理時間はかなり短くなっているので、利用する価値はありそう(精度低下との兼ね合いがどうなるかはここでは検証できていない)。

GPU

pytorch_gpu.png

conv1x1 conv3x1 conv1x3 conv3x3sep conv3x3 conv5x5
処理時間[秒] 0.102 0.173 0.169 3.786 0.230 1.108
conv3x3比 0.441 0.750 0.733 16.447 1.000 4.816
理論値 0.111 0.333 0.333 0.016 1.000 2.778

CPUの場合は、カーネルサイズを小さくした際の高速化が理論値に近い値で実現できていたが、GPUの場合はそうではない。GPUの特性上、生半可に小さなカーネルを重ねて使うくらいなら、まとめて計算したほうが効率が良いということなのだろう。なので3x3の畳み込みを3x1と1x3の畳み込みに置き換えるのは得策ではなさそう。ただ、5x5に関しては、理論値以上に時間がかかっているので、こちらは3x3の畳み込み2つに置き換えるのは良さそうである。

Depthwiseに関しては、とりあえず使わないほうが良さそうなレベルで遅くなっている。Depthwiseの実装が遅いと言うより、通常の畳み込みが最適化されまくっていて早いという感じかな…?

cudnn.benchmark = Trueとした場合

conv1x1 conv3x1 conv1x3 conv3x3sep conv3x3 conv5x5
processing time [sec] 0.096 0.173 0.169 1.716 0.229 0.984
vs 3x3 0.418 0.753 0.735 7.485 1.000 4.291
theoretical complexity 0.111 0.333 0.333 0.016 1.000 2.778

少し早くなった。遅かったconv3x3sepが2倍くらい早くなったがそれでも遅い。

cudnn.benchmark = Truecudnn.fastest = True(メモリを気にせず高速化するオプション)とした場合

conv1x1 conv3x1 conv1x3 conv3x3sep conv3x3 conv5x5 conv3x3dilated
processing time [sec] 0.102 0.180 0.186 1.951 0.230 1.024 0.484
vs 3x3 0.444 0.780 0.809 8.464 1.000 4.446 2.101
theoretical complexity 0.111 0.333 0.333 0.016 1.000 2.778 1.000

あまり変わらず

Keras/TensorFlow

Kerasでは、TensorFlowのdepthwise_conv2dを利用してdepthwiseな畳み込みが実現できる。Kerasでは直接I/Fが提供されていないので、直接tfバックエンドを利用する必要がある。なお、KerasにはSeparableConvolution2Dは存在し、これはdepthwise畳み込みを行った後、更にpointwise畳み込みを行うものである。
ここではKerasでMobileNetsを実装されている方の、depthwise_conv2dのKerasラッパーを利用した。
https://github.com/rcmalli/keras-mobilenet

CPU

keras_cpu.png

conv1x1 conv3x1 conv1x3 conv3x3sep conv3x3 conv5x5 conv3x3dilated
processing time [sec] 6.736 14.133 14.043 7.184 43.700 118.898 49.442
vs 3x3 0.154 0.323 0.321 0.164 1.000 2.721 1.131
theoretical complexity 0.111 0.333 0.333 0.016 1.000 2.778 1.000

CPUの傾向は、PyTorchと同様だが、conv3x3sepが早い。計算量的にはconv1x1に対して無視できるはずであるが、ここでの処理時間ではほぼ同等となっている。Tensorが (height, width, channel) の順序になっていると効率的な計算ができなそうな気がするので、その辺りで限界があるのだろうか。

GPU

keras_gpu.png

conv1x1 conv3x1 conv1x3 conv3x3sep conv3x3 conv5x5 conv3x3dilated
processing time [sec] 1.135 1.525 1.440 1.556 1.571 2.848 2.008
vs 3x3 0.722 0.971 0.916 0.990 1.000 1.812 1.278
theoretical complexity 0.111 0.333 0.333 0.016 1.000 2.778 1.000

GPUだと5x5ですらフィルタを分割しないほうが良さそうな傾向。1サンプルなので、何とも言えないが、PyTorchも(5x5以外は)同様の傾向なので、素直に普通のconv3x3に使っとけということだろうか。当然CPUで推論するモデルをGPUで学習する場合は別の話だけど。

所感

というわけで、depthwise convolutionは現状の実装ではGPUでは高速化に寄与しなさそう。パラメータ数の削減に関しては、当然理論値通り削減できるので、恩恵を預かることはできる。

CPUであれば計算量のオーダーまでは行かないものの、ある程度高速化が見込めるので、精度との兼ね合いで使っていけそう。

元の論文では、separable convolutionとpointwise convolutionは計算量的にpointwise convolutionがボトルネックとしていたが、現状圧倒的にseparable convolutionのほうがボトルネックとなっている。この辺り、どのくらいまで実装で高速化できるのだろうか。メモリの持ち方的に、チャネル軸が先とか後とかで変わるのかしらん。
空間方向の畳込みの分割に関しても、CPUであれば高速化に寄与する。高速化が必要なケースは非GPUだと思うので、やりましょう。

コード

最新のコードは下記を参照

PyTorch

%matplotlib inline 
import torch
import torch.nn as nn
from torch.autograd import Variable
import time
import matplotlib.pyplot as plt
import torch.backends.cudnn as cudnn


def get_conv(filter_num, param):
    return nn.Conv2d(filter_num, filter_num, param["kernel_size"],
                     groups=filter_num if param["sep"] else 1,
                     padding=param["padding"], bias=False)

image_size = 32
batch_num = 32
layer_num = 16
filter_nums = [8, 16, 32, 64]
use_gpu = True
# use_gpu = False

params = {
    "conv5x5": {"kernel_size": 5, "sep": False, "padding": 2},
    "conv3x3": {"kernel_size": 3, "sep": False, "padding": 1},
    "conv3x3sep": {"kernel_size": 3, "sep": True, "padding": 1},
    "conv1x1": {"kernel_size": 1, "sep": False, "padding": 0},
    "conv1x3": {"kernel_size": (1, 3), "sep": False, "padding": (0, 1)},
    "conv3x1": {"kernel_size": (3, 1), "sep": False, "padding": (1, 0)}
}

results = {}

for name, param in params.items():
    timings = []

    for filter_num in filter_nums:
        layers = [get_conv(filter_num, param) for _ in range(layer_num)]
        model = nn.Sequential(*layers)

        if use_gpu:
            cudnn.benchmark = True
            cudnn.fastest = True
            input = Variable(torch.randn(batch_num, filter_num, image_size, image_size).cuda())
            model.cuda()            
        else:
            input = Variable(torch.randn(batch_num, filter_num, image_size, image_size))

        out = model(input)

        start = time.time()

        for i in range(100):
            out = model(input)

        elapsed_time = time.time() - start
        timings.append(elapsed_time)

    results[name] = timings

Keras

%matplotlib inline 
from keras.models import Sequential
from keras.layers.convolutional import Conv2D
import numpy as np
import time
import matplotlib.pyplot as plt
from depthwise_conv2d import DepthwiseConv2D
import tensorflow as tf


def get_conv(filter_num, param):
    return param["conv"](filter_num, param["kernel_size"], padding='same', use_bias=False)

image_size = 32
batch_num = 32
layer_num = 16
filter_nums = [8, 16, 32, 64]
device_mode = "/gpu:0"
# device_mode = "/cpu:0"

params = {
    "conv5x5": {"kernel_size": (5, 5), "conv": Conv2D},
    "conv3x3": {"kernel_size": (3, 3), "conv": Conv2D},
    "conv3x3sep": {"kernel_size": (3, 3), "conv": DepthwiseConv2D},
    "conv1x1": {"kernel_size": (1, 1), "conv": Conv2D},
    "conv1x3": {"kernel_size": (1, 3), "conv": Conv2D},
    "conv3x1": {"kernel_size": (3, 1), "conv": Conv2D}
}

results = {}

for name, param in params.items():
    timings = []

    for filter_num in filter_nums:
        with tf.device(device_mode):
            layers = [get_conv(filter_num, param) for _ in range(layer_num)]
            input_shape = (image_size, image_size, filter_num)
            model = Sequential()
            model.add(param["conv"](filter_num, param["kernel_size"], padding='same',
                                          input_shape=input_shape, use_bias=False))

            for _ in range(layer_num - 1):
                model.add(get_conv(filter_num, param))

            input = np.random.randn(batch_num, image_size, image_size, filter_num)
            out = model.predict(input, batch_size=batch_num)

            start = time.time()

            for i in range(100):
                out = model.predict(input, batch_size=batch_num)

            elapsed_time = time.time() - start
            timings.append(elapsed_time)

    results[name] = timings

グラフ

for name in params.keys():
    plt.plot(filter_nums, results[name], label=name, marker="o")

plt.xlabel("number of filters")
plt.ylabel("processing time")
plt.legend()

処理時間

convs = ["conv1x1", "conv3x1", "conv1x3", "conv3x3sep", "conv3x3", "conv5x5"]

print("||" + "|".join(convs) + "|")
print("|:-:|:-:|:-:|:-:|:-:|:-:|:-:|")
print("|processing time [sec]|" + "|".join(["{:0.3f}".format(results[n][-1]) for n in convs]) + "|")
print("|vs 3x3|" + "|".join(["{:0.3f}".format(results[n][-1]/results["conv3x3"][-1]) for n in convs]) + "|")
print("|theoretical complexity|0.111|0.333|0.333|0.016|1.000|2.778|")

参考文献


  1. C. Szegedy, V. Vanhoucke, S. Ioffe, J. Shlens, Z. Wojna, "Rethinking the Inception Architecture for Computer Vision," in Proc. of CVPR, 2016. 

  2. http://lsun.cs.princeton.edu/slides/Christian.pdf 

  3. A. G. Howard, M. Zhu, B. Chen, D. Kalenichenko, W. Wang, T. Weyand, M. Andreetto, and H. Adam, "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications," in arXiv:1704.04861, 2017. 

  4. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L39 

  5. F. Chollet, "Xception: Deep Learning with Depthwise Separable Convolutions," in Proc. of CVPR, 2017. 

63
42
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
63
42