LoginSignup
6
9

More than 5 years have passed since last update.

OpenCVのSVMで学習を行う小さめのコード

Last updated at Posted at 2018-08-21

OpenCVのSVMを利用して学習を行う。

# -*- coding: utf-8 -*
import glob
import cv2
import numpy as np

# generateimage array
def creeate_image_array(img_paths):
    images = []
    for img_path in img_paths:
        image = cv2.imread(img_path)
        reshaped_image = image.reshape(image.shape[0] * image.shape[1] * 3)
        images.append(reshaped_image)
    return np.array(images, np.float32)

def main():
    # positive and negative image paths
    pos_img_paths = glob.glob('./data/positive/*')
    neg_img_paths = glob.glob('./data/negative/*')

    # load images
    positive_images = creeate_image_array(pos_img_paths)
    negative_images = creeate_image_array(neg_img_paths)
    images = np.r_[positive_images, negative_images]

    # generate labels
    positive_labels = np.ones(len(pos_img_paths), np.int32)
    negative_labels = np.zeros(len(neg_img_paths), np.int32)
    labels = np.array([np.r_[positive_labels, negative_labels]])

    # SVM
    svm = cv2.ml.SVM_create()
    svm.setType(cv2.ml.SVM_C_SVC)
    svm.setKernel(cv2.ml.SVM_RBF)
    svm.setGamma(10)
    svm.setC(10)
    svm.setTermCriteria((cv2.TERM_CRITERIA_COUNT, 100, 1.e-06))
    svm.train(images, cv2.ml.ROW_SAMPLE, labels)
    svm.save('svm_trained_data.xml')

if __name__ == '__main__':
    main()

無加工の画像を利用して学習を行っている。
出力した学習モデルを利用して分類をテストするコードは下記のようになる。

# -*- coding: utf-8 -*
import glob
import cv2
import numpy as np

# generate image array
def creeate_image_array(img_paths):
    images = []
    for img_path in img_paths:
        image = cv2.imread(img_path)
        reshaped_image = image.reshape(image.shape[0] * image.shape[1] * 3)
        images.append(reshaped_image)
    return np.array(images, np.float32)

def main():
    # set test target images
    test_img_paths = glob.glob('./data/test_tagert/*')

    # loading
    svm = cv2.ml.SVM_load("svm_trained_data.xml")

    # test
    images = creeate_image_array(test_img_paths)
    predicted = svm.predict(hist)
    result = predicted[1]
    for (i, img_path) in enumerate(test_img_paths):
        if result[i][0] == 1.0:
            print("yes", img_path)
        else:
            print("no", img_path)

if __name__ == '__main__':
    main()

参考

https://github.com/opencv/opencv/blob/master/samples/python/letter_recog.py#L62
https://algorithm.joho.info/programming/python/hog-svm-classifier-py/

6
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
9