はじめに
AtCoder で Elixir の Nx が使えるようになったので再入門するシリーズです
以前、 Nx なしで解いたのはこちら
今回実装したノートブックはこちら
問題
入出力例
以下のようになれば OK です
"""
2
2
2
100
"""
|> Main.solve()
|> then(&(&1 == 2))
"""
5
1
0
150
"""
|> Main.solve()
|> then(&(&1 == 0))
"""
30
40
50
6000
"""
|> Main.solve()
|> then(&(&1 == 213))
Nx なしの回答
内包表現を用いて条件を満たす組み合わせを探索し、その数を出しています
defmodule Main do
def main do
:stdio
|> IO.read(:all)
|> solve()
|> IO.puts()
end
defp split_lines(lines) do
lines
|> String.trim()
|> String.split("\n")
end
def solve(input) do
[num_500, num_100, num_50, total] =
input
|> split_lines()
|> Enum.map(&String.to_integer/1)
total_50 = total / 50
for p_500 <- 0..num_500,
p_100 <- 0..num_100,
p_50 <- 0..num_50,
10 * p_500 + 2 * p_100 + p_50 == total_50 do
{p_500, p_100, p_50}
end
|> Enum.count()
end
end
実行時間は 804 msec でした
Nx ありの回答
堂々と全ての組み合わせを計算し、合計金額が X 円になった組み合わせの数を出します
今回の場合、3種類の硬貨の全ての組み合わせを計算するため、例えば枚数が以下のような場合を考えます
- 500 円玉 2 枚
- 100 円玉 2 枚
- 50 円玉 2 枚
ここから各硬貨を何枚かずつ使う組み合わせを考えます
まず 500 円玉と 100 円玉だけで考えてみます
組み合わせは以下のように 9 通り考えられます
- 500 * 0 + 100 * 0
- 500 * 0 + 100 * 1
- 500 * 0 + 100 * 2
- 500 * 1 + 100 * 0
- 500 * 1 + 100 * 1
- 500 * 1 + 100 * 2
- 500 * 2 + 100 * 0
- 500 * 2 + 100 * 1
- 500 * 2 + 100 * 2
これは 0 から 2 までの整数の組み合わせを考えているのと同じです
Nx.add()
を使うと、全ての和の組み合わせを計算できます
[0, 1, 2] と [0, 1, 2] の全ての和の組み合わせは以下のように計算します
\begin{bmatrix}
0 & 1 & 2
\end{bmatrix}
^\top
+
\begin{bmatrix}
0 & 1 & 2
\end{bmatrix}
=
\begin{bmatrix}
0 \\
1 \\
2
\end{bmatrix}
+
\begin{bmatrix}
0 & 1 & 2
\end{bmatrix}
=
\begin{bmatrix}
0 && 1 && 2 \\
1 && 2 && 3 \\
2 && 3 && 4
\end{bmatrix}
ここでは以下のようなイメージの計算をしています
- | 0 | 1 | 2 |
---|---|---|---|
0 | 0+0 | 0+1 | 0+2 |
1 | 1+0 | 1+1 | 1+2 |
2 | 2+0 | 2+1 | 2+2 |
これは次元が増えても同じように計算できます
つまり、片方の次元を増やして転置して足せば良いだけです
500 円玉の枚数を n としたとき、使った枚数毎の金額を表す行列は以下のコードで表せます
Nx.iota(x)
= [0, 1, 2, ..., x - 1]
{n + 1}
|> Nx.iota()
|> Nx.multiply(500)
これらを利用して、最終的なコードは以下のようになります
defmodule Main do
import Nx.Defn
def main do
:stdio
|> IO.read(:all)
|> solve()
|> IO.puts()
end
defp split_lines(lines) do
lines
|> String.trim()
|> String.split("\n")
end
defn calc(tensor_500, tensor_100, tensor_50, total) do
Nx.multiply(tensor_500, 500)
|> Nx.new_axis(0)
|> Nx.transpose()
|> Nx.add(Nx.multiply(tensor_100, 100))
|> Nx.new_axis(0)
|> Nx.transpose()
|> Nx.add(Nx.multiply(tensor_50, 50))
|> Nx.equal(total)
|> Nx.sum()
end
def solve(input) do
[num_500, num_100, num_50, total] =
input
|> split_lines()
|> Enum.map(&String.to_integer/1)
tensor_500 = Nx.iota({num_500 + 1})
tensor_100 = Nx.iota({num_100 + 1})
tensor_50 = Nx.iota({num_50 + 1})
calc(tensor_500, tensor_100, tensor_50, total)
|> Nx.to_number()
end
end
実行時間は 1718 msec でした
まとめ
Nx を使って組み合わせの問題を解きました
動的計画法よりも、やっていること自体は直観的で単純です
行列演算らしくできたので満足です