36
24

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

TensorFlowに挫折した僕がPyTorchで圧倒的にシンプルなGCNを実装した話

Last updated at Posted at 2021-01-31

#はじめに
化学構造を対象としたGraph Convolutional Network(GCN)に関して、Convolutional Networks on Graphs for Learning Molecular Fingerprints とい論文が広く知られている。2015年に発表された論文であるが、化合物系のGCNの論文には必ず引用されていると言っていい程、有名な論文である。
既に様々なライブラリに実装されているが、今回PyTorchの力を借りてこの論文を自力で実装したので詳細を共有したい。

#モチベーション

論文の著者による実装コードは https://github.com/HIPS/neural-fingerprint で公開されている。

また、化合物によるディープラーニングライブラリとして DeepChemChainer Chemistry などいくつか存在するが、これらのライブラリに論文の手法は搭載されている。

これらの元コードがあるにもかかわらず、PyTorchで実装するに至ったモチベーションは以下の通りである。

  • 論文の手法を忠実に実装した上で、原子に独自の特徴を追加したり、予測に影響する部分構造の可視化をするなどの拡張をしたかった。
  • 論文の著者が書いたコードは Python2 ベースであり、古いライブラリを用いておりカスタマイズが難しい。
  • DeepChem は TensorFlow ベースということもあり、コードの理解が難しく挫折した(本記事のタイトルはこれによる)。
  • Chainer Chemistory は Chainer が開発終了になった。
  • 最後に、PyTorch で複雑なモデルを構築するスキルを習得したかった (実はこれが一番デカイ)。

#論文の概要
本論文は様々な記事でも解説が試みられているため、詳細はそちらを見ていただくとして、ここでは実装の説明に必要な範囲に限定して論文の内容を説明する。

分子の化学構造から特徴ベクトルを生成する従来手法として、ある一定のルールに基づき特定の長さの0または1のビットの並びに変換するフィンガープリントが知られている。本論文を一言でいうと、このフィンガープリントをGCNにより実現したものになる。

##既存手法の説明

既存の分子のフィンガープリントであるCircular fingerprints のアルゴリズムについて論文 (https://arxiv.org/pdf/1509.09292.pdf) より引用し、これに沿って説明する。
image.png
初期化処理として、全て0からなる長さSのフィンガープリントを用意しておく。
また、分子中の各原子について特徴ベクトルを生成しておく(特徴ベクトルについては、論文アルゴリズムの説明の箇所で具体的に説明する)。

その後、指定した層の数だけ以下①~⑤を繰り返す。
① 各原子について、自身の特徴ベクトルと隣接する周りの原子との特徴ベクトルを concatinate する(7, 8行目)
② ハッシュ関数を施し、その結果を新たな原子の特徴とする(9行目)(※注1)。
③ ②のハッシュ関数の戻り値に対し、フィンガープリントの長さ S でmod演算を施し変数iに格納する(10行目)。i は 0~(フィンガープリント長S-1) の間の整数になる。
④ フィンガープリントの i 番目のビットに1を立てる(11行目)。
⑤ ①~④を全原子分行う。

最終的に得られた$f$が、分子のフィンガープリントとなる。

何層か処理を繰り返す度に、隣の原子、さらに隣の隣の原子といった形で、周囲の情報が次々に原子に埋め込まれ、ハッシュ関数により少し引数が異なるだけでも違う値が得られるため、構造が少し違うだけでも異なるフィンガープリントを作成することができる。

一方で重複なくあらゆる可能な部分構造を区別するためには、非常に大きな長さのフィンガープリントを用意しておく必要があり、長さが足りない場合には異なる構造が同じインデックスに割り当てられられ、情報が欠落してしまうという欠点がある。

※注1) 10行目では$r_a$がスカラー表記であるが、9行目ではベクトル表記っぽく見えるため、ここの解釈は誤っているかもしれない。

論文アルゴリズムの説明

次にGCNによるフィンガープリントのアルゴリズムも、同様に論文(https://arxiv.org/pdf/1509.09292.pdf) より引用し説明する。
image.png

特徴ベクトルの生成

1、2行目はニューラルネットの重み、フィンガープリントの初期化処理である。
3、4行目の以下は、従来手法と同様、全原子に対する特徴ベクトルを生成している。

r_a ← g(a) 

原子の特徴ベクトルについて具体的にはイメージが湧かないため、論文の実装コードを説明する。
論文の実装コードでは以下の仕様となっており、SMILESという分子の文字列表現から、RDKitを用いて特徴ベクトルを生成している。

###原子の特徴ベクトル(全62ビット)

項目 説明 ビット数
原子番号 C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'Unknown'のワンホットエンコーディング 44
次数 0~5のワンホットエンコーディング 6
水素数 0~4のワンホットエンコーディング 5
電荷 0~5のワンホットエンコーディング 6
芳香属性か 0 or 1 1

また、論文には記載がないが、論文の実装コードでは以下のような結合の特徴ベクトルも利用されているため紹介する。

###結合の特徴ベクトル(全6ビット)

項目 説明 ビット数
単結合か 0 or 1 1
二重結合か 0 or 1 1
三重結合か 0 or 1 1
芳香環結合か 0 or 1 1
共役結合か 0 or 1 1
環に属する結合か 0 or 1 1

5行目は1層からR層まで繰り返し処理を行うことを示している。
6行目からは、各層において全原子に対して行う処理が記載されている。
以下、各原子に対する処理を説明する。

畳み込み処理

以下7行目では、原子に結合している全ての原子の特徴ベクトルを取り出している。

r_1,...,r_N = neighbors(a)

ここでNは原子の次数によって変わり、結合数が3の原子であれば、N=3となり、$ r_1, r_2, r_3$となる。

以下の8行目の処理では、結合する全原子の特徴ベクトルを合計したものを、原子自身の特徴ベクトルに加算し ベクトル$v$に格納している。結合する全原子の特徴ベクトルを合計しているのは、結合する原子の並びに影響を受けないモデルにするためで、並びを変えても同じ値になる演算としてSUMを採用したとのことである(精度を無視していうなら、平均でもMAXでもよいことになる)。

v ← r_a + \sum_{i=1}^{N}r_i

なお、論文の実装コードでは、原子間の結合ベクトル(6ビット)も concatinate した上で上記の処理を行っている。このあたり、論文のアルゴリズムと実装コードが微妙に異なっているため、詳細を知りたい場合は実装コードを読むことをお勧めする。

以下9行目では、ベクトル$v$と、層毎次数毎に用意している重み行列の行列積をとった後、活性化関数を施している。これは既存手法におけるハッシュ関数に対応している。

r_a ←  σ(vH_{N}^{L})

これにより、隣接する原子の特徴が対象原子に埋め込まれる。以降、何層か処理を繰り返す度に、隣の原子、さらに隣の隣の原子といった形で、周囲の情報が次々に原子に埋め込まれる(畳み込み)。

プーリング処理

以下10行目は、原子の特徴ベクトルを、フィンガープリントのサイズに線形変換し、Softmax演算を施している。

i ← softmax(r_aW_L)

これは、既存手法においてフィンガープリントの特定のビットに1を立てる操作に対応している。

以下11行目は上で得られたフィンガープリントと同じサイズのベクトルを、フィンガープリント$f$に足している。

f ← f + i

全層の処理が終わった時点の$f$が、最終的な分子のフィンガープリントとなる。

このようにノードである原子の全ての特徴ベクトルを、分子全体の単一の固定長のフィンガープリントに結合している操作は、標準の畳み込みニューラルネットワークのプーリング操作に類似していると論文に書かれている。

その他のまとめ

ここまでがフィンガープリントに着目した従来手法と論文手法の説明であるが、この後の実装の説明に入る前に、その他論文に記載されていることや補足を以下に記載する。

  • 従来手法の場合は学習が不要であるのに対し、論文手法は重み行列のパラメータを学習させる必要がある。特徴ベクトルのサイズを$F$、フィンガープリントのサイズを$L$とすると、学習が必要な重み行列は以下の通りである。
  • 畳み込み処理では、畳み込み層毎に$F$ x $F$の重み行列が、原子の持ちうる結合の数(有機化合物では5まで)の分だけ必要。
  • プーリング処理では、畳み込み層毎に$F$ x $L$のサイズ の重み行列が必要
  • 従来手法は、大きなランダムな重みとした場合の、論文手法の特殊なケースといえる。その理由は、畳み込み層の活性化関数(tanh)は、入力の重みが大きい場合ステップ関数に近づくため、単純なハッシュ関数とみなすことができるためである。
  • 従来手法のハッシュ関数を、tanh等による滑らかな活性化関数に置き換えたことにより汎化性能の向上が示唆された。その理由として、あまり重要でない局所的な部分構造の変化であったとしても、従来手法ではハッシュ関数により異なるフィンガープリントになっていたが、論文手法では活性化状態が変わらないように学習させることができるからである。
  • フィンガープリントを生成した後は、それを特徴とみなした通常の機械学習と同じ流れになる。本論文の実装コードでは多層パーセプトロンに接続している。

次からはPyTorchで実装したものの説明を行う。

#作ったもの
こんな感じ

  • SDF、SMILESともに入力とするコマンドラインプログラムとした。
    論文の実装コードはSMILESのみ入力となっていたが、独自の原子の特徴を取り込めるようにするためSDFの方が都合がよいのでSDFにも対応した。
  • 分子からの原子や結合の特徴ベクトルの生成やグラフデータの持ち方は、論文の実装コードをほとんど流用した。逆にいうとPyTorchのおかげでほとんど手を加える必要がなかった。
  • 論文の実装コード同様にミニバッチ学習に対応した。即ち分子のリストが入力として与えられたときに、分子のリストをミニバッチ単位に分割し、その単位毎に勾配計算を行い重みを更新するようにした。
  • 入力データを学習とテストを8:2に自動で分け、学習データで予測モデルを作成し検証データでテストを行い予測精度を出力する。
  • 今のところ分類は2値分類しかできない。
  • 作成した予測モデル指定して予測を行う機能はない。
  • GPUによる学習が可能(今のところ1GPUのみ利用可能)。
  • Early Stoppingに対応。
  • axによるハイパーパラメータチューニングが可能。
  • Tensor Boardによる学習の可視化が可能。

#環境
今回使ったライブラリは以下の通りである。

  • PyTorch 1.7.0
  • RDKit 2020.09.2
  • scikit-learn 0.23.2
  • pytorch-lightning 1.1.4
  • ax-platform 0.1.19
  • tensorboard 2.4.1 (TensorBoard可視化時に必要)

#ソースの解説
ソースは、https://github.com/kimisyo/simple-GCNで公開しているが、ここではざっと解説してみる。

フォルダ構成

フォルダ構成は今後の拡張性も考慮し、こんな感じにした。

|-- run.py #実行ファイル
|-- simpleGCN
    |-- data  #データ読み込み
    |   |-- dataloader.py
    |   |-- dataset.py
    |-- feat   #特徴ベクトル生成
    |   |-- features.py
    |   |-- mol_graph.py
    |-- layers # モデル内レイヤー定義
    |   |-- graph_conv_layer.py
    |-- models # モデル定義
    |   |-- graph_conv_model.py

以下、各パーツを解説していく。

データ読み込み

データローダクラスの作成

SMILES用とSDF用それぞれのデータを取り込むデータローダークラスを作成した。構造データをRDKitのMolオブジェクトに変換し、ラベルと共にmols, labels_listという変数に格納している。Molオブジェクトに変換している理由は、原子の特徴なども取り込む際にMolオブジェクトにその情報を保持させておくと、将来的の機能追加の際に改修が少ないだろうと考えたためである。

dataloader.py
class SMILESDataLoader:

    def __init__(self, csv_path, smiles_col="smiles", label_props=["label"]):
        # CSVの読み込み
        df = pd.read_csv(csv_path)

        # SMILESの読み込みとMOLへの変換, ラベルの読み込み
        mols = []
        labels_list = []
        for i, samples in enumerate(df[smiles_col].values):
            from rdkit import Chem
            mol = Chem.MolFromSmiles(samples)
            if mol is not None:
                mols.append(mol)
            else:
                mols.append(None)

            labels = [None] * len(label_props)
            for j, label in enumerate(label_props):
                labels[j] = df[label].values[i]

            labels_list.append(labels)

        self.mols = mols
        self.labels_list = labels_list


class SDFDataLoader:

    def __init__(self, sdf_path, label_props=["label"]):

        from rdkit import Chem
        sdf_sup = Chem.SDMolSupplier(sdf_path)

        mols = []
        labels_list = []
        for i, mol in enumerate(sdf_sup):
            labels = []
            if not mol:
                continue

            mols.append(mol)

            labels = [None] * len(label_props)
            for j, label in enumerate(label_props):
                try:
                    labels[j] = float(mol.GetProp(label))
                except Exception as ex:
                    labels[j] = None

            labels_list.append(labels)

        self.mols = mols
        self.labels_list = labels_list

データセットクラスの作成

次にPyTorch用のデータセットクラスとして、以下のようにMolオブジェクトおよび、ラベルを返すものを作った。

dataset.py

class MoleculeDataset(data.Dataset):

    def __init__(self, mol_list, label_list):
        self.mol_list = mol_list
        self.label_list = label_list

    def __len__(self):
        return len(self.mol_list)

    def __getitem__(self, index):
        return self.mol_list[index], self.label_list[index]

ここは特に工夫はない。

モデル入力用データの作成

###原子および結合の特徴ベクトルの生成
ここはほぼ論文の実装コードから流用している。1点、論文はSMILESから原子、結合の特徴量を生成しているが、今回はMolオブジェクトを保持するようにしたため、Molオブジェクトから特徴量を生成するよう変更した。ソースではfeaturesディレクトリの下のgraph_mol.pyおよびfeatures.pyが該当する。
これらに関する詳しい説明は省くが、MolGraphクラスのgraph_from_mol_tupleメソッドを適用した結果得られる特徴ベクトル等が格納されたデータ構造について解説しておく。

###モデル入力用データ構造

graph_from_mol_tupleにより、分子の数だけMolオブジェクトが格納されたリストを与えると、分子グラフを表現するための各種データが格納されたディクショナリが生成される。

以下分子グラフの各種データを説明する。

atome_features

原子数 x 62のサイズの特徴ベクトルである。全分子の原子がまとめられていることに留意されたい。
また、booleanになっているが、モデルの入力に与える際にtorch.Tensorのfloat型に変換している。

 array([[False, False, False, ..., False,False, False],
       [False, False,  True, ..., False, False, False],
       [False, False,  True, ..., False, False, False],
       ...,
       [ True, False, False, ..., False, False, False],
       [ True, False, False, ..., False, False, False],
       [ True, False, False, ..., False, False, False]])

bond_features

結合数 x 6 のサイズの特徴ベクトルである。atom_features同様、全分子の結合がまとめられている。

array([[ True, False, False, False, False, False],
       [False, False, False,  True,  True,  True],
       [False, False, False,  True,  True,  True],
       ...,
       [False, False, False,  True,  True,  True],
       [False, False, False,  True,  True,  True],
       [False, False, False,  True,  True,  True]])

atom_list

分子数分からなるPythonのリスト(のリスト)である。各行は分子に対応しており、各行のリストは、その分子がどの原子から構成されるかを示している。原子はatom_featuresにおけるインデックスで指定されている。原子毎に畳み込みを行った結果を、分子単位に集約する際に用いられる。

 [
  [0, 1507, 344, 345, 1508, 346, 1509, 1, 347, 1510, 1511, 1512, 348, 349, 350, 351, 352, 1513, 353, 1514, 2, 1515, 354, 355, 356, 357], 
  [3, 358, 359, 360, 361, 1516, 4, 362, 363, 1517, 364, 365, 366, 367, 1518, 1519, 5, 368, 369, 1520, 370, 371, 372, 373, 374], 
  [6, 375, 1521, 376, 377, 378, 1522, 379, 380, 381, 1523, 382, 383, 1524, 384, 1525, 385, 1526, 7, 1527, 1528, 8, 386, 387, 388, 1529, 9, 1530, 10, 11, 389, 390], 
  [12, 2179, 13, 391,392, 1531, 14, 1532, 393, 1533, 394, 1534, 395, 1535, 396, 397, 398, 399, 400], 
  [15, 2180, 16, 1536, 17, 401, 1537, 402, 403, 1538, 404, 1539, 18, 2181, 19, 20, 405, 406, 407, 21, 2182, 22, 23, 24],
   ...

 ]

重要なポイントとして、分子毎に原子数が異なるので、numpy配列ではなくPythonのリストになっている点である。
論文の実装では、プーリングする際にこのリストをfor分で回して原子の結果からから分子の結果を生成している。
DeepChemでは、これをTensorflowで計算できるようにデータの持ち方を変えているため、実装が煩雑となっているようである。
今回は、めんどくさかったので実装の分かりやすさを優先し、このままのデータ構造で進めることとした。

atom_neighbors

次数(結合数)別に、原子の隣接原子が格納される。なお、言い忘れたが今回論文の実装と同様、原子の次数は5次までと想定している。

次数=1の原子用の配列

次数1の場合は、次数1の原子数 x 1 の配列となる。
隣接原子は atom_featuresにおけるインデックスで指定している。

 array([[1507],
       [1509],
       [1514],
        ...,
       [2172]]) 
次数=2の原子用の配列

次数2の場合は、次数2の原子数 x 2 の配列となる。
隣接原子は atom_featuresにおけるインデックスで指定している。

 array([[1507,  345],
       [ 344, 1508],
       [1508, 1509],
       ...,
       [1503, 2175],
       [2174, 2178],
       [2178, 2172]])
次数=3の原子用の配列

次数3の場合は、次数3の原子数 x 3 の配列となる。
隣接原子は atom_featuresにおけるインデックスで指定している。

array([[   0,  344,  353],
       [ 345,  346, 1513],
       [ 346,    1,  347],
       ...,
       [1496, 1497, 1503],
       [1497, 1498, 1502],
       [1505, 1506, 2173]])
次数=4の原子用の配列

次数4の場合は、次数4の原子数 x 4 の配列となる。
隣接原子は atom_featuresにおけるインデックスで指定している。

array([[  12,   13,  391,  395],
       [  15,   16, 1536, 2182],
       [1538,   19,   20,  405],
       [2180,   22,   23,   24],
       [1566,   35,   36,   37],
      ...,
       [2079,  312,  313,  314]])
 

5次のデータは今回なかったので掲載略。いまだ試してません(笑)。

bond_neighbors

次数別(結合数)別に、原子が保持する結合が格納される。結合は、bond_featuresにおけるインデックスで指定している。
それ以外はatom_neighborsにおける次数別の原子と構造は同じなので、説明を省略する。

rdkit_ix

各原子のRDKitのMolオブジェクト内おけるインデックス。これは可視化等のために利用するが、今回使っていないので説明を省略する。

atom_labels

各原子のラベル。これは原子ラベルの予測のために利用する予定で実装しているが、今回使っていないので説明を省略する。

ミニバッチ生成ための仕組み

さてデータの読み込みと、モデル入力用データの作成が終わったので、これをモデルに与えられるようなミニバッチ生成を検討する。

これまでに作成したPyTorchデータセットであるMoleculeDatasetは、RDKitのMolオブジェクトと分子ラベルのリストを返却するが、モデル入力用データ構造の方は原子特徴ベクトルのリストやatom_list、次数別の原子の隣接原子、結合等になっており、形状が大きく異なっている。またデータセットは分子単位で返すが、モデルは原子単位で学習を行う必要がある。
このように、データセットからイテレーションで得られるデータと、モデルの入力に与えるデータが大きく異なる点が、化合物におけるGCNの難しいところである。
しかし、PyTorchならばcollate_fnという仕組みを使えば、簡単にデータセットから受け取ったデータを任意の形式に変換することができる。
詳しくはPytorchのcollate_fnを使ってみるを見てほしい(丸投げ)。

バッチから受け取ったデータを変換して返却する関数を、PyTorchのデータローダのcollate_fn引数に与えることで、好みの形にバッチを変換できるのである。百聞は意見に如かず。その関数の実装を見てみよう。

dataset.py
def gcn_collate_fn(batch):

    mols = []
    labels = []
    atom_labels = []
    atom_features = []

    for i, (mol, label, *atom_data) in enumerate(batch):
        mols.append(mol)
        labels.append(label)
        if len(atom_data) > 0:
            atom_labels.append(atom_data[0])
            if len(atom_data) > 1:
                atom_features.append(atom_data[1])

    molgraph = graph_from_mol_tuple(mols, atom_labels, atom_features)
    arrayrep = {'atom_features' : molgraph.feature_array('atom'),
                'bond_features' : molgraph.feature_array('bond'),
                'atom_list'     : molgraph.neighbor_list('molecule', 'atom'), # List of lists.
                'rdkit_ix': molgraph.rdkit_ix_array(),
                'atom_labels': molgraph.labels_array()
                }  # For plotting only.

    for degree in degrees:
        arrayrep[('atom_neighbors', degree)] = \
            np.array(molgraph.neighbor_list(('atom', degree), 'atom'), dtype=int)
        arrayrep[('bond_neighbors', degree)] = \
            np.array(molgraph.neighbor_list(('atom', degree), 'bond'), dtype=int)

    #print(arrayrep)
    return arrayrep, labels

引数で受け取ったbatchに、MoleculeDatasetからのMolオブジェクト、分子ラベルのタップルのリストが格納されているので、さっきのgraph_from_mol_tuple関数を使って分子グラフ情報(array_rep)を生成し、それを分子ラベルとペアでそれを返却している。

あとは、この関数をPyTorchのDataLoaderのcollate_fn引数に与えるだけある。ここまでに作成したDataLoader, DataSetを用いてデータを読み込み、scikit-learnのtrain_test_splitでデータを分割し、データセットを生成する流れはこんな感じになる。

run.py
    if params["file_type"] == "sdf":
        loader = SDFDataLoader(params["train_data_file"], label_props=params["label_cols"])
    else:
        loader = SMILESDataLoader(params["train_data_file"], smiles_col=params["smiles_col"], label_props=params["label_cols"])

    from sklearn.model_selection import train_test_split
    X_train, X_test, y_train, y_val = train_test_split(loader.mols, loader.labels_list, shuffle=True, train_size=0.8, random_state=params["random_seed"])

    molecule_dataset_train = MoleculeDataset(X_train, y_train)
    data_loader_train = data.DataLoader(molecule_dataset_train, batch_size=params["batch_size"], shuffle=False, collate_fn=gcn_collate_fn)

    molecule_dataset_test = MoleculeDataset(X_test, y_val)
    data_loader_val = data.DataLoader(molecule_dataset_test, batch_size=params["batch_size"], shuffle=False, collate_fn=gcn_collate_fn)

元の論文の実装のデータ構造をcollate_fnで吸収でき、非常に可読性が高くなっている。collate_fnに手動で関数を設定しないといけないのはダサいが、標準のDataLoaderを使っているので目を瞑ろう。

次からいよいよモデルを作っていく。

モデル構築

畳み込みレイヤー

周辺の原子の情報を対象原子に畳み込む畳み込みレイヤーは、PyTorchのモジュールとして実現した。入力はミニバッチから得られる、対象ミニバッチの全分子から生成された分子グラフ情報(array_rep)ディクショナリ、前のレイヤーからの原子特徴ベクトルのリスト、および結合の特徴ベクトルのリストであり、出力はそのレイヤーにおける新たな原子の特徴ベクトルである。

graph_conv_layer.py
class GraphConvLayer(nn.Module):

    def __init__(self,
               in_features,
               out_features,
               activation=nn.ReLU(),
               normalize=True
                 ):

        super().__init__()

        self.degrees = [0, 1, 2, 3, 4, 5]
        self.activation = activation
        self.normalize = normalize

        self.in_features = in_features
        self.out_channel = out_features

        self.self_activation = nn.Linear(in_features, out_features, bias=False)
        self.bias = nn.Parameter(torch.Tensor(np.zeros(out_features)))

        self.degree_liner_layer_list = []
        for k in range(len(self.degrees)):
            self.degree_liner_layer_list.append(nn.Linear(in_features + 6, out_features, bias=False))

        self.degree_liner_layer_list = torch.nn.ModuleList(self.degree_liner_layer_list)

        if self.normalize:
            self.batch_norm = torch.nn.BatchNorm1d(out_features)

    def forward(self, array_rep, atom_features, bond_features):
        self_activations = self.self_activation(atom_features)
        activations_by_degree = []

        for i, degree in enumerate(self.degrees):
            # indexとして利用するためlongに変換する
            atom_neighbors_list = torch.tensor(array_rep[('atom_neighbors', degree)], dtype=torch.long)
            bond_neighbors_list = torch.tensor(array_rep[('bond_neighbors', degree)], dtype=torch.long)

            if len(atom_neighbors_list) > 0:
                stacked_neighbors_tensor = torch.cat([atom_features[atom_neighbors_list], bond_features[bond_neighbors_list]], axis=2)
                summed_neighbors_tensor = torch.sum(stacked_neighbors_tensor, axis=1)
                activations = self.degree_liner_layer_list[i](summed_neighbors_tensor)
                activations_by_degree.append(activations)

        neighbour_activations = torch.cat(activations_by_degree, axis=0)
        total_activations = neighbour_activations + self_activations + self.bias

        if self.activation:
            total_activations = self.activation(total_activations)

        if self.normalize:
            total_activations = self.batch_norm(total_activations)

        return total_activations

論文の実装コードにほぼ1対1に対応する形で実装できた。Activation, BatchNormalizationなども論文の実装に沿っている。
実装上の注意点は以下の通りである。

  • for i, degree in enumerate(self.degrees):のループは次数毎に、対象次数を持つ原子で畳み込みを行っている処理である。ループ後、torch.cat(activations_by_degree, axis=0)で全次数分の原子をまとめて1つのリストにしている。
  • atom_features, bond_features の添え字にatom_neighbors, bond_neighborsを指定し、対象原子の周辺の特徴ベクトルを取得している箇所 torch.cat([atom_features[atom_neighbors_list], bond_features[bond_neighbors_list]], axis=2)は、このindexに指定するTensorをLong型に変換しておく必要がある。
  • PyTorchではよくある話のようだが、nn.Module型のオブジェクトをリストとして保持する際は、torch.nn.ModuleListに変換する必要がある。これをしないと、学習可能なパラメータとして認識してくれず、勾配が計算されない。

###プーリングレイヤー
ハッシュ化、原子情報を分子単位に集約するプーリングレイヤーについても、PyTorchのモジュールとして実現した。入力はミニバッチから得られる、対象ミニバッチの全分子から生成された分子グラフ情報(array_rep)ディクショナリ、前のレイヤーからの原子特徴ベクトルであり、出力はそのレイヤーにおける分子のフィンガープリントである。

graph_conv_layer.py
class GraphPoolingLayer(nn.Module):

    def __init__(self,
                 in_features,
                 out_features):

        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.liner_layer = nn.Linear(in_features, out_features, bias=True)

    def forward(self, array_rep, atom_features):

        tmp = self.liner_layer(atom_features)
        atom_outputs = nn.Softmax(dim=1)(tmp)

        xs = []
        for idx_list in array_rep['atom_list']:
            idx_list = torch.tensor(idx_list, dtype=torch.long)
            xs.append(torch.unsqueeze(torch.sum(atom_outputs[idx_list], axis=0), dim=0))

        layer_output = torch.cat(xs, axis=0)
        return layer_output

forwardメソッドの部分は、論文の実装コードにほぼ1対1に対応している。
実装上の注意点は以下の通りである。

  • 原子の単位のフィンガープリントを分子単位のフィンガープリントに集約する際に for idx_list in array_rep['atom_list']:で分子毎にループを回しながら、分子を構成する原子を集めている。論文の実装通りではあるが、このループはCPU演算となるため、ここが性能のボトルネックになる可能性がある。

モデルクラスの作成

これまでのレイヤーを全て組み合わせ、Pytorch Lightningを継承してモデルクラスを実装した。

graph_conv_model.py
class GraphConvModel(pl.LightningModule):

    def __init__(self,
                 device_ext=torch.device('cpu'),
                 task="regression",
                 atom_features_size=62,
                 conv_layer_sizes=[20, 20, 20],
                 fingerprints_size=50,
                 mlp_layer_sizes=[100, 1],
                 activation=nn.ReLU(),
                 normalize=True,
                 lr=0.01,
                 ):

        super().__init__()

        self.device_ext = device_ext
        self.task = task
        self.activation = activation
        self.normalize = normalize
        self.number_atom_features = atom_features_size
        self.graph_conv_layer_list = nn.ModuleList()
        self.graph_pooling_layer_list = nn.ModuleList()
        self.mlp_layer_list = nn.ModuleList()
        self.batch_norm_list = nn.ModuleList()
        self.lr = lr

        prev_layer_size = atom_features_size

        for i, layer_size in enumerate(conv_layer_sizes):
            self.graph_conv_layer_list.append(GraphConvLayer(prev_layer_size, layer_size))
            self.graph_pooling_layer_list.append(GraphPoolingLayer(layer_size, fingerprints_size))
            prev_layer_size = layer_size

        prev_layer_size = fingerprints_size
        for i, layer_size in enumerate(mlp_layer_sizes):
            self.mlp_layer_list.append(torch.nn.Linear(prev_layer_size, layer_size, bias=True))
            prev_layer_size = layer_size

        if normalize:
            for i, mlp_layer in enumerate(self.mlp_layer_list):
                if i < len(self.mlp_layer_list) -1 :
                    self.batch_norm_list.append(torch.nn.BatchNorm1d(mlp_layer.out_features))

    def forward(self, array_rep, atom_features, bond_features):

        all_layer_fps = []
        for graph_conv_layer, graph_pooling_layer in zip(self.graph_conv_layer_list, self.graph_pooling_layer_list):
            #畳み込み処理
            atom_features = graph_conv_layer(array_rep, atom_features, bond_features)
            #プーリング処理
            fingerprint = graph_pooling_layer(array_rep, atom_features)
            #このレイヤーでのフィンガープリントを一旦保存
            all_layer_fps.append(torch.unsqueeze(fingerprint, dim=0))

        #全レイヤーのフィンガープリントを加算し最終的な分子フィンガープリントを生成
        layer_output = torch.cat(all_layer_fps, axis=0)
        layer_output = torch.sum(layer_output, axis=0)

        # 分子フィンガープリントを入力とした多層パーセプトロンによる処理
        x = layer_output.float()
        for i, mlp_layer in enumerate(self.mlp_layer_list):
            x = mlp_layer(x)
            if i < len(self.mlp_layer_list) - 1:
                if self.activation:
                    x = self.activation(x)
                if self.normalize:
                    x = self.batch_norm_list[i](x)

        return x

    def training_step(self, batch, batch_idx):

        array_rep, labels_list = batch
        atom_features = array_rep['atom_features']
        bond_features = array_rep['bond_features']

        atom_features = torch.tensor(atom_features, dtype=torch.float)
        bond_features = torch.tensor(bond_features, dtype=torch.float)

        if self.task == "regression":
            labels = torch.tensor(labels_list, dtype=torch.float)
        else:
            labels = torch.tensor(labels_list, dtype=torch.float)

        atom_features = atom_features.to(self.device_ext)
        bond_features = bond_features.to(self.device_ext)
        labels = labels.to(self.device_ext)

        y_pred = self(array_rep, atom_features, bond_features)

        if self.task == "regression":
            loss = F.mse_loss(y_pred, labels)
            # https://github.com/pytorch/ignite/issues/453
            var_y = torch.var(labels, unbiased=False)
            r2 = 1.0 - F.mse_loss(y_pred, labels, reduction="mean") / var_y
        else:
            loss = F.binary_cross_entropy_with_logits(y_pred, labels)

            y_pred_proba = torch.sigmoid(y_pred)
            y_pred = y_pred_proba > 0.5
            acc = accuracy(y_pred, labels)
            y_pred_proba = y_pred_proba.to('cpu').detach().numpy().tolist()
            from sklearn import metrics
            try:
                rocauc = metrics.roc_auc_score(labels_list, y_pred_proba)
            except Exception as es:
                rocauc = 0

        if self.task == "regression":
            ret = {'loss': loss, 'train_r2': r2}
        else:
            ret = {'loss': loss, 'train_acc': acc, 'train_rocauc': torch.tensor(rocauc, dtype=torch.float)}

        return ret

ここも論文の実装コードにほぼ1対1に対応している。
PyTorch Lightningを継承しているため、training_stepメソッドに、ミニバッチから取得した時の学習コードを書いている。
入力データである atom_features, bond_features は boolean型になっているため、これをPyTorchのTensor.float型に変換し、forwardに与えている。
入力データは手動でGPU指定を行っている (モデル内のデータは PyTorch Lighitningが一括でやってくれる)。
r2の指標はPyTorchに見当たらなかったため、自前で作成している。

学習の実行

これで全てのお膳建てが整ったので、以下のようにモデルクラスをインスタンス化し、データローダクラスと共にPyTorch Lighitningのfitメソッドに与えればよい。

run.py

      model = GraphConvModel(
         device_ext=device,
         task=params["task"],
         conv_layer_sizes=params["conv_layer_sizes"],
         fingerprints_size=params["fingerprints_size"],
         mlp_layer_sizes=params["mlp_layer_sizes"],
         lr=params["lr"]
        )

        trainer = pl.Trainer(
             max_epochs=params["num_epochs"],
             gpus=gpu
        )
        trainer.fit(model, data_loader_train, data_loader_val)

#使い方
基本的な使い方は、パラメータをPythonのディクショナリ形式のconfigファイルで設定し、それをrun.pyの引数与えて実行する。これは、同じ化合物DeepLearningライブラリであるOpenChemを参考にしている。設定内容がファイルとして残るので超便利である。

設定ファイルは、configディレクトリに回帰と分類の例を入れている。

設定ファイルの例

以下は回帰の設定ファイル例である。

regression.py
model_params = {
    'task': 'regression', # 'regression' or 'classification'
    'random_seed': 42,
    'num_epochs': 30,
    'batch_size': 80,
    'file_type': 'smi', # 'sdf' (SDF) or 'smi' (SMILES)
    'train_data_file': 'input/Lipophilicity.csv', # input file path
    'label_cols': ['exp'], # target parameter label
    'smiles_col': 'smiles', # smiles column (required when 'file_type' is 'smi')
    'use_gpu': True, # use gpu or not
    'atom_features_size': 62, # fixed value
    'conv_layer_sizes': [20,20,20],  # convolution layer sizes
    'fingerprints_size': 50, # finger print size
    'mlp_layer_sizes': [100, 1], # multi layer perceptron sizes
    'lr': 0.01, #learning late
    'metrics': 'r2', # the metrics for 'check_point' , 'early_stopping', 'hyper'
    'minimize': False, # True if you want to minimize the 'metrics'
    'check_point':
        {
        "dirpath": 'model', # model save path
        "save_top_k": 3, # save top k metrics model
        },
    'early_stopping': # see https://pytorch-lightning.readthedocs.io/en/stable/generated/pytorch_lightning.callbacks.EarlyStopping.html
        {
        "min_delta": 0.00,
        "patience": 5,
        "verbose": True,
    },
    'hyper':
        {
            'trials': 10,
            'parameters':
                [
                    {'name': 'batch_size', 'type': 'range', 'bounds': [50, 300], 'value_type': 'int'},
                    {'name': 'conv_layer_width', 'type': 'range', 'bounds': [1, 4], 'value_type': 'int'},
                    {'name': 'conv_layer_size', 'type': 'range', 'bounds': [5, 100], 'value_type': 'int'},
                    {'name': 'fingerprint_size', 'type': 'range', 'bounds': [30, 100], 'value_type': 'int'},
                    {'name': 'mlp_layer_size', 'type': 'range', 'bounds': [30, 100], 'value_type': 'int'},
                    {'name': 'lr', 'type': 'range', 'bounds': [0.001, 0.1], 'value_type': 'float'},
                ]
        }
}
  • 分類はtaskclassificationになる以外は回帰と同様である。
  • 入力ファイル種別はfile_typesdf(SDFの場合) かsmi (SMILESの場合)を指定し、入力ファイルパスはtrain_data_fileで指定する。
  • GPUを使いたい場合は、use_gpuをTrueにする。
  • check_pointで指定した内容で、モデルの保存が行われる。dirpathで保存場所を、 save_top_kで最も検証精度が良いものを何件保存するかを指定できる。またその際の指標は、merticsで設定し、その精度が大きければよいか小さければよいかは、minimizeで指定する。この指標は、以下のearly_stopping、ハイパーパラメータ探索の指標と共通である。
  • Early Stoppingをやりたい場合は、run.pyの引数に-esオプションを指定し、その詳細は設定ファイルのearly_stoppingで指定する。パラメータの意味は、こちらを参照のこと。
  • ハイパーパラメータ探索をやりたい場合は、run.pyの引数に-hyperオプションを指定し、試行回数や探索範囲は、設定ファイルのhyperで指定する。探索範囲の指定方法は、axのドキュメントを見てくれ(丸投げ)。なお、ハイパーパラメータ探索を指定しない場合は、model_params直下で指定したパラメータが使われる。

実行方法

通常の実行

実行は以下のように-configオプションに設定ファイルを指定する。
この場合、model_params直下で指定したパラメータを使って学習を行い、全エポック終了すると学習/検証の指標出力する。
なお、検証精度は最終エポックの出力であるため、途中で最も高かった検証精度を知りたい場合は、check_pointの指定で保存したモデルのファイル名に精度が刻印されているのでそれを見てほしい。ここはイケてない作りである。
また、-esオプションを指定するとEarly Stoppingが行われる。

$ python run.py -config config/regression_qiita.py
Global seed set to 42
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                     | Type       | Params
--------------------------------------------------------
0 | activation               | ReLU       | 0
1 | graph_conv_layer_list    | ModuleList | 16.6 K
2 | graph_pooling_layer_list | ModuleList | 3.1 K
3 | mlp_layer_list           | ModuleList | 5.2 K
4 | batch_norm_list          | ModuleList | 200
--------------------------------------------------------
25.2 K    Trainable params
0         Non-trainable params
25.2 K    Total params
Epoch 2: 100%|████████████████████████████| 210/210 [00:12<00:00, 17.14it/s, loss=0.856, v_num=99]
train loss=0.8300003409385681
train r2=0.27188658714294434
validation loss=1.0553514957427979
validation r2=0.33314740657806396

ハイパーパラメータ探索

バッチサイズ、GraphConvolution層の深さ(層の数)、Convoluation層のサイズ(全層で共通)、フィンガープリントサイズ、多層パーセプトロンの中間層(1層固定)のサイズ、学習率を指定してハイパーパラメータ探索をすることができる。回帰に設定できる評価指標はlossとr2、分類の場合にはloss, accuracy, aucrocとなっている。
以下のようにrun.pyに-hyperオプションをつけて実行する。
-esオプションを指定すると各試行でEarly Stoppingが行われる。なお、ハイパーパラメータ探索の場合、途中のモデルの保存は行われない

$ python run.py -config config/regression_hyper.py -hyper

Tensor Boardによる学習の可視化

実行ディレクトリの直下に"lightning_logs"というフォルダが生成される。
学習中にそのフォルダを指定して、Tensor Boardを起動することで学習状況を可視化することができる。

Tensor Boardをインストールし、学習中に別のターミナルから以下を実行する。

tensorboard --logdir lightning_logs/

ブラウザでhttp://localhost:6006/ を叩くと以下のように学習状況を見ることができる。

image.png

TensorBoardとの連携は、DeepChemだと途方に暮れる感じだったので、これができたのは嬉しい。気が付いたらPyTorch Lightning君が勝手にやってくれていたという感じで、感謝である。

#検証
###検証データ
ここでは既存のデータを使って正しく実装できているかを確認する。利用したデータは MoleculeNet から取得可能な回帰用データである Lipophilicity を利用する。http://moleculenet.ai/datasets-1 からダウンロード可能である。
http://moleculenet.ai/latest-results によると予測精度は以下の通りとなっており、論文の実装である graphconvreg は R2 で 0.66 程度の精度となっている。

image.png

###検証
今回、自作ライブラリを使って以下の条件で予測モデルの作成を行った。

  • バッチサイズ 80(固定)
  • Graph Convolution 層の深さ 1, 2, 3, 4で確認
  • Graph Convolution 層のサイズ 20, 50, 80で確認
  • フィンガープリントのサイズ 50(固定)
  • 多層パーセプトロンのサイズ 100(固定)
  • 学習率 0.01(固定)
  • Early Stopping 5回連続検証データの r2 の値が改善しない場合終了
  • 精度 検証データのr2が値が最も高かったエポックにおける検証データの r2 を採用

結果は以下の通りとなった。
image.png

###結果の考察

  • 検証の条件が全く同じではないため、MoleculeNetと単純に比較はできないが、同等の精度が得られている。
  • Graph Convolution の層の深さ3, Graph Convolution 層のサイズを20とした時に最高精度0.714をたたき出した。
  • 多少、上下はあるものの Graph Convolution の層を深くすると精度が向上する傾向にある。

#今後の展望
将来的には以下の機能の搭載を考えている。考えるだけでもワクワクしてくる。

  • 予測機能
  • 予測に影響のあった部分構造の可視化。これは論文の実装コードに書かれているので、それを参考に実装したい。
  • 多値分類機能
  • 原子の特徴やラベルを取り込めるようにし、分子の予測だけでなく原子の予測も行えるようにしたい。
  • 複数の目的変数を指定してのマルチタスク学習機能
  • 他手法の取り込み。分子内での原子の位置関係の考慮することで、より精度があげられるタスクがあると考えており、特にアノテーションやメッセージパッシングニューラルネットワーク(MPNN)に興味がある。

#おわりに

  • TensorFlowで書かれたライブラリであるDeepChemの解析には断念したが、PyTorchでは1から容易に実装できた。
  • PyTorch Lightningも非常に便利だった。Early Stopping、途中で最も検証制度の良かったモデルを保存する機能、Tensorbordとの連携などは全てこのライブラリのおかげで、何も考えることなく実現できた。
  • GPUとCPUで大幅な性能差がない。手元ではせいぜい2~3倍程度である。これはPyTorchのTensorだけなく、自前のforループやリスト操作による部分のオーバヘッドが原因と考えている。可能な限りPyTorchの演算に移植すればよいが、メンテナンス性、可読性が失われる可能性があり今後の課題である。
  • 実装した時は「シンプルすぎる!」と思ったが人に伝えるとなると中々大変で、この記事を書くだけで丸3日かかってしまった。この論文を書いた人は本当にすごいと思った。

#参考

36
24
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
36
24

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?