Help us understand the problem. What is going on with this article?

患者カルテデータを使って機械学習をやってみるよ!【scikit-learn】

はじめに

「患者データで機械学習の研究始めるよ!」って人の研究チュートリアルしたいです。
scikit-learnの糖尿病データセットでいろんな回帰モデルを紹介します。

読み飛ばす背景。。

臨床データで機械学習したい研究って結構あるとおもいます。

けど、ありがちなのが、「倫理審査が終わるまでデータがもらえない=>もらってからデスマーチうぉぉぉ=>orz」のパターン。。

今回は練習用に、使用可能な患者公開データの紹介から、機械学習を使った回帰予測モデルの構築まで紹介してきます。

使用可能なデータセット

いくつか使えそうなカルテデータを紹介します。

sklearn.datasets.load_diabetes
灯台下暗し。scikit-learnのサンプルデータセットに糖尿病患者のデータセットが入ってました。もっと早く知っていれば。。

Indian Liver Patient Records
UCI Machine Learningで配布されてる、肝疾患のデータセット。
(https://archive.ics.uci.edu/ml/datasets/ILPD+(Indian+Liver+Patient+Dataset))

Chronic Kidney Disease dataset
UCI Machine Learningで配布されてる、肝疾患のデータセット。
(https://archive.ics.uci.edu/ml/datasets/Chronic_Kidney_Disease)

他にあったら教えて下さい。。

本題

一番お手軽に使えるsklearn.datasets.load_diabetesを使っていきます。
441名の糖尿病患者カルテデータセットで、タスクは「カルテデータから1年後の症状の進行を予測する」というものです。

今回はGoogle Colaboratory(もしくはJupyterNotebook)でやっていきます。

データの読み込み

まずは、データを読み込んで見てみます。

import sys,re,time,os,pickle
import pandas as pd
import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
import seaborn as sns

#データセットの読み込み
data = datasets.load_diabetes()
display(data)

▼出力

{'DESCR': '.. _diabetes_dataset:\n\nDiabetes dataset\n----------------\n\nTen baseline variables, age, sex, body mass index, average blood\npressure, and six blood serum measurements were obtained for each of n =\n442 diabetes patients, as well as the response of interest, a\nquantitative measure of disease progression one year after baseline.\n\n**Data Set Characteristics:**\n\n  :Number of Instances: 442\n\n  :Number of Attributes: First 10 columns are numeric predictive values\n\n  :Target: Column 11 is a quantitative measure of disease progression one year after baseline\n\n  :Attribute Information:\n      - Age\n      - Sex\n      - Body mass index\n      - Average blood pressure\n      - S1\n      - S2\n      - S3\n      - S4\n      - S5\n      - S6\n\nNote: Each of these 10 feature variables have been mean centered and scaled by the standard deviation times `n_samples` (i.e. the sum of squares of each column totals 1).\n\nSource URL:\nhttps://www4.stat.ncsu.edu/~boos/var.select/diabetes.html\n\nFor more information see:\nBradley Efron, Trevor Hastie, Iain Johnstone and Robert Tibshirani (2004) "Least Angle Regression," Annals of Statistics (with discussion), 407-499.\n(https://web.stanford.edu/~hastie/Papers/LARS/LeastAngle_2002.pdf)',
 'data': array([[ 0.03807591,  0.05068012,  0.06169621, ..., -0.00259226,
          0.01990842, -0.01764613],
        ...,
        [-0.04547248, -0.04464164, -0.0730303 , ..., -0.03949338,
         -0.00421986,  0.00306441]]),
 'data_filename': '/usr/local/lib/python3.6/dist-packages/sklearn/datasets/data/diabetes_data.csv.gz',
 'feature_names': ['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4',
's5','s6'],
 'target': array([151.,  75., 141., 206., 135.,  97., 138.,  63., 110., 310., 101.,
         69., 179., 185., 118., 171., 166., 144.,  97., 168.,  68.,  49.,...

検査値が"data",その列名が"feature_names",予測する1月後の数値が"target"に入ってますね。
データフレーム化して処理していきます。

data_df = pd.DataFrame(data.data,columns=data.feature_names)
data_df['target']=data.target
display(data_df)

▼出力
0910_qiita1.png

それぞれの列の特徴を見ていきます。pandas profilingを使うと簡単にデータの特徴がわかります。超便利。

import pandas_profiling as pdp
pdp.ProfileReport(data_df)

▼出力
0910_qiita2.png
▼相関図も見れます
0910_qiita3.png

欠損値がないのでデータは全部使えますね。

相関図でbmi,s4,s5あたりが関連しそうな因子だとわかります

データの分割

train : test=8:2
にして機械学習モデルを構築・評価していきます。

from sklearn.model_selection import train_test_split

y =data_df['target'] 
X = data_df.drop('target',axis=1)

#hold_out法:train:test=0.8:0.2にして分割
(X_train, X_test, y_train, y_test) = train_test_split(X, y, test_size=0.2, random_state=44,
)

※機械学習モデルの選択

ここまでで機械学習モデルを構築する準備ができました!さて、どのモデルがこの問題には最適でしょうか?
公式様がガイド出してるので参考にします。(右上らへんのregressionの部分)

モデルの選択

いくつかの特徴は重要なような、そうでもないような。。

公式様の教えに従いはElasticNetとSVRの2つ、あと定番のRandomForest(※EnsembleRegressorの1種)をやってみます。

モデルの構築

ではモデルの構築をやってきます。とはいえscikit-learnを使えば一瞬です。。神。。

#モデルの構築
from sklearn.linear_model import ElasticNet
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor
#Elastic Net
el_model = ElasticNet(alpha=0.0001, l1_ratio=0.5)
el_model.fit(X_train, y_train)

#Linear SVR
svr_model = SVR(kernel='linear', C=1000, epsilon=0.2)
svr_model.fit(X_train, y_train)

#RandomForest
rf_model = RandomForestRegressor(n_estimators=100,criterion='mse', max_depth=6,bootstrap=0.9)
rf_model.fit(X_train, y_train)

予測・評価

学習済みのモデルに、分割しておいたテストデータを入れて評価します。
今回はR2という評価指標を用いました(最大は1.0)

#testデータの予測
el_pred = el_model.predict(X_test)
svr_pred = svr_model.predict(X_test)
rf_pred = rf_model.predict(X_test)
from sklearn.metrics import r2_score

#r2scoreによる評価
el_score = r2_score(y_test, el_pred)
svr_score = r2_score(y_test, svr_pred)
rf_score = r2_score(y_test, rf_pred)

print("Score_ElasticNet:{} ".format(el_score))
print("Score_SVR(liner):{} ".format(svr_score))
print("Score_RandomForest:{} ".format(rf_score))

▼出力

Score_ElasticNet:0.5212070906081345 
Score_SVR(liner):0.5175245841065237 
Score_RandomForest:0.43607586214409555 

きちんと評価までできてますね!今回評価した中だと、ElasticNetが最も良いモデル(R2score:0.52)となりました。

まとめ

今回はscikit-learnの糖尿病データセットを用いてか回帰問題の機械学習モデルの構築・評価を行ってみました。

スコアも得られましたが、まだまだやることはあると思います。例えば...

・hold-outではなく交差検証で妥当な検証を行う
・パラメータチューニングをする
・特徴量の処理を行う
・他のモデルを試してみる

などなど...。それでは皆様、良いカルテ研究を!

※おまけ

scikit-learn以外の手法も簡単にやってみました。
LightGBMという勾配ブースティングの手法。あとはニューラルネットワーク君です。

#LightGBMによる回帰
import lightgbm as lgb

(train_X, val_X, train_y, val_y) = train_test_split(X_train, y_train, test_size=0.1, random_state=0,
)
lgb_train = lgb.Dataset(train_X, train_y)
lgb_eval  = lgb.Dataset(val_X, val_y, reference=lgb_train)

#回帰の場合、objectiveをregressionにする
params = {
    'task': 'train',
    'boosting_type': 'gbdt',
    'objective': 'regression',
    'feature_fraction': 0.5,
    'bagging_fraction': 0.7,
    'bagging_freq': 7,
    "max_depth": 6, 
    "num_leaves":32,
    "max_bin": 512,
    "learning_rate": 0.01, 
    "verbose":0
  }

#validationデータを使って学習
gbm = lgb.train(params,lgb_train,valid_sets=lgb_eval,
                num_boost_round=10000, early_stopping_rounds=100)

#予測
gbm_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)

print("LGBM_r2score:{}".format(r2_score(y_test,gbm_pred)))

▼出力

LGBM_r2score:0.530247205736787

一番高いスコアが出ましたね!次はNN...

#kerasによるNN
from keras.layers import Dense, Dropout, Activation
from keras import optimizers
from keras.models import Model,Sequential
from keras.callbacks import EarlyStopping

#隠れ層の数、ユニット数の設定
hidden_layer=5
hidden_units = 32

#モデルの構築
model = Sequential()
model.add(Dense(hidden_units))
model.add(Activation('relu'))
model.add(Dropout(0.2))

for n in range(hidden_layer):
  model.add(Dense(hidden_units-(n*4)))
  model.add(Activation('relu'))
model.add(Dense(1))

sgd = optimizers.Adam(lr=0.002,decay=1e-9)
model.compile(loss='mean_squared_error',optimizer=sgd,)

#過学習しないように設定
early_stopping =EarlyStopping(monitor='val_loss', patience=5)

#学習
hist = model.fit(train_X.values, np.log1p(train_y.values),
        validation_data=(val_X, val_y), 
        epochs=10000, batch_size=1, shuffle=True, 
        callbacks=[early_stopping])

#予測
nn_preds = model.predict(X_test.values, batch_size=1, verbose=0)

print("NN_r2score:{}".format(r2_score(y_test,np.expm1(nn_preds))))

▼出力

NN_r2score:0.4928342032547228

NNはあんまりですが、パラメータ調節すればまだ上がると思います。

参考サイト

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away