Help us understand the problem. What is going on with this article?

MolGAN : An implicit generative model for small molecular graphs

More than 1 year has passed since last update.

はじめに

「機械学習の数理 Advent Calendar 2018」の18日目の記事です。

MolGAN : An implicit generative model for small molecular graphsについて読みましたので、纏めたいと思います。機械学習の数理というよりむしろ、手法の応用な所だと思いますがご容赦頂ければと思います。

SMILESとは

本論文の説明に当たって機械学習以外に1つだけ知識が必要となりますので説明します。WikipediaによるとSMILES (Simplified Molecular Input Line Entry System)は分子の化学構造をASCII符号の英数字で文字列構造の曖昧性のない表記方法である、となっています。

例えば、ニコチンは構造式だと下図になります。


ニコチン.PNG

化学式 : C10H14N2
Smiles : CN1CCC[C@H]1c2cccnc2

このグラフ構造を文字列で表したものがsmilesです。このような文字列表記にすることで、機械学習を用いることが容易になります。

ここでのグラフ構造ではノードが{C, N, H}の組み合わせであり、エッジは結合{単結合、二重結合}を表しています。この場合だとノードは化学式の添え字個数から26, エッジは構造から24となります。このようなグラフは隣接する化合物間の距離でカウントすることで、それぞれの距離での隣接行列(Adjacency Matrix)を求めることができます。

MolGAN概要

generative adversarial network (GANs) + Deep Deterministic Policy Gradient (DDPG, 強化学習)を用いて、目的の分子構造を自動生成していくようなアルゴリズムとなっています。MolGANの特徴は、有効な分子が生成されるところです。

MolGAN_fig1.PNG

この図は、Generatorで分子構造を生成し、DiscriminatorでデータセットとGeneratorを識別します。この時、Generatorで生成された分子はReward Networkで評価されるアルゴリズムとなっています。

分子構造(すなわちSMILES)を生成させるような手法は、様々な方法があります。下記の方法が良く用いられる方法です。

  • RNN (LSTMやGRU)
  • VAE
  • GAN

強化学習では、上記の生成モデルに加えてモンテカルロ木探索(MCTS)やActor Criticも組み合わせて用いられます。

各種理論

SMILESでも述べた通り、化学構造は無向グラフで表すことができます。グラフG, エッジE, ノードV、それぞれの原子ノードを$\upsilon_i \in V$とします。また、Vの属性としてT次元のone-hot vector $x_i$とすると、これは原子の種類(C、H、Nなど)を表します。各原子の結合はエッジ$(\upsilon_i, \upsilon_j) \in E$と表され、結合の種類$y \in {1, ..., Y}$に関連付けられます。

GANのような生成モデルを創薬のために用いるためには、合成可能かつ有用な活性、物性値を持った化合物を生成する必要がありますが、これには強化学習がよく用いられています。

Generative adversarial networks (GANs)

GANの詳しい説明は、Adventカレンダー17日目のkzkadcさんのGANと損失関数の計算についてまとめたを参照いただくとして、ここでは軽く説明します。

  • Generative model $G_\theta$ : 事前分布から新しいデータを生成するモデル
  • Discriminative model $D_\phi$ : $G_\theta$の分布から生成されたものかサンプルから来たものかを識別するモデル

これらの目的関数は2 playersにおけるmini maxゲームです。Stochastic Gradient Descent (SGD)で同時に学習できます。

\min_\theta \max_\phi \mathbb{E}_{x\sim p_{data} (x)} [ \log D_\phi (x)] + \mathbb{E}_{x \sim p_z (z)} [ \log \big(1 - D_\phi  (G_\theta(z)) \big)] 

MolGANでは、ミニバッチdiscriminationと収束性の高いimporved WGANを使用してます。

WGAN

WGANs (Arjovsky et al., 2017)は、2つの確率分布間で定義されたEarth Mover距離 (これはWasserstein距離の1次元の場合)の近似を最小化させるようなモデルです。1次のWasserstein距離の最小化は、Kantrovich-Rubinsteinの双対表現により以下に示す最大化問題を解くことで近似できます。

D_{W} [ p || q] = \frac{1}{K} \sup_{||f||_L < K} \mathbb{E}_{x \sim p(x)} [f(x)] - \mathbb{E}_{x \sim q(x)} [f(x)]

ここで、WGANの例では、pは経験分布でqがgenerator分布です。Lipscitz関数は、関数の連続微分可能性より強い形の連続性を有するような関数です。グラフ上の任意の2点を結ぶ直線の傾きの絶対値はある定数Kを超えないような関数です。すなわち、この上界(supremum)がK-Lipschitz定数です。

d_Y (f(x_1), f(x_2)) \leq K d_X (x_1, x_2) \hspace{1cm}(\forall x_1, x_2 \in X)

さらに連続性を一般化させたものはヘルダー関数です。

WGANでWasserstein距離が用いられているのはJensen-Shannon divergenceよりも最適化しやすいためです。パラメタライズされた関数$\{f_w\}_{w \in \mathcal{W}}$に対して次のよう定式化します。

\max_{w \in \mathcal{W}} \mathbb{E}_{x \sim p_{r}(x)} [f_w(x)] - \mathbb{E}_{z \sim p_{\theta}(z)} [f_w (g_\theta (z))]

このリプシッツ関数$f_w$をニューラルネットワークで近似していくことになります。これがDiscriminatorであり、Wasserstein距離の計算を表しています。$g_\theta$がGeneratorです。Wasserstein距離を最小化させるようなパラメータ$\theta$を求めるためにWasserstein距離を$\theta$で微分します。

\begin{align}
\nabla_{\boldsymbol \theta}W( P_r, P_{\theta}) &= -\mathbb{E}_{z \sim p(z)}\left[\nabla_{\boldsymbol \theta} f(g_{\boldsymbol \theta}(z))\right] \\
&\simeq \frac{1}{M}\sum_{m=1}^M \nabla_{\boldsymbol \theta}f\left(g_{\boldsymbol \theta}(z^{(m)})\right)
\end{align}

Mはバッチ数です。また$f_w$をリプシッツ関数にするため、w ← clip(w, -0.01, 0.01)とさせます。$(w, \theta)$いずれの学習にもRMSPropを用います。

improved WGAN(Gulrajani et al. (2017))では、gradient penaltyとして1-リプシッツ連続の緩い束縛条件を導入し、WGANからgradient clippingを用いてさらに改善しています。WGAN-gpとも呼ばれていますね。generatorに関するlossはWGANと同じまま、discriminatorに関する誤差を下式のように修正しています。

L(x^{(i)} , g_\theta (z^{(i)}; w)) = - f_w (x^{(i)}) + f_w (g_\theta (z^{(i)})) + \alpha (|| \nabla_\hat{x}^{(i)} f_w (\hat{x}^{(i)}) || - 1)^2

第1項+第2項がオリジナルのWGANであり、第3項目の勾配ペナルティ(gp)正則化が特徴です。

ここで、$\alpha$はハイパーパラメータ、$\hat{x}^{(i)}$は$z^{(i)} \sim p_z(z)$における$x^{(i)} \sim p_{data}(x)$と$g_\theta (z^{(i)})$間のサンプリングされた線形和となります。

\hat{x}^{(i)} = \varepsilon x^{(i)} + (1-\varepsilon) g_\theta (z^{(i)}) \\
\varepsilon \sim \mathcal{U}(0,1)

Deep Deterministic policy gradient (DDPG)

強化学習を行うためのアルゴリズムは多々ありますが、DDPGは方策勾配法の一つです。より良い探索のために確率的な行動を取る方策を用いる一方で、決定的な目的方策を推定します。これにより学習がより容易となります。

方策はパラメータ$\theta$をもつような

\pi_\theta (s) = p_{\theta} (a|s)

で定義されます。これは状態$s$での行動$a$を選択する条件付き確率分布を表しています。一方、deterministic policyは下式のようになります。

\mu_\theta(s) = a

これは状態$s$で行動$a$を出力するという決定的方策となっています。

DDPGの概略図を示します。
DDPGャ.PNG

https://medium.com/@deshpandeshrinath/how-to-train-your-cheetah-with-deep-reinforcement-learning-14855518f916

Actor networkでは、Critic networkによって更新された行動価値関数$Q_w(s, a)$に基づき、方策$p_\theta (a_t | s_t)$の$\theta$を更新し、この方策から行動を決定します。このときは交差エントロピーを最小化(尤度の最大化)させるように学習させます。Critic networkではActorの行動を批判します。すなわち、状態$s_t$, 行動$a_t$での報酬予測値(行動価値関数の$Q(s_t, a_t)$値)と即時報酬の2乗誤差(TD誤差)を計算し、この勾配を用いてQ関数のパラメータ$w$を更新します。それぞれ交互にactorとcriticを学習させモデルを更新していくものです。式で書くと下式のようになります。

\nabla_\theta = \alpha \nabla_\theta  \log{ \pi_\theta (a_t | s_t)} \hat{Q}_w(s_{t},  a_t) \\
\nabla_w = \beta \big(R(s_t, a_t) + \gamma \hat{Q}_w (s_{t+1}, a_{t+1}) - \hat{Q}_w (s_t, a_t)\big) \nabla_{w} \hat{Q}_w (s_t, a_t)

1式目は方策勾配定理によって導かれたものです。2式目はTD Learningから導かれた式です。

DDPGの特徴

  1. Experience Replay
    DQNでも用いられる方法ですが、一般に学習を進めていくと、価値関数や方策関数の値は同じような値をとり相関が高くなってしまいます。真のQ関数(Critic)の近似値のバリアンスが大きくなり、学習が進まなくなるため、学習中にエージェントの経験をバッファーに保存し、シャッフルし用いる工夫が施されています。

  2. Target network
    直接、TD誤差から得られた勾配と共にactorとcriticのニューラルネットワークの重みを更新すると発散するか学習が進みません。そのためTD誤差計算のターゲットを生成するネットワークを用いることで学習の安定性が増すことが知られています。TD targetを$y_i$とすると、critic networkのloss functionは下式となります。これを最小化することでcritic networkを更新します。

y_i =  r_i + \gamma Q' (s_{i+1}, \mu' (s_{i+1} | \theta^{\mu'}) | \theta^{Q'}) \\
L = \frac{1}{N} \sum_i (y_i - Q(s_i, a_i| \theta^Q))^2

ここで、Nはリプレイバッファからサンプリングされたミニバッチのサイズです。TD誤差の予測値である$y_i$(TD target)は、「即時報酬$r_i$」と「critic networkから計算される報酬予測値(Q'値)」に時間割引率$\gamma$を掛けた和で計算されます。このQ'値を求めるときは行動$a_{t+1}$の代わりに決定的方策$\mu'$が用いられているのが特徴です。また、critic networkの損失関数Lは先ほどのTD target $y_i$と行動$a_i$での$Q(s_i, a_i | \theta^Q)$値からMSEを計算し、最小化することで$\theta$を更新します。

一方、Actor networkではDeterministic Policy Gradientからのサンプリングを用いて重みを更新します。確率的方策勾配$\nabla_\theta \mu (a | s, \theta)$は方策のパフォーマンスの勾配であり、deterministic policy gradientはMarkov Decision Processでは下式となります。

\nabla_\theta \mu (a | s, \theta)  \approx  E_{\mu'} [ \nabla_{a} Q(s, a | \theta^Q) |_{s=s_t, a=\mu (s_t)}  \nabla_{\theta^\mu} \mu (s | \theta^{\mu} ) |_{s=s_t} ] 

期待値の方策項はactionに対する分布でないです。必要なのは、ミニバッチに対する平均、パラメータに関するactor networkの結果の勾配を乗じた行動に対するcritic networkの結果の勾配だけです。

一方、本論文では、方策はgenerator $G_\theta$であり、インプットとしてサンプル$z$をとっています。これは強化学習での状態$s$に対応しています。アウトプットは分子グラフであり、行動a = Gに対応しています。また、Experience replay, target networksは用いられていません。そのため、即時報酬を予測できる学習可能かつ微分可能な報酬関数$R_\psi$を導入しています。これは分子の合成可能なスコアによって定義された平均2乗誤差関数です。

モデル

MolGANのアーキテクチャは主に3つから成っています。

  • generator $ G_{\theta}$
  • discriminator $ D_{\phi}$,
  • reward network $R_{\psi}$

分子グラフ

無向グラフの構造Gを次のように定義します。原子情報を表すノードをT次元のワンホットベクトル$x_i$とし、その特徴行列は$X = [x_1, ..., x_n]^T \in \mathbb{R}^{N×T}$です。原子の結合を表すための隣接テンソルは$A \in \mathbb{R}^{N×N×Y}$です。ここで、$A_{ij} \in \mathbb{R}^Y$でi, j間のエッジの種類({単結合、二重結合、三重結合、結合無し})のワンホットベクトルです。

MolGAN_fig2.PNG

generatorは事前分布から隣接テンソルA(エッジ)とアノテーション行列X(ノード)を生成します。次にカテゴリカルもしくはgumbel softmaxのサンプリングによりスパースで離散の$\tilde{A}$と$\tilde{X}$がそれぞれ得られます。これらのグラフが化合物情報に相当します。これらは最終的にRelational-GCNを使うことで、ノードの並び方によらず特徴量を抽出できるとともにdiscriminatorとreward networkの入力に使用されます。

詳細

generatorは事前分布からサンプリングされ、分子を表すannotated graph Gが生成されます。ノードとエッジは、それぞれ原子と結合数でアノテーションされており、discriminatorはデータセットとGから生成したサンプルを識別します。$G_\theta, D_\phi$はimproved WGANを用いて学習され、generatorは経験分布と合うように学習することで最終的な結果として有効な分子が出力されるアルゴリズムとなっています。

次にreward networkですが、サンプルの価値関数を近似するために用いられ、強化学習を用いて微分不可能な指標に対し分子構造生成を最適化します。データセットや$G$のサンプルは$\mathcal{R}$のインプットですが、discriminatorと異なりスコア(生成された分子の水に対する溶解性など)を与えられます。reword networkはRDKitを用いて得られたスコアから学習します。有効でない分子の場合は報酬は与えられないようになっています(zero rewards)。

DiscriminatorはWGANのlossと強化学習のlossの線形和を用いて学習しています。

L(\theta) = \lambda L_{WGAN} + (1-\lambda) L_{RL}

ここで$\lambda \in [0,1 ]$はGANと強化学習での生成モデルに対する2つコンポーネント間のトレードオフ(目的特性を重視するか有効な分子生成を重視するか)を調整するハイパーパラメータとなっています。$\lambda=0$なら完全な強化学習であり、$\lambda=1$ならGANによる生成モデルです。

Generator

$G_\phi (z)$は$z \sim \mathcal{N}(0,I), z \in \mathbb{R}^D$からサンプリングされたD次元のベクトルであり、アウトプットはグラフです。

Discriminatorとreward network

Discriminatorのreward networkのインプットはグラフ構造であり、アウトプットはスカラーです。ネットワークは同じアーキテクチャですがパラメータは共有しません。graph convolution層でグラフ隣接テンソルAを用いてノードのシグナルXを畳込みます。複数のエッジに対応するgraph convolution networkのRelational-GCN (Schlichtkrull et al., 2017)のモデルとしています。これはKnowledge baseのグラフであるので、すべてのレイヤーでノードの特徴表現は下式で畳込/逆伝播させることができます。

h_i^{'(l+1)} = f_s^{(l)} (h_i^{'l} , x_i) + \sum_{j=1}^N \sum_{y=1}^Y \frac{A_{ijy}}{|\mathcal{N_i}|}f_y^{(l)}(h_j^{(l)}, x_j) \\
h_i^{(l+1)} = tanh(h_i^{'(l+1)})

ここで、$h_i^{(l)}$はレイヤー$l$でのノード$i$のシグナルです。$f_{s}^{(l)}$はレイヤー間の自己結合のようにふるまう線形変換関数です。$f_{y}^{l}$はエッジに特異的なアフィン変換です。$N_i$はノードiの隣接するセットになります。

graphの畳み込みを通じて各層は伝搬した後、ノードをベクターレベルのグラフ表現に埋め込むことができます。これは、Gated Graph Sequence Neural Network (GGS-NN)に従い行います。Readoutとも呼ばれます。

h'_g = \sum_{v in V} \sigma(i(h_v^L, x_v)) \odot \tanh(j(h_v^L, x_v)) \\
h_g = \tanh (h'_g) \\
\sigma = \frac{1}{1+exp(-x)}

RGCNはMessage Passing Neural Network(MPNN)の特殊な例であると考えることができます。

また、さらに一般化されたフレームワークはDeepMindのgraph_netsです。

tensorflowでは次のようになります。MPNNと対応できるように少し変更しています。

import tensorflow as tf


def mpnn_rgcn(inputs, units, training, activation, dropout_rate=0.,):
    adjacency_tensor, hidden_tensor, node_tensor = inputs
    graph_convolution_units, auxiliary_units = units

    # message passing
    with tf.variable_scope('graph_convolutions'):
        for unit in graph_convolution_units:
            adj = tf.transpose(adjacency_tensor[:, :, :, 1:], (0, 3, 1, 2))
            message = tf.concat((hidden_tensor, node_tensor), -1) if hidden_tensor is not None else node_tensor
            hidden_tensor = tf.stack([tf.layers.dense(inputs=message, units=unit) for _ in range(adj.shape[1])], 1)
            hidden_tensor = tf.matmul(adj, hidden_tensor)
            hidden_tensor = tf.reduce_sum(hidden_tensor, 1) + tf.layers.dense(inputs=message, units=units)
            hidden_tensor = activation(hidden_tensor) if activation is not None else hidden_tensor
            hidden_tensor = tf.layers.dropout(hidden_tensor, dropout_rate, training=training)

    # readout
    _, hidden_tensor0, node_tensor = inputs
    with tf.variable_scope('graph_aggregation'):
        if hidden_tensor0 is not None:
            message = tf.concat((hidden_tensor, hidden_tensor0, node_tensor), -1)
        else:
            message = tf.concat((hidden_tensor, node_tensor), -1)

        i = tf.layers.dense(message, units=auxiliary_units, activation=tf.nn.sigmoid)
        j = tf.layers.dense(message, units=auxiliary_units, activation=activation)
        output = tf.reduce_sum(i * j, 1)
        output = activation(output) if activation is not None else output
        output = tf.layers.dropout(output, dropout_rate, training=training)

    return output

実験条件

Generator architecture

  • 最大ノード N = 9,
  • 原子タイプ T = 5 (C, O, N, Fと one padding), 
  • 結合タイプ Y = 5 (単, 二重, 三重結合, リング, 結合無し)
  • generatorは正規分布$\mathcal{N}(0,I)$から32次元のベクトルをサンプリング。
  • 3-layer MLP of [128, 256, 512]
  • 活性化関数$\tanh$

Discriminator and reward network architecture

Generatorの出力は512次元になるためノード、エッジはそれぞれ5*9 =45次元, 5*9*9=405次元にした後(5, 9), (5, 9, 9)にreshapeします。また、隣接テンソルは対称性があるので平均化させておきます。このグラフ構造をRelational GCN encoder with 2 layers [64, 32]で畳み込み計算します。その後128次元のグラフ表現にし、2 layer MLP [128, 1] with tanhで計算しています。reward networkではsigmoid functionを用いています。

評価指標

Samanta et al (2018) : validity , novelty, and uniqueness.

  • Validity : valid数と全ての生成された分子数の比
  • Novelty : データセットでない有効なサンプルセットと有効な全サンプル数の比
  • Uniqueness : ユニークサンプルと有効サンプルの比

物性予測の場合、独自の予測器を作成しないといけないかもしれませんね。

結果

実際の論文のソースコードはgithubに公開されています。実際に実装してみたかったのですが、時間が足りませんでした・・・10 epoch試した結果を示します。

2018-12-17 20:38:53 Validation --> {'NP score': 0.926981851250385,
 'QED score': 0.49032351626089926,
 'SA score': 0.3339037990758661,
 'diversity score': 0.5755544722836816,
 'drugcandidate score': 0.4007257305767061,
 'la': 1.0,
 'logP score': 0.3475431831460153,
 'loss D': -114.99824,
 'loss G': 64.350525,
 'loss RL': -0.6873395,
 'loss V': 0.6542503,
 'novel score': 72.6204128440367,
 'unique score': 5.848623853211009,
 'valid score': 69.760000705719}

30 min位計算してみましたが、論文で出ているようなvalidスコアには全然なりませんでした。もう少し試したみたいですね。

MolGANの問題

強化学習をしているとmode collapseが発生します。また、正しいグラフはあまり生成されません。生成モデルから生成したグラフの隣接行列が正しいものでないため、2分子が生成されてしまいました。

グラフの学習方法を工夫する必要があるかもしれません。

おわりに

MolGANの要素技術についてまとめました。
生成モデル(WGANs)、強化学習(DDPG)、グラフ畳み込み(relational-GCN)と様々な方法が用いられています。こういった機械学習の方法を組み合わせ、活性の高い医薬品や新規化合物が合成できたら良いですね。

要素技術の実装部分があまり追えていないので、こちらも実装していきたいのと強化学習をどんどん試していきたいと思います。

Reference

tsugar
リサーチエンジニア。ケモ(マテリアルズ)インフォマティクスの研究をしています。業務外では機械学習を用いたシステムトレード(EA)に関心があります。
https://udnp.hatenablog.com/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away