LoginSignup
5
2

1. AlphaFold2をコードで理解する

AlphaFold2(以下AlphaFold)とはGoogle DeepMindが開発したタンパク質構造予測モデルで、アミノ酸配列からタンパク質の立体構造を予測します。2021年に発表されたHighly accurate protein structure prediction with AlphaFoldでアルゴリズムが説明されていて、ソースコードも公開されています。

アルゴリズムの概要を理解することを目的とした記事はいくつか既にあるので、この記事ではPyTorchで各ブロックを簡易的に実装し、どのような処理が行われているかをコードベースで理解することを目的とします。

1.2 AlphaFoldアルゴリズム概要

AlphaFoldではMSA(Multiple Sequene Alignmet)表現という入力アミノ酸配列に似ている配列を探してきて並べたものをもとにした特徴量と、ペア表現という入力配列の各残基間の関係性を表現するような特徴量の2つを使って特徴抽出します。

MSA表現、ペア表現をEvoformerブロックと呼ばれるモジュールに通して洗練し、その結果をStructure Moduleと呼ばれるモジュールに通して立体構造の座標情報を出力します。最終的に5種類の誤差関数を使って正解となる立体構造との誤差を最小化するように学習を行います。
スクリーンショット 2023-12-23 183559.png

1.3 実装内容

あくまで理解を目的としているため、本物のAlphaFoldのアルゴリズムからは以下の点は省略して実装しました。

  • MSA取得部分の処理
  • 学習時MSA表現にマスクをかける処理
  • Extra MSA処理
  • 出力結果を再度ネットワークに入力し直すRecycling処理
  • 側鎖の座標予測
  • 誤差関数

PyTorchによる実装は既にあるため、そちらも参考にしつつ本記事の実装に上記を追加することで全体を自分で実装することは可能だと思います。

アルゴリズムは大きくわけると以下3ブロックになるのでそれぞれについて説名していきます。

  1. MSA表現、ペア表現の埋め込み
  2. Evoformerブロック
  3. Struture Module

コードはgithubにまとめてあります

2. Input Embedding

入力配列とMSA表現からEvoformerブロックの入力となるMSA表現、ペア表現を作成します。Recycling処理やテンプレートなどからの入力も同時にありますが簡単のため本実装では省略します。
スクリーンショット 2023-12-23 190812.png
スクリーンショット 2023-12-23 190716.png
スクリーンショット 2023-12-23 190741.png

最初に必要なライブラリをインポートしておきます

import torch
import torch.nn as nn
from dataclasses import dataclass
from einops import rearrange, repeat

入力埋め込みモジュールです。

class InputEmbedder(nn.Module):
  """
  Algorithm3: Embeddings for initial representations
  """
  def __init__(self, channel_size_feat, channel_size_msa, channel_size_pair, sequence_size, sequence_num):
    super().__init__()
    self.to_a = nn.Linear(channel_size_feat, channel_size_pair)
    self.to_b = nn.Linear(channel_size_feat, channel_size_pair)
    self.to_m1 = nn.Linear(channel_size_feat, channel_size_msa)
    self.to_m2 = nn.Linear(channel_size_feat, channel_size_msa)
    self.pos_proj = nn.Linear(65, channel_size_pair)
    self.sequence_size = sequence_size
    self.sequence_num = sequence_num

  def forward(self, target_feat, residue_index, msa_feat):
    a = self.to_a(target_feat)
    b = self.to_b(target_feat)
    a = repeat(a, 'b s c -> b r s c', r=self.sequence_size)
    b = repeat(b, 'b s c -> b s r c', r=self.sequence_size)
    z = a + b
    z += self.relpos(residue_index)

    target_feat = repeat(target_feat, 'b r c -> b n r c', n=self.sequence_num)
    m = self.to_m1(msa_feat) + self.to_m2(target_feat)
    return m, z

  def relpos(self, residue_index, vbins=(torch.arange(65)-32)):
    """
    Algorithm4: Relative position enccoding
    """
    d_left = repeat(residue_index, 'b r -> b i r', i=self.sequence_size)
    d_right = repeat(residue_index, 'b r -> b r i', i=self.sequence_size)
    d = d_left - d_right
    p = self.one_hot(d, vbins)
    p = p.to(torch.float32)
    p = self.pos_proj(p)
    return p

  def one_hot(self, x, vbins):
    """
    Algorithm5: One-hot ecoding with nearest bin
    """
    bin_size = vbins.shape[0]
    b, r1, r2 = x.shape
    x = repeat(x, 'b r1 r2 -> b r1 r2 d', d=bin_size)
    vbins = repeat(vbins, 'd -> b r1 r2 d', b=b, r1=r1, r2=r2)
    index = torch.argmin(torch.abs(x - vbins), dim=-1)
    index = index.flatten()

    p = torch.zeros_like(x, dtype=x.dtype)
    p = rearrange(p, 'b r1 r2 v -> (b r1 r2) v')
    p[index] = 1
    p = rearrange(p, '(b r1 r2) v -> b r1 r2 v', r1=self.sequence_size, r2=self.sequence_size)
    return p

試しにこのブロックに入力してみるため、まず定数を整理するため以下のconfigを作ります。

@dataclass(frozen=True)
class Config:
  batch_size:int = 2 #バッチサイズ
  sequence_size:int = 100 #アミノ酸残基数(数字は適当)
  sequence_num:int = 7 #MSAで取ってくる配列数+1(+1は入力配列の分)
  channel_size_feat: int = 23 #入力MSAのチャンネルサイズ(アミノ酸20種類+不明(1)+ギャップ(1)+マスク(1)=23)
  channel_size_msa:int = 10 #MSA表現チャンネルサイズ(数字は適当)
  channel_size_pair:int = 10 #ペア表現チャンネルサイズ(数字は適当)

入力配列に対するMSAは取得してきたものとして、以下のように疑似的に入力配列、MSA、残基インデックス配列を用意しておきます。

config = Config()
residue_index = repeat(torch.arange(config.sequence_size), 's -> b s', b=config.batch_size)
msa_feat = torch.randn(
    config.batch_size,
    config.sequence_num,
    config.sequence_size,
    config.channel_size_feat
)
target_feat = torch.randn(
    config.batch_size,
    config.sequence_size,
    config.channel_size_feat
)

これらをInputEmbedderに入力するとMSA表現、ペア表現に対する最初の埋め込みを得ます。

embed = InputEmbedder(config.channel_size_feat, config.channel_size_msa, config.channel_size_pair, config.sequence_size, config.sequence_num)
print(f"{target_feat.shape=}")
print(f"{residue_index.shape=}")
print(f"{msa_feat.shape=}")
print()
msa_repr, pair_repr = embed(target_feat, residue_index, msa_feat)
print(f"{msa_repr.shape=}")
print(f"{pair_repr.shape=}")
#出力結果
target_feat.shape=torch.Size([2, 100, 23])
residue_index.shape=torch.Size([2, 100])
msa_feat.shape=torch.Size([2, 7, 100, 23])

msa_repr.shape=torch.Size([2, 7, 100, 10])
pair_repr.shape=torch.Size([2, 100, 100, 10])

3. Evoformer

Evoformerブロックは下図のようにMSA表現を処理する部分(図の上段)とペア表現を処理する部分(図の下段)に分かれており、それぞれについてみていきます。
スクリーンショット 2023-12-23 193514.png

3.1 Process for MSA representation

3.1.1 MSARowWiseAttentionWithPairBias

まずMSA表現の行方向にアテンションをとって特徴抽出します。
スクリーンショット 2023-12-23 194210.png
スクリーンショット 2023-12-23 194225.png

class MSARowWiseAttentionWithPairBias(nn.Module):
  """
  Algorithm7: MSA row-wise gated self-attetion with pair bias
  """
  def __init__(self, sequence_num, sequence_size, channel_size_msa, channel_size_pair, hidden_channel=32, nhead=8):
    super().__init__()
    self.layernorm = nn.LayerNorm(sequence_size)
    self.to_q = nn.Linear(channel_size_msa, nhead*hidden_channel, bias=False)
    self.to_k = nn.Linear(channel_size_msa, nhead*hidden_channel, bias=False)
    self.to_v = nn.Linear(channel_size_msa, nhead*hidden_channel, bias=False)
    self.to_bias = nn.Linear(channel_size_pair, nhead)
    self.to_g = nn.Linear(channel_size_msa, nhead*hidden_channel)
    self.last_layer = nn.Linear(nhead*hidden_channel, channel_size_msa)

    self.nhead = nhead
    self.hidden_channel = hidden_channel
    self.sequence_size = sequence_size

    self.sigmoid = nn.Sigmoid()
    self.softmax = nn.Softmax(dim=3)
    self.scale = 1 / (hidden_channel ** 0.5)

  def forward(self, msa_repr, pair_repr):
    #Input Projecction
    x = rearrange(msa_repr, 'b s r c -> b c s r')
    x = self.layernorm(x)
    x = rearrange(x, 'b c s r -> b s r c')
    tbatch = x.shape[1]
    x = rearrange(x, 'b h w c -> (b h) w c') #here should be changed in column-wise attention
    q = self.to_q(x)
    k = self.to_k(x)
    v = self.to_v(x)
    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.nhead), (q, k, v))

    bias = self.to_bias(pair_repr)
    bias = rearrange(bias, 'b i j h -> b h i j')
    bias = repeat(bias, 'b i j h -> (a b) i j h', a=tbatch)

    g = self.sigmoid(self.to_g(x))
    g = rearrange(g, 'b i (h d) -> b h i d', d = self.hidden_channel)

    #Attention
    dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
    attn = self.softmax(self.scale * dots + bias)
    out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
    out = g * out

    #Output projetion
    out = rearrange(out, '(b h) n w d -> b h w n d', h=tbatch)
    out = torch.concat([out[:,:,:,i,:] for i in range(out.shape[3])], dim=-1) #out.shape = (b,h,w,n*d)
    out = self.last_layer(out) #out.shape = (b,h,w,c)

    return out

3.1.2 MSAColumnWiseAttention

次にMSA表現の列方向にアテンションをとって特徴抽出します。
スクリーンショット 2023-12-23 194525.png
スクリーンショット 2023-12-23 194541.png

class MSAColumnWiseAttention(nn.Module):
  """
  Algorithm8: MSA column-wise gated self-attention
  """
  def __init__(self, sequence_num, sequence_size, channel_size_msa, channel_size_pair, hidden_channel=32, nhead=8):
    super().__init__()
    self.layernorm = nn.LayerNorm(sequence_num)
    self.to_q = nn.Linear(channel_size_msa, nhead*hidden_channel, bias=False)
    self.to_k = nn.Linear(channel_size_msa, nhead*hidden_channel, bias=False)
    self.to_v = nn.Linear(channel_size_msa, nhead*hidden_channel, bias=False)
    self.to_bias = nn.Linear(channel_size_pair, nhead)
    self.to_g = nn.Linear(channel_size_msa, nhead*hidden_channel)
    self.last_layer = nn.Linear(nhead*hidden_channel, channel_size_msa)

    self.nhead = nhead
    self.hidden_channel = hidden_channel
    self.sequence_size = sequence_size

    self.sigmoid = nn.Sigmoid()
    self.softmax = nn.Softmax(dim=3)
    self.scale = 1 / (hidden_channel ** 0.5)

  def forward(self, msa_repr):
    #Input Projecction
    msa_repr = rearrange(msa_repr, 'b s r c -> b r s c')
    x = rearrange(msa_repr, 'b s r c -> b c s r')
    x = self.layernorm(x)
    x = rearrange(x, 'b c s r -> b s r c')
    tbatch = x.shape[1]
    x = rearrange(x, 'b h w c -> (b h) w c') #here should be changed in column-wise attention
    q = self.to_q(x)
    k = self.to_k(x)
    v = self.to_v(x)
    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.nhead), (q, k, v))

    g = self.sigmoid(self.to_g(x))
    g = rearrange(g, 'b i (h d) -> b h i d', d = self.hidden_channel)

    #Attention
    dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
    attn = self.softmax(self.scale * dots)
    out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
    out = g * out

    #Output projetion
    out = rearrange(out, '(b h) n w d -> b h w n d', h=tbatch)
    out = torch.concat([out[:,:,:,i,:] for i in range(out.shape[3])], dim=-1) #out.shape = (b,h,w,n*d)
    out = self.last_layer(out) #out.shape = (b,h,w,c)
    out = rearrange(out, 'b h w c -> b w h c')

    return out

3.1.3 Transition

最後に全結合層に通します。
スクリーンショット 2023-12-23 194730.png
スクリーンショット 2023-12-23 194745.png

class Trasition(nn.Module):
  """
  Algorithm9: Transition layer in the MSA stack
  """
  def __init__(self, channel_size_msa, n=4):
    super().__init__()
    self.norm = nn.LayerNorm(channel_size_msa)
    self.linear1 = nn.Linear(channel_size_msa, n*channel_size_msa)
    self.relu = nn.ReLU()
    self.linear2 = nn.Linear(n*channel_size_msa, channel_size_msa)

  def forward(self, msa_repr):
    x = self.norm(msa_repr)
    x = self.linear1(x)
    x = self.relu(x)
    x = self.linear2(x)

    return x

3.2 Process for Pair represetation

ペア表現の方の処理は大きく以下お3つが必要で、下2つはそれぞれoutgoingとincoming、startingとendingの2種類ありますが似ているので2種類ずつまとめて実装します。

  • OuterProductMean
  • TriangularMultiplicativeUpdate
  • TriangularGatedSelfAttention

3.2.1 OuterProductMean

MSA表現からの情報を受け取るため、MSA表現各列同士のテンソル積をとってペア表現の形状に変形します。
スクリーンショット 2023-12-23 195327.png
スクリーンショット 2023-12-23 195339.png

class OuterProductMean(nn.Module):
  """
  Algorithm10: Outer product mean
  """
  def __init__(self, sequence_num, sequence_size, channel_size_msa, channel_size_pair, hidden_channel=32):
    super().__init__()
    self.norm = nn.LayerNorm(channel_size_msa)
    self.left_proj = nn.Linear(channel_size_msa, hidden_channel)
    self.right_proj = nn.Linear(channel_size_msa, hidden_channel)
    self.out_proj = nn.Linear(hidden_channel**2, channel_size_pair)

    self.sequence_size = sequence_size
  def forward(self, msa_repr):
    x = rearrange(msa_repr, 'b s r c -> (b r) s c')
    x = self.norm(x)
    left = self.left_proj(x) #left.shape = (b,s,d)
    right = self.right_proj(x)
    left = repeat(left, '(b r) s d -> b r s d', r=self.sequence_size)
    right = repeat(right, '(b r) s d -> b r s d', r=self.sequence_size)
    outer = rearrange(left, 'b r s d -> b r () s d') * rearrange(right, 'b r s d -> b () r s d')
    outer = rearrange(outer, 'b a c s d -> b a c s d ()') * rearrange(outer, 'b a c s d -> b a c s () d')
    outer = outer.mean(dim=3)
    outer = rearrange(outer, 'b i j d e -> b i j (d e)')
    z = self.out_proj(outer)
    return z

3.2.2 TriangularMultiplicativeUpdate

各残基間の関係性を学習するため、ペア表現の2ノードijから伸びる任意のノードへの情報からijの特徴量を更新します。これがoutgoingの場合です。incomingでは逆に任意のノードから伸びるijへの枝からの更新を受けます。

文章で書くとよくわからず、理解が未だに曖昧な点ですがコードにすると(設計の意図はともかく)どんな処理をしているかは理解できます。outgoingとincomingは2種類まとめて実装します。

スクリーンショット 2023-12-23 195709.png
スクリーンショット 2023-12-23 195718.png
スクリーンショット 2023-12-23 195730.png

class TriangularMultiplicativeUpdate(nn.Module):
  """
  Algorithm 11: Triangular multiplicative update using 'outgoing' edges
  Algorithm 12: Triangular multiplicative update using 'incoming' edges
  """
  def __init__(self, channel_size_pair, hidden_channel=128, mode='outgoing'):
    super().__init__()
    self.norm = nn.LayerNorm(channel_size_pair)
    self.to_g1 = nn.Linear(channel_size_pair, hidden_channel)
    self.to_g2 = nn.Linear(channel_size_pair, hidden_channel)
    self.to_g3 = nn.Linear(channel_size_pair, channel_size_pair)
    self.left_proj = nn.Linear(channel_size_pair, hidden_channel)
    self.right_proj = nn.Linear(channel_size_pair, hidden_channel)
    self.norm2 = nn.LayerNorm(channel_size_pair)
    self.out_proj = nn.Linear(hidden_channel, channel_size_pair)
    self.sigmoid = nn.Sigmoid()

    assert mode in ['outgoing', 'incoming'], 'mode must be either "outgoing" or "incoming"'
    self.mode = mode

  def forward(self, pair_repr):
    z = self.norm(pair_repr)
    g1 = self.sigmoid(self.to_g1(z))
    g2 = self.sigmoid(self.to_g2(z))
    left = g1 * self.left_proj(z)
    right = g2 * self.right_proj(z)
    if self.mode == 'outgoing':
      x = torch.einsum('b i k d, b j k d -> b i j d', left, right)
    elif self.mode == 'incoming':
      x = torch.einsum('b k i d, b k j d -> b i j d', left, right)
    x = self.sigmoid(self.to_g3(z)) * self.norm2(self.out_proj(x))

    return x

3.2.3 TriangularGatedSelfAttention

TriangularMultiplicativeUpdateをもう少し複雑にしたような処理で、ペア表現の行方向と列方向にアテンションをとります。こちらもstarting nodeを基準とする場合とending nodeを基準とする場合で2種類ありますがまとめて実装します。

class TriangularGatedSelfAttention(nn.Module):
  """
  Algorithm 13: Triangular gated self-attetion around starting node
  Algorithm 14: Triangular gated self-attetion around ending node
  """
  def __init__(self, channel_size_pair, sequence_size, hidden_channel=32, nhead=4, around='starting'):
    super().__init__()
    self.norm = nn.LayerNorm(channel_size_pair)
    self.to_q = nn.Linear(channel_size_pair, nhead*hidden_channel, bias=False)
    self.to_k = nn.Linear(channel_size_pair, nhead*hidden_channel, bias=False)
    self.to_v = nn.Linear(channel_size_pair, nhead*hidden_channel, bias=False)
    self.to_bias = nn.Linear(channel_size_pair, nhead, bias=False)
    self.to_g = nn.Linear(channel_size_pair, nhead*hidden_channel)
    self.out_proj = nn.Linear(nhead*hidden_channel, channel_size_pair)

    self.sigmoid = nn.Sigmoid()
    self.softmax = nn.Softmax(dim=4)

    assert around in ['starting', 'ending'], 'around must be either "starting" or "ending"'
    self.around = around
    self.nhead = nhead
    self.sequence_size = sequence_size
    self.scale = 1 / (hidden_channel**0.5)

  def forward(self, pair_repr):
    #Input Projections
    z = self.norm(pair_repr)
    q, k, v = self.to_q(z), self.to_k(z), self.to_v(z)
    q, k, v = map(lambda t: rearrange(t, 'b h w (n d) -> b n h w d', n = self.nhead), (q, k, v))
    bias = self.to_bias(z) #bias.shape = (b, h, w, n)
    bias = rearrange(bias, 'b h w n -> b n h w')
    if self.around=='starting':
      bias = repeat(bias, 'b n h w -> b n k h w', k=self.sequence_size)
    elif self.around=='ending':
      bias = repeat(bias, 'b n h w -> b n h k w', k=self.sequence_size)
    g = self.sigmoid(self.to_g(z))
    g = rearrange(g, 'b h w (n d) -> b n h w d', n=self.nhead)

    #Attetion
    if self.around == 'starting':
      dots = self.scale * torch.einsum('b n i j d, b n i k d -> b n i j k', q, k) + bias
    elif self.around == 'ending':
      dots = self.scale * torch.einsum('b n i j d, b n k j d -> b n i j k', q, k) + bias
    dots = self.softmax(dots)
    if self.around == 'starting':
      out = g * torch.einsum('b n i j k, b n i j d -> b n i j d', dots, v)
    elif self.around == 'ending':
      out = g * torch.einsum('b n i j k, b n k j d -> b n i j d', dots, v)
    out = rearrange(out, 'b n i j d -> b i j (n d)')

    #Outer Projection
    z = self.out_proj(out)

    return z

3.3 Evoformerブロックの処理全体

ここまで実装したことをもとにEvoformerブロック全体を実装します。
入力としてMSA表現、ペア表現をとり、更新されたMSA表現、ペア表現、ターゲット配列を出力します。
スクリーンショット 2023-12-23 201222.png

class EvoformerBlock(nn.Module):
  """
  Algorithm6: Evoformer stack
  """
  def __init__(self, sequence_num, sequence_size, channel_size_msa, channel_size_pair, depth=5):
    super().__init__()
    self.msa_row = MSARowWiseAttentionWithPairBias(
      sequence_num,
      sequence_size,
      channel_size_msa,
      channel_size_pair
      )
    self.msa_column = MSAColumnWiseAttention(
      sequence_num,
      sequence_size,
      channel_size_msa,
      channel_size_pair
    )
    self.trans_msa = Trasition(
      channel_size_msa,
    )
    self.outer_prod = OuterProductMean(
      sequence_num,
      sequence_size,
      channel_size_msa,
      channel_size_pair,
    )
    self.outgoing = TriangularMultiplicativeUpdate(
      channel_size_pair,
      mode='outgoing'
    )
    self.incoming = TriangularMultiplicativeUpdate(
      channel_size_pair,
      mode='incoming'
    )
    self.starting = TriangularGatedSelfAttention(
      channel_size_pair,
      sequence_size,
      around='starting'
    )
    self.ending = TriangularGatedSelfAttention(
      channel_size_pair,
      sequence_size,
      around='ending'
    )
    self.trans_pair = Trasition(
      channel_size_pair,
    )
    self.depth = depth
    self.dropout15 = nn.Dropout(p=0.15)
    self.dropout25 = nn.Dropout(p=0.25)
    self.to_s = nn.Linear(channel_size_msa, channel_size_msa)
  def forward(self, msa_repr, pair_repr):
    x = msa_repr
    z = pair_repr
    for i in range(self.depth):
      #MSA stack
      x += self.dropout15(self.msa_row(x, z))
      x += self.msa_column(x)
      x += self.trans_msa(x)

      #Communication
      z += self.outer_prod(x)

      #Pair stack
      z += self.dropout25(self.outgoing(z))
      z += self.dropout25(self.incoming(z))
      z += self.dropout25(self.starting(z))
      z += self.dropout25(self.ending(z))
      z += self.trans_pair(z)
    #Extract the sigle represetation
    s = self.to_s(x[:,0,:,:])
    return x, z, s

3.3.1 入出力形状確認

ここまでの実装を簡易的に確認するため、乱数で生成したmsa_reprpair_reprを入力してみてデータの形状を確認します。

evoformer = EvoformerBlock(config.sequence_num, config.sequence_size, config.channel_size_msa, config.channel_size_pair)
msa_repr, pair_reppr, targe_seq = evoformer(msa_repr, pair_repr)
print(f"{msa_repr.shape=}")
print(f"{pair_repr.shape=}")
print(f"{target_seq.shape=}")
#出力結果
msa_repr.shape=torch.Size([2, 7, 100, 10])
pair_repr.shape=torch.Size([2, 100, 100, 10])
target_seq.shape=torch.Size([2, 100, 10])

4. Structure Module

いよいよMSA表現、ペア表現から構造予測を行うモジュールです。ターゲット配列、ペア表現、主鎖の位置を入力にとって回転行列と並進ベクトルを予測、主鎖の位置を更新します。さらにそれをもとに側鎖の位置も決定します。
スクリーンショット 2023-12-23 202238.png
本記事では簡単のため以下は省略しました。

  • IPAモジュール(既存パッケージから流用)
  • テンプレート構造の入力
  • 側鎖構造の予測(出力は主鎖のみにとどめる)

4.1 インポート

必要なパッケージをインポートします。

from invariant_point_attention import IPABlock
import torch.nn.functional as F

4.2 Structure Module実装

論文の図の通りに実装します。quaternionから回転行列を作る関数やquaternionを合成する処理は既存パッケージの実装をもとに利用しました。

def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
    """
    Convert rotations given as quaternions to rotation matrices.

    Args:
        quaternions: quaternions with real part first,
            as tensor of shape (..., 4).

    Returns:
        Rotation matrices as tensor of shape (..., 3, 3).
    """
    r, i, j, k = torch.unbind(quaternions, -1)
    # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
    two_s = 2.0 / (quaternions * quaternions).sum(-1)

    o = torch.stack(
        (
            1 - two_s * (j * j + k * k),
            two_s * (i * j - k * r),
            two_s * (i * k + j * r),
            two_s * (i * j + k * r),
            1 - two_s * (i * i + k * k),
            two_s * (j * k - i * r),
            two_s * (i * k - j * r),
            two_s * (j * k + i * r),
            1 - two_s * (i * i + j * j),
        ),
        -1,
    )
    return o.reshape(quaternions.shape[:-1] + (3, 3))

def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """
    Multiply two quaternions.
    Usual torch rules for broadcasting apply.

    Args:
        a: Quaternions as tensor of shape (..., 4), real part first.
        b: Quaternions as tensor of shape (..., 4), real part first.

    Returns:
        The product of a and b, a tensor of quaternions shape (..., 4).
    """
    aw, ax, ay, az = torch.unbind(a, -1)
    bw, bx, by, bz = torch.unbind(b, -1)
    ow = aw * bw - ax * bx - ay * by - az * bz
    ox = aw * bx + ax * bw + ay * bz - az * by
    oy = aw * by - ax * bz + ay * bw + az * bx
    oz = aw * bz + ax * by - ay * bx + az * bw
    return torch.stack((ow, ox, oy, oz), -1)

class StructureModule(nn.Module):
  def __init__(self, channel_size_msa, heads=8, scalar_key_dim=16, scalar_value_dim=16, point_key_dim=4, point_value_dim=4):
    super().__init__()
    self.ipa_block = IPABlock(
        dim = channel_size_msa,
        heads = heads,
        scalar_key_dim = scalar_key_dim,
        scalar_value_dim = scalar_value_dim,
    )
    self.out_proj = nn.Linear(channel_size_msa, 6)

  def forward(self, target_seq, pair_repr, translations, quaternions):
    rotations = quaternion_to_matrix(quaternions)
    single_repr = self.ipa_block(
        target_seq,
        pairwise_repr = pair_repr,
        rotations = rotations,
        translations = translations,
    )

    # update quaternion and translation
    updates = self.out_proj(single_repr)
    quaternion_update, translation_update = updates.chunk(2, dim=-1)
    quaternion_update = F.pad(quaternion_update, (1, 0), value = 1.)

    quaternions = quaternion_raw_multiply(quaternions, quaternion_update)
    translations = translations + torch.einsum('b n c, b n c r -> b n r', translation_update, rotations)

    return translations, quaternions

4.3 Backbone構造予測

実装したStructure ModuleをもとにBackboneの構造予測を簡易的に行います。テンプレートやRecycligを省略しているので最初のtranslationは乱数で初期化しています。

ここらへんは本物から簡単化しすぎて解釈を間違えてるところがあるかもしれません。

class PredictStructure(nn.Module):
  def __init__(self, channel_size_msa, depth=2):
    super().__init__()
    self.structure_module = StructureModule(channel_size_msa)
    self.depth = depth
    self.to_points = nn.Linear(channel_size_msa, 3)

  def forward(self, msa_repr, pair_repr):
    b, _, n, _ = msa_repr.shape
    translations = torch.randn(b, n, 3)
    quaternions = repeat(torch.tensor([1., 0., 0., 0.]), 'd -> b n d', b = b, n = n)
    single_repr = msa_repr[:,0,:,:]
    for i in range(self.depth):
      translations, quaternions = self.structure_module(single_repr, pair_repr, translations, quaternions)
    points_local = self.to_points(single_repr)
    rotations = quaternion_to_matrix(quaternions)
    coords = torch.einsum('b n c, b n c d -> b n d', points_local, rotations) + translations

    return coords

4.4 入出力形状確認

実装したStructure ModuleにMSA表現とペア表現を入力してみて、出力として座標を取り出してみます。

predict = PredictStructure(config.channel_size_msa)
print(f"{msa_repr.shape=}")
print(f"{pair_repr.shape=}")
print()
coords = predict(msa_repr, pair_repr)
print(f"{coords.shape=}")
#出力結果
msa_repr.shape=torch.Size([2, 7, 100, 10])
pair_repr.shape=torch.Size([2, 100, 100, 10])

coords.shape=torch.Size([2, 100, 3])

5. まとめ

以上AlphaFoldの各モジュールをPyTorchで実装してみました。最初に説明したように実際にはMSAのマスクやExtra MSA、誤差関数、側鎖予測など省略した処理が多々あるので全部実装するにはまだまだ足りないですが、アルゴリズムをシンプルなコードで理解する上では参考になると思います。

5
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
2