76
78

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorch のネットワーク?クラス?ポイント 5 個抑えれば大丈夫!(Python 基礎_特にクラス_を飛ばして学び始めてしまった方向け)

Last updated at Posted at 2021-02-22

Python で最初につまづくポイントの 1 つがクラスだと思います。
私は最初はずっと Keras(Functional API) でディープラーニングを実装していました。
そしてクラスの理解を疎かにしたまま PyTorch へ手を出し、苦労した経験があります。
きっと同じような経験をした方もいるのではないでしょうか(Keras でも Subclassing API に慣れている方は大丈夫かな。)

また、機械学習・ディープラーニングを学び始めてクラスを使ったネットワーク構築でつまづいた方も多いかと思います。

本記事がおすすめの方

  • 最近、機械学習、ディープラーニングについて学び始めた方
  • TensorFlow、Keras を利用していたが PyTorch も挑戦してみたい方
  • PyTorch でネットワークを組んでみたがクラスの理解につまづいた方

クラスの最低限の基礎を習得して、フレームワークを上手く使いこなせるお役に立てれば幸いです。

PyTorch では一般的にクラスを用いて(オブジェクト指向プログラミングで)ネットワークを組んでいきます。
オブジェクト指向プログラミング
なぜクラスを使用するのかについて、クラスの基礎から見ていきましょう。

最低限覚えておくと役立つポイントに『ポイント』と記述しますので本記事のポイントを抑えられればネットワーク構築に役立つはずです。
※ PyTorch のネットワーク構築のための必要最低限のポイントとします。

クラスの基礎

クラスは設計図に例えられることが多いです。(イメージ:紙面上に書かれているだけの状態)

ポイント①:インスタンス化

この紙面上の設計図を実体化することを インスタンス化 と呼びます。

空のクラス
class Class:
    pass # →何も無し
インスタンス化
instance = Class()

上記でクラスをインスタンス化できますが、何も処理できない空のクラスです。

Helloと出力するクラス
class Class:
    def method(self):
        print('Hello')
確認
instance = Class() # インスタンス化
instance.method()
>>> Hello

ポイント②:self

メソッドを記述する際は最低 1 つ引数が必要です。
引数無しだとエラーになります。(Python のルールなので受け入れましょう。)
self はインスタンス自身を表すもので、慣習的に使用されるものです。

しかし、上記のクラスは汎用的なプログラムとは呼べません。

ポイント③:__init__

より汎用的なプログラムを組み、インスタンス化する際に、必要なメソッドが __init__ です。

例えば、以下でデータ(変数)を生み出すクラスを組んでみます。

変数を生み出すクラス
class Variable:
    def __init__(self, data):
        self.data = data
データ用意
import numpy as np
data = np.array(1.0)
確認_1
x = Variable(data) # インスタンス化
print(x.data)
>>> 1.0
確認_2
x.data = np.array(2.0)
print(x.data)
>>> 2.0

このように、クラスを用いることで汎用性を高めることができます。

ポイント④:__call__

さらに __call__ メソッドと呼ばれる特殊メソッドなるものがあります。こちらは関数のように使用することができます。

callメソッドを加えた関数
class Variable:
    def __init__(self, data):
        self.data = data
        
    def __call__(self):
        y = self.data ** 2 # データを二乗する
        return y
確認_1
x = Variable(np.array(2.0)) # インスタンス化
print(x.data)
>>> 2.0
確認_2__call__メソッド(インスタンスそのものを関数のように呼び出すことができる)
x()
>>> 4.0

ポイント⑤:継承

また以下のようにすべてに共通するクラス(Function:基底クラスという位置づけ)を継承したクラスを実装することで、関数同士を組み合わせることができ、より汎用性が高まります。

以下の関数を実装します。
ゼロから作るディープラーニング③フレームワーク偏より一部引用します。

$$
y=(e^{x^2})^2
$$

class Variable:
    def __init__(self, data):
        self.data = data
        
class Function: # 基底クラス
    def __call__(self, input):
        x = input.data # データ取得
        y = self.forward(x) # 計算処理
        output = Variable(y) # Variable として取得
        return output
    
    def forward(self, x):
        raise NotImplementedError() # 組み込み例外(割愛)
        
class Square(Function): # 基底クラスを継承して二乗
    def forward(self, x):
        return x ** 2
    
class Exp(Function): # 基底クラスを継承して対数変換
    def forward(self, x):
        return np.exp(x)
確認
A = Square()
B = Exp()
C = Square()

x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)

print(y.data)
>>> 1.648721270700128

要はクラスを用いるとコード量を減らしつつ、汎用的なプログラムを組めるわけです。

PyTorch のネットワーク

ニューラルネットワークは複数の関数がつながったひとつの合成関数とみなすことができます。
アルゴリズム的には順伝播逆伝播を繰り返して、目的関数の最適化のためにモデルのパラメータを調整していくものです。
順伝播・逆伝播ではそれぞれ線形変換・非線形変換・微分が行われており、これらをクラスを用いて構築することで汎用性の高いプログラムにできます。逆伝播(バックプロパゲーション)は合成関数の微分を用いて、誤差を出力→入力方向(逆方向)に伝播していきます。
ここで大事なことが Dfine-by-Run(動的計算グラフ)と呼ばれる、計算のつながりを、計算を行うタイミングで作る仕組みです。これを「動的計算グラフ」とも呼びます。

ノード数や層の数、最適化手法、学習係数など人間側で決めなければならないハイパーパラメータに合わせて柔軟に処理できるようにしてくれたものがフレームワークというわけですね。

フレームワーク(PyTorch)にはこれらの処理をまとめたモジュールが用意されています。
このモジュールを使用しない場合にはひとつひとつの処理をフルスクラッチで記述する必要がでてきます。
モジュールを継承したクラス(ネットワーク計算の設計図)を組むことでコード量を減らすことができます。

以下の PyTorch 公式チュートリアルを参考にネットワークを組みます。
https://pytorch.org/docs/stable/generated/torch.nn.Module.html

import torch.nn as nn
import torch.nn.functional as F
基底クラスとして使う
nn.Module
>>> torch.nn.modules.module.Module

まずは、モジュール torch.nn.modules.module.Module 以下(nn.Module)を継承します。

モジュールの中身の確認ができます
help(nn.Module)
ネットワーク(例)
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(30, 10)
        self.fc2 = nn.Linear(10, 2)

    def forward(self, x):
        h = self.fc1(x)
        h = F.relu(h)
        h = self.fc2(h)
        return h
ネットワークアーキテクチャ確認_1
net = Net() # インスタンス化
print(net)
>>>
Net(
  (fc1): Linear(in_features=30, out_features=10, bias=True)
  (fc2): Linear(in_features=10, out_features=2, bias=True)
)

このようなネットワークアーキテクチャを組むことができました。
ポイント①〜⑤がすべて使われています。
__init__nn.Module をインスタンス化しています。

__call__ メソッドは?と思った方

基底クラスの forward メソッド(nn.Module.forward)を使用しており、__call__ メソッドとして forward が機能するようにモジュール側で実装されています。
チュートリアルにも『Defines the computation performed at every call.』と説明されています。

super().__init__() は Net クラスの __init__ メソッドを実行しています。
基底クラスを継承してメソッドに機能を追加・変更をすることができます。
これをオーバーライドといいます。

super() 内の Net, self は省略することも可能です。

省略Ver
class Net(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(30, 10)
        self.fc2 = nn.Linear(10, 2)

    def forward(self, x):
        h = self.fc1(x)
        h = F.relu(h)
        h = self.fc2(h)
        return h
ネットワークアーキテクチャ確認_2
net = Net() # インスタンス化
print(net)
>>> 
Net(
  (fc1): Linear(in_features=30, out_features=10, bias=True)
  (fc2): Linear(in_features=10, out_features=2, bias=True)
)

以下のように引数を利用することで、

引数設定Ver
class Net(nn.Module):

    def __init__(self, n_mid=None):
        super().__init__()
        self.fc1 = nn.Linear(30, n_mid)
        self.fc2 = nn.Linear(n_mid, 2)


    def forward(self, x):
        h = self.fc1(x)
        h = F.relu(h)
        h = self.fc2(h)
        return h
ネットワークアーキテクチャ確認_3
net = Net(n_mid=5) # インスタンス化(中間層のノード数を 5 に)
print(net)
>>>
Net(
  (fc1): Linear(in_features=30, out_features=5, bias=True)
  (fc2): Linear(in_features=5, out_features=2, bias=True)
)

インスタンス化と同時に、ネットワーク構造を変えたりすることができます。
(中間層のノード数を 5 に変更している)

ここからモデルの訓練の際には以下のような処理を書いていく必要があります。

訓練
for epoch in range(max_epoch):

    for batch in train_loader:

        x, t = batch
        x = x.to(device)
        t = t.to(device)
        y = net(x)# __call__ メソッドとして使用可能
        loss = F.cross_entropy(y, t)

        y_label = torch.argmax(y, dim=1)
        accuracy = (y_label == t).sum().float() / len(t)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

そして訓練後にはモデルの性能評価処理を以下のように記述していきます。(例:分類問題)

正解率を計算する関数
def calc_accuracy(data_loader):

    with torch.no_grad():
        total = 0
        correct = 0.0

        for batch in data_loader:
            x, t = batch
            x = x.to(device)
            t = t.to(device)
            y = net(x) # __call__ メソッドとして使用可能

            y_label = torch.argmax(y, dim=1)
            total += len(t)
            correct += (y_label == t).sum()

        accuracy = correct / total

    return accuracy

PyTorch Lightning

先程紹介したように、nn.Module を継承したネットワークを 『生 PyTorch』と呼んだりもします。(身内ネタかもです。)

生 PyTorch の記述をさらに簡略化してくれるものをラッパーと呼びます。
以下のように、pl.LightningModule を継承します。1.0 系にメジャーアップデートされ、内部のメソッドが充実し、さらに使いやすくなりました。

PyTorch_Lightning
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy

class Net(pl.LightningModule):

    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(30, 10)
        self.fc2 = nn.Linear(10, 2)

    # 順伝播
    def forward(self, x):
        h = self.fc1(x)
        h = F.relu(h)
        h = self.fc2(h)
        return h

    # 訓練データに対する処理
    def training_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', accuracy(y, t), on_step=True, on_epoch=True, prog_bar=True)
        return loss


    # 検証データに対する処理
    def validation_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_acc', accuracy(y, t), on_step=False, on_epoch=True)
        return loss


    # テストデータに対する処理
    def test_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        self.log('test_loss', loss, on_step=False, on_epoch=True)
        self.log('test_acc', accuracy(y, t), on_step=False, on_epoch=True)
        return loss

    # 最適化手法
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
        return optimizer

訓練・ログ確認・結果確認・検証も以下のように簡略化して実行できます。

訓練・ログ確認・結果確認・検証
pl.seed_everything(0)

# 訓練の実行
net = Net()
trainer = pl.Trainer(max_epochs=30, gpus=1, deterministic=True)
trainer.fit(net, train_loader, val_loader)

# 訓練ログ確認(TensorBoard)
%tensorboard --logdir lightning_logs/

# 訓練結果の確認
trainer.callback_metrics

# テストデータ検証
trainer.test(test_dataloaders=test_loader)

私は、ディープラーニングの処理の細かさを残しながら、汎用性の高い PyTorch Lightning が圧倒的に使いやすいです。
GPU への転送処理が不要なことや、logger の使いやすさ、Optuna との相性など
PyTorch の基礎を抑えておけばすごく便利ですので、是非使ってみてください。

参考

torch.nn.Module
PyTorchLightning/pytorch-lightning
ゼロつく③

76
78
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
76
78

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?