この記事はSupershipグループ Advent Calendar 2022 の22日目の記事になります。
はじめに
初めまして! アドベントカレンダー初参加の伊藤です!
最近Apache Airflowに触れる機会があり、動的タスクとJinjaテンプレートについて学んだので書いていきます!
今回の内容
- Airflowの動的タスク
- Airflowは前のタスクの出力に基づいて次に実行するタスクを繰り返すことができる。
- Ver. 2.3.0以上ならDynamic Task Mappingが便利
- Ver. 2.3.0未満ならBranchOperator × TaskGropが便利
- AirflowのJinjaテンプレート
- AirflowはJinjaテンプレート機能を活用しており変数・マクロ・フィルターをテンプレートで使用可能
- 組み込みのフィルタ以外にも自分でカスタムすることができる。
- Jinjaなのでレンダリング時の型に注意が必要
1. Airflowの動的タスク
Airflowは、実行する処理をTaskと呼ばれる単位で整理することができ、DAG(Directed Acyclic Graph)はそのタスクをまとめて、タスク間の依存関係などを整理することができます。現在のAirflow(Ver.2.5.0)では、このDAGとTaskを動的に生成したい際に、それぞれDynamic DAG Generation と Dynamic Task Mapping が用意されています。
- Dynamic DAG Generation:パラメータを変えた複数のDAGを動的に生成可能
- Dynamic Task Mapping :DAGの実行時に前のタスクに応じて複数のタスクを実行可能
特にDynamic Task Mappingは、DAGを定義する際に事前にタスクの数を宣言しなくても良いので、Pysparkなどで引数を変えたジョブを複数実行したいが、実行するジョブの数がDAGを定義する際に決まらない(別のタスクで決めたい)時に便利な機能!ただ、Ver.2.3.0以降じゃないと使えない...
そこで、前のタスクに応じて複数のタスクを実行する際にDynamic Task Mapping を使った場合と使わないでもそれっぽくできる方法を書いていきます!
Dynamic Task Mapping
ここでは、サポートベクターマシン(SVM)のパラメータを変えてscikit-learnが提供するIrisデータセットを学習するタスクを実行するDAGを作成したいと思います。
- まずは複数回実行したい学習用のコード(
iris_svm.py
)を用意します。 - 次に、最初に実行するTaskで
iris_svm.py
に渡す引数を用意します。[{c: 0.1, gamma: 0.1}, {c:0.1, gamma: 1}, ...]
- 次に、前に実行されたパラメータをもとに
iris_svm.py
を実行するTaskを作成します。-
PythonOperator.partial(...).expand(op_kwargs=...)
のop_kwargs
に直前のTaskを渡します。 -
partial()
は変更しない部分、expand()
は変更する部分を意味します。
-
- 最後に実行するTaskでは、複数回実行された学習の結果を表示します。
from sklearn import datasets
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
def run(c, gamma):
"""Irisデータセットを使ったSVMの学習
Args:
c: 正則化パラメータ
gamma: ガウシアンカーネルの幅
"""
iris = datasets.load_iris() # iris dataset の読み込み
X_train, X_test, y_train, y_test = train_test_split(iris.data[:, :2], iris.target, test_size=0.2, random_state=111)
model = SVC(kernel='rbf', C=c, gamma=gamma)
model.fit(X_train, y_train) # 学習
Y_pred = model.predict(X_test) # 予測
return round(f1_score(y_true=y_test, y_pred=Y_pred, average="macro"), 3) # F値
from datetime import datetime, timedelta
from airflow import DAG, XComArg
from airflow.operators.python import PythonOperator
def make_params():
params = [{"c": c, "gamma": g} for c in [0.1, 1, 10] for g in [0.1, 1, 10]]
return params
def submit_job(c, gamma, **context):
import iris_svm
print(f"C = {c}, gamma = {gamma}")
f1_score = iris_svm.run(c=c, gamma=gamma)
print(f"f1_score = {f1_score}")
result = (c, gamma, f1_score)
context["task_instance"].xcom_push(key="result", value=result)
def summary(**context):
result = context["task_instance"].xcom_pull(task_ids="submit_job", key="result")
for c, gamma, f1_score in result:
print(f"C = {c}, gamma = {gamma}, f1_score = {f1_score}")
with DAG(
'example_dynamic',
default_args={
"retries": 1,
"retry_delay": timedelta(minutes=5),
"depends_on_past": False,
},
description='A simple tutorial DAG',
schedule_interval="@once",
start_date=datetime(2021, 1, 1),
catchup=False,
tags=['example'],
) as dag:
task_make_params = PythonOperator(
task_id="make_params",
python_callable=make_params,
dag=dag
)
task_job = PythonOperator.partial(
task_id="submit_job",
python_callable=submit_job,
dag=dag
).expand(op_kwargs=XComArg(task_make_params))
task_summary = PythonOperator(
task_id="summary",
python_callable=summary,
dag=dag,
)
(
task_make_params >> task_job >> task_summary
)
DAGを実行すると...
- AirflowのGraphから
Task: Submit_job
が9回実行されていることが確認できます。
-
Task: summary
のログにはパラメータ毎の学習結果が一覧表示されています。 - Dynamic Task Mappingを使うとコード量が少なくシンプルに実装が出来ています。
INFO - C = 0.1, gamma = 0.1, f1_score = 0.482
INFO - C = 0.1, gamma = 1, f1_score = 0.829
INFO - C = 0.1, gamma = 10, f1_score = 0.482
INFO - C = 1, gamma = 0.1, f1_score = 0.829
INFO - C = 1, gamma = 1, f1_score = 0.853
INFO - C = 1, gamma = 10, f1_score = 0.822
INFO - C = 10, gamma = 0.1, f1_score = 0.792
INFO - C = 10, gamma = 1, f1_score = 0.810
INFO - C = 10, gamma = 10, f1_score = 0.853
BranchPythonOperator × TaskGrop
Dynamic Task Mappingを使わない方法として、実行回数を定義しておき、For文でタスクを実行するという方法があります。しかし、この方法では実行回数に満たない繰り返しに対応できません。そうした際にBranchPythonOperatorを使うこと回数に満たない繰り返しにも対応することができます。
例えば、実行回数を10回としてDAGを作成しつつも、前のパラメータを作成するタスクで9回分のパラメータしか作成されなかった場合、BranchPythonOperatorで1回ごとにDummyのTaskを実行するか、iris_svm.py
を実行するTaskを実行するかを分岐させることで最後の10回目はDummyのTaskが実行されるようになります。
また、この方法ではAirflowのグラフからTaskの関係性が視認しづらくなるため、TaskGropを使い1つのグループであることを示すと視覚的にわかりやすくなります。
with TaskGroup("train", tooltip="Tasks Train") as tg_train:
for i in range(10):
task_branch = BranchPythonOperator(task_id=f"branch_{i}", python_callable=branch, op_kwargs={"index": i})
task_dummy = DummyOperator(task_id=f"dummy_{i}", dag=dag)
task_job = PythonOperator(
task_id=f"submit_job_{i}",
python_callable=submit_job,
op_kwargs={
"index": i,
"params": "{{ task_instance.xcom_pull(task_ids='make_params', key='return_value') }}",
},
dag=dag,
)
(task_branch >> [task_dummy, task_job])
- Jinjaテンプレートを使って値を渡した場合、値が文字列(str)になるため
eval()
を使ってListに変換する必要があります。
def submit_job(index, params, **context):
params = eval(params)
c, gamma = params[index]["c"], params[index]["gamma"]
...
- Dynamic Task Mappingを利用した場合は、XComから1回のPullですべての学習結果を取得できましたが、For文で複数のタスクを実行しているためタスクごとにPullする必要があります。
- また、後続タスクにトリガールール(
ALL_DONE
など)を指定して、TaskGroup内で実行したすべてのTaskが完了してから後続タスクを実行するといった依存関係を書く必要があります。
def summary(**context):
params = context["task_instance"].xcom_pull(task_ids="make_params", key="return_value")
for i in range(len(params)):
c, gamma, f1_score = context["task_instance"].xcom_pull(task_ids=f"train.submit_job_{i}", key="result")
print(f"C = {c}, gamma = {gamma}, f1_score = {f1_score}")
task_summary = PythonOperator(
task_id="summary",
python_callable=summary,
dag=dag,
trigger_rule=TriggerRule.ALL_DONE
)
2. AriflowのJinjaテンプレート
AirflowはJinjaテンプレート機能を活用しているため、変数・マクロ・フィルターをテンプレートで使用することができます。例えば、{{ ds_nodash }}
は DAGを実行したときの時間 {{ task_instance.task_id }}
は 実行中のタスクインスタンスのタスクIDとなります。また、変数以外にもJinjaのフィルタやマクロを利用することができます。
The Apache Software Foundation 2022."Templates reference".ApacheAirflow
Filter, Macro
- Filterでは、Jinjaの組み込みフィルタや値のフォーマットに使用できるフィルタを利用することが可能です。
- TaskGroup内のTaskIDが
TaskGroup名.TaskID
と ドットで連結されるのですが、GoogleのDataprocにジョブを投入するオペレータと互換性がなく、エラーになってしまうバグが報告されています。既に修正されているようですが、何かしら不都合が生じた際にフィルターが使えるのは便利かもしれません。
apach/airflow."Issue:DataprocJobBaseOperator not compatible with TaskGroups".GitHub
{{ task_instance.task_id | replace('.', '_') }}
- Macroでは、Airflowが用意している日数を加算・減算する関数の他にも、DAGの引数
user_defined_macros
に自分で定義した関数を渡すことで、テンプレート内から呼び出して使うことができます。
def add_world(str):
return str + " World!"
with airflow.DAG(
...
user_defined_macros={"add_world": add_world},
) as dag:
task = PythonOperator(
task_id="HelloWorld",
python_callable=lambda x: print(x),
dag=dag,
op_kwargs={"x": "{{ add_world('Hello')}}"}
)
Hello World!
テンプレートが返す値の型
便利な反面、Jinjaだから起きる問題もあります。
- デフォルトの設定では、テンプレートで取得される値は文字列としてレンダリングされるようになっているため、
{'1001': 301.27, '1002': 433.21, '1003': 502.22}
を{{ti.xcom_pull(...)}}"}
で後続タスクに渡した場合、文字列になります。 - 正しく辞書型として渡したい場合は、
render_template_as_native_obj
をTrueにする必要があります。
The Apache Software Foundation 2022."Operators".ApacheAirflow
with DAG(
...
render_template_as_native_obj=True,
) as dag:
おわりに
今回、Airflowのバージョンの関係でDynamic Task Mappingが使えない場合に、BranchPythonOperator × TaskGropを使ってそれっぽくする方法を書きましたが、もっと良い方法をがないか調査中です!!
Airflowは癖が強い印象を持つ一方で、使っていくと沢山のことをを学べるので楽しいと思いました。Airflowを触ってみて、ワークフローエンジンに興味が出てきたので他のものも学んでいきたいと思います。今回の記事が誰かの役に立てれば幸いです。
ぜひ、他の日のSupershipグループ Advent Calendar 2022 もご覧ください!!
Supershipではプロダクト開発やサービス開発に関わる人を絶賛募集しております。
ご興味がある方は以下リンクよりご確認ください。
Supership株式会社 採用サイト
是非ともよろしくお願いいたします。
参考文献(サイト)
Airflowについて学ぶ上で参考になりました。