はじめに
Einsumは、様々な行列の演算ができます。通常、行列積や内積の計算では、行列の形に制約がありますが、Einsumは、添え字を使ってどんな形の行列でも計算ができてしまいます。
それゆえ、挙動が理解しにくいです。そこで、ここではfor文で実装して、挙動を分かりやすくしてみました。
torch.einsum 公式ドキュメント
https://pytorch.org/docs/stable/generated/torch.einsum.html
Einsumの実行結果
(4×3)のaと、(2×4)のbにEinsumを使ってみます。
import torch
a = torch.tensor(range(1, 13)).reshape(4, 3)
b = torch.tensor(range(1, 9)).reshape(2, 4)
print(a)
print(b)
einsum1 = torch.einsum("nk,cv->nv", (a, b))
einsum2 = torch.einsum("nk,cv->nv", (b, a))
print(einsum1)
print(einsum2)
実行結果
aとb
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
einsum1とeinsum2
tensor([[ 36, 48, 60, 72],
[ 90, 120, 150, 180],
[144, 192, 240, 288],
[198, 264, 330, 396]])
tensor([[220, 260, 300],
[572, 676, 780]])
for文での実装
import numpy as np
a = np.array(range(1, 13)).reshape(4, 3)
b = np.array(range(1, 9)).reshape(2, 4)
print(a)
print(b)
print("einsum('nk,cv->nv', (a, b))")
# einsum.shape = (n, v) = (4, 4)
einsum = np.zeros((4, 4))
for n in range(a.shape[0]):
for k in range(a.shape[1]):
for c in range(b.shape[0]):
for v in range(b.shape[1]):
einsum[n, v] += a[n, k] * b[c, v]
print(einsum)
print("einsum('nk,cv->nv', (b, a))")
# einsum.shape = (n, v) = (2, 3)
einsum = np.zeros((2, 3))
for n in range(b.shape[0]):
for k in range(b.shape[1]):
for c in range(a.shape[0]):
for v in range(a.shape[1]):
einsum[n, v] += b[n, k] * a[c, v]
print(einsum)
実行結果
[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]
[10 11 12]]
[[1 2 3 4]
[5 6 7 8]]
einsum('nk,cv->nv', (a, b))
[[ 36. 48. 60. 72.]
[ 90. 120. 150. 180.]
[144. 192. 240. 288.]
[198. 264. 330. 396.]]
einsum('nk,cv->nv', (b, a))
[[220. 260. 300.]
[572. 676. 780.]]
便利な使い方
for文で実装したものは、挙動を理解するために特殊な使い方で行いました。実際には、以下のような場合で使うことがあるかと思います。
バッチごとに同じ行列をかけたい
例えば、バッチサイズ4で4×3の行列があるとしたときに、この4×3の行列に対して、何か処理を行いたい場合があります。
以下では、4×3の行列の上半分を0にする例を紹介します。
import torch
a = torch.tensor(range(1, 13)).reshape(4, 3)
batch4_a = torch.stack((a, a, a, a))
print(batch4_a)
upper_mask = torch.cat([torch.zeros(2, 3), torch.ones(2, 3)])
print(upper_mask)
einsum = torch.einsum("bij,ij->bij", (batch4_a, upper_mask))
print(einsum)
実行結果
batch4_a
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]],
[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]],
[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]],
[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]]])
upper_mask
tensor([[0., 0., 0.],
[0., 0., 0.],
[1., 1., 1.],
[1., 1., 1.]])
einsum
tensor([[[ 0., 0., 0.],
[ 0., 0., 0.],
[ 7., 8., 9.],
[10., 11., 12.]],
[[ 0., 0., 0.],
[ 0., 0., 0.],
[ 7., 8., 9.],
[10., 11., 12.]],
[[ 0., 0., 0.],
[ 0., 0., 0.],
[ 7., 8., 9.],
[10., 11., 12.]],
[[ 0., 0., 0.],
[ 0., 0., 0.],
[ 7., 8., 9.],
[10., 11., 12.]]])