LoginSignup
0
2

More than 3 years have passed since last update.

Pytorchで学習処理をパッケージ化

Posted at

はじめに

画像分類問題の練習のために友人と5人でSignateの練習問題に投稿したのですが、その時に自作した機能を紹介します。

機能自作の経緯

1. 実装開始の前に

集まったメンバはAIについて勉強はしていましたが普段からプログラミングしているわけではなかっので、画像分類問題は初めてのメンバもいたり、スキルレベルもまちまちでした。
なので、お互い教え合いながらやっていく方針で取り組みました。
また、色々事情もあり取り組む期間は2週間という特殊な制約もありました。

5人で同じ問題に取り組むにあたって同じ設定で実装をしないと教え合うことが難しいので、まず以下のことを決めました。

  • 開発言語はPython
  • DL用FWはPytorch
  • Google Colaboratoryで実行

FWについては私がPytorchを使っていたこと、もう2人が練習中であっためTensorflowではなくPytorchを使うことにしました。

2. 各メンバで学習処理を実装開始

Resnetを転移学習してベースモデルを作ることをまずは目指して実装を始めました。
Pytorchを使ったことある方はご存知と思いますが、以下のことを実装する必要があります。

  • Datasetの作成
  • DataLoaderの作成
  • Model定義
  • 誤差関数、最適化関数の定義
  • 学習ループの作成
  • バリデーション

意気揚々と取り組み始めましたが、最初のデータセット用のリストや画像読込、ちょっとした前処理などで結構苦戦しました。
みんな早く学習処理回したかったのですが、相談しながらやってるとModel定義まで中々時間がかかってしまいました。

3. 学習処理のパッケージ化

このままだと2週間の期間内に終わらなさそうだったので学習処理部分をクラス化してチームメンバに配布することにしました。
ということで、実装する機能についての要件を簡単に検討して作成しました。

機能要件

メンバの状況を見つつ、取り合えず以下のような要件を決めました。

  • ソースコードをコピペじゃなくって、パッケージ化してimportさせる
  • 学習処理は隠ぺいしてパラメータ与えてメソッド一発で学習してくれる
  • モデルは色々選びたいので別クラスにして入力にする
  • モデル保存、バリデーションといっためんどくさい処理もやってくれる

とにかく色々めんどくさいをまとめて、処理してくれてくれるものとしました。
また、pytorch-ligntningとかcatalystなどの充実したパッケージもあるんですが、機能が多いのと使い方を覚えないといけないので必要な機能のみを自作するということとしました。

その結果、myTorchLib.pyというファイルにTrainerクラスを実装して共有しました。

実装したコード紹介

以下がTrainerクラスの実装です。
急いで作ったので微妙な部分あるかもですが、甘い目で見て頂けたらと思います。


class Trainer():
  def __init__(self, model, criterion, optimizer, params=None ):

    if params is None:
      params = {
          "EPOCH" : 5
      }

    self.__model = model
    self.__criterion = criterion
    self.__optimizer = optimizer
    self.__params = params

    # model 保存用
    self.__current_loss = None

    # デフォルトはCPU設定
    self.__device = "cpu"

  # gpuへ転送されたtorch.tensorのデータをnumpy(cpu)形式に変換
  def __to_numpy(self, x):
    return x.to("cpu").detach().numpy().copy()

  # バッチ単位のlossのaverage計算のため
  def __avarage_loss(self, loss):
    return np.sum(loss) / len(np.ravel(loss))

  # モデルの保存
  def __save_model(self, epoch, loss=None, save_model_name="model", save_dir="./", save_method = "best_model"):
    # モデル保存用のパス作成
    if save_method == "best_model":
      file_name = save_dir + save_model_name + "_best.pth"
      # lossが一番低いモデルを保存
      if self.__current_loss is None:
        torch.save(self.__model.state_dict(), file_name)
        self.__current_loss = loss
      elif loss < self.__current_loss:
        torch.save(self.__model.state_dict(), file_name)
        self.__current_loss = loss
    elif save_method == "epoch":
      file_name = save_dir + save_model_name + "_E_" + str(epoch) + ".pth"
      # Epoch毎に保存
      torch.save(self.__model.state_dict(), file_name)

  # 1 ループの処理
  def __training_step(self, inputs, labels):
    outputs = self.__model(inputs)
    loss = self.__criterion(outputs, labels)
    loss.backward()
    return outputs, loss.item()

  # バリデーション
  def __validation_step(self, inputs, labels):
    with torch.no_grad():
      outputs = self.__model(inputs)
      loss = self.__criterion(outputs, labels)
    return outputs, loss.item()

  # 1 epoch の処理
  def __train_valid_loop(self, loader: data.DataLoader, phase: str ="train"):
    loss_data = []

    for inputs, labels in tqdm(loader):

      inputs = inputs.to(self.__device)
      labels = labels.to(self.__device)

      if phase == "train":
        self.__optimizer.zero_grad()
        preds, loss = self.__training_step(inputs, labels)
        self.__optimizer.step()
      else:
        preds, loss = self.__validation_step(inputs, labels)

      loss_data.append(loss)

    out_loss = self.__avarage_loss(loss_data)
    return out_loss

  # デバイス設定:(device = "cuda" if torch.cuda.is_available() else "cpu")の結果を渡す
  def to(self, device):
    self.__device = device

  # 学習処理実行用メソッド(外部公開用)
  def train(self, train_loader: data.DataLoader, valid_loader: data.DataLoader=None, save_model_name="model", save_dir="./", save_method = "best_model"):
    train_loss_log = []
    valid_loss_log = []
    for epoch in range(self.__params["EPOCH"]):
      # training
      self.__model.train()
      train_loss = self.__train_valid_loop(train_loader)
      train_loss_log.append(train_loss)
      print("EPOCH : ", epoch, " training loss : ", train_loss)

      # validation
      if valid_loader != None:
        self.__model.eval()
        valid_loss = self.__train_valid_loop(valid_loader, phase="valid")
        valid_loss_log.append(valid_loss)
        print("EPOCH : ", epoch, " validation loss : ", valid_loss)

      self.__save_model(epoch, train_loss, save_model_name, save_dir, save_method)
    return train_loss_log, valid_loss_log

コード解説

まず、別ファイル(myTorchLib.py)として共有したので以下のように配置してnotebookからTrainerクラスをインポートするようにしました。

┣ Lib
┃ ┗ myTorchLib.py
┗ basemodel.ipynb

インポートするときは以下のようにします。

import Lib.myTorchLib as mylib

Trainerクラスが公開する機能は以下の2つです。

  • train():学習処理
  • to():gpu, cpu設定

学習処理の実行部分のみを公開して、学習ループの実処理部分は非公開にしました。
これによってDataset, DataLoader, modelを作ってtrainメソッドを呼び出せば学習、バリデーションが回ってくれるようになります。
学習処理を実行するときは、Trainerクラスのインスタンスをを作成します。
この時、モデル(model)、誤差関数(criterion)、最適化関数(optimizer)、パラメータ(params)を指定します。
今のところparamsにはEPOCH(エポック数)しか対応していません。
そして、train_dataloader, valid_dataloader等を指定してtrainメソッドを実行すると学習処理が回ります。

params = {
    "EPOCH": 10 # epoch数を指定
}
trainer = mylib.Trainer(model=model, criterion=criterion, optimizer=optimizer, params=params)
trainer.to(device)
trainer.train(
    train_dataloader, # 学習データ用のDataLoader(必須)
    valid_dataloader, # バリデーション用のDataLoader(未指定であればバリデーションなし)
    save_model_name="resnet", # 保存モデル名
    save_dir="./model/", # 保存場所
    save_method="best_model" # 保存方法('best_model' : lossが一番低いモデルを保存, "epoch":毎エポック保存、それ以外の文字列はモデル保存なし)
    )

共有後は自分でも使いつつ、モデル保存処理や各エポック毎のlossの出力などの機能を追加しました。
trainメソッドについてですが、
必須なのはtrain_dataloaderのみです。
valid_dataloader未設定の場合はバリデーションなしで動きます。
save_model_name、save_dir、save_methodはモデル保存のための設定なので、未設定の場合はデフォルト設定で処理されます。

これによって学習処理部分はメソッドを呼び出すだけとなり、一緒に取り組んだメンバはmodelや前処理の実装に集中できるようになりました。

使ってみた感想

今回は画像分類問題に対して転移学習で学習する利用ケースだけだったので限定した機能だけの実装で作ることができましたし、使うことができました。
これによって、各メンバが色んなモデルを試すことが出来ましたし短期的ですが役立つことができたと思います。
機能が少ないことが逆に良かったのかもしれません。

正直なところ、catalystなどを超絶に劣化させたようなものなのでこれから使うのであればcatalystやpytorch-lightningを使った方が良いと思います。
ただ、自分の作ったものを人に使ってもらえて感想をもらって改造していけるのはとてもありがたく嬉しいことだなと思いました。

最後に、私の実装したことが何かお役に立てば幸いです。

参考情報

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2