LoginSignup
4
2

More than 3 years have passed since last update.

【Python】みんな大好きscikit-learnを使って手書き数字認識

Last updated at Posted at 2020-12-08

はじめに

こんにちは!!
最近やりたいことが多すぎて困っているヨシキです。今回は2回目となるPython一体何が出来んねんの記事になります!(早くも2回目笑)
機械学習というホットな話題にも触れられる内容となってますので、是非最後まで読んで一緒に実装していただけると飛んで喜びます。

事前知識

今回扱う内容は実践的内容ですので事前にscikit-learnなどに関して知りたいという方はこちらのサイトを参考にしてみてください。とても分かりやすかったです!!

環境

今回は簡単のためGoogleColaboratoryを使ってやっていきます。
GoogleColaboratoryをまだ使ったことがないよ!って人はこちらの私の記事、もしくはこちらの記事を参考にしてみてください。

実装

では!早速やっていきましょう!
まずは、諸々のインポート、Colabの設定と手書き数字データセットのロードをしていきます。
ここでGoogle driveにアクセスしてマイドライブ内に好きな名前でフォルダを1つ作っておきましょう!下記コードのDigitRecognitionの部分を作っていただいたフォルダ名と対応させてください。

import.py
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
from sklearn import decomposition
from sklearn import svm, metrics
import sklearn.svm as svm
from sklearn.metrics import confusion_matrix

from google import colab
colab.drive.mount('/content/gdrive')

#Directory setting
b_dir='gdrive/My Drive/DigitRecognition/' 

#digitdigitsというデータセットをロードする
digits=datasets.load_digits()

次にどんなデータを扱うのかを見てみましょう。

digits.py
ROWS_COUNT=4
COLUMNS_COUNT=5

DIGIT_GRAPH_COUNT=ROWS_COUNT*COLUMNS_COUNT

subfig=[]

x=np.linspace(-1,1,10)
fig=plt.figure(figsize=(12,9))

for i in range (DIGIT_GRAPH_COUNT):
  subfig.append(fig.add_subplot(ROWS_COUNT, COLUMNS_COUNT,i+1))
  subfig[i].imshow(digits.images[i],interpolation='nearest',cmap='viridis')
  plt.axis('off')
fig.subplots_adjust(wspace=0.3, hspace=0.3)
plt.show()
#手書き数字のデータをロードして特徴、ラベルを抽出
all_features=digits.data
gt_labels=digits.target

以下のような出力があればOKです。
image.png

次に3次元空間にデータの描画を行っていきたいと思います。
この際に、数字ごとに色分け、主成分分析により8×8次元を3次元に落とす作業などを並行してやっていきたいと思います。

plot.py
def getcolor(color):
  if color==0:
    return 'red'
  elif color==1:
    return 'orange'
  elif color==2:
    return 'yellow'
  elif color==3:
    return 'greenyellow'
  elif color==4:
    return 'green'
  elif color==5:
    return 'cyan'
  elif color==6:
    return 'blue'
  elif color==7:
    return 'navy'
  elif color==8:
    return 'purple'
  else:
    return 'black'
#主成分分析により次元削減
pca=decomposition.PCA(n_components=3)

#主成分分析により64次元データを3次元データに変換
three_features=pca.fit_transform(all_features)

#figureオブジェクト作成サイズを決める
fig=plt.figure(figsize=(12,9))

subfig=fig.add_subplot(111,projection='3d')
#教師データに対応する色のリスト用意
colors=list(map(getcolor,gt_labels))

#三次元空間へのデータの色付き描画を行う
subfig.scatter(three_features[:,0],three_features[:,1], three_features[:,2],s=50, c=colors, alpha=0.3)
plt.show()
fig.savefig(b_dir+'3d.png')

実行して以下の出力が得られたらOKです。
image.png

なんかかっこいいですね笑

次はデータの前処理を行っていきます。
以下で設定しいるハイパーパラメータは、私がこれが最良だろうと思ったものですので、よりよいものが見つかったよ!という人がいましたら教えてください笑

preprocessing.py
num_samples=len(digits.images)
data=digits.images.reshape((num_samples,-1))

#ハイパーパラメータ
h_para=3
#データの分類
denom=2 

#学習用データと教師データ 
train_features=data[:num_samples//denom] 
train_gt_labels=digits.target[:num_samples//denom]

#検証用データと教師データ
test_features=data[num_samples//2:]
test_gt_labels=digits.target[num_samples//2:]

次で最後です!いよいよsvmで学習させて結果を出力させてみます。

result.py
#modelの定義
model=svm.SVC(C=h_para,kernel='rbf' )

#SVMの学習
model.fit(train_features,train_gt_labels)
#学習結果の確認
expected=test_gt_labels

predicted=model.predict(test_features)

cm = confusion_matrix(expected,predicted)

print("Confusion matrix: \n %s"% cm)

csv_conf=b_dir+'confusion_trainSize'+str(denom)+'.csv'
np.savetxt(csv_conf,cm)
print(np.sum(np.diag(cm))/np.sum(cm))

以下のような出力が得られましたでしょうか!?
結果は96.8%とまずまずの結果なのではないでしょうか!
是非ハイパーパラメータを変えて実行して、違いを体感してみてください!
ここで「おいお前!Confusion matrixってなんや!」となったそこのあなた!安心してください。下の考察パートでしっかり解説させていただきます。

result.txt
Confusion matrix: 
 [[87  0  0  0  1  0  0  0  0  0]
 [ 0 89  0  0  0  0  0  0  1  1]
 [ 1  0 84  1  0  0  0  0  0  0]
 [ 0  0  0 82  0  3  0  1  5  0]
 [ 0  0  0  0 88  0  0  0  0  4]
 [ 0  0  0  0  0 88  1  0  0  2]
 [ 0  1  0  0  0  0 90  0  0  0]
 [ 0  0  0  0  0  1  0 88  0  0]
 [ 0  0  0  0  0  1  0  1 86  0]
 [ 0  0  0  1  0  1  0  1  0 89]]
0.9688542825361512

考察

混同行列(Confusion matrix)とは

混同行列とは、左上を原点としてx軸を本当の正解、y軸をモデルの推定結果と定めたとき(逆でも可)に出力された値を行列化したものを言います。なので今回の混同行列で説明すると、上から斜めに87,89とありますが、それは0を0と、1を1と認識した個数となります。間違っているものに注目すると2行目の9列目10列目なんかは1を8と認識してしまった個数が1個、1を9と認識してしまった個数が1個と行列から簡単に見て取ることが出来ます。このような点が一目でわかるので混同行列は優秀ですね!

カーネルに関して

result.pyのモデルの定義をしているところでkernelという言葉が出てきています。これはカーネル関数といって、様々な種類があり、代表的なところで言うと線形カーネル(linear)、多項式カーネル(poly)、RBFカーネル(rbf),シグモイドカーネル(sigmoid)というものがあります。データの特徴などによって変えていくものなのですが、基本的にRBFカーネルが万能という所感です。
なので比較してみるのが一番いいですが、困ったらRBFカーネルを使いましょう。

おわりに

ここまで読んでくださった方ありがとうございました!
カーネルなど考えていくと、とても深く難しい概念が多い分野ですが、まずは自分の手で実装してみて肌で感じるというのが良いと思います!
ここら辺の座学は難しいですからね!考えていくのも面白いんですけどね笑
最後に、なにか至らぬ点や疑問点、ミスがありましたらコメントください。これからも機械学習などの記事をたくさん書いていくつもりなので良かったらフォローお願いします!

参考文献

GoogleColaboratory導入方法①
GoogleColaboratory導入方法②
scikit-learnに関して
svmに関して

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2