0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

グラフニューラルネットワークを使って有機化合物の事前学習と物性予測を試してみた

Last updated at Posted at 2024-05-30

はじめに

有機化合物データセットであるQM9データセットを用いて事前学習モデルを構築し、有機化合物の物性を予測するモデルを構築するために以下のことを試みました。
1.事前学習モデルとしてQM9データセットを用いたHOMO・LUMO予測モデルを構築する。
2.構築した事前学習モデルをHugging Faceにアップロードし、簡単にダウンロードできるようにする。
3.Hugging Faceから作成した事前学習モデルをダウンロードして有機化合物の物性予測を行うモデルを構築する。
4.構築した物性予測モデルの評価を行う。

コードはこちら
github
・google colablatory(完全版簡易版)

事前学習モデルの作成(HOMO・LUMO予測)

今回は事前学習モデルとしてまずQM9データセットを用いたsmilesを入力としHOMO準位とLUMO準位の二値予測を行うモデルを構築しました。
モデルはグラフニューラルネットワークを採用し、GCN, GIN, TopK_GCN, set2set_NMP(Neural Message Passing)を用いました。

データの読み込み

まずこちらのサイトからダウンロードしたQM9データセットから学習データとして分子をsmiles形式で抜き出し、ラベルとしてHOMOとLUMOのデータを抜き出しました。

学習データとして読み込んだsmilesはrdkitによりMolオブジェクトへの変換および三次元へのたたきおこしを行った後、以下のようにpytorch_geometric.data.Dataの形式でグラフデータに変換しました。

gnn_dataset.py
def mol2geodata_for_QM9(mol,h_y,l_y):
    smile = Chem.MolToSmiles(mol)
    atom_features =[get_atom_features(atom) for atom in mol.GetAtoms()]
    atom_features = np.array(atom_features)
    num_atom_features=len(atom_features[0])
    atom_features = torch.FloatTensor(atom_features).view(-1, len(atom_features[0]))

    edge_list,num_bond_features = get_edge_features(mol)
    edge_list=sorted(edge_list)
    
    edge_indices=[e for e,v in edge_list]
    edge_attributes=[v for e,v in edge_list]
    edge_indices = torch.tensor(edge_indices)
    edge_indices = edge_indices.t().to(torch.long).view(2, -1)
    edge_attributes = np.array(edge_attributes)
    edge_attributes = torch.FloatTensor(edge_attributes)
    #print(num_atom_features,num_bond_features)
    return TorchGeometricData(x=atom_features, edge_index=edge_indices, edge_attr=edge_attributes, num_atom_features=num_atom_features,num_bond_features=num_bond_features,smiles=smile, h_y=h_y, l_y=l_y)

詳しくはソースを見てください。

また、ラベルとして読み込んだHOMO・LUMOのデータは平均や分散を計算し、それらの値を用いてラベルのノーマライズを行いました。

モデルの定義

各モデルはpytorch_lightningによって記述しました。
また、今回はtorch_geometricを用いたモデルをHugging Faceに上げるため、こちらを参考にtorch_geometric.nn.model_hubPyGModelHubMixinを用いたクラスを定義しました。

PL_GNN_to_Hug.py
class PL_Basic_GNN(PL_BasicGNNs, PyGModelHubMixin):
    def __init__(self,model_name, dataset_name, model_kwargs):
        PL_BasicGNNs.__init__(self,**model_kwargs)
        PyGModelHubMixin.__init__(self, model_name,
            dataset_name, model_kwargs)

ここでPL_BasicGNNsPL_TopKmodelPL_Set2Setmodelは自作モデルです。
モデルの構造はそれぞれ異なりますが、ファインチューニングの際に回帰分析だけでなく多クラスの分類問題にも対応できるようfineune_dimtaskというパラメータに適当な値を入力することで最終層の出力を適切な形に変更したり、損失関数を分類ならCrossEntropyLossを、回帰ならMSELossを使用するようにしてあります。事前学習モデルの作成時はHOMO・LUMOの値予測なのでMSELossを用いていました。

PL_BasicGNN_models.py
    if model_name == 'GCN':
        self.model=GCN(in_channels=in_channels, hidden_channels=hidden_channels, num_layers=num_layers, out_channels=out_channels, dropout=dropout)
    elif model_name == 'GIN':
            self.model=GIN(in_channels=in_channels, hidden_channels=hidden_channels, num_layers=num_layers, out_channels=out_channels, dropout=dropout)
    elif model_name == 'GraphSAGE':
        self.model=GraphSAGE(in_channels=in_channels, hidden_channels=hidden_channels, num_layers=num_layers, out_channels=out_channels, dropout=dropout)
                
    self.set2set = Set2Set(out_channels, processing_steps=3)
    self.lin1 = Linear(2 * out_channels, out_channels)
    self.lin2 = Linear(out_channels, 2)
    if finetune_dim != 0: #fintue_dimを指定した際に最終層を変更
        print('Fintune model!')
        self.f_lin = Linear(out_channels, finetune_dim)
        
 def forward(self, data):
        h = self.model.forward(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr)
        h = self.set2set(h, data.batch)
        h = F.relu(self.lin1(h))
        if self.model_type == 'finetune':
            out = self.f_lin(h)
        else:
            out = self.lin2(h)
        return out

詳しくはソースを見てください。

学習開始

これらのモデルに対して適切な引数を与えてモデルを定義し、ラベルのノーマライズに必要な平均と分散の値を追加しました。
そして、結果やモデルの保存に必要な引数をcallbacksに与え、trainerを定義し、trainer.fit()により学習を開始させました。

Hugging Faceの準備

モデルのHugging Faceへの保存にはHugging Faceのトークンなどが必要になり、それらの取得にはHugging Faceのアカウントが必要です。
Hugging FaceのアカウントはHugging FaceのHPのSing Upから作成することが可能です。
HuggingFaceTop.png
アカウント作成後、登録したアドレスに確認メールが送られてくるので、そのメールに記載されているリンクを開き、メールアドレスの確認を行えばアカウントの登録は完了です。

そして、アカウントロゴのメニューより、(1)Settingを開き、(2)Access Tokensの項から、(3)New Tokenをクリックすることでトークンが発行されます。
HuggingFaceMakeToken.png

また、作成したモデルをHugging Faceに保存するためにはあらかじめモデルを保存するレポジトリを作成しておく必要があります。
先ほど同様にアカウントロゴのメニューより、(1)+ New Modelを開き、(2)作成するモデルの名前(これがレポジトリの名前になります。)と必要があればライセンスを入力し、モデルを公開するかどうかの選択をします。(この設定は後からも変更可能です。)そして、(3)Create modelをクリックすることでモデルレポジトリを作成することができます。
ここで設定したOwner/Model nameがレポジトリのIDになります。
HuggingFaceMakeModel.png
作成したモデルレポジトリは自身のプロフィールにて確認することができます。

Hugging Faceへのpush

学習を完了したモデルはsave_pretrained()を用いてHugging Faceに保存しました。
ここではHugging Faceの準備で作成したトークンとレポジトリのIDが必要になります。
save_directoryには保存するモデルのweightが保存されている自身のレポジトリを指定し、push_to_hubTrueにすることでモデルをpushできます。repo_idにはモデルを保存するHugging FaceのレポジトリのIDをtokenには自身のトークンを指定します。

PL_GNN_to_Hug.py
pl_model.save_pretrained(
        save_directory=state_save_dir, #wightの保存先
        push_to_hub=True, #Hugging Faceにモデルをpushする
        repo_id=repo_id, #Hugging Faceの保存先のレポジトリID
        token=my_token #自身のHugging Faceのトークン
     )

事前学習結果

batch_size = 512valid_size = 0.1, test_size = 0.1, max_epoch = 200に固定して各モデルを比較しました。

このモデルの評価を行うためテストデータを用いて予測を行いました。

PL_GNN_to_Hug.py
trainer.test(ckpt_path='best', dataloaders=test_loader) 

学習の結果や予測値の評価は学習曲線とR二乗値により評価しました。
pretrain_results.png
いずれのモデルも良く学習ができており、R二乗値についても全体でみればほぼ1.0を、HOMO、LUMOそれぞれでみても0.85以上の高い値で予測できていることが分かりました。

物性予測のためのファインチューニング

続いて作成した事前学習モデルを読み込み目的のデータセットに対してファインチューニングを行い物性予測を行いました。

また、ファインチューニングについてはgoogle colablatory(完全版簡易版)でも実装しています。

Hugging Faceからの読み込み

Hugging Faceからの読み込みには事前学習の際と同様にこちらを参考にして行いました。
まず、事前学習の際と同様にモデルのクラスを定義しました。
この際にファインチューニングを行うにあたり変更する必要のある変数をクラス内で再定義する必要があったため、二度手間にはなりますが、model_kwargsにそれぞれ必要な変数を代入しています。

PL_GNN_from_Hug.py
class PL_Basic_GNN(PL_BasicGNNs, PyGModelHubMixin):
    def __init__(self, dataset_name, model_name, model_kwargs):
        #正常に読み込まれないため再定義
        model_kwargs['finetune_dim'] = 2 
        model_kwargs['task'] = 'classification'
        model_kwargs['model_type'] = 'finetune'
        model_kwargs['model_name'] = model_name #変更しないでください。
        model_kwargs['in_channels'] = 81 #変更しないでください。
        PL_BasicGNNs.__init__(self,**model_kwargs)
        PyGModelHubMixin.__init__(self, model_name, dataset_name, model_kwargs)

その後、必要な引数を与えてモデルを呼び出し、from_pretrained()を用いてHugging Faceから事前学習モデルのパラメータを読み込みました。

PL_GNN_from_Hug.py
model_name = 'GCN'
task = 'classification'
model_type = 'finetune'
finetune_dim = 2
pl_gnn = PL_Basic_GNN(model_name=model_name, dataset_name='QM9',model_kwargs=dict(model_name=model_name, task=task, model_type=model_type, in_channels=num_atom_features, finetune_dim=finetune_dim))
#pretraine済みモデルの読み込み
pl_model = pl_gnn.from_pretrained(repo_id,model_name=model_name,dataset_name='QM9')

ファインチューニング

ファインチューニング用のデータセットはAmes試験データ水溶性に関するデータの二種類のデータセットを利用しました。
Ames試験データに関しては2クラス分類を水溶性に関するデータには3クラス分類と単回帰分析を行いました。
水溶性に関するデータにおける回帰分析では事前学習の際と同様にラベルのノーマライズを行いました。
その後、事前学習モデルを読み込み、事前学習と同様にtrainerを定義し、trainer.fit()により学習を開始させました。

予測結果

ファインチューニングではbatch_size = 128valid_size = 0.1, test_size = 0.1, max_epoch = 300に固定して各モデルを比較しました。
それぞれの場合において学習の結果は次のようになりました。
fintuen_loss.png
いずれの結果でも順調にlossが減少し、学習が進んでいることが分かります。

続いて予測結果の評価を行いました。
まずは、Ames試験データに対する結果はこちらになります。
Ames.png
モデルによってばらつきはありますが、75~83%くらいの正解率で予測することが出来ました。

次に、水溶性に関するデータの分類の結果がこちらになります。
solubility_clas.png
こちらはデータ数に対してデータのばらつきが大きかったためか、予測精度はあまり上がらず、最も良いものでも80%弱の正解率にとどまりました。

最後に水溶性に関するデータについて回帰分析を行った結果を示します。
solubility_rgr.png
こちらの結果はモデルによって大きく差が出ました。
TopK_GCNのモデルではR二乗値は0.6にも届かなかったのに対し、set2set_NMPのモデルでは0.85と高い値を示し、うまく予測することが出来ました。

まとめ

QM9データセットを用いてHOMOとLUMOを予測するモデルを作成し、そのモデルを化合物の物性予測のための事前学習モデルとしてHugging Faceにアップロードしました。そして、実際にHugging Faceからそのモデルを読みこみファインチューニングを行って物性予測も行なってみました。
いくつかの種類のモデルで検討を行いましたが、HOMO・LUMO予測モデルはいずれのモデルでも高いレベルで予測することが出来ました。
物性予測モデルについてはモデルによってはあまり予測性能が良くないものもありましたが、Hugging Faceから読みこみファインチューニングを行うだけで容易に物性予測モデルを構築することが出来ました。

(この記事は研究室インターンで取り組みました:https://kojima-r.github.io/kojima/)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?