LoginSignup
13
1

[AtCoder Nx] ABC087B - Coins [Elixir Livebook]

Last updated at Posted at 2023-11-20

はじめに

AtCoder で Elixir の Nx が使えるようになったので再入門するシリーズです

以前、 Nx なしで解いたのはこちら

今回実装したノートブックはこちら

問題

入出力例

以下のようになれば OK です

"""
2
2
2
100
"""
|> Main.solve()
|> then(&(&1 == 2))
"""
5
1
0
150
"""
|> Main.solve()
|> then(&(&1 == 0))
"""
30
40
50
6000
"""
|> Main.solve()
|> then(&(&1 == 213))

Nx なしの回答

内包表現を用いて条件を満たす組み合わせを探索し、その数を出しています

defmodule Main do
  def main do
    :stdio
    |> IO.read(:all)
    |> solve()
    |> IO.puts()
  end

  defp split_lines(lines) do
    lines
    |> String.trim()
    |> String.split("\n")
  end

  def solve(input) do
    [num_500, num_100, num_50, total] =
      input
      |> split_lines()
      |> Enum.map(&String.to_integer/1)

    total_50 = total / 50

    for p_500 <- 0..num_500,
        p_100 <- 0..num_100,
        p_50 <- 0..num_50,
        10 * p_500 + 2 * p_100 + p_50 == total_50 do
      {p_500, p_100, p_50}
    end
    |> Enum.count()
  end
end

実行時間は 804 msec でした

Nx ありの回答

堂々と全ての組み合わせを計算し、合計金額が X 円になった組み合わせの数を出します

今回の場合、3種類の硬貨の全ての組み合わせを計算するため、例えば枚数が以下のような場合を考えます

  • 500 円玉 2 枚
  • 100 円玉 2 枚
  • 50 円玉 2 枚

ここから各硬貨を何枚かずつ使う組み合わせを考えます

まず 500 円玉と 100 円玉だけで考えてみます
組み合わせは以下のように 9 通り考えられます

  • 500 * 0 + 100 * 0
  • 500 * 0 + 100 * 1
  • 500 * 0 + 100 * 2
  • 500 * 1 + 100 * 0
  • 500 * 1 + 100 * 1
  • 500 * 1 + 100 * 2
  • 500 * 2 + 100 * 0
  • 500 * 2 + 100 * 1
  • 500 * 2 + 100 * 2

これは 0 から 2 までの整数の組み合わせを考えているのと同じです

Nx.add() を使うと、全ての和の組み合わせを計算できます

[0, 1, 2]  と [0, 1, 2] の全ての和の組み合わせは以下のように計算します

\begin{bmatrix}
0 & 1 & 2
\end{bmatrix}
^\top
+
\begin{bmatrix}
0 & 1 & 2
\end{bmatrix}
=
\begin{bmatrix}
0 \\
1 \\
2
\end{bmatrix}
+
\begin{bmatrix}
0 & 1 & 2
\end{bmatrix}
=
\begin{bmatrix}
0 && 1 && 2 \\
1 && 2 && 3 \\
2 && 3 && 4
\end{bmatrix}

ここでは以下のようなイメージの計算をしています

- 0 1 2
0 0+0 0+1 0+2
1 1+0 1+1 1+2
2 2+0 2+1 2+2

これは次元が増えても同じように計算できます

つまり、片方の次元を増やして転置して足せば良いだけです

500 円玉の枚数を n としたとき、使った枚数毎の金額を表す行列は以下のコードで表せます

Nx.iota(x) = [0, 1, 2, ..., x - 1]

{n + 1}
|> Nx.iota()
|> Nx.multiply(500)

これらを利用して、最終的なコードは以下のようになります

defmodule Main do
  import Nx.Defn

  def main do
    :stdio
    |> IO.read(:all)
    |> solve()
    |> IO.puts()
  end

  defp split_lines(lines) do
    lines
    |> String.trim()
    |> String.split("\n")
  end

  defn calc(tensor_500, tensor_100, tensor_50, total) do
    Nx.multiply(tensor_500, 500)
    |> Nx.new_axis(0)
    |> Nx.transpose()
    |> Nx.add(Nx.multiply(tensor_100, 100))
    |> Nx.new_axis(0)
    |> Nx.transpose()
    |> Nx.add(Nx.multiply(tensor_50, 50))
    |> Nx.equal(total)
    |> Nx.sum()
  end

  def solve(input) do
    [num_500, num_100, num_50, total] =
      input
      |> split_lines()
      |> Enum.map(&String.to_integer/1)

    tensor_500 = Nx.iota({num_500 + 1})
    tensor_100 = Nx.iota({num_100 + 1})
    tensor_50 = Nx.iota({num_50 + 1})

    calc(tensor_500, tensor_100, tensor_50, total)
    |> Nx.to_number()
  end
end

実行時間は 1718 msec でした

まとめ

Nx を使って組み合わせの問題を解きました

動的計画法よりも、やっていること自体は直観的で単純です

行列演算らしくできたので満足です

13
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
1