Javaで競技プログラミングをするときの罠とテク

  • 54
    いいね
  • 0
    コメント

はじめに

この記事は「Java競技プログラミング」の前半部分を元に書き直したものです。後半部分は「Javaで競技プログラミングをするときによく使う標準ライブラリ」です。

今までJavaで競技プログラミングをしていて自分がつまづいたところや、知って役にたったと思ったことをまとめました。
備忘録でもありますが、これを読んだ方がJavaの罠を回避してもらえれば嬉しいです。
主にJavaで簡単な問題が解ける人を対象とします。基本的な文法については触れません。

記事中のコードは適宜import文やmain関数の部分を省略します。

提出

多くのオンラインジャッジではクラス名がMainである必要があり、またデフォルトパッケージでないと正常に実行されません。
AtCoderでは、デフォルトパッケージではない場合は結果がコンパイルエラーではなく、ランタイムエラーとなるので注意。
基本的な提出コードは以下のようになります。

//package hoge; のようなパッケージ宣言は提出前に除く
//インポート文など
public class Main { //クラス名はMain
    public static void main(String[] args) {
        //コード
    }
}

ただし、TopCoderのように提出がこの形式でないシステムもあります。

入力

Scannerが遅い

参考:http://d.hatena.ne.jp/wata_orz/20090914/1252902159

入力が標準入力で渡されるオンラインジャッジでは、Scannerを使うと楽にパースが出来ます。例えば1つの整数の入力は次のようにします。

Scanner sc = new Scanner(System.in);
int n = sc.nextInt();

この方法は低速でメモリ消費も大きいので注意が必要です。パースをInteger.parseInt()で行うことで簡単に高速化ができます。

Scanner sc = new Scanner(System.in);
int n = Integer.parseInt(sc.next());

さらに高速化が必要な場合は、Scannerの代わりになる処理を自前で実装すると良いです。オンラインのコンテストでは大抵自作ライブラリの使用が認められているので、高速なライブラリを作ってしまいましょう。
私は、入力される要素がおよそ$10^5$個以上になる場合やTLEが心配な場合にライブラリを貼っています。

私のライブラリの一部抜粋です。uwiさんのライブラリを参考にしました。

class FastScanner {
    private final InputStream in = System.in;
    private final byte[] buffer = new byte[1024];
    private int ptr = 0;
    private int buflen = 0;
    private boolean hasNextByte() {
        if (ptr < buflen) {
            return true;
        }else{
            ptr = 0;
            try {
                buflen = in.read(buffer);
            } catch (IOException e) {
                e.printStackTrace();
            }
            if (buflen <= 0) {
                return false;
            }
        }
        return true;
    }
    private int readByte() { if (hasNextByte()) return buffer[ptr++]; else return -1;}
    private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126;}
    public boolean hasNext() { while(hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++; return hasNextByte();}
    public String next() {
        if (!hasNext()) throw new NoSuchElementException();
        StringBuilder sb = new StringBuilder();
        int b = readByte();
        while(isPrintableChar(b)) {
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }
    public long nextLong() {
        if (!hasNext()) throw new NoSuchElementException();
        long n = 0;
        boolean minus = false;
        int b = readByte();
        if (b == '-') {
            minus = true;
            b = readByte();
        }
        if (b < '0' || '9' < b) {
            throw new NumberFormatException();
        }
        while(true){
            if ('0' <= b && b <= '9') {
                n *= 10;
                n += b - '0';
            }else if(b == -1 || !isPrintableChar(b)){
                return minus ? -n : n;
            }else{
                throw new NumberFormatException();
            }
            b = readByte();
        }
    }
    public int nextInt() {
        long nl = nextLong();
        if (nl < Integer.MIN_VALUE || nl > Integer.MAX_VALUE) throw new NumberFormatException();
        return (int) nl;
    }
    public double nextDouble() { return Double.parseDouble(next());}
}

ベンチマーク

入力処理についてベンチマークを取りました。$10^9$以下の正の整数$10^7$個が改行区切りで与えられるので、入力を読んでその和を求めます。FastScannerは上記のライブラリです。
ソースコード:https://ideone.com/DBFV5a

実行時間(ms)
Scanner.nextInt() 10591
Integer.parseInt(sc.next()) 5300
FastScanner.nextInt() 988

このように、自前で実装するとScannerの10倍くらい速くできます。

Scanner.nextLine()

Scanner.next()Scanner.nextInt()などで行末に到達した後にScanner.nextLine()を呼び出すと、次の行が読まれるのではなく、空文字列が帰ってくるので注意しましょう。どうしても使い分けが必要なのでなければ、next系かnextLineのどちらかのみを使用するのが良いでしょう。

結果の整形と出力

文字列結合

+演算子を使うことで文字列の結合をすることができます。しかし、たくさんの文字列を繰り返し処理などで結合したい場合はStringBuilderを用いた方が高速です。
文字列結合をした後出力するだけならば、直接出力するのも手でしょう。

ベンチマーク: https://ideone.com/z0F775

    //遅い
    public static String repeatString1(String s,int n) {
        String ret = "";
        for(int i=0;i<n;i++) {
            ret += s;
        }
        return ret;
    }
    //速い    
    public static String repeatString2(String s,int n) {
        StringBuilder sb = new StringBuilder();
        for(int i=0;i<n;i++) {
            sb.append(s);
        }
        return sb.toString();
    }

出力

標準出力に文字列を書き出すには普通System.out.printlnを用いますが、デフォルトではオートフラッシュが有効なため低速です。
解決策の一つとして、PrintWriterを用いて自動フラッシュをしないようにする方法があります。この場合、プログラムの終了まえにフラッシュを忘れないよう注意してください。

PrintWriter out = new PrintWriter(System.out);
int n = 100;
for(int i=0;i<n;i++) {
    out.println("hoge");
}
out.flush();

小数のフォーマット

出力形式として小数点以下の桁数を指定される場合があります。
DecimalFormatString.formatSystem.out.printfなどでフォーマットができます。doubleBigDecimalも同じようにフォーマットが可能です。

public class Main {
    public static void main(String[] args) {
        DecimalFormat df = new DecimalFormat("0.00000");
        System.out.println(df.format(Math.PI*10));
        System.out.println(String.format("%.5f",Math.PI*10));
        System.out.printf("%.5f\n",Math.PI*10);

        BigDecimal x = BigDecimal.ONE.divide(new BigDecimal("13"), 40, RoundingMode.HALF_EVEN);
        DecimalFormat df2 = new DecimalFormat("0.000000000000000000000000000000");
        System.out.println(df2.format(x));
        System.out.println(String.format("%.30f",x));
        System.out.printf("%.30f\n",x);
    }
}

出力
31.41593
31.41593
31.41593
0.076923076923076923076923076923
0.076923076923076923076923076923
0.076923076923076923076923076923

また、指数を含まない表記にしたい場合があります。小数点以下の桁数を指定するか、BigDecimalの場合はBigDecimal.toPlainStringも使えます。

public class Main {
    public static void main(String[] args) {
        BigDecimal x = BigDecimal.ONE.divide(new BigDecimal("13"), 60, RoundingMode.HALF_EVEN);
        x = x.scaleByPowerOfTen(20).setScale(-10, RoundingMode.HALF_EVEN);
        System.out.println(x);
        System.out.printf("%.0f\n",x);
        System.out.println(x.toPlainString());

        double y = Math.PI * Math.pow(10, 10);
        System.out.println(y);
        System.out.printf("%.0f\n",y);
    }
}
出力
7.69230769E+18
7692307690000000000
7692307690000000000
3.141592653589793E10
31415926536

doubleの高速なフォーマット

uwiさんのコメントを参考にしました。ありがとうございます。
doubleを大量に出力する場合、System.out.printfでは低速です。DecimalFormatの方がやや高速で、自前でフォーマットすると更に高速化できます。
以下のソースコードはuwiさんのコメントからです。

    public static String dtos(double x, int n) {
        StringBuilder sb = new StringBuilder();
        if(x < 0){
            sb.append('-');
            x = -x;
        }
        x += Math.pow(10, -n)/2;
//      if(x < 0){ x = 0; }
        sb.append((long)x);
        sb.append(".");
        x -= (long)x;
        for(int i = 0;i < n;i++){
            x *= 10;
            sb.append((int)x);
            x -= (int)x;
        }
        return sb.toString();
    }

ベンチマークを取ってみました。Math.random()で長さ$10^6$個の配列を生成した後、配列を改行区切りで出力する時間を計測しました。
ソースコード:https://ideone.com/uRElnn

実行時間(ms)
printf 2121
String.format 2264
DecimalFormat.format 1446
dtos 543

スタックオーバーフロー

配列外参照などのミスが無さそうなのにREになる場合、スタックオーバーフローの可能性が考えられます。
スレッドのデフォルトのスタックサイズはJVMのオプションによって決まり、これを設定していない場合は環境ごとのデフォルト値(320KB~1024KB)が用いられます。
引数の数によって前後しますが、再帰の深さが数万程度になるとよくスタックオーバーフローを起こします。
C++と比べるとスタックオーバーフローしやすいので注意が必要です。

    public static void main(String[] args) {
        System.out.println(recursive(65000)); //(自分の環境では)スタックオーバーフロー
    }
    static int recursive(int n) {
        if (n == 0) {
            return n;
        }else{
            return recursive(n-1) + 1;
        }
    }
}

スタックの拡張

スタックの拡張は、スタックサイズを指定してスレッドを作り実行することで可能です。
実装するメソッドはRunnable.run()ですが呼び出すのはThread.start()であることに注意してください。間違えてrun()を呼び出した場合、スタックを拡張せず実行することになるので気づきにくいです。

public class Main implements Runnable { //Runnableを実装する
    public static void main(String[] args) {
        new Thread(null, new Main(), "", 16 * 1024 * 1024).start(); //16MBスタックを確保して実行
    }
    public void run() {
        //ここに処理を書く
    }
}

MLE

ほとんど何もしないプログラムでもJVMがメモリを持って行ったり、全てのオブジェクトがヘッダに8byte使っているなど、C++と比べると消費メモリが増えやすい傾向があります。

配列の消費メモリを小さくする

参考:http://www.limy.org/program/java/memory_usage.html

int[1000000][2]int[2][1000000]では、後者のほうが消費メモリが少なく、この例では1/8程度です。
このような例はフラグを持たせたDPでよく登場します。要素数がとても少ない次元は左のほうに持ってくるようにしましょう。

配列に入れたい値の範囲に応じて適切な型を選択すれば、メモリを節約できる可能性があります。プリミティブ型の消費メモリは以下のようになります。

消費メモリ(byte)
byte / boolean 1
short / char 2
int / float 4
long / double 8

データ構造を使い回す

入力に複数のテストケースが含まれている場合など、配列を何度も生成したくなる場合があります。(AOJのICPC系の問題に多い)

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        while(true) {
            int n = sc.nextInt();
            if (n == 0) break;
            int[] a = new int[n];
            for(int i=0;i<n;i++) {
                a[i] = sc.nextInt();
            }
            //メインの処理
        }
    }
}

配列をいちいちnewで確保するのはメモリの消費量を増やすので、予め確保しておきます。処理時間も短くなります。
初期化したい場合はArrays.fillを使います。

public class Main {
    public static final int MAX_N = 10000;
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int[] a = new int[MAX_N];
        while(true) {
            int n = sc.nextInt();
            if (n == 0) break;
            for(int i=0;i<n;i++) {
                a[i] = sc.nextInt();
            }
            //メインの処理
        }
    }
}

ArrayListなどのCollectionの実装クラスの場合はclear()が使えます。

ガベージコレクション(GC)を実行する

入力に複数のテストケースが含まれるような場合で、配列などの使い回しも面倒な場合は、テストケース毎にGCを行いメモリを開放する手があります。
System.gc()を呼び出すことでGCを実行することができます。実行時間は長くなるので注意して下さい。

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        while(true) {
            int n = sc.nextInt();
            int[] a = new int[n];
            if (n == 0) break;
            for(int i=0;i<n;i++) {
                a[i] = sc.nextInt();
            }
            //メインの処理
            System.gc(); //1テストケースごとにメモリを解放
        }
    }
}

演算の罠

Javaに限った話ではありませんが、基本的な演算に思わぬ落とし穴があったりします。

負の数を含む割り算

結果が負になる整数の割り算をした場合、切り捨てられる方向は絶対値が小さくなる方です。
切り捨ても含めた計算式を考えてコードに落としこむときに間違えやすいです。

public class Main {
    public static void main(String[] args) {
        for(int i=-4;i<=4;i++) {
            System.out.println(i + " / 3 = " + (i/3));
        }
    }
}
出力
-4 / 3 = -1
-3 / 3 = -1
-2 / 3 = 0
-1 / 3 = 0
0 / 3 = 0
1 / 3 = 0
2 / 3 = 0
3 / 3 = 1
4 / 3 = 1

doubleintへのキャストでも同じように絶対値が小さくなる方に丸めが行われます。
必要に応じてMath.floorMath.ceilを使いましょう。

除算と剰余算は遅い

除算と剰余算は他の演算や基本操作(例えば+,-,*,<<,比較など)と比べてとても低速です。
「○○を$10^9+7$で割ったあまりを答えよ」系の問題では、剰余算の数をどれだけ減らせるかが実行時間に大きく響いてきます。

例えばmod $10^9+7$上のフィボナッチ数列を素直に実装すると次のようになります。

public class Main {
    public static final long MOD = 1000000007;
    public static void main(String[] args) {
        long stime = System.nanoTime();
        int n = 10000000;
        long[] fib = new long[n+1];
        fib[0] = 0; fib[1] = 1;
        for(int i=2;i<=n;i++) {
            fib[i] = (fib[i-2] + fib[i-1]) % MOD; //ここが遅い
        }
        System.out.println((System.nanoTime() - stime) / 1000000 + " ms");
        System.out.println(fib[n]);
    }
}

fib[i]には$0$以上$10^9+7$未満の値が入るので、(fib[i-2] + fib[i-1])は$2 \times (10^9+7)$`未満であるはずです。
したがって、剰余算を使わなくても$10^9+7$以上の時に$10^9+7$を引けばよさそうです。

public class Main {
    public static final long MOD = 1000000007;
    public static void main(String[] args) {
        long stime = System.nanoTime();
        int n = 10000000;
        long[] fib = new long[n+1];
        fib[0] = 0; fib[1] = 1;
        for(int i=2;i<=n;i++) {
            fib[i] = fib[i-2] + fib[i-1];
            if (fib[i] >= MOD) fib[i] -= MOD; //条件分岐と減算に変更
        }
        System.out.println((System.nanoTime() - stime) / 1000000 + " ms");
        System.out.println(fib[n]);
    }
}

このように剰余算の回数を減らす工夫をすることで高速化が見込めます。
ideoneで実行した所、後者の方が526ms,191msで2倍以上高速という結果になりましたが、ideoneが32ビット環境というのが大きく効いていると思います。
手元の64ビット環境ではそれぞれ92ms,68msでした。
ソースコード:https://ideone.com/Eq7PFx

ビットシフトの罠

ビットシフトのシフトする量は左辺がintだとmod32、longだとmod64が取られます。

public class Main {
    public static void main(String[] args) {
        for(int i=30;i<=33;i++) {
            System.out.println(1<<i);
        }
    }
}
出力
1073741824
-2147483648
1
2

また、左辺がintで右辺がlongの時、結果はintになるので注意。(他の二項演算子と異なります)

public class Main {
    public static void main(String[] args) {
        long a = 1  << 60 ; //int
        long b = 1  << 60L; //これもint
        long c = 1L << 60 ;
        long d = 1L << 60L;
        System.out.println("1  << 60  :" + a);
        System.out.println("1  << 60L :" + b);
        System.out.println("1L << 60  :" + c);
        System.out.println("1L << 60L :" + d);
    }
}
出力
1  << 60  :268435456
1  << 60L :268435456
1L << 60  :1152921504606846976
1L << 60L :1152921504606846976

Javaでは論理右シフトは>>>演算子を、算術右シフトでは>>を使います。

public class Main {
    public static void main(String[] args) {
        System.out.println(0xFFFFFFFF >>  16);
        System.out.println(0xFFFFFFFF >>> 16);
    }
}
出力
-1
65535

longからdoubleへの変換

longからdoubleへのキャストは暗黙的に行えるが、longはdoubleの仮数部に収まらないので誤差が発生します。

次の例では、long型の大きい奇数$10^{16}+1$をMath.pow(double,double)の第二引数に渡しています。暗黙の型変換が起こりdoubleにキャストされますが、ここで丸め誤差が発生し$10^{16}$になってしまいます。

public class Main {
    public static void main(String[] args) {
        System.out.println(Math.pow(-1, 10000000000000001L));
        System.out.printf("%.0f\n",(double) 10000000000000001L);
    }
}
出力
1.0
10000000000000000

BigDecimalのオーバーフロー

BigDecimalは任意精度符号付き小数を扱うクラスです。
BigDecimalの指数部はint型なので指数がオーバーフローするような計算、つまり非常に大きい値や0にとても近い値になるような場合に例外を発生させます。

その他の罠

オートボクシング / アンボクシング

プリミティブ型とラッパ型間の変換は、オートボクシング / アンボクシング機能によって暗黙的に行われます。
しかしながら、ボクシングが多く発生すると実行速度に影響が出るので注意しましょう。
特に、コレクションはボクシングを回避するように自前で実装することで高速化できることがあります。

ArrayDeque<Integer>と自作キューを、$10^7$回のofferpollをした場合で比較してみます。

public class Main {
    public static final int N = 10000000;
    public static void main(String[] args) {
        Queue<Integer> q1 = new ArrayDeque<>(N);
        long stime = System.nanoTime();
        for(int i=0;i<N;i++) q1.offer(i);
        for(int i=0;i<N;i++) q1.poll();
        System.out.println((System.nanoTime() - stime) / 1000000 + " ms");

        IntQueue q2 = new IntQueue(N);
        stime = System.nanoTime();
        for(int i=0;i<N;i++) q2.offer(i);
        for(int i=0;i<N;i++) q2.poll();
        System.out.println((System.nanoTime() - stime) / 1000000 + " ms");
    }
}
class IntQueue {
    protected int[] a;
    protected int head,tail;
    public IntQueue(int max) {
        a = new int[max];
    }
    public void offer(int x) {
        a[tail++] = x;
    }
    public int poll() {
        return a[head++];
    }
}
実行結果
1788 ms
27 ms

このように、標準ライブラリと自前実装で大きく差がつきました。実装の違いはあるもの、実行時間の差はボクシングの有無にあります。
ボクシングの回数が$10^7$を超えたら普通の制限時間ではTLEするでしょう。$10^6$を超えたら警戒した方がよさそうです。

比較

==を用いてオブジェクトを比較すると、参照の比較になります。次の例のようにStringのインスタンスを2つ生成して比較するとfalseとなります。
一方、equalsを用いて比較した場合、Stringequalsをオーバーライドしているので、文字列の中身で比較します。
オブジェクト同士の比較には基本的にequalsを使いましょう。

public class Main {
    public static void main(String[] args) {
        String a = new String("abc");
        String b = new String("abc");
        System.out.println(a + " " + b);
        System.out.println(a == b);
        System.out.println(a.equals(b));
    }
}
出力
abc abc
false
true

ラッパークラスの比較

ラッパークラス同士を比較するときも、必ずequalsを使うようにしましょう。
ラッパー型の変数を宣言していなくても、標準ライブラリの返り値などに現れるので注意しましょう。
(ArrayList<Integer>を使う場合など)

public class Main {
    public static void main(String[] args) {
        Integer wrapped1 = Integer.valueOf(123456);
        Integer wrapped2 = Integer.valueOf(123456);
        int primitive1 = 123456;
        int primitive2 = 123456;

        //大丈夫
        System.out.println(primitive1 == primitive2);
        System.out.println(primitive1 == wrapped1);
        System.out.println(wrapped1 == primitive1);

        //参照の比較になるのでダメ
        System.out.println(wrapped1 == wrapped2);
    }
}

ソート

プリミティブ型(int,long,doubleなど)の配列のソートをArrays.sort()ですると、最悪計算量が$ \Theta (n^2)$になります。
これはプリミティブ型の配列をソートする際のアルゴリズムがクイックソート(の改良版)であるためで、ほとんどの場合は問題になりません。
ただし、作問側が意地悪な場合や、参加者がテストケースを作れるTopcoderやCodeforcesでは狙われる可能性があります。
対策としては、ラッパ型(Integer,Long,Doubleなど)の配列でソートする方法があります。
Java6とJava7,8ではクイックソートのアルゴリズムが異なるので、攻撃方法も違うそうです。
参考:http://codeforces.com/blog/entry/4827 , http://codeforces.com/blog/entry/7108

WAが実はRE

キャッチされない例外でプログラムが終了しても、REと判定されないケースがあります。
まず、終了コードが0以外でも正常終了とみなすジャッジがあります。(例:SPOJ)
この場合はJavaではおそらくどうしようもないです。

終了コードが0以外の時はREとするジャッジでも、注意するべき状況があります。
次のコードはスタックの拡張のためにスレッドを立てて実行する例で、AtcoderやyukicoderではREと判定されません。

public class Main implements Runnable {
    public static void main(String[] args) {
        new Thread(null, new Main(), "", 8L * 1024 * 1024).start();
    }

    public void run() {
        throw new RuntimeException();
    }
}

ここに1行追加すると、REと判定されるようになります。

public class Main implements Runnable {

    public static void main(String[] args) {
        Thread.setDefaultUncaughtExceptionHandler((t,e)->System.exit(1));
        new Thread(null, new Main(), "", 8L * 1024 * 1024).start();
    }

    public void run() {
        throw new RuntimeException();
    }
}

例外が発生した時にSystem.exit(1)で終了コードを指定して終了するようにします。これでスレッドを立てても大丈夫です。