15
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

ElixirAdvent Calendar 2022

Day 22

ねえパパ、numpyのnp.max(a, axis=1)、Nxに無いの?

Last updated at Posted at 2022-11-23

やりたい事

Nxで、各列の最大値を値とする行列を作りたい

numpyのnp.max(a, axis=1)に相当する処理です

Nx.maxは二つの行列を与えて、各要素の大きい方をとるものなので、ちょっとちがいました。

argmaxで最大値のindexを調べて、take_along_axisで値を取り出せばできそう。

argmax

iex(1)> a = Nx.tensor([[1,3,2],[2,1,5]])   
#Nx.Tensor<
  s64[2][3]
  [
    [1, 3, 2],
    [2, 1, 5]
  ]
>
iex(2)> i = Nx.argmax(a, axis: 1, keep_axis: true)    
#Nx.Tensor<
  s64[2][1]
  [
    [1],
    [2]
  ]
>

take_along_axis

iex(3)> Nx.take_along_axis(a, i, axis: 1)
#Nx.Tensor<
  s64[2][1]
  [
    [3],
    [5]
  ]
>
iex(4)> 

結論

argmxとtake_along_axisを組み合わせてできる

iex(1)> a = Nx.tensor([[1,3,2],[2,1,5]])   
iex(2)> Nx.take_along_axis(a, Nx.argmax(a, axis: 1, keep_axis: true), axis: 1)
#Nx.Tensor<
  s64[2][1]
  [
    [3],
    [5]
  ]
>
iex(8)> 

結論(最新)

Nx.reduce_maxを発見しました。これでできます。

Nx.reduce_max(a, axes: [1])

結論(最新)(最新)

結論と同じ形状のtensorを得るには、keep_axes: trueが必要でした。

Nx.reduce_max(a, axes: [1], keep_axes: true)
15
3
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?