LoginSignup
9
9

More than 5 years have passed since last update.

[Java 8+] マップをマージする

Last updated at Posted at 2017-03-18

やること

複数のマップを1つのマップにマージすることを考えます。
最終的には以下のようなユーティリティメソッドに一般化します。

// Map<K, V>[] を Map<K, V> にマージ
public static <K, V> Map<K, V> merge(BiFunction<? super V, ? super V, ? extends V> mergeFunction, Map<K, V>... maps)

// Map<K, V>[] を Map<K, R> にマージ
public static <K, V, R> Map<K, R> merge(Collector<V, ?, R> mergeCollector, Map<K, V>... maps)

// Map<K, List<V>>[] を Map<K, List<V>> にマージ
public static <K, V> Map<K, List<V>> merge(Map<K, List<V>>... maps)

基本的な考え方

Map<K, V>[]Stream<Entry<K, V>> に展開してコレクトしますが、キーの一致するエントリーが複数存在する場合の処理はケースバイケースで考えなければなりません。
大きく分けて次の2パターンが考えられます。

  • Collectors#toMap() を使ったマージ : キーの一致した2値をマージ関数で処理して格納するのを繰り返す
  • Collectors#grounpingBy() を使ったマージ : キーの一致した値を下位コレクタで一括処理する

なお、マージ対象の Map が2つだけの場合はどちらでもやれることは変わりませんので、単純な Collectors#toMap() の使用をおすすめします。

Collectors#toMap() を使ったマージ

Map#merge() と同様に、既に格納されている値に新たな値をマージする、という操作の繰り返しでマージを行える場合は Collectors#toMap() を利用できます。

Map<String, Integer> m1 = new HashMap<>();
m1.put("a", 1);
m1.put("b", 2);
m1.put("c", 3);
Map<String, Integer> m2 = new HashMap<>();
m2.put("a", 100);
m2.put("b", 200);
m2.put("c", 300);

// キーが一致する場合値の和を取る
Map<String, Integer> m3 =
    Stream.of(m1, m2)
          .flatMap(m -> m.entrySet().stream())
          .collect(Collectors.toMap(Entry::getKey, Entry::getValue, Integer::sum));

Map#merge() の第3引数が BiFunction なのに対し Collectors#toMap() の第3引数は (サブインタフェースの) BinaryOperator でやや制限がありますが、特に問題はないでしょう。既に Map#merge() に渡している BiFunction が存在する場合は、BiFunction#apply() を参照すればそのまま適用できます。

// Map#merge() の第3引数の型
BiFunction<? super V, ? super V, ? extends V> remappingFunction = ...;

// Collectors#toMap() の第3引数の型に変換
BinaryOperator<V> mergeFunction = remappingFunction::apply;

Collectors#grounpingBy() を使ったマージ

同一のキーに対する値が3つ以上あり、それらの平均をマージ後の値としたいような場合は前述の方法では誤差が出ます。また、インタフェースが BinaryOperator ですので元の値と結果の値で型が異なるような集約を行うこともできません。
このような場合は Collectors#groupingBy() でキーごとの値をコレクションに蓄積し、フィニッシャで結果に変換するような方法を取る必要があります。

// キーが一致する場合値の平均を取る
Map<String, Double> m4 =
    Stream.of(m1, m2)
          .flatMap(m -> m.entrySet().stream())
          .collect(Collectors.groupingBy(Entry::getKey, Collectors.averagingInt(Entry::getValue)));

Map<K, List<V>> をマージする例

もう少し具体的な例として、Map<K, List<V>> 形式のデータを上記2つの方法でマージしてみます。

Map<String, List<String>> lm1 = new HashMap<>();
lm1.put("a", Arrays.asList("a1","a2","a3"));
lm1.put("b", Arrays.asList("b1","b2","b3"));
lm1.put("c", Arrays.asList("c1","c2","c3"));
Map<String, List<String>> lm2 = new HashMap<>();
lm2.put("a", Arrays.asList("a3","a4","a5"));
lm2.put("e", Arrays.asList("e1","e2","e3"));
lm2.put("f", Arrays.asList("f1","f2","f3"));

// リストをマージする関数
// メソッド化しておけばメソッド参照で済む
BinaryOperator<List<String>> mergeList = (left, right) -> Stream.of(left, right).flatMap(List::stream).collect(Collectors.toList());

// この mergeList を Collector でしか使わないのであれば下記実装の方が高速かつ省メモリでしょう
// BinaryOperator<List<String>> mergeList = (l, r) -> {l.addAll(r); return l;};

// toMap() でマージ
Map<String, List<String>> lm3 =
    Stream.of(lm1, lm2)
          .flatMap(m -> m.entrySet().stream())
          .collect(Collectors.toMap(Entry::getKey, Entry::getValue, mergeList));

// Stream<List<E>> を List<E> にマージするコレクタ
// Stream<E> を List<E> に集約する Collectors#toList() とは異なる
Collector<List<String>, ?, List<String>> collector = Collector.of(ArrayList::new, List::addAll, mergeList, Characteristics.IDENTITY_FINISH);

// groupingBy() でマージ
Map<String, List<String>> lm4 =
    Stream.of(lm1, lm2)
          .flatMap(m -> m.entrySet().stream())
          .collect(Collectors.groupingBy(Entry::getKey, Collectors.mapping(Entry::getValue, collector)));

一般化

最後に、上記を一般化してユーティリティクラスにまとめたものを例示しておきます。

import static java.util.stream.Collectors.*;

import java.util.*;
import java.util.Map.Entry;
import java.util.function.*;
import java.util.stream.*;

public class MapMerger {

    // Map<K, V> の配列を Entry<K, V> のストリームに展開
    @SafeVarargs
    private static <K, V> Stream<Entry<K, V>> flatten(Map<K, V>... maps){
        return Arrays.stream(maps).flatMap(map -> map.entrySet().stream());
    }

    // マップを異なる型のマップにマージ
    @SafeVarargs
    public static <K, V, R> Map<K, R> merge(Collector<V, ?, R> collector, Map<K, V>... maps){
        return flatten(maps).collect(groupingBy(Entry::getKey, mapping(Entry::getValue, collector)));
    }

    // マップを同じ型のマップにマージ
    @SafeVarargs
    public static <K, V> Map<K, V> merge(BiFunction<? super V, ? super V, ? extends V> mergeFunction, Map<K, V>... maps) {
        return flatten(maps).collect(toMap(Entry::getKey, Entry::getValue, mergeFunction::apply));
        // コレクタ版の実装を活用するなら
//        Function<List<V>, V> finisher = values -> values.stream().reduce(mergeFunction::apply).get();
//        Collector<V, ?, V> collector = Collector.of(ArrayList::new, List::add, MapMerger::mergeIntoList, finisher, Characteristics.IDENTITY_FINISH);
//        return merge(collector, maps);
    }

    // コレクションをリストにマージ
    @SafeVarargs
    private static <E> List<E> mergeIntoList(Collection<E>... collections){
        return Arrays.stream(collections).flatMap(Collection::stream).collect(toList());
    }

    // Map<K, List<V>> をマージ
    @SafeVarargs
    public static <K, V> Map<K, List<V>> merge(Map<K, List<V>>... maps){
        return merge(MapMerger::mergeIntoList, maps);
    }

    // 使用例
    public static void main(String[] args) {
        // サンプルデータ1
        Map<String, Integer> m1 = new HashMap<>();
        m1.put("a", 1);
        m1.put("b", 2);
        m1.put("c", 3);
        Map<String, Integer> m2 = new HashMap<>();
        m2.put("a", 100);
        m2.put("b", 200);
        m2.put("c", 300);

        // 和
        Map<String, Integer> m3 = merge(Integer::sum, m1, m2);
        Map<String, Integer> m4 = merge(summingInt(Integer::intValue), m1, m2);
        // 平均
        Map<String, Double> m5 = merge(averagingInt(Integer::intValue), m1, m2);

        // サンプルデータ2
        Map<String, List<String>> lm1 = new HashMap<>();
        lm1.put("a", Arrays.asList("a1","a2","a3"));
        lm1.put("b", Arrays.asList("b1","b2","b3"));
        lm1.put("c", Arrays.asList("c1","c2","c3"));
        Map<String, List<String>> lm2 = new HashMap<>();
        lm2.put("a", Arrays.asList("a3","a4","a5"));
        lm2.put("e", Arrays.asList("e1","e2","e3"));
        lm2.put("f", Arrays.asList("f1","f2","f3"));

        // 同型マージ
        Map<String, List<String>> lm3 = merge(lm1, lm2);
        Map<String, List<String>> lm4 = merge(MapMerger::mergeIntoList, lm1, lm2);

        // Map<String, String> へマージ
        Collector<List<String>, ?, String> collector = Collector.<List<String>, List<String>, String>of(
                ArrayList::new,
                List::addAll,
                MapMerger::mergeIntoList,
                List::toString);
        Map<String, String> lm5 = merge(collector, lm1, lm2);

        // 出力
        Stream.of(m3, m4, m5).forEach(System.out::println);
        Stream.of(lm3, lm4, lm5).forEach(System.out::println);
    }
}
9
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
9