OpenCVのSVMを利用して学習を行う。
# -*- coding: utf-8 -*
import glob
import cv2
import numpy as np
# generateimage array
def creeate_image_array(img_paths):
images = []
for img_path in img_paths:
image = cv2.imread(img_path)
reshaped_image = image.reshape(image.shape[0] * image.shape[1] * 3)
images.append(reshaped_image)
return np.array(images, np.float32)
def main():
# positive and negative image paths
pos_img_paths = glob.glob('./data/positive/*')
neg_img_paths = glob.glob('./data/negative/*')
# load images
positive_images = creeate_image_array(pos_img_paths)
negative_images = creeate_image_array(neg_img_paths)
images = np.r_[positive_images, negative_images]
# generate labels
positive_labels = np.ones(len(pos_img_paths), np.int32)
negative_labels = np.zeros(len(neg_img_paths), np.int32)
labels = np.array([np.r_[positive_labels, negative_labels]])
# SVM
svm = cv2.ml.SVM_create()
svm.setType(cv2.ml.SVM_C_SVC)
svm.setKernel(cv2.ml.SVM_RBF)
svm.setGamma(10)
svm.setC(10)
svm.setTermCriteria((cv2.TERM_CRITERIA_COUNT, 100, 1.e-06))
svm.train(images, cv2.ml.ROW_SAMPLE, labels)
svm.save('svm_trained_data.xml')
if __name__ == '__main__':
main()
無加工の画像を利用して学習を行っている。
出力した学習モデルを利用して分類をテストするコードは下記のようになる。
# -*- coding: utf-8 -*
import glob
import cv2
import numpy as np
# generate image array
def creeate_image_array(img_paths):
images = []
for img_path in img_paths:
image = cv2.imread(img_path)
reshaped_image = image.reshape(image.shape[0] * image.shape[1] * 3)
images.append(reshaped_image)
return np.array(images, np.float32)
def main():
# set test target images
test_img_paths = glob.glob('./data/test_tagert/*')
# loading
svm = cv2.ml.SVM_load("svm_trained_data.xml")
# test
images = creeate_image_array(test_img_paths)
predicted = svm.predict(hist)
result = predicted[1]
for (i, img_path) in enumerate(test_img_paths):
if result[i][0] == 1.0:
print("yes", img_path)
else:
print("no", img_path)
if __name__ == '__main__':
main()
参考
https://github.com/opencv/opencv/blob/master/samples/python/letter_recog.py#L62
https://algorithm.joho.info/programming/python/hog-svm-classifier-py/