テンプレートを用いてバージョン管理
モデルの学習用のコンフィグのバージョン管理に便利そうなライブラリ:Jinja2( Python用のテンプレエンジン)があるので、実際の使い方を想定して、今回はYAMLと組み合わせたものを共有します。
流れとしては、⇓の画像のようになります。
- テンプレートを作成
- テンプレートからコンフィグファイルを作成
- 作成したコンフィグファイルを用いてモデルを学習
実装
1. Jinja2 Templateの作成: template.yaml.j2
テンプレに入力できるようにしたい項目は{{}}で囲うと認識されるようになります。
DATASET:
TEST_SET: '{{TEST_SET_DIR | default('./test')}}'
TRAIN_SET: '{{TRAIN_SET_DIR | default('./train')}}'
MODEL:
NAME: FCN
TRAIN:
BATCH_SIZE: {{BATCH_SIZE | default(4)}}
EPOCHS: {{EPOCHS | default(25)}}
TEST:
BATCH_SIZE: {{BATCH_SIZE | default(4)}}
EPOCHS: {{EPOCHS | default(25)}}
2. Configuration YAML file をテンプレートから作成
from jinja2 import Template
#出力ファイル名
render_to ='configured.yaml'
# 先ほど作成したjinja2 templateを開く
with open('model_config.yaml.j2') as file_:
template = Template(file_.read())
# モデル用のハイパーパラメータを設定;辞書形式で指定します
# GUIを作成して入力できるようにしたら尚楽かも?
param_config = {
"TEST_SET": "./test_set",
"TRAIN_SET": "./train_set",
"BASE_SIZE": "16",
"EPOCHS": "25",
}
#テンプレートを用いてコンフィグファイルを作成
rendered_conf = template.render(param_config)
#書き出し
with open(render_to, 'w') as f:
f.write(rendered_conf)
作成されたYaml FILE
DATASET:
TEST_SET: './test'
TRAIN_SET: './train'
MODEL:
NAME: FCN
TRAIN:
BATCH_SIZE: 4
EPOCHS: 25
TEST:
BATCH_SIZE: 4
EPOCHS: 25
3. Configuration YAML file の情報をモデル学習用コードから取得
def main():
#Open config yaml file
conf_yaml ='config.yaml'
try:
with open (conf_yaml, 'r') as file:
config = yaml.safe_load(file)
except Exception as e:
raise('Error reading the config file')
#各設定パラメータはキーを指定して取得できます。
batch_size = config['TRAIN']['BATCH_SIZE']
#取得した情報を用いて学習を開始
train_model(batch_size)
if __name__ == "__main__":
main()
Writed by F.K(20代・入社3年目)