はじめに
PyTorchのTensorは基本的にNumPyのndarrayのように扱えますが、ところどころに異なる部分があります。
特にTensorの最大値、最小値を求めるmax()
, min()
はよく使うのにNumPyと挙動が異なっていて扱いづらいです。
この違いを緩和するためにNumPyと同じようなスタイルでTensorの最大値、最小値を求められる関数を書きました。
PyTorchのmax()
とNumPyのmax()
の違い
Tensorまたはndarrayの全体の最大値を求めたい場合は両者の違いはありません。
つまり、以下の式は両方とも(型は違いますが)同じ値が返ってきます。
torch.tensor([1,2,3,4,5]).max()
np.array([1,2,3,4,5]).max()
しかし、NumPyで言うところのaxis
を指定する場合、すなわち特定の軸に沿って最大値を求めたい場合の挙動が異なっています。
NumPyではaxisに複数の軸をタプルなどでまとめて指定でき、例えば画像のチャンネル毎に最大値を求めるといった使い方ができます。
x = np.array([[[8, 2, 2],
[6, 2, 3],
[8, 2, 4]],
[[8, 4, 9],
[0, 3, 9],
[5, 5, 3]],
[[6, 5, 5],
[4, 8, 0],
[1, 6, 0]]])
x.max(axis=(1,2)) # -> [8, 9, 8]
しかし、Tensorのmax()
では一度に一つの軸しか指定できません(しかも引数名がaxis
ではなくdim
)。
更に、Tensorのmax()
では最大値とその位置を示すインデックス(=argmax)がタプルで返ってきます1。
この仕様がとても扱いづらく、上記のNumPyと同じことを行いたい場合は返り値の0番目を取り出してmaxを重ねがけするという、とても不格好なことをしなくてはなりません。
x = torch.tensor([[[8, 2, 2],
[6, 2, 3],
[8, 2, 4]],
[[8, 4, 9],
[0, 3, 9],
[5, 5, 3]],
[[6, 5, 5],
[4, 8, 0],
[1, 6, 0]]])
x.max(dim=2)[0].max(dim=1)[0]
また、上記のコードのようにする場合、dimを指定する順番によってはmax()を行う度に軸のインデックスがずれていくため、意図した通りの軸の最大値が求まらないといったバグの温床にもなりえます。
書いた関数
from typing import Sequence, Union
from torch import Tensor
def tensor_max(x: Tensor,
axis: Union[int, Sequence[int], None] = None,
keepdims: bool = False
) -> Tensor:
if axis is None:
axis = range(x.ndim)
elif isinstance(axis, int):
axis = [axis]
else:
axis = sorted(axis)
for ax in axis[::-1]:
x = x.max(dim=ax, keepdim=keepdims)[0]
return x
def tensor_min(x: Tensor,
axis: Union[int, Sequence[int], None] = None,
keepdims: bool = False
) -> Tensor::
if axis is None:
axis = range(x.ndim)
elif isinstance(axis, int):
axis = [axis]
else:
axis = sorted(axis)
for ax in axis[::-1]:
x = x.min(dim=ax, keepdim=keepdims)[0]
return x
例えばx
の最大値を求めたいときはtensor_max(x)
のように書きます。
引数にはNumPyと同様にaxis
, keepdims
を指定できます。
axis
を複数指定した場合でも、降順でmax (min)をかけていくため、指定する順番によって結果が変化することはありません。
これを使うことで、PyTorchで複数の軸に渡る最大値、最小値もシンプルに書けるようになります。
x = torch.tensor([[[8, 2, 2],
[6, 2, 3],
[8, 2, 4]],
[[8, 4, 9],
[0, 3, 9],
[5, 5, 3]],
[[6, 5, 5],
[4, 8, 0],
[1, 6, 0]]])
x.max(dim=2)[0].max(dim=1)[0] # before
tensor_max(x, axis=(1,2)) # after
おわりに
PyTorchは便利ですが、Tensorの仕様がNumPyと似ているようで微妙に異なるので戸惑うことが多いです。
統一されてほしいものです。