47
71

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

scikit-learnを用いた機械学習入門 -データの取得からパラメータ最適化まで

Last updated at Posted at 2016-12-30

概要

ここでは**「はがきに記入された手書きの郵便番号を自動で識別してみよう」**という例題を考えてみます。

初心者向けの記事です。基本的には、scikit-learnのチュートリアルやDocumentをまとめた内容ですが、それ以外の内容も含みます。
データセットはdigitsを、機械学習の手法はSVM(正確にはSVC)を用いることにします。

  • digits : 手書きの数字文字画像データセット
  • SVC : サポートベクターマシンの一種

データセット:digits

digitsは数字ラベルと数字画像データが組となったデータセットです。このラベルと画像の組を、後ほど学習していくわけですね。scikit-learnによりあらかじめデータが用意されているため、誰でも簡単に試してみることができます。

データを読み込む

datasets.load_digits()でデータセットdigitsが読み込めます。

from sklearn import datasets
from matplotlib import pyplot as plt
# from sklearn import datasets

digits = datasets.load_digits()

データの中身を見る

各画像は0~9までの手書き文字画像です。これらの画像は、プログラム上では0~255の値をもつ二次元配列で表されます。

# 画像の配列データ
print(digits.data)
[[  0.   0.   5. ...,   0.   0.   0.]
 [  0.   0.   0. ...,  10.   0.   0.]
 [  0.   0.   0. ...,  16.   9.   0.]
 ..., 
 [  0.   0.   1. ...,   6.   0.   0.]
 [  0.   0.   2. ...,  12.   0.   0.]
 [  0.   0.  10. ...,  12.   1.   0.]]

画像データを配列のまま眺めてもわかりにくいので、画像として表示させてみたいですね。

画像を表示させる前に、まずは、ラベルデータを確認してみます。
下のように、あらかじめ0~9のラベルが正しく付与されています。

# ラベル
print(digits.target)
[0 1 2 ..., 8 9 8]

上の結果を見てみると、例えば先頭から0, 1, 2番目の画像にはラベル0, 1, 2が付与されていますし、後ろから2番目の画像にはラベル9が付与されています。これらの画像を表示するのには、matplotlibが使えます。

#  画像の表示
# number 0
plt.subplot(141), plt.imshow(digits.images[0], cmap = 'gray')
plt.title('number 0'), plt.xticks([]), plt.yticks([])

# number 1
plt.subplot(142), plt.imshow(digits.images[1], cmap = 'gray')
plt.title('numbert 1'), plt.xticks([]), plt.yticks([])

# number 2
plt.subplot(143), plt.imshow(digits.images[2], cmap = 'gray')
plt.title('numbert 2'), plt.xticks([]), plt.yticks([])

# number 9
plt.subplot(144), plt.imshow(digits.images[-2], cmap = 'gray')
plt.title('numbert 9'), plt.xticks([]), plt.yticks([])

plt.show()

output_7_0.png

このように、各画像に正しいラベルが付与されているらしいことが確認できます。

SVMによる画像分類

SVMとは

**SVM(Support Vector Machine)**は、非常に優れた認識性能を持つ教師あり学習手法のひとつです。基本的には、2クラス分類をマージン最大化を基準として行います。もちろん、(2クラス分類を複数回行うことで)多クラス分類にも適用できます。

厳密なSVMでは、分類するデータに重なりがある場合(つまり、すべてのデータを完全に切りわけることができない場合)は、きちんとした分類境界を求めることができません。一方、エラーを許容するSVMはソフトマージンSVMと呼ばれます。誤分類にペナルティCを与えることで、完全に分けられないデータの場合でも、できるだけ誤分類を少なくするような分類境界を引くことができます。

ペナルティCが大きいほど誤りに厳しくなると同時に、**過学習(Overfitting)**を起こしやすくなるので注意が必要です。

(注)過学習とは、学習モデルが訓練データの特定のランダムな(本来学習させたい特徴とは無関係な)特徴にまで適合してしまうことです。過学習が起こった場合、訓練データについての性能は向上しますが、それ以外のデータでは逆に結果が悪くなります。(参考:過剰適合 - Wikipedia)

scikit-learnでのSVM

実は、scikit-learnではSVC、NuSVC、LinearSVCといったやや種類の異なるSVMが存在しています。NuSVCとSVCは良く似た手法ですが、わずかに異なるパラメータセットを持ち、数学的には異なる定式化で表されます。LinearSVCは線形カーネルを用いたSVMであり、そのほかのカーネルを指定することはできません。

今回はSVCを用い、ソフトマージンを適用することにします。やることは、(1)識別器を作成し、(2)データに適用するだけです。

(1)識別器の作成

  • 最後の10件以外で学習モデルを作成
from sklearn import svm

# SVM
clf = svm.SVC(gamma=0.001, C=100.)
clf.fit(digits.data[:-10], digits.target[:-10])
SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma=0.001, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

ここではほとんど指定していませんが、上のようにSVCにはかなり多くのパラメータ(C, cache_size, class_weight, coef0, ...)があることがわかります。
最初はあまり気にせず、デフォルトの設定で大丈夫です。

(2)識別器による画像分類

  • 学習モデルからテストデータの最後の10件を推定

作成した識別器を用いて、実際に画像からラベルを推定します。学習モデルの作成に用いていない、最後の10件のデータで試してみます。

clf.predict(digits.data[-10:])
array([5, 4, 8, 8, 4, 9, 0, 8, 9, 8])

実際のデータを見てみると、

print(digits.target[-10:])
[5 4 8 8 4 9 0 8 9 8]

であり、推定結果と一致しています。

これで、学習したモデルからおおよそ正しい推定が可能であることが確かめられました。パラメータをいろいろと変えて試してみましょう。

識別器の精度評価

識別器の評価指標

識別器の分類精度の評価指標はいくつかありますが、基本的には以下のような指標によって測ることができます。

  • 正解率(Accuracy)
    • 予測のうち、正しく分類されたものの割合
  • 適合率(Precision)
    • 正と予測したデータのうち,実際に正であるものの割合
  • 再現率(Recall)
    • 実際に正であるもののうち,正であると予測されたものの割合
  • F値(F-measure)
    • 適合率と再現率の調和平均

通常は識別器の精度はF値で評価することが多いです。
しかし、実用上は適合率と再現率のどちらを重視するかが異なっていることが多々あります。

適合率と再現率

例えば、工場での部品検査を考えてみましょう。どこも壊れていない部品を間違って「壊れている(エラー)」に分類してしまっても大した問題にはなりません。ところが、壊れている部品を間違って「壊れていない(正)」と分類してしまうと、クレームやリコールの原因になったり、商品によっては人命にかかわることさえあります。
このような場合は、再現率よりも適合率の方が重視されることになります。例えば、「適合率99%+再現率70%」と「適合率80%+再現率99%」では、後者の方がF値は高いですが、実用性は圧倒的に前者が高いという場合があり得るのです。
一方で、データベースの検索などの場合は、適合率よりも再現率が重視されることが多いでしょう。多少間違った検索結果が多く出ても、検索しても見つからないデータが多いよりははるかにましだからです。

パラメータ最適化

これまで、パラメータはなんとなくで適当な値を設定していました。 しかし、これでは要求される分類精度が得られないことも多く、実際にはパラメータの最適化が必須となります。では、識別器の分類精度を高めるためにどのパラメータをどのように設定すればよいのでしょうか。
ひとつずつパラメータを手でチューニングしていっても良いですが、これは非常に大変です。データセットや手法によっては慣習的にこれくらいの値が良い、などの知見が存在する場合もあるようですが、未知のデータセットの場合には使えません。
そこで、よく用いられるのが、グリッドサーチという方法です。簡単に説明すれば、探索範囲のなかでパラメータを変化させながら実際にモデルを学習し、その結果の精度が最も良いパラメータを探すという方法です。
また、求めたパラメータでの学習モデルが過学習を起こしていないかを確認するためには、交差検証法(Cross-validation)を用います。k-交差検証法では、まず、データをk個に分割します。そのうちのk-1個で学習、残った1個で評価するというのを(訓練データ・テストデータを変えながら)k回繰り返し、その平均値で学習モデルを評価するという方法です。このようにすることで、学習モデルの汎化性能を評価することができます。

(注)汎化性能が良いとは、簡単に言えば未知データに対しても学習モデルが適切な識別を行える能力のことです。過学習を起こしている場合、訓練データに対しては高い精度で識別できるのですが、未知データに対しては識別の精度が低下することを思い出してください。

scikit-learnでは、GridSearchCV()を用いてグリッドサーチと交差検証法を簡単に行うことができます。例えば、以下のようなパラメータを指定できます。

  • scoring
    • パラメータ最適化における評価値。今回は'precision'と'recall'を指定
  • cv
    • 交差検証の分割数。10程度のことが多い。計算量が大きくなりすぎる場合や、逆にデータ数が少なすぎる場合は分割数を少なく指定

準備

パラメータ最適化を行う前に、読み込んだデータの形式を変換しておきます。

from sklearn import datasets
from sklearn.cross_validation import train_test_split
from sklearn.grid_search import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.svm import SVC

# Digitsデータセットのロード
digits = datasets.load_digits()
print(len(digits.images))
print(digits.images.shape)
1797
(1797, 8, 8)
# To apply an classifier on this data, we need to flatten the image, to
# turn the data in a (samples, feature) matrix:
n_samples = len(digits.images)
X = digits.images.reshape((n_samples, -1))  # reshape(cols, rows)でcols行rows列に変換(引数の一方が-1だと自動計算)
y = digits.target
print(X.shape)
print(y)
(1797, 64)
[0 1 2 ..., 8 9 8]

グリッドサーチと交差検証法

下記のコードは一見難しそうですが、実際にやっていることは単純です。

  • kernel = "rbf", gamma = 0.001 or 0.0001, C = 1 or 10 or 100 or 1000
  • kernel = "linear", C = 1 or 10 or 100 or 1000

上のような場合わけによる組み合わせをすべて試して、それぞれのprecisionとrecallが最大となるパラメータ(best_params_)を求めているだけですね。(なお、gammaはkernelがrbfの場合のパラメータなので、kernelがlinearの場合は無関係です)

後はグリッドサーチの結果を細かく表示させたり、classification_report()で結果の詳細なレポートを表示させたりしています。

# データセットを訓練用データとテスト用データに分割
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.5, random_state=0)

# クロスバリデーションで最適化したいパラメータをセット
tuned_parameters = [{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4],
                     'C': [1, 10, 100, 1000]},
                    {'kernel': ['linear'], 'C': [1, 10, 100, 1000]}]

scores = ['precision', 'recall']

for score in scores:
    print("# Tuning hyper-parameters for %s" % score)
    print()

    # グリッドサーチと交差検証法
    clf = GridSearchCV(SVC(C=1), tuned_parameters, cv=5,
                       scoring='%s_weighted' % score)
    clf.fit(X_train, y_train)

    print("Best parameters set found on development set:")
    print()
    print(clf.best_params_)
    print()
    print("Grid scores on development set:")
    print()
    for params, mean_score, scores in clf.grid_scores_:
        print("%0.3f (+/-%0.03f) for %r"
              % (mean_score, scores.std() * 2, params))
    print()
    print("Detailed classification report:")
    print()
    print("The model is trained on the full development set.")
    print("The scores are computed on the full evaluation set.")
    print()
    y_true, y_pred = y_test, clf.predict(X_test)
    print(classification_report(y_true, y_pred))
    print()  
# Tuning hyper-parameters for precision

Best parameters set found on development set:

{'gamma': 0.001, 'kernel': 'rbf', 'C': 10}

Grid scores on development set:

0.987 (+/-0.018) for {'gamma': 0.001, 'kernel': 'rbf', 'C': 1}
0.959 (+/-0.030) for {'gamma': 0.0001, 'kernel': 'rbf', 'C': 1}
0.988 (+/-0.018) for {'gamma': 0.001, 'kernel': 'rbf', 'C': 10}
0.982 (+/-0.027) for {'gamma': 0.0001, 'kernel': 'rbf', 'C': 10}
0.988 (+/-0.018) for {'gamma': 0.001, 'kernel': 'rbf', 'C': 100}
0.982 (+/-0.026) for {'gamma': 0.0001, 'kernel': 'rbf', 'C': 100}
0.988 (+/-0.018) for {'gamma': 0.001, 'kernel': 'rbf', 'C': 1000}
0.982 (+/-0.026) for {'gamma': 0.0001, 'kernel': 'rbf', 'C': 1000}
0.974 (+/-0.014) for {'kernel': 'linear', 'C': 1}
0.974 (+/-0.014) for {'kernel': 'linear', 'C': 10}
0.974 (+/-0.014) for {'kernel': 'linear', 'C': 100}
0.974 (+/-0.014) for {'kernel': 'linear', 'C': 1000}

Detailed classification report:

The model is trained on the full development set.
The scores are computed on the full evaluation set.

             precision    recall  f1-score   support

          0       1.00      1.00      1.00        89
          1       0.97      1.00      0.98        90
          2       0.99      0.98      0.98        92
          3       1.00      0.99      0.99        93
          4       1.00      1.00      1.00        76
          5       0.99      0.98      0.99       108
          6       0.99      1.00      0.99        89
          7       0.99      1.00      0.99        78
          8       1.00      0.98      0.99        92
          9       0.99      0.99      0.99        92

avg / total       0.99      0.99      0.99       899


# Tuning hyper-parameters for recall

Best parameters set found on development set:

{'gamma': 0.001, 'kernel': 'rbf', 'C': 10}

Grid scores on development set:

0.986 (+/-0.021) for {'gamma': 0.001, 'kernel': 'rbf', 'C': 1}
0.958 (+/-0.029) for {'gamma': 0.0001, 'kernel': 'rbf', 'C': 1}
0.987 (+/-0.021) for {'gamma': 0.001, 'kernel': 'rbf', 'C': 10}
0.981 (+/-0.029) for {'gamma': 0.0001, 'kernel': 'rbf', 'C': 10}
0.987 (+/-0.021) for {'gamma': 0.001, 'kernel': 'rbf', 'C': 100}
0.981 (+/-0.027) for {'gamma': 0.0001, 'kernel': 'rbf', 'C': 100}
0.987 (+/-0.021) for {'gamma': 0.001, 'kernel': 'rbf', 'C': 1000}
0.981 (+/-0.027) for {'gamma': 0.0001, 'kernel': 'rbf', 'C': 1000}
0.973 (+/-0.015) for {'kernel': 'linear', 'C': 1}
0.973 (+/-0.015) for {'kernel': 'linear', 'C': 10}
0.973 (+/-0.015) for {'kernel': 'linear', 'C': 100}
0.973 (+/-0.015) for {'kernel': 'linear', 'C': 1000}

Detailed classification report:

The model is trained on the full development set.
The scores are computed on the full evaluation set.

             precision    recall  f1-score   support

          0       1.00      1.00      1.00        89
          1       0.97      1.00      0.98        90
          2       0.99      0.98      0.98        92
          3       1.00      0.99      0.99        93
          4       1.00      1.00      1.00        76
          5       0.99      0.98      0.99       108
          6       0.99      1.00      0.99        89
          7       0.99      1.00      0.99        78
          8       1.00      0.98      0.99        92
          9       0.99      0.99      0.99        92

avg / total       0.99      0.99      0.99       899

さて、print(clf.best_params_)の結果から、precision/recallのどちらの観点から見ても'gamma': 0.001, 'kernel': 'rbf', 'C': 10が最良であることがわかります。これでパラメータの最適化ができました。

必要なら、さらに別のkernelを使った場合やSVM以外の学習手法と比較した場合の最適化も試してみましょう。

参考

[1]An introduction to machine learning with scikit-learn — scikit-learn 0.18.1 documentation
http://scikit-learn.org/stable/tutorial/basic/tutorial.html#introduction
[2]Parameter estimation using grid search with cross-validation — scikit-learn 0.18.1 documentation
http://scikit-learn.org/stable/auto_examples/model_selection/grid_search_digits.html#example-model-selection-grid-search-digits-py
[3]1.4. Support Vector Machines — scikit-learn 0.18.1 documentation
http://scikit-learn.org/stable/modules/svm.html
[4]F値 - 機械学習の「朱鷺の杜Wiki」
http://ibisforest.org/index.php?F%E5%80%A4
[5]SVMを使いこなす!チェックポイント8つ - Qiita
http://qiita.com/pika_shi/items/5e59bcf69e85fdd9edb2
[6]Scikit learnよりグリッドサーチによるパラメータ最適化
http://qiita.com/SE96UoC5AfUt7uY/items/c81f7cea72a44a7bfd3a
[7]機械学習のためのベイズ最適化入門|Tech Book Zone Manatee
https://book.mynavi.jp/manatee/detail/id=59393
[8]3.3. Model evaluation: quantifying the quality of predictions — scikit-learn 0.18.1 documentation
http://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter
[9]過剰適合 - Wikipedia
https://ja.wikipedia.org/wiki/%E9%81%8E%E5%89%B0%E9%81%A9%E5%90%88

47
71
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
47
71

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?