アンサンブル学習とは
アンサンブル学習とは、複数の学習器の予測結果を組み合わせて最終的な回答を生成する機械学習の一手法です。単純な方法としては、各学習器の予測結果から多数決をとって予測結果を出力する方法があります。アンサンブル学習に用いられる各学習器は弱学習器とよばれ、単体では精度のよいものでなくても、複数組み合わせることで、精度の高いモデルを構成することができます。
高度なアンサンブル学習の手法としては、
- ブースティング
- バギング
- ブレンディング
- スタッキング
などが有名ですが、ここではその話はしません。私はどちらかというと、実装も単純で理論も明快な「多数決」によるアンサンブル学習が好きだからです。先日もKaggleのTitanicチュートリアルで四苦八苦しながら、小手先のワザとして「多数決」を使って精度向上させていました。今日は、簡単な例題を実行しながら、「多数決」の持つポテンシャルの「凄み」のようなものを一緒に体験できたらと思っています。
多数決を用いたアンサンブル学習
多数決を用いたアンサンブル学習の概念を図にすると、下記のようになります。
例えば、Titanicチュートリアルのような、「生存か死亡」を「1か0」で分類するようなタスクで考えてみましょう。分類器A~Cは、それぞれ別々に学習させて出来上がった、精度はほどほどの分類器(=弱学習器)です。「別々に学習させる」方法には、いろいろなバリエーションがあります。
- (A)同じ学習データを使って、異なるモデル(Random Forest、SVM、Logistic Regression等)を使って作った分類器A~C
- (B)同じモデルだが、分割した学習データで異なる学習をほどこした分類器A~C
- (C)同じ学習データ、同じ学習モデルを使うが、利用する特徴量の組み合わせやモデルのパラメータを変えて生成した分類器A~C
などなど。実装の手間としては、(A) > (B) > (C)になるでしょうか。これら、分類器A~Cにテストデータをそれぞれ入力し、それぞれ別の分類結果A~Cを生成します。そして、得られた3つの分類結果を見ながら、3つの分類結果のうち過半数を超えたものを、最終的な分類結果として統合していきます。
例えば、分類器A~Cの分類結果が
- 分類器A :[0, 0, 1, 0, 1]
- 分類器B :[0, 1, 0, 0. 1]
- 分類器C :[1, 1, 0, 0, 1]
だった場合、それぞれのリストの要素について多数決をとるので、 - 投票結果:[0, 1, 0, 0, 1]
になります。なんとなく「3人寄れば文殊の知恵」という感じがしますね。
さて、雰囲気としては「なんとなく良くなりそう」と思えるかもしれませんが、どのくらい良くなるもんなんでしょうか。正解率0.7くらいの凡人が3人寄り集まって投票したところで、そんなに正解に近づけるものなんでしょうか。また、3人よりも多くの凡人が集まるとどうなるでしょうか。凡人なのだから何人寄り集まっても大したことはできないのか、「集合知」のように、凡人の意見も集めてみると案外正しい結果になるのでしょうか。そこで、簡単なプログラムを書いて、確かめてみましょう。
多数決によるアンサンブル学習の実装
概念図の説明ではTitanicの例を持ち出しましたが、実際の分類タスクだと「特徴量の選択がどうのこうの」「データの傾向がどうのこうの」「欠損値がどうのこうの」といった議論に目が奪われてしまいます。そこで、ここでは、単純にランダムに生成される0, 1のリストを正解(Ground Truth)データとして、これを多数決で当てにいく、という体で実験をしていきます。つまり、正解を求めるために使えるドメイン知識や特徴量は何もなく、ただひたすらに、[0, 0, 1, 0, 1, ..., 0, 0]といった、各分類器の出力結果をうまく活用して正解を導き出す、という試みになります。
まずは、正解データ(Ground Truth)を乱数で生成します。最初は少ない次元数で実験して、後から次元数を増やしてどうなるか確認したいので、次元数をDIMとして設定することにします。また、Titanicの例では死亡率が6割程度なので、それをちょっと意識して、0/1を等確率ではなく、6割が0、4割が1になるように生成します。その際の確率をPROBで指定します。
# 正解データ(Ground Truth)をランダムに決める(0/1の生成確率は6:4にする)
import random
DIM = 20
PROB = 60
def make_ground_truth(k, prob):
return [0 if random.random() * 100 < prob else 1 for x in range(k)]
gt = make_ground_truth(DIM, PROB)
gt
結果はこんな感じになりますね。
[0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1]
次に、弱学習器の結果を生成します。まずは弱学習器の数は3つから初めましょうか。また、正解率0.75くらいの弱学習器にしたいので、50%の確率で正解データをコピーして、50%の確率で0か1をランダムに決定することにします。(これで正解率の期待値が0.5 + 0.5 × 0.5 = 0.75になります)この確率はACCで指定することにします。
# 弱学習器の出力を生成する(平均正解率を0.75程度とする)
ACC = 50
def make_trial_list(k, n, gt, acc):
trial_list = []
for i in range(n):
data = [gt[x] if random.random() * 100 < acc else random.randint(0, 1) for x in range(k)]
trial_list.append(data)
return trial_list
trial_list = make_trial_list(DIM, 3, gt, ACC)
trial_list
結果はこんな感じ。
[[0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0],
[0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1],
[0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0]]
それでは、3つの弱学習器の精度を求めてみましょう。
# 各弱学習器の精度を計算する
def eval_acc(gt, trial_data):
score = 0
for i in range(len(gt)):
if gt[i] == trial_data[i]:
score += 1
return score / len(gt)
def get_acc_list(gt, trial_list):
acc_list = [eval_acc(gt, trial_data) for trial_data in trial_list]
return acc_list
acc_list = get_acc_list(gt, trial_list)
acc_list
結果はこうなりました。まあ、ほどほどの凡人が集まっている感じですね。
[0.8, 0.75, 0.65]
それでは、3つの弱学習器の結果を多数決で統合しましょう。
# 多数決で結果を統合する
def merge_trial_list(trial_list):
res = trial_list[0].copy()
for i in range(len(trial_list[0])):
score = 0
for j in range(len(trial_list)):
score += trial_list[j][i]
if score > len(trial_list) / 2:
res[i] = 1
else:
res[i] = 0
return res
predict = merge_trial_list(trial_list)
predict
多数決で統合された分類結果は下記の通り。
[0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0]
いよいよ、正解データに対する多数決の正解率を求めてみましょう。
eval_acc関数がそのまま使えますね。
# 統合された分類結果の正解率を求める
mrgd_acc = eval_acc(gt, predict)
mrgd_acc
果たして結果は...?
0.8
でした!
「え~!なにそれ?」と思った方、ご心配なく。私もがっかりしましたから。「弱学習器Aの正解率は0.8だったのだから、何も良くなってないじゃん!」と思うのも仕方ありません。ただ、「やっぱり凡人が集まっても凡人なりの結果しか出ない」と結論づけるのは、ちょっと待ってください。そう、今回はあくまで、「3つの弱学習器を統合した結果」です。n=3を、もっとどんどん増やしていったらどうなるでしょうか?
そこで、弱学習器の数を、3つから初めて、段階的にMAX_UNIT個まで増やしながら精度を評価するプログラムを実装しましょう。こんな感じです。
# 弱学習器を3つ~MAX_UNIT個まで増やしながら精度を評価する
import numpy as np
import matplotlib.pyplot as plt
MAX_UNIT = 10
def main():
gt = make_ground_truth(DIM, PROB)
print("ground truth:", gt)
mrgd_acc_list = []
for i in range(3, MAX_UNIT + 1):
trial_list = make_trial_list(DIM, i, gt, ACC)
acc_list = get_acc_list(gt, trial_list)
print("\nnumber of units = ", i)
for i in range(len(trial_list)):
print(" unit", i, ": ", trial_list[i])
print(" acc_list =", acc_list)
avrg_acc = sum(acc_list)/len(acc_list)
print(' avrg_acc = {:.3f}'.format(avrg_acc))
predict = merge_trial_list(trial_list)
mrgd_acc = eval_acc(gt, predict)
mrgd_acc_list.append(mrgd_acc)
print(" mrgd_res:", predict)
print(' mrgd_acc = {:.3f}'.format(mrgd_acc))
plot_x = np.array([i for i in range(3, MAX_UNIT + 1)])
plot_y = np.array(mrgd_acc_list)
plt.plot(plot_x, plot_y)
main()
結果は下記のようになります。今までの処理の流れをそのまま表示しています。
項目の意味は下記の通りです。
ground truth:正解データ
unit 0~n:弱学習器の分類結果
acc_list:各弱学習器の正解率
avrg_acc:弱学習器の平均正解率
mrgd_res:多数決で統合された分類結果
mrgd_acc:多数決で統合された分類結果の正解率
ground truth: [0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0]
number of units = 3
unit 0 : [0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0]
unit 1 : [1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0]
unit 2 : [0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1]
acc_list = [0.6, 0.65, 0.55]
avrg_acc = 0.600
mrgd_res: [0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0]
mrgd_acc = 0.750
number of units = 4
unit 0 : [0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0]
unit 1 : [0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0]
unit 2 : [0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0]
unit 3 : [0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0]
acc_list = [0.85, 0.7, 0.85, 0.75]
avrg_acc = 0.787
mrgd_res: [0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0]
mrgd_acc = 0.900
number of units = 5
unit 0 : [0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1]
unit 1 : [0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0]
unit 2 : [0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1]
unit 3 : [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0]
unit 4 : [1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1]
acc_list = [0.55, 0.8, 0.9, 0.85, 0.7]
avrg_acc = 0.760
mrgd_res: [0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1]
mrgd_acc = 0.900
number of units = 6
unit 0 : [0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1]
unit 1 : [0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0]
unit 2 : [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0]
unit 3 : [0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1]
unit 4 : [0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0]
unit 5 : [1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]
acc_list = [0.85, 0.85, 0.6, 0.65, 0.8, 0.55]
avrg_acc = 0.717
mrgd_res: [0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0]
mrgd_acc = 0.900
number of units = 7
unit 0 : [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0]
unit 1 : [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0]
unit 2 : [0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0]
unit 3 : [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0]
unit 4 : [1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0]
unit 5 : [1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0]
unit 6 : [0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1]
acc_list = [0.85, 0.75, 0.6, 0.65, 0.85, 0.85, 0.75]
avrg_acc = 0.757
mrgd_res: [0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0]
mrgd_acc = 1.000
number of units = 8
unit 0 : [0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0]
unit 1 : [0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0]
unit 2 : [0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0]
unit 3 : [1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0]
unit 4 : [1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0]
unit 5 : [1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0]
unit 6 : [0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0]
unit 7 : [1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1]
acc_list = [0.9, 0.75, 0.8, 0.75, 0.75, 0.85, 0.7, 0.6]
avrg_acc = 0.762
mrgd_res: [0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0]
mrgd_acc = 0.950
number of units = 9
unit 0 : [0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0]
unit 1 : [0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1]
unit 2 : [0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0]
unit 3 : [1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0]
unit 4 : [0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1]
unit 5 : [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1]
unit 6 : [0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1]
unit 7 : [1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0]
unit 8 : [1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1]
acc_list = [0.8, 0.85, 0.8, 0.7, 0.75, 0.65, 0.8, 0.85, 0.45]
avrg_acc = 0.739
mrgd_res: [0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1]
mrgd_acc = 0.900
number of units = 10
unit 0 : [1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1]
unit 1 : [1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0]
unit 2 : [1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1]
unit 3 : [0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]
unit 4 : [0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]
unit 5 : [0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0]
unit 6 : [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0]
unit 7 : [0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1]
unit 8 : [0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0]
unit 9 : [0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0]
acc_list = [0.55, 0.65, 0.8, 0.75, 0.6, 0.7, 0.75, 0.75, 0.75, 0.65]
avrg_acc = 0.695
mrgd_res: [0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0]
mrgd_acc = 0.950
どうでしょうか?弱学習器の数が増えるにしたがって、多数決で統合された分類結果の正解率mrgd_accもどんどん向上していってますね。10個の弱学習器を使った場合は、弱学習器の平均の正解率は0.695しかないのに、多数決の結果の正解率は0.95にもなっています。いやあ、素晴らしい。
弱学習器の数に対する精度の変化をグラフにすると、下記のようになりました。
(横軸が弱学習器の数、縦軸が正解率)
弱学習器が3つの時の正解率は0.75で、さほど効果が出ていなくても、弱学習器を5つ以上にすると、安定して精度が0.9を超えていますね。
さて、ここまでは理解を簡単にするために、分類結果の次元数を20くらいに制限して実験してきました。Titanicのタスクでは、約400人分の生死を判別するため、分類結果の次元数は400くらいになります。果たして、それぐらい次元数が増えても、多数決は効果があるのでしょうか?
先ほどのソースコードの、DIMを400に、MAX_UNITを20に変更して実験してみましょう。下記のようなグラフが得られました。若干サチって来ますが、弱学習器を増やすほど正解率が向上していますね。
個々の弱学習器の能力は貧弱なのに、たくさん寄せ集めるだけでこんなに鮮やかに正解率が向上するなんてすばらしいですよね。感動します。これは、個別の弱学習器の出力には、エラーが含まれていても、そのエラーの分布がきれいに分散していると、エラー同士がお互い打ち消しあって、正しい結果が浮かび上がってくるのだと思います。
ただ、これが成立するのは本当に「各弱学習器が独立に学習し、出力結果がきれいに分散している」場合に限られます。実際には、使っている特徴量の影響や、モデルのくせによって、各弱学習器の出力に偏りが存在しており、それがバイアスとなって精度向上の妨げになってしまうので、なるべく「均等に分散するような弱学習器」を作っていくのが、「多数決」によるアンサンブル学習のコツになるかと思います。
イメージだけでは、なかなかそこまで精度向上するとは思えないのですが、実際に確かめてみると「原理的にはこんなに向上するんだ」ということが分かって面白いですね。Kaggleでも「最後の一押し」として、「多数決」によるアンサンブル学習を取り入れると良いと思います。