Java
java8
Stream
StreamAPI

Java8 StreamAPI reduce() で入力要素とは異なる型にして返したい

More than 1 year has passed since last update.

Javaだとうまくいかない?

例えば、文字列の長さの合計を計算したい場合、Pythonだとこうします。

>>> fruits = [ 'apple', 'orange', 'kiwi' ]
>>> reduce(lambda sum, elm: sum + len(elm), fruits, 0)
15

map で要素のサイズに変換して reduce だろという意見はおいておいて下さい (笑

Javaだとこう...かと思いきや、エラーになります。 (´・ω・`)

List<String> fruits = Arrays.asList("apple", "orange", "kiwi");
fruits.stream()
        .reduce(0, (sum, elm) -> sum + elm.length()); // Integer: 15 ?
-------------------------------------------------------------
COMPILATION ERROR : 
-------------------------------------------------------------
Javatest.java:[33,8] error: no suitable method found for reduce(int,(sum,elm)-[...]gth())
1 error
-------------------------------------------------------------

引数2個の reduce() は、要素の型と同じ型の戻り値をかえす

普段使っている? list.stream().reduce(...) は引数2個 (初期値 と lambda) のやつです。

java/util/stream/Stream.java
T reduce(T identity,
         BinaryOperator<T> accumulator);

要素の型と戻り値の型が同じ T なので、型を合わせないとコンパイルが通りません。

List<String> fruits = Arrays.asList("apple", "orange", "kiwi");
String result = fruits.stream()
        .reduce("I like ", (sum, elm) -> sum + ", " + elm)); // String: "I like apple, orange, kiwi, "

引数3個の reduce() は、要素の型と異なる型の戻り値をかえせる

実は引数3個の list.stream().reduce(...) があります。これを使うと要素の型 T と異なる型 U の値を返す事ができます。

java/util/stream/Stream.java
<U> U reduce(U identity,
             BiFunction<U, ? super T, U> accumulator,
             BinaryOperator<U> combiner);

例えば以下の様に書けます。sum=int, elm=String, sum1=int, sum2=int です。

List<String> fruits = Arrays.asList("apple", "orange", "kiwi");
int result = fruits.stream().reduce(
        0,                                 // 初期値
        (sum, elm) -> sum + elm.length(),  // accumulator. 中間生成物を作る.
        (sum1, sum2) -> sum1 + sum2);      // combiner. 中間生成物どうしをマージする.

accumulator (第2引数) と combiner (第3引数) を指定しています。何故2つもLambdaが必要かと言うと、並列実行 .parallel() される場合に必要になります。

例えば、accumulator (sum, elm) -> sum + elm.length() が並列で実行されると中間生成物 (int型) が複数できる事になります。
このint型の中間生成物同士をマージする為に combiner (sum1, sum2) -> sum1 + sum2 が必要になります。

例えば、2-threadで並列実行された場合、以下の様になるだろうと考えられます。

実際の処理の流れとは少々異なりますが、簡単化して書いています

まず、全ての要素 "apple", "orange", "kiwi" を accumulator (第2引数) で処理し、中間生成物を作ります。

  1. 初期値 + 要素1 ==> 中間生成物1
    • (0, "apple") -> 5
  2. 初期値 + 要素2 ==> 中間生成物2
    • (0, "orange") -> 6
  3. 中間生成物1 + 要素3 ==> 中間生成物3
    • (5, "kiwi") -> 9

生成された中間生成物を全て combiner (第3引数) でマージし、一つの値に集約します。

  1. 中間生成物2 + 中間生成物3 ==> 中間生成物4
    • (6, 9) -> 15

最後に残った 中間生成物4=15 が最終的な戻り値となります。

参考