3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

einsum は割と早い?

Last updated at Posted at 2020-10-04

数パターンで行列・ベクトル演算時間を比較

やるつもりなかったけど、やり始めたら色々試してしまった。。。勿体ないのでメモ。

  \boldsymbol{A}_k = \left(
    \begin{array}{c}
      a_{k1} \\
      a_{k2} \\
      \vdots \\
      a_{kn}
    \end{array}
  \right) \\

\boldsymbol B_k = \left(
\begin{array}{ccccc}
a_{k11} & \cdots & a_{k1i} & \cdots & a_{k1n} \\
 \vdots & \ddots &         &        & \vdots \\
a_{ki1} &        & a_{kii} &        & a_{kin} \\
 \vdots &        &         & \ddots & \vdots \\
a_{kn1} & \cdots & a_{kni} & \cdots & a_{knn} \\
\end{array}
\right) \\

の場合に、

\boldsymbol{A}_k^T \boldsymbol{B}_k \\
\boldsymbol{A}_k^T \boldsymbol{B}_k \boldsymbol{A}_k \\
\boldsymbol{B}_k \boldsymbol{A}_k \boldsymbol{A}_k^T \boldsymbol{B}_k

の計算を$k$全部に対して一括で計算してみた。
これだけやっても numpy の 行列計算はややっこしい。
einsum が割と早く、慣れれば書きやすいかもしれない。
(それでも、結構書き間違えて間違った計算してしまったけど)

Ak^T Bk の場合

\boldsymbol{A}_k^T \boldsymbol{B}_k \\

別の書き方したら、変わるかもしれないけど、こんな感じ。

  1. この場合 einsum が圧倒的に早い。(func3)
  2. 次点は matmul か numba + 式。(func4, func2)
  3. 最後は numba + for文。(func1)

numbaはコンパイルの制約がある割に、早くない。書き方が悪い?
(list から np.array に変換とか、np.newaxis が使えないとか)

###コード

import numba
from numba import njit

A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))

@njit(cache=True)
def func1(A, B):
    I, J, K = np.shape(B)
    C = np.zeros((I, K))
    for i in range(I):
        C[i] = A[i].dot(B[i])
    return C
    #return np.array([A[i].dot(B[i]) for i in range(len(A))])

@njit(cache=True)
def func2(A, B):
    return (np.expand_dims(A, -1) * B).sum(1)
def func2a(A, B):
    return (np.expand_dims(A, -1) * B).sum(1)
def func3(A, B):
    return np.einsum('km,kmn->kn', A, B)
def func4(A, B):
    return np.matmul(np.expand_dims(A, 1), B).squeeze()

C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
C4 = func4(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3), np.allclose(C1, C4))

print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func4")
%timeit func4(A, B)

###実行結果。

allclose True True True
func1
94.9 µs ± 3.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2
43.3 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
114 µs ± 1.53 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3
20.1 µs ± 907 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
func4
42.8 µs ± 2.87 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Ak^T Bk Akの場合

\boldsymbol{A}_k^T \boldsymbol{B}_k \boldsymbol{A}_k \\

einsumが遅くなってきた。それでもnumba + for文よりは早いけど。

  1. numba + 式(func2)
  2. einsum (func3)
  3. 最後は numba + for文。(func1)

einsumは最適化オプションがあるみたいなので指定してみたけど、遅くなった。
あと、matmulを使った書き方はわからなかった。

###コード

import numba
from numba import jit, njit, prange

A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))

@njit(cache=True)
def func1(A, B):
    I, J, K = np.shape(B)
    C = np.zeros(I)
    for i in range(I):
        C[i] = A[i].dot(B[i]).dot(A[i])
    return C

@njit(cache=True)
def func2(A, B):
    return ((np.expand_dims(A, -1) * B).sum(1) * A).sum(1)
def func2a(A, B):
    return ((np.expand_dims(A, -1) * B).sum(1) * A).sum(1)
def func3(A, B):
    return np.einsum('km,kmn,kn->k', A, B, A)
def func3a(A, B):
    return np.einsum('km,kmn,kn->k', A, B, A, optimize=True)
C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3))

print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func3a")
%timeit func3a(A, B)

実行結果

allclose True True
func1
101 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2
45.4 µs ± 1.11 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
120 µs ± 1.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3
56.2 µs ± 500 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3a
139 µs ± 3.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Bk Ak Ak^T Bk の場合

\boldsymbol{B}_k \boldsymbol{A}_k \boldsymbol{A}_k^T \boldsymbol{B}_k

einsumはさらに遅くなった。最適化すると大分早くなるけど、numba + 式の方が早い。

  1. numba + 式(func2)
  2. einsum (func3)
  3. 最後は numba + for文。(func1)

einsum_path で計算順序の最適化ができるみたい。詳しくは見ていないけど。

import numba
from numba import jit, njit, prange

A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))

@njit(cache=True)
def func1(A, B):
    I, J, K = np.shape(B)
    C = np.zeros((I, J, K))
    for i in range(I):
        C[i] = np.outer(B[i] @ A[i], A[i] @ B[i])
    return C

@njit(cache=True)
def func2(A, B):
    return np.expand_dims((B * np.expand_dims(A, 1)).sum(2), -1) * np.expand_dims((np.expand_dims(A, -1) * B).sum(1), 1)
def func2a(A, B):
    return np.expand_dims((B * np.expand_dims(A, 1)).sum(2), -1) * np.expand_dims((np.expand_dims(A, -1) * B).sum(1), 1)
def func3(A, B):
    return np.einsum('kab,kb,kc,kcd->kad', B, A, A, B, optimize=['einsum_path', (0, 1), (0, 1), (0, 1)])
def func3a(A, B):
    return np.einsum('kab,kb,kc,kcd->kad', B, A, A, B, optimize=True)

C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3))

print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func3a")
%timeit func3a(A, B)

###実行結果

allclose True True
func1
335 µs ± 7.15 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
func2
97.8 µs ± 1.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
246 µs ± 20.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
func3
154 µs ± 3.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3a
250 µs ± 6.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?