はじめに
AIにElixirのコードを書かせてみたところ、よくできてました。
もしかしたら、Elixirらしい良いコードを書いてくれるかもしれないと思い、AtCoderのABC404 Dを解かせてみました。
o3で試したところ、正解するコードを生成してきました。
気になった部分をリファクタリングして、私なりにElixirらしいコードにしてみました。
ABC404のD問題の回答が含まれているので、自分で解いてみたいという型は、読まずに、まず解いてみてください。
問題
前回のAtCoderのD問題を試して見ます。
回答の方針
各動物園に0,1,2回訪問する場合の全探索を行う方針を出してきました。
私がPythonで解いたものと同じ方針です。
このプログラムを見ていきます。
solve関数
全探査機を3進数の値にして行う方法となっています。
再帰を使った探索をもってくるのかと思ったのですが、これもアリですね。
nの値が3進数の値で、0から3 ** n-1までの値を全て探索します。
nの値から、visitsとcostを求めるdecode_state関数を呼び出します。
visitsは、各動物園に訪問した回数を表すリストです。
costは、訪問した動物園のコストの合計を表しています。
全ての動物を2回以上みるの条件を満たしているかを確認してcostを更新します。
# 動物園の訪問回数を全探索して、最小コストを求める
defp solve(n, costs, animals) do
total_states = 3 ** n
0..(total_states - 1)
|> Enum.reduce(:infinity, fn state, best ->
{visits, cost} = decode_state(state, n, costs)
cond do
cost >= best ->
best # 打ち切り
covers_all?(visits, animals) ->
min(cost, best) # 候補更新
true ->
best
end
end)
end
decode_state関数
decode_state関数は、3進数の値を展開して、各動物園の訪問回数とコストを求める関数です。
n桁の3進数なので、reduceを使って、0からn-1までの値を回していきます。
accの3番目の要素に3進数の値を持たせて、remとdivを使って、3進数の値を展開していきます。
この値の初期値を求めたい値、numにしています。
# state を 3 進数に展開して訪問回数リストと費用を返す
defp decode_state(num, n, costs) do
Enum.reduce(0..(n - 1), {[], 0, num}, fn idx, {vs, csum, x} ->
{x, rem3} = {div(x, 3), rem(x, 3)}
{ [rem3 | vs], csum + rem3 * Enum.at(costs, idx), x }
end)
|> then(fn {vs, csum, _} -> {Enum.reverse(vs), csum} end)
end
例えば、decode_state(10, 4, [1, 3, 9, 27])を実行すると
iex(3)> Main.decode_state(11,4,[1,3,9,27])
{[2, 0, 1, 0], 11}
となります。
3進数の値は、2,0,1,0で、コストは11となります。
再帰を使った記述にしてみる
defp decode_state(num, remaining, costs, visits \\ [], cost_sum \\ 0)
defp decode_state(_num, 0, _costs, visits, cost_sum) do
{Enum.reverse(visits), cost_sum}
end
defp decode_state(num, remaining, costs, visits , cost_sum) do
rem3 = rem(num, 3)
zoo_index = length(costs) - remaining
cost = rem3 * Enum.at(costs, zoo_index)
decode_state(
div(num, 3),
remaining - 1,
costs,
[rem3 | visits],
cost_sum + cost
)
end
デフォルト引数の記述方法がChatGPTは知らなかったようなので、修正してあげました。
もっとわかりやすくならない?
もっとわかりやすくならない?と聞いてみた結果
@spec decode_state(integer, pos_integer, [integer]) :: {[0 | 1 | 2], integer}
defp decode_state(state, n, costs) do
# ① 3 進数の桁列(下位 → 上位)を取得して逆順に
digits = Integer.digits(state, 3) |> Enum.reverse()
# ② n 桁に満たなければ左側を 0 でパディング
visits = List.duplicate(0, n - length(digits)) ++ digits
# ③ visits と costs を zip しながらコスト合計を計算
total_cost =
Enum.zip(visits, costs)
|> Enum.reduce(0, fn {v, c}, acc -> acc + v * c end)
{visits, total_cost}
end
このコードは、一見正しそうなんですが、試してみたら、WAが発生しました。
0のパディングをEnum.reverse()した後に行ってるためでした。おしい。
3進数の変換とtotal_costの計算を分けて行うほうがわかりやすそうなので、この方針に直す事にします。
covers_all?関数
Enum.all?で、それぞれの動物の訪問回数が2回以上かを確認しています。全ての動物でtrue(2回以上)であれば、trueを返します。
最初のコードは次のようなコードでした。
# すべての動物を 2 回以上見られるか判定
def covers_all?(visits, animals) do
Enum.all?(animals, fn zoos ->
Enum.reduce(zoos, 0, fn z, acc ->
acc = acc + Enum.at(visits, z)
if acc >= 2, do: {:halt, acc}, else: {:cont, acc}
end) >= 2
end)
end
AtCoder的には、必須ではない考慮ですが、reduce_whileを使って、訪れた回数が2回になったら、打ち切るようにreduce_whileを使っています。
covers_all?の改良版
このコードでもいいんですが、別の記述にできないか聞いてみます。
def covers_all?(visits, animals) do
Enum.all?(animals, fn zoos ->
zoos
|> Enum.map(&Enum.at(visits, &1))
|> Enum.sum() >= 2
end)
end
おお。これはわかりやすい。
初見でこのコードは書けないなぁ。
最終的なコード
私好みなコードに整形してみました。
最終的な回答です。
細かい関数にわかれてるので、AtCoderの回答としては、少し冗長かもしれませんが、わかりやすいコードになっていると思います。
defmodule Main do
def next_token(acc \\ "") do
case IO.getn(:stdio, "", 1) do
" " -> acc
"\n" -> acc
x -> next_token(acc <> x)
end
end
def input(), do: IO.read(:line) |> String.trim()
def ii(), do: next_token() |> String.to_integer()
def li(), do: input() |> String.split(" ") |> Enum.map(&String.to_integer/1)
# --- 問題固有ロジック ---------------------------------------------------
@doc """
全状態 (3^n 通り) を列挙し,条件を満たすものの最小コストを返す
"""
defp solve(n, costs, animals) do
for state <- 0..(3 ** n - 1), reduce: :infinity do
best -> evaluate_state(state, best, n, costs, animals)
end
end
# 各 state を評価して best を更新する reducer
defp evaluate_state(state, best, n, costs, animals) do
{visits, cost} = get_visits_and_cost(state, n, costs)
cond do
cost >= best -> best # すでに最良コストを上回る
covers_all?(visits, animals) -> cost # 条件を満たすので最良更新
true -> best # 条件未達
end
end
# nを3進数に変換。結果は [0 | 1 | 2] で返す
@spec to_base3(non_neg_integer, pos_integer) :: [0 | 1 | 2]
@spec to_base3(non_neg_integer, pos_integer, [0 | 1 | 2]) :: [0 | 1 | 2]
defp to_base3(_state, left, acc \\ [])
# どちらでもよいけど、LSBが先頭になる順で返す仕様にしておく。
defp to_base3(_state, 0, acc) do
Enum.reverse(acc)
end
defp to_base3(state, left, acc) do
to_base3(div(state, 3), left - 1, [rem(state, 3) | acc])
end
@spec get_visits_and_cost(integer, pos_integer, [integer]) :: {[0 | 1 | 2], integer}
defp get_visits_and_cost(state, n, costs) do
visits = to_base3(state, n)
total_cost =
Enum.zip(visits, costs)
|> Enum.reduce(0, fn {v, c}, acc -> acc + v * c end)
{visits, total_cost}
end
# 全ての動物が2回以上になっているか?
defp covers_all?(visits, animals) do
Enum.all?(animals, fn zoos ->
zoos
|> Enum.map(&Enum.at(visits, &1))
|> Enum.sum() >= 2
end)
end
def main do
[n, m] = li()
costs = li()
animals =
for _ <- 1..m do
[k | a] = li()
Enum.map(a, &(&1 - 1)) # 0-index 化
end
IO.puts(solve(n, costs, animals))
end
end
まとめ
- covers_all?の記述は、言われてみればそうなんですが、初見では思いつか無かったので、勉強になりました。
- 色々な記述方法を提案してもらえたのでその中から、好みのものを選んでみる事ができました。