LoginSignup
3
3

More than 5 years have passed since last update.

手書き数字データの機械学習テンプレ

Posted at

実践力を身につける Pythonの教科書のサンプルをいじっただけのやつ

python3 digits.py ${fileName}で適当な数字画像を投げ込めば予測してくれる。

スクリーンショット 2017-05-27 9.36.24.png

digits.py
import os, sys, math
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets, model_selection, svm, metrics
from sklearn.externals import joblib
from PIL import Image

# モデルデータファイル名
DIGITS_PKL = "digit-clf.pkl"

# 手書き数字データを読み込む
digits = datasets.load_digits()
# クロスバリデーション
# データをランダムに訓練用とテスト用にわける
data_train, data_test, label_train, label_test = \
    model_selection.train_test_split(digits.data, digits.target)

# 予測モデルを作成する
def create_model():
    # モデル構築
    clf = svm.SVC(gamma=0.001)
    # clf = svm.LinearSVC()
    # from sklearn.ensemble import RandomForestClassifier
    # clf = RandomForestClassifier()
    # 学習
    clf.fit(data_train, label_train)
    # 予測モデルを保存
    joblib.dump(clf, DIGITS_PKL)
    print("予測モデルを保存しました=", DIGITS_PKL)
    return clf

# 予測モデルを選定する
def select_model():
    # モデルファイルを読み込む
    if not os.path.exists(DIGITS_PKL):
        clf = create_model() # モデルがなければ生成
    clf = joblib.load(DIGITS_PKL)
    return clf

# データから数字を予測する
def predict_digits(data,clf):
    n = clf.predict([data])
    print("判定結果=", n)

# 手書き数字画像を8x8グレイスケールのデータ配列に変換
def image_to_data(imagefile):
    image = Image.open(imagefile).convert('L') # グレイスケール変換
    image = image.resize((8, 8), Image.ANTIALIAS)
    img = np.asarray(image, dtype=float)
    img = np.floor(16 - 16 * (img / 256)) # 行例演算
    # 変換後の画像を表示
    plt.imshow(img)
    plt.gray()
    plt.show()

    img = img.flatten()
    print("img=",img)
    return img

# モデルを評価する
def evaluate_model(clf):
    predict = clf.predict(data_test)
    return predict

# 予測からレポートをつくる
def show_report(predict, clf):
    ac_score = metrics.accuracy_score(label_test, predict)
    cl_report = metrics.classification_report(label_test, predict)
    print('分類機の情報=', clf)
    print('正解率=', ac_score)
    print('レポート=', cl_report)
    # precision:精度, recall:再現率(正解率),
    # f1-score:精度と再現率の調和平均, support:正解ラベルのデータ数

def main():
    # コマンドライン引数を得る
    if len(sys.argv) <= 1:
        print("USAGE:")
        print("python3 predict_digit.py imagefile")
        return
    imagefile = sys.argv[1]
    data = image_to_data(imagefile)
    clf = select_model();
    predict_digits(data,clf)
    show_report(evaluate_model(clf),clf)

if __name__ == '__main__':
    main()
結果
img= [ 0.  0.  0.  0.  0.  0.  0.  0.  1.  9.  7.  7.  7.  7.  2.  0.  1.  8.
  0.  1.  0.  0.  0.  0.  1.  6.  0.  0.  0.  0.  0.  0.  1.  9.  5.  6.
  5.  1.  0.  0.  0.  4.  3.  3.  4.  8.  1.  0.  0.  0.  0.  0.  2.  9.
  2.  0.  0.  3.  8.  8.  8.  2.  0.  0.]
判定結果= [5]
分類機の情報= SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma=0.001, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)
正解率= 0.993333333333
レポート=              precision    recall  f1-score   support

          0       1.00      1.00      1.00        38
          1       1.00      1.00      1.00        48
          2       1.00      1.00      1.00        40
          3       0.98      0.98      0.98        47
          4       1.00      1.00      1.00        54
          5       0.98      0.98      0.98        47
          6       0.98      1.00      0.99        46
          7       1.00      1.00      1.00        42
          8       1.00      1.00      1.00        47
          9       1.00      0.98      0.99        41

avg / total       0.99      0.99      0.99       450
3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3