2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【Python初学者】Batch Normalizationがとてつもなく強力な件。(DeepLearning精度向上)

Last updated at Posted at 2021-11-09

#はじめに

この記事の対象者は
「python初学者・機械学習初学者」向けです。
つまり私です。

#今回のケース

sklearnのload_breast_cancerを用いて、
別で用意したサンプルデータに対して
予測を行うための、学習モデルを作ります。

そのモデルに対し
Batch Normalizationの採用の可否で
正解率の結果が大きく変わることを
見ていきたいと思います!!

(本記事では予測値は出しません。)

#実装

PyTorch Lightning でネットワーク・学習の手順を定義します。

class Net(pl.LightningModule):

    def __init__(self):
        super().__init__()
        
        self.bn = nn.BatchNorm1d(30)
        self.fc1 = nn.Linear(30, 10)
        self.fc2 = nn.Linear(10, 2)


    def forward(self, x):
        h = self.bn(x)
        h = self.fc1(h)
        h = F.relu(h)
        h = self.fc2(h)
        return h


    # 訓練データに対する処理
    def training_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', accuracy(y.softmax(dim=-1), t), on_step=True, on_epoch=True, prog_bar=True)
        return loss


    # 検証データに対する処理
    def validation_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_acc', accuracy(y.softmax(dim=-1), t), on_step=False, on_epoch=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
        return optimizer

学習の実行

# GPU を含めた乱数のシードを固定
pl.seed_everything(0)

# 学習の実行
net = Net()
logger = CSVLogger(save_dir='logs', name='my_exp')
trainer = pl.Trainer(max_epochs=30, deterministic=True, logger=logger)
trainer.fit(net, train_loader, val_loader)

学習の結果を出します。

trainer.callback_metrics

{'train_acc': tensor(0.6364),
'train_acc_epoch': tensor(0.6364),
'train_acc_step': tensor(0.7500),
'train_loss': tensor(0.6559),
'train_loss_epoch': tensor(0.6559),
'train_loss_step': tensor(0.5971),
'val_acc': tensor(0.6140),
'val_loss': tensor(0.6672)}

以上のように結果が出ました。
検証データの**accuracy(正解率)**については

'val_acc': tensor(0.6140)
61.4% となっており、

けして良い学習結果とは言えません。

#ネットワークに「Batch Normalization」を組み込んでみた。

class Net(pl.LightningModule):

    def __init__(self):
        super().__init__()
        
        self.bn = nn.BatchNorm1d(30) ←これです
        self.fc1 = nn.Linear(30, 10)
        self.fc2 = nn.Linear(10, 2)


    def forward(self, x):
        h = self.bn(x) ←これです
        h = self.fc1(h)
        h = F.relu(h)
        h = self.fc2(h)
        return h

※以下同様

trainer.callback_metrics

{'train_acc': tensor(0.9455),
'train_acc_epoch': tensor(0.9455),
'train_acc_step': tensor(0.9000),
'train_loss': tensor(0.1633),
'train_loss_epoch': tensor(0.1633),
'train_loss_step': tensor(0.1284),
'val_acc': tensor(0.9912),
'val_loss': tensor(0.0839)}

検証データの**accuracy(正解率)**が

'val_acc': tensor(0.9912)
99.12%

となっていました!
正解率が爆増しています!

これは衝撃を受けました。
すごい発明ですね。

#Batch Normalizationってなに?

精度が向上してすごい!!
ってなりますが、Batch Normalizationってなんなのか?

Batch Normalization(以下Batch Norm)は
2015年に提案された割と最近の手法ではあるのですが
多くの研究者や技術者に広く使われているそうです。
機械学習のコンペティションでもBatch Normを活用して
優れた結果を達成している例が多くみられるそうです。

メリットとしては以下のようになります。
・学習を早く進行できる。
・初期値にそれほど依存しない。
・過学習を抑制する。

Batch Normは、Batchという名前の通り
ミニバッチごとに正規化します。
簡単に言うとデータのスケールを扱いやすいものに整えることです。

正規化を行わない場合、
例えば、ミニバッチをランダムで決めると、
大きい数字のものが集まったり、または逆に小さい数字が集まったり
ミニバッチごとのスケールにばらつきが生じることがあります。

それをBatch Normを使用することで
適度なスケールにしバッチ間の差の広がりを抑え、
学習をスムーズにできる!

そんな手法です!!

(数学理論は小難しいので、自分が分かるように割愛しました!!)
※詳しくは文献およびネットで検索

#参考文献

ゼロから作るDeep Learning――Pythonで学ぶディープラーニングの理論と実装
https://www.oreilly.co.jp/books/9784873117584/

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?