数年前に作ったリポジトリの torch==1.7.0 を torch==1.13.0 に置き換えていて気付いた torch.solve の仕様変更に関するメモ。
Pytorchでは、線型方程式 $\mathbf{A} \mathbf{X} = \mathbf{B}$ を解く(行列$\mathbf{A}$ と 行列orベクトル$\mathbf{B}$ から解 $\mathbf{X}$ を求める)ソルバーが提供されている。
torch<=1.7 では torch.solve が利用可能で、torch>=1.8 からは torch.linalg.solve が新しく利用できる。
ただし、
torch.solve と torch.linalg.solve では引数と返り値が異なる
ことに注意する。
torch==1.7.0
X, LU = torch.solve(B, A)
torch==1.8.0
X = torch.linalg.solve(A, B)
LU = torch.lu(A) # LU分解したい場合
そしてさらに
torch==1.9 から torch.solve の利用が非推奨
になり、
torch>=1.12 では torch.linalg.solve に一本化されて torch.solve は使えない
ことに注意。
