目的
- ファイルからパラメータを読み込んで、引数に追加するデコレータを作成したので、使い方をシェアしたい
- DLなどのモデルのパラメータ読み込みを簡単にしたい、というのがモチベーション
- main 関数から引数を連携するコードが面倒なので・・・
- DLなどのモデルのパラメータ読み込みを簡単にしたい、というのがモチベーション
-
omegaconf
便利なので、知らない人には知って欲しい- 他にも便利なのがあったら、ぜひ教えて欲しい!
環境
- Colaboratory
コードサンプル(サマリ)
- 準備
-
omegaconf
をインストール
-
%%bash
pip install omegaconf
- デコレータの準備
import functools
from omegaconf import OmegaConf
def add_args(params_file: str, as_default: bool = False) -> callable:
@functools.wraps(add_args)
def _decorator(f: callable) -> callable:
@functools.wraps(f)
def _wrapper(*args, **kwargs) -> None:
cfg_params = OmegaConf.load(params_file)
if as_default:
cfg_params.update(kwargs)
kwargs = cfg_params
else:
kwargs.update(cfg_params)
return f(*args, **kwargs)
return _wrapper
return _decorator
- 読み込むパラメータファイルを用意する(yaml or json)
-
omegaconf
が、yaml, json に対応
-
%%bash
cat <<__YML__ > params.yml
n_encoder_layer: 3
n_decoder_layer: 5
n_heads: 4
n_embedding: 16
__YML__
:
echo "===== [ params.yml ] ====="
cat params.yml
echo "====="
- 呼び出し
@add_args("params.yml")
def use_params(a, b, n_encoder_layer, n_decoder_layer, n_heads, n_embedding):
assert a == 0.25
assert b == "world"
assert n_encoder_layer == 3
assert n_decoder_layer == 5
assert n_heads == 4
assert n_embedding == 16
use_params(a=0.25, b="world")
print("OK")
ここで、use_params()
関数に指定しているのは、a
, b
だけだという点です。
また、以下のように、as_default=True
をデコレータ引数に指定すると、params.yml
の設定をデフォルトとして読み込み、プログラムで上書きすることもできます。(ちなみに、as_default=False
(デコレータのデフォルト)の場合は、プログラムで指定した実引数よりも、設定ファイルの直を優先します。)
@add_args("params.yml", as_default=True)
def use_params(n_encoder_layer, n_decoder_layer, n_heads, n_embedding):
assert n_encoder_layer == 128 # notice !!
assert n_decoder_layer == 5
assert n_heads == 4
assert n_embedding == 16
use_params(n_encoder_layer=128)
print("OK")
- その他
- クラスの
__init__
にデコレートできるのでぜひお試しを -
omegaconf
では、環境変数や設定ファイル内の直を変数として参照できる-
omegaconf
の詳細は、[こちら] (https://omegaconf.readthedocs.io/en/2.0_branch/usage.html) 参照
-
- クラスの
残課題
- 毎回同じコード書くのも微妙なので、pip install できるようにしたいところ