LoginSignup
56
69

More than 5 years have passed since last update.

DNNに匹敵する新たな学習器であるDeep Forestの学習

Last updated at Posted at 2017-07-11

Deep Learningの闇ではなく、Deep Forestの茂みの中に進みたいと思います。

DNNの代替となるかもしれないアルゴリズムにDeep Forestというものがあります。
Deep Forest :Deep Neural Networkの代替へ向けての記事を読んだり、論文を読んだりしてみると、どうやらランダムフォレストと呼ばれる決定木の集合体を複数個使い、幅および深さの方向に対してたくさん並べていくことで、Deepな構造にしているようです。

ランダムフォレストについてはこの記事が参考になります。pythonのモジュールであるscikit-learnと呼ばれる機械学習ツールにおけるランダムフォレストの説明となっていますが、今回走らせるコードもPythonで、しかもscikit-learnを使っているので、めちゃくちゃ参考になります。

Python以外にもR言語での実装があるようです。(Deep Forestの実装コード事例)

コードを手に入れる

Deep Forestを1から作るのは大変です。迷子になっちゃいます。
というわけで、Pythonで実装されているDeep Forestをgithubから手に入れます。

学習からテストまでは、READMEを見ると上手く行きます。
正答率もそこそこな感じで悪くなさそうです。ちなみに、CPUのみの対応なので、CPU使用率がかなり高くなります。

モデルの保存

一通り体験できてお腹いっぱいなのですが、学習したモデルの保存ができないというのが少し不便だったので、以下の2つのメンバ関数ををdeep_forest.pyのMGCForestクラスに追加します。

deep_forest.py
class MGCForest():

    :
    :
    :

    def save_model(self):
        # save multi-grained scanner
        for mgs_instance in self.mgs_instances:
            stride_ratio = mgs_instance.stride_ratio
            folds = mgs_instance.folds
            for i, estimator in enumerate(mgs_instance.estimators):
                joblib.dump(estimator, 'model/mgs_submodel_%.4f_%d_%d.pkl' % (stride_ratio, folds, i + 1)) 

        # save cascade forest
        for n_level, one_level_estimators in enumerate(self.c_forest.levels):
            for i, estimator in enumerate(one_level_estimators):
                joblib.dump(estimator, 'model/cforest_submodel_%d_%d.pkl' % (n_level + 1, i + 1))

    def load_model(self):
        # load multi-grained scanner
        for mgs_instance in self.mgs_instances:
            stride_ratio = '%.4f' % mgs_instance.stride_ratio
            folds = mgs_instance.folds
            for i in range(len(mgs_instance.estimators)):
                model_name = 'model/mgs_submodel_%s_%d_%d.pkl' % (stride_ratio, folds, i + 1)
                print('load model: {}'.format(model_name))
                mgs_instance.estimators[i] = joblib.load(model_name)

        # load cascade forest
        model_files = glob.glob('model/cforest_submodel_*.pkl')
        model_files.sort()
        max_level = 0
        model_dict = dict()
        for model_name in model_files:
            model_subname = re.sub('model/cforest_submodel_', '', model_name)
            model_level = int(model_subname.split('_')[0])
            if max_level < model_level:
                max_level = model_level

            if model_level not in model_dict.keys():
                model_dict[model_level] = list()
            print('load model: {}'.format(model_name))
            model_dict[model_level].append(joblib.load(model_name))

        self.c_forest.levels = list()
        for n_level in range(1, max_level + 1):
            self.c_forest.levels.append(model_dict[n_level])

        n_classes_ = self.c_forest.levels[0][0].n_classes_
        self.c_forest.classes = np.unique(np.arange(n_classes_))

fit関数で学習をした後にsave_model関数を呼べばモデルパラメータをmodelディレクトリに保存してくれます(save_model関数を呼ぶ前にmodelディレクトリを作成し、中身を空にしておいてください)。
学習済みのモデルパラメータを読み込みたいときはload_model関数を呼べばOKです。

複数のランダムフォレストによってDeep Forestはできあがっているのですが、モデル保存時は、各ランダムフォレストに対して、それぞれパラメータファイルを作成し保存する必要があります。なので、modelディレクトリの中には複数のpklファイルができています。

元々ランダムフォレストは、各特徴量の値の範囲に違いがあっても影響を受けないモデルであるという利点があります。ニューラルネットワークはそうではないので、各特徴量の値の範囲を0から1に正規化したりする必要があります。したがって、画像だけでなく、他の様々な特徴も組み合わせたいときにはこのDeep Forestがその威力を大いに奮ってくれるのではないかと期待しています。

56
69
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
56
69