13
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【入門編】WandBの基本 〜登録から実装、slack通知まで〜

Last updated at Posted at 2023-05-29

※ この記事は、WandBの基本的な使い方を紹介しています。WandBの応用についてはこちらで紹介しているのでご参照ください:

WandBとは

Weights and Biases (WandB)は、機械学習のための実験管理プラットフォームです。機械学習モデルのtrainやtestの実行中に生じる全てのデータを追跡・可視化することができます。また実験の終了をメールやslackなどで通知することも可能です。基本的には無料で使えます。

ちなみに発音については、筆者の周りでは「ワンダブ」と言う人と「ワンディービー」と言う人で別れてます。

wandb_demo_experiments.gif
(上図はhttps://docs.wandb.ai/quickstartより)

Project・Group・Jobtype・Runの概念

WandBでは、最大4つの階層で実験管理をします。最上級のものがProject、その次がGroup、その次がJobtype、その次がRunとなります。無論、Projectの直下にRunを置くことも可能なので、必要に応じて実験管理の粒度を変えることができます。
下図では、group-demoというProject下に、exp_1 ~ exp_5 の5つの実験Groupがあります。exp_5を展開すると更にrollout, eval2, eval, optimizerと4種類のJobtypeがあり、更に optimizer を展開するとeffortless-serenity-22, still-star-21の2種類のRunがある状態です。なお、ここで表示されているRunの名前は自動生成されたものですが、自分で指定することも可能です。

(実際にデモを触りたい方はこちらから:https://wandb.ai/carey/group-demo)

Screen_Shot_2023-05-09_at_18_27_10.png

WandBの登録とセットアップ

ここではwandbのアカウント登録から、実際に使うまでの流れを説明していきます。内容は https://docs.wandb.ai/quickstart の流れに沿って記述しています。

Step 1 アカウント登録

公式ページの「利用開始」ボタンから、アカウント登録をする。

Step 2 ライブラリのインストール

WandBを使いたい環境で、WandBをインストールする。

$ pip install wandb

(ノートブックを使っている場合は !pip install wandb をセルで実行する。)

Step 3 WandBへのログイン

コマンドライン上でWandBにログインをする。

$ wandb login

すると次のようなメッセージが表示されるので、https://wandb.ai/authorizeからAPIキーをコピーし、コマンドライン上に貼り付ける。

wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 

WandBを実際に使ってみる

Open In Colab

ここでは、実際にpythonでwandbを使っていきます。MNIST分類を事例にやっていきます。

モデルの作成

MNISTを解くための単純な全結合層だけのネットワークを定義します。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms

# Mnist用の単純なネットワークの定義
class Net1(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 100)
        self.fc2 = nn.Linear(100, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

データの準備

MNISTデータセットを準備します。

import torchvision.datasets as datasets

# データセット準備
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)

wandbを使わない場合の学習

wandbを使わない場合の学習コードを一応載せておきます。

def train(net, trainloader, optimizer, criterion, epochs):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1} loss: {running_loss/len(trainloader)}")

# 各種定義
net1 = Net1()
optimizer = optim.Adam(net1.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

config = {
    'net':net1,
    'trainloader':trainloader,
    'optimizer':optimizer,
    'criterion':criterion,
    'epochs': 10,
}

# 学習
train(**config)

wandbを使う場合の学習

以下が本記事の要となる、wandbを使う場合の学習コードです。以下で各行の意味について具体的に説明していきます。

import wandb
from wandb import AlertLevel

def train_with_wandb(net, trainloader, optimizer, criterion, epochs):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1} loss: {running_loss/len(trainloader)}")
        # ② wandbに記録したい項目のログ記録
        wandb.log({"loss": running_loss/len(trainloader), "epoch": epoch})

net1 = Net1()
optimizer = optim.Adam(net1.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

config_dict = {
    'net':net1,
    'trainloader':trainloader,
    'optimizer':optimizer,
    'criterion':criterion,
    'epochs': 10,
}

# ① wandbをinitializeした上で学習
with wandb.init(project='wandb_tutorial', group='tutorial',config=config_dict,name='testrun'):

    train_with_wandb(**config_dict)
    
    # ③ wandbを使ったアラート通知
    wandb.alert(
        title='wandb_tutorial',
        text='<@slack id> net1の学習が終わりました!',
        level=AlertLevel.INFO
    )
    
# ④ wandbのrunを終了
wandb.finish()

① wandbをinitializeした上で学習

with wandb.init(project='wandb_tutorial', group='tutorial',config=config,name='testrun'):

ここではwandbのセッションをinitializeしています。各引数の意味は次の通りです。

  • project: プロジェクト名(ここでは wandb_tutorial)
  • group: プロジェクト内でのグループ名(ここでは tutorial)
  • config: wandbに記録させたい実験のconfig。ここでは、モデルやエポック数に関する情報を含んだ辞書config_dictを使っている。他にも、簡単な実験の説明文や、checkpointの保存先などを記載することも可能。
  • name:wandbのrunの名前(ここではtestrun)

他にも、必要に応じて

  • job_type:ジョブの種類
  • tags:タグの配列
  • notes:実験の説明文(gitのコメント的なもの)

などを追加することも可能。(詳細はドキュメンテーション参照)

② wandbに記録したい項目のログ記録

wandb.log( {"loss": running_loss/len(trainloader),
           "epoch": epoch} )

ここでは、epochごとにロスの記録をしています。これにより、wandb上でepochあたりのロスを記録することができます。縦軸・横軸は自由に差し替え可能なので、他にもepochあたりのvalidation errorやaccuracyなどを記録することも可能です。

wandb.init()の前にwandb.log()をするとエラーを吐くので、注意してください。

③ wandbを使ったアラート通知

wandb.alert(
    title='wandb_tutorial',
    text='<@slack_id> net1の学習が終わりました!',
    level=AlertLevel.INFO
)

実験終了後、その旨をアラート通知するようにします。通知はメールとslackの両方またはどちらかだけに送信することが可能です。通知先の設定に関しては、設定画面(https://wandb.ai/settings)の「Alerts」項目の「wandb.alert()」から設定できます。(slackと連携する場合は、wandbアプリをワークスペースにインストールする必要があります:)

Slack通知は通知設定で指定されたワークスペースの指定されたチャンネルに、次のような形で送られます。slack_id を自分のslack idに置き換えれば、@通知が送られてきます。青いリンクを踏めば、該当runのページに飛びます。

またAlertLevelAlertLevel.INFOAlertLevel.WARNAlertLevel.ERRORの3種類があるので、try/except構文でエラー時にエラー通知を送ったりと、必要に応じて通知の種類も変えることができます。

④ wandbのrunを終了

wandbのrun sessionを終了し、記録した情報を全てwandb側にアップロードします。このアップロード処理はpythonプロセスの終了と同時に(原則)自動的に行われるので、厳密には書く必要がありませんが、書いた方が丁寧です。

# ④ wandbのrunを終了
wandb.finish()

まとめ

  • WandBは実験管理用のサービス
  • wandb.init()で記録先を指定し、
  • wandb.log()で記録したい情報を記録する
  • 更にwandb.alert()でアラート通知をslack等に送る
import wandb
from wandb import AlertLevel

def train_with_wandb(net, trainloader, optimizer, criterion, epochs):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1} loss: {running_loss/len(trainloader)}")
        # ② wandbに記録したい項目のログ記録
        wandb.log({"loss": running_loss/len(trainloader), "epoch": epoch})

net1 = Net1()
optimizer = optim.Adam(net1.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

config_dict = {
    'net':net1,
    'trainloader':trainloader,
    'optimizer':optimizer,
    'criterion':criterion,
    'epochs': 10,
}

# ① wandbをinitializeした上で学習
with wandb.init(project='wandb_tutorial', group='tutorial',config=config_dict,name='testrun'):

    train_with_wandb(**config_dict)
    
    # ③ wandbを使ったアラート通知
    wandb.alert(
        title='wandb_tutorial',
        text='<@slack id> net1の学習が終わりました!',
        level=AlertLevel.INFO
    )
13
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?