Help us understand the problem. What is going on with this article?

DNNに匹敵する新たな学習器であるDeep Forestの学習

More than 1 year has passed since last update.

Deep Learningの闇ではなく、Deep Forestの茂みの中に進みたいと思います。

DNNの代替となるかもしれないアルゴリズムにDeep Forestというものがあります。
Deep Forest :Deep Neural Networkの代替へ向けての記事を読んだり、論文を読んだりしてみると、どうやらランダムフォレストと呼ばれる決定木の集合体を複数個使い、幅および深さの方向に対してたくさん並べていくことで、Deepな構造にしているようです。

ランダムフォレストについてはこの記事が参考になります。pythonのモジュールであるscikit-learnと呼ばれる機械学習ツールにおけるランダムフォレストの説明となっていますが、今回走らせるコードもPythonで、しかもscikit-learnを使っているので、めちゃくちゃ参考になります。

Python以外にもR言語での実装があるようです。(Deep Forestの実装コード事例)

コードを手に入れる

Deep Forestを1から作るのは大変です。迷子になっちゃいます。
というわけで、Pythonで実装されているDeep Forestをgithubから手に入れます。

https://github.com/leopiney/deep-forest

学習からテストまでは、READMEを見ると上手く行きます。
正答率もそこそこな感じで悪くなさそうです。ちなみに、CPUのみの対応なので、CPU使用率がかなり高くなります。

モデルの保存

一通り体験できてお腹いっぱいなのですが、学習したモデルの保存ができないというのが少し不便だったので、以下の2つのメンバ関数ををdeep_forest.pyのMGCForestクラスに追加します。

deep_forest.py
class MGCForest():

    :
    :
    :

    def save_model(self):
        # save multi-grained scanner
        for mgs_instance in self.mgs_instances:
            stride_ratio = mgs_instance.stride_ratio
            folds = mgs_instance.folds
            for i, estimator in enumerate(mgs_instance.estimators):
                joblib.dump(estimator, 'model/mgs_submodel_%.4f_%d_%d.pkl' % (stride_ratio, folds, i + 1)) 

        # save cascade forest
        for n_level, one_level_estimators in enumerate(self.c_forest.levels):
            for i, estimator in enumerate(one_level_estimators):
                joblib.dump(estimator, 'model/cforest_submodel_%d_%d.pkl' % (n_level + 1, i + 1))

    def load_model(self):
        # load multi-grained scanner
        for mgs_instance in self.mgs_instances:
            stride_ratio = '%.4f' % mgs_instance.stride_ratio
            folds = mgs_instance.folds
            for i in range(len(mgs_instance.estimators)):
                model_name = 'model/mgs_submodel_%s_%d_%d.pkl' % (stride_ratio, folds, i + 1)
                print('load model: {}'.format(model_name))
                mgs_instance.estimators[i] = joblib.load(model_name)

        # load cascade forest
        model_files = glob.glob('model/cforest_submodel_*.pkl')
        model_files.sort()
        max_level = 0
        model_dict = dict()
        for model_name in model_files:
            model_subname = re.sub('model/cforest_submodel_', '', model_name)
            model_level = int(model_subname.split('_')[0])
            if max_level < model_level:
                max_level = model_level

            if model_level not in model_dict.keys():
                model_dict[model_level] = list()
            print('load model: {}'.format(model_name))
            model_dict[model_level].append(joblib.load(model_name))

        self.c_forest.levels = list()
        for n_level in range(1, max_level + 1):
            self.c_forest.levels.append(model_dict[n_level])

        n_classes_ = self.c_forest.levels[0][0].n_classes_
        self.c_forest.classes = np.unique(np.arange(n_classes_))

fit関数で学習をした後にsave_model関数を呼べばモデルパラメータをmodelディレクトリに保存してくれます(save_model関数を呼ぶ前にmodelディレクトリを作成し、中身を空にしておいてください)。
学習済みのモデルパラメータを読み込みたいときはload_model関数を呼べばOKです。

複数のランダムフォレストによってDeep Forestはできあがっているのですが、モデル保存時は、各ランダムフォレストに対して、それぞれパラメータファイルを作成し保存する必要があります。なので、modelディレクトリの中には複数のpklファイルができています。

元々ランダムフォレストは、各特徴量の値の範囲に違いがあっても影響を受けないモデルであるという利点があります。ニューラルネットワークはそうではないので、各特徴量の値の範囲を0から1に正規化したりする必要があります。したがって、画像だけでなく、他の様々な特徴も組み合わせたいときにはこのDeep Forestがその威力を大いに奮ってくれるのではないかと期待しています。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away