はじめに
機械学習コードに用いられるPyTorchコーディング時の実装負担低減を目的として、テンプレートコードを作成してみました。本記事では具体的な使用方法を記載します。(テンプレートコード部分の実装は文献1を参考にさせていただきました)
ソースコード
下記リンクからアクセス可能です。
テンプレートコードはframeworkディレクトリに、ユーザー実装部分はusrディレクトリに、それぞれ格納されています。
使用方法
本テンプレートコードを使用することにより、自作のデータセットおよびモデル(+損失関数等)を用意するだけで、簡単にモデルの学習やテストができるようになります。ここでは、具体的な使用方法について説明します。
1. データセットの構築
初めに、データセットを読み込むためのクラスを作成します。データセットはユーザー定義であることから、データ形状等に制約はありません。一方で、後述するデータローダー作成時に本クラスを参照する都合上、(1)テンプレートコード内で定義されているDataクラスを継承すること、(2)データセットの要素数、入力データ、出力ラベルを格納するための変数(_len, _data, _target)を定義すること、の2点を遵守する必要があります。
一例として、CIFAR10データセット2を読み込むためのデータセットクラス実装例を示します。
import sys
import glob
import pickle
import numpy as np
import pandas as pd
from framework import Data
#
# class to load CIFAR10
class CIFAR10(Data.Data):
# initializer
def __init__(self, filepath):
# initialize member function
self._len = 0
self._data = np.empty((0, 3 * 32 * 32))
self._target = []
# load list of file path
filepath_list = glob.glob(filepath + "*")
if (not filepath_list):
print("error : failed to load data file.")
sys.exit(0)
# load data
for _, filepath in enumerate(filepath_list):
dict = pickle.load(open(filepath, "rb"), encoding="bytes")
self._len += len(dict[b"labels"])
self._data = np.append(self._data, dict[b"data"], axis=0)
self._target.extend(dict[b"labels"])
self._data = self._data.astype("float32").reshape(-1, 3, 32, 32)/255.0
ここで、ユーザー定義クラスではDataクラスを継承していますが、これによりデータローダー構築時に使用するゲッター等の関数群をユーザー側で実装する必要がなくなっています。
使用する際には、クラスのインスタンスを作成するだけでOKです。
cifar10_train = CIFAR10.CIFAR10("./data/cifar10/data_batch")
cifar10_test = CIFAR10.CIFAR10("./data/cifar10/test_batch")
2. DataLoaderの構築
PyTorchには、データのデータのミニバッジを作成してくれるDataLoaderクラスが存在しますが、本テンプレートコードにはこれに対応するDataManagerクラスが定義されています。
使用例は下記の通りです。インスタンス作成時には、学習・評価・テストデータのほかに、バッジサイズやデータローディング時の並列数等のオプションを指定することができます。
data_manager = DataManager.DataManager(train_data=cifar10_train, test_data=cifar10_test,
batch_size=batch_size, num_workers=num_workers)
2. 学習用クラスの構築
モデル学習・推論を実施する際には、ネットワーク・損失関数・評価関数・最適化アルゴリズム等の要素を使用します。本テンプレートコードでは、これらの要素を一括管理し、学習や推論を簡単に実施するためのNetWorkクラスが定義されています。
使用例は下記の通りです。インスタンス作成時には、上記要素に加え、使用するデバイスの種類(CPU or GPU)や使用するGPU数等のオプションを指定することができます。
net = Network.Network(model=model, criterion=criterion,
evaluator=evaluator, optimizer=optimizer,
data_manager=data_manager,
device=device, device_ids=device_ids, non_blocking=True)
3. 学習・評価・推論
学習・評価・推論を実施するための関数は、先述したNetWorkクラスにメンバ関数として定義されています。これらメソッドはインスタンス生成時の引数を参照して実施されるため、メソッド呼び出し時には引数渡し等の操作は必要はありません。
学習メソッドの呼び出し例は下記の通りです。
for _ in range(num_epochs):
net.train()
評価・推論メソッドの呼び出し例は下記の通りです。
net.eval()
net.test()
4.実装例
最後に、CIFAR10データセット用の線形結合ネットワークを学習・推論させるコードの実装例を下記に示します。通常だとデータセット構築部分や学習部分等の行数が長くなり可読性が低下しがちですが、本テンプレートコードを用いることにより全体の見通しが良くなりました。
import torch
import torch.nn as nn
from framework import DataManager
from framework import Network
from framework import Timer
from usr import CIFAR10
from usr import Linear
from usr import CorrectRate
if __name__ == "__main__":
# define variables
device = "cuda"
device_ids = [0]
batch_size = 100
num_workers = 1
num_epochs = 20
learning_rate = 0.001
# create data set
cifar10_train = CIFAR10.CIFAR10("./data/cifar10/data_batch")
cifar10_test = CIFAR10.CIFAR10("./data/cifar10/test_batch")
# create data manager
data_manager = DataManager.DataManager(train_data=cifar10_train, test_data=cifar10_test, batch_size=batch_size, num_workers=num_workers)
# define user-defined network
model_kwargs = {"in_units" : 3*32*32, "mid_units" : [100, 100, 100, 100], "out_units" : 10}
model = Linear.Linear(**model_kwargs)
# define user-defined criterion
criterion = nn.CrossEntropyLoss()
# define user-defined optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# define user-defined evaluator
evaluator = CorrectRate.CorrectRate()
# define model
net = Network.Network(model=model, criterion=criterion, evaluator=evaluator, optimizer=optimizer,
data_manager=data_manager, device=device, device_ids=device_ids, non_blocking=True)
# training
with Timer.Timer() as timer:
for _ in range(num_epochs):
net.train()
# test
net.test()
下記のような出力が得られます。学習・推論ともに正常にできています。
train mode : epoch = 1, loss = 1.904e-02
train mode : epoch = 2, loss = 1.710e-02
train mode : epoch = 3, loss = 1.636e-02
train mode : epoch = 4, loss = 1.581e-02
train mode : epoch = 5, loss = 1.532e-02
train mode : epoch = 6, loss = 1.498e-02
train mode : epoch = 7, loss = 1.469e-02
train mode : epoch = 8, loss = 1.440e-02
train mode : epoch = 9, loss = 1.417e-02
train mode : epoch = 10, loss = 1.397e-02
time : 17.194sec
test mode : loss = 1.454e-02, eval = 4.835e-01