LoginSignup
0
0

eli5とscikit-survivalの互換性

Posted at

train -val- testに分けてRSFを行っていました。concordance indexまでは求めることができましたが、eli5とscikit-survivalの互換性の問題でfeature importanceを求めることができない。
コードを以下に示します。google colaboratoryを使用

生存関数が関与してくるとむずかしくなりますが、eli5とscikit-survivalの互換性の問題を皆さんどのように解決していますか。
feature importanceは論文を作成するうえで必ず必要となります。

!pip install scikit-survival

import pandas as pd
from sklearn.model_selection import train_test_split
from sksurv.ensemble import RandomSurvivalForest
from sklearn.model_selection import GridSearchCV
from sksurv.metrics import concordance_index_censored
import numpy as np

Load your data

data = pd.read_excel('your_data_file.xlsx')

Preparing the data

target = np.array([(e == 2, t) for e, t in zip(data['Event'], data['DFS'])], dtype=[('Event', '?'), ('DFS', '<f8')])
features = data.drop(columns=['Event', 'DFS'])

Splitting the data

X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=42)

Hyperparameter tuning

param_grid = {
'max_features': ['auto', 'sqrt', 'log2'],
'max_depth': [None, 10, 20, 30],
'min_samples_leaf': [1, 2, 4],
'n_estimators': [100, 200, 300],
'min_samples_split': [2, 5, 10]
}
rsf = RandomSurvivalForest(random_state=42)
grid_search = GridSearchCV(rsf, param_grid, cv=5, n_jobs=-1, scoring='roc_auc')
grid_search.fit(X_train, y_train)

Best parameters and model

best_params = grid_search.best_params_
best_rsf = grid_search.best_estimator_

Evaluation on test set

prediction = best_rsf.predict(X_test)
c_index = concordance_index_censored(y_test['Event'], y_test['DFS'], prediction)

Output best parameters and concordance index

print("Best Parameters:", best_params)
print("Concordance Index on Test Set:", c_index[0])

!pip install eli5

sf = RandomSurvivalForest(random_state=42, **best_params)
rsf.fit(X_train, y_train)

Feature Importance

from eli5.sklearn import PermutationImportance
perm = PermutationImportance(rsf, n_iter=15, random_state=42)
perm.fit(X_test, y_test)
eli5.show_weights(perm, feature_names=X_test.columns.tolist())

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0