0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

木と線形系における機械学習の可視化

Posted at

使用するアルゴリズム

  • リッジ回帰
  • ラッソ回帰
  • 重回帰
  • 決定木
  • ランダムフォレスト
  • エキストラツリー

以上を対象に,重要度評価をして円グラフに可視化する.

大まかなプログラムの流れ

  1. 各アルゴリズムをforで回しながらインスタンス化
  2. 重要度を算出
  3. plt.add_subplotでそれぞれを円グラフに

ソースコード

Xyを定義してるのを前提としています.

main
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)


models = (Ridge(alpha=0.01),
         Lasso(alpha=0.01), 
         LinearRegression(),
         DecisionTreeRegressor(random_state=0,max_depth=4),
         RandomForestRegressor(max_features=3,random_state=0,max_depth=14, n_estimators=100),
         ExtraTreesRegressor(random_state=0,n_estimators=100,max_depth=5)
        )

fig = plt.figure(figsize = (24,12))

for num,model in enumerate(models):
    feature = Regression(model)
    plot_model(feature,X.collumns)
fig.tight_layout()
plt.savefig('figure.png',bbox_inches='tight',pad_inches=0.05)

ここに,Regression()はモデルを引き渡し,重要度評価をしてくれる関数として定義しています.
try文で木かそれ以外かを場合分けしてるけど,これが最適とは思いにくい.

Regression(model)
def Regression(model):
    model.fit(X_train, y_train)
    try:
        x = model.feature_importances_ #木ならこっち
    except:
        x = model.coef_  # リッジ,ラッソ,重回帰ならこっち
    
    y_pred = model.predict(X_test)
    
    kf = KFold(n_splits=5, shuffle=True, random_state=1)
    result = cross_val_score(model,X,y,scoring='neg_root_mean_squared_error',cv=kf)
    
    print(str(model))
    print("RMSE:", -1*result.mean())
    print("-"*40)
    return abs(x)

続いて,plot_model()は重要度とラベルを引数にとって可視化する関数.

plot_model(x,labels)
def plot_model(x,labels):
    fig.patch.set_facecolor('white')
    idx = np.argsort(-x)
    ax = fig.add_subplot(2, 3, num+1)
    
    ax.pie(x[idx], autopct=lambda p:'{:.1f}%'.format(p) if p>=4.0 else '',
           pctdistance=0.8, startangle=90, counterclock=False,
           # labeldistance=1.2,
           # radius=0.8,
           textprops={'color': "black",'fontweight': 'bold'} # 'size': 'small'
          )
    ax.set_title(str(model))
    ax.legend(labels[idx],fancybox=True,loc='center left',bbox_to_anchor=(0.9,0.5))

そんでできるのがコレ.凡例は一部を隠しています(テキトーに用意するのが面倒くさかった).
図1.png

参考

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?