train -val- testに分けてRSFを行っていました。concordance indexまでは求めることができましたが、eli5とscikit-survivalの互換性の問題でfeature importanceを求めることができない。
コードを以下に示します。google colaboratoryを使用
生存関数が関与してくるとむずかしくなりますが、eli5とscikit-survivalの互換性の問題を皆さんどのように解決していますか。
feature importanceは論文を作成するうえで必ず必要となります。
!pip install scikit-survival
import pandas as pd
from sklearn.model_selection import train_test_split
from sksurv.ensemble import RandomSurvivalForest
from sklearn.model_selection import GridSearchCV
from sksurv.metrics import concordance_index_censored
import numpy as np
Load your data
data = pd.read_excel('your_data_file.xlsx')
Preparing the data
target = np.array([(e == 2, t) for e, t in zip(data['Event'], data['DFS'])], dtype=[('Event', '?'), ('DFS', '<f8')])
features = data.drop(columns=['Event', 'DFS'])
Splitting the data
X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=42)
Hyperparameter tuning
param_grid = {
'max_features': ['auto', 'sqrt', 'log2'],
'max_depth': [None, 10, 20, 30],
'min_samples_leaf': [1, 2, 4],
'n_estimators': [100, 200, 300],
'min_samples_split': [2, 5, 10]
}
rsf = RandomSurvivalForest(random_state=42)
grid_search = GridSearchCV(rsf, param_grid, cv=5, n_jobs=-1, scoring='roc_auc')
grid_search.fit(X_train, y_train)
Best parameters and model
best_params = grid_search.best_params_
best_rsf = grid_search.best_estimator_
Evaluation on test set
prediction = best_rsf.predict(X_test)
c_index = concordance_index_censored(y_test['Event'], y_test['DFS'], prediction)
Output best parameters and concordance index
print("Best Parameters:", best_params)
print("Concordance Index on Test Set:", c_index[0])
!pip install eli5
sf = RandomSurvivalForest(random_state=42, **best_params)
rsf.fit(X_train, y_train)
Feature Importance
from eli5.sklearn import PermutationImportance
perm = PermutationImportance(rsf, n_iter=15, random_state=42)
perm.fit(X_test, y_test)
eli5.show_weights(perm, feature_names=X_test.columns.tolist())