1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

KaggleのHouse Sales in King County, USAのデータセットを使って、機械学習を行い、APIサーバーにするまでの手順

Last updated at Posted at 2020-11-01

##はじめに
KaggleのHouse Sales in King County, USAのデータセットを使って、XGboost機械学習で学習モデルを生成して、その学習モデルをFlaskでAPIサーバーにするというのをやりました。この機械学習のAPIサーバーは、主に4つの手順で行っています。最初にHouse Sailsのデータを把握するために、EDA(Explanatory Data Analysis)探索的データ解析を行って、データの状況を把握します。次に、機械学習で学習させるためのデータになるように前処理を行っています。その次に、機械学習で学習モデルを生成します。今回は、XGboostを使用しています。最後に、FlasでAPIサーバーを実装していきます。

このプログラムを実行するのに必要な環境

Anaconda、XGBoost、joblib、Flask、flask-corsなどのライブラリがインストールされている。

この機械学習によるAPIサーバーの実装は以下の4つの処理によって行っています。

  • House Sailsのデータを把握する(EDA)
  • データセットの前処理を行う
  • 機械学習で学習モデルを作る
  • flaskでAPIサーバーを作る

ライブラリとデータセットの読み込み

まずは、必要なラブラリとKaggleからダウンロードしてきたHouse Sailsのデータセットを読み込んできます。set_optionを指定すると、表示するカラム数を指定することができます。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
pd.set_option('display.max_columns', 4)
df = pd.read_csv('house_sales/kc_house_data.csv')
df.head()
id date ... sqft_living15 sqft_lot15
0 7129300520 20141013T000000 ... 1340 5650
1 6414100192 20141209T000000 ... 1690 7639
2 5631500400 20150225T000000 ... 2720 8062
3 2487200875 20141209T000000 ... 1360 5000
4 1954400510 20150218T000000 ... 1800 7503
5 rows × 21 columns

sqft_livingのヒストグラムを表示してみる。

まずは、特徴量のうち、sqft_livingのヒストグラムを表示しています。これを見えると、大きな値があるので、標準正規分布の形からはずれていることがわかります。

plt.figure(figsize = (12,8))
plt.hist(df["sqft_living"])
plt.savefig('House_Sales_Explanatory Data Analysis_hist01.png', bbox_inches='tight')

House_Sales_Explanatory Data Analysis_hist01.png

priceのヒストグラムを表示してみる。

今回予測する値となる、priceについてもヒストグラムを表示してみます。これも、先ほどと同様に、大きな値、外れ値があるため、左によった形状をしています。

plt.figure(figsize = (12,8))
plt.hist(df["price"])
plt.savefig('House_Sales_Explanatory Data Analysis_hist02.png', bbox_inches='tight')

House_Sales_Explanatory Data Analysis_hist02.png

四分位範囲の処理を行って外れ値を削除する

この関数は、四分位範囲の処理をしています。これによって、外れ値となるデータを削除します。

def outlier_iqr(df, columns = None):
    if columns == None:
        columns = df.columns
        
    for col in columns:
        q1 = df[col].describe()['25%']
        q3 = df[col].describe()['75%']
        
        iqr = q3 - q1
        
        outlier_min = q1 - iqr * 1.5
        outlier_max = q3 + iqr * 1.5
        
        df = df[(df[col] >= outlier_min) & (df[col] <= outlier_max)]
        
    return df
df_1 = outlier_iqr(df, columns = ['price'])
df_1.shape

(20454, 21)

再度priceをヒストグラムで表示してみる。

四分位範囲の処理を行ったpriceのヒストグラムをもう一度表示してみます。今度は、外れ値を削除したので、正規分布に近い形状のヒストグラムになっています。shapeでデータの形状を確認してみると、先ほどよりは多少データの数は減っていますが、大きくは減っていないことも確認できます。

plt.figure(figsize = (12,8))
plt.hist(df_1["price"])
plt.savefig('House_Sales_Explanatory Data Analysis_hist03.png', bbox_inches='tight')

House_Sales_Explanatory Data Analysis_hist03.png

df.shape

(21613, 21)

数値データのヒストグラムを全部表示してみる。

一応、確認のために、他のデータのヒストグラムの形状も確認してみます。概ね、バランスのとれたヒストグラムの形状になっていることが確認できます。

fig, axes = plt.subplots(2,3, figsize = (18, 12))

axes.ravel()[0].hist(df_1["sqft_living"])
axes.ravel()[1].hist(df_1["sqft_above"])
axes.ravel()[2].hist(df_1["sqft_basement"])
axes.ravel()[3].hist(df_1["lat"])
axes.ravel()[4].hist(df_1["long"])
axes.ravel()[5].hist(df_1["sqft_living15"])

axes.ravel()[0].set_title("sqft_living")
axes.ravel()[1].set_title("sqft_above")
axes.ravel()[2].set_title("sqft_basement")
axes.ravel()[3].set_title("lat")
axes.ravel()[4].set_title("long")
axes.ravel()[5].set_title("sqft_living15")

plt.savefig('House_Sales_Explanatory Data Analysis_hist04.png', bbox_inches='tight')

House_Sales_Explanatory Data Analysis_hist04.png

不必要な特徴量を削除する。

ここでは、学習の際に不要となる特徴量を削除します。ここでは、id、date、sqft_lot、sqft_lot15、zipcodeを削除しています。

df_1 = df_1.drop(columns = ['id', 'date', 'sqft_lot','sqft_lot15','zipcode'])

建築年を築年数に変える。

yr_builtという特徴量は、建築年なので、建物が建てられた年がデータにあります。このままだと、学習データとして扱うのは難しいので、築年数とした新たな特徴量を加えています。また、yr_renovatedも改装された年になっているので、これも改装されてからの年数に変えています。

df_1["age"] = 2020 - df_1["yr_built"]
df_1.loc[(df_1['yr_renovated'] == 0), 'yr_renovated'] = 2020

数値の特徴量を標準化する。

数値のデータは、ここではStandardScalerを使って標準化しています。

from sklearn.preprocessing import StandardScaler
num_feature = ['sqft_living', 'sqft_above', 'sqft_basement', 'lat', 'long', 'sqft_living15']

for col in num_feature:
    scaler = StandardScaler()
    df_1[col] = scaler.fit_transform(np.array(df_1[col].values).reshape(-1, 1))

再度、数値データのヒストグラムを全部表示してみる。

一応、再度、数値データのヒストグラムを確認しています。先ほど、表示した形状と変わりがないことが確認できます。

fig, axes = plt.subplots(2,3, figsize = (18, 12))

axes.ravel()[0].hist(df_1["sqft_living"])
axes.ravel()[1].hist(df_1["sqft_above"])
axes.ravel()[2].hist(df_1["sqft_basement"])
axes.ravel()[3].hist(df_1["lat"])
axes.ravel()[4].hist(df_1["long"])
axes.ravel()[5].hist(df_1["sqft_living15"])

axes.ravel()[0].set_title("sqft_living")
axes.ravel()[1].set_title("sqft_above")
axes.ravel()[2].set_title("sqft_basement")
axes.ravel()[3].set_title("lat")
axes.ravel()[4].set_title("long")
axes.ravel()[5].set_title("sqft_living15")

plt.savefig('House_Sales_Explanatory Data Analysis_hist05.png', bbox_inches='tight')

House_Sales_Explanatory Data Analysis_hist05.png

正解データをcsvで保存する

最後に、機械学習用のデータとしてCSVに一旦保存しておきます。

df_price = df_1["price"]
df_price.to_csv('House_Sales_Explanatory_Price.csv')

学習データをcsvで保存する

学習データは、不必要なデータを削除して、カテゴリカルなデータは、get_dummiesでダミー変数に変換したものをcsvとして保存しています。

df_1 = df_1.drop(columns = ['price', 'yr_built', 'yr_renovated'])
df_1 = pd.get_dummies(df_1, columns = ['bedrooms', 'bathrooms', 'floors', 'waterfront', 'view', 'condition', 'grade', 'age', 'renovated_age'], drop_first = True)
df_1.to_csv('House_Sales_Explanatory_Preprocessing.csv')

前処理を行ったデータを使って機械学習を行う。

ここでは、先ほどまでに行った前処理をしたデータを使って機械学習をしていきます。なので、別のプロジェクトとしておいたほうが良いです。

必要なライブラリをインポートする

あらためて今回の処理に必要なライブラリをインポートしてきます。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

CSVファイルをpandasとして読み込む

先ほど、保存した前処理済みのCSVファイルをpandasとして読み込みます。Unnamedという不要な特徴量があったので、それは削除しておきます。set_optionで表示される特徴量は4つにしています。

pd.set_option('display.max_columns', 4)
df = pd.read_csv('House_Sales_Explanatory_Preprocessing.csv')
df = df.drop(columns = ['Unnamed: 0'])
df.head()
sqft_living sqft_above ... renovated_age_80 renovated_age_86
0 -1.026685 -0.725963 ... 0
1 0.769106 0.635702 ... 0
2 -1.556379 -1.289885 ... 0
3 -0.018975 -0.904768 ... 0
4 -0.380717 -0.038253 ... 0

CSVファイルをpandasとして読み込む(price)

正解データとなるpriceも同じにように読み込んできます。

price
0 221900.0
1 538000.0
2 180000.0
3 604000.0
4 510000.0
df_price = pd.read_csv('House_Sales_Explanatory_Price.csv', header=None, names=['price'])
df_price.head()

使用する特徴量だけのデータフレームを作る。

今回の機械学習では、事前の学習の結果、前処理を行ったすべてのデータを使った学習モデルと、重要度の高くない特徴量を削除したデータを使った学習モデルでの学習精度に大きな差が出なかったこともあって、以下の特長量だけを使うことにしています。基本的には数値系の特長量はすべて使い、カテゴリカルな特長量に関しては、gradeのみを残して、それ以外のカテゴリカルなデータを削除しています。理由としては、最終的に機械学習モデルを使ったアプリケーションを開発する際に、複数のカテゴリカルなデータのフロントエンドの実装に手間がかかると予想したことが主な理由です。

df = df[["sqft_living","sqft_above","sqft_basement","lat","long","sqft_living15","grade_3","grade_4","grade_5","grade_6","grade_7"
        ,"grade_8","grade_9","grade_10","grade_11","grade_12"]]
sqft_living sqft_above ... grade_11 grade_12
0 -1.026685 -0.725963 ... 0 0
1 0.769106 0.635702 ... 0 0
2 -1.556379 -1.289885 ... 0 0
3 -0.018975 -0.904768 ... 0 0
4 -0.380717 -0.038253 ... 0 0
5 rows × 16 columns

XGboostで学習モデルを生成する

ここでは、XGboostをインポートして、機械学習を行っています。パラーメータは、ほぼデフォルトのままです。

import xgboost as xgb
X_train, X_test, y_train, y_test = train_test_split(df, df_price, random_state = 0)
params = {
    'silent' : 1,
    'max_depth' : 6,
    'min_chiled_weight' : 1,
    'eta' : 0.1,
    'tree_method' : 'exact',
    'objective' : 'reg:linear',
    'eval_metric' : 'rmse',
    'predictor' : 'cpu_predictor'
}

dtrain = xgb.DMatrix(X_train, label = y_train)
dtest = xgb.DMatrix(X_test, label = y_test)

model = xgb.train(params = params,
                 dtrain = dtrain,
                 num_boost_round = 200,
                 early_stopping_rounds = 10,
                 evals = [(dtest, 'test')])
[0]	test-rmse:471544
Will train until test-rmse hasn't improved in 10 rounds.
[1]	test-rmse:427350
[2]	test-rmse:387757
[3]	test-rmse:352314
[4]	test-rmse:320602
[5]	test-rmse:292132
[6]	test-rmse:266667
[7]	test-rmse:244148
[8]	test-rmse:223983
[9]	test-rmse:206046
[10]	test-rmse:190112
[11]	test-rmse:176111
[12]	test-rmse:163754
[13]	test-rmse:152820
[14]	test-rmse:143269
[15]	test-rmse:134879
[16]	test-rmse:127772
[17]	test-rmse:121362
[18]	test-rmse:115939
[19]	test-rmse:111405
[20]	test-rmse:107280
[21]	test-rmse:103750
[22]	test-rmse:100928
[23]	test-rmse:98446.5
[24]	test-rmse:96280.3
[25]	test-rmse:94419.2
[26]	test-rmse:92933.6
[27]	test-rmse:91644.1
[28]	test-rmse:90581.3
[29]	test-rmse:89422.8
[30]	test-rmse:88575.7
[31]	test-rmse:88038.8
[32]	test-rmse:87254.6
[33]	test-rmse:86857.1
[34]	test-rmse:86527.8
[35]	test-rmse:86238.3
[36]	test-rmse:85950
[37]	test-rmse:85705
[38]	test-rmse:85532.4
[39]	test-rmse:85346.7
[40]	test-rmse:85204.1
[41]	test-rmse:85058.9
[42]	test-rmse:84926.7
[43]	test-rmse:84845.4
[44]	test-rmse:84671.9
[45]	test-rmse:84539.6
[46]	test-rmse:84380.6
[47]	test-rmse:84287.2
[48]	test-rmse:84254.7
[49]	test-rmse:84168.9
[50]	test-rmse:84106.6
[51]	test-rmse:83858.5
[52]	test-rmse:83829.8
[53]	test-rmse:83809.5
[54]	test-rmse:83726
[55]	test-rmse:83704.2
[56]	test-rmse:83650.4
[57]	test-rmse:83422.6
[58]	test-rmse:83405.8
[59]	test-rmse:83281
[60]	test-rmse:83293.6
[61]	test-rmse:83289.4
[62]	test-rmse:83251.9
[63]	test-rmse:83237.5
[64]	test-rmse:83055.6
[65]	test-rmse:83051.9
[66]	test-rmse:82938.8
[67]	test-rmse:82932.7
[68]	test-rmse:82933.2
[69]	test-rmse:82859
[70]	test-rmse:82829.6
[71]	test-rmse:82840.5
[72]	test-rmse:82823
[73]	test-rmse:82827.4
[74]	test-rmse:82834.6
[75]	test-rmse:82845.9
[76]	test-rmse:82839.4
[77]	test-rmse:82828.5
[78]	test-rmse:82829.7
[79]	test-rmse:82651.8
[80]	test-rmse:82660
[81]	test-rmse:82637.3
[82]	test-rmse:82514.6
[83]	test-rmse:82497.6
[84]	test-rmse:82484.7
[85]	test-rmse:82486.3
[86]	test-rmse:82486.8
[87]	test-rmse:82496
[88]	test-rmse:82491.4
[89]	test-rmse:82486.6
[90]	test-rmse:82290.3
[91]	test-rmse:82265.1
[92]	test-rmse:82261.5
[93]	test-rmse:82236.5
[94]	test-rmse:82236.4
[95]	test-rmse:82111.9
[96]	test-rmse:82111.1
[97]	test-rmse:82111.3
[98]	test-rmse:82108
[99]	test-rmse:82097.1
[100]	test-rmse:82077.4
[101]	test-rmse:82041.9
[102]	test-rmse:82040
[103]	test-rmse:82042.6
[104]	test-rmse:82044.2
[105]	test-rmse:82033.7
[106]	test-rmse:82041.1
[107]	test-rmse:82028.4
[108]	test-rmse:82030.7
[109]	test-rmse:82036.4
[110]	test-rmse:82028.6
[111]	test-rmse:82020.3
[112]	test-rmse:82025.5
[113]	test-rmse:82024.9
[114]	test-rmse:82034
[115]	test-rmse:82025.2
[116]	test-rmse:81957.5
[117]	test-rmse:81950.9
[118]	test-rmse:81959.8
[119]	test-rmse:81936.7
[120]	test-rmse:81935.9
[121]	test-rmse:81937
[122]	test-rmse:81945.8
[123]	test-rmse:81894.8
[124]	test-rmse:81885.2
[125]	test-rmse:81899.3
[126]	test-rmse:81877
[127]	test-rmse:81875.7
[128]	test-rmse:81859.6
[129]	test-rmse:81849.7
[130]	test-rmse:81851.2
[131]	test-rmse:81839.4
[132]	test-rmse:81850.8
[133]	test-rmse:81846
[134]	test-rmse:81836.2
[135]	test-rmse:81827.2
[136]	test-rmse:81832.3
[137]	test-rmse:81859.6
[138]	test-rmse:81856.6
[139]	test-rmse:81850
[140]	test-rmse:81847.6
[141]	test-rmse:81842.8
[142]	test-rmse:81794.5
[143]	test-rmse:81803.8
[144]	test-rmse:81829.3
[145]	test-rmse:81815.9
[146]	test-rmse:81813.6
[147]	test-rmse:81741
[148]	test-rmse:81728.8
[149]	test-rmse:81714.4
[150]	test-rmse:81708.6
[151]	test-rmse:81592.3
[152]	test-rmse:81621.7
[153]	test-rmse:81624.8
[154]	test-rmse:81629.3
[155]	test-rmse:81615.7
[156]	test-rmse:81617.7
[157]	test-rmse:81613.9
[158]	test-rmse:81612.9
[159]	test-rmse:81594.9
[160]	test-rmse:81595.1
[161]	test-rmse:81581.7
[162]	test-rmse:81595.3
[163]	test-rmse:81603.8
[164]	test-rmse:81601.2
[165]	test-rmse:81600.5
[166]	test-rmse:81552.3
[167]	test-rmse:81557.6
[168]	test-rmse:81565.5
[169]	test-rmse:81566.6
[170]	test-rmse:81581.9
[171]	test-rmse:81570.5
[172]	test-rmse:81571.8
[173]	test-rmse:81569.4
[174]	test-rmse:81494.3
[175]	test-rmse:81476.3
[176]	test-rmse:81454
[177]	test-rmse:81422.6
[178]	test-rmse:81426.1
[179]	test-rmse:81410.8
[180]	test-rmse:81425.1
[181]	test-rmse:81418.2
[182]	test-rmse:81419.4
[183]	test-rmse:81409.6
[184]	test-rmse:81392.1
[185]	test-rmse:81389.3
[186]	test-rmse:81391.1
[187]	test-rmse:81414.5
[188]	test-rmse:81369.9
[189]	test-rmse:81368.3
[190]	test-rmse:81358.4
[191]	test-rmse:81347.7
[192]	test-rmse:81355.4
[193]	test-rmse:81349.2
[194]	test-rmse:81343
[195]	test-rmse:81346.3
[196]	test-rmse:81345.5
[197]	test-rmse:81374.6
[198]	test-rmse:81358.5
[199]	test-rmse:81359.4

グリッドサーチのためのパラメータを生成する

もう少し、モデルの精度がないかを試すために、グリッドサーチをやっています。最初に、グリッドサーチをするためのパラメータを生成しています。

gridsearch_params = [
    (max_depth, eta)
    for max_depth in [6, 7, 8]
    for eta in [0.1, 0.05, 0.01]
]
gridsearch_params

[(6, 0.1),
(6, 0.05),
(6, 0.01),
(7, 0.1),
(7, 0.05),
(7, 0.01),
(8, 0.1),
(8, 0.05),
(8, 0.01)]

もっとも精度の良かったパラメータを計算する

ここでは、どの組み合わせのパラメータがもっとも精度が高くなるかを計算しています。結果は、Best params (8, 0.01)となりました。

min_rmse = float('Inf')

best_param = []

for max_depth, eta in gridsearch_params:
    print('max_depth = {}, eta = {}'.format(max_depth, eta))
    
    params['max_depth'] = max_depth
    params['eta'] = eta
    
    cv_results = xgb.cv(
        params,
        dtrain,
        num_boost_round = 1000,
        seed = 0,
        nfold = 5,
        metrics = {'rmse'},
        early_stopping_rounds = 5
    )
    
    mean_rmse = cv_results['test-rmse-mean'].min()
    boost_rounds = cv_results['test-rmse-mean'].argmin()
    print('RMSE {} for {} rounds'.format(mean_rmse, boost_rounds))
    if mean_rmse < min_rmse:
        min_rmse = mean_rmse
        best_param = (max_depth, eta)

print('Best params {}, RMSE {}'.format(best_param, min_rmse))
max_depth = 6, eta = 0.1
RMSE 81689.0296874 for 123 rounds
max_depth = 6, eta = 0.05
RMSE 81545.2953126 for 267 rounds
max_depth = 6, eta = 0.01
RMSE 82118.7765624 for 999 rounds
max_depth = 7, eta = 0.1
RMSE 81372.990625 for 161 rounds
max_depth = 7, eta = 0.05
RMSE 81372.7171876 for 202 rounds
max_depth = 7, eta = 0.01
RMSE 81308.89999979999 for 999 rounds
max_depth = 8, eta = 0.1
RMSE 81277.4515624 for 96 rounds
max_depth = 8, eta = 0.05
RMSE 81155.2687498 for 201 rounds
max_depth = 8, eta = 0.01
RMSE 81080.3156252 for 849 rounds
Best params (8, 0.01), RMSE 81080.3156252

再度パラメータを変更して学習モデルを生成する

先ほどのグリッドサーチによって、計算されたパラメータを使って再度学習モデルを生成します。

params['max_depth'] = 8
params['eta'] = 0.01

model = xgb.train(params = params,
                 dtrain = dtrain,
                 num_boost_round = 1000,
                 early_stopping_rounds = 5,
                 evals = [(dtest, 'test')])
[0]	test-rmse:515961
Will train until test-rmse hasn't improved in 5 rounds.
[1]	test-rmse:511040
[2]	test-rmse:506173
[3]	test-rmse:501356
[4]	test-rmse:496588
[5]	test-rmse:491874
[6]	test-rmse:487210
[7]	test-rmse:482584
[8]	test-rmse:478012
[9]	test-rmse:473483
[10]	test-rmse:468999
[11]	test-rmse:464566
[12]	test-rmse:460175
[13]	test-rmse:455822
[14]	test-rmse:451526
[15]	test-rmse:447271
[16]	test-rmse:443066
[17]	test-rmse:438893
[18]	test-rmse:434766
[19]	test-rmse:430688
[20]	test-rmse:426657
[21]	test-rmse:422664
[22]	test-rmse:418718
[23]	test-rmse:414801
[24]	test-rmse:410941
[25]	test-rmse:407098
[26]	test-rmse:403307
[27]	test-rmse:399547
[28]	test-rmse:395829
[29]	test-rmse:392152
[30]	test-rmse:388512
[31]	test-rmse:384923
[32]	test-rmse:381353
[33]	test-rmse:377832
[34]	test-rmse:374343
[35]	test-rmse:370902
[36]	test-rmse:367481
[37]	test-rmse:364100
[38]	test-rmse:360764
[39]	test-rmse:357461
[40]	test-rmse:354187
[41]	test-rmse:350960
[42]	test-rmse:347750
[43]	test-rmse:344572
[44]	test-rmse:341426
[45]	test-rmse:338314
[46]	test-rmse:335245
[47]	test-rmse:332213
[48]	test-rmse:329190
[49]	test-rmse:326207
[50]	test-rmse:323269
[51]	test-rmse:320358
[52]	test-rmse:317476
[53]	test-rmse:314619
[54]	test-rmse:311801
[55]	test-rmse:309001
[56]	test-rmse:306233
[57]	test-rmse:303498
[58]	test-rmse:300790
[59]	test-rmse:298113
[60]	test-rmse:295470
[61]	test-rmse:292855
[62]	test-rmse:290269
[63]	test-rmse:287704
[64]	test-rmse:285165
[65]	test-rmse:282666
[66]	test-rmse:280184
[67]	test-rmse:277741
[68]	test-rmse:275313
[69]	test-rmse:272907
[70]	test-rmse:270532
[71]	test-rmse:268190
[72]	test-rmse:265866
[73]	test-rmse:263562
[74]	test-rmse:261295
[75]	test-rmse:259046
[76]	test-rmse:256816
[77]	test-rmse:254625
[78]	test-rmse:252448
[79]	test-rmse:250296
[80]	test-rmse:248164
[81]	test-rmse:246063
[82]	test-rmse:243969
[83]	test-rmse:241918
[84]	test-rmse:239880
[85]	test-rmse:237853
[86]	test-rmse:235862
[87]	test-rmse:233881
[88]	test-rmse:231939
[89]	test-rmse:230004
[90]	test-rmse:228095
[91]	test-rmse:226206
[92]	test-rmse:224346
[93]	test-rmse:222499
[94]	test-rmse:220670
[95]	test-rmse:218861
[96]	test-rmse:217075
[97]	test-rmse:215311
[98]	test-rmse:213567
[99]	test-rmse:211827
[100]	test-rmse:210120
[101]	test-rmse:208421
[102]	test-rmse:206737
[103]	test-rmse:205088
[104]	test-rmse:203457
[105]	test-rmse:201827
[106]	test-rmse:200212
[107]	test-rmse:198636
[108]	test-rmse:197085
[109]	test-rmse:195530
[110]	test-rmse:194010
[111]	test-rmse:192494
[112]	test-rmse:191011
[113]	test-rmse:189524
[114]	test-rmse:188077
[115]	test-rmse:186631
[116]	test-rmse:185212
[117]	test-rmse:183809
[118]	test-rmse:182411
[119]	test-rmse:181043
[120]	test-rmse:179675
[121]	test-rmse:178325
[122]	test-rmse:177006
[123]	test-rmse:175698
[124]	test-rmse:174401
[125]	test-rmse:173124
[126]	test-rmse:171857
[127]	test-rmse:170612
[128]	test-rmse:169374
[129]	test-rmse:168161
[130]	test-rmse:166952
[131]	test-rmse:165766
[132]	test-rmse:164596
[133]	test-rmse:163434
[134]	test-rmse:162284
[135]	test-rmse:161156
[136]	test-rmse:160034
[137]	test-rmse:158914
[138]	test-rmse:157821
[139]	test-rmse:156745
[140]	test-rmse:155676
[141]	test-rmse:154628
[142]	test-rmse:153593
[143]	test-rmse:152568
[144]	test-rmse:151559
[145]	test-rmse:150558
[146]	test-rmse:149572
[147]	test-rmse:148603
[148]	test-rmse:147644
[149]	test-rmse:146701
[150]	test-rmse:145766
[151]	test-rmse:144831
[152]	test-rmse:143911
[153]	test-rmse:143000
[154]	test-rmse:142102
[155]	test-rmse:141215
[156]	test-rmse:140345
[157]	test-rmse:139482
[158]	test-rmse:138627
[159]	test-rmse:137799
[160]	test-rmse:136970
[161]	test-rmse:136155
[162]	test-rmse:135347
[163]	test-rmse:134549
[164]	test-rmse:133771
[165]	test-rmse:132997
[166]	test-rmse:132246
[167]	test-rmse:131489
[168]	test-rmse:130746
[169]	test-rmse:130024
[170]	test-rmse:129296
[171]	test-rmse:128587
[172]	test-rmse:127886
[173]	test-rmse:127192
[174]	test-rmse:126505
[175]	test-rmse:125824
[176]	test-rmse:125160
[177]	test-rmse:124501
[178]	test-rmse:123857
[179]	test-rmse:123216
[180]	test-rmse:122583
[181]	test-rmse:121954
[182]	test-rmse:121339
[183]	test-rmse:120737
[184]	test-rmse:120148
[185]	test-rmse:119561
[186]	test-rmse:118982
[187]	test-rmse:118408
[188]	test-rmse:117840
[189]	test-rmse:117286
[190]	test-rmse:116739
[191]	test-rmse:116198
[192]	test-rmse:115670
[193]	test-rmse:115143
[194]	test-rmse:114633
[195]	test-rmse:114128
[196]	test-rmse:113628
[197]	test-rmse:113133
[198]	test-rmse:112648
[199]	test-rmse:112167
[200]	test-rmse:111694
[201]	test-rmse:111232
[202]	test-rmse:110769
[203]	test-rmse:110309
[204]	test-rmse:109870
[205]	test-rmse:109429
[206]	test-rmse:109001
[207]	test-rmse:108584
[208]	test-rmse:108159
[209]	test-rmse:107745
[210]	test-rmse:107338
[211]	test-rmse:106934
[212]	test-rmse:106543
[213]	test-rmse:106161
[214]	test-rmse:105774
[215]	test-rmse:105404
[216]	test-rmse:105032
[217]	test-rmse:104666
[218]	test-rmse:104306
[219]	test-rmse:103951
[220]	test-rmse:103605
[221]	test-rmse:103256
[222]	test-rmse:102918
[223]	test-rmse:102581
[224]	test-rmse:102258
[225]	test-rmse:101929
[226]	test-rmse:101614
[227]	test-rmse:101305
[228]	test-rmse:101001
[229]	test-rmse:100687
[230]	test-rmse:100393
[231]	test-rmse:100106
[232]	test-rmse:99803.7
[233]	test-rmse:99521.5
[234]	test-rmse:99228
[235]	test-rmse:98952.9
[236]	test-rmse:98687
[237]	test-rmse:98407.6
[238]	test-rmse:98145.9
[239]	test-rmse:97895.6
[240]	test-rmse:97630.3
[241]	test-rmse:97373.5
[242]	test-rmse:97131.5
[243]	test-rmse:96879.7
[244]	test-rmse:96638.5
[245]	test-rmse:96409.2
[246]	test-rmse:96174.5
[247]	test-rmse:95950.8
[248]	test-rmse:95724
[249]	test-rmse:95504.3
[250]	test-rmse:95286
[251]	test-rmse:95063.2
[252]	test-rmse:94852.8
[253]	test-rmse:94646.3
[254]	test-rmse:94438.7
[255]	test-rmse:94227.9
[256]	test-rmse:94032
[257]	test-rmse:93828.1
[258]	test-rmse:93637
[259]	test-rmse:93447.4
[260]	test-rmse:93264
[261]	test-rmse:93072.1
[262]	test-rmse:92886.1
[263]	test-rmse:92699.4
[264]	test-rmse:92519.7
[265]	test-rmse:92341.1
[266]	test-rmse:92158.9
[267]	test-rmse:91984.2
[268]	test-rmse:91818.9
[269]	test-rmse:91667.4
[270]	test-rmse:91508.6
[271]	test-rmse:91340.9
[272]	test-rmse:91179.8
[273]	test-rmse:91036.4
[274]	test-rmse:90880.2
[275]	test-rmse:90730.6
[276]	test-rmse:90586.1
[277]	test-rmse:90440.3
[278]	test-rmse:90301.1
[279]	test-rmse:90168.4
[280]	test-rmse:90031.9
[281]	test-rmse:89908.5
[282]	test-rmse:89775.1
[283]	test-rmse:89654.3
[284]	test-rmse:89526.7
[285]	test-rmse:89395.2
[286]	test-rmse:89275.8
[287]	test-rmse:89160.1
[288]	test-rmse:89035.6
[289]	test-rmse:88924.4
[290]	test-rmse:88812.4
[291]	test-rmse:88696.1
[292]	test-rmse:88588.3
[293]	test-rmse:88483.3
[294]	test-rmse:88367.4
[295]	test-rmse:88265.7
[296]	test-rmse:88159
[297]	test-rmse:88060.3
[298]	test-rmse:87956.8
[299]	test-rmse:87859.5
[300]	test-rmse:87763.9
[301]	test-rmse:87660.7
[302]	test-rmse:87573.6
[303]	test-rmse:87475.7
[304]	test-rmse:87378.3
[305]	test-rmse:87287.8
[306]	test-rmse:87194.3
[307]	test-rmse:87113.9
[308]	test-rmse:87024.7
[309]	test-rmse:86936.5
[310]	test-rmse:86847.3
[311]	test-rmse:86761.9
[312]	test-rmse:86679.8
[313]	test-rmse:86612.5
[314]	test-rmse:86528
[315]	test-rmse:86449.6
[316]	test-rmse:86374.7
[317]	test-rmse:86297.3
[318]	test-rmse:86216.9
[319]	test-rmse:86147
[320]	test-rmse:86085.8
[321]	test-rmse:86018.1
[322]	test-rmse:85941.5
[323]	test-rmse:85878.8
[324]	test-rmse:85815.2
[325]	test-rmse:85755.3
[326]	test-rmse:85691.4
[327]	test-rmse:85631.7
[328]	test-rmse:85554.2
[329]	test-rmse:85478.2
[330]	test-rmse:85420.4
[331]	test-rmse:85355
[332]	test-rmse:85282.9
[333]	test-rmse:85212.7
[334]	test-rmse:85156.1
[335]	test-rmse:85089.2
[336]	test-rmse:85042.1
[337]	test-rmse:84977.9
[338]	test-rmse:84916.3
[339]	test-rmse:84865
[340]	test-rmse:84819.4
[341]	test-rmse:84764.9
[342]	test-rmse:84698.7
[343]	test-rmse:84655.8
[344]	test-rmse:84595.1
[345]	test-rmse:84546.1
[346]	test-rmse:84496.5
[347]	test-rmse:84446.9
[348]	test-rmse:84401.1
[349]	test-rmse:84349.7
[350]	test-rmse:84312.6
[351]	test-rmse:84263.9
[352]	test-rmse:84217.4
[353]	test-rmse:84176.9
[354]	test-rmse:84126.8
[355]	test-rmse:84081.5
[356]	test-rmse:84037.6
[357]	test-rmse:84001.1
[358]	test-rmse:83961.8
[359]	test-rmse:83922.8
[360]	test-rmse:83884.8
[361]	test-rmse:83842.4
[362]	test-rmse:83805.7
[363]	test-rmse:83771.6
[364]	test-rmse:83738.9
[365]	test-rmse:83701.5
[366]	test-rmse:83668
[367]	test-rmse:83633.7
[368]	test-rmse:83591.7
[369]	test-rmse:83552.1
[370]	test-rmse:83514.7
[371]	test-rmse:83479.3
[372]	test-rmse:83440.2
[373]	test-rmse:83412.3
[374]	test-rmse:83380.3
[375]	test-rmse:83346.3
[376]	test-rmse:83309.6
[377]	test-rmse:83272.6
[378]	test-rmse:83243.7
[379]	test-rmse:83211.3
[380]	test-rmse:83184.4
[381]	test-rmse:83151.7
[382]	test-rmse:83119.6
[383]	test-rmse:83089.4
[384]	test-rmse:83056.3
[385]	test-rmse:83023.5
[386]	test-rmse:82994.4
[387]	test-rmse:82964.4
[388]	test-rmse:82936.3
[389]	test-rmse:82907.3
[390]	test-rmse:82873.7
[391]	test-rmse:82845.4
[392]	test-rmse:82816.8
[393]	test-rmse:82790.7
[394]	test-rmse:82766
[395]	test-rmse:82740.9
[396]	test-rmse:82719.9
[397]	test-rmse:82695.1
[398]	test-rmse:82672.3
[399]	test-rmse:82647.9
[400]	test-rmse:82629.9
[401]	test-rmse:82602
[402]	test-rmse:82581.2
[403]	test-rmse:82562.3
[404]	test-rmse:82541
[405]	test-rmse:82524.3
[406]	test-rmse:82504
[407]	test-rmse:82490.7
[408]	test-rmse:82472
[409]	test-rmse:82448
[410]	test-rmse:82424.8
[411]	test-rmse:82408.9
[412]	test-rmse:82395.4
[413]	test-rmse:82373.6
[414]	test-rmse:82358.9
[415]	test-rmse:82336.1
[416]	test-rmse:82322.6
[417]	test-rmse:82301.7
[418]	test-rmse:82282.6
[419]	test-rmse:82268.4
[420]	test-rmse:82253.9
[421]	test-rmse:82229.1
[422]	test-rmse:82207
[423]	test-rmse:82188.9
[424]	test-rmse:82176.5
[425]	test-rmse:82170.7
[426]	test-rmse:82157
[427]	test-rmse:82151.2
[428]	test-rmse:82139.1
[429]	test-rmse:82126.4
[430]	test-rmse:82108.7
[431]	test-rmse:82098.1
[432]	test-rmse:82087.3
[433]	test-rmse:82075.5
[434]	test-rmse:82063.7
[435]	test-rmse:82054
[436]	test-rmse:82039.1
[437]	test-rmse:82027.3
[438]	test-rmse:82014.6
[439]	test-rmse:82005.3
[440]	test-rmse:81993.7
[441]	test-rmse:81984.7
[442]	test-rmse:81973.3
[443]	test-rmse:81955.5
[444]	test-rmse:81943.4
[445]	test-rmse:81932.8
[446]	test-rmse:81918.6
[447]	test-rmse:81909.1
[448]	test-rmse:81899.2
[449]	test-rmse:81886.5
[450]	test-rmse:81873.7
[451]	test-rmse:81863
[452]	test-rmse:81854.2
[453]	test-rmse:81842.5
[454]	test-rmse:81831.3
[455]	test-rmse:81821.2
[456]	test-rmse:81811.4
[457]	test-rmse:81804.7
[458]	test-rmse:81789.8
[459]	test-rmse:81784.3
[460]	test-rmse:81779.4
[461]	test-rmse:81771.3
[462]	test-rmse:81756.4
[463]	test-rmse:81751.9
[464]	test-rmse:81739.6
[465]	test-rmse:81730.1
[466]	test-rmse:81719.8
[467]	test-rmse:81710.2
[468]	test-rmse:81701.1
[469]	test-rmse:81689.9
[470]	test-rmse:81685
[471]	test-rmse:81675.6
[472]	test-rmse:81670.9
[473]	test-rmse:81659.8
[474]	test-rmse:81651.6
[475]	test-rmse:81641.8
[476]	test-rmse:81632.5
[477]	test-rmse:81629.2
[478]	test-rmse:81619.2
[479]	test-rmse:81611
[480]	test-rmse:81608
[481]	test-rmse:81599.1
[482]	test-rmse:81588.6
[483]	test-rmse:81578.9
[484]	test-rmse:81573.4
[485]	test-rmse:81570.2
[486]	test-rmse:81558.9
[487]	test-rmse:81554.5
[488]	test-rmse:81544.7
[489]	test-rmse:81533.8
[490]	test-rmse:81526.6
[491]	test-rmse:81518.9
[492]	test-rmse:81512.2
[493]	test-rmse:81498.3
[494]	test-rmse:81495.6
[495]	test-rmse:81488.1
[496]	test-rmse:81478.6
[497]	test-rmse:81469.1
[498]	test-rmse:81463
[499]	test-rmse:81462.4
[500]	test-rmse:81454.6
[501]	test-rmse:81453.7
[502]	test-rmse:81450.8
[503]	test-rmse:81443
[504]	test-rmse:81434
[505]	test-rmse:81430.1
[506]	test-rmse:81427.5
[507]	test-rmse:81421.4
[508]	test-rmse:81421.2
[509]	test-rmse:81415.5
[510]	test-rmse:81413
[511]	test-rmse:81409.5
[512]	test-rmse:81396.6
[513]	test-rmse:81395.1
[514]	test-rmse:81395.3
[515]	test-rmse:81392.2
[516]	test-rmse:81391.4
[517]	test-rmse:81388.2
[518]	test-rmse:81383.8
[519]	test-rmse:81379.3
[520]	test-rmse:81379.8
[521]	test-rmse:81376.9
[522]	test-rmse:81376.9
[523]	test-rmse:81375.1
[524]	test-rmse:81370.2
[525]	test-rmse:81365.3
[526]	test-rmse:81364.3
[527]	test-rmse:81363.6
[528]	test-rmse:81362.5
[529]	test-rmse:81358.9
[530]	test-rmse:81354.8
[531]	test-rmse:81353.9
[532]	test-rmse:81355.2
[533]	test-rmse:81355.2
[534]	test-rmse:81356.4
[535]	test-rmse:81356.3
[536]	test-rmse:81352.6
[537]	test-rmse:81347.5
[538]	test-rmse:81347.6
[539]	test-rmse:81349.5
[540]	test-rmse:81350.3
[541]	test-rmse:81351.2
[542]	test-rmse:81351.6
Stopping. Best iteration:
[537]	test-rmse:81347.5

重要度の高い特徴量を表示する。

先ほど学習したモデルのうち、その特徴量が重要度が高かったのかを、表示してみます。これを見ると、long、lat、sqft_livingなどの特徴量の重要度が高いことがわかります。

fig, ax = plt.subplots(figsize = (12,12))
xgb.plot_importance(model, max_num_features = 12, height = 0.8, ax = ax)
plt.savefig('house_sails_feature_importance03.png', bbox_inches='tight')

house_sails_feature_importance03.png

学習したモデルの精度を計測してみる。

学習したモデルの精度を計測してみます。r2_scoreでは、約0.847という精度となっています。

from sklearn.metrics import r2_score
preds = model.predict(dtest)
r2 = r2_score(y_test, preds)
print(r2)

0.8473346069012444

学習モデルを保存する

最後に、joblibを使って学習モデルをpklファイルとして保存しておきます。これで、機械学習で学習した学習モデルができました。

from sklearn.externals import joblib
joblib.dump(model, 'house_sales_model.pkl')

['house_sales_model.pkl']

FlaskでAPIサーバーを作る

ここでは、先ほど機械学習で生成した学習モデルをAPIサーバーにするということをやっています。APIサーバー開発にはPythonのマイクロサービス系のフレームワークである、Flaskを使用しています。開発の流れとしては、condaで仮想環境を構築し、簡易的なAPIサーバーをテストして、そこにXGBoostで作った学習モデルを載せるという流れになります。

condaで仮想環境を構築する

仮想環境は、Anacondaのcondaを使用します。ターミナルでアプリ開発用のフォルダ(ここでは、titanic_api)を作り、そのフォルダ内に移動します。そしたらconda createで仮想環境を生成し、conda activateで仮想環境をアクティブな状態にします。

mkdir housesails_api
cd housesails_api
conda create -n housesailsenv
conda activate housesailsenv

FlaskでAPIを開発する

FlaskでAPIサーバーを開発するために、最初に簡易的なAPIサーバーを作ってテストしてみます。先ほど作ったフォルダ内に以下のようなフォルダとファイルを作ります。ファイルにはそれぞれ以下のようなコードを書いて、APIサーバーを起動して、curlから通信できれば簡易的なAPIサーバーテストの成功です。

ターミナルに必要なフォルダとファイルを生成する。

以下のような階層になるようにフォルダとファイルを作ります。空ファイルを作るならtouchコマンドなどを使用すると便利です。

housesails_api
├── api
│   ├── __init__.py
│   └── views
│       └── user.py
├── housesails_app.py
└── house_sales_model.pkl

作成したファイルにコードを書く

先ほど作成したファイルに以下のようにコードを書きます。簡易的なAPIサーバーをテストするために必要なファイルは、api/views/user.py、api/init.py、titanic_app.pyの三つです。ターミナルで書く場合はvim、GUIで書く場合はAtomなどを使用すると便利です。

api/views/user.py
from flask import Blueprint, request, make_response, jsonify

# ルーティング設定
user_router = Blueprint('user_router', __name__)

# パスとHTTPメソッドを指定
@user_router.route('/users', methods=['GET'])
def get_user_list():

  return make_response(jsonify({
    'users': [
       {
         'id': 1,
         'name': 'John'
       }
     ]
  }))
api/__init__.py
from flask import Flask, make_response, jsonify
from .views.user import user_router

def create_app():

  app = Flask(__name__)
  app.register_blueprint(user_router, url_prefix='/api')

  return app

app = create_app()
housesails_app.py
import json

from flask import Flask
from flask import request
from flask import abort

import pandas as pd
from sklearn.externals import joblib
import xgboost as xgb

model = joblib.load("house_sales_model.pkl")

app = Flask(__name__)

# Get headers for payload
headers = ['sqft_living','sqft_above','sqft_basement','lat','long','sqft_living15','grade_3','grade_4','grade_5','grade_6','grade_7','grade_8','grade_9','grade_10','grade_11','grade_12']

@app.route('/house_sails', methods=['POST'])
def housesails():
    if not request.json:
        abort(400)
    payload = request.json['data']
    values = [float(i) for i in payload.split(',')]
    data1 = pd.DataFrame([values], columns=headers, dtype=float)
    predict = model.predict(xgb.DMatrix(data1))
    return json.dumps(str(predict[0]))


if __name__ == "__main__":
    app.run(debug=True, port=5000)

curlでAPI通信テストをする

コードを書き換えたら、改めて、python housesails_app.py で、APIサーバーを起動します。APIサーバーが起動したら、以下のようにcurlコマンドで通信テストをしています。送ったJSONデータに対して、小数点1以下の値が返ってきたら成功です。これで、機械学習で生成した学習モデルをAPIサーバーにすることができました。

curl http://localhost:5000/house_sails -s -X POST -H "Content-Type: application/json" -d '{"data": "-1.026685, -0.725963, -0.652987, -0.323607, -0.307144, -0.946801, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0"}'

認証機能をつける

最後に、APIサーバーに認証機能をつけます。今回はbasic認証を実装しています。ライブラリにHTTPBasicAuthがインストールされている必要があります。先ほどのhousesails_app.pyのコードに# BasicAuthと書かれているところを追記することで実装できます。

housesails_app.py
import json

from flask import Flask
from flask import request
from flask import abort
from flask_httpauth import HTTPBasicAuth

import pandas as pd
from sklearn.externals import joblib
import xgboost as xgb

model = joblib.load("house_sales_model.pkl")

app = Flask(__name__)

# BasicAuth
auth = HTTPBasicAuth()

users = {
    "user01": "password01",
    "user02": "password02"
}

@auth.get_password
def get_pw(username):
    if username in users:
        return users.get(username)
    return None

# Get headers for payload
headers = ['sqft_living','sqft_above','sqft_basement','lat','long','sqft_living15','grade_3','grade_4','grade_5','grade_6','grade_7','grade_8','grade_9','grade_10','grade_11','grade_12']

@app.route('/house_sails', methods=['POST'])

# BasicAuth
@auth.login_required

def housesails():
    if not request.json:
        abort(400)
    payload = request.json['data']
    values = [float(i) for i in payload.split(',')]
    data1 = pd.DataFrame([values], columns=headers, dtype=float)
    predict = model.predict(xgb.DMatrix(data1))
    return json.dumps(str(predict[0]))


if __name__ == "__main__":
    app.run(debug=True, port=5000)

再度curlでAPI通信テストをする

再度、python housesails_app.py で、APIサーバーを起動します。APIサーバーが起動したら、以下のようにcurlコマンドで通信テストをしています。--user user01:password01というのを追記することで認証されます。これで通信がうまくいえば成功です。

curl http://localhost:5000/house_sails --user user01:password01 -s -X POST -H "Content-Type: application/json" -d '{"data": "-1.026685, -0.725963, -0.652987, -0.323607, -0.307144, -0.946801, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0"}'
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?