LoginSignup
0
0

ProteinBERTのモデル性能評価

Last updated at Posted at 2024-02-03

※このブログはAidemy Premiumのカリキュラムの一環で、受講修了条件を満たすために公開しています

はじめに

機械学習・ディープラーニングの初学者です。
今までGithubにあるコードを動かした経験がないため、興味のある分野のコードを動かすことにしました。
私はバイオ系の研究者でタンパク質や抗体に関する研究を行っています。
最近、大規模言語モデルを使った研究が様々な分野で行われていて、バイオ分野でもタンパク質の構造や機能を予測するという研究が行われています。
大規模言語モデルとはchatgptの訓練にも使われている言語モデルですが、タンパク質のアミノ酸配列(20種類のアルファベットで表記可能)を言語とみなすことで、タンパク質の特徴量を掴む訓練をすることができます。
Githubでこのような言語モデルがないか検索した結果、ProteinBERTというモデルを見つけたので、実装しました。
ただし、時間やコストの制約上、タンパク質の大規模言語モデル自体を1から作ることはせずに、事前学習済みのモデルを使ったfinetuningを行いました。

※事前学習:大規模なデータセットやタスクで訓練されるプロセス
※finetuning:事前学習済みのモデルを特定のタスクやデータセットに適用し、追加の訓練や微調整を行うプロセス

目的

GithubにあるProteinBERTを実装してみる。
パラメータによりモデル性能がどう変化するのか調べる。

手順

①Google colaboratryでProteinBERTを実装する
②パラメータを変更して、モデル性能を評価する

結果

以下のGithubページからProteinBERTを実装する。
https://github.com/nadavbra/protein_bert

初めに以下のコードを実行したが、エラーが出てしまった。

git submodule init
git submodule update
python setup.py install

そこで、以下を初めに実行したところ、解決した。これはGitレポジトリを初期化するコマンドで、最初に設定が必要らしい。

!git init

次に以下で、Gitのレポジトリを複製。

!git clone https://github.com/nadavbra/protein_bert

また、protein_bertフォルダへ移動。

cd protein_bert

そして、サブモジュールの初期化とダウンロード、pythonパッケージのインストールを行う。

!git submodule init
!git submodule update
!python setup.py install

必要なpythonモジュールをインポートする。

import os

import pandas as pd
from IPython.display import display

from tensorflow import keras

from sklearn.model_selection import train_test_split

データが保管されているprotein_benchmarksフォルダへ移動する。

cd protein_benchmarks
Benchmarksの場所を定義する。何回か使用するため。
BENCHMARKS_DIR = ‘/content/protein_bert/protein_benchmarks'

Protein_bertフォルダへ戻る。

cd ..

protein_bertから必要なモジュールをインポートする。
シグナルペプチドかの2値判定データでfinetuneを行う。
*シグナルペプチドは、細胞内や細胞間で情報を伝達するための小さなペプチド(アミノ酸の連鎖)です。通常、シグナルペプチドは細胞内シグナル伝達経路を活性化または抑制するために、特定の受容体と相互作用します。これにより、細胞は外部からの刺激に反応し、適切な生理学的応答を引き起こすことができます。

from proteinbert import OutputType, OutputSpec, FinetuningModelGenerator, load_pretrained_model, finetune, evaluate_by_len
from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs

BENCHMARK_NAME = 'signalP_binary'

# A local (non-global) binary output
OUTPUT_TYPE = OutputType(False, 'binary')
UNIQUE_LABELS = [0, 1]
OUTPUT_SPEC = OutputSpec(OUTPUT_TYPE, UNIQUE_LABELS)

# Loading the dataset

train_set_file_path = os.path.join(BENCHMARKS_DIR, '%s.train.csv' % BENCHMARK_NAME)
train_set = pd.read_csv(train_set_file_path).dropna().drop_duplicates()
train_set, valid_set = train_test_split(train_set, stratify = train_set['label'], test_size = 0.1, random_state = 0)

test_set_file_path = os.path.join(BENCHMARKS_DIR, '%s.test.csv' % BENCHMARK_NAME)
test_set = pd.read_csv(test_set_file_path).dropna().drop_duplicates()

print(f'{len(train_set)} training set records, {len(valid_set)} validation set records, {len(test_set)} test set records.')

事前学習済みのモデルをロードしてfinetuneを行う。finetuneにとても時間がかかった。Google colaboのCPUでずっと動かしていたためと判明。GPUに切り替えて計算するも数時間はかかりそうだった。そこで、colaboの有料会員となりプログラム実行したところ、15分くらいで終了した。

# Loading the pre-trained model and fine-tuning it on the loaded dataset

pretrained_model_generator, input_encoder = load_pretrained_model()

# get_model_with_hidden_layers_as_outputs gives the model output access to the hidden layers (on top of the output)
model_generator = FinetuningModelGenerator(pretrained_model_generator, OUTPUT_SPEC, pretraining_model_manipulation_function = \
        get_model_with_hidden_layers_as_outputs, dropout_rate = 0.5)

training_callbacks = [
    keras.callbacks.ReduceLROnPlateau(patience = 1, factor = 0.25, min_lr = 1e-05, verbose = 1),
    keras.callbacks.EarlyStopping(patience = 2, restore_best_weights = True),
]

finetune(model_generator, input_encoder, OUTPUT_SPEC, train_set['seq'], train_set['label'], valid_set['seq'], valid_set['label'], \
        seq_len = 512, batch_size = 32, max_epochs_per_stage = 40, lr = 1e-04, begin_with_frozen_pretrained_layers = True, \
        lr_with_frozen_pretrained_layers = 1e-02, n_final_epochs = 1, final_seq_len = 1024, final_lr = 1e-05, callbacks = training_callbacks)

# Evaluating the performance on the test-set

results, confusion_matrix = evaluate_by_len(model_generator, input_encoder, OUTPUT_SPEC, test_set['seq'], test_set['label'], \
        start_seq_len = 512, start_batch_size = 32)

print('Test-set performance:')
display(results)

print('Confusion matrix:')
display(confusion_matrix)

モデル性能を算出した。
AUCは0から1の間の値をとる精度指標の一つで、1だと精度100%.
今回のモデルのAUCはほとんど1のため、相当精度が高い。

Test-set performance:(バッチサイズ32)
image.png
# records AUC
Model seq len
512 4152 0.995787
All 4152 0.995787

混同行列(Confusion Matrix)は、実際のクラスと予測されたクラスの組み合わせを示し、モデルがどれだけ正確に分類を示す。
混同行列の対角線上の要素は、正しい予測の数を示し、対角線以外の要素は誤った予測の数を示す。
Confusion matrix:(バッチサイズ32)
image.png
0 1
0 3436 42
1 26 648

パラメータのバッチサイズを32→64へ変更した。バッチサイズを2倍にすることで、1回のトレーニングで処理されるサンプル数が2倍になる。
バッチサイズを増やすことで、各サンプルのノイズが平均化されるため、学習が安定し、過学習を減らすことができる。
デメリットは、個々のサンプルの特徴を見落としてしまうかもしれないこと。

計算の結果、精度はほとんど変わらなかったが、バッチサイズ32と比べて僅かにAUC精度が落ちていた。

Test-set performance:(バッチサイズ64)

image.png
# records AUC
Model seq len
512 4152 0.995151
All 4152 0.995151

Confusion matrix:(バッチサイズ64)
image.png
0 1
0 3448 30
1 31 643

パラメータのバッチサイズを32→16へ変更した。
計算の結果、こちらも精度はほとんど変わらなかったが、バッチサイズ32,64と比較して僅かにAUC精度が上がっていた。
計算時間を測っておけば良かったと後から気づいたが、体感でバッチサイズ16の計算時間が長かったと思う。
サンプルをある程度まとめて調べていくより、細かく調べていく方が精度は上がりそうなので、イメージにも合致した結果だった。
ただ、バッチサイズを小さくすると計算時間がかかるため、単純に小さければ良いというわけではなさそう。
計算をした後に、バッチサイズやパラメータについて調べていくと、学習率など他のパラメータと合わせて最適なパラメータを調整していてなかなか奥が深そうだった。

Test-set performance:(バッチサイズ16)
image.png
# records AUC
Model seq len
512 4152 0.995609
All 4152 0.995609

Confusion matrix:(バッチサイズ16)
image.png
0 1
0 3439 39
1 29 645

感想

google colaboを使って、初めてGithubのモデルを動かしてみた。また、一部のパラメータを変更して予測精度の比較を行った。
今回は元々のモデル性能が非常に高いためかパラメータ違いでほとんど精度の差は見られなかったものの、僅かな精度の違いや計算時間の違いから、パラメータ設定の感覚を掴むことはできた。
当初公開されているプログラムを実行することは難しくないと思っていたが、環境の違いによってエラーがたくさん出たりして、想像以上に難しかった。今回学んだことを活かしてGithubにある興味のあるプログラムをどんどん動かしていきたい。また、コードの基礎的な理解が不足していて、初見でコードの概要すらつかめないため、「ゼロから作るDeep Learning」を写経してみる。
(他に良い勉強方法などあれば教えてもらえると嬉しいです)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0