0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【Python】scikit-learnを用いたSVM(Support Vector Machine)の実装

Last updated at Posted at 2021-02-07

#Purpose
Making a model to classify handwritten numbers (0~9) by SVM

#Env

  • OS : Mac
  • Module : scikit-learn
  • Lang : Python

#Get character dataset from scikit-learn
You can get handwritten numbers (images) from scikit-learn.
You can check contents of the dataset.

from sklearn import datasets

# Read handwritten dataset
digits = datasets.load_digits()

# You can check the image as you want with matplotlib
import matplotlib.pyplot as plt
plt.matshow(digits.images[0], cmap="Greys")
plt.show()

Sample (0 and 8)

#From image to array
Since SVM can not directly read the images, therefore we need to convert each the image into array dataset. You can change the data to array by [object].data.

# array of each image
X = digits.data
# label of each image
y = digits.target

# split train and test data
# Train data (even)
X_train, y_train = X[0::2], y[0::2]
# Test data (odd)
X_test, y_test = X[1::2], y[1::2]

# you can check the train data
print(X)
print(y)

Content of the Train data (arrays (image))
numbers in each array mean darkness of square.
black > gray > white

[[ 0.  0.  5. ...  0.  0.  0.]
 [ 0.  0.  0. ... 10.  0.  0.]
 [ 0.  0.  0. ... 16.  9.  0.]
 ...
 [ 0.  0.  1. ...  6.  0.  0.]
 [ 0.  0.  2. ... 12.  0.  0.]
 [ 0.  0. 10. ... 12.  1.  0.]]

Content of the train data (label)

[0 1 2 ... 8 9 8]

#Make a model
Make a model with train data.

# select SVM
from sklearn import svm
clf = svm.SVC(gamma=0.001)

# train the model with train data
clf.fit(X_train, y_train)

#Test
Evaluate the created model with test data.

Return accuracy

accuracy = clf.score(X_test, y_test)
print(f"Accuracy:{accuracy}")
Accuracy:0.9866369710467706

Return prediction of each test data

predicted = clf.predict(X_test)
[1 3 9 7 9 1 3 5 7 9 1 3 5 7 9..... 9 4 7 3 1 0 2 8 0...]

You can check detail (precision/recall/F-score)

import sklearn.metrics as metrics
print("classification report")
print(metrics.classification_report(y_test, predicted))
classification report
              precision    recall  f1-score   support

           0       1.00      0.99      0.99        88
           1       0.98      1.00      0.99        89
           2       1.00      1.00      1.00        91
           3       1.00      0.98      0.99        93
           4       0.99      1.00      0.99        88
           5       0.98      0.97      0.97        91
           6       0.99      1.00      0.99        90
           7       0.99      1.00      0.99        91
           8       0.97      0.97      0.97        86
           9       0.98      0.97      0.97        91

    accuracy                           0.99       898
   macro avg       0.99      0.99      0.99       898
weighted avg       0.99      0.99      0.99       898
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?