Apache DataFusion には UDF(ユーザー定義関数)の仕組みがあり、任意の処理を関数として定義して SQL 等で利用できます。
UDF には下記のような種類がありますが、今回は Scalar UDF を作成してみました。
- Scalar
- Window
- Aggregate
- Table
なお、UDF の実現方法としては 2通り用意されています。
定義方法 | 備考 |
---|---|
create_udf 関数利用 | 簡易的で制限あり |
ScalarUDFImpl トレイト実装 | Low Level API で複雑な処理も実現可能 |
UDF作成
ここでは単純な UDF を作成します。
今回使用した Cargo.toml はこのような内容です。
[dependencies]
datafusion = "46.0.1"
tokio = { version = "1.0", features = ["rt-multi-thread"] }
1. create_udf による UDF 定義1
create_udf 関数の内容はこのようになっており、戻り値の ScalarUDF を SessionContext
へ register_udf
する事で使用できるようになります。
create_udf(関数名, 引数の型, 戻り値の型, 変動性, 実装関数) -> ScalarUDF
例えば、2つの文字列を引数にとってそれを連結する関数 concat
を create_udf で定義するとこのようになります。
use datafusion::arrow::datatypes::DataType;
use datafusion::common::plan_err;
use datafusion::error::Result;
use datafusion::logical_expr::{ColumnarValue, Volatility};
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
use std::sync::Arc;
// UDF実装
fn concat(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let (a, b) = (&args[0], &args[1]);
match (a, b) {
(
ColumnarValue::Scalar(ScalarValue::Utf8(a)),
ColumnarValue::Scalar(ScalarValue::Utf8(b)),
) => {
let r = a.clone().and_then(|x| b.clone().map(|y| x + &y));
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(r)))
}
_ => plan_err!("unsupported arg types"),
}
}
#[tokio::main]
async fn main() -> Result<()> {
// UDF作成
let udf = create_udf(
"concat",
vec![DataType::Utf8, DataType::Utf8],
DataType::Utf8,
Volatility::Immutable,
Arc::new(concat),
);
let ctx = SessionContext::new();
// UDF登録
ctx.register_udf(udf);
run_query(&ctx, "SELECT concat('ab', 'cdef')").await?;
run_query(&ctx, "SELECT concat(1234, 56)").await?;
run_query(&ctx, "SELECT concat([1, 2, 3], ['a', 'b'])").await?;
// この処理はエラーになる
run_query(
&ctx,
"SELECT concat(unnest(['a1', 'b22', 'c333']), '-sample')",
)
.await?;
Ok(())
}
async fn run_query(ctx: &SessionContext, query: &str) -> Result<()> {
let df = ctx.sql(query).await?;
df.show().await?;
Ok(())
}
実行結果はこうなります。
数値や配列等もスカラーな文字列として連結するのがポイントです。
ただし、テーブルの行データや unnest
の結果は ColumnarValue::Scalar
にはマッチしないためエラーになります。
+---------------------------------+
| concat(Utf8("ab"),Utf8("cdef")) |
+---------------------------------+
| abcdef |
+---------------------------------+
+-------------------------------+
| concat(Int64(1234),Int64(56)) |
+-------------------------------+
| 123456 |
+-------------------------------+
+--------------------------------------------------------------------------------+
| concat(make_array(Int64(1),Int64(2),Int64(3)),make_array(Utf8("a"),Utf8("b"))) |
+--------------------------------------------------------------------------------+
| [1, 2, 3][a, b] |
+--------------------------------------------------------------------------------+
Error: Plan("unsupported arg types")
2. create_udf による UDF 定義2
unnest の結果を処理するには ColumnarValue::Array
にマッチする処理を追加します。
例えば、ColumnarValue::Array の個々の要素へ第二引数の文字列を連結する処理はこのようになります。
...省略
fn concat(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let (a, b) = (&args[0], &args[1]);
match (a, b) {
...省略
(ColumnarValue::Array(a), ColumnarValue::Scalar(ScalarValue::Utf8(b))) => {
let r = a
.as_string::<i32>()
.iter()
.map(|x| x.and_then(|x| b.clone().map(|y| x.to_string() + &y)))
.collect::<StringArray>();
Ok(ColumnarValue::Array(Arc::new(r)))
}
_ => plan_err!("unsupported arg types"),
}
}
...省略
実行結果はこうなります。
...省略
+---------------------------------------------------------------------------------+
| concat(UNNEST(make_array(Utf8("a1"),Utf8("b22"),Utf8("c333"))),Utf8("-sample")) |
+---------------------------------------------------------------------------------+
| a1-sample |
| b22-sample |
| c333-sample |
+---------------------------------------------------------------------------------+
3. ScalarUDFImpl トレイト実装
同様の処理を ScalarUDFImpl
トレイトで実装すると、例えばこのようになります。
UDF の引数の型によって戻り値の型を変化させたり等、create_udf
を使うよりも柔軟な対応が可能となるのが特徴です。
use datafusion::arrow::array::{AsArray, StringArray};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::plan_err;
use datafusion::error::Result;
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
use std::sync::Arc;
#[derive(Debug, Clone)]
struct Concat {
signature: Signature,
}
impl Concat {
fn new() -> Self {
Self {
signature: Signature::uniform(
2,
vec![DataType::Utf8, DataType::Utf8],
Volatility::Immutable,
),
}
}
}
// UDF実装
impl ScalarUDFImpl for Concat {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
"concat"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
if args.args.len() == 2 {
let a = args.args.first();
let b = args.args.last();
match (a, b) {
(
Some(ColumnarValue::Scalar(ScalarValue::Utf8(a))),
Some(ColumnarValue::Scalar(ScalarValue::Utf8(b))),
) => {
let r = a.clone().and_then(|x| b.clone().map(|y| x + &y));
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(r)))
}
(
Some(ColumnarValue::Array(a)),
Some(ColumnarValue::Scalar(ScalarValue::Utf8(b))),
) => {
let r = a
.as_string::<i32>()
.iter()
.map(|x| x.and_then(|x| b.clone().map(|y| x.to_string() + &y)))
.collect::<StringArray>();
Ok(ColumnarValue::Array(Arc::new(r)))
}
_ => plan_err!("unsupported arg types"),
}
} else {
plan_err!("unsupported arg types")
}
}
}
#[tokio::main]
async fn main() -> Result<()> {
// UDF作成
let udf = ScalarUDF::from(Concat::new());
let ctx = SessionContext::new();
ctx.register_udf(udf);
...省略
}
...省略
実行結果は同じです。
+---------------------------------+
| concat(Utf8("ab"),Utf8("cdef")) |
+---------------------------------+
| abcdef |
+---------------------------------+
+-------------------------------+
| concat(Int64(1234),Int64(56)) |
+-------------------------------+
| 123456 |
+-------------------------------+
+--------------------------------------------------------------------------------+
| concat(make_array(Int64(1),Int64(2),Int64(3)),make_array(Utf8("a"),Utf8("b"))) |
+--------------------------------------------------------------------------------+
| [1, 2, 3][a, b] |
+--------------------------------------------------------------------------------+
+---------------------------------------------------------------------------------+
| concat(UNNEST(make_array(Utf8("a1"),Utf8("b22"),Utf8("c333"))),Utf8("-sample")) |
+---------------------------------------------------------------------------------+
| a1-sample |
| b22-sample |
| c333-sample |
+---------------------------------------------------------------------------------+