Pytorchで3次元以上の配列の掛け算を行う必要に迫られたので、そのやり方をメモします。
torch.matmul(tensor1,tensor2)
上記の関数を使います。
一見して行列同士の掛け算を行う関数に思いますが、3次元以上の配列でも演算を行ってくれるようです。
3次元以上の配列の積といっても、テンソル積のようなものではなく、ミニバッチの行列演算を行ってくれるものです。
使用例(3次元配列と2次元配列)
$M$個の行列をまとめたミニバッチt1 $(M,p,q)$と$M$個のベクトルをまとめたミニバッチt2$(M,p)$がそれぞれあるとします。
このときミニバッチ中の各サンプルに対して次のように計算したいとします。
$y_i = x_iW_i\ \ \ \ \cdots (1)$
where $i = 1,\dots,M$
$x_i$はt1に含まれる、shape$(1,p)$の行ベクトル,$W_i$はt2に含まれるshape$(p,q)$の行列です。
もしW_iが一つだけで、t1がshape$(p,q)$の二次元配列なら上記の計算のためには行列演算として、t2*t1を計算しておけばよいのですが、いまshapeは$(M,p,q)$なので、行列演算できません。
そこでtorch.matmul()を使います。
torch.matmul()では引数の配列のうち、下二つの次元を行列とし、残りをサンプルとして計算を行います。
(1)式においてすべての$i$についての掛け算をやってくれます。
コード例です。
import torch
t1 = torch.FloatTensor(3,2,2)
for i in range(3):
t1[i,:,:] = torch.eye(2)*2**i
print('t1:\n',t1)
'''
t1:
tensor([[[ 1., 0.],
[ 0., 1.]],
[[ 2., 0.],
[ 0., 2.]],
[[ 4., 0.],
[ 0., 4.]]])
'''
t2 = torch.FloatTensor(3,2)
for i in range(3):
t2[i,:] = 2**(2-i)
print('t2:\n',t2)
'''
t2:
tensor([[ 4., 4.],
[ 2., 2.],
[ 1., 1.]])
'''
t2 = t2.unsqueeze(1)
t2t1 = torch.matmul(t2,t1)
print('t2t1:\n',t1t2)
print('t2t1.shape:\n',t1t2.shape)
'''
t1t2:
tensor([[[ 4., 4.]],
[[ 4., 4.]],
[[ 4., 4.]]])
t1t2.shape:
torch.Size([3, 1, 2])
'''
t1は(3,2,2)のテンソル、t2は(3,2)のテンソルです。
コード中の計算としては、[4,4],[2,2],[1,1]のベクトルを1,2,4倍しているだけです。
1次元目はサンプルとして、各サンプルにおける(1)式を計算したいとします。
まず、テンソルの次元はあっていないといけないので、t2を3次元にします。
行ベクトルにしたいので、unsqueezeで間に(3,2)の真ん中に1次元追加して(3,1,2)にします。
そのつぎにtorch.matmul(t2,t1)とすると、下2次元を使って(1,2),(2,2)の行列の掛け算とみなして、ミニバッチt2,t1に含まれる各サンプルに対してその行列計算をしてくれます。
帰ってくる配列のshapeは(3,1,2)です。
4次元以上
下2次元を行列とみて、残りの次元をサンプルと見てミニバッチ計算する点は同じです。
と言い切っておきますが、各サンプルに対してどう計算されているか少しよくわかりませんでした。必要に迫られなかったので今回はここまでとしました。
詳しくはテキストを読んでいただきたいです。
なんでやった?
重みやバイアス自体にノイズを加えて計算したかったためです。