AlphaFold -Structure Moduleについて-
AlphaFoldのstructure moduleについて勉強したことをまとめるために記事を書きました.間違っていたら(僕がしょげてしまわないようなるべく優しく)教えてください.
記事中の各algorithmは論文のSupplementary fileから引用しています.
論文へのリンク
https://www.nature.com/articles/s41586-021-03819-2#Sec20
Structure module全体の流れ
Structure moduleではEvoformer部分で抽出したMSAのうち入力配列に相当するアミノ酸配列の特徴量 $s$とpair representation 特徴量 $z$ を入力として全原子の座標 $\mathbf{x}$ と残基ごとの予測信頼度スコアplDDTを出力する.
Structure moduleは8つの重みを共有したレイヤから構成されている.
各レイヤはアミノ酸配列の特徴量 $s$ および残基ごとに定義された座標系$T_i$(オブジェクト座標に相当,これを定義しておくと正解構造との誤差を計算する際にRMSDとは異なる,キラリティを認識できる手法FAPEでの評価ができるため嬉しい,また残基ごとにCA原子の向きが決まれば文献値をもとに結合角,結合長を用いて同一残基中の他の原子を表すことができ,考慮すべき自由度がねじれ角 torsion angle のみになって嬉しい)にプロットされた3D representation を更新する.なお $T_i$ は$T_i=(R_i,\mathbf{t_i}$)で定義され,各残基に対して定められた座標系をグローバル座標系(全体に対して定められた基準となる座標系,ワールド座標に相当)に対して重ね合わせる回転を表す回転行列 $R_i$ と並進移動を表すベクトル $\mathbf{t}_i$ の組となっている.すなわち
\begin{aligned}
\mathbf{x}_{global}&=T_i\circ\mathbf{x}_{local}\\
&=R_i\mathbf{x}_{local}+\mathbf{t}_i
\end{aligned}
が成り立つ(この計算では座標系が変換されただけで点の位置が変化したわけではないことに注意).
Structure moduleの最初で $T_i$ は $T_i=(\mathbf{I}, \mathbf{0})$ と初期化される.各レイヤ中で具体的には
- {$s_i$}をIPAにより更新する
- 更新された{$s_i$}を用いて$T_i$を更新する
- {$s_i$}を用いて側鎖の原子位置を予測する
- 誤差計算で用いられる$\mathcal{L}_{aux}$を計算する
以上の操作が行われる(4は学習時のみ).
IPA について
まずよくある普通のattention について
attentionわかるよって人は読み飛ばしてください,逆にちゃんと勉強したい人も他にいい記事があるのでそちらを読んでみてください.
attentionは深層学習で不可欠な手法で,AlphaFoldの実装中で言えばEvoformer部分にも使われている.attentionhはベクトルを要素にもつ入力系列$\mathbf{x}$を要素間の関係などの情報を取り込んだ対応するベクトル系列$\mathbf{y}$に変換する関数である.クエリ$Q$,キー$K$,値$V$,とし,クエリとキーの次元数を$d_k$としてattentionが行う計算は
$$
Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V
$$
以上の式で表すことができる.これの意味するところを確認していく.まず入力されたクエリとキーとの類似度を内積($\mathbf{a}^T\mathbf{b}=||\mathbf{a}||||\mathbf{b}||\cos\theta$であり,方向が似ているベクトルであればcosの値が大きくなる)により計算し,次元数でスケールをする.内積の結果得られる類似度が足して1になるようsoftmax関数を適用する.その後キーに対応した値に類似度で重み付けし,値の荷重和を取ることでクエリに対する値を得,これを出力とする.
attentionは、クエリとキーが似ているかどうかで、どの要素の値を読み込むかどうか(すなわちどの要素に注意を払うか)を制御・決定している。学習の結果、その要素から値を読み込んだ方がよければ対応するクエリとキーは近づくように更新され、読み込まない方が良ければクエリとキーは離れるように更新される。こうして、どの要素を読み込むかどうか、データのルーティングが学習で自動決定される。
*正確にはEvoformer部分も含めて実際に使われているのはMulti-head Attentionというattentionを拡張したものです.行われている処理の本質は上に示したattentionなので気になる人は調べてみてください.
本題のIPAについて
IPAでは上記のattentionにアミノ酸配列由来の特徴量 $s$ を基本としてpair representation, および 変換 $T_i$ を取り入れる構成になっている.以下は論文のSupplementary Algorithm22から
IPAは特徴量 $s$ を立体構造の情報を加味して更新するが,この時にタンパク質の構造はglobal座標系の元で並進,回転を行っても全体として構造が変わらないということをattentionに知らせつつ,それぞれのlocalな座標系 $T_i, T_j$ にプロットされたクエリ $\mathbf{q}_i$ とキー $\mathbf{k}_j$ の関係性を $s$ の更新に使いたいという意図を叶えるように構成されている.具体的にはlocalな座標系$T_i$が絡む計算では変換$T_i$によりあらかじめglobalな座標系に変換してからL2ノルムをとる,逆変換 $T_i^{-1}$ を作用させることでglobalな座標系での操作に変えてあげるという方法を取っている.
以下AlphaFold論文のサプリ中の擬似コードをみていく
$$
\mathbf{q}_i^h,\mathbf{k}_i^h,\mathbf{v}_i^h=LinearNoBias(\mathbf{s}_i)
$$
$$
\mathbf{q}_i^{hp},\mathbf{k}_i^{hp}=LinearNoBias(\mathbf{s}_i)
$$
$$
\mathbf{v}_i^{hp}=LinearNoBias(\mathbf{s}_i)
$$
$$
b_{ij}^h = LinearNoBias(\mathbf{z}_{ij})
$$
1番目の式ではアミノ酸配列由来の特徴量$s_i$を線形変換しattentionの入力であるクエリ,キー,値を得ている.2番目,3番目の式は同様にアミノ酸配列由来の特徴量 $s_i$ から線形変換によりIPAの式中第3項で利用する局所座標上の点として$\mathbf{q}_i^{hp}$,$\mathbf{k}_i^{hp}$,$\mathbf{v}_i^{hp}$を得ている.4番目の式ではpair representation $\mathbf{z}_{ij}$ から次式の入力であるpair representation $b_{ij}$を線形変換により得ている.
a^h_{ij}=softmax_k(w_L(\frac{1}{\sqrt{c}}\mathbf{q}_i^{hT}\mathbf{k}_i^h+b_{ij}^h-\frac{\gamma^hw_C}{2}\sum_p\parallel T_i\circ\mathbf{q}_i^{hp}-T_j\circ\mathbf{k}_j^{hp}\parallel^2))
attentionの式中にみられた$softmax(\frac{1}{\sqrt{d}_k}\mathbf{Q}^T\mathbf{K})$部分に相当する計算式にpair representation $b_{ij}$, および 変換 $T_i$ に由来する項を導入している.第二項でpair representation由来の情報を利用し,第三項で局所座標をグローバルな座標系に変換し,$ \mathbf{q}_i, \mathbf{k}_j$間の距離の情報を取り入れている.これらをsoftmax関数に入れてクエリとキーの関係性を取り出している.最後に $a_{ij}$ と$\mathbf{z}_{ij}, \mathbf{v}_j$ との積をとり,連結し線形変換することで更新した $s$ を得ている.(global座標系における不変性の証明はサプリ中で丁寧に計算されているのでそちらをみてね)
なお一連の操作は概念図としては以下のようになる
(論文のSupplementary Fig.8から引用)
Backbone update について
次にIPAで更新された $s_i$ を用いてlocalな座標系 $T_i$ の更新を行う.localな座標系$T_i$ はそれをglobalな座標系に重ね合わせる変換として
$$
T_i=(R_i, \mathbf{t}_i)
$$
で表されるため,回転成分 $R_i$ と並進成分 $\mathbf{t}_i$ をそれぞれ更新する.三次元での回転を表す方法としては回転行列,ロドリゲスの回転公式,オイラー角,四元数といった4つの方法があるが,backbone updateでは特徴量 $s$ を四元数の各成分に線形変換し,単位四元数に正規化,単位四元数を回転行列に直すという手法を用いている.
四元数から回転行列へ変換するところの補足
一般に四元数は
$$
i^2=j^2=k^2=-1
$$
$$
ij=-ji=k
$$
$$
jk=-kj=i
$$
$$
ki=-ik=j
$$
を満たす虚数単位i, j, kを用いて
$$
q = a+bi+cj+dk
$$
と表すことができる.
この四元数qに対して絶対値と共軛四元数をそれぞれ以下のように定める
$$
\parallel q\parallel=\sqrt{a^2+b^2+c^2+d^2}
$$
$$
q^*=a-bi-cj-dk
$$
この時四元数rを$r=r_1i+r_2j+r_3k$とし四元数$r'$を
$$
r'=qrq^*
$$
とすると四元数i, j, kの性質から
\begin{aligned}
\begin{split}
r'&=(a+bi+cj+dk)(r_1i+r_2j+r_3k)(a-bi-cj-dk)\\
&=\left\{(a^2+b^2-c^2-d^2)r_1+2(bd-ad)r_2+2(ac+bd)r_3\right\}i\\
&\quad+\left\{2(ad+bc)r_1+(a^2-b^2+c^2-d^2)r_2+2(-ab+cd)r_3\right\}j
&\quad+\left\{2(bd-ac)r_1+2(cd+ab)r_2+(a^2-b^2-c^2+d^2)r_3\right\}k
\end{split}
\end{aligned}
と表せる.ここで$r'=r'_0+r'_1i+r'_2j+r'_3k$とおくと
$$
r'_0=0
$$
$$
r'_1=(a^2+b^2-c^2-d^2)r_1+2(bd-ad)r_2+2(ac+bd)r_3
$$
$$
r'_2=2(ad+bc)r_1+(a^2-b^2+c^2-d^2)r_2+2(-ab+cd)r_3
$$
$$
r'_3=2(bd-ac)r_1+2(cd+ab)r_2+(a^2-b^2-c^2+d^2)r_3
$$
のように書くことができて,これを行列の形式でかくと
\begin{aligned}
\begin{pmatrix}
r'_1 \\
r'_2 \\
r'_3
\end{pmatrix}
&=\begin{pmatrix}
a^2+b^2-c^2-d^2 & 2(bd-ad) & 2(ac+bd) \\
2(ad+bc) & a^2-b^2+c^2-d^2 & 2(-ab+cd) \\
2(bd-ac) & 2(cd+ab) & a^2-b^2-c^2+d^2
\end{pmatrix}
\begin{pmatrix}
r_1 \\
r_2 \\
r_3
\end{pmatrix}\\
&=R
\begin{pmatrix}
r_1 \\
r_2 \\
r_3
\end{pmatrix}
\end{aligned}
となり,$qrq^*$ は $r$ の虚数成分からなる3次元ベクトルに行列$R$をかけることに相当する.面倒くさくなってしまったので証明は省きますが,$\parallel q\parallel=1$ の時 $R$ は回転行列になります.
サプリ中のAlgorithm 23は上記の計算を行い$T_i$の更新を行っています.
全原子の座標の計算について
Structure module全体の流れの説明の中で全原子の座標の計算でやっていることを若干ネタバレしてしまったが,原子の座標は$T_i$ つまりCA原子について座標と向きが予測できているので文献値をもとに残基中の他の原子を生やす操作で求められる.
主鎖のCA原子の向きが決まっていることでCA原子と結合しているC,N,CB原子の座標は一意に定まる.また側鎖の原子の座標はねじれ角を次々定めていけば定めることができる.
AlphaFoldではこのねじれ角の依存性に着目して各残基の各原子を原子種,角の依存性および対称性で分類している.
(論文のSupplementary Table.2から引用)
ねじれ角も予測で求める.アミノ酸配列の特徴量 $s$ をResNet(残差接続と全結合層と活性化関数を組み合わせた深層学習のモデル)の入力として$\mathbf{\alpha}_i$を得る.ここで求めているのは各原子についてのねじれ角であるが,この$\mathbf{\alpha}_i$は直接の角度$\theta$ではなく平面上にプロットされた点として表され,正規化することで単位円上の点として表現される($\theta$で定義すると$[0,2\pi]$で不連続になるため扱いづらい,またこのあと回転行列を用いて各原子の座標を求める際に,平面上の点として表現することで三角関数を用いた計算を省くことができて嬉しいらしい).
残基ごとに定められたlocalな座標と予測ねじれ角,および文献値に由来する結合長,結合角に基づいて各原子をlocalな座標系にプロットしていく.これを実現するために残基ごとに定義されたlocalな座標系に対してさらに原子ごとに定義されねじれ角の軸がx軸と一致するようなlocalな座標系を定義する.表中の$X_1$ に属する原子を例に取ると,まず予測ねじれ角$\alpha$に基づいてCA-CB軸を回転させる.次にCB-CGがx軸となるように定義された座標系をCAが原点となるよう残基ごとに定義されたlocalな座標系に変換する.最後にこのlocalな座標系をglobalな座標系に変換する$T_i$を作用させることでglobalな座標系で原子の位置を定めることができる.$X_2, X_3, \dots$ に属する原子についても次々に次々に座標変換を適用していくことでglobalな座標系での座標を定めることができる.
ねじれ角に応じて原子を回転させるalgorithm 25 の回転行列Rは二次元平面の単位円上で定義された$\alpha$をそのまま利用できる形式になっている.
その他言及しなかった諸問題について
ここまでで予測構造がどのように作られるかみてきたが,深層学習では最終的に正解構造と比較して誤差を出す必要がある.誤差の計算方法について詳しくは誤差関数FAPEについての節に書いてある.正解構造に対するlocalな座標系の決め方はCA-Nベクトル,CA-Cベクトルに対してグラムシュミットの方法を使い正規直行座標を定め,直行な基底同士のcross積を計算することで座標系を定義している.
またアミノ酸残基にのなかには180°対称性をもつものがある(表中で四角で囲まれている原子).これらは正解構造の原子の名前づけを変更することで適切に誤差が計算できるように対応している.
原子間の結合長,結合角は文献値で固定化されているが,誤差を計算する予測構造に対してはamber力場で緩和することにより,物理的に正しそうな結合長,結合角になるよう計算される.
以上がAlphaFoldのstructure moduleでやっていることです,気が向いたらstructure module以外の部分についてもちゃんと勉強します.