AtCoder ARC104の4問目をScala、C++、Java、Ruby、Perl、Elixirで解きました。言語ごとの処理速度の比較ができました。
結果
言語 | 結果 |
---|---|
Scala | TLE |
C++ | AC 1830ms |
Java | AC 2995ms |
JavaっぽいScala | AC 2991ms |
Ruby | TLE |
Perl | TLE |
Elixir | TLE |
競技プログラミングにおける言語の選択と処理速度の関係は以下のツイートに言及がありました。
このツイートによると、C++、Javaは大丈夫だけど、他の言語は計算量がシビアな問題には厳しいようです。
競技中は解法を見つけ出すのに時間がかかり、わかったあとはScalaで書いたものの、バグ取りに手間取って、AC(正解)に至りませんでした。競技後も続けたものの、どうしても実行時間制限4秒のTLEを解消できませんでした。
あきらめて同じロジックをC++で書いたらあっさりACを取れました。
そこで、せっかくなのでいろんな言語で書いてみて、AC取れるかを比較した結果です。
Scalaは最初はTLEだったものの、JavaでAC取ったあとにJavaを翻訳したようなScalaを書いたらScalaでもACとれました。
この記事の残りは各言語のソースコードです。
Scala
Array
を使うなど中途半端にJavaっぽいScalaコードですが、TLEを解消できずにあきらめたコードです。
object Main extends App {
val sc = new java.util.Scanner(System.in);
val n, k, m = sc.nextInt();
val table = (0 until n).map(i => Array.fill(i * (i+1) * k / 2 + 2)(0));
table(0)(0) = 1;
(1 until n).foreach { i =>
val j_max = i * (i+1) * k / 2;
val t1 = table(i - 1);
val t2 = table(i);
(0 to j_max).foreach { j =>
val a1 = (((j - (i - 1) * i * k / 2) + i - 1) / i) max 0;
val a2 = (j / i) min k;
val s = (a1 to a2).map { a =>
t1(j - a * i).toLong;
}.sum % m;
t2(j) = s.toInt;
}
}
val table2 = Array.fill(n+1)(0);
(1 to n).foreach { x =>
if (x <= (n + 1) / 2) {
val (n1, n2) = (x - 1, n - x);
val s = (0 to (n1 * (n1+1) * k / 2)).map { i =>
table(n1)(i).toLong * table(n2)(i) % m;
}.sum % m;
val answer = ((m.toLong + s * (k+1) - 1) % m).toInt;
table2(x) = answer;
println("%d".format(answer));
} else {
val answer = table2(n + 1 - x);
println("%d".format(answer));
}
}
}
後述のJavaでAC取れたあと、JavaをScalaに翻訳したようなコードを書いたところ、Javaとほぼ同じ2991msでした。Scalaでこういうコードを書くぐらいならJavaでいいわけで、Scalaは計算量が厳しい問題には向いていないという結論です。Scalaは私にとって書きやすいので今後も使い分ければよいです。
object Main extends App {
val sc = new java.util.Scanner(System.in);
val n, k, m = sc.nextInt();
val table = new Array[Array[Int]](n);
table(0) = new Array[Int](2);
table(0)(0) = 1;
var i: Int = 1;
while (i < n) {
val j_max = i * (i+1) * k / 2;
val t1 = table(i - 1);
val t2 = new Array[Int](i * (i+1) * k / 2 + 2);
table(i) = t2;
var j: Int = 0;
while (j <= j_max) {
val a1 = (((j - (i - 1) * i * k / 2) + i - 1) / i) max 0;
val a2 = (j / i) min k;
var s: Long = 0L;
var a: Int = a1;
while (a <= a2) {
s += t1(j - a * i).toLong;
a += 1;
}
t2(j) = (s % m).toInt;
j += 1;
}
i += 1;
}
val table2 = new Array[Int](n+1);
var x: Int = 1;
while (x <= n) {
if (x <= (n + 1) / 2) {
val (n1, n2) = (x - 1, n - x);
var s: Long = 0L;
var i: Int = 0;
var max = (n1 * (n1+1) * k / 2);
while (i <= max) {
s += table(n1)(i).toLong * table(n2)(i) % m;
i += 1;
}
s = s % m;
val answer = ((m.toLong + s * (k+1) - 1) % m).toInt;
table2(x) = answer;
println("%d".format(answer));
} else {
val answer = table2(n + 1 - x);
println("%d".format(answer));
}
x += 1;
}
}
C++
同じロジックをC++で書き直したところ、1830msでAC取れました。
#include <bits/stdc++.h>
using namespace std;
int main() {
int n, k, m;
cin >> n >> k >> m;
vector<vector<int>> table(n);
table[0] = vector<int>(2);
table[0][0] = 1;
for (int i = 1; i < n; i++) {
auto t1 = table[i-1];
auto t2 = vector<int>(i * (i+1) * k / 2 + 2);
int j_max = i * (i+1) * k / 2;
for (int j = 0; j <= j_max; j++) {
int a1 = ((j - (i - 1) * i * k / 2) + i - 1) / i;
if (a1 < 0) a1 = 0;
int a2 = j / i;
if (a2 > k) a2 = k;
long s = 0;
for (int a = a1; a <= a2; a++) {
s += t1[j - a * i];
}
s = s % m;
t2[j] = (int)s;
}
table[i] = t2;
}
auto table2 = vector<int>((n + 1) / 2 + 1);
for (int x = 1; x <= (n + 1) / 2; x++) {
int n1 = x - 1;
int n2 = n - x;
long s = 0;
for (int i = 0; i <= n1 * (n1+1) * k / 2; i++) {
s += (long)table[n1][i] * table[n2][i] % m;
}
int answer = (int)(((long)m + s * (k+1) - 1) % m);
table2[x] = answer;
printf("%d\n", answer);
}
for (int x = (n + 1) / 2 + 1; x <= n; x++) {
printf("%d\n", table2[n + 1 - x]);
}
}
Java
C++でACが取れ、解法が間違ってないことがわかりましたので、次はJavaで書きました。2995msでAC取れました。C++よりは遅いものの、同じJVMのScalaよりも高速です。JVMでできるということは、ScalaでもJavaっぽい書き方にすればできるはずと考えたところ、ScalaでもACが取れたのは先述のとおりです。
import java.util.Scanner;
class Main {
public static void main(String[] args) {
var sc = new Scanner(System.in);
var n = sc.nextInt();
var k = sc.nextInt();
var m = sc.nextInt();
var table = new int[n][];
table[0] = new int[2];
table[0][0] = 1;
for (int i = 1; i < n; i++) {
var t1 = table[i-1];
var t2 = new int[i * (i+1) * k / 2 + 2];
int j_max = i * (i+1) * k / 2;
for (int j = 0; j <= j_max; j++) {
int a1 = ((j - (i - 1) * i * k / 2) + i - 1) / i;
if (a1 < 0) a1 = 0;
int a2 = j / i;
if (a2 > k) a2 = k;
long s = 0;
for (int a = a1; a <= a2; a++) {
s += t1[j - a * i];
}
s = s % m;
t2[j] = (int)s;
}
table[i] = t2;
}
var table2 = new int[(n + 1) / 2 + 1];
for (int x = 1; x <= (n + 1) / 2; x++) {
int n1 = x - 1;
int n2 = n - x;
long s = 0;
for (int i = 0; i <= n1 * (n1+1) * k / 2; i++) {
s += (long)table[n1][i] * table[n2][i] % m;
}
int answer = (int)(((long)m + s * (k+1) - 1) % m);
table2[x] = answer;
System.out.printf("%d\n", answer);
}
for (int x = (n + 1) / 2 + 1; x <= n; x++) {
System.out.printf("%d\n", table2[n + 1 - x]);
}
}
}
Ruby
TLEでした。
n, k, m = gets.strip.split(" ").map(&:to_i)
table = [[1]]
for i in 1..n-1
t1 = table[i-1]
t2 = []
for j in 0 .. i * (i+1) * k / 2
a1 = ((j - (i - 1) * i * k / 2) + i - 1) / i;
a1 = 0 if a1 < 0
a2 = j / i;
a2 = k if a2 > k
s = 0;
for a in a1 .. a2
s += t1[j - a * i]
end
s = s % m
t2.push(s)
end
table.push(t2)
end
table2 = [0]
for x in 1 .. (n + 1) / 2
n1 = x - 1
n2 = n - x
s = 0
for i in 0 .. n1 * (n1+1) * k / 2
s += table[n1][i] * table[n2][i] % m;
end
answer = (m + s * (k+1) - 1) % m
table2.push(answer)
p answer
end
for x in (n + 1) / 2 + 1 .. n
p table2[n + 1 - x]
end
Perl
TLEでした。
use strict;
use warnings;
use integer;
my $nkm = <STDIN>;
$nkm =~ s/\n\z//;
my ($n, $k, $m) = split(/ /, $nkm);
my $table = [[1]];
for (my $i = 1; $i < $n; $i++) {
my $t1 = $table->[$i - 1];
my $t2 = [];
my $j_max = $i * ($i + 1) * $k / 2;
for (my $j = 0; $j <= $j_max; $j++) {
my $a1 = (($j - ($i - 1) * $i * $k / 2) + $i - 1) / $i;
$a1 = 0 if $a1 < 0;
my $a2 = $j / $i;
$a2 = $k if $a2 > $k;
my $s = 0;
for (my $aa = $a1; $aa <= $a2; $aa++) {
$s += $t1->[$j - $aa * $i];
}
$s = $s % $m;
push(@$t2, $s);
}
push(@$table, $t2);
}
my $table2 = [0];
for (my $x = 1; $x <= ($n + 1) / 2; $x++) {
my $n1 = $x - 1;
my $n2 = $n - $x;
my $s = 0;
for (my $i = 0; $i <= $n1 * ($n1+1) * $k / 2; $i++) {
$s += $table->[$n1]->[$i] * $table->[$n2]->[$i] % $m;
}
my $answer = ($m + $s * ($k+1) - 1) % $m;
push(@$table2, $answer);
printf("%d\n", $answer);
}
for (my $x = ($n + 1) / 2 + 1; $x <= $n; $x++) {
printf("%d\n", $table2->[$n + 1 - $x]);
}
Elixir
TLEでした。
defmodule Main do
def main do
[n, k, m] = IO.read(:line) |> String.trim() |> String.split(" ") |> Enum.map(&String.to_integer/1)
table = Enum.reduce(1 .. n, [[1, 0]], fn i, acc ->
[t1 | _] = acc
j_max = div(i * (i+1) * k, 2);
t2 = Enum.map(0 .. j_max, fn j ->
a1 = max(div((j - div((i - 1) * i * k, 2)) + i - 1, i), 0)
a2 = min(div(j, i), k)
rem(Enum.reduce(a1 .. a2, 0, fn a, acc3 ->
acc3 + Enum.at(t1, j - a * i)
end), m)
end)
[Enum.reverse(t2) | acc]
end) |> Enum.reverse
table2 = (1 .. div(n + 1, 2)) |> Enum.map(fn x ->
[n1, n2] = [x - 1, n - x]
s = Enum.reduce(0 .. div(n1 * (n1 + 1) * k, 2), 0, fn i, acc ->
rem((table |> Enum.at(n1) |> Enum.at(i)) * (table |> Enum.at(n2) |> Enum.at(i)) + acc, m)
end)
rem(m + s * (k+1) - 1, m)
end)
(1 .. n) |> Enum.each(fn x ->
if x <= div(n + 1, 2) do
IO.puts(Enum.at(table2, x - 1))
else
IO.puts(Enum.at(table2, n - x))
end
end)
end
end