はじめに
DeepChemのサンプルでは、Hold-Outの事例がほとんどで、Cross-Validationの事例が少ない。(DeepChemを用いたベンチマークであるMoleculeNetもHold-Outによる評価となっている)。しかし、Cross-Validationの方が、検証に多くのデータが使えることもあり、高々1000件程度のデータであれば、迷わずCross-Valdationを使いたい。今回、DeepChemでCross-Validationをする方法を調べたのでメモっておく
環境
- python 3.6
- deepchem 2.2.1.dev54
- rdkit 2019.03.3.0
ソースコード
以下がソースコードだ。Splitterクラスのk_fold_splitメソッドにより分割が得られるので、後はscikit-learnのKFoldと同じようにすればよい。
import argparse
import numpy as np
import deepchem as dc
from sklearn.ensemble import RandomForestRegressor
def get_model():
estimator = RandomForestRegressor()
return dc.models.SklearnModel(estimator, None)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-train", type=str, required=True, help="trainig data file(csv)")
parser.add_argument("-target_col", type=str, required=True)
parser.add_argument("-smiles_col", type=str, default="smiles")
parser.add_argument("-id_col", type=str, default="id")
parser.add_argument("-cv", type=int, default=5)
args = parser.parse_args()
featurizer = dc.feat.RDKitDescriptors()
# 学習データの読み込み
loader = dc.data.CSVLoader(tasks=[args.target_col],
id_field=args.id_col,
smiles_field=args.smiles_col,
featurizer=featurizer)
dataset = loader.featurize(args.train)
splitter = dc.splits.RandomSplitter(dataset)
transformers = [dc.trans.NormalizationTransformer(transform_y=True, dataset=dataset)]
split_datas = splitter.k_fold_split(dataset, args.cv)
metric = dc.metrics.Metric(dc.metrics.r2_score, np.mean)
train_scores = []
validation_scores = []
# Cross-Validationの実行
for train_set, validation_set in split_datas:
model = get_model()
model.fit(train_set)
train_score = model.evaluate(train_set, [metric], transformers)
train_scores.append(train_score)
validation_score = model.evaluate(validation_set, [metric], transformers)
validation_scores.append(validation_score)
print("score train:{0}, val:{1}".format(train_score, validation_score))
if __name__ == "__main__":
main()
おわりに
手元のscikit-learn版のプログラムと同じデータ、同じアルゴリズムで比較したところほぼ同程度の精度となった。
今回はscikit-learnのRandomForestRegressorをラップしたモデルでやってみたが、DeepChemのアルゴリズムでも同様にできるはずだ。
今後はCross-Validationによるハイパーパラメータ探索なども試してみる予定である。
さて今回でDeepChemを使う準備は概ね整った。次回からは、本格的にDeepChemならではのアルゴリズムを使っていきたい。
参考
-How to implement cross-validation for training set in deepchem? #736