はじめに
AtCoder で Elixir の Nx が使えるようになったので再入門するシリーズです
以前、 Nx なしで解いたのはこちら
今回実装したノートブックはこちら
問題
入出力例
以下のようになれば OK です
"""
3
8 12 40
"""
|> Main.solve()
|> then(&(&1 == 2))
"""
4
5 6 8 10
"""
|> Main.solve()
|> then(&(&1 == 0))
Nx なしの回答
与えられた各数値について 2 で何回割れるかを計算し、その最小値が答えになります
$1 \le A_i \le 10^9$ の条件から 2^1 から 2^29 の数値を計算しておき、割り切れた数の最大値 = 2 で割れる数となります
defmodule Main do
def main do
:stdio
|> IO.read(:all)
|> solve()
|> IO.puts()
end
defp split_lines(lines) do
lines
|> String.trim()
|> String.split("\n")
end
defp split_words(words) do
String.split(words, " ")
end
defp num_of_power(input, powers) do
powers
|> Enum.map(fn {index, num} ->
case rem(input, num) do
0 ->
index
_ ->
0
end
end)
|> Enum.max()
end
def solve(input) do
nums =
input
|> split_lines()
|> Enum.at(1)
|> split_words()
|> Enum.map(&String.to_integer/1)
powers =
Enum.to_list(1..29)
|> Enum.map(fn power ->
{power, 2 |> :math.pow(power) |> round() }
end)
nums
|> Enum.map(&num_of_power(&1, powers))
|> Enum.min()
end
end
実行時間は 1572 msec でした
Nx ありの回答
Nx を使う場合、主に Enum.map
で繰り返し処理しているところを行列演算に置き換えることになります
まず、 2 の n 乗の計算をしているところ
後の処理との整合性のため、 0 から 29 で計算します
Enum.to_list(0..29)
|> Enum.map(fn power ->
{power, 2 |> :math.pow(power) |> round() }
end)
これを Nx を使って計算してみましょう
まず、 Nx.iota({30})
は以下のような値になります
#Nx.Tensor<
s64[30]
EXLA.Backend<host:0, 0.2046932374.739901452.194665>
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
>
この行列に対して Nx.pow(2, n)
とすることで、 0 から 29 までの 2 の累乗を全て計算します
Nx.pow(2, Nx.iota({30}))
実行結果
#Nx.Tensor<
s64[30]
EXLA.Backend<host:0, 0.2046932374.739901452.194664>
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608, 16777216, 33554432, 67108864, 134217728, 268435456, 536870912]
>
行列演算はこのように一気に並列で計算するため、 GPU やマルチコアの CPU を使うことで高速化可能です
続いて、 Nx を使わない場合、 2 で割れる数は以下のように求めていました
powers
|> Enum.map(fn {index, num} ->
case rem(input, num) do
0 ->
index
_ ->
0
end
end)
|> Enum.max()
これを Nx を使うようにする場合、以下のようになります
input_tensor
|> Nx.new_axis(0)
|> Nx.transpose()
|> Nx.remainder(powers)
|> Nx.reverse(axes: [1])
|> Nx.argmin(axis: 1)
|> then(&Nx.subtract(29, &1))
|> Nx.reduce_min()
例えば input_tensor を [8, 12, 40]
とすると、以下のように変形しています
Nx.new_axis(0)
で次元を増やします
[[8, 12, 40]]
Nx.transpose()
で縦横を入れ替えます
\begin{bmatrix}
8 & 12 & 40
\end{bmatrix}
^\top
=
\begin{bmatrix}
8 \\
12 \\
40
\end{bmatrix}
Nx.remainder(powers)
で 2 の n 乗との余りを一気に計算します
\begin{bmatrix}
8 \\
12 \\
40
\end{bmatrix}
\bmod
\begin{bmatrix}
1 & 2 & 4 & 8 & 16 & ... & 2^{29}
\end{bmatrix}
=
\begin{bmatrix}
0 & 0 & 0 & 0 & 8 & 8 & 8 & ... & 8 \\
0 & 0 & 0 & 4 & 12 & 12 & 12 & ... & 12 \\
0 & 0 & 0 & 0 & 8 & 8 & 40 & ... & 40
\end{bmatrix}
以下のようなイメージの計算をしています(A%BはAをBで割った余り)
- | 1 | 2 | 4 | ... | 2^29 |
---|---|---|---|---|---|
8 | 8%1 | 8%2 | 8%4 | ... | 8%2^29 |
12 | 12%1 | 12%2 | 12%4 | ... | 12%2^29 |
40 | 40%1 | 40%2 | 40%4 | ... | 40%2^29 |
Nx.reverse(axes: [1])
で各行の順序を逆にします
\begin{bmatrix}
8 & ... & 8 & 8 & 8 & 0 & 0 & 0 & 0 \\
12 & ... & 12 & 12 & 12 & 4 & 0 & 0 & 0 \\
40 & ... & 40 & 8 & 8 & 0 & 0 & 0 & 0
\end{bmatrix}
Nx.argmin(axis: 1)
で各行について最小値(0)が最初に登場する列番号を取得します
さっき順序を逆にしたため、実質的には 29 - (0が最後に出た位置)
を求めたことになります
\begin{bmatrix}
26 \\
27 \\
26
\end{bmatrix}
then(&Nx.subtract(29, &1))
をすることで 29 - (29 - (0が最後に出た位置))
= 0が最後に出た位置
を取得します
29 -
\begin{bmatrix}
26 \\
27 \\
26
\end{bmatrix}
=
\begin{bmatrix}
3 \\
2 \\
3
\end{bmatrix}
Nx.reduce_min()
でそのうちの最小値(この例では 2)を取得します
この一連の計算で、各値に対する 2 で割れる最大数を求めたことになります
回答の全文はこちら
defmodule Main do
import Nx.Defn
def main do
:stdio
|> IO.read(:all)
|> solve()
|> IO.puts()
end
defp split_lines(lines) do
lines
|> String.trim()
|> String.split("\n")
end
defp split_words(words) do
String.split(words, " ")
end
defn get_powers() do
2 ** Nx.iota({30})
end
defn calc(input_tensor, powers) do
input_tensor
|> Nx.new_axis(0)
|> Nx.transpose()
|> Nx.remainder(powers)
|> Nx.reverse(axes: [1])
|> Nx.argmin(axis: 1)
|> then(&Nx.subtract(29, &1))
|> Nx.reduce_min()
end
def solve(input) do
nums =
input
|> split_lines()
|> Enum.at(1)
|> split_words()
|> Enum.map(&String.to_integer/1)
|> Nx.tensor()
powers = get_powers()
nums
|> calc(powers)
|> Nx.to_number()
end
end
Nx の計算については defn
で定義することにより、高速化しています
また、累乗の計算は defn
内だと **
を使って表されるため、簡略化できます
実行時間は 1009 msec でした
まとめ
今回の問題は元々が繰り返し(Enum.map
) を多く使っていたため、 Nx を有効活用できたと思います
上手く書けると Enum.map
以上にシンプルに表現できるのが気持ちいいですね