0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

DataFusionでUDFを作成してみる

Posted at

Apache DataFusion には UDF(ユーザー定義関数)の仕組みがあり、任意の処理を関数として定義して SQL 等で利用できます。

UDF には下記のような種類がありますが、今回は Scalar UDF を作成してみました。

  • Scalar
  • Window
  • Aggregate
  • Table

なお、UDF の実現方法としては 2通り用意されています。

定義方法 備考
create_udf 関数利用 簡易的で制限あり
ScalarUDFImpl トレイト実装 Low Level API で複雑な処理も実現可能

UDF作成

ここでは単純な UDF を作成します。
今回使用した Cargo.toml はこのような内容です。

Cargo.toml
[dependencies]
datafusion = "46.0.1"
tokio = { version = "1.0", features = ["rt-multi-thread"] }

1. create_udf による UDF 定義1

create_udf 関数の内容はこのようになっており、戻り値の ScalarUDF を SessionContextregister_udf する事で使用できるようになります。

create_udf(関数名, 引数の型, 戻り値の型, 変動性, 実装関数) -> ScalarUDF

例えば、2つの文字列を引数にとってそれを連結する関数 concat を create_udf で定義するとこのようになります。

src/sample1.rs
use datafusion::arrow::datatypes::DataType;
use datafusion::common::plan_err;
use datafusion::error::Result;
use datafusion::logical_expr::{ColumnarValue, Volatility};
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;

use std::sync::Arc;
// UDF実装
fn concat(args: &[ColumnarValue]) -> Result<ColumnarValue> {
    let (a, b) = (&args[0], &args[1]);

    match (a, b) {
        (
            ColumnarValue::Scalar(ScalarValue::Utf8(a)),
            ColumnarValue::Scalar(ScalarValue::Utf8(b)),
        ) => {
            let r = a.clone().and_then(|x| b.clone().map(|y| x + &y));
            Ok(ColumnarValue::Scalar(ScalarValue::Utf8(r)))
        }
        _ => plan_err!("unsupported arg types"),
    }
}

#[tokio::main]
async fn main() -> Result<()> {
    // UDF作成
    let udf = create_udf(
        "concat",
        vec![DataType::Utf8, DataType::Utf8],
        DataType::Utf8,
        Volatility::Immutable,
        Arc::new(concat),
    );

    let ctx = SessionContext::new();
    // UDF登録
    ctx.register_udf(udf);

    run_query(&ctx, "SELECT concat('ab', 'cdef')").await?;
    run_query(&ctx, "SELECT concat(1234, 56)").await?;
    run_query(&ctx, "SELECT concat([1, 2, 3], ['a', 'b'])").await?;

    // この処理はエラーになる
    run_query(
        &ctx,
        "SELECT concat(unnest(['a1', 'b22', 'c333']), '-sample')",
    )
    .await?;

    Ok(())
}

async fn run_query(ctx: &SessionContext, query: &str) -> Result<()> {
    let df = ctx.sql(query).await?;

    df.show().await?;

    Ok(())
}

実行結果はこうなります。

数値や配列等もスカラーな文字列として連結するのがポイントです。
ただし、テーブルの行データや unnest の結果は ColumnarValue::Scalar にはマッチしないためエラーになります。

実行結果1
+---------------------------------+
| concat(Utf8("ab"),Utf8("cdef")) |
+---------------------------------+
| abcdef                          |
+---------------------------------+
+-------------------------------+
| concat(Int64(1234),Int64(56)) |
+-------------------------------+
| 123456                        |
+-------------------------------+
+--------------------------------------------------------------------------------+
| concat(make_array(Int64(1),Int64(2),Int64(3)),make_array(Utf8("a"),Utf8("b"))) |
+--------------------------------------------------------------------------------+
| [1, 2, 3][a, b]                                                                |
+--------------------------------------------------------------------------------+
Error: Plan("unsupported arg types")

2. create_udf による UDF 定義2

unnest の結果を処理するには ColumnarValue::Array にマッチする処理を追加します。

例えば、ColumnarValue::Array の個々の要素へ第二引数の文字列を連結する処理はこのようになります。

src/sample2.rs
...省略

fn concat(args: &[ColumnarValue]) -> Result<ColumnarValue> {
    let (a, b) = (&args[0], &args[1]);

    match (a, b) {
        ...省略
        (ColumnarValue::Array(a), ColumnarValue::Scalar(ScalarValue::Utf8(b))) => {
            let r = a
                .as_string::<i32>()
                .iter()
                .map(|x| x.and_then(|x| b.clone().map(|y| x.to_string() + &y)))
                .collect::<StringArray>();
            Ok(ColumnarValue::Array(Arc::new(r)))
        }
        _ => plan_err!("unsupported arg types"),
    }
}

...省略

実行結果はこうなります。

実行結果2
...省略
+---------------------------------------------------------------------------------+
| concat(UNNEST(make_array(Utf8("a1"),Utf8("b22"),Utf8("c333"))),Utf8("-sample")) |
+---------------------------------------------------------------------------------+
| a1-sample                                                                       |
| b22-sample                                                                      |
| c333-sample                                                                     |
+---------------------------------------------------------------------------------+

3. ScalarUDFImpl トレイト実装

同様の処理を ScalarUDFImpl トレイトで実装すると、例えばこのようになります。

UDF の引数の型によって戻り値の型を変化させたり等、create_udf を使うよりも柔軟な対応が可能となるのが特徴です。

src/sample3.rs
use datafusion::arrow::array::{AsArray, StringArray};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::plan_err;
use datafusion::error::Result;
use datafusion::logical_expr::{
    ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;

use std::sync::Arc;

#[derive(Debug, Clone)]
struct Concat {
    signature: Signature,
}

impl Concat {
    fn new() -> Self {
        Self {
            signature: Signature::uniform(
                2,
                vec![DataType::Utf8, DataType::Utf8],
                Volatility::Immutable,
            ),
        }
    }
}
// UDF実装
impl ScalarUDFImpl for Concat {
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    fn name(&self) -> &str {
        "concat"
    }

    fn signature(&self) -> &Signature {
        &self.signature
    }

    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
        Ok(DataType::Utf8)
    }

    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
        if args.args.len() == 2 {
            let a = args.args.first();
            let b = args.args.last();

            match (a, b) {
                (
                    Some(ColumnarValue::Scalar(ScalarValue::Utf8(a))),
                    Some(ColumnarValue::Scalar(ScalarValue::Utf8(b))),
                ) => {
                    let r = a.clone().and_then(|x| b.clone().map(|y| x + &y));
                    Ok(ColumnarValue::Scalar(ScalarValue::Utf8(r)))
                }
                (
                    Some(ColumnarValue::Array(a)),
                    Some(ColumnarValue::Scalar(ScalarValue::Utf8(b))),
                ) => {
                    let r = a
                        .as_string::<i32>()
                        .iter()
                        .map(|x| x.and_then(|x| b.clone().map(|y| x.to_string() + &y)))
                        .collect::<StringArray>();
                    Ok(ColumnarValue::Array(Arc::new(r)))
                }
                _ => plan_err!("unsupported arg types"),
            }
        } else {
            plan_err!("unsupported arg types")
        }
    }
}

#[tokio::main]
async fn main() -> Result<()> {
    // UDF作成
    let udf = ScalarUDF::from(Concat::new());

    let ctx = SessionContext::new();
    ctx.register_udf(udf);

    ...省略
}

...省略

実行結果は同じです。

実行結果3
+---------------------------------+
| concat(Utf8("ab"),Utf8("cdef")) |
+---------------------------------+
| abcdef                          |
+---------------------------------+
+-------------------------------+
| concat(Int64(1234),Int64(56)) |
+-------------------------------+
| 123456                        |
+-------------------------------+
+--------------------------------------------------------------------------------+
| concat(make_array(Int64(1),Int64(2),Int64(3)),make_array(Utf8("a"),Utf8("b"))) |
+--------------------------------------------------------------------------------+
| [1, 2, 3][a, b]                                                                |
+--------------------------------------------------------------------------------+
+---------------------------------------------------------------------------------+
| concat(UNNEST(make_array(Utf8("a1"),Utf8("b22"),Utf8("c333"))),Utf8("-sample")) |
+---------------------------------------------------------------------------------+
| a1-sample                                                                       |
| b22-sample                                                                      |
| c333-sample                                                                     |
+---------------------------------------------------------------------------------+
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?