はじめに
torch.einsumという関数は、"ij,jk->ik"というequationを入力する必要がある。このequationがどういう動きをするのかよくわからなかったため、簡単な例で確認していく。
equationの入力のi,jの意味
equationに入っているiとかjは、入力した配列の次元数に合わせて用意される。入力する配列が2次元配列ならiやjの文字が2つ用意される。もし、2次元配列が入力されているのにも関わらず、1つしか文字を指定しなかった場合エラーとなる。
>>>a = torch.arange(4).reshape(2,2)
tensor([[0, 1],
[2, 3]])
>>>torch.einsum("ij",a)
tensor([[0, 1],
[2, 3]])
>>torch.einsum("i",a)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.10/dist-packages/torch/functional.py", line 378, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (1) does not match the number of dimensions (2) for operand 0 and no ellipsis was given
3次元配列なら3つ文字が用意される。
>>> b = torch.arange(8).reshape(2,2,2)
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
>>> torch.einsum("ijk",b)
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
>>> torch.einsum("ij",b)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.10/dist-packages/torch/functional.py", line 378, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given
equatiuonの"->"の意味について
equationに現れる"->"は出力を指定するために使用される。"->"の右側が出力の形式の指定である。
以下には、添え字の順番をそのままにした場合と、添え字の順番を反転させた場合の例を示した。
>>> a = torch.arange(4).reshape(2,2)
tensor([[0, 1],
[2, 3]])
>>> torch.einsum("ij->ij",a)
tensor([[0, 1],
[2, 3]])
>>> torch.einsum("ij->ji",a)
tensor([[0, 2],
[1, 3]])
入力配列を$a$、出力配列を$b$とすると、"ij->ij"は
b[i][j] = a[i][j]
iとjを反転させた"ij-ji"は
b[j][i] = a[i][j]
を計算していると考えることができる。
equationの入力の","の意味について
equationにおいて","は入力された配列(オペランド)ごとに添え字の指定を分けるために使用する。
配列を2つ入力した場合、対応する要素は乗算される。簡単な例を以下に示す。
>>> a = torch.arange(3)
tensor([0, 1, 2])
>>> b = torch.arange(3)+1
tensor([1, 2, 3])
>>> torch.einsum("i,j->ij",a,b)
tensor([[0, 0, 0],
[1, 2, 3],
[2, 4, 6]])
上記例では、"i,j->ij"を指定しているが、これは以下の式を計算することとなる。
c[i][j] = a[i]*b[j]
2次元配列と1次元配列の場合は2次元配列の方は文字を2つ、1次元配列の方は文字を1つ指定する必要がある。
>>> a = torch.arange(2)
tensor([0, 1])
>>> b = torch.arange(4).reshape(2,2)
tensor([[1, 2],
[3, 4]])
>>> torch.einsum("i,jk->ijk",a,b)
tensor([[[0, 0],
[0, 0]],
[[0, 1],
[2, 3]]])
上記プログラムは以下の式を計算している。
c[i][j][k] = a[i]*b[j][k]
もちろん、入力の添え字の数が反対になるとエラーとなる
>>> torch.einsum("jk,i->ijk",a,b)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.10/dist-packages/torch/functional.py", line 378, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (1) for operand 0 and no ellipsis was given
ちなみに添え字は同じ文字を使用することもできる。以下の例では、入力にiを2つ指定している。
>>> a = torch.arange(3)
tensor([0, 1, 2])
>>> b = torch.arange(3)+1
tensor([1, 2, 3])
>>> torch.einsum("i,i->i",a,b)
tensor([0, 2, 6])
これは以下の計算をしていることとなる。
c[i] = a[i]*b[i]
equationの添え字に同じ文字を複数回使う
上で示したように、equationには同じ文字を複数回使用することができる。
別の例を示す。
>>> a = torch.arange(4).reshape(2,2)
tensor([[0, 1],
[2, 3]])
>>> torch.einsum("ii->i",a)
tensor([0, 3])
計算している式は
b[i] = a[i][i]
なんとなくお分かりだと思うが、同じ文字を複数回使う場合は同じ大きさである必要がある。例えば、下のように正方行列でない場合はエラーとなる。
>>> a = torch.arange(6).reshape(2,3)
tensor([[0, 1, 2],
[3, 4, 5]])
>>> torch.einsum("ii->i",a)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.10/dist-packages/torch/functional.py", line 378, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): subscript i is repeated for operand 0 but the sizes don't match, 3 != 2
出力で指定されない添え字が入力の添え字に含まれている場合
出力に含まれない添え字が入力に含まれる場合、その含まれない添え字に関しては合計が計算される。
以下に例を示す。
>>> a = torch.arange(4).reshape(2,2)
tensor([[0, 1],
[2, 3]])
>>>torch.einsum("ij->i",a)
tensor([1, 5])
上記例ではjが出力("->"の右側)に含まれていない。そのため、jに関して合計され出力の1要素の値が決定されている。
入力配列を$a$、出力配列を$b$とすると、
b[0] = a[0][0]+a[0][1] = 0 + 1
b[1] = a[1][0]+a[1][1] = 2 + 3
という計算をしていることとなる。
それでは、出力の文字がiではなくjの場合はどうなるだろうか?
>>> a = torch.arange(4).reshape(2,2)
tensor([[0, 1],
[2, 3]])
>>> torch.einsum("ij->j",a)
tensor([2, 4])
jの場合は以下の数式を計算していることとなる。
b[0] = a[0][0]+a[1][0] = 0 + 2
b[1] = a[0][1]+a[1][1] = 1 + 3
torch.einsumによるバッチごとの行列積の計算
上記までのことが理解できれば、バッチごとに行列積を計算する"bij,bjk->bik"を理解できる。
>>> A = torch.arange(8).reshape(2,2,2)
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
>>> B = torch.arange(8).reshape(2,2,2) + 1
tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
>>> torch.einsum("bij,bjk->bik",A,B)
tensor([[[ 3, 4],
[11, 16]],
[[55, 64],
[79, 92]]])
上記の例では出力にjが含まれていない。そのため、jに関しては合計が計算されることとなる。入力配列をA,B、出力行列をCとすると、
C[b][i][k] = \sum_{j}A[b][i][j]*B[b][j][k]
を計算していることとなる。これにより、バッチごとの行列計算がtorch.einsumでどのように実現されているのかを理解することができた。