LoginSignup
30
30

PyTorchコーディング時の実装負担を低減させるテンプレートコード

Last updated at Posted at 2023-07-08

はじめに

 機械学習コードに用いられるPyTorchコーディング時の実装負担低減を目的として、テンプレートコードを作成してみました。本記事では具体的な使用方法を記載します。(テンプレートコード部分の実装は文献1を参考にさせていただきました)

ソースコード

 下記リンクからアクセス可能です。
 テンプレートコードはframeworkディレクトリに、ユーザー実装部分はusrディレクトリに、それぞれ格納されています。

使用方法

 本テンプレートコードを使用することにより、自作のデータセットおよびモデル(+損失関数等)を用意するだけで、簡単にモデルの学習やテストができるようになります。ここでは、具体的な使用方法について説明します。

1. データセットの構築

 初めに、データセットを読み込むためのクラスを作成します。データセットはユーザー定義であることから、データ形状等に制約はありません。一方で、後述するデータローダー作成時に本クラスを参照する都合上、(1)テンプレートコード内で定義されているDataクラスを継承すること、(2)データセットの要素数、入力データ、出力ラベルを格納するための変数(_len, _data, _target)を定義すること、の2点を遵守する必要があります。

 一例として、CIFAR10データセット2を読み込むためのデータセットクラス実装例を示します。

CIFAR10.py
import sys
import glob
import pickle
import numpy as np
import pandas as pd

from framework import Data

#
# class to load CIFAR10
class CIFAR10(Data.Data):

  # initializer
  def __init__(self, filepath):
    # initialize member function
    self._len = 0
    self._data = np.empty((0, 3 * 32 * 32))
    self._target = []
    # load list of file path
    filepath_list = glob.glob(filepath + "*")
    if (not filepath_list):
      print("error : failed to load data file.")
      sys.exit(0)
    # load data
    for _, filepath in enumerate(filepath_list):
      dict = pickle.load(open(filepath, "rb"), encoding="bytes")
      self._len += len(dict[b"labels"])
      self._data = np.append(self._data, dict[b"data"], axis=0)
      self._target.extend(dict[b"labels"])
    self._data = self._data.astype("float32").reshape(-1, 3, 32, 32)/255.0

 ここで、ユーザー定義クラスではDataクラスを継承していますが、これによりデータローダー構築時に使用するゲッター等の関数群をユーザー側で実装する必要がなくなっています。
 使用する際には、クラスのインスタンスを作成するだけでOKです。

 cifar10_train = CIFAR10.CIFAR10("./data/cifar10/data_batch")
 cifar10_test = CIFAR10.CIFAR10("./data/cifar10/test_batch")

2. DataLoaderの構築

 PyTorchには、データのデータのミニバッジを作成してくれるDataLoaderクラスが存在しますが、本テンプレートコードにはこれに対応するDataManagerクラスが定義されています。
 使用例は下記の通りです。インスタンス作成時には、学習・評価・テストデータのほかに、バッジサイズやデータローディング時の並列数等のオプションを指定することができます。

data_manager = DataManager.DataManager(train_data=cifar10_train, test_data=cifar10_test, 
                                       batch_size=batch_size, num_workers=num_workers)

2. 学習用クラスの構築

 モデル学習・推論を実施する際には、ネットワーク・損失関数・評価関数・最適化アルゴリズム等の要素を使用します。本テンプレートコードでは、これらの要素を一括管理し、学習や推論を簡単に実施するためのNetWorkクラスが定義されています。
 使用例は下記の通りです。インスタンス作成時には、上記要素に加え、使用するデバイスの種類(CPU or GPU)や使用するGPU数等のオプションを指定することができます。

  net = Network.Network(model=model, criterion=criterion,
                        evaluator=evaluator, optimizer=optimizer, 
                        data_manager=data_manager,
                        device=device, device_ids=device_ids, non_blocking=True)

3. 学習・評価・推論

 学習・評価・推論を実施するための関数は、先述したNetWorkクラスにメンバ関数として定義されています。これらメソッドはインスタンス生成時の引数を参照して実施されるため、メソッド呼び出し時には引数渡し等の操作は必要はありません。
 学習メソッドの呼び出し例は下記の通りです。

for _ in range(num_epochs):
  net.train()

 評価・推論メソッドの呼び出し例は下記の通りです。

net.eval()
net.test()

4.実装例

 最後に、CIFAR10データセット用の線形結合ネットワークを学習・推論させるコードの実装例を下記に示します。通常だとデータセット構築部分や学習部分等の行数が長くなり可読性が低下しがちですが、本テンプレートコードを用いることにより全体の見通しが良くなりました。

CIFAR10.ipynb
import torch
import torch.nn as nn

from framework import DataManager
from framework import Network
from framework import Timer

from usr import CIFAR10
from usr import Linear
from usr import CorrectRate

if __name__ == "__main__":
  # define variables
  device = "cuda"
  device_ids = [0]
  batch_size = 100
  num_workers = 1
  num_epochs = 20
  learning_rate = 0.001

  # create data set
  cifar10_train = CIFAR10.CIFAR10("./data/cifar10/data_batch")
  cifar10_test = CIFAR10.CIFAR10("./data/cifar10/test_batch")
  
  # create data manager
  data_manager = DataManager.DataManager(train_data=cifar10_train, test_data=cifar10_test, batch_size=batch_size, num_workers=num_workers)

  # define user-defined network
  model_kwargs = {"in_units" : 3*32*32, "mid_units" : [100, 100, 100, 100], "out_units" : 10}
  model = Linear.Linear(**model_kwargs)

  # define user-defined criterion
  criterion = nn.CrossEntropyLoss()

  # define user-defined optimizer
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

  # define user-defined evaluator
  evaluator = CorrectRate.CorrectRate()

  # define model
  net = Network.Network(model=model, criterion=criterion, evaluator=evaluator, optimizer=optimizer, 
                        data_manager=data_manager, device=device, device_ids=device_ids, non_blocking=True)
  
  # training
  with Timer.Timer() as timer:
    for _ in range(num_epochs):
      net.train()
  
  # test
  net.test()

 下記のような出力が得られます。学習・推論ともに正常にできています。

train mode : epoch =     1, loss = 1.904e-02
train mode : epoch =     2, loss = 1.710e-02
train mode : epoch =     3, loss = 1.636e-02
train mode : epoch =     4, loss = 1.581e-02
train mode : epoch =     5, loss = 1.532e-02
train mode : epoch =     6, loss = 1.498e-02
train mode : epoch =     7, loss = 1.469e-02
train mode : epoch =     8, loss = 1.440e-02
train mode : epoch =     9, loss = 1.417e-02
train mode : epoch =    10, loss = 1.397e-02
time : 17.194sec
test mode : loss = 1.454e-02, eval = 4.835e-01

参考文献

  1. PyTorchのテンプレコードを用意してどんなデータセットにも楽々ディープラーニング

  2. The CIFAR-10 dataset

30
30
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
30
30