wandbとは
wandbそのものの始め方などは過去の記事を見てくださるとありがたいです。簡単に言うと、実験結果やログをWeb上に保存することができます。結果を保存したり可視化したりしなくても、web上で確認できるので便利です。
以前は、パラメータサーチをwandbでするドキュメントが少なさ過ぎてちょっと無理って状態だったんですが、いつのまにか充実してたので簡単に記事にします。
簡単に自分のパラメータサーチの結果
結果のtableの真ん中にあるsweepのzvpkx0spというsweep idをクリックするとsweep結果をまとめたページを見れます。
どのパラメータが精度(自由に設定可能)に影響しているかが勝手に計算されてる。
input_channelとbatch_sizeが精度をあげるのに大きく貢献していることがわかる。
Parallel Coordinatesというらしい?。各探索項目でこういうのができるんですね。これどう使うかというと、軸を入れ替えることができます。num_channelは今回Listで渡していて、可視化上うまく処理できていないので消しちゃいましょう。消すときは右上のペンマークの「eddt panel」から消すことができます。
そうするとinput_channelから各accに向かって線が伸びていますね。input_channelが高いほうがaccが高めの位置を通過していることがわかります。色は一番右側のaccが高いものほど明るい色になっていて、input_channelが高め、batch_sizeが高めのケースがよさそうなことがわかります。一方で一番高いのはinput_channelが32の時ではなく16のところを通過してそうですよね。軸を入れ替えてみましょう。入れ替えはinput_channelを掴んで左にドラックするだけです。
するとbatch_sizeとinput_channelの順番を入れ替えられます。こう見るとinput_channel中くらいでbatch_size高めがいいのかな?みたいな分析ができます
ここにあるのが2パラ+精度なので、まぁ大したこともないんですが、10パラあって、パラメータサーチ前に大体の探索範囲を絞りたい!ってときにこういう可視化があるととてもはかどります。
簡単な紹介
公式のipynb見たほうが多分早いですが、数行なので簡単に説明します。
まずはsweep用のconfigを用意します。
method: 'random'
metric:
name: 'acc'
goal: 'maximize'
parameters:
batch_size:
values: [256, 128, 64, 32]
input_channel:
values: [32, 16, 8, 4, 1]
num_channels:
values: [[100, 100], [100], [100, 100, 100]]
methodは探索方法で[grid, random, baysian]から選べます。これおそらくrepeat randomなのでrandomでもほぼ網羅的に探索できると思います。
metricは何でパラメータを評価するかで、nameの「acc」は僕が決めた変数名です。あとでwandb.log({"acc": accuracy})と記録しています。goalはその変数が小さいほうがいいのか、大きいほうがいいのかを指定します。
parametersはみたままです。Listでも大丈夫でした。
上とは別に普通のconfigを用意しておきます。
epochs: 30
lr: 0.01
batch_size: 32
input_channel: 8
num_channels: [100]
これに関してはyamlじゃなくてもいいんですが、最終的にdict形式のconfigが必要です。
次にpyファイルをいじります。自分の場合は大体こんな感じの構成でした
def main():
config = load_config()
wandb.init()
for epoch in config.epochs:
train_loss = train()
accuracy = test()
wandb.log("epoch": epoch, "loss": train_loss, "acc": accuracy)
main()
これをsweep用に変えます。大体こんな感じ
def sweep():
sweep_config = load_config("config_sweep.yaml")
sweep_id = wandb.sweep(config, project="sample")
wandb.agent(sweep_id, main)
def main():
config = load_config("config_default.yaml")
wandb.init(project="sample", config=config)
config = wandb.config
for epoch in config.epochs:
train_loss = train()
accuracy = test()
wandb.log("epoch": epoch, "loss": train_loss, "acc": accuracy)
#main()
sweep()
先ほどのconfig_sweep.yamlを読み込んでdict形式のものをwandb.sweepに渡します。
これをwandb.agentに渡せばいいです。
configはwandb.configから取得するようにすると、ハイパラ探索したいとこだけいい感じに変わってくれる
自分も昨日見つけたばかりなのでもう少しマシなコーディングあったかも。。。では!
注意点
見つけ次第更新
method: "random"
はおそらく、「何回やったら終わる」みたいな設定がない。今回自分は、同じパラメータが二週目以降に入っていて、途中でkillしました。終わるための関数かクラス作っておいてもいいかも
wandb.sweep(config)で渡すのはdict
実は自分は普段、addictのDict形式(これはconfig["epochs"]の代わりに、config.epochsと読み込めるようにする)を使っていて、それを渡したらエラーになりました。