LoginSignup
2
0

More than 3 years have passed since last update.

ARC104のD Multiset MeanをScala、Java、C++、Ruby、Perl、Elixirで解く

Posted at

AtCoder ARC104の4問目をScala、C++、Java、Ruby、Perl、Elixirで解きました。言語ごとの処理速度の比較ができました。

問題へのリンク

結果

言語 結果
Scala TLE
C++ AC 1830ms
Java AC 2995ms
JavaっぽいScala AC 2991ms
Ruby TLE
Perl TLE
Elixir TLE

競技プログラミングにおける言語の選択と処理速度の関係は以下のツイートに言及がありました。

このツイートによると、C++、Javaは大丈夫だけど、他の言語は計算量がシビアな問題には厳しいようです。

競技中は解法を見つけ出すのに時間がかかり、わかったあとはScalaで書いたものの、バグ取りに手間取って、AC(正解)に至りませんでした。競技後も続けたものの、どうしても実行時間制限4秒のTLEを解消できませんでした。

あきらめて同じロジックをC++で書いたらあっさりACを取れました。

そこで、せっかくなのでいろんな言語で書いてみて、AC取れるかを比較した結果です。

Scalaは最初はTLEだったものの、JavaでAC取ったあとにJavaを翻訳したようなScalaを書いたらScalaでもACとれました。

この記事の残りは各言語のソースコードです。

Scala

Arrayを使うなど中途半端にJavaっぽいScalaコードですが、TLEを解消できずにあきらめたコードです。

object Main extends App {
  val sc = new java.util.Scanner(System.in);
  val n, k, m = sc.nextInt();

  val table = (0 until n).map(i => Array.fill(i * (i+1) * k / 2 + 2)(0));
  table(0)(0) = 1;
  (1 until n).foreach { i =>
    val j_max = i * (i+1) * k / 2;
    val t1 = table(i - 1);
    val t2 = table(i);
    (0 to j_max).foreach { j =>
      val a1 = (((j - (i - 1) * i * k / 2) + i - 1) / i) max 0;
      val a2 = (j / i) min k;
      val s = (a1 to a2).map { a =>
        t1(j - a * i).toLong;
      }.sum % m;
      t2(j) = s.toInt;
    }
  }

  val table2 = Array.fill(n+1)(0);
  (1 to n).foreach { x =>
    if (x <= (n + 1) / 2) {
      val (n1, n2) = (x - 1, n - x);
      val s = (0 to (n1 * (n1+1) * k / 2)).map { i =>
        table(n1)(i).toLong * table(n2)(i) % m;
      }.sum % m;
      val answer = ((m.toLong + s * (k+1) - 1) % m).toInt;
      table2(x) = answer;
      println("%d".format(answer));
    } else {
      val answer = table2(n + 1 - x);
      println("%d".format(answer));
    }
  }
}

後述のJavaでAC取れたあと、JavaをScalaに翻訳したようなコードを書いたところ、Javaとほぼ同じ2991msでした。Scalaでこういうコードを書くぐらいならJavaでいいわけで、Scalaは計算量が厳しい問題には向いていないという結論です。Scalaは私にとって書きやすいので今後も使い分ければよいです。

object Main extends App {
  val sc = new java.util.Scanner(System.in);
  val n, k, m = sc.nextInt();

  val table = new Array[Array[Int]](n);
  table(0) = new Array[Int](2);
  table(0)(0) = 1;
  var i: Int = 1;
  while (i < n) {
    val j_max = i * (i+1) * k / 2;
    val t1 = table(i - 1);
    val t2 = new Array[Int](i * (i+1) * k / 2 + 2);
    table(i) = t2;
    var j: Int = 0;
    while (j <= j_max) {
      val a1 = (((j - (i - 1) * i * k / 2) + i - 1) / i) max 0;
      val a2 = (j / i) min k;
      var s: Long = 0L;
      var a: Int = a1;
      while (a <= a2) {
        s += t1(j - a * i).toLong;
        a += 1;
      }
      t2(j) = (s % m).toInt;
      j += 1;
    }
    i += 1;
  }

  val table2 = new Array[Int](n+1);
  var x: Int = 1;
  while (x <= n) {
    if (x <= (n + 1) / 2) {
      val (n1, n2) = (x - 1, n - x);
      var s: Long = 0L;
      var i: Int = 0;
      var max = (n1 * (n1+1) * k / 2);
      while (i <= max) {
        s += table(n1)(i).toLong * table(n2)(i) % m;
        i += 1;
      }
      s = s % m;
      val answer = ((m.toLong + s * (k+1) - 1) % m).toInt;
      table2(x) = answer;
      println("%d".format(answer));
    } else {
      val answer = table2(n + 1 - x);
      println("%d".format(answer));
    }
    x += 1;
  }
}

C++

同じロジックをC++で書き直したところ、1830msでAC取れました。

#include <bits/stdc++.h>
using namespace std;

int main() {
  int n, k, m;
  cin >> n >> k >> m;

  vector<vector<int>> table(n);
  table[0] = vector<int>(2);
  table[0][0] = 1;
  for (int i = 1; i < n; i++) {
    auto t1 = table[i-1];
    auto t2 = vector<int>(i * (i+1) * k / 2 + 2);
    int j_max = i * (i+1) * k / 2;
    for (int j = 0; j <= j_max; j++) {
      int a1 = ((j - (i - 1) * i * k / 2) + i - 1) / i;
      if (a1 < 0) a1 = 0;
      int a2 = j / i;
      if (a2 > k) a2 = k;
      long s = 0;
      for (int a = a1; a <= a2; a++) {
        s += t1[j - a * i];
      }
      s = s % m;
      t2[j] = (int)s;
    }
    table[i] = t2;
  }

  auto table2 = vector<int>((n + 1) / 2 + 1);
  for (int x = 1; x <= (n + 1) / 2; x++) {
    int n1 = x - 1;
    int n2 = n - x;
    long s = 0;
    for (int i = 0; i <= n1 * (n1+1) * k / 2; i++) {
      s += (long)table[n1][i] * table[n2][i] % m;
    }
    int answer = (int)(((long)m + s * (k+1) - 1) % m);
    table2[x] = answer;
    printf("%d\n", answer);
  }
  for (int x = (n + 1) / 2 + 1; x <= n; x++) {
    printf("%d\n", table2[n + 1 - x]);
  }
}

Java

C++でACが取れ、解法が間違ってないことがわかりましたので、次はJavaで書きました。2995msでAC取れました。C++よりは遅いものの、同じJVMのScalaよりも高速です。JVMでできるということは、ScalaでもJavaっぽい書き方にすればできるはずと考えたところ、ScalaでもACが取れたのは先述のとおりです。

import java.util.Scanner;

class Main {
  public static void main(String[] args) {
    var sc = new Scanner(System.in);
    var n = sc.nextInt();
    var k = sc.nextInt();
    var m = sc.nextInt();

    var table = new int[n][];
    table[0] = new int[2];
    table[0][0] = 1;
    for (int i = 1; i < n; i++) {
      var t1 = table[i-1];
      var t2 = new int[i * (i+1) * k / 2 + 2];
      int j_max = i * (i+1) * k / 2;
      for (int j = 0; j <= j_max; j++) {
        int a1 = ((j - (i - 1) * i * k / 2) + i - 1) / i;
        if (a1 < 0) a1 = 0;
        int a2 = j / i;
        if (a2 > k) a2 = k;
        long s = 0;
        for (int a = a1; a <= a2; a++) {
          s += t1[j - a * i];
        }
        s = s % m;
        t2[j] = (int)s;
      }
      table[i] = t2;
    }

    var table2 = new int[(n + 1) / 2 + 1];
    for (int x = 1; x <= (n + 1) / 2; x++) {
      int n1 = x - 1;
      int n2 = n - x;
      long s = 0;
      for (int i = 0; i <= n1 * (n1+1) * k / 2; i++) {
        s += (long)table[n1][i] * table[n2][i] % m;
      }
      int answer = (int)(((long)m + s * (k+1) - 1) % m);
      table2[x] = answer;
      System.out.printf("%d\n", answer);
    }
    for (int x = (n + 1) / 2 + 1; x <= n; x++) {
      System.out.printf("%d\n", table2[n + 1 - x]);
    }
  }
}

Ruby

TLEでした。

n, k, m = gets.strip.split(" ").map(&:to_i)

table = [[1]]
for i in 1..n-1
  t1 = table[i-1]
  t2 = []
  for j in 0 .. i * (i+1) * k / 2
    a1 = ((j - (i - 1) * i * k / 2) + i - 1) / i;
    a1 = 0 if a1 < 0
    a2 = j / i;
    a2 = k if a2 > k
    s = 0;
    for a in a1 .. a2
      s += t1[j - a * i]
    end
    s = s % m
    t2.push(s)
  end
  table.push(t2)
end

table2 = [0]
for x in 1 .. (n + 1) / 2
  n1 = x - 1
  n2 = n - x
  s = 0
  for i in 0 .. n1 * (n1+1) * k / 2
    s += table[n1][i] * table[n2][i] % m;
  end
  answer = (m + s * (k+1) - 1) % m
  table2.push(answer)
  p answer
end
for x in (n + 1) / 2 + 1 .. n
  p table2[n + 1 - x]
end

Perl

TLEでした。

use strict;
use warnings;
use integer;

my $nkm = <STDIN>;
$nkm =~ s/\n\z//;
my ($n, $k, $m) = split(/ /, $nkm);

my $table = [[1]];
for (my $i = 1; $i < $n; $i++) {
    my $t1 = $table->[$i - 1];
    my $t2 = [];
    my $j_max = $i * ($i + 1) * $k / 2;
    for (my $j = 0; $j <= $j_max; $j++) {
        my $a1 = (($j - ($i - 1) * $i * $k / 2) + $i - 1) / $i;
        $a1 = 0 if $a1 < 0;
        my $a2 = $j / $i;
        $a2 = $k if $a2 > $k;
        my $s = 0;
        for (my $aa = $a1; $aa <= $a2; $aa++) {
            $s += $t1->[$j - $aa * $i];
        }
        $s = $s % $m;
        push(@$t2, $s);
    }
    push(@$table, $t2);
}

my $table2 = [0];
for (my $x = 1; $x <= ($n + 1) / 2; $x++) {
    my $n1 = $x - 1;
    my $n2 = $n - $x;
    my $s = 0;
    for (my $i = 0; $i <= $n1 * ($n1+1) * $k / 2; $i++) {
        $s += $table->[$n1]->[$i] * $table->[$n2]->[$i] % $m;
    }
    my $answer = ($m + $s * ($k+1) - 1) % $m;
    push(@$table2, $answer);
    printf("%d\n", $answer);
}
for (my $x = ($n + 1) / 2 + 1; $x <= $n; $x++) {
    printf("%d\n", $table2->[$n + 1 - $x]);
}

Elixir

TLEでした。

defmodule Main do
  def main do
    [n, k, m] = IO.read(:line) |> String.trim() |> String.split(" ") |> Enum.map(&String.to_integer/1)
    table = Enum.reduce(1 .. n, [[1, 0]], fn i, acc ->
      [t1 | _] = acc
      j_max = div(i * (i+1) * k, 2);
      t2 = Enum.map(0 .. j_max, fn j ->
        a1 = max(div((j - div((i - 1) * i * k, 2)) + i - 1, i), 0)
        a2 = min(div(j, i), k)
        rem(Enum.reduce(a1 .. a2, 0, fn a, acc3 ->
          acc3 + Enum.at(t1, j - a * i)
        end), m)
      end)
      [Enum.reverse(t2) | acc]
    end) |> Enum.reverse

    table2 = (1 .. div(n + 1, 2)) |> Enum.map(fn x ->
      [n1, n2] = [x - 1, n - x]
      s = Enum.reduce(0 .. div(n1 * (n1 + 1) * k, 2), 0, fn i, acc ->
        rem((table |> Enum.at(n1) |> Enum.at(i)) * (table |> Enum.at(n2) |> Enum.at(i)) + acc, m)
      end)
      rem(m + s * (k+1) - 1, m)
    end)
    (1 .. n) |> Enum.each(fn x ->
      if x <= div(n + 1, 2) do
        IO.puts(Enum.at(table2, x - 1))
      else
        IO.puts(Enum.at(table2, n - x))
      end
    end)
  end
end
2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0