Pythonの機械学習の本を読んだのでちゃんとメモするぞ。
準備
まず、Pythonはよく分からないのでRのreticulate
パッケージを使うことにする。
reticulate
を使うとRからPythonが使用できる。なお、venv
を使用している場合はuse_viertualenv
ではなくuse_python
を使うようだ。ちなみにPythonは3.6.4を入れてある。以下で必要なモジュールは予めインストールしてある。
library(reticulate)
use_python("~/myenv/bin/python")
データを読む
iris
をsklearn.datasets
のload_iris
で読み込む。load_iris
が返すオブジェクトはBunch
クラス、概ね辞書みたいなものらしい。
sk_datasets <- import("sklearn.datasets")
iris_dataset <- sk_datasets$load_iris()
class(iris_dataset)
$> [1] "sklearn.utils.Bunch" "python.builtin.dict" "python.builtin.object"
キーを確認しよう。
names(iris_dataset)
$> [1] "data" "target" "target_names" "DESCR"
$> [5] "feature_names"
DESCR
にデータセットの解説が入っている。
cat(iris_dataset$DESCR)
$> Iris Plants Database
$> ====================
$>
$> Notes
$> -----
$> Data Set Characteristics:
$> :Number of Instances: 150 (50 in each of three classes)
$> :Number of Attributes: 4 numeric, predictive attributes and the class
$> :Attribute Information:
$> - sepal length in cm
$> - sepal width in cm
$> - petal length in cm
$> - petal width in cm
$> - class:
$> - Iris-Setosa
$> - Iris-Versicolour
$> - Iris-Virginica
$> :Summary Statistics:
$>
$> ============== ==== ==== ======= ===== ====================
$> Min Max Mean SD Class Correlation
$> ============== ==== ==== ======= ===== ====================
$> sepal length: 4.3 7.9 5.84 0.83 0.7826
$> sepal width: 2.0 4.4 3.05 0.43 -0.4194
$> petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)
$> petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)
$> ============== ==== ==== ======= ===== ====================
$>
$> :Missing Attribute Values: None
$> :Class Distribution: 33.3% for each of 3 classes.
$> :Creator: R.A. Fisher
$> :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
$> :Date: July, 1988
$>
$> This is a copy of UCI ML iris datasets.
$> http://archive.ics.uci.edu/ml/datasets/Iris
$>
$> The famous Iris database, first used by Sir R.A Fisher
$>
$> This is perhaps the best known database to be found in the
$> pattern recognition literature. Fisher's paper is a classic in the field and
$> is referenced frequently to this day. (See Duda & Hart, for example.) The
$> data set contains 3 classes of 50 instances each, where each class refers to a
$> type of iris plant. One class is linearly separable from the other 2; the
$> latter are NOT linearly separable from each other.
$>
$> References
$> ----------
$> - Fisher,R.A. "The use of multiple measurements in taxonomic problems"
$> Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
$> Mathematical Statistics" (John Wiley, NY, 1950).
$> - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.
$> (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.
$> - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
$> Structure and Classification Rule for Recognition in Partially Exposed
$> Environments". IEEE Transactions on Pattern Analysis and Machine
$> Intelligence, Vol. PAMI-2, No. 1, 67-71.
$> - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions
$> on Information Theory, May 1972, 431-433.
$> - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II
$> conceptual clustering system finds 3 classes in the data.
$> - Many, many more ...
target_names
には目的変数のクラス名が入っている。
iris_dataset$target_names
$> [1] "setosa" "versicolor" "virginica"
feature_names
には特徴量の説明。
iris_dataset$feature_names
$> [1] "sepal length (cm)" "sepal width (cm)" "petal length (cm)"
$> [4] "petal width (cm)"
データの本体はtarget
とdata
に配列として格納されている。
iris_dataset$target
$> [1] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
$> [36] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
$> [71] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2
$> [106] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
$> [141] 2 2 2 2 2 2 2 2 2 2
head(iris_dataset$data)
$> [,1] [,2] [,3] [,4]
$> [1,] 5.1 3.5 1.4 0.2
$> [2,] 4.9 3.0 1.4 0.2
$> [3,] 4.7 3.2 1.3 0.2
$> [4,] 4.6 3.1 1.5 0.2
$> [5,] 5.0 3.6 1.4 0.2
$> [6,] 5.4 3.9 1.7 0.4
これらの数字や、数字の列がが具体的にどれに対応しているのかを示しているのがtarget_names
とfeature_names
という訳だ。
訓練データとテストデータ
データを分割しよう。分割にはtrain_test_split
関数を使う。擬似乱数列の種を固定するために、random_state
に0を指定する。明示的に整数型を渡すのを忘れないようにしよう。
sk_model_selection <- import("sklearn.model_selection")
tmp <- sk_model_selection$train_test_split(
iris_dataset$data,
iris_dataset$target,
random_state = 0L
)
X_train <- tmp[[1]]
X_test <- tmp[[2]]
y_train <- tmp[[3]]
y_test <- tmp[[4]]
X
は入力、y
は出力を、_train
は訓練データを、_test
はテストデータを表している。入力データの形状だけ確認しておこう。
dim(X_train)
$> [1] 112 4
dim(X_test)
$> [1] 38 4
train_test_split
は標準で全データの25%をテストデータに割り振る。
データを良く観察する
pandas
のscatter_matrix
関数で散布図行列を作成する。(このとき、knitrで.Rmdから.mdを生成しようとしたらエラーが出た。~/.matplotlib/matplotlibrc
を編集して、backend : TkAgg
としたらエラーは出なくなった。ただしプロットは埋め込まれたりしない。参考: Python 3.3でmatplitlibとpylabを使おうとしたら RuntimeError: Python is not installed as a frameworkというエラーが発生したときの解決方法 - Qiita)
pd <- import("pandas")
matplotlib <- import("matplotlib")
iris_dataframe <- pd$DataFrame(X_train, columns=iris_dataset$feature_names)
grr <- pd$scatter_matrix(iris_dataframe, c = y_train, marker = 'o', figsize = c(10, 10),
hist_kwds=list('bins' = 20L))
matplotlib$pyplot$show()
k-最近傍法をやる
まず、neighboars
モジュールのKNeighboarsClassifier
クラスのインスタンスを作成する。
sklearn.neighbors <- import("sklearn.neighbors")
# 整数型のハイパーパラメータはきちんと整数型を明示して指定しよう
knn = sklearn.neighbors$KNeighborsClassifier(n_neighbors = 1L)
訓練セットからモデルを構築する。
knn$fit(X_train, y_train)
$> KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
$> metric_params=None, n_jobs=1, n_neighbors=1, p=2,
$> weights='uniform')
予測を行う
予測はknn$predict
で行う。np$array
をそのまま与えたらarray.reshape
で直せと怒られたので(観測値が1つだから?)少し加工してある。
np <- import("numpy")
X_new <- np$array(c(5.0, 2.9, 1, 0.2))
X_new <- array_reshape(X_new, c(1, -1))
prediction <- knn$predict(X_new)
prediction
$> [1] 0
iris_dataset$target_names[prediction+1] # Rは1オリジンなので気をつけよう
$> [1] "setosa"
ただ、これでは"setosa"が正しいのかどうか分からない。
モデルの評価
テストデータを使ってテストをしよう。
y_pred <- knn$predict(X_test)
y_pred
$> [1] 2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2
$> [36] 1 0 2
これを正解のクラスであるy_test
と比較すればいい。
mean(y_test == y_pred)
$> [1] 0.9736842
あるいは、knn$score
を使ってもいい。これにはテストデータの入力と出力を与える。
knn$score(X_test, y_test)
$> [1] 0.9736842
感想
ネタでやってみたら思ったよりずっと普通にやれてしまった…reticulate
すごい…
注意点は整数は整数として書くということ、添字がRは1オリジンでPythonは0オリジンというあたり。