LoginSignup
1
1

More than 3 years have passed since last update.

DeepChemでCross-Validationをやってみる

Last updated at Posted at 2019-10-24

はじめに

DeepChemのサンプルでは、Hold-Outの事例がほとんどで、Cross-Validationの事例が少ない。(DeepChemを用いたベンチマークであるMoleculeNetもHold-Outによる評価となっている)。しかし、Cross-Validationの方が、検証に多くのデータが使えることもあり、高々1000件程度のデータであれば、迷わずCross-Valdationを使いたい。今回、DeepChemでCross-Validationをする方法を調べたのでメモっておく

環境

  • python 3.6
  • deepchem 2.2.1.dev54
  • rdkit 2019.03.3.0

ソースコード

以下がソースコードだ。Splitterクラスのk_fold_splitメソッドにより分割が得られるので、後はscikit-learnのKFoldと同じようにすればよい。

DeelChemCrossValExample.py
import argparse
import numpy as np
import deepchem as dc
from sklearn.ensemble import RandomForestRegressor

def get_model():
    estimator = RandomForestRegressor()
    return dc.models.SklearnModel(estimator, None)

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("-train", type=str, required=True, help="trainig data file(csv)")
    parser.add_argument("-target_col", type=str, required=True)
    parser.add_argument("-smiles_col", type=str, default="smiles")
    parser.add_argument("-id_col", type=str, default="id")
    parser.add_argument("-cv", type=int, default=5)

    args = parser.parse_args()

    featurizer = dc.feat.RDKitDescriptors()

    # 学習データの読み込み
    loader = dc.data.CSVLoader(tasks=[args.target_col],
                               id_field=args.id_col,
                               smiles_field=args.smiles_col,
                               featurizer=featurizer)

    dataset = loader.featurize(args.train)
    splitter = dc.splits.RandomSplitter(dataset)

    transformers = [dc.trans.NormalizationTransformer(transform_y=True, dataset=dataset)]

    split_datas = splitter.k_fold_split(dataset, args.cv)
    metric = dc.metrics.Metric(dc.metrics.r2_score, np.mean)
    train_scores = []
    validation_scores = []

    # Cross-Validationの実行
    for train_set, validation_set in split_datas:
        model = get_model()
        model.fit(train_set)
        train_score = model.evaluate(train_set, [metric], transformers)
        train_scores.append(train_score)
        validation_score = model.evaluate(validation_set, [metric], transformers)
        validation_scores.append(validation_score)
        print("score train:{0}, val:{1}".format(train_score, validation_score))


if __name__ == "__main__":
    main()

おわりに

手元のscikit-learn版のプログラムと同じデータ、同じアルゴリズムで比較したところほぼ同程度の精度となった。
今回はscikit-learnのRandomForestRegressorをラップしたモデルでやってみたが、DeepChemのアルゴリズムでも同様にできるはずだ。
今後はCross-Validationによるハイパーパラメータ探索なども試してみる予定である。
さて今回でDeepChemを使う準備は概ね整った。次回からは、本格的にDeepChemならではのアルゴリズムを使っていきたい。

参考

-How to implement cross-validation for training set in deepchem? #736

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1