多クラス分類をxgboostで解く
解決したいこと
signateの林型の分類の練習問題を、xgboostで解いてみようとしています。
コードを複数のサイトから引用して組み合わせてみたのですが、うまくいきません。
主にこのページを参考にさせていただきました
https://qiita.com/predora005/items/19aebcf3aa05946c7cf4
発生している問題・エラー
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
xgboost.core.XGBoostError: [09:06:34] d:\bld\xgboost-split_1645118015404\work\src\objective\multiclass_obj.cu:120: SoftmaxMultiClassObj: label must be in [0, num_class).
該当するソースコード
#訓練データとテストデータの取得
from sklearn.model_selection import train_test_split
forest_data = pd.DataFrame(train, columns=["Elevation", "Aspect", "Slope", "Horizontal_Distance_To_Hydrology","Vertical_Distance_To_Hydrology", "Horizontal_Distance_To_Roadways", "Hillshade_9am",
"Hillshade_Noon","Hillshade_3pm","Horizontal_Distance_To_Fire_Points","Wilderness_Area1","Wilderness_Area2","Wilderness_Area3","Wilderness_Area4",
"Soil_Type1",
"Soil_Type2", "Soil_Type3", "Soil_Type4", "Soil_Type5", "Soil_Type6", "Soil_Type7", "Soil_Type8","Soil_Type9", "Soil_Type10", "Soil_Type11", "Soil_Type12",
"Soil_Type13", "Soil_Type14", "Soil_Type15", "Soil_Type16", "Soil_Type17", "Soil_Type18", "Soil_Type19", "Soil_Type20", "Soil_Type21", "Soil_Type22", "Soil_Type23", "Soil_Type24",
"Soil_Type25", "Soil_Type26", "Soil_Type27", "Soil_Type28", "Soil_Type29", "Soil_Type30", "Soil_Type31", "Soil_Type32", "Soil_Type33", "Soil_Type34", "Soil_Type35", "Soil_Type36",
"Soil_Type37", "Soil_Type38", "Soil_Type39", "Soil_Type40"])
print(forest_data)
forest_target=pd.Series(train.iloc[:, -1])
print(forest_target)
#訓練データとテストデータの取得
from sklearn.model_selection import train_test_split
train_x, test_x, train_y, test_y = train_test_split(forest_data,
forest_target,
test_size=0.2,
shuffle=True)
#xgboost用の型に変換する
dtrain = xgb.DMatrix(train_x, label=train_y)
#パラメータの設定 max_depth:木の最大深度 eta:学習率 objective:学習目的 num_class:クラス数
param = {'max_depth': 2, 'eta': 1, 'objective': 'multi:softmax', 'num_class': 7}
#学習
num_round = 10
bst = xgb.train(param, dtrain, num_round)
#予測
dtest = xgb.DMatrix(test_x)
pred = bst.predict(dtest)
#精度の確認
from sklearn.metrics import accuracy_score
score = accuracy_score(test_y, pred)
print('score:{0:.4f}'.format(score))
#重要度の可視化
xgb.plot_importance(bst)
plt.show()
自分で試したこと
エラーコードを調べてみたところクラスの数が間違っているのではないかとあったため確認したのですが、
やはりクラスは7種類のようです。
解決方法をご教授いただけますでしょうか。
追記:
クラス数を8にしたら最後まで動きました。
7つの分類なのになぜなんでしょうか・・・・
0