※ この記事は、WandBの基本的な使い方を紹介しています。WandBの応用についてはこちらで紹介しているのでご参照ください:
WandBとは
Weights and Biases (WandB)は、機械学習のための実験管理プラットフォームです。機械学習モデルのtrainやtestの実行中に生じる全てのデータを追跡・可視化することができます。また実験の終了をメールやslackなどで通知することも可能です。基本的には無料で使えます。
ちなみに発音については、筆者の周りでは「ワンダブ」と言う人と「ワンディービー」と言う人で別れてます。
(上図はhttps://docs.wandb.ai/quickstartより)
Project・Group・Jobtype・Runの概念
WandBでは、最大4つの階層で実験管理をします。最上級のものがProject、その次がGroup、その次がJobtype、その次がRunとなります。無論、Projectの直下にRunを置くことも可能なので、必要に応じて実験管理の粒度を変えることができます。
下図では、group-demo
というProject下に、exp_1 ~ exp_5
の5つの実験Groupがあります。exp_5
を展開すると更にrollout, eval2, eval, optimizer
と4種類のJobtypeがあり、更に optimizer
を展開するとeffortless-serenity-22, still-star-21
の2種類のRunがある状態です。なお、ここで表示されているRunの名前は自動生成されたものですが、自分で指定することも可能です。
(実際にデモを触りたい方はこちらから:https://wandb.ai/carey/group-demo)
WandBの登録とセットアップ
ここではwandbのアカウント登録から、実際に使うまでの流れを説明していきます。内容は https://docs.wandb.ai/quickstart の流れに沿って記述しています。
Step 1 アカウント登録
公式ページの「利用開始」ボタンから、アカウント登録をする。
Step 2 ライブラリのインストール
WandBを使いたい環境で、WandBをインストールする。
$ pip install wandb
(ノートブックを使っている場合は !pip install wandb
をセルで実行する。)
Step 3 WandBへのログイン
コマンドライン上でWandBにログインをする。
$ wandb login
すると次のようなメッセージが表示されるので、https://wandb.ai/authorize
からAPIキーをコピーし、コマンドライン上に貼り付ける。
wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:
WandBを実際に使ってみる
ここでは、実際にpythonでwandbを使っていきます。MNIST分類を事例にやっていきます。
モデルの作成
MNISTを解くための単純な全結合層だけのネットワークを定義します。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
# Mnist用の単純なネットワークの定義
class Net1(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 100)
self.fc2 = nn.Linear(100, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
データの準備
MNISTデータセットを準備します。
import torchvision.datasets as datasets
# データセット準備
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)
wandbを使わない場合の学習
wandbを使わない場合の学習コードを一応載せておきます。
def train(net, trainloader, optimizer, criterion, epochs):
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1} loss: {running_loss/len(trainloader)}")
# 各種定義
net1 = Net1()
optimizer = optim.Adam(net1.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
config = {
'net':net1,
'trainloader':trainloader,
'optimizer':optimizer,
'criterion':criterion,
'epochs': 10,
}
# 学習
train(**config)
wandbを使う場合の学習
以下が本記事の要となる、wandbを使う場合の学習コードです。以下で各行の意味について具体的に説明していきます。
import wandb
from wandb import AlertLevel
def train_with_wandb(net, trainloader, optimizer, criterion, epochs):
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1} loss: {running_loss/len(trainloader)}")
# ② wandbに記録したい項目のログ記録
wandb.log({"loss": running_loss/len(trainloader), "epoch": epoch})
net1 = Net1()
optimizer = optim.Adam(net1.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
config_dict = {
'net':net1,
'trainloader':trainloader,
'optimizer':optimizer,
'criterion':criterion,
'epochs': 10,
}
# ① wandbをinitializeした上で学習
with wandb.init(project='wandb_tutorial', group='tutorial',config=config_dict,name='testrun'):
train_with_wandb(**config_dict)
# ③ wandbを使ったアラート通知
wandb.alert(
title='wandb_tutorial',
text='<@slack id> net1の学習が終わりました!',
level=AlertLevel.INFO
)
# ④ wandbのrunを終了
wandb.finish()
① wandbをinitializeした上で学習
with wandb.init(project='wandb_tutorial', group='tutorial',config=config,name='testrun'):
ここではwandbのセッションをinitializeしています。各引数の意味は次の通りです。
-
project
: プロジェクト名(ここでは wandb_tutorial) -
group
: プロジェクト内でのグループ名(ここでは tutorial) -
config
: wandbに記録させたい実験のconfig。ここでは、モデルやエポック数に関する情報を含んだ辞書config_dictを使っている。他にも、簡単な実験の説明文や、checkpointの保存先などを記載することも可能。 -
name
:wandbのrunの名前(ここではtestrun)
他にも、必要に応じて
-
job_type
:ジョブの種類 -
tags
:タグの配列 -
notes
:実験の説明文(gitのコメント的なもの)
などを追加することも可能。(詳細はドキュメンテーション参照)
② wandbに記録したい項目のログ記録
wandb.log( {"loss": running_loss/len(trainloader),
"epoch": epoch} )
ここでは、epochごとにロスの記録をしています。これにより、wandb上でepochあたりのロスを記録することができます。縦軸・横軸は自由に差し替え可能なので、他にもepochあたりのvalidation errorやaccuracyなどを記録することも可能です。
wandb.init()の前にwandb.log()をするとエラーを吐くので、注意してください。
③ wandbを使ったアラート通知
wandb.alert(
title='wandb_tutorial',
text='<@slack_id> net1の学習が終わりました!',
level=AlertLevel.INFO
)
実験終了後、その旨をアラート通知するようにします。通知はメールとslackの両方またはどちらかだけに送信することが可能です。通知先の設定に関しては、設定画面(https://wandb.ai/settings)の「Alerts」項目の「wandb.alert()」から設定できます。(slackと連携する場合は、wandbアプリをワークスペースにインストールする必要があります:)
Slack通知は通知設定で指定されたワークスペースの指定されたチャンネルに、次のような形で送られます。slack_id
を自分のslack idに置き換えれば、@通知が送られてきます。青いリンクを踏めば、該当runのページに飛びます。
またAlertLevel
はAlertLevel.INFO
、AlertLevel.WARN
、AlertLevel.ERROR
の3種類があるので、try/except構文でエラー時にエラー通知を送ったりと、必要に応じて通知の種類も変えることができます。
④ wandbのrunを終了
wandbのrun sessionを終了し、記録した情報を全てwandb側にアップロードします。このアップロード処理はpythonプロセスの終了と同時に(原則)自動的に行われるので、厳密には書く必要がありませんが、書いた方が丁寧です。
# ④ wandbのrunを終了
wandb.finish()
まとめ
- WandBは実験管理用のサービス
-
wandb.init()
で記録先を指定し、 -
wandb.log()
で記録したい情報を記録する - 更に
wandb.alert()
でアラート通知をslack等に送る
import wandb
from wandb import AlertLevel
def train_with_wandb(net, trainloader, optimizer, criterion, epochs):
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1} loss: {running_loss/len(trainloader)}")
# ② wandbに記録したい項目のログ記録
wandb.log({"loss": running_loss/len(trainloader), "epoch": epoch})
net1 = Net1()
optimizer = optim.Adam(net1.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
config_dict = {
'net':net1,
'trainloader':trainloader,
'optimizer':optimizer,
'criterion':criterion,
'epochs': 10,
}
# ① wandbをinitializeした上で学習
with wandb.init(project='wandb_tutorial', group='tutorial',config=config_dict,name='testrun'):
train_with_wandb(**config_dict)
# ③ wandbを使ったアラート通知
wandb.alert(
title='wandb_tutorial',
text='<@slack id> net1の学習が終わりました!',
level=AlertLevel.INFO
)