##はじめに
KaggleのHouse Sales in King County, USAのデータセットを使って、XGboost機械学習で学習モデルを生成して、その学習モデルをFlaskでAPIサーバーにするというのをやりました。この機械学習のAPIサーバーは、主に4つの手順で行っています。最初にHouse Sailsのデータを把握するために、EDA(Explanatory Data Analysis)探索的データ解析を行って、データの状況を把握します。次に、機械学習で学習させるためのデータになるように前処理を行っています。その次に、機械学習で学習モデルを生成します。今回は、XGboostを使用しています。最後に、FlasでAPIサーバーを実装していきます。
このプログラムを実行するのに必要な環境
Anaconda、XGBoost、joblib、Flask、flask-corsなどのライブラリがインストールされている。
この機械学習によるAPIサーバーの実装は以下の4つの処理によって行っています。
- House Sailsのデータを把握する(EDA)
- データセットの前処理を行う
- 機械学習で学習モデルを作る
- flaskでAPIサーバーを作る
ライブラリとデータセットの読み込み
まずは、必要なラブラリとKaggleからダウンロードしてきたHouse Sailsのデータセットを読み込んできます。set_optionを指定すると、表示するカラム数を指定することができます。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
pd.set_option('display.max_columns', 4)
df = pd.read_csv('house_sales/kc_house_data.csv')
df.head()
id | date | ... | sqft_living15 | sqft_lot15 | |
---|---|---|---|---|---|
0 | 7129300520 | 20141013T000000 | ... | 1340 | 5650 |
1 | 6414100192 | 20141209T000000 | ... | 1690 | 7639 |
2 | 5631500400 | 20150225T000000 | ... | 2720 | 8062 |
3 | 2487200875 | 20141209T000000 | ... | 1360 | 5000 |
4 | 1954400510 | 20150218T000000 | ... | 1800 | 7503 |
5 rows × 21 columns |
sqft_livingのヒストグラムを表示してみる。
まずは、特徴量のうち、sqft_livingのヒストグラムを表示しています。これを見えると、大きな値があるので、標準正規分布の形からはずれていることがわかります。
plt.figure(figsize = (12,8))
plt.hist(df["sqft_living"])
plt.savefig('House_Sales_Explanatory Data Analysis_hist01.png', bbox_inches='tight')
priceのヒストグラムを表示してみる。
今回予測する値となる、priceについてもヒストグラムを表示してみます。これも、先ほどと同様に、大きな値、外れ値があるため、左によった形状をしています。
plt.figure(figsize = (12,8))
plt.hist(df["price"])
plt.savefig('House_Sales_Explanatory Data Analysis_hist02.png', bbox_inches='tight')
四分位範囲の処理を行って外れ値を削除する
この関数は、四分位範囲の処理をしています。これによって、外れ値となるデータを削除します。
def outlier_iqr(df, columns = None):
if columns == None:
columns = df.columns
for col in columns:
q1 = df[col].describe()['25%']
q3 = df[col].describe()['75%']
iqr = q3 - q1
outlier_min = q1 - iqr * 1.5
outlier_max = q3 + iqr * 1.5
df = df[(df[col] >= outlier_min) & (df[col] <= outlier_max)]
return df
df_1 = outlier_iqr(df, columns = ['price'])
df_1.shape
(20454, 21)
再度priceをヒストグラムで表示してみる。
四分位範囲の処理を行ったpriceのヒストグラムをもう一度表示してみます。今度は、外れ値を削除したので、正規分布に近い形状のヒストグラムになっています。shapeでデータの形状を確認してみると、先ほどよりは多少データの数は減っていますが、大きくは減っていないことも確認できます。
plt.figure(figsize = (12,8))
plt.hist(df_1["price"])
plt.savefig('House_Sales_Explanatory Data Analysis_hist03.png', bbox_inches='tight')
df.shape
(21613, 21)
数値データのヒストグラムを全部表示してみる。
一応、確認のために、他のデータのヒストグラムの形状も確認してみます。概ね、バランスのとれたヒストグラムの形状になっていることが確認できます。
fig, axes = plt.subplots(2,3, figsize = (18, 12))
axes.ravel()[0].hist(df_1["sqft_living"])
axes.ravel()[1].hist(df_1["sqft_above"])
axes.ravel()[2].hist(df_1["sqft_basement"])
axes.ravel()[3].hist(df_1["lat"])
axes.ravel()[4].hist(df_1["long"])
axes.ravel()[5].hist(df_1["sqft_living15"])
axes.ravel()[0].set_title("sqft_living")
axes.ravel()[1].set_title("sqft_above")
axes.ravel()[2].set_title("sqft_basement")
axes.ravel()[3].set_title("lat")
axes.ravel()[4].set_title("long")
axes.ravel()[5].set_title("sqft_living15")
plt.savefig('House_Sales_Explanatory Data Analysis_hist04.png', bbox_inches='tight')
不必要な特徴量を削除する。
ここでは、学習の際に不要となる特徴量を削除します。ここでは、id、date、sqft_lot、sqft_lot15、zipcodeを削除しています。
df_1 = df_1.drop(columns = ['id', 'date', 'sqft_lot','sqft_lot15','zipcode'])
建築年を築年数に変える。
yr_builtという特徴量は、建築年なので、建物が建てられた年がデータにあります。このままだと、学習データとして扱うのは難しいので、築年数とした新たな特徴量を加えています。また、yr_renovatedも改装された年になっているので、これも改装されてからの年数に変えています。
df_1["age"] = 2020 - df_1["yr_built"]
df_1.loc[(df_1['yr_renovated'] == 0), 'yr_renovated'] = 2020
数値の特徴量を標準化する。
数値のデータは、ここではStandardScalerを使って標準化しています。
from sklearn.preprocessing import StandardScaler
num_feature = ['sqft_living', 'sqft_above', 'sqft_basement', 'lat', 'long', 'sqft_living15']
for col in num_feature:
scaler = StandardScaler()
df_1[col] = scaler.fit_transform(np.array(df_1[col].values).reshape(-1, 1))
再度、数値データのヒストグラムを全部表示してみる。
一応、再度、数値データのヒストグラムを確認しています。先ほど、表示した形状と変わりがないことが確認できます。
fig, axes = plt.subplots(2,3, figsize = (18, 12))
axes.ravel()[0].hist(df_1["sqft_living"])
axes.ravel()[1].hist(df_1["sqft_above"])
axes.ravel()[2].hist(df_1["sqft_basement"])
axes.ravel()[3].hist(df_1["lat"])
axes.ravel()[4].hist(df_1["long"])
axes.ravel()[5].hist(df_1["sqft_living15"])
axes.ravel()[0].set_title("sqft_living")
axes.ravel()[1].set_title("sqft_above")
axes.ravel()[2].set_title("sqft_basement")
axes.ravel()[3].set_title("lat")
axes.ravel()[4].set_title("long")
axes.ravel()[5].set_title("sqft_living15")
plt.savefig('House_Sales_Explanatory Data Analysis_hist05.png', bbox_inches='tight')
正解データをcsvで保存する
最後に、機械学習用のデータとしてCSVに一旦保存しておきます。
df_price = df_1["price"]
df_price.to_csv('House_Sales_Explanatory_Price.csv')
学習データをcsvで保存する
学習データは、不必要なデータを削除して、カテゴリカルなデータは、get_dummiesでダミー変数に変換したものをcsvとして保存しています。
df_1 = df_1.drop(columns = ['price', 'yr_built', 'yr_renovated'])
df_1 = pd.get_dummies(df_1, columns = ['bedrooms', 'bathrooms', 'floors', 'waterfront', 'view', 'condition', 'grade', 'age', 'renovated_age'], drop_first = True)
df_1.to_csv('House_Sales_Explanatory_Preprocessing.csv')
前処理を行ったデータを使って機械学習を行う。
ここでは、先ほどまでに行った前処理をしたデータを使って機械学習をしていきます。なので、別のプロジェクトとしておいたほうが良いです。
必要なライブラリをインポートする
あらためて今回の処理に必要なライブラリをインポートしてきます。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
CSVファイルをpandasとして読み込む
先ほど、保存した前処理済みのCSVファイルをpandasとして読み込みます。Unnamedという不要な特徴量があったので、それは削除しておきます。set_optionで表示される特徴量は4つにしています。
pd.set_option('display.max_columns', 4)
df = pd.read_csv('House_Sales_Explanatory_Preprocessing.csv')
df = df.drop(columns = ['Unnamed: 0'])
df.head()
sqft_living | sqft_above | ... | renovated_age_80 | renovated_age_86 |
---|---|---|---|---|
0 | -1.026685 | -0.725963 | ... | 0 |
1 | 0.769106 | 0.635702 | ... | 0 |
2 | -1.556379 | -1.289885 | ... | 0 |
3 | -0.018975 | -0.904768 | ... | 0 |
4 | -0.380717 | -0.038253 | ... | 0 |
CSVファイルをpandasとして読み込む(price)
正解データとなるpriceも同じにように読み込んできます。
price | |
---|---|
0 | 221900.0 |
1 | 538000.0 |
2 | 180000.0 |
3 | 604000.0 |
4 | 510000.0 |
df_price = pd.read_csv('House_Sales_Explanatory_Price.csv', header=None, names=['price'])
df_price.head()
使用する特徴量だけのデータフレームを作る。
今回の機械学習では、事前の学習の結果、前処理を行ったすべてのデータを使った学習モデルと、重要度の高くない特徴量を削除したデータを使った学習モデルでの学習精度に大きな差が出なかったこともあって、以下の特長量だけを使うことにしています。基本的には数値系の特長量はすべて使い、カテゴリカルな特長量に関しては、gradeのみを残して、それ以外のカテゴリカルなデータを削除しています。理由としては、最終的に機械学習モデルを使ったアプリケーションを開発する際に、複数のカテゴリカルなデータのフロントエンドの実装に手間がかかると予想したことが主な理由です。
df = df[["sqft_living","sqft_above","sqft_basement","lat","long","sqft_living15","grade_3","grade_4","grade_5","grade_6","grade_7"
,"grade_8","grade_9","grade_10","grade_11","grade_12"]]
sqft_living | sqft_above | ... | grade_11 | grade_12 | |
---|---|---|---|---|---|
0 | -1.026685 | -0.725963 | ... | 0 | 0 |
1 | 0.769106 | 0.635702 | ... | 0 | 0 |
2 | -1.556379 | -1.289885 | ... | 0 | 0 |
3 | -0.018975 | -0.904768 | ... | 0 | 0 |
4 | -0.380717 | -0.038253 | ... | 0 | 0 |
5 rows × 16 columns |
XGboostで学習モデルを生成する
ここでは、XGboostをインポートして、機械学習を行っています。パラーメータは、ほぼデフォルトのままです。
import xgboost as xgb
X_train, X_test, y_train, y_test = train_test_split(df, df_price, random_state = 0)
params = {
'silent' : 1,
'max_depth' : 6,
'min_chiled_weight' : 1,
'eta' : 0.1,
'tree_method' : 'exact',
'objective' : 'reg:linear',
'eval_metric' : 'rmse',
'predictor' : 'cpu_predictor'
}
dtrain = xgb.DMatrix(X_train, label = y_train)
dtest = xgb.DMatrix(X_test, label = y_test)
model = xgb.train(params = params,
dtrain = dtrain,
num_boost_round = 200,
early_stopping_rounds = 10,
evals = [(dtest, 'test')])
[0] test-rmse:471544
Will train until test-rmse hasn't improved in 10 rounds.
[1] test-rmse:427350
[2] test-rmse:387757
[3] test-rmse:352314
[4] test-rmse:320602
[5] test-rmse:292132
[6] test-rmse:266667
[7] test-rmse:244148
[8] test-rmse:223983
[9] test-rmse:206046
[10] test-rmse:190112
[11] test-rmse:176111
[12] test-rmse:163754
[13] test-rmse:152820
[14] test-rmse:143269
[15] test-rmse:134879
[16] test-rmse:127772
[17] test-rmse:121362
[18] test-rmse:115939
[19] test-rmse:111405
[20] test-rmse:107280
[21] test-rmse:103750
[22] test-rmse:100928
[23] test-rmse:98446.5
[24] test-rmse:96280.3
[25] test-rmse:94419.2
[26] test-rmse:92933.6
[27] test-rmse:91644.1
[28] test-rmse:90581.3
[29] test-rmse:89422.8
[30] test-rmse:88575.7
[31] test-rmse:88038.8
[32] test-rmse:87254.6
[33] test-rmse:86857.1
[34] test-rmse:86527.8
[35] test-rmse:86238.3
[36] test-rmse:85950
[37] test-rmse:85705
[38] test-rmse:85532.4
[39] test-rmse:85346.7
[40] test-rmse:85204.1
[41] test-rmse:85058.9
[42] test-rmse:84926.7
[43] test-rmse:84845.4
[44] test-rmse:84671.9
[45] test-rmse:84539.6
[46] test-rmse:84380.6
[47] test-rmse:84287.2
[48] test-rmse:84254.7
[49] test-rmse:84168.9
[50] test-rmse:84106.6
[51] test-rmse:83858.5
[52] test-rmse:83829.8
[53] test-rmse:83809.5
[54] test-rmse:83726
[55] test-rmse:83704.2
[56] test-rmse:83650.4
[57] test-rmse:83422.6
[58] test-rmse:83405.8
[59] test-rmse:83281
[60] test-rmse:83293.6
[61] test-rmse:83289.4
[62] test-rmse:83251.9
[63] test-rmse:83237.5
[64] test-rmse:83055.6
[65] test-rmse:83051.9
[66] test-rmse:82938.8
[67] test-rmse:82932.7
[68] test-rmse:82933.2
[69] test-rmse:82859
[70] test-rmse:82829.6
[71] test-rmse:82840.5
[72] test-rmse:82823
[73] test-rmse:82827.4
[74] test-rmse:82834.6
[75] test-rmse:82845.9
[76] test-rmse:82839.4
[77] test-rmse:82828.5
[78] test-rmse:82829.7
[79] test-rmse:82651.8
[80] test-rmse:82660
[81] test-rmse:82637.3
[82] test-rmse:82514.6
[83] test-rmse:82497.6
[84] test-rmse:82484.7
[85] test-rmse:82486.3
[86] test-rmse:82486.8
[87] test-rmse:82496
[88] test-rmse:82491.4
[89] test-rmse:82486.6
[90] test-rmse:82290.3
[91] test-rmse:82265.1
[92] test-rmse:82261.5
[93] test-rmse:82236.5
[94] test-rmse:82236.4
[95] test-rmse:82111.9
[96] test-rmse:82111.1
[97] test-rmse:82111.3
[98] test-rmse:82108
[99] test-rmse:82097.1
[100] test-rmse:82077.4
[101] test-rmse:82041.9
[102] test-rmse:82040
[103] test-rmse:82042.6
[104] test-rmse:82044.2
[105] test-rmse:82033.7
[106] test-rmse:82041.1
[107] test-rmse:82028.4
[108] test-rmse:82030.7
[109] test-rmse:82036.4
[110] test-rmse:82028.6
[111] test-rmse:82020.3
[112] test-rmse:82025.5
[113] test-rmse:82024.9
[114] test-rmse:82034
[115] test-rmse:82025.2
[116] test-rmse:81957.5
[117] test-rmse:81950.9
[118] test-rmse:81959.8
[119] test-rmse:81936.7
[120] test-rmse:81935.9
[121] test-rmse:81937
[122] test-rmse:81945.8
[123] test-rmse:81894.8
[124] test-rmse:81885.2
[125] test-rmse:81899.3
[126] test-rmse:81877
[127] test-rmse:81875.7
[128] test-rmse:81859.6
[129] test-rmse:81849.7
[130] test-rmse:81851.2
[131] test-rmse:81839.4
[132] test-rmse:81850.8
[133] test-rmse:81846
[134] test-rmse:81836.2
[135] test-rmse:81827.2
[136] test-rmse:81832.3
[137] test-rmse:81859.6
[138] test-rmse:81856.6
[139] test-rmse:81850
[140] test-rmse:81847.6
[141] test-rmse:81842.8
[142] test-rmse:81794.5
[143] test-rmse:81803.8
[144] test-rmse:81829.3
[145] test-rmse:81815.9
[146] test-rmse:81813.6
[147] test-rmse:81741
[148] test-rmse:81728.8
[149] test-rmse:81714.4
[150] test-rmse:81708.6
[151] test-rmse:81592.3
[152] test-rmse:81621.7
[153] test-rmse:81624.8
[154] test-rmse:81629.3
[155] test-rmse:81615.7
[156] test-rmse:81617.7
[157] test-rmse:81613.9
[158] test-rmse:81612.9
[159] test-rmse:81594.9
[160] test-rmse:81595.1
[161] test-rmse:81581.7
[162] test-rmse:81595.3
[163] test-rmse:81603.8
[164] test-rmse:81601.2
[165] test-rmse:81600.5
[166] test-rmse:81552.3
[167] test-rmse:81557.6
[168] test-rmse:81565.5
[169] test-rmse:81566.6
[170] test-rmse:81581.9
[171] test-rmse:81570.5
[172] test-rmse:81571.8
[173] test-rmse:81569.4
[174] test-rmse:81494.3
[175] test-rmse:81476.3
[176] test-rmse:81454
[177] test-rmse:81422.6
[178] test-rmse:81426.1
[179] test-rmse:81410.8
[180] test-rmse:81425.1
[181] test-rmse:81418.2
[182] test-rmse:81419.4
[183] test-rmse:81409.6
[184] test-rmse:81392.1
[185] test-rmse:81389.3
[186] test-rmse:81391.1
[187] test-rmse:81414.5
[188] test-rmse:81369.9
[189] test-rmse:81368.3
[190] test-rmse:81358.4
[191] test-rmse:81347.7
[192] test-rmse:81355.4
[193] test-rmse:81349.2
[194] test-rmse:81343
[195] test-rmse:81346.3
[196] test-rmse:81345.5
[197] test-rmse:81374.6
[198] test-rmse:81358.5
[199] test-rmse:81359.4
グリッドサーチのためのパラメータを生成する
もう少し、モデルの精度がないかを試すために、グリッドサーチをやっています。最初に、グリッドサーチをするためのパラメータを生成しています。
gridsearch_params = [
(max_depth, eta)
for max_depth in [6, 7, 8]
for eta in [0.1, 0.05, 0.01]
]
gridsearch_params
[(6, 0.1),
(6, 0.05),
(6, 0.01),
(7, 0.1),
(7, 0.05),
(7, 0.01),
(8, 0.1),
(8, 0.05),
(8, 0.01)]
もっとも精度の良かったパラメータを計算する
ここでは、どの組み合わせのパラメータがもっとも精度が高くなるかを計算しています。結果は、Best params (8, 0.01)となりました。
min_rmse = float('Inf')
best_param = []
for max_depth, eta in gridsearch_params:
print('max_depth = {}, eta = {}'.format(max_depth, eta))
params['max_depth'] = max_depth
params['eta'] = eta
cv_results = xgb.cv(
params,
dtrain,
num_boost_round = 1000,
seed = 0,
nfold = 5,
metrics = {'rmse'},
early_stopping_rounds = 5
)
mean_rmse = cv_results['test-rmse-mean'].min()
boost_rounds = cv_results['test-rmse-mean'].argmin()
print('RMSE {} for {} rounds'.format(mean_rmse, boost_rounds))
if mean_rmse < min_rmse:
min_rmse = mean_rmse
best_param = (max_depth, eta)
print('Best params {}, RMSE {}'.format(best_param, min_rmse))
max_depth = 6, eta = 0.1
RMSE 81689.0296874 for 123 rounds
max_depth = 6, eta = 0.05
RMSE 81545.2953126 for 267 rounds
max_depth = 6, eta = 0.01
RMSE 82118.7765624 for 999 rounds
max_depth = 7, eta = 0.1
RMSE 81372.990625 for 161 rounds
max_depth = 7, eta = 0.05
RMSE 81372.7171876 for 202 rounds
max_depth = 7, eta = 0.01
RMSE 81308.89999979999 for 999 rounds
max_depth = 8, eta = 0.1
RMSE 81277.4515624 for 96 rounds
max_depth = 8, eta = 0.05
RMSE 81155.2687498 for 201 rounds
max_depth = 8, eta = 0.01
RMSE 81080.3156252 for 849 rounds
Best params (8, 0.01), RMSE 81080.3156252
再度パラメータを変更して学習モデルを生成する
先ほどのグリッドサーチによって、計算されたパラメータを使って再度学習モデルを生成します。
params['max_depth'] = 8
params['eta'] = 0.01
model = xgb.train(params = params,
dtrain = dtrain,
num_boost_round = 1000,
early_stopping_rounds = 5,
evals = [(dtest, 'test')])
[0] test-rmse:515961
Will train until test-rmse hasn't improved in 5 rounds.
[1] test-rmse:511040
[2] test-rmse:506173
[3] test-rmse:501356
[4] test-rmse:496588
[5] test-rmse:491874
[6] test-rmse:487210
[7] test-rmse:482584
[8] test-rmse:478012
[9] test-rmse:473483
[10] test-rmse:468999
[11] test-rmse:464566
[12] test-rmse:460175
[13] test-rmse:455822
[14] test-rmse:451526
[15] test-rmse:447271
[16] test-rmse:443066
[17] test-rmse:438893
[18] test-rmse:434766
[19] test-rmse:430688
[20] test-rmse:426657
[21] test-rmse:422664
[22] test-rmse:418718
[23] test-rmse:414801
[24] test-rmse:410941
[25] test-rmse:407098
[26] test-rmse:403307
[27] test-rmse:399547
[28] test-rmse:395829
[29] test-rmse:392152
[30] test-rmse:388512
[31] test-rmse:384923
[32] test-rmse:381353
[33] test-rmse:377832
[34] test-rmse:374343
[35] test-rmse:370902
[36] test-rmse:367481
[37] test-rmse:364100
[38] test-rmse:360764
[39] test-rmse:357461
[40] test-rmse:354187
[41] test-rmse:350960
[42] test-rmse:347750
[43] test-rmse:344572
[44] test-rmse:341426
[45] test-rmse:338314
[46] test-rmse:335245
[47] test-rmse:332213
[48] test-rmse:329190
[49] test-rmse:326207
[50] test-rmse:323269
[51] test-rmse:320358
[52] test-rmse:317476
[53] test-rmse:314619
[54] test-rmse:311801
[55] test-rmse:309001
[56] test-rmse:306233
[57] test-rmse:303498
[58] test-rmse:300790
[59] test-rmse:298113
[60] test-rmse:295470
[61] test-rmse:292855
[62] test-rmse:290269
[63] test-rmse:287704
[64] test-rmse:285165
[65] test-rmse:282666
[66] test-rmse:280184
[67] test-rmse:277741
[68] test-rmse:275313
[69] test-rmse:272907
[70] test-rmse:270532
[71] test-rmse:268190
[72] test-rmse:265866
[73] test-rmse:263562
[74] test-rmse:261295
[75] test-rmse:259046
[76] test-rmse:256816
[77] test-rmse:254625
[78] test-rmse:252448
[79] test-rmse:250296
[80] test-rmse:248164
[81] test-rmse:246063
[82] test-rmse:243969
[83] test-rmse:241918
[84] test-rmse:239880
[85] test-rmse:237853
[86] test-rmse:235862
[87] test-rmse:233881
[88] test-rmse:231939
[89] test-rmse:230004
[90] test-rmse:228095
[91] test-rmse:226206
[92] test-rmse:224346
[93] test-rmse:222499
[94] test-rmse:220670
[95] test-rmse:218861
[96] test-rmse:217075
[97] test-rmse:215311
[98] test-rmse:213567
[99] test-rmse:211827
[100] test-rmse:210120
[101] test-rmse:208421
[102] test-rmse:206737
[103] test-rmse:205088
[104] test-rmse:203457
[105] test-rmse:201827
[106] test-rmse:200212
[107] test-rmse:198636
[108] test-rmse:197085
[109] test-rmse:195530
[110] test-rmse:194010
[111] test-rmse:192494
[112] test-rmse:191011
[113] test-rmse:189524
[114] test-rmse:188077
[115] test-rmse:186631
[116] test-rmse:185212
[117] test-rmse:183809
[118] test-rmse:182411
[119] test-rmse:181043
[120] test-rmse:179675
[121] test-rmse:178325
[122] test-rmse:177006
[123] test-rmse:175698
[124] test-rmse:174401
[125] test-rmse:173124
[126] test-rmse:171857
[127] test-rmse:170612
[128] test-rmse:169374
[129] test-rmse:168161
[130] test-rmse:166952
[131] test-rmse:165766
[132] test-rmse:164596
[133] test-rmse:163434
[134] test-rmse:162284
[135] test-rmse:161156
[136] test-rmse:160034
[137] test-rmse:158914
[138] test-rmse:157821
[139] test-rmse:156745
[140] test-rmse:155676
[141] test-rmse:154628
[142] test-rmse:153593
[143] test-rmse:152568
[144] test-rmse:151559
[145] test-rmse:150558
[146] test-rmse:149572
[147] test-rmse:148603
[148] test-rmse:147644
[149] test-rmse:146701
[150] test-rmse:145766
[151] test-rmse:144831
[152] test-rmse:143911
[153] test-rmse:143000
[154] test-rmse:142102
[155] test-rmse:141215
[156] test-rmse:140345
[157] test-rmse:139482
[158] test-rmse:138627
[159] test-rmse:137799
[160] test-rmse:136970
[161] test-rmse:136155
[162] test-rmse:135347
[163] test-rmse:134549
[164] test-rmse:133771
[165] test-rmse:132997
[166] test-rmse:132246
[167] test-rmse:131489
[168] test-rmse:130746
[169] test-rmse:130024
[170] test-rmse:129296
[171] test-rmse:128587
[172] test-rmse:127886
[173] test-rmse:127192
[174] test-rmse:126505
[175] test-rmse:125824
[176] test-rmse:125160
[177] test-rmse:124501
[178] test-rmse:123857
[179] test-rmse:123216
[180] test-rmse:122583
[181] test-rmse:121954
[182] test-rmse:121339
[183] test-rmse:120737
[184] test-rmse:120148
[185] test-rmse:119561
[186] test-rmse:118982
[187] test-rmse:118408
[188] test-rmse:117840
[189] test-rmse:117286
[190] test-rmse:116739
[191] test-rmse:116198
[192] test-rmse:115670
[193] test-rmse:115143
[194] test-rmse:114633
[195] test-rmse:114128
[196] test-rmse:113628
[197] test-rmse:113133
[198] test-rmse:112648
[199] test-rmse:112167
[200] test-rmse:111694
[201] test-rmse:111232
[202] test-rmse:110769
[203] test-rmse:110309
[204] test-rmse:109870
[205] test-rmse:109429
[206] test-rmse:109001
[207] test-rmse:108584
[208] test-rmse:108159
[209] test-rmse:107745
[210] test-rmse:107338
[211] test-rmse:106934
[212] test-rmse:106543
[213] test-rmse:106161
[214] test-rmse:105774
[215] test-rmse:105404
[216] test-rmse:105032
[217] test-rmse:104666
[218] test-rmse:104306
[219] test-rmse:103951
[220] test-rmse:103605
[221] test-rmse:103256
[222] test-rmse:102918
[223] test-rmse:102581
[224] test-rmse:102258
[225] test-rmse:101929
[226] test-rmse:101614
[227] test-rmse:101305
[228] test-rmse:101001
[229] test-rmse:100687
[230] test-rmse:100393
[231] test-rmse:100106
[232] test-rmse:99803.7
[233] test-rmse:99521.5
[234] test-rmse:99228
[235] test-rmse:98952.9
[236] test-rmse:98687
[237] test-rmse:98407.6
[238] test-rmse:98145.9
[239] test-rmse:97895.6
[240] test-rmse:97630.3
[241] test-rmse:97373.5
[242] test-rmse:97131.5
[243] test-rmse:96879.7
[244] test-rmse:96638.5
[245] test-rmse:96409.2
[246] test-rmse:96174.5
[247] test-rmse:95950.8
[248] test-rmse:95724
[249] test-rmse:95504.3
[250] test-rmse:95286
[251] test-rmse:95063.2
[252] test-rmse:94852.8
[253] test-rmse:94646.3
[254] test-rmse:94438.7
[255] test-rmse:94227.9
[256] test-rmse:94032
[257] test-rmse:93828.1
[258] test-rmse:93637
[259] test-rmse:93447.4
[260] test-rmse:93264
[261] test-rmse:93072.1
[262] test-rmse:92886.1
[263] test-rmse:92699.4
[264] test-rmse:92519.7
[265] test-rmse:92341.1
[266] test-rmse:92158.9
[267] test-rmse:91984.2
[268] test-rmse:91818.9
[269] test-rmse:91667.4
[270] test-rmse:91508.6
[271] test-rmse:91340.9
[272] test-rmse:91179.8
[273] test-rmse:91036.4
[274] test-rmse:90880.2
[275] test-rmse:90730.6
[276] test-rmse:90586.1
[277] test-rmse:90440.3
[278] test-rmse:90301.1
[279] test-rmse:90168.4
[280] test-rmse:90031.9
[281] test-rmse:89908.5
[282] test-rmse:89775.1
[283] test-rmse:89654.3
[284] test-rmse:89526.7
[285] test-rmse:89395.2
[286] test-rmse:89275.8
[287] test-rmse:89160.1
[288] test-rmse:89035.6
[289] test-rmse:88924.4
[290] test-rmse:88812.4
[291] test-rmse:88696.1
[292] test-rmse:88588.3
[293] test-rmse:88483.3
[294] test-rmse:88367.4
[295] test-rmse:88265.7
[296] test-rmse:88159
[297] test-rmse:88060.3
[298] test-rmse:87956.8
[299] test-rmse:87859.5
[300] test-rmse:87763.9
[301] test-rmse:87660.7
[302] test-rmse:87573.6
[303] test-rmse:87475.7
[304] test-rmse:87378.3
[305] test-rmse:87287.8
[306] test-rmse:87194.3
[307] test-rmse:87113.9
[308] test-rmse:87024.7
[309] test-rmse:86936.5
[310] test-rmse:86847.3
[311] test-rmse:86761.9
[312] test-rmse:86679.8
[313] test-rmse:86612.5
[314] test-rmse:86528
[315] test-rmse:86449.6
[316] test-rmse:86374.7
[317] test-rmse:86297.3
[318] test-rmse:86216.9
[319] test-rmse:86147
[320] test-rmse:86085.8
[321] test-rmse:86018.1
[322] test-rmse:85941.5
[323] test-rmse:85878.8
[324] test-rmse:85815.2
[325] test-rmse:85755.3
[326] test-rmse:85691.4
[327] test-rmse:85631.7
[328] test-rmse:85554.2
[329] test-rmse:85478.2
[330] test-rmse:85420.4
[331] test-rmse:85355
[332] test-rmse:85282.9
[333] test-rmse:85212.7
[334] test-rmse:85156.1
[335] test-rmse:85089.2
[336] test-rmse:85042.1
[337] test-rmse:84977.9
[338] test-rmse:84916.3
[339] test-rmse:84865
[340] test-rmse:84819.4
[341] test-rmse:84764.9
[342] test-rmse:84698.7
[343] test-rmse:84655.8
[344] test-rmse:84595.1
[345] test-rmse:84546.1
[346] test-rmse:84496.5
[347] test-rmse:84446.9
[348] test-rmse:84401.1
[349] test-rmse:84349.7
[350] test-rmse:84312.6
[351] test-rmse:84263.9
[352] test-rmse:84217.4
[353] test-rmse:84176.9
[354] test-rmse:84126.8
[355] test-rmse:84081.5
[356] test-rmse:84037.6
[357] test-rmse:84001.1
[358] test-rmse:83961.8
[359] test-rmse:83922.8
[360] test-rmse:83884.8
[361] test-rmse:83842.4
[362] test-rmse:83805.7
[363] test-rmse:83771.6
[364] test-rmse:83738.9
[365] test-rmse:83701.5
[366] test-rmse:83668
[367] test-rmse:83633.7
[368] test-rmse:83591.7
[369] test-rmse:83552.1
[370] test-rmse:83514.7
[371] test-rmse:83479.3
[372] test-rmse:83440.2
[373] test-rmse:83412.3
[374] test-rmse:83380.3
[375] test-rmse:83346.3
[376] test-rmse:83309.6
[377] test-rmse:83272.6
[378] test-rmse:83243.7
[379] test-rmse:83211.3
[380] test-rmse:83184.4
[381] test-rmse:83151.7
[382] test-rmse:83119.6
[383] test-rmse:83089.4
[384] test-rmse:83056.3
[385] test-rmse:83023.5
[386] test-rmse:82994.4
[387] test-rmse:82964.4
[388] test-rmse:82936.3
[389] test-rmse:82907.3
[390] test-rmse:82873.7
[391] test-rmse:82845.4
[392] test-rmse:82816.8
[393] test-rmse:82790.7
[394] test-rmse:82766
[395] test-rmse:82740.9
[396] test-rmse:82719.9
[397] test-rmse:82695.1
[398] test-rmse:82672.3
[399] test-rmse:82647.9
[400] test-rmse:82629.9
[401] test-rmse:82602
[402] test-rmse:82581.2
[403] test-rmse:82562.3
[404] test-rmse:82541
[405] test-rmse:82524.3
[406] test-rmse:82504
[407] test-rmse:82490.7
[408] test-rmse:82472
[409] test-rmse:82448
[410] test-rmse:82424.8
[411] test-rmse:82408.9
[412] test-rmse:82395.4
[413] test-rmse:82373.6
[414] test-rmse:82358.9
[415] test-rmse:82336.1
[416] test-rmse:82322.6
[417] test-rmse:82301.7
[418] test-rmse:82282.6
[419] test-rmse:82268.4
[420] test-rmse:82253.9
[421] test-rmse:82229.1
[422] test-rmse:82207
[423] test-rmse:82188.9
[424] test-rmse:82176.5
[425] test-rmse:82170.7
[426] test-rmse:82157
[427] test-rmse:82151.2
[428] test-rmse:82139.1
[429] test-rmse:82126.4
[430] test-rmse:82108.7
[431] test-rmse:82098.1
[432] test-rmse:82087.3
[433] test-rmse:82075.5
[434] test-rmse:82063.7
[435] test-rmse:82054
[436] test-rmse:82039.1
[437] test-rmse:82027.3
[438] test-rmse:82014.6
[439] test-rmse:82005.3
[440] test-rmse:81993.7
[441] test-rmse:81984.7
[442] test-rmse:81973.3
[443] test-rmse:81955.5
[444] test-rmse:81943.4
[445] test-rmse:81932.8
[446] test-rmse:81918.6
[447] test-rmse:81909.1
[448] test-rmse:81899.2
[449] test-rmse:81886.5
[450] test-rmse:81873.7
[451] test-rmse:81863
[452] test-rmse:81854.2
[453] test-rmse:81842.5
[454] test-rmse:81831.3
[455] test-rmse:81821.2
[456] test-rmse:81811.4
[457] test-rmse:81804.7
[458] test-rmse:81789.8
[459] test-rmse:81784.3
[460] test-rmse:81779.4
[461] test-rmse:81771.3
[462] test-rmse:81756.4
[463] test-rmse:81751.9
[464] test-rmse:81739.6
[465] test-rmse:81730.1
[466] test-rmse:81719.8
[467] test-rmse:81710.2
[468] test-rmse:81701.1
[469] test-rmse:81689.9
[470] test-rmse:81685
[471] test-rmse:81675.6
[472] test-rmse:81670.9
[473] test-rmse:81659.8
[474] test-rmse:81651.6
[475] test-rmse:81641.8
[476] test-rmse:81632.5
[477] test-rmse:81629.2
[478] test-rmse:81619.2
[479] test-rmse:81611
[480] test-rmse:81608
[481] test-rmse:81599.1
[482] test-rmse:81588.6
[483] test-rmse:81578.9
[484] test-rmse:81573.4
[485] test-rmse:81570.2
[486] test-rmse:81558.9
[487] test-rmse:81554.5
[488] test-rmse:81544.7
[489] test-rmse:81533.8
[490] test-rmse:81526.6
[491] test-rmse:81518.9
[492] test-rmse:81512.2
[493] test-rmse:81498.3
[494] test-rmse:81495.6
[495] test-rmse:81488.1
[496] test-rmse:81478.6
[497] test-rmse:81469.1
[498] test-rmse:81463
[499] test-rmse:81462.4
[500] test-rmse:81454.6
[501] test-rmse:81453.7
[502] test-rmse:81450.8
[503] test-rmse:81443
[504] test-rmse:81434
[505] test-rmse:81430.1
[506] test-rmse:81427.5
[507] test-rmse:81421.4
[508] test-rmse:81421.2
[509] test-rmse:81415.5
[510] test-rmse:81413
[511] test-rmse:81409.5
[512] test-rmse:81396.6
[513] test-rmse:81395.1
[514] test-rmse:81395.3
[515] test-rmse:81392.2
[516] test-rmse:81391.4
[517] test-rmse:81388.2
[518] test-rmse:81383.8
[519] test-rmse:81379.3
[520] test-rmse:81379.8
[521] test-rmse:81376.9
[522] test-rmse:81376.9
[523] test-rmse:81375.1
[524] test-rmse:81370.2
[525] test-rmse:81365.3
[526] test-rmse:81364.3
[527] test-rmse:81363.6
[528] test-rmse:81362.5
[529] test-rmse:81358.9
[530] test-rmse:81354.8
[531] test-rmse:81353.9
[532] test-rmse:81355.2
[533] test-rmse:81355.2
[534] test-rmse:81356.4
[535] test-rmse:81356.3
[536] test-rmse:81352.6
[537] test-rmse:81347.5
[538] test-rmse:81347.6
[539] test-rmse:81349.5
[540] test-rmse:81350.3
[541] test-rmse:81351.2
[542] test-rmse:81351.6
Stopping. Best iteration:
[537] test-rmse:81347.5
重要度の高い特徴量を表示する。
先ほど学習したモデルのうち、その特徴量が重要度が高かったのかを、表示してみます。これを見ると、long、lat、sqft_livingなどの特徴量の重要度が高いことがわかります。
fig, ax = plt.subplots(figsize = (12,12))
xgb.plot_importance(model, max_num_features = 12, height = 0.8, ax = ax)
plt.savefig('house_sails_feature_importance03.png', bbox_inches='tight')
学習したモデルの精度を計測してみる。
学習したモデルの精度を計測してみます。r2_scoreでは、約0.847という精度となっています。
from sklearn.metrics import r2_score
preds = model.predict(dtest)
r2 = r2_score(y_test, preds)
print(r2)
0.8473346069012444
学習モデルを保存する
最後に、joblibを使って学習モデルをpklファイルとして保存しておきます。これで、機械学習で学習した学習モデルができました。
from sklearn.externals import joblib
joblib.dump(model, 'house_sales_model.pkl')
['house_sales_model.pkl']
FlaskでAPIサーバーを作る
ここでは、先ほど機械学習で生成した学習モデルをAPIサーバーにするということをやっています。APIサーバー開発にはPythonのマイクロサービス系のフレームワークである、Flaskを使用しています。開発の流れとしては、condaで仮想環境を構築し、簡易的なAPIサーバーをテストして、そこにXGBoostで作った学習モデルを載せるという流れになります。
condaで仮想環境を構築する
仮想環境は、Anacondaのcondaを使用します。ターミナルでアプリ開発用のフォルダ(ここでは、titanic_api)を作り、そのフォルダ内に移動します。そしたらconda createで仮想環境を生成し、conda activateで仮想環境をアクティブな状態にします。
mkdir housesails_api
cd housesails_api
conda create -n housesailsenv
conda activate housesailsenv
FlaskでAPIを開発する
FlaskでAPIサーバーを開発するために、最初に簡易的なAPIサーバーを作ってテストしてみます。先ほど作ったフォルダ内に以下のようなフォルダとファイルを作ります。ファイルにはそれぞれ以下のようなコードを書いて、APIサーバーを起動して、curlから通信できれば簡易的なAPIサーバーテストの成功です。
ターミナルに必要なフォルダとファイルを生成する。
以下のような階層になるようにフォルダとファイルを作ります。空ファイルを作るならtouchコマンドなどを使用すると便利です。
housesails_api
├── api
│ ├── __init__.py
│ └── views
│ └── user.py
├── housesails_app.py
└── house_sales_model.pkl
作成したファイルにコードを書く
先ほど作成したファイルに以下のようにコードを書きます。簡易的なAPIサーバーをテストするために必要なファイルは、api/views/user.py、api/init.py、titanic_app.pyの三つです。ターミナルで書く場合はvim、GUIで書く場合はAtomなどを使用すると便利です。
from flask import Blueprint, request, make_response, jsonify
# ルーティング設定
user_router = Blueprint('user_router', __name__)
# パスとHTTPメソッドを指定
@user_router.route('/users', methods=['GET'])
def get_user_list():
return make_response(jsonify({
'users': [
{
'id': 1,
'name': 'John'
}
]
}))
from flask import Flask, make_response, jsonify
from .views.user import user_router
def create_app():
app = Flask(__name__)
app.register_blueprint(user_router, url_prefix='/api')
return app
app = create_app()
import json
from flask import Flask
from flask import request
from flask import abort
import pandas as pd
from sklearn.externals import joblib
import xgboost as xgb
model = joblib.load("house_sales_model.pkl")
app = Flask(__name__)
# Get headers for payload
headers = ['sqft_living','sqft_above','sqft_basement','lat','long','sqft_living15','grade_3','grade_4','grade_5','grade_6','grade_7','grade_8','grade_9','grade_10','grade_11','grade_12']
@app.route('/house_sails', methods=['POST'])
def housesails():
if not request.json:
abort(400)
payload = request.json['data']
values = [float(i) for i in payload.split(',')]
data1 = pd.DataFrame([values], columns=headers, dtype=float)
predict = model.predict(xgb.DMatrix(data1))
return json.dumps(str(predict[0]))
if __name__ == "__main__":
app.run(debug=True, port=5000)
curlでAPI通信テストをする
コードを書き換えたら、改めて、python housesails_app.py で、APIサーバーを起動します。APIサーバーが起動したら、以下のようにcurlコマンドで通信テストをしています。送ったJSONデータに対して、小数点1以下の値が返ってきたら成功です。これで、機械学習で生成した学習モデルをAPIサーバーにすることができました。
curl http://localhost:5000/house_sails -s -X POST -H "Content-Type: application/json" -d '{"data": "-1.026685, -0.725963, -0.652987, -0.323607, -0.307144, -0.946801, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0"}'
認証機能をつける
最後に、APIサーバーに認証機能をつけます。今回はbasic認証を実装しています。ライブラリにHTTPBasicAuthがインストールされている必要があります。先ほどのhousesails_app.pyのコードに# BasicAuthと書かれているところを追記することで実装できます。
import json
from flask import Flask
from flask import request
from flask import abort
from flask_httpauth import HTTPBasicAuth
import pandas as pd
from sklearn.externals import joblib
import xgboost as xgb
model = joblib.load("house_sales_model.pkl")
app = Flask(__name__)
# BasicAuth
auth = HTTPBasicAuth()
users = {
"user01": "password01",
"user02": "password02"
}
@auth.get_password
def get_pw(username):
if username in users:
return users.get(username)
return None
# Get headers for payload
headers = ['sqft_living','sqft_above','sqft_basement','lat','long','sqft_living15','grade_3','grade_4','grade_5','grade_6','grade_7','grade_8','grade_9','grade_10','grade_11','grade_12']
@app.route('/house_sails', methods=['POST'])
# BasicAuth
@auth.login_required
def housesails():
if not request.json:
abort(400)
payload = request.json['data']
values = [float(i) for i in payload.split(',')]
data1 = pd.DataFrame([values], columns=headers, dtype=float)
predict = model.predict(xgb.DMatrix(data1))
return json.dumps(str(predict[0]))
if __name__ == "__main__":
app.run(debug=True, port=5000)
再度curlでAPI通信テストをする
再度、python housesails_app.py で、APIサーバーを起動します。APIサーバーが起動したら、以下のようにcurlコマンドで通信テストをしています。--user user01:password01というのを追記することで認証されます。これで通信がうまくいえば成功です。
curl http://localhost:5000/house_sails --user user01:password01 -s -X POST -H "Content-Type: application/json" -d '{"data": "-1.026685, -0.725963, -0.652987, -0.323607, -0.307144, -0.946801, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0"}'