LoginSignup
14
2

[AtCoder Nx] ABC081B - Shift only [Elixir Livebook]

Last updated at Posted at 2023-11-20

はじめに

AtCoder で Elixir の Nx が使えるようになったので再入門するシリーズです

以前、 Nx なしで解いたのはこちら

今回実装したノートブックはこちら

問題

入出力例

以下のようになれば OK です

"""
3
8 12 40
"""
|> Main.solve()
|> then(&(&1 == 2))
"""
4
5 6 8 10
"""
|> Main.solve()
|> then(&(&1 == 0))

Nx なしの回答

与えられた各数値について 2 で何回割れるかを計算し、その最小値が答えになります

$1 \le A_i \le 10^9$ の条件から 2^1 から 2^29 の数値を計算しておき、割り切れた数の最大値 = 2 で割れる数となります

defmodule Main do
  def main do
    :stdio
    |> IO.read(:all)
    |> solve()
    |> IO.puts()
  end

  defp split_lines(lines) do
    lines
    |> String.trim()
    |> String.split("\n")
  end

  defp split_words(words) do
    String.split(words, " ")
  end

  defp num_of_power(input, powers) do
    powers
    |> Enum.map(fn {index, num} ->
      case rem(input, num) do
        0 ->
          index
        _ ->
          0
      end
    end)
    |> Enum.max()
  end

  def solve(input) do
    nums =
      input
      |> split_lines()
      |> Enum.at(1)
      |> split_words()
      |> Enum.map(&String.to_integer/1)

    powers =
      Enum.to_list(1..29)
      |> Enum.map(fn power ->
        {power, 2 |> :math.pow(power) |> round() }
      end)
 
    nums
    |> Enum.map(&num_of_power(&1, powers))
    |> Enum.min()
  end
end

実行時間は 1572 msec でした

Nx ありの回答

Nx を使う場合、主に Enum.map で繰り返し処理しているところを行列演算に置き換えることになります

まず、 2 の n 乗の計算をしているところ

後の処理との整合性のため、 0 から 29 で計算します

Enum.to_list(0..29)
|> Enum.map(fn power ->
    {power, 2 |> :math.pow(power) |> round() }
end)

これを Nx を使って計算してみましょう

まず、 Nx.iota({30}) は以下のような値になります

#Nx.Tensor<
  s64[30]
  EXLA.Backend<host:0, 0.2046932374.739901452.194665>
  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
>

この行列に対して Nx.pow(2, n) とすることで、 0 から 29 までの 2 の累乗を全て計算します

Nx.pow(2, Nx.iota({30}))

実行結果

#Nx.Tensor<
  s64[30]
  EXLA.Backend<host:0, 0.2046932374.739901452.194664>
  [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608, 16777216, 33554432, 67108864, 134217728, 268435456, 536870912]
>

行列演算はこのように一気に並列で計算するため、 GPU やマルチコアの CPU を使うことで高速化可能です

続いて、 Nx を使わない場合、 2 で割れる数は以下のように求めていました

powers
|> Enum.map(fn {index, num} ->
  case rem(input, num) do
    0 ->
      index
    _ ->
      0
  end
end)
|> Enum.max()

これを Nx を使うようにする場合、以下のようになります

input_tensor
|> Nx.new_axis(0)
|> Nx.transpose()
|> Nx.remainder(powers)
|> Nx.reverse(axes: [1])
|> Nx.argmin(axis: 1)
|> then(&Nx.subtract(29, &1))
|> Nx.reduce_min()

例えば input_tensor を [8, 12, 40] とすると、以下のように変形しています

Nx.new_axis(0) で次元を増やします

[[8, 12, 40]]

Nx.transpose() で縦横を入れ替えます

\begin{bmatrix}
8 & 12 & 40
\end{bmatrix}
^\top
=
\begin{bmatrix}
8 \\
12 \\
40
\end{bmatrix}

Nx.remainder(powers) で 2 の n 乗との余りを一気に計算します

\begin{bmatrix}
8 \\
12 \\
40
\end{bmatrix}
\bmod
\begin{bmatrix}
1 & 2 & 4 & 8 & 16 & ... & 2^{29}
\end{bmatrix}
=
\begin{bmatrix}
0 & 0 & 0 & 0 & 8 & 8 & 8 & ... & 8 \\
0 & 0 & 0 & 4 & 12 & 12 & 12 & ... & 12 \\
0 & 0 & 0 & 0 & 8 & 8 & 40 & ... & 40
\end{bmatrix}

以下のようなイメージの計算をしています(A%BはAをBで割った余り)

- 1 2 4 ... 2^29
8 8%1 8%2 8%4 ... 8%2^29
12 12%1 12%2 12%4 ... 12%2^29
40 40%1 40%2 40%4 ... 40%2^29

Nx.reverse(axes: [1]) で各行の順序を逆にします

\begin{bmatrix}
8 & ... & 8 & 8 & 8 & 0 & 0 & 0 & 0 \\
12 & ... & 12 & 12 & 12 & 4 & 0 & 0 & 0 \\
40 & ... & 40 & 8 & 8 & 0 & 0 & 0 & 0
\end{bmatrix}

Nx.argmin(axis: 1) で各行について最小値(0)が最初に登場する列番号を取得します

さっき順序を逆にしたため、実質的には 29 - (0が最後に出た位置) を求めたことになります

\begin{bmatrix}
26 \\
27 \\
26
\end{bmatrix}

then(&Nx.subtract(29, &1)) をすることで 29 - (29 - (0が最後に出た位置)) = 0が最後に出た位置 を取得します

29 -
\begin{bmatrix}
26 \\
27 \\
26
\end{bmatrix}
=
\begin{bmatrix}
3 \\
2 \\
3
\end{bmatrix}

Nx.reduce_min() でそのうちの最小値(この例では 2)を取得します

この一連の計算で、各値に対する 2 で割れる最大数を求めたことになります

回答の全文はこちら

defmodule Main do
  import Nx.Defn

  def main do
    :stdio
    |> IO.read(:all)
    |> solve()
    |> IO.puts()
  end

  defp split_lines(lines) do
    lines
    |> String.trim()
    |> String.split("\n")
  end

  defp split_words(words) do
    String.split(words, " ")
  end

  defn get_powers() do
    2 ** Nx.iota({30})
  end

  defn calc(input_tensor, powers) do
    input_tensor
    |> Nx.new_axis(0)
    |> Nx.transpose()
    |> Nx.remainder(powers)
    |> Nx.reverse(axes: [1])
    |> Nx.argmin(axis: 1)
    |> then(&Nx.subtract(29, &1))
    |> Nx.reduce_min()
  end

  def solve(input) do
    nums =
      input
      |> split_lines()
      |> Enum.at(1)
      |> split_words()
      |> Enum.map(&String.to_integer/1)
      |> Nx.tensor()

    powers = get_powers()

    nums
    |> calc(powers)
    |> Nx.to_number()
  end
end

Nx の計算については defn で定義することにより、高速化しています

また、累乗の計算は defn 内だと ** を使って表されるため、簡略化できます

実行時間は 1009 msec でした

まとめ

今回の問題は元々が繰り返し(Enum.map) を多く使っていたため、 Nx を有効活用できたと思います

上手く書けると Enum.map 以上にシンプルに表現できるのが気持ちいいですね

14
2
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
2