0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

irisデータセットでデータサイエンス超絶初歩入門する

Posted at

#動機
自然言語処理の文書分類の前に普通の分類やってみたい

#設定
かの有名な__iris__データセット使います
適当にsklearnで訓練用、検証用、テスト用にわけて精度測ります
モデルはおなじみ__lightgbm__です

#コード

import numpy as np
from pandas import DataFrame
import pandas as pd
import lightgbm as lgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

iris = load_iris() #irisdatasetの読み込み

data = iris["data"]
target = iris["target"].reshape(150,1)
dataset = np.concatenate([data, target], axis=1)
features = iris['feature_names']
features.append("target")
df = DataFrame(dataset, columns=features)

#中身の確認
print(df) 

#訓練とテストにデータを分ける
train_set, test_set = train_test_split(df, test_size = 0.3, random_state = 123)
eval_set, test_set = train_test_split(test_set, test_size = 0.5, random_state = 123)

#目的変数と説明変数に分ける
x_train, y_train = train_set.drop("target", axis=1), train_set["target"]
x_eval, y_eval = eval_set.drop("target", axis=1), eval_set["target"]
x_test, y_test = test_set.drop("target", axis=1), test_set["target"]


#lightgbmのパラメータを設定
params = {
    'task': 'train', #訓練する
    'num_class': 3, #irisデータは目的変数が3種類であるため
    'boosting_type': 'gbdt',
    'objective': 'multiclass', #多クラス分類
    'metric': {'multi_logloss'}, #多クラス分類の損失
    'verbose': -1
}

#入力データの統合
lgb_train = lgb.Dataset(x_train, label=y_train)
lgb_eval = lgb.Dataset(x_eval, label=y_eval)

#訓練
model = lgb.train(params=params, #パラメータ 
                train_set=lgb_train, #訓練データ
                num_boost_round=100, #計算回数
                valid_sets=lgb_eval, #検証データ
                early_stopping_rounds=10) #アーリーストッピング

#テストデータで予測&精度確認
y_pred = model.predict(x_test, num_iteration=model.best_iteration)
y_pred = DataFrame(np.argmax(y_pred, axis=1).reshape(len(y_pred),1))

#単純な精度
print(accuracy_score(y_test, y_pred))
#Recall Precsionなど
print(classification_report(y_test, y_pred))
#混同行列
print(confusion_matrix(y_test, y_pred))

#参考

http://wordroid.sblo.jp/article/180966841.html

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?