0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

ID3の実装

Posted at

ID3の実装

  • この記事では決定木分析の基本的なアルゴリズムであるID3をfrom scratchで実装します
  • コードは主にこちらを参考にさせて頂いております

ID3とは

  • 簡単に言えば、情報利得が最も大きくなる特徴量で分岐する、ということを繰り返して木を作成する方法です
  • 詳しいアルゴリズムはwikipediaから見ていただければと思います

ID3の実装

ID3.py
import math
from collections import deque

class Node:
    """自身と他のノードについての情報を持つ"""
    def __init__(self):
        self.value = None  #ブランチを分割する特徴量
        self.next = None #次のノード
        self.childs = None #ノードから分かれるブランチ
        
class DecisionTreeClassifier:
    def __init__(self,X,feature_names,labels):
        self.X = X  #特徴量
        self.feature_names = feature_names  #特徴量のカラム名
        self.labels = labels   #目的変数
        self.labelCategories = list(set(labels)) #カテゴリーのユニーク値(Yes,No)
        #カテゴリーごとの数
        self.labelCategoriesCount = [list(labels).count(x) for x in self.labelCategories]
        self.node = None
        self.entropy = self._get_entropy([x for x in range(len(self.labels))])
        
    def _get_entropy(self,x_ids):
        """エントロピーの計算
        
        :param x_ids: インスタンスのidを格納したリスト(0~999まで)
        
        :return エントロピー
        """
        labels = [self.labels[i] for i in x_ids] #目的変数のリスト[yes,yes,no,...,]
        #カテゴリーごとの数のリスト 例:[451,549]
        label_count = [labels.count(x) for x in self.labelCategories]
        
        #各カテゴリーごとにエントロピーを計算し、それらを合算
        entropy = sum([-count / len(x_ids) * math.log(count / len(x_ids), 2) if count else 0 for count in label_count])
        
        return entropy
    
    def _get_information_gain(self,x_ids,feature_id):
        """与えられた特徴について、情報利得を計算する
        
        :param x_ids: インスタンスのidを格納したリスト(初期は0~999まで)
        :param feature_id: 特徴量のid(初期は0,1,2のうちいずれか)
        """
        #現在の情報利得を計算
        info_gain = self._get_entropy(x_ids)
        # 特徴量の値のリスト([small,large,large,...,])みたいな
        x_features = [self.X[x][feature_id] for x in x_ids]
        #特徴量のユニーク値をとる
        feature_vals = list(set(x_features))
        #各特徴量の値ごとの数のリスト([22,38,45,33])みたいな
        feature_vals_count = [x_features.count(x) for x in feature_vals]
        # 各特徴量のユニーク値について、その特徴量の値と合致するx_idをリスト形式でまとめている [[1,3,4,...,],[2,5,...,]]みたいな
        feature_vals_id = [
            [x_ids[i]
            for i,x in enumerate(x_features)
            if x == y
            ]
            for y in feature_vals
        ]
        
        #情報利得を計算する
        info_gain = info_gain -  sum([val_counts / len(x_ids) * self._get_entropy(val_ids)
                                     for val_counts, val_ids in zip(feature_vals_count, feature_vals_id)])
        
        
        return info_gain
        
    def _get_feature_max_information(self,x_ids,feature_ids):
        """情報利得を最大化する特徴量を見つける
        
        :param x_ids: インスタンスのidを格納したリスト(初期は0~999まで)
        :param feature_ids: 特徴量のids(初期は0,1,2)
        
        :returns 情報利得を最大化する特徴量の名称とそのid
        """
        #各特徴量ごとにエントロピーを取得
        feature_entropy = [self._get_information_gain(x_ids, feature_id) for feature_id in feature_ids]
        #最も情報利得の高い特徴量のidを取得
        max_id = feature_ids[feature_entropy.index(max(feature_entropy))]
        
        return self.feature_names[max_id], max_id
    
    def id3(self):
        """ID3アルゴリズムを動かすインターフェース
        
        """
        x_ids = [x for x in range(len(self.X))]
        feature_ids = [x for x in range(len(self.feature_names))]
        self.node = self._id3_recv(x_ids, feature_ids,self.node)
        print('')
        
        
    def _id3_recv(self,x_ids,feature_ids,node):
        """ID3アルゴリズム本体 何らかの基準まで再帰的に関数を実装する
        
        :param x_ids: インスタンスのidを格納したリスト(初期は0~999まで)
        :param faeture_ids 特徴量のids(初期は0,1,2)
        :param node ノードクラスのインスタンス
        
        :returns 
        """
        if not node:
            node = Node()
        labels_in_features = [self.labels[x] for x in x_ids]
        #もし全てのサンプルが同じクラスなら、ノードを返す
        if len(set(labels_in_features)) == 1:
            node.value = self.labels[x_ids[0]]
            return node
        
        #もしもう計算する特徴量がなければ、最も可能性の高いノードを返す
        if len(feature_ids) == 0:
            node.value = max(set(labels_in_features),key=labels_in_features.count)
            return node
        
        #最も情報利得の大きい特徴を選択する
        best_feature_name,best_feature_id = self._get_feature_max_information(x_ids,feature_ids)
        node.value = best_feature_name
        node.childs = []
        #各インスタンスの選ばれた特徴量の値
        feature_values = list(set(self.X[x][best_feature_id] for x in x_ids))
        
        for value in feature_values:
            child = Node()
            child.value = value
            node.childs.append(child)
            # 特徴量の値に合致するインスタンスのid
            child_x_ids = [x for x in x_ids if self.X[x][best_feature_id] == value]
            if not child_x_ids:
                child.next = max(set(labels_in_features), key=labels_in_features.count)
                print('')
            else:
                if feature_ids and best_feature_id in feature_ids:
                    to_remove = feature_ids.index(best_feature_id)
                    feature_ids.pop(to_remove) #もう既に使った特徴量は削除
                child.next = self._id3_recv(child_x_ids,feature_ids,child.next)
        return node
    
    
    def printTree(self):
        if not self.node:
            return
        nodes = deque()
        nodes.append(self.node)
        while len(nodes) > 0:
            node = nodes.popleft()
            print(node.value)
            if node.childs:
                for child in node.childs:
                    print('({})'.format(child.value))
                    nodes.append(child.next)
            elif node.next:
                print(node.next)        

サンプル

test.py
import numpy as np
import pandas as pd
from collections import deque

# generate some data
# define features and target values
data = {
    'wind_direction': ['N', 'S', 'E', 'W'],
    'tide': ['Low', 'High'],
    'swell_forecasting': ['small', 'medium', 'large'],
    'good_waves': ['Yes', 'No']
}

# create an empty dataframe
data_df = pd.DataFrame(columns=data.keys())

np.random.seed(42)
# randomnly create 1000 instances
for i in range(1000):
    data_df.loc[i, 'wind_direction'] = str(np.random.choice(data['wind_direction'], 1)[0])
    data_df.loc[i, 'tide'] = str(np.random.choice(data['tide'], 1)[0])
    data_df.loc[i, 'swell_forecasting'] = str(np.random.choice(data['swell_forecasting'], 1)[0])
    data_df.loc[i, 'good_waves'] = str(np.random.choice(data['good_waves'], 1)[0])

tree_clf = DecisionTreeClassifier(X=X, feature_names=feature_names, labels=y)
# run algorithm id3 to build a tree
tree_clf.id3()
tree_clf.printTree()

出力

wind_direction
(W)
(N)
(S)
(E)
swell_forecasting
(medium)
(large)
(small)
No
Yes
No
tide
(High)
(Low)
No
Yes
No
No
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?