0
0

More than 1 year has passed since last update.

【Pytorch】torch.einsumのequationの確認

Posted at

はじめに

torch.einsumという関数は、"ij,jk->ik"というequationを入力する必要がある。このequationがどういう動きをするのかよくわからなかったため、簡単な例で確認していく。

equationの入力のi,jの意味

equationに入っているiとかjは、入力した配列の次元数に合わせて用意される。入力する配列が2次元配列ならiやjの文字が2つ用意される。もし、2次元配列が入力されているのにも関わらず、1つしか文字を指定しなかった場合エラーとなる。

>>>a = torch.arange(4).reshape(2,2)
tensor([[0, 1],
       [2, 3]])

>>>torch.einsum("ij",a)
tensor([[0, 1],
        [2, 3]])

>>torch.einsum("i",a)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/functional.py", line 378, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (1) does not match the number of dimensions (2) for operand 0 and no ellipsis was given

3次元配列なら3つ文字が用意される。

>>> b = torch.arange(8).reshape(2,2,2)
tensor([[[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]]])

>>> torch.einsum("ijk",b)
tensor([[[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]]])

>>> torch.einsum("ij",b)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/functional.py", line 378, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given

equatiuonの"->"の意味について

equationに現れる"->"は出力を指定するために使用される。"->"の右側が出力の形式の指定である。
以下には、添え字の順番をそのままにした場合と、添え字の順番を反転させた場合の例を示した。

>>> a = torch.arange(4).reshape(2,2)
tensor([[0, 1],
        [2, 3]])

>>> torch.einsum("ij->ij",a)
tensor([[0, 1],
        [2, 3]])

>>> torch.einsum("ij->ji",a)
tensor([[0, 2],
        [1, 3]])

入力配列を$a$、出力配列を$b$とすると、"ij->ij"は

b[i][j] = a[i][j]

iとjを反転させた"ij-ji"は

b[j][i] = a[i][j]

を計算していると考えることができる。

equationの入力の","の意味について

equationにおいて","は入力された配列(オペランド)ごとに添え字の指定を分けるために使用する。
配列を2つ入力した場合、対応する要素は乗算される。簡単な例を以下に示す。

>>> a = torch.arange(3)
tensor([0, 1, 2])

>>> b = torch.arange(3)+1
tensor([1, 2, 3])

>>> torch.einsum("i,j->ij",a,b)
tensor([[0, 0, 0],
        [1, 2, 3],
        [2, 4, 6]])

上記例では、"i,j->ij"を指定しているが、これは以下の式を計算することとなる。

c[i][j] = a[i]*b[j]

2次元配列と1次元配列の場合は2次元配列の方は文字を2つ、1次元配列の方は文字を1つ指定する必要がある。

>>> a = torch.arange(2)
tensor([0, 1])

>>> b = torch.arange(4).reshape(2,2)
tensor([[1, 2],
        [3, 4]])

>>> torch.einsum("i,jk->ijk",a,b)
tensor([[[0, 0],
         [0, 0]],

        [[0, 1],
         [2, 3]]])

上記プログラムは以下の式を計算している。

c[i][j][k] = a[i]*b[j][k]

もちろん、入力の添え字の数が反対になるとエラーとなる

>>> torch.einsum("jk,i->ijk",a,b)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/functional.py", line 378, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (1) for operand 0 and no ellipsis was given

ちなみに添え字は同じ文字を使用することもできる。以下の例では、入力にiを2つ指定している。

>>> a = torch.arange(3)
tensor([0, 1, 2])

>>> b = torch.arange(3)+1
tensor([1, 2, 3])

>>> torch.einsum("i,i->i",a,b)
tensor([0, 2, 6])

これは以下の計算をしていることとなる。

c[i] = a[i]*b[i]

equationの添え字に同じ文字を複数回使う

上で示したように、equationには同じ文字を複数回使用することができる。
別の例を示す。

>>> a = torch.arange(4).reshape(2,2)
tensor([[0, 1],
        [2, 3]])

>>> torch.einsum("ii->i",a)
tensor([0, 3])

計算している式は

b[i] = a[i][i]

なんとなくお分かりだと思うが、同じ文字を複数回使う場合は同じ大きさである必要がある。例えば、下のように正方行列でない場合はエラーとなる。

>>> a = torch.arange(6).reshape(2,3)
tensor([[0, 1, 2],
        [3, 4, 5]])

>>> torch.einsum("ii->i",a)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/functional.py", line 378, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): subscript i is repeated for operand 0 but the sizes don't match, 3 != 2

出力で指定されない添え字が入力の添え字に含まれている場合

出力に含まれない添え字が入力に含まれる場合、その含まれない添え字に関しては合計が計算される。
以下に例を示す。

>>> a = torch.arange(4).reshape(2,2)
tensor([[0, 1],
        [2, 3]])

>>>torch.einsum("ij->i",a)
tensor([1, 5])

上記例ではjが出力("->"の右側)に含まれていない。そのため、jに関して合計され出力の1要素の値が決定されている。
入力配列を$a$、出力配列を$b$とすると、

b[0] = a[0][0]+a[0][1] = 0 + 1
b[1] = a[1][0]+a[1][1] = 2 + 3

という計算をしていることとなる。
それでは、出力の文字がiではなくjの場合はどうなるだろうか?

>>> a = torch.arange(4).reshape(2,2)
tensor([[0, 1],
        [2, 3]])

>>> torch.einsum("ij->j",a)
tensor([2, 4])

jの場合は以下の数式を計算していることとなる。

b[0] = a[0][0]+a[1][0] = 0 + 2
b[1] = a[0][1]+a[1][1] = 1 + 3

torch.einsumによるバッチごとの行列積の計算

上記までのことが理解できれば、バッチごとに行列積を計算する"bij,bjk->bik"を理解できる。

>>> A = torch.arange(8).reshape(2,2,2)
tensor([[[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]]])

>>> B = torch.arange(8).reshape(2,2,2) + 1
tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]])

>>> torch.einsum("bij,bjk->bik",A,B)
tensor([[[ 3,  4],
         [11, 16]],

        [[55, 64],
         [79, 92]]])

上記の例では出力にjが含まれていない。そのため、jに関しては合計が計算されることとなる。入力配列をA,B、出力行列をCとすると、

C[b][i][k] = \sum_{j}A[b][i][j]*B[b][j][k]

を計算していることとなる。これにより、バッチごとの行列計算がtorch.einsumでどのように実現されているのかを理解することができた。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0