LoginSignup
29
13

More than 3 years have passed since last update.

[PyTorch]torch.max()でちょっと迷ったこと

Last updated at Posted at 2019-12-07

はじめに

PyTorchのドキュメントでちょっと理解できてなかったから残す程度
コードはドキュメントから

torch.max()の使い方

1. 1Dテンソルの場合

a = torch.randn(1, 3)
a
tensor([[ 0.6763,  0.7445, -2.2369]])
torch.max(a)
tensor(0.7445)

うん,一番簡単一次元配列の最大値の要素を返してくれる

2Dテンソルの場合

a = torch.randn(4, 4)
a
tensor([[-1.2360, -0.2942, -0.1222,  0.8475],
        [ 1.1949, -1.1127, -2.2379, -0.6702],
        [ 1.5717, -0.9207,  0.1297, -1.8768],
        [-0.6172,  1.0036, -0.6060, -0.2432]])
torch.max(a, 1)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))

これの第二引数がよくわかんなかったけど
numpyでゆうところのaxisでした.
なので個人的には

a = torch.randn(4, 4)
a
tensor([[-1.2360, -0.2942, -0.1222,  0.8475],
        [ 1.1949, -1.1127, -2.2379, -0.6702],
        [ 1.5717, -0.9207,  0.1297, -1.8768],
        [-0.6172,  1.0036, -0.6060, -0.2432]])
axis = 1
torch.max(a, axis)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))

としたほうがわかりやすいですね.
torch.max(a, axis)みたいな使い方はクラス分類で使われますね~
ちなみに,自分用ですけど,axisは軸ですよ!(axis=0:col, axis=1:row)

おわりに

まだまだライブラリと数学に振り回されているので早めにうまく使えるようになりたいです

29
13
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
29
13