SQL
Presto
BrainPadDay 13

PrestoでUDFを使う

この記事は?

分散クエリエンジンPrestoのUDFの使い方を紹介します。大体のやり方は公式ドキュメントのDeveloper Guideを見れば書いてあるのですが、完全な例はソース読めとのことなので1、そちらを参考にしつつ実際にUDFを作って動かすまでにやることをまとめました。

概略

Prestoはコネクター、型、関数、アクセス制御をプラグインとして定義するためのService provider interface (SPI) を提供しています2。このSPIを利用すれば、基本的にはやることは↓だけです。

  1. presto-spiのインターフェースに従ってプラグインを実装
  2. 実装したプラグインとメタ情報が入ったjarファイルを作成
  3. jarファイルをpresto-server/plugin下にコピー

以下で具体的な手順を見ていきます。なお、今回使ったPrestoのバージョンは0.189です。

実際

実装

presto-spiで定義されたプラグインのインターフェースcom.facebook.presto.spi.Pluginを実装し、getFunctions()メソッドでこの後実装する関数 (の集合) を返します。

MyUdfPlugin.java
package com.example.presto.udf;

import com.facebook.presto.spi.Plugin;
import com.google.common.collect.ImmutableSet;
import java.util.Set;

public class MyUdfPlugin implements Plugin
{
    @Override
    public Set<Class<?>> getFunctions()
    {
        return ImmutableSet.<Class<?>>builder()
                .add(MyFunctions.class)
                .build();
    }
}

関数の実装はpresto-spiのアノテーションを使って行います。アノテーションで関数名やSQL型を指定してやれば、メソッドをUDFとして定義できます。スカラー関数と集約関数とがそれぞれ定義できますが、手始めにシンプルな例として、整数を引数にとって文字列を返すスカラー関数check_fizz_buzzを定義します。VARCHAR型はStringではなくSliceとして扱うことに注意してください。

MyFunctions.java
package com.example.presto.udf;

import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.SqlNullable;
import com.facebook.presto.spi.type.StandardTypes;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;

public final class MyFunctions
{
    @ScalarFunction("check_fizz_buzz")
    @Description("Returns 'Fizz' if n is divisible by 3, or else ...")
    @SqlType(StandardTypes.VARCHAR)
    public static Slice check(@SqlNullable @SqlType(StandardTypes.BIGINT) Long num)
    {
        if (num % 15 == 0) { return Slices.utf8Slice("FizzBuzz"); }
        else if (num % 5 == 0) { return Slices.utf8Slice("Buzz"); }
        else if (num % 3 == 0) { return Slices.utf8Slice("Fizz"); }
        else { return Slices.utf8Slice(Long.toString(num)); }
    }
}

ビルド

プラグインを使うために、実装したクラスとメタ情報をjarファイルに固めます。Mavenを使う場合、プロジェクトのファイル構成は例えばこんな具合になります。

.
|-- pom.xml
`-- src
    `-- main
        |-- java
        |   `-- com
        |       `-- example
        |           `-- presto
        |               `-- udf
        |                   |-- MyFunctions.java
        |                   `-- MyUdfPlugin.java
        `-- resources
            `-- META-INF
                `-- services
                    `-- com.facebook.presto.spi.Plugin

MERA-INF/serviceにはプラグインの実装をServiceLoaderに知らせるための情報として、ファイルにクラスの名前を書き、インターフェースのパスをファイル名に付けて保存します。

com.facebook.presto.spi.Plugin
com.example.presto.udf.MyUdfPlugin

pom.xmlにはpresto-spiへの依存と、その他プラグインの実装に必要なライブラリへの依存を書いておきます。

pom.xml
<?xml version="1.0"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.example</groupId>
    <artifactId>presto-udf</artifactId>
    <version>1</version>
    <packaging>jar</packaging>

    <dependencies>
        <dependency>
            <groupId>com.facebook.presto</groupId>
            <artifactId>presto-spi</artifactId>
            <version>0.189</version>
            <scope>provided</scope>
        </dependency>
        <dependency>
            <groupId>io.airlift</groupId>
            <artifactId>slice</artifactId>
            <version>0.32</version>
            <scope>provided</scope>
        </dependency>
        <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>21.0</version>
        </dependency>
    </dependencies>
</project>

デプロイ

presto-server/plugin以下にディレクトリを作成して、ビルドしたjarをコピーします。また、プラグインが依存する他のjarファイルも一緒に置いておく必要があります。コピーしたらPrestoを再起動しましょう。

$ mkdir ${presto_install_dir}/presto-server/plugin/my-udf
$ cp target/presto-udf-1.jar ${presto_install_dir}/presto-server/plugin/my-udf/
$ cp target/dependency/guava-21.0.jar ${presto_install_dir}/presto-server/plugin/my-udf/
$ ${presto_install_dir}/presto-server/bin/launcher restart

クラスタ上のすべてのノードでこの作業を実行したらデプロイ完了です。

動作確認

presto> select check_fizz_buzz(1);
 _col0 
-------
 1     
(1 row)

presto> select check_fizz_buzz(3);
 _col0 
-------
 Fizz  
(1 row)

presto> select n, check_fizz_buzz(n) from (select 1) cross join unnest(sequence(1, 20)) as t(n);
 n  |  _col1   
----+----------
  1 | 1        
  2 | 2        
  3 | Fizz     
  4 | 4        
  5 | Buzz     
  6 | Fizz     
  7 | 7        
  8 | 8        
  9 | Fizz     
 10 | Buzz     
 11 | 11       
 12 | Fizz     
 13 | 13       
 14 | 14       
 15 | FizzBuzz 
 16 | 16       
 17 | 17       
 18 | Fizz     
 19 | 19       
 20 | Buzz     
(20 rows)

集約関数の例

集約関数はメソッドではなくクラスで定義されます。以下は文字列に含まれる 'z' の数の合計を数える例です。
InputFunction, OutputFunction, CombineFunctionに対応するメソッドをそれぞれ実装します。

CountZ.java
package com.example.presto.udf;

import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.function.InputFunction;
import com.facebook.presto.spi.function.CombineFunction;
import com.facebook.presto.spi.function.OutputFunction;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.type.StandardTypes;
import io.airlift.slice.Slice;
import java.util.stream.IntStream;

import static com.facebook.presto.spi.type.BigintType.BIGINT;

@AggregationFunction("count_z")
public class CountZ
{
    @InputFunction
    public static void input(LongState state, @SqlType(StandardTypes.VARCHAR) Slice slice)
    {
        IntStream chars = slice.toStringUtf8().chars();
        state.setLong(state.getLong() + chars.filter(c -> c == 'z').count());
    }

    @CombineFunction
    public static void combine(LongState state, LongState otherState)
    {
        state.setLong(state.getLong() + otherState.getLong());
    }

    @OutputFunction(StandardTypes.BIGINT)
    public static void output(LongState state, BlockBuilder out)
    {
        out.writeLong(state.getLong());
    }
}

これらのメソッドで処理しているのは、集約の状態を表すLongStateオブジェクトです。定義はAccumulatorStateを継承して行えます。

LongState.java
package com.example.presto.udf;

import com.facebook.presto.spi.function.AccumulatorState;

public interface LongState extends AccumulatorState
{
    long getLong();

    void setLong(long value);
}

使用例です。

presto> select count_z(check_fizz_buzz(n)) from (select 1) cross join unnest(sequence(1, 20)) as t(n);
 _col0
-------
    20
(1 row)

presto> select n % 5 mod, count_z(check_fizz_buzz(n)) from (select 1) cross join unnest(sequence(1, 20)) as t(n) group by n % 5;
 mod | _col1
-----+-------
   2 |     2
   3 |     4
   4 |     2
   0 |    10
   1 |     2
(5 rows)

参考

presto:sample> with t as (
            ->   select
            ->     id,
            ->     species,
            ->     features(sepallength, sepalwidth, petallength, petalwidth) f
            ->   from iris)
            -> select
            ->   evaluate_classifier_predictions(species, classify(f, model))
            -> from t,
            ->   (select learn_classifier(species, f) model from t where id % 2 = 0)
            -> where id % 2 = 1;
           _col0            
----------------------------
 Accuracy: 72/75 (96.00%)   
 Class 'virginica'          
 Precision: 23/24 (95.83%)  
 Recall: 23/25 (92.00%)     
 Class 'setosa'             
 Precision: 25/25 (100.00%) 
 Recall: 25/25 (100.00%)    
 Class 'versicolor'         
 Precision: 24/26 (92.31%)  
 Recall: 24/25 (96.00%)     

(1 row)

  1. 今回の例ではpresto-mlという機械学習プラグインを大いに参考にしました。このプラグイン、隠し機能として?Prestoの配布アーカイブに最初から入っており、Prestoに繋いだデータソースを使ってお手軽に学習、予測、評価ができるようです(上の例参照)。裏側はlibsvmで、学習は分散処理されるわけではなさそうなのでそこは注意。 

  2. 特にコネクターに関しては、Prestoがサポートしているコネクターのほとんどがプラグインとして実装されているようです