1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

深層学習を用いたトランスクリプトームの分類(最適化編)

Posted at

BioinfoにおけるPandas,Matplotlibの基礎

この記事は、Pythonで実践 生命科学データの機械学習 (https://www.yodosha.co.jp/yodobook/book/9784758122634/) の内容を含んでいます。

私は現在バイオインフォマティクス研究室に所属する学生です。
勉強した事をアウトプットする場として用いていますため、何卒ご理解のほどよろしくお願いいたします。
(>人<;)

PyTorchでニューラルネットワークを構築する

⇧前回の記事では、深層学習の実装を行ないました。
今回は、過学習などの課題点を解決する為に勉強したことを記していきます。

過学習への対策

EarlyStopping(アーリーストッピング)とは?

過学習対策の一つとして、コードは、EarlyStopping(アーリーストッピング) という手法を実装しています。
機械学習では、モデルがデータを学習しすぎると、未知のデータに対してうまく動作しなくなることがあります。
EarlyStopping は、学習の途中で「もう十分だ!」と判断して、早めに訓練を止める仕組み です。


EarlyStopping クラスの役割

このクラスは、検証データ(val_loss)の誤差を監視しながら、以下のように動作します:

  1. val_loss(検証誤差)が改善し続ける間は学習を続ける。
  2. val_loss が一定回数(patience)改善しなければ、訓練を停止する。
  3. 最も良かったモデル(val_lossが最小だった時点)を保存し、最終的にそのモデルを採用。

🏗 クラスの構造

class EarlyStopping:

このクラスを使えば、学習の途中で「もう十分!」と判断して止めることができます。


1️⃣ __init__ メソッド(クラスの初期設定)

クラスが作られるときに、必要なパラメータをセットします。

def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
引数 説明
patience val_loss が何回連続で悪化したらストップするか(デフォルトは 7 回)
verbose True にすると、モデルの保存時にメッセージを表示(デフォルトは False)
delta 改善とみなす最小変化量(小さな改善を無視するため)
path 保存するモデルのファイル名(デフォルトは 'checkpoint.pt'
trace_func 処理状況を表示する関数(デフォルトは print

クラス内の変数

self.patience = patience
self.verbose = verbose
self.counter = 0  # 何回連続でval_lossが悪化したかをカウント
self.best_score = None  # 最も良いスコア(最小val_lossのスコア)
self.early_stop = False  # True になったら学習停止
self.val_loss_min = np.Inf  # 最小の val_loss(初期値は無限大)
self.delta = delta
self.path = path
self.trace_func = trace_func

💡 ポイント

  • counter は、何回連続で val_loss が悪化したかを記録
  • best_score は、最小の val_loss に対応するスコア
  • early_stop = False にして、学習を続けるかどうかのフラグを作る。

2️ __call__ メソッド(学習ごとに呼び出される)

def __call__(self, val_loss, model):
  • val_loss現在の検証誤差
  • model現在のモデル

このメソッドが 学習のたびに呼び出され、EarlyStopping の判定を行う


2-1. score を計算(val_loss の最小化 → score の最大化)

score = -val_loss
  • val_loss小さいほど良い(誤差が少ないほど精度が高い)。
  • しかし、best_score は「大きいほど良い値」として扱うため、score = -val_loss に変換

2-2. scorebest_score + delta を下回る(改善なし)

elif score < self.best_score + self.delta:
    self.counter += 1  # 改善がなかった回数をカウント
    self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
    
    if self.counter >= self.patience:
        self.early_stop = True  # patience を超えたら学習停止
  • 改善がない(スコアがほとんど変わらない)場合
    • counter を増やす(「何回連続で改善なし」かを記録)。
    • counterpatience(デフォルトは7回)を超えたら、early_stop = True にして 学習を終了

3️ save_checkpoint メソッド(モデルを保存する)

def save_checkpoint(self, val_loss, model):
  • 改善があった場合にモデルを保存する関数
if self.verbose:
    self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
  • verboseTrue なら、どのくらい改善したかを表示
torch.save(model.state_dict(), self.path)
  • 現在のモデルのパラメータを保存
  • self.val_loss_min = val_loss最小 val_loss を更新

実際の使い方

この EarlyStopping クラスを使って、訓練ループに組み込む方法を見てみましょう。

early_stopping = EarlyStopping(patience=5, verbose=True)

for epoch in range(100):
    train(...)  # 訓練データで学習
    val_loss = validate(...)  # 検証データで損失を計算

    early_stopping(val_loss, model)  # EarlyStopping を実行

    if early_stopping.early_stop:  # もし停止判定なら
        print("Early stopping!")
        break  # ループを抜ける

まとめると、

  1. val_loss(検証誤差)が改善しなくなったら学習をストップする仕組み
  2. 一定回数(patience)連続で改善しなければ訓練を終了
  3. 最も良かったモデル(val_loss が最小だったとき)を保存
  4. 過学習を防ぎ、訓練時間を短縮できる!

ハイパラメーターの最適化

最適なオプティマイザーを選択(get_optimizer())の解説

この関数 get_optimizer() は、最適なオプティマイザー(最適化アルゴリズム)を選ぶ ためのものです。
この選択は、モデルの学習速度や最終的な精度に大きな影響を与える ため、とても重要です。


まず「オプティマイザー」って何?

オプティマイザーとは、モデルのパラメータ(重み)をどのように更新するかを決めるもの です。
機械学習では、ニューラルネットワークの学習に 「勾配降下法」 という手法を使います。

簡単に言うと…

  1. モデルが予測を行う
  2. 誤差(損失関数の値)を計算
  3. 勾配を計算(どの方向にパラメータを変えれば良いか)
  4. オプティマイザーがパラメータを更新
  5. これを繰り返し、誤差を減らしながらモデルを最適化

オプティマイザーによって、この「パラメータの更新方法」が変わります。
どのオプティマイザーを選ぶかによって、学習速度や精度が変わる ので、最適なものを選ぶことが重要です。


🛠 get_optimizer() のコード

def get_optimizer(trial, model):
    optimizer_names = ['Adam', 'RMSprop']  # 候補のオプティマイザー
    optimizer_name = trial.suggest_categorical('optimizer', optimizer_names)  
    weight_decay = trial.suggest_float('weight_decay', 1e-8, 1e-2, log=True)

    if optimizer_name == 'Adam':
        Adam_lr = trial.suggest_float('Adam_lr', 1e-5, 1e-1, log=True)
        optimizer = optim.Adam(model.parameters(), lr=Adam_lr, weight_decay=weight_decay)
    else:
        RMSprop_lr = trial.suggest_float('RMSprop_lr', 1e-5, 1e-1, log=True)
        optimizer = optim.RMSprop(model.parameters(), lr=RMSprop_lr, weight_decay=weight_decay)

    return optimizer

** どのオプティマイザーを使うか選択**

optimizer_names = ['Adam', 'RMSprop']  # 候補のオプティマイザー
optimizer_name = trial.suggest_categorical('optimizer', optimizer_names)

ここでは、最適化の候補として AdamRMSprop の2つを用意しています。
trial.suggest_categorical('optimizer', optimizer_names) を使うことで、optunaどちらのオプティマイザーが良いか試行錯誤させる ことができます。

オプティマイザー 特徴
Adam(Adaptive Moment Estimation) 学習率を自動調整するので万能・収束が早い
RMSprop(Root Mean Square Propagation) 学習率の変化を適応的に調整・勾配消失を防ぐ

どちらが最適かは、データやモデルの構造によって異なるため、自動的に最適なものを選ぶ ために optuna を使います。


weight_decay(重み減衰)を決める

weight_decay = trial.suggest_float('weight_decay', 1e-8, 1e-2, log=True)

weight_decay とは?

  • 重み(パラメータ)を少しずつ小さくする ことで、過学習を防ぐためのテクニック。
  • これを L2正則化 とも呼びます。

trial.suggest_float() の意味

  • trial.suggest_float('weight_decay', 1e-8, 1e-2, log=True)
    1e-8(とても小さい値)〜 1e-2(少し大きめ) の範囲で、optuna が最適な weight_decay を探してくれる。

Adam(アダム)オプティマイザーを選んだ場合

if optimizer_name == 'Adam':
    Adam_lr = trial.suggest_float('Adam_lr', 1e-5, 1e-1, log=True)
    optimizer = optim.Adam(model.parameters(), lr=Adam_lr, weight_decay=weight_decay)

Adam の特徴

  • 学習率を自動調整 する → 初心者でも使いやすく、一般的なタスクで良い結果が出やすい!
  • 慣性のようなもの(モーメント)を考慮 して、パラメータの更新方向を滑らかにする。

学習率(lr)の決め方

  • trial.suggest_float('Adam_lr', 1e-5, 1e-1, log=True)
    • 1e-5(0.00001)〜 1e-1(0.1) の間で optuna が最適な学習率を探索。

get_optimizer() のまとめ

  • Adam or RMSprop を選択できる
  • optuna を使って最適なオプティマイザーを自動的に選ぶ
  • 学習率 (lr) も optuna で最適化
  • weight_decay(正則化)も調整して、過学習を防ぐ
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?