3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ABC404 DをAIのコードを参考にして、Elixirらしい回答を作ってみる

Posted at

はじめに

AIにElixirのコードを書かせてみたところ、よくできてました。

もしかしたら、Elixirらしい良いコードを書いてくれるかもしれないと思い、AtCoderのABC404 Dを解かせてみました。

o3で試したところ、正解するコードを生成してきました。
気になった部分をリファクタリングして、私なりにElixirらしいコードにしてみました。

ABC404のD問題の回答が含まれているので、自分で解いてみたいという型は、読まずに、まず解いてみてください。

問題

前回のAtCoderのD問題を試して見ます。

回答の方針

各動物園に0,1,2回訪問する場合の全探索を行う方針を出してきました。
私がPythonで解いたものと同じ方針です。
このプログラムを見ていきます。

solve関数

全探査機を3進数の値にして行う方法となっています。
再帰を使った探索をもってくるのかと思ったのですが、これもアリですね。

nの値が3進数の値で、0から3 ** n-1までの値を全て探索します。
nの値から、visitsとcostを求めるdecode_state関数を呼び出します。

visitsは、各動物園に訪問した回数を表すリストです。
costは、訪問した動物園のコストの合計を表しています。

全ての動物を2回以上みるの条件を満たしているかを確認してcostを更新します。

solve
    # 動物園の訪問回数を全探索して、最小コストを求める
    defp solve(n, costs, animals) do
        total_states = 3 ** n

        0..(total_states - 1)
        |> Enum.reduce(:infinity, fn state, best ->
          {visits, cost} = decode_state(state, n, costs)

          cond do
            cost >= best ->
              best                                      # 打ち切り

            covers_all?(visits, animals) ->
              min(cost, best)                           # 候補更新

            true ->
              best
          end
        end)
    end

decode_state関数

decode_state関数は、3進数の値を展開して、各動物園の訪問回数とコストを求める関数です。

n桁の3進数なので、reduceを使って、0からn-1までの値を回していきます。
accの3番目の要素に3進数の値を持たせて、remとdivを使って、3進数の値を展開していきます。
この値の初期値を求めたい値、numにしています。

decode_state
    # state を 3 進数に展開して訪問回数リストと費用を返す
    defp decode_state(num, n, costs) do
        Enum.reduce(0..(n - 1), {[], 0, num}, fn idx, {vs, csum, x} ->
            {x, rem3} = {div(x, 3), rem(x, 3)}
            { [rem3 | vs], csum + rem3 * Enum.at(costs, idx), x }
        end)
        |> then(fn {vs, csum, _} -> {Enum.reverse(vs), csum} end)
    end

例えば、decode_state(10, 4, [1, 3, 9, 27])を実行すると

iex(3)> Main.decode_state(11,4,[1,3,9,27])
{[2, 0, 1, 0], 11}

となります。
3進数の値は、2,0,1,0で、コストは11となります。

再帰を使った記述にしてみる

decode_state再帰版
  defp decode_state(num, remaining, costs, visits \\ [], cost_sum \\ 0)

  defp decode_state(_num, 0, _costs, visits, cost_sum) do
      {Enum.reverse(visits), cost_sum}
  end

  defp decode_state(num, remaining, costs, visits , cost_sum) do
      rem3 = rem(num, 3)
      zoo_index = length(costs) - remaining
      cost = rem3 * Enum.at(costs, zoo_index)

      decode_state(
          div(num, 3),
          remaining - 1,
          costs,
          [rem3 | visits],
          cost_sum + cost
      )
  end

デフォルト引数の記述方法がChatGPTは知らなかったようなので、修正してあげました。

もっとわかりやすくならない?

もっとわかりやすくならない?と聞いてみた結果

decode_state3進数文字列版
    @spec decode_state(integer, pos_integer, [integer]) :: {[0 | 1 | 2], integer}
    defp decode_state(state, n, costs) do
      # ① 3 進数の桁列(下位 → 上位)を取得して逆順に
      digits = Integer.digits(state, 3) |> Enum.reverse()

      # ② n 桁に満たなければ左側を 0 でパディング
      visits = List.duplicate(0, n - length(digits)) ++ digits

      # ③ visits と costs を zip しながらコスト合計を計算
      total_cost =
        Enum.zip(visits, costs)
        |> Enum.reduce(0, fn {v, c}, acc -> acc + v * c end)

      {visits, total_cost}
    end

このコードは、一見正しそうなんですが、試してみたら、WAが発生しました。
0のパディングをEnum.reverse()した後に行ってるためでした。おしい。

3進数の変換とtotal_costの計算を分けて行うほうがわかりやすそうなので、この方針に直す事にします。

covers_all?関数

Enum.all?で、それぞれの動物の訪問回数が2回以上かを確認しています。全ての動物でtrue(2回以上)であれば、trueを返します。
最初のコードは次のようなコードでした。

covers_all?
    # すべての動物を 2 回以上見られるか判定
    def covers_all?(visits, animals) do
        Enum.all?(animals, fn zoos ->
            Enum.reduce(zoos, 0, fn z, acc ->
            acc = acc + Enum.at(visits, z)
            if acc >= 2, do: {:halt, acc}, else: {:cont, acc}
            end) >= 2
        end)
    end

AtCoder的には、必須ではない考慮ですが、reduce_whileを使って、訪れた回数が2回になったら、打ち切るようにreduce_whileを使っています。

covers_all?の改良版

このコードでもいいんですが、別の記述にできないか聞いてみます。

covers_all?改良版
    def covers_all?(visits, animals) do
        Enum.all?(animals, fn zoos ->
          zoos
          |> Enum.map(&Enum.at(visits, &1))
          |> Enum.sum() >= 2
        end)
      end

おお。これはわかりやすい。
初見でこのコードは書けないなぁ。

最終的なコード

私好みなコードに整形してみました。
最終的な回答です。
細かい関数にわかれてるので、AtCoderの回答としては、少し冗長かもしれませんが、わかりやすいコードになっていると思います。

defmodule Main do
    def next_token(acc \\ "") do
      case IO.getn(:stdio, "", 1) do
        " " -> acc
        "\n" -> acc
        x -> next_token(acc <> x)
      end
    end
    def input(), do: IO.read(:line) |> String.trim()
    def ii(), do: next_token() |> String.to_integer()
    def li(), do: input() |> String.split(" ") |> Enum.map(&String.to_integer/1)

  # --- 問題固有ロジック ---------------------------------------------------
  @doc """
  全状態 (3^n 通り) を列挙し,条件を満たすものの最小コストを返す
  """
  defp solve(n, costs, animals) do
    for state <- 0..(3 ** n - 1), reduce: :infinity do
      best -> evaluate_state(state, best, n, costs, animals)
    end
  end

  # 各 state を評価して best を更新する reducer
  defp evaluate_state(state, best, n, costs, animals) do
    {visits, cost} = get_visits_and_cost(state, n, costs)

    cond do
      cost >= best               -> best       # すでに最良コストを上回る
      covers_all?(visits, animals) -> cost       # 条件を満たすので最良更新
      true                       -> best       # 条件未達
    end
  end


  # nを3進数に変換。結果は [0 | 1 | 2] で返す
  @spec to_base3(non_neg_integer, pos_integer) :: [0 | 1 | 2]
  @spec to_base3(non_neg_integer, pos_integer, [0 | 1 | 2]) :: [0 | 1 | 2]
  defp to_base3(_state, left, acc \\ [])

  # どちらでもよいけど、LSBが先頭になる順で返す仕様にしておく。
  defp to_base3(_state, 0, acc) do
    Enum.reverse(acc)
  end

  defp to_base3(state, left, acc) do
    to_base3(div(state, 3), left - 1, [rem(state, 3) | acc])
  end

  @spec get_visits_and_cost(integer, pos_integer, [integer]) :: {[0 | 1 | 2], integer}
  defp get_visits_and_cost(state, n, costs) do
    visits = to_base3(state, n)
    total_cost =
      Enum.zip(visits, costs)
      |> Enum.reduce(0, fn {v, c}, acc -> acc + v * c end)

    {visits, total_cost}
  end

  # 全ての動物が2回以上になっているか?
  defp covers_all?(visits, animals) do
    Enum.all?(animals, fn zoos ->
      zoos
      |> Enum.map(&Enum.at(visits, &1))
      |> Enum.sum() >= 2
    end)
  end

  def main do
    [n, m]  = li()
    costs   = li()

    animals =
      for _ <- 1..m do
        [k | a] = li()
        Enum.map(a, &(&1 - 1))        # 0-index 化
      end

    IO.puts(solve(n, costs, animals))
  end
end

まとめ

  • covers_all?の記述は、言われてみればそうなんですが、初見では思いつか無かったので、勉強になりました。
  • 色々な記述方法を提案してもらえたのでその中から、好みのものを選んでみる事ができました。
3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?