3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

勾配ブースティング木による侵入検知分類器の作成

Last updated at Posted at 2019-03-12

勾配ブースティング木とは

##アンサンブル学習法
アンサンブル学習は、弱学習器と呼ばれる分類性能の低い分類器を複数組み合わせ、それらの判定結果を集約し多数決をとったものを出力する方法です。

例えばあるアンケートを100人に実施し、そのアンケートのある質問に「はい」または「いいえ」で答えるようなものを考える。どちらかが多くなれば、多いほうの意見を採択するのが普通です。このように、ある問題に対して答えが2種類に分かれるとき,それは機械学習の分野では「2値分類」もしくは「2クラス分類」といいます。

アンサンブル学習には、バギング、ブースティング、スタッキングの3種類(しか知らないのですが他にあるんでしょうか…?)があり、アンサンブル学習で有名なランダムフォレストはバギングにあたります。勾配ブースティング木はその名の通りブースティングにあたります。

##勾配ブースティング木
弱学習器に決定木を用いていることからこのような名前が付けられています。先に述べたように、勾配ブースティング木はブースティングにあたるのですが、では一体ブースティングは他の方法と何が違うのでしょうか?

###ブースティングの概要
これを理解するには、バギングやスタッキングの概要を大まかに知っていなければなりません。ので先に大まかに説明します。
####バギング
入力されたトレーニングデータセットから、重複を許してブートストラップ標本を弱学習器の数だけ作成し、これを各弱学習器に与えたあと、すべての弱学習器の多数決投票を行って最終結果を得ます。しかしこれは分類の時のみであって、回帰の時は各弱学習器の平均をとって最終結果を得ます。
####スタッキング
長くなるのでこちらの記事を参照してください。
####ブースティング
ブースティングはバギングのように弱学習器を複数重ねてそれぞれがデータセットを学習する学習器を形成しますが、バギングと大きく異なるのは「各弱学習器がつながっている」ということです。具体的には、前の弱学習器の誤分類情報を優先的に学習し,その弱学習器の誤分類情報を次の弱学習器に学習させる…というようにして弱学習器の数だけこの操作を行います。ブースティングではパラメータ設定がシビアな反面、バギングよりも汎化性能が高くなる傾向があります。

#機械学習は特徴量選択と学習器パラメータ調整が命
といいつつも、ある程度適当にやってもそれなりの数字が出るのがアンサンブル学習のすごいところ。
わたしは機械学習を勉強していく中でその点に惹かれて勾配ブースティングを試すに至ったというでたらめな動機ですが、大目に見てください・・・^^;

#学習及び検証用データセット
##Kyoto 2016 Dataset
攻撃検知性能の評価のために適したデータセットを用いなければ意味がない。
データの収集期間の長さや、最新の攻撃傾向を反映できているデータセットを探した結果、前記事でも紹介した「Kyoto 2016 Dataset」を用いることにします。特徴量は全部で24種類ありますが、そのうち用いたのは論文内で用いられていた特徴量のみに絞ります。

#実験方法
##計算機仕様

種類 名前
OS Ubuntu 18.04.1 LTS
CPU Intel Core i9-9900K
RAM 64GB
RAMは私のプログラムの書き方がヘタクソ過ぎて、最大30GBくらい食ったような…
##実装
XGBoostを使いました。作業ディレクトリの中に学習・検証したい期間のデータセットを入れておく⇛読み込ませて整形⇛学習⇛検証という流れになっています。
学習及び検証時に使うデータには、攻撃サンプル1万件と正常サンプル1万件をランダムに抽出し、ばらつきを防ぐために10回試行しました。
分類精度に関する指標として、正解率(Accuracy)、適合率(Precision)、検知率(True Positive Rate : TPR)、誤検知率(False Positive Rate : FPR)を算出しました。
各指標の定義式は下のとおりです。
Accuracy=\frac{TP+TN}{TP+FP+TN+FN}\\
Precision=\frac{TP}{TP+FP}\\
TPR=\frac{TP}{TP+FN}\\
FPR=\frac{FP}{FP+TN}\\

Github

Githubに載せたコードは同時にランダムフォレストとscikit-learnのGradientBoostingClassifierクラスの結果も算出されるようになっています。

gbt.py
import numpy as np
import gc
import os
import sys
import time
import random
import xgboost as xgb
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, roc_curve

#===== Designation of target year / month =====#
monthA = ["200611", "200612"]                  #
monthB = ["200701", "200702"]                  #
monthC = ["200703", "200704"]                  #
monthD = ["200705", "200706"]                  #
monthE = ["200707", "200708"]                  #
monthF = ["200709", "200710"]                  #
monthG = ["200711", "200712"]                  #
monthH = ["200801", "200802"]                  #
monthI = ["200803", "200804"]                  #
monthJ = ["200805", "200806"]                  #
monthK = ["200807", "200808"]                  #
monthL = ["200809", "200810"]                  #
monthM = ["200811", "200812"]                  #
monthN = ["201311", "201312"]                  #
monthO = ["201401", "201402"]                  #
monthP = ["201403", "201404"]                  #
monthQ = ["201405", "201406"]                  #
monthR = ["201407", "201408"]                  #
monthS = ["201409", "201410"]                  #
monthT = ["201411", "201412"]                  #
monthU = ["201501", "201502"]                  #
monthV = ["201503", "201504"]                  #
monthW = ["201505", "201506"]                  #
monthX = ["201507", "201508"]                  #
monthY = ["201509", "201510"]                  #
monthZ = ["201511", "201512"]                  #
#==============================================#



#================================= CREATE DATA =================================#
def create_data(month):                                                         #
    pos = []                                                                    #
    neg = []                                                                    #
    for mon in month:                                                           #
        for i in range(1, 32):                                                  #
            data = mon + ('%02d.txt' %i)                                        #
            if os.path.exists(data):                                            #
                for line in open(data, 'rb'):                                   #
                    ls = line.split()                                           #
                    a = int(ls[17])     #                                       #
                    b = float(ls[0])    #                                       #
                    c = float(ls[2])    #                                       #
                    d = float(ls[3])    #                                       #
                    e = float(ls[4])    #                                       #
                    f = float(ls[5])    #                                       #
                    g = float(ls[6])    #                                       #
                    h = float(ls[7])    #                                       #
                    i = float(ls[8])    #                                       #
                    j = float(ls[9])    #                                       #
                    k = float(ls[10])   #                                       #
                    l = float(ls[11])   #                                       #
                    m = float(ls[12])   #                                       #
                    if a > 0:                                                   #
                        a = 1                                                   #
                        pos.append([b, c, d, e, f, g, h, i, j, k, l, m])        #
                    else:                                                       #
                        a = 0                                                   #
                        neg.append([b, c, d, e, f, g, h, i, j, k, l, m])        #
    random.shuffle(pos)                                                         #
    random.shuffle(neg)                                                         #
    return pos, neg                                                             #
#===============================================================================#

process_start = time.time()
start = time.time()
pos, neg = create_data(monthT)
end = time.time()
print("===== Training Data Reading Time =====")
print("{0}[s]".format(end - start))
start = time.time()
post, negt = create_data(monthU)
end = time.time()
print("===== Test Data Reading Time =====")
print("{0}[s]".format(end - start))

xacc = 0
xpre = 0 
xtpr = 0
xfpr = 0
xf1 = 0
xtime = 0

xgb = xgb.XGBClassifier(learning_rate =0.45,
                        n_estimators=90,
                        max_depth=6,
                        min_child_weight=1,
                        gamma=0,
                        subsample=0.75,
                        colsample_bytree=0.8,
                        objective= 'binary:logistic',
                        nthread=4,                   
                        scale_pos_weight=1,
                        seed=27)

for i in range(10):
    print("\n------------------------- TRIAL {0} -------------------------".format(i + 1))
    
    random.shuffle(pos)
    random.shuffle(neg)
    X = neg[0:10000] + pos[0:10000]
    X_train = np.array(X)
    y = [0]*10000 + [1]*10000
    y_train = np.array(y)

    random.shuffle(post)
    random.shuffle(negt)
    Xt = negt[0:10000] + post[0:10000]
    X_test = np.array(Xt)
    yt = [0]*10000 + [1]*10000
    y_test = np.array(yt)

#============================== XGBoost ============================#
    xgb = xgb.fit(X_train, y_train)                                 #
    y_train_pred = xgb.predict(X_train)                             #
    xtest_start = time.time()                                       #
    y_test_pred = xgb.predict(X_test)                               #
    xtest_end = time.time()                                         #
    xgb_train = accuracy_score(y_train, y_train_pred)               #
    tp, fn, fp, tn = confusion_matrix(y_test, y_test_pred).ravel()  #
    TP = float(tp)                                                  #
    FP = float(fp)                                                  #
    TN = float(tn)                                                  #
    FN = float(fn)                                                  #
    acc = 100*(TP+TN)/(TP+FP+TN+FN)                                 #
    pre = 100*TP/(TP+FP)                                            #
    tpr = 100*TP/(TP+FN)                                            #
    fpr = 100*FP/(FP+TN)                                            #
    F1 = 2*pre*tpr/(pre+tpr)                                        #
    timer = xtest_end - xtest_start                                 #
    xacc += acc                                                     #
    xpre += pre                                                     #
    xtpr += tpr                                                     #
    xfpr += fpr                                                     #
    xf1 += F1                                                       #
    xtime += timer                                                  #
    print("\n===== XGBoost =====")                                  #
    print("TP\tFP\tTN\tFN")                                         #
    print("{0}\t{1}\t{2}\t{3}".format(tp, fp, tn, fn))              #
    print("\nXGBoost train = {0}".format(xgb_train))                #
    print("XGBoost Accuracy = {0}".format(acc))                     #
    print("XGBoost Precision = {0}".format(pre))                    #
    print("XGBoost Recall (TPR) = {0}".format(tpr))                 #
    print("XGBoost FPR = {0}".format(fpr))                          #
    print("XGBoost F1-Score = {0}".format(F1))                      #
    print("Test time = {0}".format(timer))                          #
#===================================================================#

print("\n\n= = = = = = = = = = RESULT = = = = = = = = = =")

print("\n==== XGBoost ====")
print("\tAccuracy = {0}".format(xacc/10))
print("\tPrecision = {0}".format(xpre/10))
print("\tTrue Positive Rate = {0}".format(xtpr/10))
print("\tFalse Positive Rate = {0}".format(xfpr/10))
print("\tF1-Score = {0}".format(xf1/10))
print("\tTest time = {0}".format(xtime/10))

print("\n= = = = = = = = = = = = = = = = = = = = = = =")

process_end = time.time()
print("\n===== PROCESS TIME =====")
time = process_end - process_start
minute = int(time / 60)
second = time % 60
print("{0}:{1}".format(minute, second))

#実験結果

正解率 適合率 検知率 誤検知率
96.22% 95.45% 97.46% 5.10%

各期間で学習し、各期間で検証した平均値です。内訳は省きましたが、かなりいい精度で分類できました。

#展望
特徴量選択をもうちょっと頑張ってみようかなと思います。

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?