AtCoder ABC330をElixirで復習してみよう で問題DがTLEで解けませんでした。
詳しく調べてみます
テストデータ作成
2023/12/8現在またテストケースが公開されていないので、TLEになったテストケースをローカルで実行してみることはできません。
単純にN=2000のランダムデータを生成して時間計測してみます。
テストデータ作成プログラム
defmodule CreateData do
def print_line(f,n) do
str =
for i <- 1..n, reduce: "" do
acc -> acc <> Enum.random(["o","x"])
end
IO.write(f,str)
IO.write(f,"\n")
end
def main() do
n = 2000
{:ok,f} = File.open("in_4.txt",[:write])
IO.write(f,n)
IO.write(f,"\n")
for i <- 1..n do
print_line(f,n)
end
File.close(f)
end
end
TLEになったプログラム
defmodule Main do
import Bitwise
def next_token(acc \\ "") do
case IO.getn(:stdio, "", 1) do
" " -> acc
"\n" -> acc
x -> next_token(acc <> x)
end
end
def input(), do: IO.read(:line) |> String.trim()
def ii(), do: next_token() |> String.to_integer()
def li(), do: input() |> String.split(" ") |> Enum.map(&String.to_integer/1)
def count_char(line,c) do
Enum.count(line, fn x -> x == c end)
end
def count_o(s) do
Enum.map(s, fn line -> count_char(line, ?o) end)
end
def transpose(s) do
Enum.zip(s) |> Enum.map(&Tuple.to_list/1)
end
def main() do
n = ii()
s = for _ <- 1..n, do: input()|> String.to_charlist()
cnt_h = count_o(s)
cnt_v = count_o(transpose(s))
flat_s = List.flatten(s)
flat_cnt_v = List.duplicate(cnt_v, n) |> List.flatten()
flat_cnt_h = Enum.flat_map(cnt_h, fn x -> List.duplicate(x, n) end)
for {c, h, v} <- Enum.zip([flat_s, flat_cnt_h, flat_cnt_v]), reduce: 0 do
acc -> if c == ?o, do: acc + (h-1) * (v-1), else: acc
end
|> IO.puts()
end
end
atcoder-toolsでテストケースを実行してみた結果です。
in_4.txtがN=2000の場合です
20秒。これでは、TLEですね。
検証1 N^2のループだけで測定
結果をreduceで求めてる部分がN^2回実行されて時間のかかる部分です。この部分だけのプログラムにしてみます。
入力データやcnt_h cnt_vも固定値で試してみます
def main() do
n = ii()
flat_s = List.duplicate(?o,n*n)
flat_cnt_v = List.duplicate(0,n*n)
flat_cnt_h = List.duplicate(0,n*n)
for {c, h, v} <- Enum.zip([flat_s, flat_cnt_h, flat_cnt_v]), reduce: 0 do
acc -> if c == ?o, do: acc + (h-1) * (v-1), else: acc
end
IO.puts(1997793203536)
end
プログラム | 実行時間 ms |
---|---|
オリジナル | 20449 |
検証1 | 7511 |
もしかして、List.duplicateって遅いのか?
iex(1)> a = List.duplicate(0,4000000)
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]
iex(2)> length(a)
4000000
大丈夫です。iexで実行してみましたが、即完了します。これは問題ないです。
Enum.zipを使って、N=4000000のリストを捜査するのは間に合わないらしい。
検証2 Mapにしてみる
頼りにしていたEnum.zipが、意外と遅かったです。
この方法が遅いのは予想外でした。
元記事に@piacerexさんからのコメントにあったMapによる方法に方針変更して試してみます。
keyの値をリストのindex値としたMapに変換する関数作ってみました。
def list_to_map(list) do
list
|> Enum.with_index()
|> Enum.into(%{}, fn {v,i}->{i,v} end)
end
リストの要素アクセスはO(N)なので避けたいとき、list_to_map()でできたmapを使えばO(logN)にできます。
def list_to_map(list) do
list
|> Enum.with_index()
|> Enum.into(%{}, fn {v,i}->{i,v} end)
end
def create_cnt_h(n,s) do
for i <- 0..n-1 do
for j <- 0..n-1, reduce: 0 do
acc -> if s[i][j] == ?o, do: acc+1, else: acc
end
end
|> list_to_map()
# |> IO.inspect(label: "cnt_h")
end
def create_cnt_v(n,s) do
for i <- 0..n-1 do
for j <- 0..n-1, reduce: 0 do
acc -> if s[j][i] == ?o, do: acc+1, else: acc
end
end
|> list_to_map()
# |> IO.inspect(label: "cnt_v")
end
def load_data() do
n = ii()
s = for _ <- 1..n, do: (input()|> String.to_charlist()|> list_to_map())
{n,list_to_map(s)}
end
def get_ans(n,s,cnt_h, cnt_v) do
for i <- 0..n-1, j <- 0..n-1, reduce: 0 do
acc -> if s[i][j] == ?o do
acc + (cnt_h[i] - 1) * (cnt_v[j] - 1)
else
acc
end
end
end
def main() do
{n,s} = load_data()
cnt_h = create_cnt_h(n, s)
cnt_v = create_cnt_v(n, s)
get_ans(n, s, cnt_h, cnt_v)
|> IO.puts()
end
プログラム | 実行時間 ms |
---|---|
オリジナル | 20449 |
検証1 | 7511 |
検証2 | 4796 |
おお!かなり速くなりました。もう一歩。
検証3 Tupleにしてみる
Tupleは、O(1)で値の参照ができるので、MapのO(NlogN)よりは速くなるはずです。
Enumで使えないとか、elem()で書かないといけないとか、ちょっとダルいですが、やってみます。
defmodule Main do
import Bitwise
def next_token(acc \\ "") do
case IO.getn(:stdio, "", 1) do
" " -> acc
"\n" -> acc
x -> next_token(acc <> x)
end
end
def input(), do: IO.read(:line) |> String.trim()
def ii(), do: next_token() |> String.to_integer()
def li(), do: input() |> String.split(" ") |> Enum.map(&String.to_integer/1)
def create_cnt_h(n,s) do
for i <- 0..n-1 do
for j <- 0..n-1, reduce: 0 do
acc -> if s|>elem(i)|>elem(j) == ?o, do: acc+1, else: acc
end
end
|> List.to_tuple()
# |> IO.inspect(label: "cnt_h")
end
def create_cnt_v(n,s) do
for i <- 0..n-1 do
for j <- 0..n-1, reduce: 0 do
acc -> if s|>elem(j)|>elem(i) == ?o, do: acc+1, else: acc
end
end
|> List.to_tuple()
# |> IO.inspect(label: "cnt_v")
end
def load_data() do
n = ii()
s = for _ <- 1..n, do: (input()|> String.to_charlist()|> List.to_tuple())
{n,List.to_tuple(s)}
end
def get_ans(n,s,cnt_h, cnt_v) do
for i <- 0..n-1, j <- 0..n-1, reduce: 0 do
acc -> if s|>elem(i)|>elem(j) == ?o do
acc + (elem(cnt_h,i) - 1) * (elem(cnt_v,j) - 1)
else
acc
end
end
end
def main() do
{n,s} = load_data()
cnt_h = create_cnt_h(n, s)
cnt_v = create_cnt_v(n, s)
get_ans(n, s, cnt_h, cnt_v)
|> IO.puts()
end
end
プログラム | 実行時間 ms |
---|---|
オリジナル | 20449 |
検証1 | 7511 |
検証2 | 4796 |
検証2 | 1149 |
これなら、ACできるのでは?
AtCoderのジャッジにトライ
残念ながら、TLEでした。
test_04.txt, test_11.txtが 2000msを切り、通るようになりました。
ローカル環境でのテストと差があるのは、elixirコマンドでコンパイルして実行しているので軽い事があるとおもいます。
AtCoderのジャッジは、何もしない処理を実行しても、800msくらいかかるので、その差だと思います。
小手先の技でなんとかしてみる
O(N^2)では処理できてて計算量的には合格なんですが、あとちょっとなので、記述方法を変えて速くならないか、ぎりぎりを攻めてみます。
いろいろやってみた結果、次のコードが最速でした
for文のreduceより、Enum.redudeのほうが僅かに速かったです。
TLEx8まで減らせましたが、限界かなぁ。
defmodule Main do
import Bitwise
def next_token(acc \\ "") do
case IO.getn(:stdio, "", 1) do
" " -> acc
"\n" -> acc
x -> next_token(acc <> x)
end
end
def input(), do: IO.read(:line) |> String.trim()
def ii(), do: next_token() |> String.to_integer()
def li(), do: input() |> String.split(" ") |> Enum.map(&String.to_integer/1)
def create_cnt_h(n,s) do
for i <- 0..n-1 do
Enum.reduce(0..n-1,0, fn j, acc ->if s|>elem(i)|>elem(j) == ?o, do: acc+1, else: acc end)
end
|> List.to_tuple()
# |> IO.inspect(label: "cnt_h")
end
def create_cnt_v(n,s) do
for i <- 0..n-1 do
Enum.reduce(0..n-1,0, fn j, acc ->if s|>elem(j)|>elem(i) == ?o, do: acc+1, else: acc end)
end
|> List.to_tuple()
# |> IO.inspect(label: "cnt_v")
end
def load_data() do
n = ii()
s = for _ <- 1..n, do: (input()|> String.to_charlist()|> List.to_tuple())
{n,List.to_tuple(s)}
end
def get_ans(n,s,cnt_h, cnt_v) do
Enum.reduce(0..n-1,0,fn i,acc ->
s_i = elem(s,i)
Enum.reduce(0..n-1,acc, fn j,acc ->
if elem(s_i,j) == ?o do
acc + (elem(cnt_h,i) - 1) * (elem(cnt_v,j) - 1)
else
acc
end
end)
end)
end
def main() do
{n,s} = load_data()
cnt_h = create_cnt_h(n, s)
cnt_v = create_cnt_v(n, s)
get_ans(n, s, cnt_h, cnt_v)
|> IO.puts()
end
end
番外
掛け算の回数を減らしたら速くなるかな?とおもって、jのループを足し算だけにしたバージョンも試してみました。
有意な違いはなく、かえって遅かったです。昔は、コプロセッサ・・・(省略)。
def get_ans(n,s,cnt_h, cnt_v) do
Enum.reduce(0..n-1,0,fn i,acc ->
s_i = elem(s,i)
cnt_h_i = elem(cnt_h,i) - 1
sum_j = Enum.reduce(0..n-1,0 , fn j,acc ->
if elem(s_i,j)==1 do
acc + (elem(cnt_v,j) - 1)
else
acc
end
end)
acc + cnt_h_i * sum_j
end)
end