39
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

[PyTorch] torch.bmmよりも速く、batchごとに内積を計算する方法があった話

Last updated at Posted at 2018-09-07

PyTorchでbatchごとに内積を計算したい!

導入

Deepなモデルを構築していると、attentionの計算をする時など、ベクトル同士の内積をバッチごとに計算したい時があると思います。

(一応ベクトルの内積を計算する用のメソッドtorch.dotはありますが、これは1Dしかサポートしていません。)

僕はそんな時、いつもtorch.bmmを使っていました。

しかし、最近torch.bmmよりも要素積を取ってから和を取った方が計算が速いと言うポストをforumで見つけたので実際に実験しました。

参考:Dot product batch-wise

実験

環境

python: 3.6.0
pytorch: 0.4.1

設定

実験をする前に実験の設定をする。

(500, 500)のサイズの2つのtorch.Tensor(それぞれaとbとおく)同士のバッチごとの内積を計算して、(500, 1)のサイズのtorch.Tensorを出力を計算するまでの時間を計測する。

今回比較する計算方法は以下の2つ

  • (a*b).sum(1, keepdim=True)
  • torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).squeeze(2)

念のため結果が同じになることを確認

In [1]: import torch

In [2]: a = torch.randn(500, 500, dtype=torch.float, device='cpu')

In [3]: b = torch.randn(500, 500, dtype=torch.float, device='cpu')

# 最初の5要素が同じことと、出力の次元が同じになることを確認
In [4]: (a*b).sum(1, keepdim=True)[:5]
Out[4]: 
tensor([[ 19.5090],
        [-11.9383],
        [ 12.4870],
        [ 14.5671],
        [-49.7218]])

In [5]: torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).squeeze(2)[:5]
Out[5]: 
tensor([[ 19.5090],
        [-11.9383],
        [ 12.4870],
        [ 14.5671],
        [-49.7218]])

では早速実験していく!

①まずはCPUで

In [6]: %timeit (a*b).sum(1, keepdim=True)
26.4 µs ± 1.66 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [7]: %timeit torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).squeeze(2)
964 µs ± 619 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

確かにtorch.bmmの方が40倍以上遅い!

②次はGPU

In [8]: a = a.cuda()

In [9]: b = b.cuda()

In [10]: %timeit (a*b).sum(1, keepdim=True)
25.9 µs ± 57.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [11]: %timeit torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).squeeze(2)
608 µs ± 88.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

GPUでもはtorch.bmmの方が遅い・・・!

同じ計算をしているのにここまで差が開くとは正直驚きです。

③backwardは?

とはいえ微分の時間も含めたら逆転するかも?と思い、先ほどのように(500, 1)のテンソルにしてからsumを取って微分するまでの時間を計測します。

In [12]: a.requires_grad = True

In [13]: b.requires_grad = True

In [14]: a
Out[14]: 
tensor([[ 0.0782, -0.5167,  0.0012,  ..., -0.7461, -1.6369, -0.6827],
        [-1.0358, -0.9696, -0.7222,  ..., -0.6055,  0.6633,  0.0502],
        [ 0.9818,  0.4592,  0.4723,  ...,  1.0542,  1.0862,  0.5680],
        ...,
        [-0.4467,  0.8315,  0.6506,  ...,  0.1161,  0.5799, -1.2523],
        [ 1.3338, -0.5920, -1.4002,  ...,  0.0069,  1.0878, -0.9324],
        [-0.0193, -2.2005,  0.3563,  ...,  0.3481,  0.1945, -0.8756]],
       device='cuda:0', requires_grad=True)

In [15]: b
Out[15]: 
tensor([[-0.4584, -1.1394, -0.3559,  ...,  0.7641,  0.0415,  0.2294],
        [-0.6206,  0.3149,  1.6382,  ..., -0.3534, -0.3121, -0.4797],
        [-0.9203,  0.0587,  0.5146,  ..., -1.4103, -0.5372,  0.3373],
        ...,
        [-0.3318, -1.6943,  0.2874,  ..., -0.5378, -1.6260,  0.7773],
        [-0.0560,  0.6894, -0.7104,  ..., -2.6248,  0.4128,  1.3808],
        [ 2.1583, -1.6799,  0.3402,  ...,  2.2380,  1.7078, -1.7916]],
       device='cuda:0', requires_grad=True)

In [16]: %timeit (a*b).sum(1, keepdim=True).sum().backward()
163 µs ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [17]: %timeit torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).squeeze(2).sum().backward()
1.13 ms ± 15.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

微分を含めた合計時間もtorch.bmmの惨敗でした。

一応、②と③の差分をとって、微分にかかった時間も比較

  • (a*b).sum(1, keepdim=True)
    • 163 µs - 25.9 µs = 137.1 µs
  • torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).squeeze(2)
    • 1.13 ms - 608 µs = 552 µs

②と③での違う試行における経過時間同士の引き算は厳密にはアウトな気がしますが、それでも微分だけにかかった時間は4倍近く差がついているように見えます。

結果

〜GPU〜

計算方法 t(内積) t(内積+微分) t(微分)
要素積+和 25.9 µs 163 µs (137.1 µs)
torch.bmm 608 µs 1.13 ms (552 µs)

〜CPU〜

計算方法 t(内積)
要素積+和 26.4 µs
torch.bmm 964 µs

torch.bmmの惨敗ですね!!
なんてこったい

蛇足

今度暇があったらtorch.bmmのソースコードを見に行こうと思う。

その時何かわかったらまたここで報告します〜

では

39
12
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
39
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?