LoginSignup
9
5

More than 1 year has passed since last update.

PyTorchのTensorのmax, minの仕様が不便

Last updated at Posted at 2020-09-06

はじめに

PyTorchのTensorは基本的にNumPyのndarrayのように扱えますが、ところどころに異なる部分があります。

特にTensorの最大値、最小値を求めるmax(), min()はよく使うのにNumPyと挙動が異なっていて扱いづらいです。
この違いを緩和するためにNumPyと同じようなスタイルでTensorの最大値、最小値を求められる関数を書きました。

PyTorchのmax()とNumPyのmax()の違い

Tensorまたはndarrayの全体の最大値を求めたい場合は両者の違いはありません。
つまり、以下の式は両方とも(型は違いますが)同じ値が返ってきます。

PyTorch
torch.tensor([1,2,3,4,5]).max()
NumPy
np.array([1,2,3,4,5]).max()

しかし、NumPyで言うところのaxisを指定する場合、すなわち特定の軸に沿って最大値を求めたい場合の挙動が異なっています。

NumPyではaxisに複数の軸をタプルなどでまとめて指定でき、例えば画像のチャンネル毎に最大値を求めるといった使い方ができます。

NumPy
x = np.array([[[8, 2, 2],
               [6, 2, 3],
               [8, 2, 4]],
              [[8, 4, 9],
               [0, 3, 9],
               [5, 5, 3]],
              [[6, 5, 5],
               [4, 8, 0],
               [1, 6, 0]]])
x.max(axis=(1,2))    # -> [8, 9, 8]

しかし、Tensorのmax()では一度に一つの軸しか指定できません(しかも引数名がaxisではなくdim)。

更に、Tensorのmax()では最大値とその位置を示すインデックス(=argmax)がタプルで返ってきます1

この仕様がとても扱いづらく、上記のNumPyと同じことを行いたい場合は返り値の0番目を取り出してmaxを重ねがけするという、とても不格好なことをしなくてはなりません。

PyTorch
x = torch.tensor([[[8, 2, 2],
               [6, 2, 3],
               [8, 2, 4]],
              [[8, 4, 9],
               [0, 3, 9],
               [5, 5, 3]],
              [[6, 5, 5],
               [4, 8, 0],
               [1, 6, 0]]])
x.max(dim=2)[0].max(dim=1)[0]

また、上記のコードのようにする場合、dimを指定する順番によってはmax()を行う度に軸のインデックスがずれていくため、意図した通りの軸の最大値が求まらないといったバグの温床にもなりえます。

書いた関数

from typing import Sequence, Union
from torch import Tensor

def tensor_max(x: Tensor,
               axis: Union[int, Sequence[int], None] = None,
               keepdims: bool = False
               ) -> Tensor:
    if axis is None:
        axis = range(x.ndim)
    elif isinstance(axis, int):
        axis = [axis]
    else:
        axis = sorted(axis)

    for ax in axis[::-1]:
        x = x.max(dim=ax, keepdim=keepdims)[0]

    return x

def tensor_min(x: Tensor,
               axis: Union[int, Sequence[int], None] = None,
               keepdims: bool = False
               ) -> Tensor::
    if axis is None:
        axis = range(x.ndim)
    elif isinstance(axis, int):
        axis = [axis]
    else:
        axis = sorted(axis)

    for ax in axis[::-1]:
        x = x.min(dim=ax, keepdim=keepdims)[0]

    return x

例えばxの最大値を求めたいときはtensor_max(x)のように書きます。
引数にはNumPyと同様にaxis, keepdimsを指定できます。

axisを複数指定した場合でも、降順でmax (min)をかけていくため、指定する順番によって結果が変化することはありません。

これを使うことで、PyTorchで複数の軸に渡る最大値、最小値もシンプルに書けるようになります。

x = torch.tensor([[[8, 2, 2],
               [6, 2, 3],
               [8, 2, 4]],
              [[8, 4, 9],
               [0, 3, 9],
               [5, 5, 3]],
              [[6, 5, 5],
               [4, 8, 0],
               [1, 6, 0]]])
x.max(dim=2)[0].max(dim=1)[0]    # before
tensor_max(x, axis=(1,2))        # after

おわりに

PyTorchは便利ですが、Tensorの仕様がNumPyと似ているようで微妙に異なるので戸惑うことが多いです。
統一されてほしいものです。

  1. https://pytorch.org/docs/stable/generated/torch.max.html

9
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
5