search
LoginSignup
3

More than 5 years have passed since last update.

posted at

手書き数字データの機械学習テンプレ

実践力を身につける Pythonの教科書のサンプルをいじっただけのやつ

python3 digits.py ${fileName}で適当な数字画像を投げ込めば予測してくれる。

スクリーンショット 2017-05-27 9.36.24.png

digits.py
import os, sys, math
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets, model_selection, svm, metrics
from sklearn.externals import joblib
from PIL import Image

# モデルデータファイル名
DIGITS_PKL = "digit-clf.pkl"

# 手書き数字データを読み込む
digits = datasets.load_digits()
# クロスバリデーション
# データをランダムに訓練用とテスト用にわける
data_train, data_test, label_train, label_test = \
    model_selection.train_test_split(digits.data, digits.target)

# 予測モデルを作成する
def create_model():
    # モデル構築
    clf = svm.SVC(gamma=0.001)
    # clf = svm.LinearSVC()
    # from sklearn.ensemble import RandomForestClassifier
    # clf = RandomForestClassifier()
    # 学習
    clf.fit(data_train, label_train)
    # 予測モデルを保存
    joblib.dump(clf, DIGITS_PKL)
    print("予測モデルを保存しました=", DIGITS_PKL)
    return clf

# 予測モデルを選定する
def select_model():
    # モデルファイルを読み込む
    if not os.path.exists(DIGITS_PKL):
        clf = create_model() # モデルがなければ生成
    clf = joblib.load(DIGITS_PKL)
    return clf

# データから数字を予測する
def predict_digits(data,clf):
    n = clf.predict([data])
    print("判定結果=", n)

# 手書き数字画像を8x8グレイスケールのデータ配列に変換
def image_to_data(imagefile):
    image = Image.open(imagefile).convert('L') # グレイスケール変換
    image = image.resize((8, 8), Image.ANTIALIAS)
    img = np.asarray(image, dtype=float)
    img = np.floor(16 - 16 * (img / 256)) # 行例演算
    # 変換後の画像を表示
    plt.imshow(img)
    plt.gray()
    plt.show()

    img = img.flatten()
    print("img=",img)
    return img

# モデルを評価する
def evaluate_model(clf):
    predict = clf.predict(data_test)
    return predict

# 予測からレポートをつくる
def show_report(predict, clf):
    ac_score = metrics.accuracy_score(label_test, predict)
    cl_report = metrics.classification_report(label_test, predict)
    print('分類機の情報=', clf)
    print('正解率=', ac_score)
    print('レポート=', cl_report)
    # precision:精度, recall:再現率(正解率),
    # f1-score:精度と再現率の調和平均, support:正解ラベルのデータ数

def main():
    # コマンドライン引数を得る
    if len(sys.argv) <= 1:
        print("USAGE:")
        print("python3 predict_digit.py imagefile")
        return
    imagefile = sys.argv[1]
    data = image_to_data(imagefile)
    clf = select_model();
    predict_digits(data,clf)
    show_report(evaluate_model(clf),clf)

if __name__ == '__main__':
    main()
結果
img= [ 0.  0.  0.  0.  0.  0.  0.  0.  1.  9.  7.  7.  7.  7.  2.  0.  1.  8.
  0.  1.  0.  0.  0.  0.  1.  6.  0.  0.  0.  0.  0.  0.  1.  9.  5.  6.
  5.  1.  0.  0.  0.  4.  3.  3.  4.  8.  1.  0.  0.  0.  0.  0.  2.  9.
  2.  0.  0.  3.  8.  8.  8.  2.  0.  0.]
判定結果= [5]
分類機の情報= SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma=0.001, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)
正解率= 0.993333333333
レポート=              precision    recall  f1-score   support

          0       1.00      1.00      1.00        38
          1       1.00      1.00      1.00        48
          2       1.00      1.00      1.00        40
          3       0.98      0.98      0.98        47
          4       1.00      1.00      1.00        54
          5       0.98      0.98      0.98        47
          6       0.98      1.00      0.99        46
          7       1.00      1.00      1.00        42
          8       1.00      1.00      1.00        47
          9       1.00      0.98      0.99        41

avg / total       0.99      0.99      0.99       450

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
3