Edited at

[Sparkメモ] 重複排除の書き方いろいろ


はじめに

お仕事でSparkを2年ほど使っていまして、だいぶノウハウが溜まってきたのとどこかに記録として残しておきたいので少しづつですがポストしていこうと思います。

記載しているコードはインタラクティブシェル(bin/spark-shell)で確認しています。

Sparkのバージョンが古いですが最新版でも軽微な修正で対応できると思います。


環境、前提


  • macOS Sierra(10.12.6)

  • Apache Spark 1.6+

  • Scala 2.10+


重複排除


テストデータ

val words = Array("one", "two", "two", "three", "three", "three")

val wordsDF = sc.parallelize(words).toDF("word")


実行&結果

scala> val words = Array("one", "two", "two", "three", "three", "three")

words: Array[String] = Array(one, two, two, three, three, three)

scala> val wordsDF = sc.parallelize(words).toDF("word")
wordsDF: org.apache.spark.sql.DataFrame = [word: string]

scala> wordsDF.printSchema
root
|-- word: string (nullable = true)

scala> wordsDF.show
+-----+
| word|
+-----+
| one|
| two|
| two|
|three|
|three|
|three|
+-----+


distinct

一番簡単な方法ですね。

wordsDF.distinct


実行&結果

scala> val df = wordsDF.distinct

df: org.apache.spark.sql.DataFrame = [word: string]

scala> df.printSchema
root
|-- word: string (nullable = true)

scala> df.show
+-----+
| word|
+-----+
|three|
| two|
| one|
+-----+


distinct(SparkSQL ver)

先と同じdistinctですが、こちらはSparkSQLのselect distinctになります。

wordsDF.registerTempTable("words")

sqlContext.sql("select distinct word from words")


実行&結果

scala> wordsDF.registerTempTable("words")

scala> val df = sqlContext.sql("select distinct word from words")
df: org.apache.spark.sql.DataFrame = [word: string]

scala> df.printSchema
root
|-- word: string (nullable = true)

scala> df.show
+-----+
| word|
+-----+
|three|
| two|
| one|
+-----+


groupBy().agg()

groupBY().agg()を使った方法です。

最後のselect($"word")groupBy().agg()の結果を確認していただくと理由がわかると思います。

wordsDF.groupBy($"word").agg($"word").select($"word")


実行&結果

scala> val df = wordsDF.groupBy($"word").agg($"word").select($"word")

df: org.apache.spark.sql.DataFrame = [word: string]

scala> df.printSchema
root
|-- word: string (nullable = true)

scala> df.show
+-----+
| word|
+-----+
|three|
| two|
| one|
+-----+


reduceByKey、aggregateByKeyで重複排除


テストデータ

val users = Array(("userA", "male"), ("userA", "T"), ("userA", "male"), ("userB", "female"), ("userB", "F2"))

val usersDF = sc.parallelize(users).toDF("user_id", "persona")


実行&結果

scala> val users = Array(

| ("userA", "male"), ("userA", "T"), ("userA", "male"),
| ("userB", "female"), ("userB", "F2"))
users: Array[(String, String)] = Array((userA,male), (userA,T), (userA,male), (userB,female), (userB,F2))

scala> val usersDF = sc.parallelize(users).toDF("user_id", "persona")
usersDF: org.apache.spark.sql.DataFrame = [user_id: string, persona: string]

scala> usersDF.printSchema
root
|-- user_id: string (nullable = true)
|-- persona: string (nullable = true)

scala> usersDF.show
+-------+-------+
|user_id|persona|
+-------+-------+
| userA| male|
| userA| T|
| userA| male|
| userB| female|
| userB| F2|
+-------+-------+


reductByKey()

ポイントはSet()を使って重複を排除しているところですね。

難点は毎回Setのインスタンスを作ってしまうというね。

usersDF.map(row => (row.getString(0), Set(row.getString(1)))).reduceByKey(_ ++ _)


実行&結果

scala> val rdd = usersDF.map(row => (row.getString(0), Set(row.getString(1)))).reduceByKey(_ ++ _)

rdd: org.apache.spark.rdd.RDD[(String, scala.collection.immutable.Set[String])] = ShuffledRDD[75] at reduceByKey at <console>:31

scala> val df = rdd.map { case (user_id, persona) => (user_id, persona.toSeq) }.toDF("user_id", "persona")
df: org.apache.spark.sql.DataFrame = [user_id: string, persona: array<string>]

scala> df.printSchema
root
|-- user_id: string (nullable = true)
|-- persona: array (nullable = true)
| |-- element: string (containsNull = true)

scala> df.show
+-------+------------+
|user_id| persona|
+-------+------------+
| userA| [male, T]|
| userB|[female, F2]|
+-------+------------+


aggregateByKey()

こちらもSet()を使っていますが、最初に初期化しているところですね。

import scala.collection.mutable

usersDF
  .map(row => (row.getString(0), Set(row.getString(1))))
  .aggregateByKey(mutable.Set.empty[String])((x, y) => x ++= y, (x, y) => x ++= y)

※コードが長いため途中で折り返しています。


実行&結果

scala> val rdd = usersDF.map(row => (row.getString(0), Set(row.getString(1)))).aggregateByKey(mutable.Set.empty[String])((x, y) => x ++= y, (x, y) => x ++= y)

rdd: org.apache.spark.rdd.RDD[(String, scala.collection.mutable.Set[String])] = ShuffledRDD[87] at aggregateByKey at <console>:32

scala> val df = rdd.map { case (user_id, persona) => (user_id, persona.toSeq) }.toDF("user_id", "persona")
df: org.apache.spark.sql.DataFrame = [user_id: string, persona: array<string>]

scala> df.printSchema
root
|-- user_id: string (nullable = true)
|-- persona: array (nullable = true)
| |-- element: string (containsNull = true)

scala> df.show
+-------+------------+
|user_id| persona|
+-------+------------+
| userA| [T, male]|
| userB|[F2, female]|
+-------+------------+

長々とお付き合いいただきありがとうございました。