はじめに
PyTorchのドキュメントでちょっと理解できてなかったから残す程度
コードはドキュメントから
torch.max()の使い方
1. 1Dテンソルの場合
a = torch.randn(1, 3)
a
tensor([[ 0.6763, 0.7445, -2.2369]])
torch.max(a)
tensor(0.7445)
うん,一番簡単一次元配列の最大値の要素を返してくれる
2Dテンソルの場合
a = torch.randn(4, 4)
a
tensor([[-1.2360, -0.2942, -0.1222, 0.8475],
[ 1.1949, -1.1127, -2.2379, -0.6702],
[ 1.5717, -0.9207, 0.1297, -1.8768],
[-0.6172, 1.0036, -0.6060, -0.2432]])
torch.max(a, 1)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))
これの第二引数がよくわかんなかったけど
numpyでゆうところのaxisでした.
なので個人的には
a = torch.randn(4, 4)
a
tensor([[-1.2360, -0.2942, -0.1222, 0.8475],
[ 1.1949, -1.1127, -2.2379, -0.6702],
[ 1.5717, -0.9207, 0.1297, -1.8768],
[-0.6172, 1.0036, -0.6060, -0.2432]])
axis = 1
torch.max(a, axis)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))
としたほうがわかりやすいですね.
torch.max(a, axis)
みたいな使い方はクラス分類で使われますね~
ちなみに,自分用ですけど,axisは軸ですよ!(axis=0:col, axis=1:row
)
おわりに
まだまだライブラリと数学に振り回されているので早めにうまく使えるようになりたいです