#0.はじめに
Pytorchを書くのってKerasに比べて結構大変じゃないですか?(まぁ逆にそれが魅力でもあるんですが)
そんな時は、お試しでPyTorch Lightningを使ってみるのはどうでしょうか?
この手の記事って小難しくて途中で読むの挫折するケースが多いと素がポンコツな私自身は感じるので、本記事はいつも通り初心者向けに平易な形で解説していきます。
※PyTorchの基礎部分は解説してないので、それはある程度分かっている前提なのはあしからず。
・Github
・公式ドキュメンテーション
- 動作環境
- OS : Windows10 pro
- Python : 3.8.3// Miniconda 4.9.1
- (py)torch:1.7.1
- PyTorch Lightning:1.3.6 ※バージョンによってかなり差異がある模様
- jupyter notebook
#1.導入
pip install pytorch-lightning
で導入できた。
#2.CIFER10の画像分類に組み込んで使う
###2-1.インポートモジュールとバージョン確認
# ! pip install pytorch-lightning
import os
import torch
from torch import nn,optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
#pytorchとpytorch_lightningのバージョン確認
print(torch.__version__)
print(pl.__version__)
1.7.1
1.3.6
###2-2.データローダー作成まで
"""通常のPytorch同様な処理を行う(適当に)"""
#前処理+Tensor化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5), # RGB 平均
(0.5, 0.5, 0.5)# RGB 標準偏差
)
])
#CIFAR10データセットをダウンロード(train+test用) 直下のdataフォルダに格納する
train = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
#更にtest用のデータセットをvalidation,test用へ分ける
n_val, n_test = int(len(testset)*0.9), int(len(testset)*0.1) #valid9割、test1割
val, test = torch.utils.data.random_split(testset, [n_val, n_test])
#データローダーをそれぞれ準備
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(val, batch_size=32,shuffle=False, num_workers=2)
testloader = torch.utils.data.DataLoader(test, batch_size=32,shuffle=False, num_workers=2)
#データ数確認
print(len(train))
print(len(val))
print(len(test))
50000
9000
1000
###2-3.LightningModuleでクラス定義
PyTorch Lightningでは、ニューラルネット/損失関数/オプティマイザを1つのLightningModuleを継承したクラスにまとめて定義
する。細かい中身はコメント参照。
※なお、この中に上で作成したデータローダーを入れることもできるが今回はそれをしていない。(なぜか私はエラーになってしまったので)
"""
・torch.nn.functional形式
・self.log()はLightningModuleのlogでデフォルトはTensorBoard形式
・(必須)と(無くてもいいoption)があるので色々試してほしい
"""
class My_litmodel(pl.LightningModule):
"""★(必須)initは通常のPytorchと同じ★"""
def __init__(self):
super(My_litmodel, self).__init__()
# 畳み込み層の定義
self.conv1 = nn.Conv2d(3, 6, 5) # コンボリューション1
self.conv2 = nn.Conv2d(6, 16, 5) # コンボリューション2
# 全結合層の定義
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 全結合
self.fc2 = nn.Linear(120, 84) # 全結合
self.fc3 = nn.Linear(84, 10) # CIFAR10のクラス数が10なので出力を10にする
#プーリング層の定義
self.pool = nn.MaxPool2d(2, 2) # maxプーリング
"""★(必須)forwardも通常のPytorchと同じでいい★"""
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # conv1~relu~pool
x = self.pool(F.relu(self.conv2(x))) # conv2~relu~pool
x = x.view(-1, 16 * 5 * 5) # 平坦化(1次元化する)
x = F.relu(self.fc1(x)) # fc1~relu
x = F.relu(self.fc2(x)) # fc2~relu
x = self.fc3(x) #fc3~10分類
return x
"""★(必須)学習設定(training_step)★"""
def training_step(self, batch, batch_idx):
img, label = batch
out = self(img) #これでforward部分を呼び出す ※self.forward(img)でもOK
loss = F.cross_entropy(out, label) #nn.CrossEntropyLoss()でもいい
#logの設定。Tensorbordで見たいように設定する
self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return loss
"""
★(必須)バリデーション設定(validation_step)★
stepは各イテレーションごとの結果
"""
def validation_step(self, batch, batch_idx):
img, label = batch
pred = self(img) #これでforward部分を呼び出す
val_loss = F.cross_entropy(pred, label) #loss計算
# 正解率(acc)の算出
pred_label = torch.argmax(pred, dim=1)
val_acc = torch.sum(label == pred_label) * 1.0 / len(label)
#logの設定
self.log('val_loss', val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_acc', val_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
results = {'val_loss': val_loss, 'val_acc': val_acc }
return results
"""
★(無くてもいいoption)バリデーションループの値出力★
validation_endはエポック毎の集計
"""
def validation_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
results = {'val_loss': avg_loss, 'val_acc': avg_acc}
return results
"""
★(無くてもいいoption)test設定★
※log名称以外中身はvalidation_stepと同じ
"""
def test_step(self, batch, batch_idx): #testの実行;最終的なlossとaccの確認
img, label = batch
pred = self(img) #これでforward部分を呼び出す
test_loss = F.cross_entropy(pred, label)
# 正解率の算出
pred_label = torch.argmax(pred, dim=1)
test_acc = torch.sum(label == pred_label) * 1.0 / len(label)
#logの設定
self.log('test_loss', test_loss, prog_bar=True, logger=True)
self.log('test_acc', test_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
results = {'test_loss': test_loss, 'test_acc': test_acc}
return results
"""
★(無くてもいいoption)テストループの値出力
※log名称以外中身はvalidation_endと同じ
"""
def test_end(self, outputs):
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
avg_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
results = {'test_loss': avg_loss, 'test_acc': avg_acc}
return results
"""★(必須)オプティマイザーの設定★"""
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=0.001) #オプティマイザ
###2-4.学習
2-3で定義したクラスを使用して学習をする。
もちろんfor文なんて書かなくていいし、GPUへのデータ転送等も書かなくていいので楽ちん。
この部分が楽できるのがPyTorch Lightningと通常のPyTorchの大きな違いである
Trainerには引数オプションがいくつかあるが、詳細はドキュメント参照。
#乱数初期値固定
pl.seed_everything(0)
#作成したクラスのインスタンス化 ※GPU使用でも「.to(device)」は不要
net = My_litmodel()
"""
Trainerのインスタンス化
gpus:GPUを何個使用するか? ※torch.device("cuda:0")は不要
max_epochs:エポック数指定
"""
trainer = Trainer(gpus=1, max_epochs=20)
#Kerasチックにfit()で学習をスタート
trainer.fit(net, trainloader, valloader)
###2-4.結果確認
学習が終わったら結果を確認する。
#self.logを設定した内容が表示される
print(trainer.callback_metrics)
{'val_loss': tensor(1.0951, device='cuda:0'),
'val_acc': tensor(0.6247, device='cuda:0'),
'train_loss': tensor(1.1489, device='cuda:0')}
次に、lossやaccの経過を表示する。
デフォルトのlogがtensorboard
なので、jupyterからはマジックコマンドで呼び出す。
※logはデフォルトだと直下に「lightning_logs」というフォルダが出来ていて、その中に存在
"""
マジックコマンドでtensorboardをJupyter内で呼び出す
確認したい学習時のlossをRunsから選択すれば、Notebook内で推移が確認できる
"""
%reload_ext tensorboard
%tensorboard --logdir ./lightning_logs --bind_all --port 6006
###2-5.テストデータを確認
まずは全テストデータの正解率がどんなもんか?を確認しておく。
#まとめて一気に確認 ※Class内のtest_stepで定義した部分が使用されている
test_result = trainer.test(test_dataloaders=testloader)
print(test_result)
[{'test_loss': 1.0554693937301636, 'test_acc': 0.6190000176429749}]
うーん、6割なので全然ですね。
まぁ本記事はやり方だけ学べればいいので、次は1枚の画像に関しての予測をしてみる。
#CIFER10のラベルをlistで用意
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#1個ずつ確認
index = 50 #何番目のテストデータセットを確認したいか?を指定
#データセットから__getitem__で中身を1個取得し、viewで入力形式を変更(1枚を付け加える)
test_picture = test.__getitem__(index)[0].view(1, 3, 32, 32)
test_picture_label = test.__getitem__(index)[1]
#取り出した画像を予測させる
test_picture_pred = net(test_picture)
print(test_picture_pred)
print("正解ラベルは" + str(classes[test_picture_label]))
# 予測ラベルを表示
print("予測ラベルは" + str(classes[torch.argmax(test_picture_pred)]))
tensor([[-0.4176, 4.6303, -2.1115, -2.8112, -4.2020, -4.5353, -6.8572, -5.9206,
0.7465, 1.0103]], grad_fn=<AddmmBackward>)
正解ラベルはcar
予測ラベルはcar
こんな感じで1枚1枚の結果確認もできる。
###2-6.モデル保存と読み出し
通常と変わらない
#保存
torch.save(net.state_dict(), 'cifer.pt')
#読み出し
net.load_state_dict(torch.load('cifer.pt'))
#3.EarlyStoppingを2通りで実装する
以前以下のような記事を書いたが、実はPyTorch Lightningだともっと簡単に使うことが出来る。
Trainerのインスタンス化の個所を以下のように変更するだけで簡単に使用できる。
なお、方法は2パターンあるがearly_stop_callback=Trueだと色々設定できないのでcallbackの方がいい?
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
"""
★EarlyStoppingのcallbackパラメータ★
monitor:何を監視するか
patience :何回更新しないと発動するか?
verbose:進捗表示の有無
mode:lossがどうしたらカウントするか?min一択(maxはよくわからん)
"""
#callback利用の場合
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=1,
verbose=True,
mode='min'
)
#callback利用の場合
trainer = Trainer(gpus=1, max_epochs=10000, callbacks=[early_stop_callback])
"""early_stop_callback利用の場合は以下1行だけでOK"""
# trainer = Trainer(gpus=1, max_epochs=10000, early_stop_callback=True)
trainer.fit(model, trainloader, valloader)
以下のように、毎Epochの経緯も表示設定すれば確認可能。
#4.おわりに
今回はPyTorch Lightningを使ってみました。
やはり自分でforiループ書かなくていいのは便利だし、callbackが手軽に利用できるのも魅力だと感じました。だんだんKerasの方が簡単!なんて言えなくなる時代が近づいているのかもですね。
それでは引き続きよきPyTorchライフを!
<おまけ>
公式のチュートリアルにもCIFER10の画像分類があって、こっちは94%の精度らしいですし。
中読んで理解できる人は読んでみてもいいかもしれません。
cifar10_normalizationってのがあるんですね。こっち使えばよかったかも・・