0
0

Pythonユーザー定義テーブル関数(UDTF)サンプルノートブックのウォークスルー

Posted at

こちらの記事で触れられているサンプルノートブックをウォークスルーします。

翻訳版はこちらです。

こちらのクラスターを使います。ランタイム14.0以降が必要です。
Screenshot 2024-04-16 at 17.07.59.png

Python User-Defined Table Functions (UDTFs)

イントロダクション

Python user-defined table function (UDTF)は、出力を単一のスカラー結果値ではなく、テーブル全体を返却する新たなタイプの関数です。登録すると、SQLクエリーのFROM句で使用できるようになります。

Python UDTFの作成

from pyspark.sql.functions import udtf

# UDTFクラスの定義、必要となる `eval` メソッドの実装
@udtf(returnType="num: int, squared: int")
class SquareNumbers:
    """(num, squared)ペアのシーケンスを生成"""
    def eval(self, start: int, end: int):
        for num in range(start, end + 1):
            yield (num, num * num)

PythonでのPython UDTFの利用

from pyspark.sql.functions import lit

SquareNumbers(lit(1), lit(3)).show()
+---+-------+
|num|squared|
+---+-------+
|  1|      1|
|  2|      4|
|  3|      9|
+---+-------+

SQLでのPython UDTFの利用

# Spark SQLで利用できるためにUDTFを登録
spark.udtf.register("square_numbers", SquareNumbers)
spark.sql("SELECT * FROM square_numbers(1, 3)").show()
+---+-------+
|num|squared|
+---+-------+
|  1|      1|
|  2|      4|
|  3|      9|
+---+-------+

Arrow最適化Python UDTF

Apache ArrowはJavaプロセスとPythonプロセス間の効率的なデータ転送を実現する、インメモリの列指向データフォーマットです。

UDTFが大量の行を出力する際に、劇的にパフォーマンスをブーストします。Arrow最適化はuseArrow=Trueを用いることで有効化することができます。

@udtf(returnType="num: int, squared: int", useArrow=True)
class SquareNumbersArrow:
    def eval(self, start: int, end: int):
        for num in range(start, end + 1):
            yield (num, num * num)
# 通常のPython UDTF
SquareNumbers(lit(0), lit(10000000)).show()
+---+-------+
|num|squared|
+---+-------+
|  0|      0|
|  1|      1|
|  2|      4|
|  3|      9|
|  4|     16|
|  5|     25|
|  6|     36|
|  7|     49|
|  8|     64|
|  9|     81|
| 10|    100|
| 11|    121|
| 12|    144|
| 13|    169|
| 14|    196|
| 15|    225|
| 16|    256|
| 17|    289|
| 18|    324|
| 19|    361|
+---+-------+
only showing top 20 rows

約23秒かかっています。
Screenshot 2024-04-16 at 17.10.57.png

# Arrow最適化Python UDTF
SquareNumbersArrow(lit(0), lit(10000000)).show()
+---+-------+
|num|squared|
+---+-------+
|  0|      0|
|  1|      1|
|  2|      4|
|  3|      9|
|  4|     16|
|  5|     25|
|  6|     36|
|  7|     49|
|  8|     64|
|  9|     81|
| 10|    100|
| 11|    121|
| 12|    144|
| 13|    169|
| 14|    196|
| 15|    225|
| 16|    256|
| 17|    289|
| 18|    324|
| 19|    361|
+---+-------+
only showing top 20 rows

14秒に削減されています。
Screenshot 2024-04-16 at 17.11.42.png

LangChainを用いた現実世界のユースケース

Python UDTFとOpenAI、LangChainを組み合わせた、より面白い例を深掘りしていきましょう。

%pip install openai langchain

私はGPT-4を呼び出せなかったので、GPT-3.5にしています。シークレットでAPIキーを管理しています。

from langchain.chains import LLMChain
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from pyspark.sql.functions import lit, udtf

#MODEL_NAME = "gpt-4"
MODEL_NAME = "gpt-3.5-turbo-instruct"
#MODEL_NAME = "gpt-4-turbo"
API_KEY = dbutils.secrets.get("demo-token-takaaki.yayoi", "openai_api_key")

@udtf(returnType="keyword: string")
class KeywordsGenerator:
    """
    LLMを用いてトピックに関するカンマ区切りのキーワードのリストを生成します。
    キーワードのみを出力します。
    """
    def __init__(self):
        print(MODEL_NAME)
        llm = OpenAI(model_name=MODEL_NAME, openai_api_key=API_KEY)
        prompt = PromptTemplate(
            input_variables=["topic"],
            template="generate a couple of comma separated keywords about {topic}. Output only the keywords."
        )
        self.chain = LLMChain(llm=llm, prompt=prompt)

    def eval(self, topic: str):
        response = self.chain.run(topic)
        keywords = [keyword.strip() for keyword in response.split(",")]
        for keyword in keywords:
            yield (keyword, )
KeywordsGenerator(lit("apache spark")).show(truncate=False)

おおー。

+---------------------+
|keyword              |
+---------------------+
|Apache Spark         |
|big data processing  |
|distributed computing|
|machine learning     |
|real-time analytics  |
|data streaming       |
|parallel computing   |
|data science         |
|data engineering     |
|data analysis        |
+---------------------+

日本語だとどうですかね。

KeywordsGenerator(lit("日本")).show(truncate=False)

結果が英語になってしまいました。

+----------------+
|keyword         |
+----------------+
|Japan           |
|Japanese culture|
|Tokyo           |
|sushi           |
|Mount Fuji      |
|samurai         |
|anime           |
|cherry blossoms |
+----------------+

プロンプトを変えましょう。半角カンマなどUDTFとうまく動作するように工夫しています。

@udtf(returnType="keyword: string")
class KeywordsGenerator:
    """
    LLMを用いてトピックに関するカンマ区切りのキーワードのリストを生成します。
    キーワードのみを出力します。
    """
    def __init__(self):
        print(MODEL_NAME)
        llm = OpenAI(model_name=MODEL_NAME, openai_api_key=API_KEY)
        prompt = PromptTemplate(
            input_variables=["topic"],
            template="{topic}に関する半角カンマ区切りのキーワードをいくつか生成してください。生成結果にはキーワードのみを含めてください。"
        )
        self.chain = LLMChain(llm=llm, prompt=prompt)

    def eval(self, topic: str):
        response = self.chain.run(topic)
        keywords = [keyword.strip() for keyword in response.split(",")]
        for keyword in keywords:
            yield (keyword, )
KeywordsGenerator(lit("日本文化")).show(truncate=False)

動きました!

+--------+
|keyword |
+--------+
|神道    |
|歌舞伎  |
|茶道    |
|侍      |
|浮世絵  |
|和服    |
|武士道  |
|花見    |
|お盆    |
|ふろしき|
+--------+

Sparkと複雑な処理を組み合わせて、複数行を返却させるような処理では使えそうですね。

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0