12
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Airflowの動的タスクとJinjaテンプレートについて学んだこと

Last updated at Posted at 2022-12-21

この記事はSupershipグループ Advent Calendar 2022 の22日目の記事になります。

はじめに

初めまして! アドベントカレンダー初参加の伊藤です!
最近Apache Airflowに触れる機会があり、動的タスクとJinjaテンプレートについて学んだので書いていきます!

今回の内容

  1. Airflowの動的タスク
    • Airflowは前のタスクの出力に基づいて次に実行するタスクを繰り返すことができる。
    • Ver. 2.3.0以上ならDynamic Task Mappingが便利
    • Ver. 2.3.0未満ならBranchOperator × TaskGropが便利
  2. AirflowのJinjaテンプレート
    • AirflowはJinjaテンプレート機能を活用しており変数・マクロ・フィルターをテンプレートで使用可能
    • 組み込みのフィルタ以外にも自分でカスタムすることができる。
    • Jinjaなのでレンダリング時の型に注意が必要

1. Airflowの動的タスク

Airflowは、実行する処理をTaskと呼ばれる単位で整理することができ、DAG(Directed Acyclic Graph)はそのタスクをまとめて、タスク間の依存関係などを整理することができます。現在のAirflow(Ver.2.5.0)では、このDAGとTaskを動的に生成したい際に、それぞれDynamic DAG Generation と Dynamic Task Mapping が用意されています。

  • Dynamic DAG Generation:パラメータを変えた複数のDAGを動的に生成可能
  • Dynamic Task Mapping :DAGの実行時に前のタスクに応じて複数のタスクを実行可能

特にDynamic Task Mappingは、DAGを定義する際に事前にタスクの数を宣言しなくても良いので、Pysparkなどで引数を変えたジョブを複数実行したいが、実行するジョブの数がDAGを定義する際に決まらない(別のタスクで決めたい)時に便利な機能!ただ、Ver.2.3.0以降じゃないと使えない...

そこで、前のタスクに応じて複数のタスクを実行する際にDynamic Task Mapping を使った場合と使わないでもそれっぽくできる方法を書いていきます!

Dynamic Task Mapping

ここでは、サポートベクターマシン(SVM)のパラメータを変えてscikit-learnが提供するIrisデータセットを学習するタスクを実行するDAGを作成したいと思います。

  • まずは複数回実行したい学習用のコード(iris_svm.py)を用意します。
  • 次に、最初に実行するTaskでiris_svm.pyに渡す引数を用意します。
    • [{c: 0.1, gamma: 0.1}, {c:0.1, gamma: 1}, ...]
  • 次に、前に実行されたパラメータをもとにiris_svm.pyを実行するTaskを作成します。
    • PythonOperator.partial(...).expand(op_kwargs=...)op_kwargsに直前のTaskを渡します。
    • partial()は変更しない部分、expand()は変更する部分を意味します。
  • 最後に実行するTaskでは、複数回実行された学習の結果を表示します。
iris_svm.py

from sklearn import datasets
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score


def run(c, gamma):
    """Irisデータセットを使ったSVMの学習
    Args:
        c: 正則化パラメータ
        gamma: ガウシアンカーネルの幅
    """
    iris = datasets.load_iris()  # iris dataset の読み込み
    X_train, X_test, y_train, y_test = train_test_split(iris.data[:, :2], iris.target, test_size=0.2, random_state=111)
    model = SVC(kernel='rbf', C=c, gamma=gamma)
    model.fit(X_train, y_train)  # 学習
    Y_pred = model.predict(X_test)  # 予測
    return round(f1_score(y_true=y_test, y_pred=Y_pred, average="macro"), 3)  # F値
DAG
from datetime import datetime, timedelta
from airflow import DAG, XComArg
from airflow.operators.python import PythonOperator


def make_params():
    params = [{"c": c, "gamma": g} for c in [0.1, 1, 10] for g in [0.1, 1, 10]]
    return params


def submit_job(c, gamma, **context):

    import iris_svm
    print(f"C = {c}, gamma = {gamma}")
    f1_score = iris_svm.run(c=c, gamma=gamma)
    print(f"f1_score = {f1_score}")
    result = (c, gamma, f1_score)
    context["task_instance"].xcom_push(key="result", value=result)


def summary(**context):
    result = context["task_instance"].xcom_pull(task_ids="submit_job", key="result")
    for c, gamma, f1_score in result:
        print(f"C = {c}, gamma = {gamma}, f1_score = {f1_score}")


with DAG(
    'example_dynamic',
    default_args={
        "retries": 1,
        "retry_delay": timedelta(minutes=5),
        "depends_on_past": False,
    },
    description='A simple tutorial DAG',
    schedule_interval="@once",
    start_date=datetime(2021, 1, 1),
    catchup=False,
    tags=['example'],
) as dag:

    task_make_params = PythonOperator(
        task_id="make_params",
        python_callable=make_params,
        dag=dag
    )

    task_job = PythonOperator.partial(
        task_id="submit_job",
        python_callable=submit_job,
        dag=dag
    ).expand(op_kwargs=XComArg(task_make_params))

    task_summary = PythonOperator(
        task_id="summary",
        python_callable=summary,
        dag=dag,
    )

    (
        task_make_params >> task_job >> task_summary
    )

DAGを実行すると...

  • AirflowのGraphからTask: Submit_jobが9回実行されていることが確認できます。
    スクリーンショット 2022-12-20 3.31.36.png
  • Task: summaryのログにはパラメータ毎の学習結果が一覧表示されています。
  • Dynamic Task Mappingを使うとコード量が少なくシンプルに実装が出来ています。
Task: summary
INFO - C = 0.1, gamma = 0.1,  f1_score = 0.482
INFO - C = 0.1, gamma = 1,    f1_score = 0.829
INFO - C = 0.1, gamma = 10,   f1_score = 0.482
INFO - C = 1,   gamma = 0.1,  f1_score = 0.829
INFO - C = 1,   gamma = 1,    f1_score = 0.853
INFO - C = 1,   gamma = 10,   f1_score = 0.822
INFO - C = 10,  gamma = 0.1,  f1_score = 0.792
INFO - C = 10,  gamma = 1,    f1_score = 0.810
INFO - C = 10,  gamma = 10,   f1_score = 0.853

BranchPythonOperator × TaskGrop

Dynamic Task Mappingを使わない方法として、実行回数を定義しておき、For文でタスクを実行するという方法があります。しかし、この方法では実行回数に満たない繰り返しに対応できません。そうした際にBranchPythonOperatorを使うこと回数に満たない繰り返しにも対応することができます。

例えば、実行回数を10回としてDAGを作成しつつも、前のパラメータを作成するタスクで9回分のパラメータしか作成されなかった場合、BranchPythonOperatorで1回ごとにDummyのTaskを実行するか、iris_svm.pyを実行するTaskを実行するかを分岐させることで最後の10回目はDummyのTaskが実行されるようになります。
また、この方法ではAirflowのグラフからTaskの関係性が視認しづらくなるため、TaskGropを使い1つのグループであることを示すと視覚的にわかりやすくなります。

展開前のグラフ展開後のグラフ

Task: submit_job
with TaskGroup("train", tooltip="Tasks Train") as tg_train:
        for i in range(10):
            task_branch = BranchPythonOperator(task_id=f"branch_{i}", python_callable=branch, op_kwargs={"index": i})
            task_dummy = DummyOperator(task_id=f"dummy_{i}", dag=dag)
            task_job = PythonOperator(
                task_id=f"submit_job_{i}",
                python_callable=submit_job,
                op_kwargs={
                    "index": i,
                    "params": "{{ task_instance.xcom_pull(task_ids='make_params', key='return_value') }}",
                },
                dag=dag,
            )
            (task_branch >> [task_dummy, task_job])
  • Jinjaテンプレートを使って値を渡した場合、値が文字列(str)になるためeval()を使ってListに変換する必要があります。
submit_job関数
def submit_job(index, params, **context):
    params = eval(params)
    c, gamma = params[index]["c"], params[index]["gamma"]
    ...
  • Dynamic Task Mappingを利用した場合は、XComから1回のPullですべての学習結果を取得できましたが、For文で複数のタスクを実行しているためタスクごとにPullする必要があります。
  • また、後続タスクにトリガールール(ALL_DONEなど)を指定して、TaskGroup内で実行したすべてのTaskが完了してから後続タスクを実行するといった依存関係を書く必要があります。
summary
def summary(**context):
    params = context["task_instance"].xcom_pull(task_ids="make_params", key="return_value")
    for i in range(len(params)):
        c, gamma, f1_score = context["task_instance"].xcom_pull(task_ids=f"train.submit_job_{i}", key="result")
        print(f"C = {c}, gamma = {gamma}, f1_score = {f1_score}")

task_summary = PythonOperator(
        task_id="summary",
        python_callable=summary,
        dag=dag,
        trigger_rule=TriggerRule.ALL_DONE
    )

2. AriflowのJinjaテンプレート

AirflowはJinjaテンプレート機能を活用しているため、変数・マクロ・フィルターをテンプレートで使用することができます。例えば、{{ ds_nodash }}は DAGを実行したときの時間 {{ task_instance.task_id }}は 実行中のタスクインスタンスのタスクIDとなります。また、変数以外にもJinjaのフィルタやマクロを利用することができます。
The Apache Software Foundation 2022."Templates reference".ApacheAirflow

Filter, Macro

  • Filterでは、Jinjaの組み込みフィルタや値のフォーマットに使用できるフィルタを利用することが可能です。
  • TaskGroup内のTaskIDがTaskGroup名.TaskIDと ドットで連結されるのですが、GoogleのDataprocにジョブを投入するオペレータと互換性がなく、エラーになってしまうバグが報告されています。既に修正されているようですが、何かしら不都合が生じた際にフィルターが使えるのは便利かもしれません。
    apach/airflow."Issue:DataprocJobBaseOperator not compatible with TaskGroups".GitHub
タスクIDの置換
{{ task_instance.task_id | replace('.', '_') }}
  • Macroでは、Airflowが用意している日数を加算・減算する関数の他にも、DAGの引数user_defined_macrosに自分で定義した関数を渡すことで、テンプレート内から呼び出して使うことができます。
DAG
def add_world(str):
    return str + " World!" 

with airflow.DAG(
    ...
    user_defined_macros={"add_world": add_world},
) as dag:
    
    task = PythonOperator(
        task_id="HelloWorld",
        python_callable=lambda x: print(x),
        dag=dag,
        op_kwargs={"x": "{{ add_world('Hello')}}"}
    )
実行結果
Hello World!

テンプレートが返す値の型

便利な反面、Jinjaだから起きる問題もあります。

  • デフォルトの設定では、テンプレートで取得される値は文字列としてレンダリングされるようになっているため、{'1001': 301.27, '1002': 433.21, '1003': 502.22}{{ti.xcom_pull(...)}}"}で後続タスクに渡した場合、文字列になります。
  • 正しく辞書型として渡したい場合は、render_template_as_native_objTrueにする必要があります。
    The Apache Software Foundation 2022."Operators".ApacheAirflow
DAG
with DAG(
    ...
    render_template_as_native_obj=True,
) as dag:

おわりに

今回、Airflowのバージョンの関係でDynamic Task Mappingが使えない場合に、BranchPythonOperator × TaskGropを使ってそれっぽくする方法を書きましたが、もっと良い方法をがないか調査中です!!
Airflowは癖が強い印象を持つ一方で、使っていくと沢山のことをを学べるので楽しいと思いました。Airflowを触ってみて、ワークフローエンジンに興味が出てきたので他のものも学んでいきたいと思います。今回の記事が誰かの役に立てれば幸いです。

ぜひ、他の日のSupershipグループ Advent Calendar 2022 もご覧ください!!

Supershipではプロダクト開発やサービス開発に関わる人を絶賛募集しております。
ご興味がある方は以下リンクよりご確認ください。
Supership株式会社 採用サイト
是非ともよろしくお願いいたします。

参考文献(サイト)

Airflowについて学ぶ上で参考になりました。

12
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?