こちらの記事で触れられているサンプルノートブックをウォークスルーします。
翻訳版はこちらです。
こちらのクラスターを使います。ランタイム14.0以降が必要です。
Python User-Defined Table Functions (UDTFs)
イントロダクション
Python user-defined table function (UDTF)は、出力を単一のスカラー結果値ではなく、テーブル全体を返却する新たなタイプの関数です。登録すると、SQLクエリーのFROM
句で使用できるようになります。
Python UDTFの作成
from pyspark.sql.functions import udtf
# UDTFクラスの定義、必要となる `eval` メソッドの実装
@udtf(returnType="num: int, squared: int")
class SquareNumbers:
"""(num, squared)ペアのシーケンスを生成"""
def eval(self, start: int, end: int):
for num in range(start, end + 1):
yield (num, num * num)
PythonでのPython UDTFの利用
from pyspark.sql.functions import lit
SquareNumbers(lit(1), lit(3)).show()
+---+-------+
|num|squared|
+---+-------+
| 1| 1|
| 2| 4|
| 3| 9|
+---+-------+
SQLでのPython UDTFの利用
# Spark SQLで利用できるためにUDTFを登録
spark.udtf.register("square_numbers", SquareNumbers)
spark.sql("SELECT * FROM square_numbers(1, 3)").show()
+---+-------+
|num|squared|
+---+-------+
| 1| 1|
| 2| 4|
| 3| 9|
+---+-------+
Arrow最適化Python UDTF
Apache ArrowはJavaプロセスとPythonプロセス間の効率的なデータ転送を実現する、インメモリの列指向データフォーマットです。
UDTFが大量の行を出力する際に、劇的にパフォーマンスをブーストします。Arrow最適化はuseArrow=True
を用いることで有効化することができます。
@udtf(returnType="num: int, squared: int", useArrow=True)
class SquareNumbersArrow:
def eval(self, start: int, end: int):
for num in range(start, end + 1):
yield (num, num * num)
# 通常のPython UDTF
SquareNumbers(lit(0), lit(10000000)).show()
+---+-------+
|num|squared|
+---+-------+
| 0| 0|
| 1| 1|
| 2| 4|
| 3| 9|
| 4| 16|
| 5| 25|
| 6| 36|
| 7| 49|
| 8| 64|
| 9| 81|
| 10| 100|
| 11| 121|
| 12| 144|
| 13| 169|
| 14| 196|
| 15| 225|
| 16| 256|
| 17| 289|
| 18| 324|
| 19| 361|
+---+-------+
only showing top 20 rows
# Arrow最適化Python UDTF
SquareNumbersArrow(lit(0), lit(10000000)).show()
+---+-------+
|num|squared|
+---+-------+
| 0| 0|
| 1| 1|
| 2| 4|
| 3| 9|
| 4| 16|
| 5| 25|
| 6| 36|
| 7| 49|
| 8| 64|
| 9| 81|
| 10| 100|
| 11| 121|
| 12| 144|
| 13| 169|
| 14| 196|
| 15| 225|
| 16| 256|
| 17| 289|
| 18| 324|
| 19| 361|
+---+-------+
only showing top 20 rows
LangChainを用いた現実世界のユースケース
Python UDTFとOpenAI、LangChainを組み合わせた、より面白い例を深掘りしていきましょう。
%pip install openai langchain
私はGPT-4を呼び出せなかったので、GPT-3.5にしています。シークレットでAPIキーを管理しています。
from langchain.chains import LLMChain
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from pyspark.sql.functions import lit, udtf
#MODEL_NAME = "gpt-4"
MODEL_NAME = "gpt-3.5-turbo-instruct"
#MODEL_NAME = "gpt-4-turbo"
API_KEY = dbutils.secrets.get("demo-token-takaaki.yayoi", "openai_api_key")
@udtf(returnType="keyword: string")
class KeywordsGenerator:
"""
LLMを用いてトピックに関するカンマ区切りのキーワードのリストを生成します。
キーワードのみを出力します。
"""
def __init__(self):
print(MODEL_NAME)
llm = OpenAI(model_name=MODEL_NAME, openai_api_key=API_KEY)
prompt = PromptTemplate(
input_variables=["topic"],
template="generate a couple of comma separated keywords about {topic}. Output only the keywords."
)
self.chain = LLMChain(llm=llm, prompt=prompt)
def eval(self, topic: str):
response = self.chain.run(topic)
keywords = [keyword.strip() for keyword in response.split(",")]
for keyword in keywords:
yield (keyword, )
KeywordsGenerator(lit("apache spark")).show(truncate=False)
おおー。
+---------------------+
|keyword |
+---------------------+
|Apache Spark |
|big data processing |
|distributed computing|
|machine learning |
|real-time analytics |
|data streaming |
|parallel computing |
|data science |
|data engineering |
|data analysis |
+---------------------+
日本語だとどうですかね。
KeywordsGenerator(lit("日本")).show(truncate=False)
結果が英語になってしまいました。
+----------------+
|keyword |
+----------------+
|Japan |
|Japanese culture|
|Tokyo |
|sushi |
|Mount Fuji |
|samurai |
|anime |
|cherry blossoms |
+----------------+
プロンプトを変えましょう。半角カンマなどUDTFとうまく動作するように工夫しています。
@udtf(returnType="keyword: string")
class KeywordsGenerator:
"""
LLMを用いてトピックに関するカンマ区切りのキーワードのリストを生成します。
キーワードのみを出力します。
"""
def __init__(self):
print(MODEL_NAME)
llm = OpenAI(model_name=MODEL_NAME, openai_api_key=API_KEY)
prompt = PromptTemplate(
input_variables=["topic"],
template="{topic}に関する半角カンマ区切りのキーワードをいくつか生成してください。生成結果にはキーワードのみを含めてください。"
)
self.chain = LLMChain(llm=llm, prompt=prompt)
def eval(self, topic: str):
response = self.chain.run(topic)
keywords = [keyword.strip() for keyword in response.split(",")]
for keyword in keywords:
yield (keyword, )
KeywordsGenerator(lit("日本文化")).show(truncate=False)
動きました!
+--------+
|keyword |
+--------+
|神道 |
|歌舞伎 |
|茶道 |
|侍 |
|浮世絵 |
|和服 |
|武士道 |
|花見 |
|お盆 |
|ふろしき|
+--------+
Sparkと複雑な処理を組み合わせて、複数行を返却させるような処理では使えそうですね。