1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【Kaggle】Google Colab × WandB で学習状況をスマホで確認したい!

Posted at

【Kaggle】Google Colab × WandB で学習状況をスマホで確認したい!

はじめに

本記事では,WandB(Weights & Biases)を用いて,スマホからでも Colab の学習状況を確認する方法を紹介します.

モチベーション

Kaggle のコンペ用に Colab でモデルの学習を開始しました.
ですが,データセットのサイズが大きくなると,学習がサチるまでに数時間以上かかることもあります.

時々学習の進捗を確認したいのですが,Colab 上でログを見るのはちょっと面倒で見づらい……
ということで,「スマホからでも可視化結果を確認する」 を試してみました.

WandB とは

Weights & Biases(WandB) とは,AI モデルの学習管理やログの可視化ができるクラウドベースのツールです.
TensorBoard のように,学習中のロス・精度・ハイパーパラメータなどを自動で記録・グラフ化できます.
オンプレでも可能なようですが,今回はクラウドで手軽にお試ししてみます.

料金については,有料プランもありますが,ちょっと使ってみる分には無料枠でも可能かと思います.
詳細はこちらを参照.

準備

今回使用するツール

  • Google Colab
  • WandB

WandB 登録

まずは,WandB にアクセスし,アカウントを作成します.
Python から WandB を使用するために API キーが必要になるので,
API キーが表示されたら控えておきましょう.(設定から確認もできます)
アカウント自体はメールアドレスや Google アカウントなどで簡単に作成できます.

Colab 上での WandB 設定

ここからは Colab 上での作業に移ります.

Google Colab 上で WandB を使うために以下の 3 ステップを実行します.

  1. wandb インポート

    import wandb
    

    Colab では元々 wandb はインストールされているはずなので,pip install は不要です.

  2. wandb ログイン

    wandb.login()
    

    セル上で上記を実行すると API キーを入力するフォームが出現します.
    wandb.002.jpeg

    こちらに先ほどコピーしておいた API キーを入力し,エンターを押します.
    wandb.001.jpeg
    上記のように,"Currently logged in as: ~~","True" が表示されたらログイン成功です.
    このログインはセッションをリスタートしなければ継続して使用できます.

  3. wandb 初期化

    # ==== wandb 初期化 ====
    wandb.init(
        project="mnist-cnn-demo",
        name="practice", 
        )
    

    init() の引数 project は wandb 上のプロジェクトの名前を指定します.指定した名前のプロジェクトがない場合は自動で作成されます.name はプロジェクト内の runs の名前を指定します.
    プロジェクト名・runs 名は任意で OK ですが,WandB のダッシュボード上で整理されるので意味のある名前にしておくと便利です.

学習準備

今回は wandb の可視化確認用に MNIST + 数層の CNN を構築します.
こちらは今回の記事の主旨ではないので,サラッと記載します.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from sklearn.metrics import precision_score, recall_score, f1_score


# ==== デバイス ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==== データ ====
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2048, shuffle=True)
valid_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2048, shuffle=False)

# ==== モデル ====
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128), nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.conv(x)
        return self.fc(x)

# モデル・損失関数・最適化関数作成
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# ==== 学習ループ ====
for epoch in range(30):
    model.train()
    train_loss, train_correct, train_total = 0, 0, 0

    for x, t in train_loader:
        x, t = x.to(device), t.to(device)
        optimizer.zero_grad()
        y = model(x)
        loss = criterion(y, t)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * x.size(0)
        train_correct += (y.argmax(1) == t).sum().item()
        train_total += x.size(0)

    train_acc = train_correct / train_total
    train_loss /= train_total

    # ==== 検証 ====
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for x, t in valid_loader:
            x, t = x.to(device), t.to(device)
            y = model(x)
            loss = criterion(y, t)

            val_loss += loss.item() * x.size(0)
            val_correct += (y.argmax(1) == t).sum().item()
            val_total += x.size(0)

            all_preds.extend(y.argmax(1).cpu().numpy())
            all_labels.extend(t.cpu().numpy())

    val_acc = val_correct / val_total
    val_loss /= val_total

    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')

    # ==== wandb ログ ====
    wandb.log({
        "train_loss": train_loss,
        "train_acc": train_acc,
        "val_loss": val_loss,
        "val_acc": val_acc,
        "val_precision": precision,
        "val_recall": recall,
        "val_f1_score": f1
    }, step=epoch)

    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, F1={f1:.4f}")
wandb.finish()

上記コードのポイントは下記の 2 箇所です.

# ==== wandb ログ ====
wandb.log({
    "train_loss": train_loss,
    "train_acc": train_acc,
    "val_loss": val_loss,
    "val_acc": val_acc,
    "val_precision": precision,
    "val_recall": recall,
    "val_f1_score": f1
}, step=epoch)

上記で,学習・評価の Loss・Accuracy・Precision・Recall・F1 をエポックごとに wandb に書き込んでいきます.

wandb.finish()

これで wandb のログ書き込みを終了します.
これがなくてもログは送られますが,run の完了が正しく記録されない場合があるため,入れておくのがおすすめです.

実行

上記セルを実行します.
Colab 上で下記のようなログが表示されたら実行できています.

Epoch 1: Train Loss=1.1301, Val Loss=0.4320, F1=0.8693
Epoch 2: Train Loss=0.3341, Val Loss=0.2428, F1=0.9269
Epoch 3: Train Loss=0.2143, Val Loss=0.1629, F1=0.9496
Epoch 4: Train Loss=0.1512, Val Loss=0.1175, F1=0.9638
Epoch 5: Train Loss=0.1173, Val Loss=0.0941, F1=0.9700
Epoch 6: Train Loss=0.0964, Val Loss=0.0810, F1=0.9747
Epoch 7: Train Loss=0.0808, Val Loss=0.0677, F1=0.9783
Epoch 8: Train Loss=0.0697, Val Loss=0.0582, F1=0.9827

WandB で確認

まずは PC のブラウザから確認してみます.
wandb1.png

ログ書き込みを行った val の各評価指標や train の loss など表示されました.

次に,目的のスマホからも確認してみます.
wandb.003.jpeg

スマホからも描画を確認できました.

まとめ

今回は WandB を使って,Google Colab 上の学習ログをスマホから確認する方法を紹介しました.

少しのコード追加だけでロスや精度などの可視化ができ,
スマホからでもリアルタイムで進捗を確認できるのは非常に便利でした.

特に,TensorBoard などをすでに使用している方にとっては,数行の追加で確認できるようになるので,とても有用かと思います.

Kaggle などで長時間学習を走らせているときに,
Colab を開かずに進捗チェックできるので,今後重宝しそうです.

学習の見える化・効率化に活用してみてください.

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?