記事の内容と目的
深層学習のコードを触っている際に、torch.einsum 関数が用いられており、理解に苦しんだので、備忘のため記事にまとめておきます。
torch.einsum とは?
einsum はアインシュタインの縮約記号に基づいて、テンソルの演算をする関数であり、テンソルに対する演算の多くはこの関数を用いて実装することが可能です。
(可能と言っているだけで推奨しているわけではありません)
torch.einsum の記法
torch.einsum では、以下のような記法で関数が記述されます。
torch.einsum(equation, *operands)
equation 部分は (入力テンソルの形状)→(出力テンソルの形状) という形で、
例えば、3 × 4 のテンソルと 4 × 5 のテンソルの行列積を計算する場合には
ij,jk->ik (34,45->35)
のように記述します。
ここの文字については、どんな文字を用いても問題ありません。
(ab,bc->ac と書いても st,tu->su と書いても同様に動作します)
operands の部分は、演算に使用されるテンソルを順に与えます。
例に出した行列積の例だと
import torch
A = torch.tensor(range(0, 12)).reshape(3, 4)
B = torch.tensor(range(0, 20)).reshape(4, 5)
matmul = torch.einsum("ij,jk->ik", A, B)
print(A)
print(B)
print(matmul)
------------------------------------------------------------------------
result:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]])
tensor([[ 70, 76, 82, 88, 94],
[190, 212, 234, 256, 278],
[310, 348, 386, 424, 462]])
のように実装すると良いです。
torch.einsum の動作
equation で用いられる次元の表記に基づいて、動作を示します。
入力オペランドの共通する次元のうち、出力に含まれない次元(例えば、torch.einsum("ij,jk->ik", $A, B$) における j)は縮約されます。
縮約されるとは、縮約される次元方向に和を取ることを指します。
例えば、以下のような記述の場合、$j$方向に縮約されるので、演算処理は
import torch
# A と B の例
A = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9,10,11,12]])
B = torch.tensor([[2, 1, 1, 3],
[1, 2, 1, 2],
[1, 1, 2, 1]])
# einsum("ij,ij->i") の計算過程
# j次元(列方向)が縮約される
result = torch.einsum("ij,ij->i", A, B)
print(result) # 結果: [19, 40, 53]
result[0] = A[0,0]*B[0,0] + A[0,1]*B[0,1] + A[0,2]*B[0,2] + A[0,3]*B[0,3]
= 1*2 + 2*1 + 3*1 + 4*3 = 2 + 2 + 3 + 12 = 19
result[1] = A[1,0]*B[1,0] + A[1,1]*B[1,1] + A[1,2]*B[1,2] + A[1,3]*B[1,3]
= 5*1 + 6*2 + 7*1 + 8*2 = 5 + 12 + 7 + 16 = 40
result[2] = A[2,0]*B[2,0] + A[2,1]*B[2,1] + A[2,2]*B[2,2] + A[2,3]*B[2,3]
= 9*1 + 10*1 + 11*2 + 12*1 = 9 + 10 + 22 + 12 = 53
# 結果: [19, 40, 53]
入力オペランドの共通する次元のうち、出力に含まれる次元は、縮約の対象とならず放置されます。例えば、バッチ内の全サンプルに対して行列積を計算するような処理は、以下のように記述されます。
A = torch.randn(8, 3, 4)
B = torch.randn(8, 4, 5)
result = torch.einsum("bij,bjk->bik", A, B)
この処理において、入力、出力形状の両方に含まれている次元 $b$ はこの計算において放置されます。
その上で、$ij$, $jk$ の成分に対して $j$ 方向にテンソルが縮約され、バッチ毎の行列積を計算する処理と同等の処理となっています。
torch.einsum で計算できる量
ここからは、torch.einsum を用いて計算できる演算のうち簡単目なものをいくつか紹介します。
外積
外積とは
「各要素に対して、他方のすべての要素との積を計算する演算」
のことです。
# ベクトルの外積
a = torch.randn(3)
b = torch.randn(4)
result = torch.einsum("i,j->ij", a, b) # (3, 4)
# 等価: torch.outer(a, b)
# 高次元外積
A = torch.randn(2, 3)
B = torch.randn(4, 5)
result = torch.einsum("ij,kl->ijkl", A, B) # (2, 3, 4, 5)
アダマール積(要素積)
アダマール積とは、同一の形状の行列の組において、全ての要素に対して要素積を計算し返す演算のことです。
# 要素積
A = torch.randn(3, 4)
B = torch.randn(3, 4)
result = torch.einsum("ij,ij->ij", A, B) # (3, 4)
# 等価: A * B
# 要素積 + 次元縮約
result = torch.einsum("ij,ij->i", A, B) # (3,) 行ごとの内積
result = torch.einsum("ij,ij->j", A, B) # (4,) 列ごとの内積
result = torch.einsum("ij,ij->", A, B) # スカラー 全体の内積
終わりに
今回は、torch.einsum の動作について解説しました。
torch.einsum は使い方次第で非常に様々な演算を計算できる関数であることがわかりましたね。
それではみなさんの良いエンジニアリングライフを!