LoginSignup
16
1

[AtCoder Nx] ABC083B - Some Sums [Elixir Livebook]

Last updated at Posted at 2023-11-22

はじめに

AtCoder で Elixir の Nx が使えるようになったので再入門するシリーズです

以前、 Nx なしで解いたのはこちら

今回実装したノートブックはこちら

問題

入出力例

以下のようになれば OK です

"""
20 2 5
"""
|> Main.solve()
|> then(&(&1 == 84))
"""
10 1 2
"""
|> Main.solve()
|> then(&(&1 == 13))
"""
100 4 16
"""
|> Main.solve()
|> then(&(&1 == 4554))

Nx なしの回答

N に上限があるため、 1 の位から 10000 の位まで各桁の数値を計算して足しています

defmodule Main do
  def main do
    :stdio
    |> IO.read(:all)
    |> solve()
    |> IO.puts()
  end

  defp split_words(words) do
    String.split(words, " ")
  end

  def solve(input) do
    [n, a, b] =
      input
      |> String.trim()
      |> split_words()
      |> Enum.map(&String.to_integer/1)

    1..n
    |> Enum.filter(fn num ->
      sum =
        div(num, 10000) +
          (num |> div(1000) |> rem(10)) +
          (num |> div(100) |> rem(10)) +
          (num |> div(10) |> rem(10)) +
          rem(num, 10)

      sum >= a && sum <= b
    end)
    |> Enum.sum()
  end
end

実行時間は 1404 msec でした

Nx ありの回答

各桁の数値を行列演算で求めます

つまり、 1234 を [1, 2, 3, 4] の行列に変換します

例として、 0 から 1234 について、桁に分解してみます

Nx.iota を使い、 0 から 1234 までの行列を作ります

nums = Nx.iota({1235})

実行結果

#Nx.Tensor<
  s64[1235]
  EXLA.Backend<host:0, 0.558030851.3674079243.224401>
  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, ...]
>

次に、桁計算のために、各桁の元になる行列(10^0 から 10^4)を用意します

本来、 reverse する必要はありませんが、見た目を直観的にするために reverse しています

digit = Nx.pow(10, Nx.iota({5})) |> Nx.reverse()

実行結果

#Nx.Tensor<
  s64[5]
  EXLA.Backend<host:0, 0.558030851.3674079243.224404>
  [10000, 1000, 100, 10, 1]
>

以下のようにして、すべての数について各桁に分解できます

sums =
  nums
  |> Nx.new_axis(0)
  |> Nx.transpose()
  |> Nx.quotient(digit)
  |> Nx.remainder(10)

ABC087B Coins のときに解説した通り、 |> Nx.new_axis(0) |> Nx.transpose() を使うとすべての組み合わせを計算できます

Nx.quotient は整数での割り算(商)で、 Nx.remainder は割り算の余りです
Nx なしのときと同じように、 各桁の 10^n で割ってから、 10 で割った余りを計算することで各桁の値を計算しています
例えば 1234 の 100 の位を計算すると、 1234/100 = 12 12%10=2 で 2 を求めることができます

イメージとしては、以下の計算を 0 から 1234 について一気に実行します

1234 の各桁 = [(1234/1000)%10, (1234/100)%10, (1234/10)%10, (1234/1)%10]

結果を視覚的にわかりやすくするため、 995 から 1004 の範囲で表示してみます
Nx.slice で各次元について指定した範囲を切り抜きます

sums =
  nums
  |> Nx.new_axis(0)
  |> Nx.transpose()
  |> Nx.quotient(digit)
  |> Nx.remainder(10)
  |> Nx.slice([995, 0], [10, 5])

実行結果

#Nx.Tensor<
  s64[10][5]
  EXLA.Backend<host:0, 0.558030851.3674079243.224426>
  [
    [0, 0, 9, 9, 5],
    [0, 0, 9, 9, 6],
    [0, 0, 9, 9, 7],
    [0, 0, 9, 9, 8],
    [0, 0, 9, 9, 9],
    [0, 1, 0, 0, 0],
    [0, 1, 0, 0, 1],
    [0, 1, 0, 0, 2],
    [0, 1, 0, 0, 3],
    [0, 1, 0, 0, 4]
  ]
>

このように、 995 = [0, 0, 9, 9, 5] 、 1004 = [0, 1, 0, 0, 4] としてすべて各桁に分割できています

あとは条件として、各桁の値の合計が A 以上、 B 以下になる値を抽出し、その合計を計算します

実装の全文は以下のとおりです

defmodule Main do
  import Nx.Defn

  def main do
    :stdio
    |> IO.read(:all)
    |> solve()
    |> IO.puts()
  end

  defp split_words(words) do
    String.split(words, " ")
  end

  defn calc(nums, digit, a, b) do
    nums
    |> Nx.new_axis(0)
    |> Nx.transpose()
    |> Nx.quotient(digit)
    |> Nx.remainder(10)
    |> Nx.sum(axes: [1])
    |> then(&Nx.logical_and(Nx.greater_equal(&1, a), Nx.less_equal(&1, b)))
    |> Nx.multiply(nums)
    |> Nx.sum()
  end

  def solve(input) do
    [n, a, b] =
      input
      |> String.trim()
      |> split_words()
      |> Enum.map(&String.to_integer/1)

    nums = Nx.iota({n + 1})
    digit = Nx.pow(10, Nx.iota({5}))

    calc(nums, digit, a, b)
    |> Nx.to_number()
  end
end

実行時間は 1014 msec でした

まとめ

桁の値という、少し特殊な計算も行列演算で実装できました

ループで処理しているようなところは行列演算に代替可能なので、まだまだ色んな応用ができそうですね

16
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
1