0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Lightning Flashで誰でも簡単にディープラーニングを!

Posted at

はじめに

この記事のターゲット

  • ディープラーニングに興味があって、とにかく何か試してみたい
  • ディープラーニングをやりたくてPyTorchを見てみたけど全くわからない

Lightning Flash(以下Flash)とは?

PyTorchのラッパー(PyTorch Lightning)のラッパーです。PyTorchから見ればPyTorch Lightningでもある程度簡単に描くことができるのですが、それでも初心者には理解に時間がかかります。FlashはPyTorch Lightningからさらに簡潔にコーディングすることができるので、初心者にはうってつけのフレームワークになっています。

サンプル:画像分類

ディープラーニング(特にCNN)を学ぶときにほとんどの人が最初にやるであろう画像分類をやってみましょう。
参考:https://lightning-flash.readthedocs.io/en/latest/reference/image_classification.html

Flashインストール

画像分類では画像用のFlashをインストールします。

$ pip install lightning-flash[image]

使っているパソコンにグラボが搭載していて、GPUを使いたい場合はインストールされたPyTorchをGPU版に置き換えます。

$ pip uninstall torch torchvision
$ pip install torch==[アンインストールされたtorchのバージョン] torchvision==[torchvisionのバージョン] \
  --extra-index-url https://download.pytorch.org/whl/[cuda version]

[cuda version]は現在はcu113cu116が使われることが多いです。

学習コードの作成

今回の学習コードは以下のようになります。

train.py
from torch.cuda import device_count
from flash import Trainer
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

def main():
    download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
    datamodule = ImageClassificationData.from_folders(
        train_folder="./data/hymenoptera_data/train/",
        val_folder="./data/hymenoptera_data/val/",
        test_folder="./data/hymenoptera_data/test/",
        batch_size=16
    )
    model = ImageClassifier(backbone="resnet34", labels=datamodule.labels)
    trainer = Trainer(max_epochs=5, gpus=device_count())
    trainer.finetune(model, datamodule=datamodule)
    trainer.test(model, datamodule=datamodule)
    trainer.save_checkpoint("finetuned_resnet34.pt")

if __name__ == "__main__":
    main()

21行で学習コードを書くことができました。説明は少し後にあります。

学習結果の表示

テスト結果はtrainer.testを実行したときに確認できますが、学習途中の推移を見たいときはTensorBoardを使用します。

$ tensorboard --logdir ./lightning-flash/version_[ディレクトリの中で1番大きい数字]

実行したらブラウザで http://localhost:6006/ を開いて、学習結果を確認します。参考に、自分が学習した結果を載せておきます。
検証データの正解率は75%ほどで、前処理をしていないこと、エポック数が少ないことを考えるとこんなもんかって感じです。

image.png

予測

次に予測コードを作成しましょう。

predict.py
from flash import Trainer
from flash.image import ImageClassificationData, ImageClassifier

def main():
    datamodule = ImageClassificationData.from_folders(
        test_folder="./data/hymenoptera_data/test/",
        predict_folder="./data/hymenoptera_data/predict/",
        batch_size=16
    )
    model = ImageClassifier.load_from_checkpoint("finetuned_resnet34.pt")
    trainer = Trainer()
    prediction = trainer.predict(model, datamodule=datamodule)
    for pred_batch in prediction:
        for pred in pred_batch:
            print(f"{pred['metadata']['filepath']}: {datamodule.labels[pred['preds'].argmax()]}")

if __name__ == "__main__":
    main()

これを実行すると以下のような出力が得られます。(ants、beesは学習によって変わる)

./data/hymenoptera_data/predict/1247887232_edcb61246c.jpg: ants
./data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg: ants
./data/hymenoptera_data/predict/170652283_ecdaff5d1a.jpg: ants
./data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg: ants
./data/hymenoptera_data/predict/220376539_20567395d8.jpg: bees
./data/hymenoptera_data/predict/239161491_86ac23b0a3.jpg: ants
./data/hymenoptera_data/predict/319494379_648fb5a1c6.jpg: ants
./data/hymenoptera_data/predict/477437164_bc3e6e594a.jpg: ants
./data/hymenoptera_data/predict/488272201_c5aa281348.jpg: ants
./data/hymenoptera_data/predict/72100438_73de9f17af.jpg: ants

コードの説明

学習コード

インポート部

from torch.cuda import device_count
from flash import Trainer
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

インポートするだけなので特に言うことはありませんが、1行目は利用できるGPUの数を数える関数です。CPU版なら0、GPU版なら接続しているグラボの数が出ると思います。

データ部

download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
datamodule = ImageClassificationData.from_folders(
    train_folder="./data/hymenoptera_data/train/",
    val_folder="./data/hymenoptera_data/val/",
    test_folder="./data/hymenoptera_data/test/",
    batch_size=16
)

download_dataでインターネットからデータをダウンロード、zipを展開します。その後、./data/hymenoptera_data/内にtrainvaltestpredictの4つのディレクトリができているので、そのうち学習に使うtrainval、テストに使うtestをそれぞれ指定します。
また、今回バッチサイズは16に設定しました。もしメモリに乗り切らなかったら4や8など小さな値を指定してください。

モデル部

model = ImageClassifier(backbone="resnet34", labels=datamodule.labels)

今回はResNet34を使用します。このモデルの詳細は他の人の解説を見てください。画像分類には最終的にいくつのクラスに分類するかの情報が必要なため、datamodule.labelslabelsに指定します。この中身は["ants", "bees"]とリストでした。

トレーナー部

trainer = Trainer(max_epochs=5, gpus=device_count())
trainer.finetune(model, datamodule=datamodule)
trainer.test(model, datamodule=datamodule)
trainer.save_checkpoint("finetuned_resnet34.pt")

Trainerクラスではエポック数とGPUの数を指定します。先ほど言ったように、利用できるGPUの数はtorch.cuda.device_count()で取得できます。明らかに数が分かる場合は数字に置き換えたり、GPUを使わない場合はgpusを省略できたりします。
ファインチューニングはfinetune、テストはtestモジュールで実行できます。その際にモデルとデータモジュールを指定します。
それらが終わったら、後で予測に使用するためにsave_checkpointで保存しましょう。名前は好きに決めてください。

予測コード

インポート部

from flash import Trainer
from flash.image import ImageClassificationData, ImageClassifier

今回は予測用のデータの数が多くないので、GPUの利用をしません。枚数が少ない時はむしろCPUで動作させた方が速い時があります(ロードなどの関係)。

データ部

datamodule = ImageClassificationData.from_folders(
    test_folder="./data/hymenoptera_data/test/",
    predict_folder="./data/hymenoptera_data/predict/",
    batch_size=16
)

予測用のデータはpredict_folderで指定します。予測をするだけならこれだけで十分なのですが、このディレクトリを見てみると他とは違ってants, beesという名前のディレクトリに入っておらず、画像が入っています。このようにラベルを持たないため、datamodule.labelsを使用するためにtest_folderも指定しました。

モデル部

model = ImageClassifier.load_from_checkpoint("finetuned_resnet34.pt")

ファインチューニングして保存したものを使いたいので、load_from_checkpointメソッドでロードします。

トレーナー部

trainer = Trainer()
prediction = trainer.predict(model, datamodule=datamodule)
for pred_batch in prediction:
    for pred in pred_batch:
        print(f"{pred['metadata']['filepath']}: {datamodule.labels[pred['preds'].argmax()]}")

Traineでの予測用のメソッドはpredictです。これによってデータモジュールのpredict_folder内の画像がどのクラスに分類されるか予測することができます。
予測結果は2次元のリストで返ってきて、リストの中身は様々な情報が入っている辞書型変数です。この中のpredsに予測結果が入っています。
予測結果はtensor([ 0.7435, -0.6141])といった具合に、クラスごとに数値を出していて、それが一番大きなクラスが予測された分類結果ということになります。なのでargmax()で値が一番大きなインデックスを取り、item()でTensorからただの整数に変換しています。
metadataキーにあるfilepathと一緒に表示することで、どの画像がどのクラスに分類するかを見やすくしています。

余談(Flash Zero)

Flashではコードを書かなくてもモデルの学習ができる機能(Flash Zero)があります。今回の学習コードをノーコードで実行する場合には以下のようにします。

$ flash image_classification --model.backbone resnet34 --trainer.max_epochs 5 --trainer.gpus 1 \
    from_folders --train_folder data/hymenoptera_data/train/ --val_folder data/hymenoptera_data/val/ \
    --test_folder data/hymenoptera_data/test/ --batch_size 16

これを実行すると同じようにプログレスバー表示されて学習が進み、終了するとimage_classification_model.ptが生成されます。

おわりに

今回はLightning Flashを使って簡単にディープラーニングのモデルの学習を行いました。
好評だったら他の例も載せていくので、いいねをつけていただけると励みになります。
特に、自然言語処理はサンプルが英語のデータのものしかないので、日本語バージョンのものを公開していきたいと思います。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?