はじめに
AtCoder で Elixir の Nx が使えるようになったので再入門するシリーズです
以前、 Nx なしで解いたのはこちら
今回実装したノートブックはこちら
問題
入出力例
以下のようになれば OK です
"""
20 2 5
"""
|> Main.solve()
|> then(&(&1 == 84))
"""
10 1 2
"""
|> Main.solve()
|> then(&(&1 == 13))
"""
100 4 16
"""
|> Main.solve()
|> then(&(&1 == 4554))
Nx なしの回答
N に上限があるため、 1 の位から 10000 の位まで各桁の数値を計算して足しています
defmodule Main do
def main do
:stdio
|> IO.read(:all)
|> solve()
|> IO.puts()
end
defp split_words(words) do
String.split(words, " ")
end
def solve(input) do
[n, a, b] =
input
|> String.trim()
|> split_words()
|> Enum.map(&String.to_integer/1)
1..n
|> Enum.filter(fn num ->
sum =
div(num, 10000) +
(num |> div(1000) |> rem(10)) +
(num |> div(100) |> rem(10)) +
(num |> div(10) |> rem(10)) +
rem(num, 10)
sum >= a && sum <= b
end)
|> Enum.sum()
end
end
実行時間は 1404 msec でした
Nx ありの回答
各桁の数値を行列演算で求めます
つまり、 1234 を [1, 2, 3, 4] の行列に変換します
例として、 0 から 1234 について、桁に分解してみます
Nx.iota
を使い、 0 から 1234 までの行列を作ります
nums = Nx.iota({1235})
実行結果
#Nx.Tensor<
s64[1235]
EXLA.Backend<host:0, 0.558030851.3674079243.224401>
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, ...]
>
次に、桁計算のために、各桁の元になる行列(10^0 から 10^4)を用意します
本来、 reverse する必要はありませんが、見た目を直観的にするために reverse しています
digit = Nx.pow(10, Nx.iota({5})) |> Nx.reverse()
実行結果
#Nx.Tensor<
s64[5]
EXLA.Backend<host:0, 0.558030851.3674079243.224404>
[10000, 1000, 100, 10, 1]
>
以下のようにして、すべての数について各桁に分解できます
sums =
nums
|> Nx.new_axis(0)
|> Nx.transpose()
|> Nx.quotient(digit)
|> Nx.remainder(10)
ABC087B Coins のときに解説した通り、 |> Nx.new_axis(0) |> Nx.transpose()
を使うとすべての組み合わせを計算できます
Nx.quotient
は整数での割り算(商)で、 Nx.remainder
は割り算の余りです
Nx なしのときと同じように、 各桁の 10^n で割ってから、 10 で割った余りを計算することで各桁の値を計算しています
例えば 1234 の 100 の位を計算すると、 1234/100 = 12
12%10=2
で 2 を求めることができます
イメージとしては、以下の計算を 0 から 1234 について一気に実行します
1234 の各桁 = [(1234/1000)%10, (1234/100)%10, (1234/10)%10, (1234/1)%10]
結果を視覚的にわかりやすくするため、 995 から 1004 の範囲で表示してみます
Nx.slice
で各次元について指定した範囲を切り抜きます
sums =
nums
|> Nx.new_axis(0)
|> Nx.transpose()
|> Nx.quotient(digit)
|> Nx.remainder(10)
|> Nx.slice([995, 0], [10, 5])
実行結果
#Nx.Tensor<
s64[10][5]
EXLA.Backend<host:0, 0.558030851.3674079243.224426>
[
[0, 0, 9, 9, 5],
[0, 0, 9, 9, 6],
[0, 0, 9, 9, 7],
[0, 0, 9, 9, 8],
[0, 0, 9, 9, 9],
[0, 1, 0, 0, 0],
[0, 1, 0, 0, 1],
[0, 1, 0, 0, 2],
[0, 1, 0, 0, 3],
[0, 1, 0, 0, 4]
]
>
このように、 995 = [0, 0, 9, 9, 5]
、 1004 = [0, 1, 0, 0, 4]
としてすべて各桁に分割できています
あとは条件として、各桁の値の合計が A 以上、 B 以下になる値を抽出し、その合計を計算します
実装の全文は以下のとおりです
defmodule Main do
import Nx.Defn
def main do
:stdio
|> IO.read(:all)
|> solve()
|> IO.puts()
end
defp split_words(words) do
String.split(words, " ")
end
defn calc(nums, digit, a, b) do
nums
|> Nx.new_axis(0)
|> Nx.transpose()
|> Nx.quotient(digit)
|> Nx.remainder(10)
|> Nx.sum(axes: [1])
|> then(&Nx.logical_and(Nx.greater_equal(&1, a), Nx.less_equal(&1, b)))
|> Nx.multiply(nums)
|> Nx.sum()
end
def solve(input) do
[n, a, b] =
input
|> String.trim()
|> split_words()
|> Enum.map(&String.to_integer/1)
nums = Nx.iota({n + 1})
digit = Nx.pow(10, Nx.iota({5}))
calc(nums, digit, a, b)
|> Nx.to_number()
end
end
実行時間は 1014 msec でした
まとめ
桁の値という、少し特殊な計算も行列演算で実装できました
ループで処理しているようなところは行列演算に代替可能なので、まだまだ色んな応用ができそうですね