3
1

More than 3 years have passed since last update.

拡張しやすいStream APIの代替を作ってみる

Posted at

はじめに

Stream APIとはJava SE 8から追加されたデータの列を処理するAPIです。今までコレクションに対して行っていた煩雑な処理をわかりやすいコードで記述することが可能になります。しかし実際に使ってみるといくつか気になる点があります。

  • 終端処理で配列化するときは.toArray()でいいのに、リスト化するときは.collect(Collectors.toList())とする必要があります。
  • .collect()がわかりにくいです。(特にCollectors.groupingBy()など)終端処理を自分で追加できるように、このような実装になったと思いますが、わかりやすさを犠牲にしていると思います。
  • IntStreamLongStreamはありますが、ByteStreamCharStreamはありません。Collection APIにはプリミティブ型を直接格納するコレクションはないのに、なぜStream APIにはそれらを追加したのでしょうか。しかも中途半端に。
  • 中間処理はすべてインスタンスメソッドなので新たな中間処理を追加できません。

これらを改善する新たなAPIを作ってみたいと思います。方針は以下のとおりです。

  • APIの拡張を容易にするためにすべてのAPIをインスタンスメソッドではなく、スタティックメソッドで実装します。
  • プリミティブ型を直接操作するAPIは実装しません。
  • 簡略化のために並列化はあきらめます。
  • FunctionインタフェースやComparatorインタフェースなどは既存のものをできる限り流用します。

この記事で記述したすべてのコードは拡張しやすいStream APIの代替を作ってみるにあります。

Streamに代わるインタフェース

Stream APIの代わりになるものを作るとなると、Streamインタフェースに相当するものを定義する必要があります。新たなインタフェースを定義することもできますがここではIterableインタフェースを使います。CollectionインタフェースはIterableインタフェースを実装しているので、list.stream()のような操作なしで直接的にIterableを取り出すことができます。

中間処理

map

最初にmapを実装してみます。Iterable<T>を受け取って結果としてIterable<U>を返します。List<U>Iterable<U>を実装しているので、List<U>を返してもよいと考えるとこんな実装ができます。

public static <T, U> Iterable<U> map(Function<T, U> mapper, Iterable<T> source) {
    List<U> result = new ArrayList<>();
    for (T element : source)
        result.add(mapper.apply(element));
    return result;
}

しかし、これではsourceに100万の要素があると100万件のArrayListが作成されてしまいます。中間処理では要素が必要になったときにmapperを適用した結果を逐次返すようにする必要があります。そのためには逐次的にmappaerを適用するIterator<U>を実装しなければなりません。Iterator<U>を実装すれば、それを返すIterable<U>を実装するのは簡単です。Iterable<U>Iterator<U> iterator()のみが定義された関数型インタフェースだからです。

public static <T, U> Iterable<U> map(Function<T, U> mapper, Iterable<T> source) {
    return () -> new Iterator<U>() {

        final Iterator<T> iterator = source.iterator();

        @Override
        public boolean hasNext() {
            return iterator.hasNext();
        }

        @Override
        public U next() {
            return mapper.apply(iterator.next());
        }

    };
}

テストする前に終端処理toListを定義しておきます。Iterable<T>を受け取ってList<T>を返します。拡張for文が使えるので簡単です。

public static <T> List<T> toList(Iterable<T> source) {
    List<T> result = new ArrayList<>();
    for (T element : source)
        result.add(element);
    return result;
}

テストはこんな感じになります。比較のためStream APIを使ったコードも記述してみます。

@Test
void testMap() {
    List<String> actual = toList(
        map(String::toUpperCase,
            List.of("a", "b", "c")));
    List<String> expected = List.of("A", "B", "C");
    assertEquals(actual, expected);

    List<String> stream = List.of("a", "b", "c").stream()
        .map(String::toUpperCase)
        .collect(Collectors.toList());
    assertEquals(stream, expected);
}

Stream APIではインスタンスメソッドの連鎖で処理を行うため、上から下に向かって順次処理する感じですが、スタティックメソッドを使っているので、記述の順序が逆になります。少し違和感があるかもしれません。

filter

次にfilterを実装してみます。mapの時と同じように逐次的に処理する必要があるので、Iterator<T>を実装する無名のインナークラスを定義します。mapは入力と出力が1対1に対応するので単純ですが、filter`の場合はそうではないので少しめんどうです。

public static <T> Iterable<T> filter(Predicate<T> selector, Iterable<T> source) {
    return () -> new Iterator<T>() {

        final Iterator<T> iterator = source.iterator();
        boolean hasNext = advance();
        T next;

        boolean advance() {
            while (iterator.hasNext())
                if (selector.test(next = iterator.next()))
                    return true;
            return false;
        }

        @Override
        public boolean hasNext() {
            return hasNext;
        }

        @Override
        public T next() {
            T result = next;
            hasNext = advance();
            return result;
        }
    };
}

filterが呼び出された時点でselectorの条件を満たすものがあるかどうかを先読みして調べておく必要があります。満たすものがあれば、それをインスタンス変数nextに保存しておいて、next()が呼ばれた時点でそれを返します。同時に次のselectorを満たす要素を探しておきます。
テストしてみます。整数の列から偶数だけを取り出して10倍したものの数列を求めます。

@Test
void testFilter() {
    List<Integer> actual = toList(
        map(i -> i * 10,
            filter(i -> i % 2 == 0,
                List.of(0, 1, 2, 3, 4, 5))));
    List<Integer> expected = List.of(0, 20, 40);
    assertEquals(expected, actual);

    List<Integer> stream = List.of(0, 1, 2, 3, 4, 5).stream()
        .filter(i -> i % 2 == 0)
        .map(i -> i * 10)
        .collect(Collectors.toList());
    assertEquals(stream, actual);
}

Streamの場合は終端処理を行ってしまうと、そのStreamは再利用できません。Iterableの場合は途中結果を保存しておいて、あとからそれを再利用することもできます。

@Test
void testSaveFilter() {
    Iterable<Integer> saved;
    List<Integer> actual = toList(
        map(i -> i * 10,
            saved = filter(i -> i % 2 == 0,
                List.of(0, 1, 2, 3, 4, 5))));
    List<Integer> expected = List.of(0, 20, 40);
    assertEquals(expected, actual);

    assertEquals(List.of(0, 2, 4), toList(saved));
}

終端処理

終端処理は前述のtoList()のように、拡張for文を使って結果を保存するだけなので簡単です。

toMap

toMap()は要素からキーと値を取り出すFunctionを使ってMapを作るだけです。

public static <T, K, V> Map<K, V> toMap(Function<T, K> keyExtractor,
    Function<T, V> valueExtractor, Iterable<T> source) {
    Map<K, V> result = new LinkedHashMap<>();
    for (T element : source)
        result.put(keyExtractor.apply(element), valueExtractor.apply(element));
    return result;
}

groupingBy

次に単純なgroupingByを実装してみます。要素からキーを取り出してMapにするだけです。キーが重複する要素はリストに詰め込みます。

public static <T, K> Map<K, List<T>> groupingBy(Function<T, K> keyExtractor,
    Iterable<T> source) {
    Map<K, List<T>> result = new LinkedHashMap<>();
    for (T e : source)
        result.computeIfAbsent(keyExtractor.apply(e), k -> new ArrayList<>()).add(e);
    return result;
}

テストしてみます。文字列の長さでグループ化する例です。

@Test
public void testGroupingBy() {
    Map<Integer, List<String>> actual = groupingBy(String::length,
        List.of("one", "two", "three", "four", "five"));
    Map<Integer, List<String>> expected = Map.of(
        3, List.of("one", "two"),
        5, List.of("three"),
        4, List.of("four", "five"));
    assertEquals(expected, actual);

    Map<Integer, List<String>> stream =
        List.of("one", "two", "three", "four", "five").stream()
        .collect(Collectors.groupingBy(String::length));
    assertEquals(stream, actual);
}

次はキーでグループ化したあと重複した要素を集約するgroupingByです。

static <T, K, V> Map<K, V> groupingBy(Function<T, K> keyExtractor,
    Function<Iterable<T>, V> valueAggregator, Iterable<T> source) {
    return toMap(Entry::getKey, e -> valueAggregator.apply(e.getValue()),
        groupingBy(keyExtractor, source).entrySet());
}

同一キーを持つ要素をvalueAggregatorで集約します。
テストのために終端処理をもう一つ定義しておきます。

public static <T> long count(Iterable<T> source) {
    long count = 0;
    for (@SuppressWarnings("unused")
    T e : source)
        ++count;
    return count;
}

以下は文字列の長さでグループ化して、同一文字列長の文字列の数を数えます。

@Test
public void testGroupingByCount() {
    Map<Integer, Long> actual = groupingBy(String::length, s -> count(s),
        List.of("one", "two", "three", "four", "five"));
    Map<Integer, Long> expected = Map.of(3, 2L, 5, 1L, 4, 2L);
    assertEquals(expected, actual);

    Map<Integer, Long> stream = List.of("one", "two", "three", "four", "five").stream()
        .collect(Collectors.groupingBy(String::length, Collectors.counting()));
    assertEquals(stream, actual);
}

Stream APIでは実装が難しいもの

最後にStream APIでは実装が難しいものをを実装してみましょう。いずれも中間処理です。Stream APIにおける中間操作はすべてインスタンスメソッドなので、容易に拡張できません。

zip

zipは二つのデータ列の先頭から1件ごとにマッチさせて、一つのデータ列にする処理です。二つの入力列の長さが異なる場合は、短い方に合わせて長い方の残りは無視します。

static <T, U, V> Iterable<V> zip(BiFunction<T, U, V> zipper, Iterable<T> source1,
    Iterable<U> source2) {
    return () -> new Iterator<V>() {

        final Iterator<T> iterator1 = source1.iterator();
        final Iterator<U> iterator2 = source2.iterator();

        @Override
        public boolean hasNext() {
            return iterator1.hasNext() && iterator2.hasNext();
        }

        @Override
        public V next() {
            return zipper.apply(iterator1.next(), iterator2.next());
        }

    };
}

整数と文字列の並びを文字列の並びにするテストです。

@Test
void testZip() {
    List<String> actual = toList(
        zip((x, y) -> x + "-" + y,
            List.of(0, 1, 2),
            List.of("zero", "one", "two")));
    List<String> expected = List.of("0-zero", "1-one", "2-two");
    assertEquals(expected, actual);
}

cumulative

要素を累積する中間処理です。reduce()に似ていますが、reduce()が終端処理で一つの値を返すのに対し、cumulativeは列を返します。

public static <T, U> Iterable<U> cumulative(U unit, BiFunction<U, T, U> function,
    Iterable<T> source) {
    return () -> new Iterator<U>() {

        Iterator<T> iterator = source.iterator();
        U accumlator = unit;

        @Override
        public boolean hasNext() {
            return iterator.hasNext();
        }

        @Override
        public U next() {
            return accumlator = function.apply(accumlator, iterator.next());
        }

    };
}

先頭からの部分和を求めるテストです。

@Test
public void testCumalative() {
    List<Integer> actual = toList(
        cumulative(0, (x, y) -> x + y,
            List.of(0, 1, 2, 3, 4, 5)));
    List<Integer> expected = List.of(0, 1, 3, 6, 10, 15);
    assertEquals(expected, actual);
}

flatMap

flatMap()はStream APIにもありますが、実装方法がわかりにくいので、ここに載せておきます。

public static <T, U> Iterable<U> flatMap(Function<T, Iterable<U>> flatter, Iterable<T> source) {
    return () -> new Iterator<U>() {

        final Iterator<T> parent = source.iterator();
        Iterator<U> child = null;
        boolean hasNext = advance();
        U next;

        boolean advance() {
            while (true) {
                if (child == null) {
                    if (!parent.hasNext())
                        return false;
                    child = flatter.apply(parent.next()).iterator();
                }
                if (child.hasNext()) {
                    next = child.next();
                    return true;
                }
                child = null;
            }
        }

        @Override
        public boolean hasNext() {
            return hasNext;
        }

        @Override
        public U next() {
            U result = next;
            hasNext = advance();
            return result;
        }

    };
}

以下はテストです。各要素を2個ずつの並びにに膨らませます。

@Test
public void testFlatMap() {
    List<Integer> actual = toList(
        flatMap(i -> List.of(i, i),
            List.of(0, 1, 2, 3)));
    List<Integer> expected = List.of(0, 0, 1, 1, 2, 2, 3, 3);
    assertEquals(expected, actual);

    List<Integer> stream = List.of(0, 1, 2, 3).stream()
        .flatMap(i -> Stream.of(i, i))
        .collect(Collectors.toList());
    assertEquals(stream, actual);
}

最後に

Stream APIに対応するすべてのAPIを実装したわけではありませんが、意外と簡単に相当品が作れることがわかりました。

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1