[追記] Spark 3.1.2 で修正されたみたいです。
本記事の対策2と同様に、AtomicIntegerでグローバルカウンター作ってますね。
https://github.com/apache/spark/pull/31887
罠
Spark 3でレコード内のarrayを操作する transform
filter
aggregate
などが追加されました。たとえば、
case class AB(as: Seq[Int], bs: Seq[Int])
val ds = Seq(AB(Seq(1, 2, 3), Seq(-1, -2)), AB(Seq(4, 5), Seq(-3, -4, -5))).toDS()
as | bs |
---|---|
[1, 2, 3] | [-1, -2] |
[4, 5] | [-3, -4, -5] |
というデータについて、
ds.select(
transform(filter($"as", _ % 2 === 1), _ + 1)
).show(false)
transform(filter(as, lambdafunction(((x % 2) = 1), x)), lambdafunction((x + 1), x)) |
---|
[2, 4] |
[6] |
ということができます。 explode
したりUDFに頼ることなくarrayを操作できるので便利ですね!
ここでカラムaとカラムbの直積(CROSS JOIN)を取ってみます。とりあえず素朴に書いてみると、
ds.select(
transform($"as", a => transform($"bs", b => array(a, b)))
).show(false)
transform(as, lambdafunction(transform(bs, lambdafunction(array(x, x), x)), x)) |
---|
[[[-1, -1], [-2, -2]], [[-1, -1], [-2, -2]], [[-1, -1], [-2, -2]]] |
[[[3, -3], [-4, -4], [-5, -5]], [[-3, -3], [-4, -4], [-5, -5]]] |
……なんかバグってますね。
原因
なぜこんなことが起こるのかtransform
の実装を見てみましょう。
def transform(column: Column, f: Column => Column): Column = withExpr {
ArrayTransform(column.expr, createLambda(f))
}
private def createLambda(f: Column => Column) = {
val x = UnresolvedNamedLambdaVariable(Seq("x"))
val function = f(Column(x)).expr
LambdaFunction(function, Seq(x))
}
transform
が依存している createLambda
に UnresolvedNamedLambdaVariable(Seq("x"))
とあります。つまりラムダ関数の変数名が x
に固定されているため、これをネストするとシャドーイングが起こってしまうということのようです。
Scalaのコレクション処理で書くとこんな状況です。一つ目のラムダ関数の変数 x
はどこからも参照されていません。
as.map(x => bs.map(x => (x, x)))
対策
ラムダ関数の変数名を外から与えられるように createLambda
と transform
の亜種を実装します。
def createLambdaBind(variableName: String, f: Column => Column): LambdaFunction = {
val x = UnresolvedNamedLambdaVariable(Seq(variableName))
val function = f(new Column(x)).expr
LambdaFunction(function, Seq(x))
}
def transform_bind(column: Column, variableName: String, f: Column => Column): Column =
new Column(ArrayTransform(column.expr, createLambdaBind(variableName, f)))
これを使えば、
ds.select(
transform_bind($"as","a", a => transform_bind($"bs", "b", b => array(a, b)))
).show(false)
transform(as, lambdafunction(transform(bs, lambdafunction(array(a, b), b)), a)) |
---|
[[[1, -1], [1, -2]], [[2, -1], [2, -2]], [[3, -1], [3, -2]]] |
[[[4, -3], [4, -4], [4, -5]], [[5, -3], [5, -4], [5, -5]]] |
と問題なくCROSS JOINできます。ということはカラムaとカラムbの二重ループができるということなので、たとえば、
ds.select(
array_min(transform_bind($"as", "a", a => array_min(transform_bind($"bs", "b", b => a * b))))
).show(false)
array_min(transform(as, lambdafunction(array_min(transform(bs, lambdafunction((a * b), b))), a))) |
---|
-6 |
-25 |
のようにして集約することもできます。Scalaのコレクション処理で書くと、
as.map(a => bs.map(b => a * b).min).min
と同等の計算ができています。これで色々できそうですね。
対策2
変数名をいちいち与えるのが面倒だったり、複雑なarray操作をモジュール化していてバッティングしないようにするのも難しい場合は、自動で変数名を払いだしてほしくなります。
簡易的ですが、グローバルカウンターを用意して x0
x1
x2
と振っていくとすると、こんな実装になります。
val lambdaVariableCounter = new AtomicLong()
def createLambdaSafe(f: Column => Column): LambdaFunction = {
val x = UnresolvedNamedLambdaVariable(Seq(f"x${lambdaVariableCounter.getAndIncrement()}"))
val function = f(new Column(x)).expr
LambdaFunction(function, Seq(x))
}
def transform_safe(column: Column, f: Column => Column): Column =
new Column(ArrayTransform(column.expr, createLambdaSafe(f)))
別解
ちなみにSparkのネイティブ関数のみで対応したい場合、ちょっとややこしいですが、
ds.select(
transform($"as", a => zip_with(array_repeat(a, size($"bs")), $"bs", array(_, _)))
).show(false)
transform(as, lambdafunction(zip_with(array_repeat(x, size(bs)), bs, lambdafunction(array(x, y), x, y)), x)) |
---|
[[[1, -1], [1, -2]], [[2, -1], [2, -2]], [[3, -1], [3, -2]]] |
[[[4, -3], [4, -4], [4, -5]], [[5, -3], [5, -4], [5, -5]]] |
という方法もあります。
実行環境
scalaVersion := "2.12.13"
javacOptions ++= Seq("-source", "14", "-target", "14")
val sparkVersion = "3.1.1"
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % sparkVersion,
"org.apache.spark" %% "spark-sql" % sparkVersion
)
sbt.version=1.4.8