3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

scikit-learnによる機械学習実装 ~ digits編 ~

Posted at

初めに

内容

scikit-learnを用いた機械学習を行います。
データセットはscikit-learnから提供されるデータセットを使用します。

本記事では「Digits」(数字の手書き文字)を取り扱います。

irisデータセットの記事はこちらです
bostonデータセットの記事はこちらです
diabetsデータセットの記事はこちらです
その他のデータセットは別の記事にて更新しようと思います。

対象

主に初心者向けです。が、入門向けではないです。

本記事は友人に向けて作成してますが、
「機械学習したい」と思っている方にも読んでもらえるように書いています。

実装に関しての解説は記載しますが、アルゴリズムなどの解説は(基本的に)しませんのでご理解のほどを。

環境

  • Google Colaboratory
  • Python 3.6.9
  • sklearn 0.22.2
  • pandas 1.1.5
  • matplotlib 3.2.2

ローカルで環境をそろえるのは難しいので(Dockerやクラウドを使えばできますが)今回はColaboratoryを用います。
上記のバージョンに合わせるとローカル環境でも実行できます。


以降は実際にpythonを用いながらになるのでColaboratoryの用意をお願いします。

Digitsデータセットについて

0~9の手書き数字が書かれた画像データセットです。
(正確には画像ではなく数値の配列)

一つの手書き数字は、1チャネルで表されます。(モノクロ画像)

一つの手書き数字は、8×8のサイズで構成され、1次元の配列に格納されています。
(横向きにスライスした画像を一列につなげるイメージで、64個の数値 の配列から構成される)

言葉だと少しわかりにくいので、可視化の項目で実際に見てみましょう。

説明変数

今回はpandas.DataFrameを使用しません。

from sklearn.datasets import load_digits
digits = load_digits()
data = digits.data
print("全体 :", data.shape) # => 全体 : (1797, 64)
print("1行目(1つ目のデータ) :", data[0].shape) # => 1行目(1つ目のデータ) : (64,)

レコード数は1797件ですね。
説明変数は64。
すなわち、1ピクセルが一つの説明変数になります。

可視化してみる

まず、一つ目のデータをそのまま見てみましょう

print(data[0])

image.png
1次元のベクトルになっているのがはっきりわかりますね。
可視化してみます。

import matplotlib.pyplot as plt
# plt.imshow(digits.images[0], interpolation='nearest')
plt.imshow(digits.images[0], cmap=plt.cm.gray_r, interpolation='nearest')
# plt.imshow(digits.images[0], cmap='plasma', interpolation='bicubic')
plt.show()

8×8ピクセルなので見ずらいですね...
image.png
コメントアウトされている者もぜひ見てください。見やすくなりますよ!

目的変数

形を確認します。

target = digits.target
target.shape # => (1797,)

こちらは説明変数の一つのレコードが、何の数字を表しているか、というデータですね。
つまり今回は分類問題になります。

こんな感じのが格納されてます。(全部見たいならfor文とかで...)

print(target) # => [0 1 2 ... 8 9 8]

機械学習する

前処理

データセットの分割をします。おなじみの処理ですね。

from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(data, target, test_size=0.2, random_state=123)
print('x_train :', x_train.shape) # => x_train : (1437, 64)
print('y_train :', y_train.shape) # => y_train : (1437,)
print('x_test :', x_test.shape) # => x_test : (360, 64)
print('y_test :', y_test.shape) # => y_test : (360,)

学習

今回はk近傍法を使用して分類します。
sklearn.neighborsのKNeighborsClassifierをインポートしましょう。
10クラスに分類したいのでハイパーパラメータとして n_neighbors=10 を指定しましょう。

from sklearn.neighbors import KNeighborsClassifier

model = KNeighborsClassifier(n_neighbors=10)
model.fit(x_train, y_train)

推論

from sklearn.metrics import accuracy_score, recall_score, precision_score

pred = model.predict(x_test)
accuracy_score(y_test, pred) # => 0.9861111111111112

98%を超えてるのでよしとしましょう。

CNNを利用すると簡単に100%の精度が出ると思います。
ぜひこちらの記事を...
頑張ればscikit-learnでも100%が出るので暇な方はチャレンジを。(時間の無駄だと思います)

ちなみに、predの中身は次のようになります。

image.png
「64の説明変数から予測した、0~9のうち1つの目的変数」が入力したデータ分返っていますね。

最後に

データは予測して終わりではなく、「予測したデータを利用して何かを作る」ことのほうが重要です。
ぜひ作業時間を作って、アプリケーションにしてみてください。(colaboratory上で動く関数程度でもいいです)
将来の糧になるはずですので。

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?