Java
マルチスレッド

CountDownLatchを使って、複数スレッド間の同期を取る

やること

Java1.5から追加された平行処理支援ライブラリ「java.util.concurrentパッケージ」に含まれているCountDownLatchクラスを使って、複数スレッド間の同期を取ってみる。

CountDownLatchの仕組み

CountDownLatchを一言で言うと、他のスレッドの処理を待機する仕組みである。

CountDownLatchはある整数を”カウント”として保持している。
CountDownLatchのawaitメソッドを呼び出したスレッドは、このカウントが0にならない間、処理を停止する。
別のスレッドは、CountDonwLatchのcountDownメソッドを呼び出す。このメソッドは呼び出される度に、CountDownLatchのカウントを1減らす。
何度かcountDownメソッドが呼び出されて、CountDownLatchのカウントが0になると、awaitメソッドを呼び出して停止していたスレッドは処理を再開する。

簡単な例

問題

ここに同期を取りたいスレッドが2つあるとする。
それぞれ、Aスレッド、Bスレッドとしよう。
Aスレッドは実行中にタスクAを行い、Bスレッドは実行中にタスクBを行う。
さて、この2つのスレッドはそれぞれ別に動作するが、Aスレッドで行われるタスクAが完了した後、BスレッドのタスクBが開始されるようにしたい。

スレッド タスク 順序性
Aスレッド タスクA タスクBに先行してタスクAを行いたい
Bスレッド タスクB タスクAの終了後、タスクBを行いたい

解決策

この問題を解決するには、AスレッドがタスクAの処理が終わったことを何らかの方法でBスレッドに通知し、Bスレッドはその通知を受け取ってから、タスクBを実行するようにすれば良い。

スレッド 役割
Aスレッド BスレッドにタスクAの完了を通知する
Bスレッド タスクAの完了を待機し、AスレッドからタスクAの完了通知を受け取ったらタスクBを開始する

CountDownLatchを使った解決策の実装

1. CountDownLatchの生成

CountDownLatchクラスのインスタンスを生成する。
この時、カウントを1に設定する。理由は、処理待ちするタスクが1つだからである。
生成したインスタンスは、AスレッドとBスレッドで共有する必要があるので、それぞれのスレッドの生成時に引数で渡し、2つのスレッドからアクセスできるようにしておく。

2. Bスレッドが待機する。

BスレッドのタスクBの直前で、CountDownLatchインスタンスのawaitメソッドを呼び出し、Bスレッドの処理を停止する。

3. AスレッドがBスレッドにタスクAの完了を通知する。

AスレッドのタスクAが終わった直後に、CountDownLatchインスタンスのcountDownメソッドを呼び出す。
この呼び出しにより、CountDownLatchのカウントが1から0になる。
カウントが0になったので、awaitメソッドを呼び出して停止していたBスレッドの処理が再開する。
これで、AスレッドのタスクAの完了をBスレッドのタスクBに通知できた。

スレッドが多い場合

スレッドが2つではなく、もっと多い場合について見てみる。

この節では、説明上、用語を以下にしている。

用語 意味
実行スレッド タスクを行うスレッド
待機スレッド 実行スレッドのタスク完了を待つスレッド

◇ 実行スレッド1個:待機スレッドN個

実行スレッド1個、待機スレッドN個の場合を考えてみる。

  1. CountDownLatchを初期化する。このとき、カウント数を1に設定する。
  2. 待機スレッドをN個起動する。各スレッドは、あるところまで処理が進むと、awaitメソッドを呼び出し、処理を中断する。
  3. 実行スレッドがcountDonwメソッドを呼び出す。
  4. 待機していたN個のスレッドが同時に処理を再開する。

◇ 実行スレッドN個:待機スレッド1個

実行スレッドN個、待機スレッド1個の場合を考えてみる。

  1. CountDownLatchを初期化する。このとき、カウント数をNに設定する。
  2. 待機スレッドはCountDownLatchのawaitメソッドを呼び出す。これで待機スレッドは一時停止する。
  3. 全ての実行スレッドは、タスクが完了したタイミングでCountDownLatchのcountDownメソッドを呼び出す。
  4. 実行スレッドN個のタスクが全て終わると、CountDownLatchのカウントはゼロになり、待機スレッドの処理が再開する。

サンプルプログラム

実際にプログラムを書いてみる。

実装する仕様

複数スレッドに計算をさせ、計算結果を最後に合計する。
この設計では、各スレッドの計算が全て終わってから合計処理を行う必要があるので、合計処理はスレッドの計算完了を待機しなければならない。

ソースコード

メインスレッド

このスレッドからワーカースレッドを呼び出し、計算をさせる。
全てのワーカースレッドの計算が完了したら、ワーカースレッドから計算結果を取り出し、合計値を算出、出力する。

メインスレッド
package sample;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class Main {

    // 実行するタスクの数
    private static final int NUM_OF_TASK = 5;

    public static void main(String[] args) {

        ExecutorService service = Executors.newCachedThreadPool();
        CountDownLatch latch = new CountDownLatch(NUM_OF_TASK);
        List<SampleTask> tasks = new ArrayList<SampleTask>();

        // タスクを準備
        for (int i = 0; i < NUM_OF_TASK; i++) {

            // タスクに番号を振る。
            int no = i;

            // タスクの処理時間を1〜10秒でランダムに決める。
            int lifeTime = (int)(Math.random() * 9 + 1);

            tasks.add(new SampleTask(no, lifeTime,latch));
        }

        // タスクを起動する。
        for (SampleTask task : tasks) {
            service.submit(task);
        }

        System.out.println("タスクの起動完了");

        try {
            // ラッチのカウントが0になるのを待機する。
            System.out.println("ラッチを使って完了を待機");
            latch.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println("全てのタスクの処理が完了したため、処理を再開");

        // タスクの計算結果を集計する。
        int sum = 0;
        for (SampleTask task : tasks) {
            System.out.println("No." + task.getNo() + " result:" + task.getResult());
            sum += task.getResult();
        }
        System.out.println("summary:" + sum);
        service.shutdown();
    }
}

ワーカースレッド

このスレッドに計算をさせる。
実際にする計算は、スレッドごとに振った番号に2を掛けて返すだけである。
しかし、スレッドごとに処理の完了時刻がまちまちになるようにしたいので、ランダムな時間スリープする処理を入れてある。

ワーカースレッド
package sample;

import java.text.SimpleDateFormat;
import java.util.Calendar;
import java.util.concurrent.CountDownLatch;

public class SampleTask implements Runnable {

    private static final SimpleDateFormat SDF = new SimpleDateFormat("HH:mm:ss");

    private int no;
    private int time;
    private CountDownLatch latch;
    private int result;

    public int getNo() {
        return no;
    }

    public int getResult() {
        return result;
    }

    public SampleTask(int no, int time, CountDownLatch latch) {
        this.no = no;
        this.time = time;
        this.latch = latch;
    }

    @Override
    public void run() {

        System.out.println(" No." + no + " START 処理時間:" + time + " 現在時刻:" + SDF.format(Calendar.getInstance().getTime()));

        try {
            // 重い処理に見せかけるため、スリープさせる。
            Thread.sleep(time * 1000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        // 実際にやる処理は、noに2を掛けるだけ
        result = no * 2;

        System.out.println(" No." + no +  " END   処理時間:" + time + " 現在時刻:" + SDF.format(Calendar.getInstance().getTime()));

        // 処理が終わったらラッチのカウントを1減らす
        latch.countDown();
    }
}

実行結果

コンソール出力
 No.2 START 処理時間:3 現在時刻:13:10:14
 No.1 START 処理時間:1 現在時刻:13:10:14
 No.0 START 処理時間:8 現在時刻:13:10:14
 No.3 START 処理時間:5 現在時刻:13:10:14
 No.4 START 処理時間:8 現在時刻:13:10:14
タスクの起動完了
ラッチを使って完了を待機
 No.1 END   処理時間:1 現在時刻:13:10:15
 No.2 END   処理時間:3 現在時刻:13:10:17
 No.3 END   処理時間:5 現在時刻:13:10:19
 No.0 END   処理時間:8 現在時刻:13:10:22
 No.4 END   処理時間:8 現在時刻:13:10:22
全てのタスクの処理が完了したため、処理を再開
No.0 result:0
No.1 result:2
No.2 result:4
No.3 result:6
No.4 result:8
summary:20

ワーカースレッドの起動後、メインスレッドがワーカースレッドの処理完了を待機しているのが分かる。
また、ワーカースレッドは最大8秒かけて処理を完了させているが、メインスレッドは、全てのワーカースレッドが処理を終わらせてから、メインスレッド内の合計処理を実行しており、正しく同期が取れていることが確認できる。

サンプルコード格納場所

https://github.com/nogitsune413/CountDownLatchSample

参考

Java8 標準API CountDownLatch

確認環境

Java 1.8