LoginSignup
8

More than 5 years have passed since last update.

Java8のStreamで順列を生成

Last updated at Posted at 2015-05-16

与えられたリストからn個の要素を取り出す順列をストリームとして生成します。

    static <T> List<T> cons(T head, List<T> tail) {
        List<T> list = new ArrayList<>();
        list.add(head);
        list.addAll(tail);
        return list;
    }

    static <T> List<T> remove(T e, List<T> list) {
        List<T> newList = new ArrayList<>(list);
        newList.remove(e);
        return newList;
    }

    public static <E> Stream<List<E>> permutations(int n, List<E> list) {
        if (n <= 0) return Stream.of(new ArrayList<>());
        return list.stream()
            .flatMap(h -> permutations(n - 1, remove(h, list))
                .map(t -> cons(h, t)));
    }

リストの中から要素をひとつ取り出します(h)。残りの要素(remove(h, list))の順列をすべて求めて、得られたそれぞれのリストの先頭に取り出した要素(h)を追加(cons(h, t))します。これで取り出した要素を先頭とするすべての順列が求まります。すべての要素についてこれを行うとすべての順列が得られます。

cons()はLisp系の言語や関数型言語をやっている人から見ると笑いものですね。

リスト[1, 2, 3]の中から2つ取る順列を生成してみます。

    @Test
    public void testPermutationsN() {
        permutations(2, Arrays.asList(1, 2, 3))
            .forEach(System.out::println);
    }

結果は以下のようになります。

[1, 2]
[1, 3]
[2, 1]
[2, 3]
[3, 1]
[3, 2]

これを使って覆面算「SEND + MORE = MONEY」を解いてみます。
問題に表れる英字はS, E, N, D, M, O, R, Yの8種類なので、リスト[0, 1, ... , 9]から8個を取り出す順列を得て、フィルターで式を満たすかどうかをチェックします。SとMは先頭に現れるので0にはなりません。

    int number(int... args) {
        return IntStream.of(args).reduce(0, (a, b) -> 10 * a + b);
    }

    boolean check(int s, int e, int n, int d, int m, int o, int r, int y) {
        if (s == 0 || m == 0) return false;
        int send = number(s, e, n, d);
        int more = number(m, o, r, e);
        int money = number(m, o, n, e, y);
        if (send + more != money) return false;
        System.out.printf("%s + %s = %s%n", send, more, money);
        return true;
    }

    @Test
    public void testSendMoreMoney() {
        permutations(8, Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
            .filter(l -> check(
                l.get(0), l.get(1), l.get(2), l.get(3),
                l.get(4), l.get(5), l.get(6), l.get(7)))
            .forEach(System.out::println);
    }

結果は以下のようになります。

9567 + 1085 = 10652
[9, 5, 6, 7, 1, 0, 8, 2]

私の環境ではpermutations(8, Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))の後に.parallel()をつけると少しだけ速くなりました。

リストのストリームではなくリストのリストを返す場合は以下のようになります。

    public static <E> List<List<E>> permutationsList(int n, List<E> list) {
        if (n <= 0) return Arrays.asList(Arrays.asList());
        List<List<E>> result = new ArrayList<>();
        for (E head : list)
            for (List<E> tail : permutationsList(n - 1, remove(head, list)))
                result.add(cons(head, tail));
        return result;
    }

こっちの方が分かりやすいように思えるのは古い人間だからなのでしょうか?

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8