LoginSignup
1
0

More than 1 year has passed since last update.

Sparkのレコード内でarrayの直積をとって集約したいときの罠と対策

Last updated at Posted at 2021-03-23

[追記] Spark 3.1.2 で修正されたみたいです。
本記事の対策2と同様に、AtomicIntegerでグローバルカウンター作ってますね。
https://github.com/apache/spark/pull/31887


Spark 3でレコード内のarrayを操作する transform filter aggregate などが追加されました。たとえば、

case class AB(as: Seq[Int], bs: Seq[Int])
val ds = Seq(AB(Seq(1, 2, 3), Seq(-1, -2)), AB(Seq(4, 5), Seq(-3, -4, -5))).toDS()
as bs
[1, 2, 3] [-1, -2]
[4, 5] [-3, -4, -5]

というデータについて、

ds.select(
  transform(filter($"as", _ % 2 === 1), _ + 1)
).show(false)
transform(filter(as, lambdafunction(((x % 2) = 1), x)), lambdafunction((x + 1), x))
[2, 4]
[6]

ということができます。 explode したりUDFに頼ることなくarrayを操作できるので便利ですね!

ここでカラムaとカラムbの直積(CROSS JOIN)を取ってみます。とりあえず素朴に書いてみると、

ds.select(
  transform($"as", a => transform($"bs", b => array(a, b)))
).show(false)
transform(as, lambdafunction(transform(bs, lambdafunction(array(x, x), x)), x))
[[[-1, -1], [-2, -2]], [[-1, -1], [-2, -2]], [[-1, -1], [-2, -2]]]
[[[3, -3], [-4, -4], [-5, -5]], [[-3, -3], [-4, -4], [-5, -5]]]

……なんかバグってますね。

原因

なぜこんなことが起こるのかtransformの実装を見てみましょう。

def transform(column: Column, f: Column => Column): Column = withExpr {
  ArrayTransform(column.expr, createLambda(f))
}
private def createLambda(f: Column => Column) = {
  val x = UnresolvedNamedLambdaVariable(Seq("x"))
  val function = f(Column(x)).expr
  LambdaFunction(function, Seq(x))
}

transform が依存している createLambdaUnresolvedNamedLambdaVariable(Seq("x")) とあります。つまりラムダ関数の変数名が x に固定されているため、これをネストするとシャドーイングが起こってしまうということのようです。
Scalaのコレクション処理で書くとこんな状況です。一つ目のラムダ関数の変数 x はどこからも参照されていません。

as.map(x => bs.map(x => (x, x)))

対策

ラムダ関数の変数名を外から与えられるように createLambdatransform の亜種を実装します。

def createLambdaBind(variableName: String, f: Column => Column): LambdaFunction = {
  val x = UnresolvedNamedLambdaVariable(Seq(variableName))
  val function = f(new Column(x)).expr
  LambdaFunction(function, Seq(x))
}

def transform_bind(column: Column, variableName: String, f: Column => Column): Column =
  new Column(ArrayTransform(column.expr, createLambdaBind(variableName, f)))

これを使えば、

ds.select(
  transform_bind($"as","a", a => transform_bind($"bs", "b", b => array(a, b)))
).show(false)
transform(as, lambdafunction(transform(bs, lambdafunction(array(a, b), b)), a))
[[[1, -1], [1, -2]], [[2, -1], [2, -2]], [[3, -1], [3, -2]]]
[[[4, -3], [4, -4], [4, -5]], [[5, -3], [5, -4], [5, -5]]]

と問題なくCROSS JOINできます。ということはカラムaとカラムbの二重ループができるということなので、たとえば、

ds.select(
  array_min(transform_bind($"as", "a", a => array_min(transform_bind($"bs", "b", b => a * b))))
).show(false)
array_min(transform(as, lambdafunction(array_min(transform(bs, lambdafunction((a * b), b))), a)))
-6
-25

のようにして集約することもできます。Scalaのコレクション処理で書くと、

as.map(a => bs.map(b => a * b).min).min

と同等の計算ができています。これで色々できそうですね。

対策2

変数名をいちいち与えるのが面倒だったり、複雑なarray操作をモジュール化していてバッティングしないようにするのも難しい場合は、自動で変数名を払いだしてほしくなります。
簡易的ですが、グローバルカウンターを用意して x0 x1 x2 と振っていくとすると、こんな実装になります。

val lambdaVariableCounter = new AtomicLong()

def createLambdaSafe(f: Column => Column): LambdaFunction = {
  val x = UnresolvedNamedLambdaVariable(Seq(f"x${lambdaVariableCounter.getAndIncrement()}"))
  val function = f(new Column(x)).expr
  LambdaFunction(function, Seq(x))
}

def transform_safe(column: Column, f: Column => Column): Column =
  new Column(ArrayTransform(column.expr, createLambdaSafe(f)))

別解

ちなみにSparkのネイティブ関数のみで対応したい場合、ちょっとややこしいですが、

ds.select(
  transform($"as", a => zip_with(array_repeat(a, size($"bs")), $"bs", array(_, _)))
).show(false)
transform(as, lambdafunction(zip_with(array_repeat(x, size(bs)), bs, lambdafunction(array(x, y), x, y)), x))
[[[1, -1], [1, -2]], [[2, -1], [2, -2]], [[3, -1], [3, -2]]]
[[[4, -3], [4, -4], [4, -5]], [[5, -3], [5, -4], [5, -5]]]

という方法もあります。

実行環境

build.sbt
scalaVersion := "2.12.13"
javacOptions ++= Seq("-source", "14", "-target", "14")

val sparkVersion = "3.1.1"

libraryDependencies ++= Seq(
  "org.apache.spark" %% "spark-core" % sparkVersion,
  "org.apache.spark" %% "spark-sql" % sparkVersion
)
project/build.properties
sbt.version=1.4.8
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0