LoginSignup
9
0

Elixir ABC330問題Dのギリギリを攻めてみる

Last updated at Posted at 2023-12-09

AtCoder ABC330をElixirで復習してみよう で問題DがTLEで解けませんでした。

詳しく調べてみます

テストデータ作成

2023/12/8現在またテストケースが公開されていないので、TLEになったテストケースをローカルで実行してみることはできません。
単純にN=2000のランダムデータを生成して時間計測してみます。
テストデータ作成プログラム

defmodule CreateData do

  def print_line(f,n) do
    str =
      for i <- 1..n, reduce: "" do
      acc -> acc <> Enum.random(["o","x"])
      end
    IO.write(f,str)
    IO.write(f,"\n")
  end

  def main() do
    n = 2000
    {:ok,f} = File.open("in_4.txt",[:write])
    IO.write(f,n)
    IO.write(f,"\n")
    for i <- 1..n do
      print_line(f,n)
    end
    File.close(f)
  end

end

TLEになったプログラム

オリジナル
defmodule Main do
    import Bitwise
    def next_token(acc \\ "") do
        case IO.getn(:stdio, "", 1) do
          " " -> acc
          "\n" -> acc
          x -> next_token(acc <> x)
        end
    end
    def input(), do: IO.read(:line) |> String.trim()
    def ii(), do: next_token() |> String.to_integer()
    def li(), do: input() |> String.split(" ") |> Enum.map(&String.to_integer/1)

    def count_char(line,c) do
        Enum.count(line, fn x -> x == c end)
    end

    def count_o(s) do
      Enum.map(s, fn line -> count_char(line, ?o) end)
    end

    def transpose(s) do
        Enum.zip(s) |> Enum.map(&Tuple.to_list/1)
    end


    def main() do
        n = ii()
        s = for _ <- 1..n, do: input()|> String.to_charlist()
        cnt_h = count_o(s)
        cnt_v = count_o(transpose(s))

        flat_s = List.flatten(s)
        flat_cnt_v = List.duplicate(cnt_v, n) |> List.flatten()
        flat_cnt_h = Enum.flat_map(cnt_h, fn x -> List.duplicate(x, n) end)
        for {c, h, v} <- Enum.zip([flat_s, flat_cnt_h, flat_cnt_v]), reduce: 0 do
            acc -> if c == ?o, do: acc + (h-1) * (v-1), else: acc
        end
        |> IO.puts()
    end

end

atcoder-toolsでテストケースを実行してみた結果です。
in_4.txtがN=2000の場合です

image.png

20秒。これでは、TLEですね。

検証1 N^2のループだけで測定

結果をreduceで求めてる部分がN^2回実行されて時間のかかる部分です。この部分だけのプログラムにしてみます。
入力データやcnt_h cnt_vも固定値で試してみます

検証1
    def main() do
        n = ii()
        flat_s = List.duplicate(?o,n*n)
        flat_cnt_v = List.duplicate(0,n*n)
        flat_cnt_h = List.duplicate(0,n*n)
        for {c, h, v} <- Enum.zip([flat_s, flat_cnt_h, flat_cnt_v]), reduce: 0 do
            acc -> if c == ?o, do: acc + (h-1) * (v-1), else: acc
        end
        IO.puts(1997793203536)
    end
プログラム 実行時間 ms
オリジナル 20449
検証1 7511

もしかして、List.duplicateって遅いのか?

iex(1)> a = List.duplicate(0,4000000)
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]
iex(2)> length(a)
4000000

大丈夫です。iexで実行してみましたが、即完了します。これは問題ないです。

Enum.zipを使って、N=4000000のリストを捜査するのは間に合わないらしい。

検証2 Mapにしてみる

頼りにしていたEnum.zipが、意外と遅かったです。
この方法が遅いのは予想外でした。
元記事に@piacerexさんからのコメントにあったMapによる方法に方針変更して試してみます。

keyの値をリストのindex値としたMapに変換する関数作ってみました。

    def list_to_map(list) do
        list
        |> Enum.with_index()
        |> Enum.into(%{}, fn {v,i}->{i,v} end)
    end

リストの要素アクセスはO(N)なので避けたいとき、list_to_map()でできたmapを使えばO(logN)にできます。

検証2
    def list_to_map(list) do
        list
        |> Enum.with_index()
        |> Enum.into(%{}, fn {v,i}->{i,v} end)
    end

    def create_cnt_h(n,s) do
        for i <- 0..n-1 do
            for j <- 0..n-1, reduce: 0 do
              acc -> if s[i][j] == ?o, do: acc+1, else: acc
            end
        end
        |> list_to_map()
        # |> IO.inspect(label: "cnt_h")
    end

    def create_cnt_v(n,s) do
        for i <- 0..n-1 do
            for j <- 0..n-1, reduce: 0 do
                acc -> if s[j][i] == ?o, do: acc+1, else: acc
            end
        end
        |> list_to_map()
        # |> IO.inspect(label: "cnt_v")
    end

    def load_data() do
        n = ii()
        s = for _ <- 1..n, do: (input()|> String.to_charlist()|> list_to_map())
        {n,list_to_map(s)}
    end

    def get_ans(n,s,cnt_h, cnt_v) do
        for i <- 0..n-1, j <- 0..n-1, reduce: 0 do
            acc -> if s[i][j] == ?o do
            acc + (cnt_h[i] - 1) * (cnt_v[j] - 1)
            else
            acc
            end
        end
    end

    def main() do

        {n,s} = load_data()
        cnt_h = create_cnt_h(n, s)
        cnt_v = create_cnt_v(n, s)

        get_ans(n, s, cnt_h, cnt_v)
        |> IO.puts()
    end
プログラム 実行時間 ms
オリジナル 20449
検証1 7511
検証2 4796

おお!かなり速くなりました。もう一歩。

検証3 Tupleにしてみる

Tupleは、O(1)で値の参照ができるので、MapのO(NlogN)よりは速くなるはずです。
Enumで使えないとか、elem()で書かないといけないとか、ちょっとダルいですが、やってみます。

defmodule Main do
    import Bitwise
    def next_token(acc \\ "") do
        case IO.getn(:stdio, "", 1) do
          " " -> acc
          "\n" -> acc
          x -> next_token(acc <> x)
        end
    end
    def input(), do: IO.read(:line) |> String.trim()
    def ii(), do: next_token() |> String.to_integer()
    def li(), do: input() |> String.split(" ") |> Enum.map(&String.to_integer/1)

    def create_cnt_h(n,s) do
        for i <- 0..n-1 do
            for j <- 0..n-1, reduce: 0 do
              acc -> if s|>elem(i)|>elem(j) == ?o, do: acc+1, else: acc
            end
        end
        |> List.to_tuple()
        # |> IO.inspect(label: "cnt_h")
    end

    def create_cnt_v(n,s) do
        for i <- 0..n-1 do
            for j <- 0..n-1, reduce: 0 do
                acc -> if s|>elem(j)|>elem(i) == ?o, do: acc+1, else: acc
            end
        end
        |> List.to_tuple()
        # |> IO.inspect(label: "cnt_v")
    end

    def load_data() do
        n = ii()
        s = for _ <- 1..n, do: (input()|> String.to_charlist()|> List.to_tuple())
        {n,List.to_tuple(s)}
    end

    def get_ans(n,s,cnt_h, cnt_v) do
        for i <- 0..n-1, j <- 0..n-1, reduce: 0 do
            acc -> if s|>elem(i)|>elem(j) == ?o do
            acc + (elem(cnt_h,i) - 1) * (elem(cnt_v,j) - 1)
            else
            acc
            end
        end
    end

    def main() do

        {n,s} = load_data()
        cnt_h = create_cnt_h(n, s)
        cnt_v = create_cnt_v(n, s)

        get_ans(n, s, cnt_h, cnt_v)
        |> IO.puts()
    end

end
プログラム 実行時間 ms
オリジナル 20449
検証1 7511
検証2 4796
検証2 1149

これなら、ACできるのでは?

AtCoderのジャッジにトライ

残念ながら、TLEでした。
test_04.txt, test_11.txtが 2000msを切り、通るようになりました。

ローカル環境でのテストと差があるのは、elixirコマンドでコンパイルして実行しているので軽い事があるとおもいます。
AtCoderのジャッジは、何もしない処理を実行しても、800msくらいかかるので、その差だと思います。

小手先の技でなんとかしてみる

O(N^2)では処理できてて計算量的には合格なんですが、あとちょっとなので、記述方法を変えて速くならないか、ぎりぎりを攻めてみます。
いろいろやってみた結果、次のコードが最速でした
for文のreduceより、Enum.redudeのほうが僅かに速かったです。

TLEx8まで減らせましたが、限界かなぁ。

defmodule Main do
    import Bitwise
    def next_token(acc \\ "") do
        case IO.getn(:stdio, "", 1) do
          " " -> acc
          "\n" -> acc
          x -> next_token(acc <> x)
        end
    end
    def input(), do: IO.read(:line) |> String.trim()
    def ii(), do: next_token() |> String.to_integer()
    def li(), do: input() |> String.split(" ") |> Enum.map(&String.to_integer/1)

    def create_cnt_h(n,s) do
        for i <- 0..n-1  do
            Enum.reduce(0..n-1,0, fn j, acc ->if s|>elem(i)|>elem(j) == ?o, do: acc+1, else: acc end)
        end
        |> List.to_tuple()
        # |> IO.inspect(label: "cnt_h")
    end

    def create_cnt_v(n,s) do
        for i <- 0..n-1 do
            Enum.reduce(0..n-1,0, fn j, acc ->if s|>elem(j)|>elem(i) == ?o, do: acc+1, else: acc end)
        end
        |> List.to_tuple()
        # |> IO.inspect(label: "cnt_v")
    end

    def load_data() do
        n = ii()
        s = for _ <- 1..n, do: (input()|> String.to_charlist()|> List.to_tuple())
        {n,List.to_tuple(s)}
    end

    def get_ans(n,s,cnt_h, cnt_v) do
        Enum.reduce(0..n-1,0,fn i,acc ->
            s_i = elem(s,i)
            Enum.reduce(0..n-1,acc, fn j,acc ->
            if elem(s_i,j) == ?o do
                acc + (elem(cnt_h,i) - 1) * (elem(cnt_v,j) - 1)
                else
                acc
                end
           end)
        end)
    end

    def main() do

        {n,s} = load_data()
        cnt_h = create_cnt_h(n, s)
        cnt_v = create_cnt_v(n, s)

        get_ans(n, s, cnt_h, cnt_v)
        |> IO.puts()
    end

end

image.png

番外

掛け算の回数を減らしたら速くなるかな?とおもって、jのループを足し算だけにしたバージョンも試してみました。
有意な違いはなく、かえって遅かったです。昔は、コプロセッサ・・・(省略)。

    def get_ans(n,s,cnt_h, cnt_v) do
        Enum.reduce(0..n-1,0,fn i,acc ->
            s_i = elem(s,i)
            cnt_h_i = elem(cnt_h,i) - 1
            sum_j = Enum.reduce(0..n-1,0 , fn j,acc ->
                if elem(s_i,j)==1 do
                   acc + (elem(cnt_v,j) - 1)
                else
                  acc
                end
           end)
           acc + cnt_h_i * sum_j
        end)
    end
9
0
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
0