LoginSignup
1
0

More than 3 years have passed since last update.

PyYAML で array を merge する

Posted at

yaml in データ分析

身近なところでは、機械学習・データ分析周りの設定を yaml で書くのが流行っています(主にKedroを使っています)。
なるべくDRY(don't repeat yourself)にすべく共通の設定はアンカー(&) を使っているのですが、そこで問題になるのが yamlの仕様で、mapping は merge できるが array は merge できないというものです。
これは yaml の仕様としてはサポートしない、というのが yaml チームの見解のようです(https://github.com/yaml/yaml/issues/35 がIssue として立ち上がり、度々Openされその度 Closeされているのが見て取れます)。

yaml で困る具体例

具体的には、以下のような場面で困ります。

common_features: &common
  - member_reward_program_status
  - member_is_subscribing

transaction_features: &transaction
  - num_transactions
  - average_transaction_amount
  - time_since_last_transaction

next_product_to_buy:
  model_to_use: xgboost
  feature_whitelist:
    - *common
    - *transaction
    - last_product_bought
    - applied_to_campaign
  target: propensity

複数のfeatureの塊があったとして、それを組み合わせてモデルを作る場合を考えます。
欲しいものとしては feature_whitelist の中身が

[
  'member_reward_program_status', 
  'member_is_subscribing', 
  'num_transactions', 
  'average_transaction_amount', 
  'time_since_last_transaction', 
  'last_product_bought', 
  'applied_to_campaign'
]

になることなんですが、上の設定だと下のようなネストしたリストになってしまいます。

[
  [
    'member_reward_program_status', 
    'member_is_subscribing', 
  ],
  [
    'num_transactions', 
    'average_transaction_amount', 
    'time_since_last_transaction', 
  ],
  'last_product_bought', 
  'applied_to_campaign'
]

その他の解決法

上の問題を解決するだけであればなんでもいいので、例えばネストしたリストをフラットにする とか、リストじゃなくて辞書型として定義してmergeする、などがあります。

# 辞書型の例
feature_a: &feature_a
  age: 
feature_b: &feature_b
  price:
use_features:
  <<: *feature_a
  <<: *feature_b

使い方は下のようになります。

# > params['use_features'].keys()
dictkeys(['age', 'price'])

また同じ yaml 側で解決する場合も、package を選べる場合は PyYAML の fork である ruamel.yamlを使っても実現できます。

yaml の tag を定義する

今回は Kedro の機能を拡張するために使いたいという背景がありました。
Kedro は TemplatedConfig を読み込む際にanyconfig を使っており、anyconfig 自体は PyYAML にも ruamel.yaml にも対応しているようですが、Kedro サイドで PyYAML を requirements として指定しているので、PyYAML で実現する方法を考えます。

公式のDocs にも自前タグの実装についてある程度の解説はあるので、それを参考にしつつ、タグ用の constructor を定義します。

import yaml

yaml.add_constructor("!flatten", construct_flat_list)

def construct_flat_list(loader: yaml.Loader, node: yaml.Node) -> List[str]:
    """Make a flat list, should be used with '!flatten'

    Args:
        loader: Unused, but necessary to pass to `yaml.add_constructor`
        node: The passed node to flatten
    """
    return list(flatten_sequence(node))

def flatten_sequence(sequence: yaml.Node) -> Iterator[str]:
    """Flatten a nested sequence to a list of strings
        A nested structure is always a SequenceNode
    """
    if isinstance(sequence, yaml.ScalarNode):
        yield sequence.value
        return
    if not isinstance(sequence, yaml.SequenceNode):
        raise TypeError(f"'!flatten' can only flatten sequence nodes, not {sequence}")
    for el in sequence.value:
        if isinstance(el, yaml.SequenceNode):
            yield from flatten_sequence(el)
        elif isinstance(el, yaml.ScalarNode):
            yield el.value
        else:
            raise TypeError(f"'!flatten' can only take scalar nodes, not {el}")

PyYAML は Python のオブジェクトを作成する手前で yaml をPyYAML のオブジェクトにパースした document を作るのですが、その document では array は全て yaml.SequenceNode として、値は yaml.ScalarNode として保存されているので、上のコードで再起的に値だけを取り出すことができます。
機能を確認するためのテストコードは以下のようになります。!flatten の tag をつけることで、ネストされた array をフラットな array に変換できます。

import pytest
def test_flatten_yaml():
    # single nest
    param_string = """
    bread: &bread
      - toast
      - loafs
    chicken: &chicken
      - *bread
    midnight_meal: !flatten
      - *chicken
      - *bread
    """
    params = yaml.load(param_string)
    assert sorted(params["midnight_meal"]) == sorted(
        ["toast", "loafs", "toast", "loafs"]
    )

    # double nested
    param_string = """
    bread: &bread
      - toast
      - loafs
    chicken: &chicken
      - *bread
    dinner: &dinner
      - *chicken
      - *bread
    midnight_meal_long:
      - *chicken
      - *bread
      - *dinner
    midnight_meal: !flatten
      - *chicken
      - *bread
      - *dinner
    """
    params = yaml.load(param_string)
    assert sorted(params["midnight_meal"]) == sorted(
        ["toast", "loafs", "toast", "loafs", "toast", "loafs", "toast", "loafs"]
    )

    # doesn't work with mappings
    param_string = """
    bread: &bread
      - toast
      - loafs
    chicken: &chicken
      meat: breast
    midnight_meal: !flatten
      - *chicken
      - *bread
    """
    with pytest.raises(TypeError):
        yaml.load(param_string)

参考になれば幸いです。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0