LoginSignup
55
50

More than 3 years have passed since last update.

機械学習のconfigを可読性高く設定する方法

Last updated at Posted at 2019-03-17

はじめに

機械学習のハイパーパラメータがたくさんあって、そのリストをどんな風に可読性高く持つか、というのはよく出てくる問題だと思います。
MATLABだったら構造体で持っておけばいいか、と気楽に思ったりしますがPythonではどうしましょう。
今回は全部で下記5パターンを紹介しますので、もし肌に合うものがあったら使ってみてください。

番号 用いる方法
1 class
2 collections.namedtuple
3 typing.NamedTuple
4 dataclasses
5 tensorflow.app.flags
6 sklearn.utils.Bunch(2020.10.17追記)

1:classを用いる方法

python
class Setting:
    def __init__(self, epoch, batch_size, path, flg=True):
        self.epoch = epoch
        self.batch_size = batch_size
        self.path = path
        self.flg = flg

config = Setting(epoch=50, batch_size=128, path='./data')
print('epoch:', config.epoch)
print('batch size:', config.batch_size)
print('path:', config.path)
print('flg:', config.flg)

config.epoch = 100  # 書き換えも可能
print('--epoch:', config.epoch)
実行結果
epoch: 50
batch size: 128
path: ./data
flg: True
--epoch: 100

この方法はPythonのVersionへの依存性も低くて、割と汎用的です。
書くのが少し長いですが、僕は割と愛用しています。

2:collections.namedtupleを用いる方法

python
import collections
Setting = collections.namedtuple(
    'Setting',
    ['epoch',
     'batch_size',
     'path',
     'flg'
    ]
)

config = Setting(epoch=50, batch_size=128, path='./data', flg=True)
print('epoch:', config.epoch)
print('batch size:', config.batch_size)
print('path:', config.path)
print('flg:', config.flg)

print('--epoch:', config[0])  # インデックスでアクセス可能
for item in config:  # iterでもアクセス可能
    print('--', item)
実行結果
epoch: 50
batch size: 128
path: ./data
flg: True
--epoch: 50
-- 50
-- 128
-- ./data
-- True

こちらがcollections.namedtupleを用いる方法です。
immutableなため、クラス変数のように書き換えはできません。
しかしTupleなため、インデックスやイテレータでのアクセスが可能です。

3:typing.NamedTupleを用いる方法

python
import typing
class Setting(typing.NamedTuple):
    epoch: int
    batch_size: int
    path: str
    flg: bool

config = Setting(epoch=50, batch_size=128, path='./data', flg=True)
print('epoch:', config.epoch)
print('batch size:', config.batch_size)
print('path:', config.path)
print('flg:', config.flg)

print('--epoch:', config[0])  # インデックスでアクセス可能
for item in config:  # iterでもアクセス可能
    print('--', item)
実行結果
epoch: 50
batch size: 128
path: ./data
flg: True
--epoch: 50
-- 50
-- 128
-- ./data
-- True

こちらはtyping.NamedTupleを用いる方法です。基本的には2のcollections.namedtupleと一緒。
python3.6から導入されたもので可読性が高いです。tupleで使うなら、collectionsかtypingでクラスとするか、になると思いますが、僕はこちらが気に入っています。

4:dataclassesを用いる方法

python
from dataclasses import dataclass
@dataclass
class Setting3:
    epoch: int
    batch_size: int
    path: str
    flg: bool = True  # デフォルト定義もOK

config = Setting(epoch=50, batch_size=128, path='./data')
print('epoch:', config.epoch)
print('batch size:', config.batch_size)
print('path:', config.path)
print('flg:', config.flg)

config.epoch = 100  # 書き換えも可能
print('--epoch:', config.epoch)
console
epoch: 50
batch size: 128
path: ./data
flg: True
--epoch: 100

こちらはpython 3.7から導入されたdataclass。可読性も高く、非常に使いやすいです。こちらの方が言うように、スタンダードになるかもしれません。一般的になったらこれを推したい。ただ、バージョン的に3.7以降対応なので、まだ使いづらい部分もあるかもしれません。

5:tensorflow.flagsを用いる方法

python
import tensorflow as tf
def Setting():
    flags = tf.app.flags
    FLAGS = flags.FLAGS
    flags.DEFINE_integer('epoch', None, 'description')  # 初期値を決めないことも可能
    flags.DEFINE_integer('batch_size', 128, '')
    flags.DEFINE_string('path', './data', '')
    flags.DEFINE_boolean('flg', True, '')
    return FLAGS

config = Setting()
print('epoch:', config.epoch)
print('batch size:', config.batch_size)
print('path:', config.path)
print('flg:', True)

config.epoch = 100  # 書き換え可能
print('--epoch:', config.epoch)
実行結果
epoch: None
batch size: 128
path: ./data
flg: True
--epoch: 100

機械学習用のパラメータまとめたいんでしょ?そんなん用意してありますよ、という感じでtensorflowにはこれがあります。
正直1〜4までに紹介したものを使わなくても、パラメータだけならこれだけでもいいと思ってます。

7:sklearn.utils.Bunchを用いる方法

python
from sklearn.utils import Bunch

config = Bunch()
config.epoch = 50
config.batch_size = 128
config.path = './data'

print(config)
for k, v in config.items():
    print(k, v)
{'epoch': 50, 'batch_size': 128, 'path': './data'}
epoch 50
batch_size 128
path ./data

最近はこればかり使っています。
アクセスはconfig.keyでも、辞書のようにconfig["key"]でもできます。
もちろん辞書なのでjsonで設定も保存しておけますし、万能だと思います。

jsonとして設定を保存
import json
with open('./config.json', 'w') as f:
    json.dump(config, f, indent=4)

まとめ

どれでも良い。
と言うと投げやりなので、、、

クラスに書けば一緒にセッティング用のMethodも用意できる。
イテレーティブに回すならタプルが便利。
とか、特徴を踏まえて、自分の使い勝手に合うやつを使えばいいと思います。

ちなみに、今回コード内に記載したような値をJSON,Yamlなどで外から与えてやって、試行毎にそれを保存してやれば、もっと管理しやすくなると思います。

おしまい。

55
50
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
55
50