8
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【PyTorch】torch.einsumの挙動について

Posted at

はじめに

Einsumは、様々な行列の演算ができます。通常、行列積や内積の計算では、行列の形に制約がありますが、Einsumは、添え字を使ってどんな形の行列でも計算ができてしまいます。
それゆえ、挙動が理解しにくいです。そこで、ここではfor文で実装して、挙動を分かりやすくしてみました。

torch.einsum 公式ドキュメント
https://pytorch.org/docs/stable/generated/torch.einsum.html

Einsumの実行結果

(4×3)のaと、(2×4)のbにEinsumを使ってみます。

import torch

a = torch.tensor(range(1, 13)).reshape(4, 3)
b = torch.tensor(range(1, 9)).reshape(2, 4)
print(a)
print(b)
einsum1 = torch.einsum("nk,cv->nv", (a, b))
einsum2 = torch.einsum("nk,cv->nv", (b, a))
print(einsum1)
print(einsum2)

実行結果
aとb

tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

einsum1とeinsum2

tensor([[ 36,  48,  60,  72],
        [ 90, 120, 150, 180],
        [144, 192, 240, 288],
        [198, 264, 330, 396]])
tensor([[220, 260, 300],
        [572, 676, 780]])

for文での実装

import numpy as np

a = np.array(range(1, 13)).reshape(4, 3)
b = np.array(range(1, 9)).reshape(2, 4)
print(a)
print(b)

print("einsum('nk,cv->nv', (a, b))")
# einsum.shape = (n, v) = (4, 4)
einsum = np.zeros((4, 4))
for n in range(a.shape[0]):
    for k in range(a.shape[1]):
        for c in range(b.shape[0]):
            for v in range(b.shape[1]):
                einsum[n, v] += a[n, k] * b[c, v]
print(einsum)

print("einsum('nk,cv->nv', (b, a))")
# einsum.shape = (n, v) = (2, 3)
einsum = np.zeros((2, 3))
for n in range(b.shape[0]):
    for k in range(b.shape[1]):
        for c in range(a.shape[0]):
            for v in range(a.shape[1]):
                einsum[n, v] += b[n, k] * a[c, v]
print(einsum)

実行結果

[[ 1  2  3]
 [ 4  5  6]
 [ 7  8  9]
 [10 11 12]]
[[1 2 3 4]
 [5 6 7 8]]
einsum('nk,cv->nv', (a, b))
[[ 36.  48.  60.  72.]
 [ 90. 120. 150. 180.]
 [144. 192. 240. 288.]
 [198. 264. 330. 396.]]
einsum('nk,cv->nv', (b, a))
[[220. 260. 300.]
 [572. 676. 780.]]

便利な使い方

for文で実装したものは、挙動を理解するために特殊な使い方で行いました。実際には、以下のような場合で使うことがあるかと思います。

バッチごとに同じ行列をかけたい

例えば、バッチサイズ4で4×3の行列があるとしたときに、この4×3の行列に対して、何か処理を行いたい場合があります。
以下では、4×3の行列の上半分を0にする例を紹介します。

import torch

a = torch.tensor(range(1, 13)).reshape(4, 3)
batch4_a = torch.stack((a, a, a, a))
print(batch4_a)

upper_mask = torch.cat([torch.zeros(2, 3), torch.ones(2, 3)])
print(upper_mask)

einsum = torch.einsum("bij,ij->bij", (batch4_a, upper_mask))
print(einsum)

実行結果
batch4_a

tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9],
         [10, 11, 12]],

        [[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9],
         [10, 11, 12]],

        [[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9],
         [10, 11, 12]],

        [[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9],
         [10, 11, 12]]])

upper_mask

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [1., 1., 1.],
        [1., 1., 1.]])

einsum

tensor([[[ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 7.,  8.,  9.],
         [10., 11., 12.]],

        [[ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 7.,  8.,  9.],
         [10., 11., 12.]],

        [[ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 7.,  8.,  9.],
         [10., 11., 12.]],

        [[ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 7.,  8.,  9.],
         [10., 11., 12.]]])
8
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?