LoginSignup
23
14

More than 3 years have passed since last update.

einsum(アインシュタインの縮約記法)による多次元配列の掛け算

Posted at

einsumを使うと多次元配列の掛け算が簡単に行えます。
記述方法には、癖がありますが、覚えてしまえば難しくありません。
einsumには、他にもいろいろな演算を行うことができますが、ここでは、多次元配列の掛け算について書きます。

2次元配列

アダマール積

各要素の積

\begin{pmatrix}
1 & 2 & 3 \\
4 & 5 & 6 \\
\end{pmatrix}
×
\begin{pmatrix}
1 & 2 & 3 \\
4 & 5 & 6 \\
\end{pmatrix}
=
\begin{pmatrix}
1 & 4 & 9 \\
16 & 25 & 36 \\
\end{pmatrix}
x = np.array([[1.,2.,3.],[4.,5.,6.]])
y = np.array([[1.,2.,3.],[4.,5.,6.]])

演算子

x*y
array([[ 1.,  4.,  9.],
       [16., 25., 36.]])

ループ

要素ごとに計算すると、以下のようになります。

# 要素ごとに計算
z = np.zeros((2,3))
for i in range(2):
    for j in range(3):
        z[i,j] += x[i,j] * y[i,j]
z
array([[ 1.,  4.,  9.],
       [16., 25., 36.]])

einsum

for文中の計算式の以下の部分の添え字をそのまま書けば良いです。
z[i,j] += x[i,j] * y[i,j]
上の例では、以下のように書きます。
xの添え字,yの添え字->zの添え字
(添え字に利用する文字な何でも構いません。)

# einsumで計算
np.einsum("ij,ij->ij", x, y)

同じ結果になりました。

array([[ 1.,  4.,  9.],
       [16., 25., 36.]])

内積

\begin{pmatrix}
1 & 2 & 3 \\
4 & 5 & 6 \\
\end{pmatrix}
\begin{pmatrix}
1 & 2 & 3 & 4 \\
5 & 6 & 7 & 8 \\
9 & 10 & 11 & 12 \\
\end{pmatrix}
=
\begin{pmatrix}
38 & 44 & 50 & 56 \\
83 & 98 & 113 & 128 \\
\end{pmatrix}
x = np.array([[1.,2.,3.],[4.,5.,6.]])
y = np.array([[1.,2.,3.,4.],[5.,6.,7.,8.],[9.,10.,11.,12.]])

演算

2次元の配列では、np.dotまたはnp.matmulで内積を計算できます。

dot

# ドット積で計算
np.dot(x, y)
array([[ 38.,  44.,  50.,  56.],
       [ 83.,  98., 113., 128.]])

matmul

# matmulで計算
np.matmul(x, y)
array([[ 38.,  44.,  50.,  56.],
       [ 83.,  98., 113., 128.]])

ループ

# 要素ごとに計算
z = np.zeros((2,4))
for i in range(3):
    for j in range(2):
        for k in range(4):
            z[j,k] += x[j,i] * y[i,k]
z
array([[ 38.,  44.,  50.,  56.],
       [ 83.,  98., 113., 128.]])

einsum

for文中の計算式の添え字をそのまま書けばよいのでした。
z[j,k] += x[j,i] * y[i,k]

# einsumで計算
np.einsum("ji,ik->jk", x, y)
array([[ 38.,  44.,  50.,  56.],
       [ 83.,  98., 113., 128.]])

3次元配列

2次元配列のバッチ処理を考えます。

バッチの次元が、1次元目の場合の内積

x = np.array([[[1.,2.,3.],[4.,5.,6.]],[[1.,2.,3.],[4.,5.,6.]]])
print("x.shape=", x.shape)
print("x=")
print(x)
x.shape= (2, 2, 3)
x=
[[[1. 2. 3.]
  [4. 5. 6.]]

 [[1. 2. 3.]
  [4. 5. 6.]]]
y = np.array([[[1.,2.,3.,4.],[5.,6.,7.,8.],[9.,10.,11.,12.]],[[1.,2.,3.,4.],[5.,6.,7.,8.],[9.,10.,11.,12.]]])
print("y.shape=", y.shape)
print("y=")
print(y)
y.shape= (2, 3, 4)
y=
[[[ 1.  2.  3.  4.]
  [ 5.  6.  7.  8.]
  [ 9. 10. 11. 12.]]

 [[ 1.  2.  3.  4.]
  [ 5.  6.  7.  8.]
  [ 9. 10. 11. 12.]]]

内積の結果は、以下となります。

array([[[ 38.,  44.,  50.,  56.],
        [ 83.,  98., 113., 128.]],

       [[ 38.,  44.,  50.,  56.],
        [ 83.,  98., 113., 128.]]])

演算

dot

# ドット積で計算
np.dot(x, y)
array([[[[ 38.,  44.,  50.,  56.],
         [ 38.,  44.,  50.,  56.]],

        [[ 83.,  98., 113., 128.],
         [ 83.,  98., 113., 128.]]],


       [[[ 38.,  44.,  50.,  56.],
         [ 38.,  44.,  50.,  56.]],

        [[ 83.,  98., 113., 128.],
         [ 83.,  98., 113., 128.]]]])

正しく計算できませんでした。

matmul

バッチが1次元目の場合、np.matmulで計算できます。

# matmulで計算
np.matmul(x, y)
array([[[ 38.,  44.,  50.,  56.],
        [ 83.,  98., 113., 128.]],

       [[ 38.,  44.,  50.,  56.],
        [ 83.,  98., 113., 128.]]])

ループ

# 要素ごとに計算
z = np.zeros((2,2,4))
for i in range(2):
    for j in range(3):
        for k in range(2):
            for l in range(4):
                z[i,k,l] += x[i,k,j] * y[i,j,l]
z
array([[[ 38.,  44.,  50.,  56.],
        [ 83.,  98., 113., 128.]],

       [[ 38.,  44.,  50.,  56.],
        [ 83.,  98., 113., 128.]]])

einsum

3次元でも同じです。for文中の計算式の添え字をそのまま書きます。
z[i,k,l] += x[i,k,j] * y[i,j,l]

# einsumで計算
np.einsum("ikj,ijl->ikl", x, y)
array([[[ 38.,  44.,  50.,  56.],
        [ 83.,  98., 113., 128.]],

       [[ 38.,  44.,  50.,  56.],
        [ 83.,  98., 113., 128.]]])

バッチの次元が、2次元目の場合の内積

RNNなどで、バッチの次元を2次元目としたい場合があります。

xt = x.transpose(1,0,2)
print("xt.shape=", xt.shape)
print("xt=")
xt
xt.shape= (2, 2, 3)
xt=
array([[[1., 2., 3.],
        [1., 2., 3.]],

       [[4., 5., 6.],
        [4., 5., 6.]]])
yt = y.transpose(1,0,2)
print("yt.shape=", yt.shape)
print("yt=")
yt
yt.shape= (3, 2, 4)
yt=
array([[[ 1.,  2.,  3.,  4.],
        [ 1.,  2.,  3.,  4.]],

       [[ 5.,  6.,  7.,  8.],
        [ 5.,  6.,  7.,  8.]],

       [[ 9., 10., 11., 12.],
        [ 9., 10., 11., 12.]]])

内積の結果は、以下となります。

array([[[ 38.,  44.,  50.,  56.],
        [ 38.,  44.,  50.,  56.]],

       [[ 83.,  98., 113., 128.],
        [ 83.,  98., 113., 128.]]])

演算

dot

# ドット積で計算
np.dot(xt, yt)
ValueError                                Traceback (most recent call last)
<ipython-input-24-a174c5fa02ae> in <module>
      1 # ドット積で計算
----> 2 np.dot(xt, yt)

<__array_function__ internals> in dot(*args, **kwargs)

ValueError: shapes (2,2,3) and (3,2,4) not aligned: 3 (dim 2) != 2 (dim 1)

エラーとなりました。

matmul

# matmulで計算
np.matmul(xt, yt)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-25-281cba2a720e> in <module>
      1 # matmulで計算
----> 2 np.matmul(xt, yt)

ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 3)

こちらもエラーとなりました。

ループ

# 要素ごとに計算
zt = np.zeros((2,2,4))
for i in range(2):
    for j in range(3):
        for k in range(2):
            for l in range(4):
                zt[k,i,l] += xt[k,i,j] * yt[j,i,l]
zt
array([[[ 38.,  44.,  50.,  56.],
        [ 38.,  44.,  50.,  56.]],

       [[ 83.,  98., 113., 128.],
        [ 83.,  98., 113., 128.]]])

結果を転置し、バッチを1次元目とすると結果が同じとなります。

zt.transpose(1,0,2)
array([[[ 38.,  44.,  50.,  56.],
        [ 83.,  98., 113., 128.]],

       [[ 38.,  44.,  50.,  56.],
        [ 83.,  98., 113., 128.]]])

einsum

for文中の計算式の添え字をそのままですね。
zt[k,i,l] += xt[k,i,j] * yt[j,i,l]

# einsumで計算
np.einsum("kij,jil->kil", xt, yt)
array([[[ 38.,  44.,  50.,  56.],
        [ 38.,  44.,  50.,  56.]],

       [[ 83.,  98., 113., 128.],
        [ 83.,  98., 113., 128.]]])

このようにeinsumを利用すれば、簡単に多次元配列の演算を行うことができます。
転置やtransposeにより変換を行い計算することもできますが、einsumを利用すればそのまま計算することができます。
einsumでは演算に時間がかかる場合もあるようなので、もし、性能的に問題なければ、einsumを利用すれば非常にシンプルに記述することができます。

23
14
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
23
14