Help us understand the problem. What is going on with this article?

DeepChemでCross-Validationをやってみる

はじめに

DeepChemのサンプルでは、Hold-Outの事例がほとんどで、Cross-Validationの事例が少ない。(DeepChemを用いたベンチマークであるMoleculeNetもHold-Outによる評価となっている)。しかし、Cross-Validationの方が、検証に多くのデータが使えることもあり、高々1000件程度のデータであれば、迷わずCross-Valdationを使いたい。今回、DeepChemでCross-Validationをする方法を調べたのでメモっておく

環境

  • python 3.6
  • deepchem 2.2.1.dev54
  • rdkit 2019.03.3.0

ソースコード

以下がソースコードだ。Splitterクラスのk_fold_splitメソッドにより分割が得られるので、後はscikit-learnのKFoldと同じようにすればよい。

DeelChemCrossValExample.py
import argparse
import numpy as np
import deepchem as dc
from sklearn.ensemble import RandomForestRegressor

def get_model():
    estimator = RandomForestRegressor()
    return dc.models.SklearnModel(estimator, None)

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("-train", type=str, required=True, help="trainig data file(csv)")
    parser.add_argument("-target_col", type=str, required=True)
    parser.add_argument("-smiles_col", type=str, default="smiles")
    parser.add_argument("-id_col", type=str, default="id")
    parser.add_argument("-cv", type=int, default=5)

    args = parser.parse_args()

    featurizer = dc.feat.RDKitDescriptors()

    # 学習データの読み込み
    loader = dc.data.CSVLoader(tasks=[args.target_col],
                               id_field=args.id_col,
                               smiles_field=args.smiles_col,
                               featurizer=featurizer)

    dataset = loader.featurize(args.train)
    splitter = dc.splits.RandomSplitter(dataset)

    transformers = [dc.trans.NormalizationTransformer(transform_y=True, dataset=dataset)]

    split_datas = splitter.k_fold_split(dataset, args.cv)
    metric = dc.metrics.Metric(dc.metrics.r2_score, np.mean)
    train_scores = []
    validation_scores = []

    # Cross-Validationの実行
    for train_set, validation_set in split_datas:
        model = get_model()
        model.fit(train_set)
        train_score = model.evaluate(train_set, [metric], transformers)
        train_scores.append(train_score)
        validation_score = model.evaluate(validation_set, [metric], transformers)
        validation_scores.append(validation_score)
        print("score train:{0}, val:{1}".format(train_score, validation_score))


if __name__ == "__main__":
    main()

おわりに

手元のscikit-learn版のプログラムと同じデータ、同じアルゴリズムで比較したところほぼ同程度の精度となった。
今回はscikit-learnのRandomForestRegressorをラップしたモデルでやってみたが、DeepChemのアルゴリズムでも同様にできるはずだ。
今後はCross-Validationによるハイパーパラメータ探索なども試してみる予定である。
さて今回でDeepChemを使う準備は概ね整った。次回からは、本格的にDeepChemならではのアルゴリズムを使っていきたい。

参考

-How to implement cross-validation for training set in deepchem? #736

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away