やりたい事
Nxで、各列の最大値を値とする行列を作りたい
numpyのnp.max(a, axis=1)に相当する処理です
Nx.maxは二つの行列を与えて、各要素の大きい方をとるものなので、ちょっとちがいました。
argmaxで最大値のindexを調べて、take_along_axisで値を取り出せばできそう。
argmax
iex(1)> a = Nx.tensor([[1,3,2],[2,1,5]])
#Nx.Tensor<
s64[2][3]
[
[1, 3, 2],
[2, 1, 5]
]
>
iex(2)> i = Nx.argmax(a, axis: 1, keep_axis: true)
#Nx.Tensor<
s64[2][1]
[
[1],
[2]
]
>
take_along_axis
iex(3)> Nx.take_along_axis(a, i, axis: 1)
#Nx.Tensor<
s64[2][1]
[
[3],
[5]
]
>
iex(4)>
結論
argmxとtake_along_axisを組み合わせてできる
iex(1)> a = Nx.tensor([[1,3,2],[2,1,5]])
iex(2)> Nx.take_along_axis(a, Nx.argmax(a, axis: 1, keep_axis: true), axis: 1)
#Nx.Tensor<
s64[2][1]
[
[3],
[5]
]
>
iex(8)>
結論(最新)
Nx.reduce_maxを発見しました。これでできます。
Nx.reduce_max(a, axes: [1])
結論(最新)(最新)
結論と同じ形状のtensorを得るには、keep_axes: trueが必要でした。
Nx.reduce_max(a, axes: [1], keep_axes: true)