5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

AlphaFold2でリサイクル中の構造を全て出力

Last updated at Posted at 2021-11-16

概要

タンパク質立体構造予測手法AlphaFold2はリサイクル(recycling)によって反復的に構造を改善します.
そしてリサイクルの最後の構造を出力します.

リサイクル数を変えたときに予測構造がどう変わるのかを調べるために複数のリサイクル数の構造が必要になりました (例えばリサイクル数1,2,3の構造3つ)
公式実装 及びColabFold ではリサイクル最後の構造しか出力されないので取得したいリサイクル数だけ繰り返し予測を実行する必要がありました.(そもそも本家実装そのままではリサイクル数を指定できない.ColabFoldではリサイクル数の指定が可能.)

例えばリサイクル数が1,2,3の構造を取得したいときはリサイクル数1を指定して実行,2を指定して実行,3を指定して実行,と合計3回実行する必要がありました.
リサイクル数3を指定して構造を予測する途中でリサイクル数が1,2の時の構造も予測しているのに,わざわざ別々に実行して合計6サイクル(1+2+3)回すのは非常に効率が悪いです.

そこで,リサイクル途中の全ての構造を出力できるようにしました!
例えばリサイクル3を指定した場合には途中の1,2リサイクルの構造も出力されます

この記事ではAlphaFold2のリサイクルを司っているコードの簡単な説明とどの部分を変更すればリサイクル途中の構造を出力できるようになるかを記します.

細かい解説を書いたら長くなってしまったので修正箇所だけ知りたい方はこちらまで飛ばしてください.

一応ローカルで実行できるようにしたレポジトリの方も公開します.(@Ag_smithさんのlocalcolabfoldをベースに改変しています.)

AlphaFoldで使用されているライブラリ

AlphaFoldはjax+haiku(+tensorflow)で書かれています.

jaxとはDeepMind製のNumPyと自動微分及びGPU/TPUサポートを組み合わせた数値計算ライブラリです.

haikuとはjax向けの深層学習ライブラリです.

今回の修正はほぼjaxに関する部分です.

予測モジュール

どこを修正すればいいのか説明する前に簡単にAlphaFoldのリサイクルに関する重要な部分を抜粋してその流れを説明します.

なお本記事はAlphaFold version2.0.1を元にしています.2021-11-5に公開されたversion2.1.1.とは細部が若干異なりますが,基本的には重要な部分は同じです.

main部分

予測の呼び出しと予測構造のpdbファイルへの出力を行なっている部分を抜粋しています.

(本家版ではrun_alphafold.pyのpredict_structure関数140行目以降, colabfoldだとcolabfold_alphafold.pyのrun_alphafold関数の737行目以降に相当)

run_alphafold.py
# Copyright 2021 DeepMind Technologies Limited

# 特徴量生成
processed_feature_dict = model_runner.process_features(
    feature_dict, random_seed=random_seed)

# 予測実行
prediction_result = model_runner.predict(processed_feature_dict)

# Get mean pLDDT confidence metric.
plddt = prediction_result['plddt']
plddts[model_name] = np.mean(plddt)

# Add the predicted LDDT in the b-factor column.
# Note that higher predicted LDDT value means higher model confidence.
plddt_b_factors = np.repeat(
    plddt[:, None], residue_constants.atom_type_num, axis=-1)
# 予測結果をProteinデータクラスに変換
unrelaxed_protein = protein.from_prediction(
    features=processed_feature_dict,
    result=prediction_result,
    b_factors=plddt_b_factors)

# 予測結果をpdbファイルに書き込む
unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb')
with open(unrelaxed_pdb_path, 'w') as f:
  f.write(protein.to_pdb(unrelaxed_protein))

model_runner.predict()で予測の実行をして予測結果を受け取り,それをタンパク質データ構造に変換して最終的にpdbファイルに出力を行なっています.

alphafold/model/model.pyのRunModelクラス

上述したmodel_runner.predict()に当たる部分です.

  • predict関数

    alphafold/model/model.py
    # Copyright 2021 DeepMind Technologies Limited
    def predict(self, feat: features.FeatureDict) -> Mapping[str, Any]:
      """Makes a prediction by inferencing the model on the provided features.
      Args:
        feat: A dictionary of NumPy feature arrays as output by
          RunModel.process_features.
      Returns:
        A dictionary of model outputs.
      """
      # 初期化
      self.init_params(feat)
      logging.info('Running predict with shape(feat) = %s',
                   tree.map_structure(lambda x: x.shape, feat))
      # 予測メイン部分
      result = self.apply(self.params, jax.random.PRNGKey(0), feat)
    
      # This block is to ensure benchmark timings are accurate. Some blocking is
      # already happening when computing get_confidence_metrics, and this ensures
      # all outputs are blocked on.
      jax.tree_map(lambda x: x.block_until_ready(), result)
      # 信頼スコア(pLDDT)の変換
      result.update(get_confidence_metrics(result))
      logging.info('Output shape was %s',
                   tree.map_structure(lambda x: x.shape, result))
      return result
    

    self.apply 部分で予測の実行をしています.

  • apply関数

    applyはinit内で定義されています.

    alphafold/model/model.py
    # Copyright 2021 DeepMind Technologies Limited
    def __init__(self,
                 config: ml_collections.ConfigDict,
                 params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None):
      self.config = config
      self.params = params
    
      # applyの本体
      # module.pyのAlphaFoldクラスを呼び出す
      def _forward_fn(batch):
        model = modules.AlphaFold(self.config.model)
        return model(
            batch,
            is_training=False,
            compute_loss=False,
            ensemble_representations=True)
    
      # applyの定義
      # haikuによる関数をモデルに変換して,jaxでjitコンパイルを行なっている
      self.apply = jax.jit(hk.transform(_forward_fn).apply)
      # モデルの初期化のための関数
      self.init = jax.jit(hk.transform(_forward_fn).init)
    

    applyの内部ではmodel/module.pyのAlphaFoldクラスを呼んでいます.これがAlphaFoldの予測モデルの本体です.

alphafold/model/module.pyのAlphaFoldクラス

このクラスが予測モデルのメインになります.ここでリサイクリングしています.

クラス全体を載せますが,長くて分かりにくいので後で重要な点について説明していきます.

alphafold/model/module.py
# Copyright 2021 DeepMind Technologies Limited
class AlphaFold(hk.Module):
  """AlphaFold model with recycling.
  Jumper et al. (2021) Suppl. Alg. 2 "Inference"
  """

  def __init__(self, config, name='alphafold'):
    super().__init__(name=name)
    self.config = config
    self.global_config = config.global_config

  def __call__(
      self,
      batch,
      is_training,
      compute_loss=False,
      ensemble_representations=False,
      return_representations=False):
    """Run the AlphaFold model.
    Arguments:
      batch: Dictionary with inputs to the AlphaFold model.
      is_training: Whether the system is in training or inference mode.
      compute_loss: Whether to compute losses (requires extra features
        to be present in the batch and knowing the true structure).
      ensemble_representations: Whether to use ensembling of representations.
      return_representations: Whether to also return the intermediate
        representations.
    Returns:
      When compute_loss is True:
        a tuple of loss and output of AlphaFoldIteration.
      When compute_loss is False:
        just output of AlphaFoldIteration.
      The output of AlphaFoldIteration is a nested dictionary containing
      predictions from the various heads.
    """

    impl = AlphaFoldIteration(self.config, self.global_config)
    batch_size, num_residues = batch['aatype'].shape

    def get_prev(ret):
      new_prev = {
          'prev_pos':
              ret['structure_module']['final_atom_positions'],
          'prev_msa_first_row': ret['representations']['msa_first_row'],
          'prev_pair': ret['representations']['pair'],
      }
      return jax.tree_map(jax.lax.stop_gradient, new_prev)

    # リサイクルごとに繰り返し呼び出される関数
    # AlphaFoldIterationクラスを呼び出してリサイクルごとの処理を行なっている
    def do_call(prev,
                recycle_idx,
                compute_loss=compute_loss):
      if self.config.resample_msa_in_recycling:
        num_ensemble = batch_size // (self.config.num_recycle + 1)
        def slice_recycle_idx(x):
          start = recycle_idx * num_ensemble
          size = num_ensemble
          return jax.lax.dynamic_slice_in_dim(x, start, size, axis=0)
        ensembled_batch = jax.tree_map(slice_recycle_idx, batch)
      else:
        num_ensemble = batch_size
        ensembled_batch = batch

      non_ensembled_batch = jax.tree_map(lambda x: x, prev)

      return impl(
          ensembled_batch=ensembled_batch,
          non_ensembled_batch=non_ensembled_batch,
          is_training=is_training,
          compute_loss=compute_loss,
          ensemble_representations=ensemble_representations)

    # リサイクル処理
    if self.config.num_recycle:
      emb_config = self.config.embeddings_and_evoformer
      prev = {
          'prev_pos': jnp.zeros(
              [num_residues, residue_constants.atom_type_num, 3]),
          'prev_msa_first_row': jnp.zeros(
              [num_residues, emb_config.msa_channel]),
          'prev_pair': jnp.zeros(
              [num_residues, num_residues, emb_config.pair_channel]),
      }

      if 'num_iter_recycling' in batch:
        # Training time: num_iter_recycling is in batch.
        # The value for each ensemble batch is the same, so arbitrarily taking
        # 0-th.
        num_iter = batch['num_iter_recycling'][0]

        # Add insurance that we will not run more
        # recyclings than the model is configured to run.
        num_iter = jnp.minimum(num_iter, self.config.num_recycle)
      else:
        # Eval mode or tests: use the maximum number of iterations.
        num_iter = self.config.num_recycle

      # 各リサイクルごとの予測を実行する関数
      body = lambda x: (x[0] + 1,  # pylint: disable=g-long-lambda
                        get_prev(do_call(x[1], recycle_idx=x[0],
                                         compute_loss=False)))
      if hk.running_init():
        # When initializing the Haiku module, run one iteration of the
        # while_loop to initialize the Haiku modules used in `body`.
        _, prev = body((0, prev))
      else:
        _, prev = hk.while_loop(
            lambda x: x[0] < num_iter,
            body,
            (0, prev))
    else:
      prev = {}
      num_iter = 0

    # 最終リサイクル
    ret = do_call(prev=prev, recycle_idx=num_iter)
    if compute_loss:
      ret = ret[0], [ret[1]]

    if not return_representations:
      del (ret[0] if compute_loss else ret)['representations']  # pytype: disable=unsupported-operands
    return ret

リサイクリングのメイン部分(body関数)

alphafold/model/module.py
# 各リサイクルごとの予測を実行するbody関数
# x = (recycle_idx, prev) (リサイクルindexと1サイクル前の予測構造)
# 1サイクル前の予測構造をサイクルごとの予測の処理を行うdo_callに渡して予測を実行
# get_prevで1サイクル前の構造を表すデータ形式に予測結果を変換
body = lambda x: (x[0] + 1,
                  get_prev(do_call(x[1], recycle_idx=x[0],
                               compute_loss=False)))
# 初期化
if hk.running_init():
  # When initializing the Haiku module, run one iteration of the
  # while_loop to initialize the Haiku modules used in `body`.
  _, prev = body((0, prev))

# メイン
else:
  # haiku.while_loop(jax.lax.while_loopと同義)でbody関数をリサイクル数だけ呼び出す
  # 最後のリサイクル構造がprevに渡される
  _, prev = hk.while_loop(
      lambda x: x[0] < num_iter,
      body,
      (0, prev))

# 最終リサイクル
# 最後にもう一度do_callが呼ばれて,その返り値がRunModel.predictに返却される
# 合計でリサイクル数+1回 do_callを呼び出しているが最初の1回目の呼び出しはリサイクルにはカウントしない
ret = do_call(prev=prev, recycle_idx=num_iter)
if compute_loss:
  ret = ret[0], [ret[1]]

if not return_representations:
  del (ret[0] if compute_loss else ret)['representations']  # pytype: disable=unsupported-operands
return ret

この部分がリサイクリングの肝になっている部分です.

haiku.while_loop(=jax.lax.while_loop)で繰り返しbody関数を呼びリサイクリングを行なっています.

jax.lax.while_loopの中身をpython実装で表すと以下のような関数になります

jax.lax.while_loop
def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val

ループの継続条件を表す関数(cond_fun)とメインの処理関数(body_fun)と初期値(init_val)を受け取り,条件を満たすまで反復的にbody_funを実行します.jax.lax.while_loopではこの処理をコンパイルして高速に実行してくれます.

今回の例だと初期構造(0で座標情報を埋めたもの)を受け取りリサイクル数が指定値に達するまで繰り返しbody関数を呼び出しています.

body関数内のdo_callで更にmodule.AlphaFoldIterationクラスという各サイクルごとの予測を実行するクラスを呼び出していますが,リサイクル数にはあまり関係がないので割愛します.

全体の流れ

リサイクリングに関係する部分の簡単な流れをまとめます.

  1. main部分からmodel.RunModel.predict呼び出し

  2. model.RunModel.predictからmodule.AlphaFoldクラス呼び出し

    (AlphaFoldクラスの呼び出しは高速に実行されるようにコンパイルされる)

  3. module.AlphaFoldクラスでリサイクリングして最終リサイクル後の予測結果をmodel.RunModel.predictに返す.

  4. model.RunModel.predictで予測結果を受け取り信頼性スコアの変換を行い,main部分に結果を返す

  5. main部分で結果を受け取りpdbに出力する

上記のようにmain→model.RunModel.predictmodule.AlphaFoldと呼び出されて予測が行われます.

module.AlphaFoldはjaxのjitによってコンパイルされます.

修正箇所

ここからが本題のリサイクル途中の構造を出力する方法になります.

上述したmain部分,model.RunModel.predict関数,module.AlphaFoldクラスの3箇所を修正します.

module.AlphaFold

まずはリサイクリングを司るmodule.AlphaFoldを修正します.

元々haiku.while_loop (jax.lax.while_loop)で反復的にbody関数を呼び出し最終リサイクルの結果を取得していました.

alphafold/model/module.py
# module.AlphaFoldの主な処理
body = lambda x: (x[0] + 1, 
                        get_prev(do_call(x[1], recycle_idx=x[0],
                                         compute_loss=False)))

_, prev = hk.while_loop( # ループ
      lambda x: x[0] < num_iter,
      body,
      (0, prev))

ret = do_call(prev=prev, recycle_idx=num_iter) # 最終サイクル
return ret

この部分を修正して最終リサイクル構造だけでなく,リサイクル途中の結果も全て取得するようにします.

haiku.while_loopでは最終ループの結果しか受け取ることができませんが,ループの途中結果も全て受け取ることができるhaiku.scan (jax.lax.scan) という関数が存在しました.

この関数をpython実装で表すと以下のような関数になります.

jax.lax.scan
# f: メインの処理を行う関数, init: 初期値, 
# xs: fの第二引数に渡す値を持つリスト, length: ループの長さ(xs is Noneの場合のみ有効)
def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = [] # 出力を保存するリスト
  for x in xs:
    # carryは反復的にfに入力され,yはリストに保存される
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys) # 最終的なcarryとstackされたyのリストが返される

メインの関数の出力を次のループの入力に使用でき,かつ出力をリストに保存し最終的に受け取ることが可能です..

繰り返しbody関数を実行して反復的に予測を実行しつつ各サイクルの出力結果も取得したいという今回の要望にぴったりの関数です.

このjax.lax.scanに合わせて途中の予測結果全てを返すようにコードを修正すると以下のようになります.

alphafold/model/module.py
def body(x): # lambda式だとしんどいので普通の関数にした
  n, prev = x
  ret = do_call(prev, recycle_idx=n, compute_loss=False) # do_callの戻り値を保存
  prev_ = get_prev(ret)
  return n+1, prev_, ret

def body_scan(carry, x): # For jax.lax.scan # xは使用しない
  n, prev, ret = body(carry)
  # return carry, y
  return (n, prev), (n, ret) # 1つ目の出力が次のbody_scan呼び出し時の入力に使用され,2つ目の出力はリストに保存される
	
# 最終サイクルの出力及び途中の出力がjax.numpy.stackによってstackされたタプルが返される
recycles, prev, result_tuple = hk.scan(body_scan, (0, prev), None, length=num_iter)

haiku.scanによって最終サイクルの出力及びサイクル途中の出力がstackされタプルになったもの(result_tupleと命名)が返されます.(ただしこのタプルは特殊な形式になっており単純なリストのように各サイクルの結果が簡単に取れるようにはなっていないので注意)

あとは修正前のコードと同様に最終サイクルの出力を取得するためにprevdo_call関数に通してその戻り値とリサイクル数とresult_tupleをRunModel.predictに返します.

alphafold/model/module.py
ret = do_call(prev=prev, recycle_idx=num_iter) # 最終サイクル
return ret, recycles, result_tuple # (最終サイクルの結果,リサイクル数,途中サイクルの結果が結合されたもの)

ここでresult_tupleから各サイクルの予測結果(ret)を取り出していないのはGPUの必要メモリを減らすためです.この関数の内部で結果をTupleから取り出すよりもModel.RunModel.predict内で取り出した方がGPUメモリ使用量が減りました.

model.RunModel.predict

続いてmodule.AlphaFoldから最終サイクルの予測結果を受け取っていたmodel.RunModel.predictを修正します.

元々は最終リサイクルの結果だけ受け取っていたところを,最終サイクルの結果,リサイクル数及び途中サイクルの結果が結合されたタプルを受け取るようにします.

その後受け取ったタプルをサイクルごとの予測結果が入ったリストに変換をして,main部分に返します.

alphafold/model/model.py
def predict(self,
            feat: features.FeatureDict,
            ret_all_cycle: bool = False, # 全ての出力を返すかどうか
            random_seed=0) -> Union[Tuple[Mapping[str, Any], int, float], List[Tuple[Mapping[str, Any], int, float]]]:
  self.init_params(feat)
  logging.info('Running predict with shape(feat) = %s',
               tree.map_structure(lambda x: x.shape, feat))
  # ここでAlphaFoldクラスを呼び出し,最終リサイクル結果とリサイクル数のタプル,途中サイクルの結果が結合されたものを受け取る
  result, recycles, result_tuple = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)

  # 各サイクルごとの予測結果の信頼性スコアを変換する関数
  def iter_result(result: Tuple[Mapping[str, Any], int]) -> Tuple[Mapping[str, Any], int]:
    result_ = result[0]
    jax.tree_map(lambda x: x.block_until_ready(), result_)
    result_.update(get_confidence_metrics(result_))
    logging.info('Output shape was %s',
                tree.map_structure(lambda x: x.shape, result_))
    return result

  if not ret_all_cycle:  # 最後のサイクルの予測結果のみ返す場合
    ret = iter_result((result, recycles))
    return ret
  else:  # 全サイクルの予測結果を返す場合
    def get_each_result(i): # result_tupleから各サイクルのrecycle, retを取り出す関数
      recycles = result_tuple[0][i]
      # haiku.scanで辞書は各キーごとにstackされるので以下のように取り出さなければいけない
      result = {key1: {key2: val2[i + 1] for key2, val2 in val1.items()} for key1, val1 in result_tuple[1].items()}
      return result, recycles
    # 最終サイクルを除いた各サイクルの予測結果を取り出す
    result_list = [get_each_result(i) for i in range(recycles - 1)]
    result_list.append(result) # 最終サイクルの予測結果を追加
    ret_list = [iter_result(result) for result in result_list] # 各サイクルごとに信頼性スコアの変換
    return ret_list

main部分

main部分では予測結果のリストを受け取り,一つ一つをpdbファイルに出力するようにします.

run_alphafold.py
# 予測実行
prediction_result_list = model_runner.predict(processed_feature_dict, ret_all_cycle=True)

# 一つ一つの予測結果を処理
for prediction_result, recycle in prediction_result_list:
	# Get mean pLDDT confidence metric.
	plddt = prediction_result['plddt']
	plddts[f'{model_name}_{recycle}'] = np.mean(plddt)
	
	# Add the predicted LDDT in the b-factor column.
	# Note that higher predicted LDDT value means higher model confidence.
	plddt_b_factors = np.repeat(
	    plddt[:, None], residue_constants.atom_type_num, axis=-1)
	# 予測結果をProteinデータクラスに変換
	unrelaxed_protein = protein.from_prediction(
	    features=processed_feature_dict,
	    result=prediction_result,
	    b_factors=plddt_b_factors)
	
	# 予測結果をpdbファイルに書き込む
	unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}_recycle_{recycle}.pdb')
	with open(unrelaxed_pdb_path, 'w') as f:
	  f.write(protein.to_pdb(unrelaxed_protein))

以上でリサイクル途中の構造全てを出力することが可能です.

修正箇所まとめ

AlphaFold version2.0.1に変更を適用した際のdiffを記します.

alphafold/model/module.py
diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py
index 794597f..ecdec5f 100644
--- a/alphafold/model/modules.py
+++ b/alphafold/model/modules.py
@@ -365,21 +365,30 @@ class AlphaFold(hk.Module):
         # Eval mode or tests: use the maximum number of iterations.
         num_iter = self.config.num_recycle

-      body = lambda x: (x[0] + 1,  # pylint: disable=g-long-lambda
-                        get_prev(do_call(x[1], recycle_idx=x[0],
-                                         compute_loss=False)))
+      def body(x):
+        n, prev = x
+        ret = do_call(prev, recycle_idx=n, compute_loss=False)
+        prev_ = get_prev(ret)
+        return n+1, prev_, ret
+
+      def body_scan(carry, x):  # For jax.lax.scan
+        n, prev, ret = body(carry)
+        if not return_representations:
+          del ret['representations']
+        return (n, prev), (n, ret)
+
       if hk.running_init():
         # When initializing the Haiku module, run one iteration of the
         # while_loop to initialize the Haiku modules used in `body`.
-        _, prev = body((0, prev))
+        recycles, prev, _ = body((0, prev))
+        result_tuple = ()
       else:
-        _, prev = hk.while_loop(
-            lambda x: x[0] < num_iter,
-            body,
-            (0, prev))
+        (recycles, prev), result_tuple = hk.scan(body_scan, (0, prev), None, length=num_iter)
     else:
       prev = {}
       num_iter = 0
+      recycles = 0
+      result_tuple = ()

     ret = do_call(prev=prev, recycle_idx=num_iter)
     if compute_loss:
@@ -387,7 +396,7 @@ class AlphaFold(hk.Module):

     if not return_representations:
       del (ret[0] if compute_loss else ret)['representations']  # pytype: disable=unsupported-operands
-    return ret
+    return ret, recycles, result_tuple
alphafold/model/model.py
diff --git a/alphafold/model/model.py b/alphafold/model/model.py
index 66addeb..a810fd0 100644
--- a/alphafold/model/model.py
+++ b/alphafold/model/model.py
@@ -13,7 +13,7 @@
 # limitations under the License.

 """Code for constructing the model."""
-from typing import Any, Mapping, Optional, Union
+from typing import Any, Mapping, Optional, Union, List, Tuple

 from absl import logging
 from alphafold.common import confidence
@@ -117,7 +117,7 @@ class RunModel:
     logging.info('Output shape was %s', shape)
     return shape

-  def predict(self, feat: features.FeatureDict) -> Mapping[str, Any]:
+  def predict(self, feat: features.FeatureDict, ret_all_cycle: bool) -> List[Tuple[Mapping[str, Any], int]]:
     """Makes a prediction by inferencing the model on the provided features.

     Args:
@@ -125,17 +125,34 @@ class RunModel:
         RunModel.process_features.

     Returns:
-      A dictionary of model outputs.
+      A List of Tuple of a dictionary of model outputs and the recycle number.
     """
     self.init_params(feat)
     logging.info('Running predict with shape(feat) = %s',
                  tree.map_structure(lambda x: x.shape, feat))
-    result = self.apply(self.params, jax.random.PRNGKey(0), feat)
+    result, recycles, result_tuple = self.apply(self.params, jax.random.PRNGKey(0), feat)
     # This block is to ensure benchmark timings are accurate. Some blocking is
     # already happening when computing get_confidence_metrics, and this ensures
     # all outputs are blocked on.
-    jax.tree_map(lambda x: x.block_until_ready(), result)
-    result.update(get_confidence_metrics(result))
-    logging.info('Output shape was %s',
-                 tree.map_structure(lambda x: x.shape, result))
-    return result
+
+    def iter_result(result):
+      result_ = result[0]
+      jax.tree_map(lambda x: x.block_until_ready(), result_)
+      result_.update(get_confidence_metrics(result_))
+      logging.info('Output shape was %s',
+                   tree.map_structure(lambda x: x.shape, result_))
+      return result
+
+    if not ret_all_cycle:  # Return last cycles
+      ret = iter_result((result, recycles))
+      return ret
+    else:  # Return all cycles
+      def get_each_result(i):
+        recycles = result_tuple[0][i]
+        result = {key1: {key2: val2[i + 1] for key2, val2 in val1.items()} for key1, val1 in result_tuple[1].items()}
+        return result, recycles
+
+      result_list = [get_each_result(i) for i in range(recycles - 1)]
+      result_list.append(result)
+      ret_list = [iter_result(result) for result in result_list]
+      return ret_list
run_alphafold.py
diff --git a/run_alphafold.py b/run_alphafold.py
index 6f1a690..f99f762 100644
--- a/run_alphafold.py
+++ b/run_alphafold.py
@@ -146,7 +146,7 @@ def predict_structure(
     timings[f'process_features_{model_name}'] = time.time() - t_0

     t_0 = time.time()
-    prediction_result = model_runner.predict(processed_feature_dict)
+    prediction_result_list = model_runner.predict(processed_feature_dict, ret_all_cycle=True)
     t_diff = time.time() - t_0
     timings[f'predict_and_compile_{model_name}'] = t_diff
     logging.info(
@@ -158,39 +158,41 @@ def predict_structure(
       model_runner.predict(processed_feature_dict)
       timings[f'predict_benchmark_{model_name}'] = time.time() - t_0

-    # Get mean pLDDT confidence metric.
-    plddt = prediction_result['plddt']
-    plddts[model_name] = np.mean(plddt)
-
-    # Save the model outputs.
-    result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl')
-    with open(result_output_path, 'wb') as f:
-      pickle.dump(prediction_result, f, protocol=4)
-
-    # Add the predicted LDDT in the b-factor column.
-    # Note that higher predicted LDDT value means higher model confidence.
-    plddt_b_factors = np.repeat(
-        plddt[:, None], residue_constants.atom_type_num, axis=-1)
-    unrelaxed_protein = protein.from_prediction(
-        features=processed_feature_dict,
-        result=prediction_result,
-        b_factors=plddt_b_factors)
-
-    unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb')
-    with open(unrelaxed_pdb_path, 'w') as f:
-      f.write(protein.to_pdb(unrelaxed_protein))
-
-    # Relax the prediction.
-    t_0 = time.time()
-    relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
-    timings[f'relax_{model_name}'] = time.time() - t_0
+    for prediction_result, recycle in prediction_result_list:
+      model_name_recycle = f'{model_name}_recycle_{recycle}'
+      # Get mean pLDDT confidence metric.
+      plddt = prediction_result['plddt']
+      plddts[model_name_recycle] = np.mean(plddt)
+
+      # Save the model outputs.
+      result_output_path = os.path.join(output_dir, f'result_{model_name_recycle}.pkl')
+      with open(result_output_path, 'wb') as f:
+        pickle.dump(prediction_result, f, protocol=4)
+
+      # Add the predicted LDDT in the b-factor column.
+      # Note that higher predicted LDDT value means higher model confidence.
+      plddt_b_factors = np.repeat(
+          plddt[:, None], residue_constants.atom_type_num, axis=-1)
+      unrelaxed_protein = protein.from_prediction(
+          features=processed_feature_dict,
+          result=prediction_result,
+          b_factors=plddt_b_factors)
+
+      unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name_recycle}.pdb')
+      with open(unrelaxed_pdb_path, 'w') as f:
+        f.write(protein.to_pdb(unrelaxed_protein))
+
+      # Relax the prediction.
+      t_0 = time.time()
+      relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
+      timings[f'relax_{model_name_recycle}'] = time.time() - t_0

-    relaxed_pdbs[model_name] = relaxed_pdb_str
+      relaxed_pdbs[f'{model_name_recycle}'] = relaxed_pdb_str

-    # Save the relaxed PDB.
-    relaxed_output_path = os.path.join(output_dir, f'relaxed_{model_name}.pdb')
-    with open(relaxed_output_path, 'w') as f:
-      f.write(relaxed_pdb_str)
+      # Save the relaxed PDB.
+      relaxed_output_path = os.path.join(output_dir, f'relaxed_{model_name_recycle}.pdb')
+      with open(relaxed_output_path, 'w') as f:
+        f.write(relaxed_pdb_str)

実際の実装

実際に私がlocalcolabfoldをベースに途中リサイクル構造を出力できるようにしたレポジトリを公開しておきます.

リサイクル数だけでなく他の箇所も少しいじっておりlocalcolabfoldと完全に同じではないのでご注意ください.
またMultimerには対応していないのでご注意ください.

実行環境

localcolabfold (commit 0129393)
実行環境にはColabFoldをローカルで実行できるようにしたlocalcolabfoldを使用しました
本家実装ではリサイクル数を指定できませんが,ColabFold及びlocalcolabfoldでは指定が可能です
AlphaFold (version 2.0.1)
ColabFold (commit 9546d8f)

レポジトリ

問題点

実行時間

複数リサイクルを返すので実行時間が遅くなります.
きちんと測定はしていませんが,最終リサイクルだけを返す実装よりも1.5倍程度遅くなりました.

GPU memory不足

リサイクル途中の全ての構造を返すようにしているので,本家よりも必要メモリが多いです.
そのため長い配列でリサイクルの数が多いとGPUメモリが不足で実行できない場合があります.

私の環境(Nvidia P100, 16GB) だと500残基のターゲットに対してensemble 8, recycle 10を指定したらメモリ不足で実行できませんでした.(ensembleなし, recycle 10なら700残基までは実行できることを確認)

GPUメモリ不足の際には以下の記事の内容を実行すれば解消するかもしれません.
GPU memory allocation - JAX documentation

まとめ

AlphaFoldの予測実行時にリサイクル最後の予測構造だけでなく,途中構造も出力できるようにするために変更すべき点を記しました.

またローカル環境でColabFoldを実行できるlocalColabFoldを改変してリサイクル途中の構造を出力できるようにしたレポジトリを公開しています.
注意点として残基数が多いターゲットでかつリサイクル数が多いとGPUメモリ不足になる場合があります.

コードの改善案やGPUメモリを削減できるいい実装があったらぜひ教えてください! 他にも高速化が可能であればぜひ教えてください.

5
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?