LoginSignup
9
8

More than 1 year has passed since last update.

【Hydra+wandb+etc..】研究室で利用している機械学習実験環境について(追記予定です.)

Last updated at Posted at 2021-12-04

はじめに

TDU_データ科学・機械学習研究室Advent Calendar
5日目

記事書ききれてないです.すみません,もう少し解説加えたものを近日中にアップします!!

普段僕はComputer Visionの研究をしています.
研究ではCNNを用いて伝統文様の解析をしています.昨今の深層学習ではGPUを利用することがほぼマストで.かくいう僕も研究では研究室のGPUマシンを利用しています.リモート(大学のGPU)マシンでの画像関係の実験は難しいです.具体的に言うと,画像は大量にあるので生データを逐一見るのは難しい.

そこで,僕は普段GPUマシン上に利用している実験環境について語ります.

  • パラメータ管理: hydra
  • ログ(Accuracy, loss, GPUモニタリング)管理: wandb
  • 画像の管理
    • sshfs
    • streamlit(オリジナルアプリ)
  • データ管理
    • csv管理
    • dvc(Data Version Control)

hydra

hydraはFacebookが作成しているパラメータ管理ツールです.configをyaml形式で書く形です.
実験結果のoutputフォルダが自動で作成されて,パラメータも自動保存されるのが便利です.
別のパラメータ管理方法(argparseとか)からの移行コストは少しかかるかもしれませんが,一度使い始めたら以降は楽だと思います.

hydra.utils.instantiate()を使うと画像Augmentationの条件をconfigで管理することが楽になります.
hydra.utils.instantiate()について,詳しくは公式ドキュメントを確認してください.

僕の場合configに以下のような項目を追加しています.
trainデータとvalidationデータで分けてaugmentationを定義します.

config.yaml
### ~~~一部省略~~~
aug:
  train_dataset:
    mixup:
      use_mixup: true
      mixup_alpha: 0.2
    augmentation_list:
      - _target_: albumentations.RandomBrightnessContrast
        brightness_limit: 0.2
        contrast_limit: 0.2
        p: 0.5
      - _target_: albumentations.Flip
        always_apply: true
      - _target_: albumentations.Normalize
        p: 1.0
      - _target_: albumentations.pytorch.transforms.ToTensorV2
        always_apply: true
  val_dataset:
    augmentation_list:
      - _target_: albumentations.Normalize
        p: 1.0
      - _target_: albumentations.pytorch.transforms.ToTensorV2
        always_apply: true

学習プログラムのほうでconfig.yamlを読み込み,Augmentationには

import albumentations as A
from hydra.utils import instantiate


def load_augmentation_object(input_list):
    if isinstance(input_list, A.Compose):
        return input_list
    try:
        aug_list = [instantiate(i, _recursive_=False) for i in input_list]
    except:
        aug_list = input_list   
    return A.Compose(aug_list)

train_transform = load_augmentation_object(cfg.aug.train_dataset.augmentation_list)
val_transform = load_augmentation_object(cfg.aug.val_dataset.augmentation_list)

## 以降Dataloaderにtransformを追加

wandb

実験のパラメータと結果(learning curve)などを管理できるツールです.
mlflow, commet,neptuneなど類似品はいろいろあります.気になる人はいろいろ試してみるといいと思います.

ぼく個人としては以下の点が気に入っています.

  1. ローカルにサーバを立てる必要がない
  2. tensorboardが使える(頑張れば)
  3. 公式Youtubeがいろいろ面白い

1. ローカルにサーバを立てる必要がない

Neptuneなどにもある機能ですね.
wandbにユーザ登録後に手元のマシンからログインすれば簡単にオンラインで実験結果を覗くことができます.

2. tensorboardが使える(頑張れば)

Tensorboardといえば,wandbと同じく実験管理ツールですが,モデルのweight distribution機能やデータセットの可視化(PCA, t-SNE)などの機能が優れているので両方使いたいユーザさんはいるかと思います.
tensorboard単体であればローカルにサーバを立てる必要がありますが,wandbを利用している場合,オンラインでtensorboardが利用できます.導入も楽ちんです.

import wandb
wandb.init(sync_tensorboard=True)

あとはtensorboardSummaryWriterを用いてログを吐き出せば使えます.

3. 公式Youtubeが面白い

wandbの公式Youtubeはチュートリアルの内容ももちろんありますが,それ以外にも論文読みであったり,機械学習全般で役立ちそうな情報が手に入ります.
image.png

研究室での使い方

研究室では学生個人でアカウントを作り,Alert機能を使って研究室のslackにマシンを動かしたら通知が行くようにしています.
個人のトライアンドエラーが可視化されるので,見ていて面白いです.
通知を送る際にユーザがどのマシンを使っているのかも出すようにすれば,マシン利用での衝突を避けることにもつながります.

import wandb
import torch
from wandb import AlertLevel

device_name = torch.cuda.get_device_name()
current_device = torch.cuda.current_device()

text = f"""
GPU: {device_name}\n
指定デバイス: {current_device}\n
"""
wandb.alert(
    title = 'GPU情報',
    text = text,
    level = AlertLevel.INFO
)

データの管理

csvで画像管理

dvc

DVC(Data Version Control)はデータセット版Gitのようなものです.
データセットのバージョン管理が可能になります.

使い方

pip install dvc
mkdir dvc_example


git init
dvc init

お試し用のデータをDLする

dvc get https://github.com/iterative/dataset-registry \
get-started/data.xml -o data/data.xml

addする
dvc add data/data.xml

addするとディレクトリ内にdata.xml.dvcというファイルが作成される
中身はこんなん

outs:
- md5: a304afb96060aad90176268345e10355
size: 37891850
path: data.xml

dvcファイルにデータ情報を追記したらgitで管理できるように変更履歴をaddする

git add data/data.xml.dvc data/.gitignore
git commit -m "Add raw data"

google driveにファイルをあげるように設定する

google driveの共有リンクの一番最後のドメインの文字列をコピーする

dvc remote add myremote gdrive://0AIac4JZqHhKmUk9PDA

google driveにファイルをあげようとするとエラー
pydirveが入っていないからのようだ

pip install dvc[gdrive]
で問題なくできた.

dvc pushしたときに出てくるURLを叩いて,verificationを完了すれば,キャッシュファイルが送られる.

画像データセットに対するData Versioning

おわりに

後輩へのメモとして記事執筆してみました.実験管理は扱っているデータ,自分の性格など人によってあうあわないが大きいものだと思います.
僕もかなりいろいろな方の記事を参考にして,もがきながら今の形に落ち着きました.そうした方への感謝とともに,これを読んだ方が自分なりの実験管理方法してみるきっかけにこの記事がなってくれたらうれしいです.

参考資料

  • 実験管理について考える
  • IQ1でもできる!機械学習研究の実験管理!
  • Pythonによる機械学習実験の管理 Toshihiro Kamishima
9
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
8