OpenCV Deep Learningモジュール cv::dnn の紹介

  • 33
    Like
  • 0
    Comment
More than 1 year has passed since last update.

Google Summer of Code (GSoC) 2015で発表され、opencv_contrib レポジトリに実装が公開された cv::dnn モジュールの紹介をします。

It would be cool if OpenCV could load and run deep networks trained with popular DNN packages like Caffe, Theano or Torch. - Ideas Page for OpenCV Google Summer of Code 2015 (GSoC 2015)

この cv::dnn モジュールですが、2015/12/21にリリースされたOpenCV 3.1にさっそく取り込まれたようです。

導入手順やCaffeのC++インタフェースとの比較など技術詳細についてはブログの方にまとめてありますので興味があれば。

ここでは簡単なサンプルコードを載せておきます。

画像分類処理 サンプルコード

cv::dnnモジュールを使って、Caffe用の学習済みモデルの読み込みと画像分類(一般物体認識)処理を試してみます。

DSC01498.jpg これ入力

// opencv_dnn_test.cpp
#include <iostream>
#include <fstream>
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>

using namespace std;

int main(int argc, char** argv) {
  // ImageNet Caffeリファレンスモデル
  string protoFile = "bvlc_reference_caffenet/deploy.prototxt";
  string modelFile = "bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel";
  // 画像ファイル
  string imageFile = (argc > 1) ? argv[1] : "images/cat.jpg";
  // Caffeモデルの読み込み
  cv::Ptr<cv::dnn::Importer> importer;
  try {
    importer = cv::dnn::createCaffeImporter(protoFile, modelFile);
  } catch(const cv::Exception& e) {
    cerr << e.msg << endl;
    exit(-1);
  }
  cv::dnn::Net net;
  importer->populateNet(net);
  importer.release();
  // テスト用の入力画像ファイルの読み込み
  cv::Mat img = cv::imread(imageFile);
  if(img.empty()) {
    cerr << "can't read image: " << imageFile << endl;
    exit(-1);
  }
  try {
    // 入力画像をリサイズ
    int cropSize = 224;
    cv::resize(img, img, cv::Size(cropSize, cropSize));
    // Caffeで扱うBlob形式に変換 (実体はcv::Matのラッパークラス)
    const cv::dnn::Blob inputBlob = cv::dnn::Blob(img);
    // 入力層に画像を入力
    net.setBlob(".data", inputBlob);
    // フォワードパス(順伝播)の計算
    net.forward();
    // 出力層(Softmax)の出力を取得, ここに予測結果が格納されている
    const cv::dnn::Blob prob = net.getBlob("prob");
    // Blobオブジェクト内部のMatオブジェクトへの参照を取得
    // ImageNet 1000クラス毎の確率(32bits浮動小数点値)が格納された1x1000の行列(ベクトル)
    const cv::Mat probMat = prob.matRefConst();
    // 確率(信頼度)の高い順にソートして、上位5つのインデックスを取得
    cv::Mat sorted(probMat.rows, probMat.cols, CV_32F);
    cv::sortIdx(probMat, sorted, CV_SORT_EVERY_ROW|CV_SORT_DESCENDING);
    cv::Mat topk = sorted(cv::Rect(0, 0, 5, 1));
    // カテゴリ名のリストファイル(synset_words.txt)を読み込み
    // データ例: categoryList[951] = "lemon";
    vector<string> categoryList;
    string category;
    ifstream fs("synset_words.txt");
    if(!fs.is_open()) {
      cerr << "can't read file" << endl;
      exit(-1);
    }
    while(getline(fs, category)) {
      if(category.length()) {
        categoryList.push_back(category.substr(category.find(' ') + 1));
      }
    }
    fs.close();
    // 予測したカテゴリと確率(信頼度)を出力
    cv::Mat_<int>::const_iterator it = topk.begin<int>();
    while(it != topk.end<int>()) {
      cout << categoryList[*it] << " : " << probMat.at<float>(*it) * 100 << " %" << endl;
      ++it;
    }
  } catch(const cv::Exception& e) {
    cerr << e.msg << endl;
  }
  return 0;
}

実行結果

$ g++ opencv_dnn_test.cpp -o opencv_dnn_test `pkg-config --cflags opencv` `pkg-config --libs opencv`
$ ./opencv_dnn_test
Attempting to upgrade input file specified using deprecated transformation parameters: bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel
Successfully upgraded file specified using deprecated data transformation parameters.
Note that future Caffe releases will only support transform_param messages for transformation fields.
Attempting to upgrade input file specified using deprecated V1LayerParameter: bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel
Successfully upgraded file specified using deprecated V1LayerParameter

Net Outputs(1):
prob
Siamese cat, Siamese : 93.9703 %
Egyptian cat : 4.23627 %
tabby, tabby cat : 0.365742 %
lynx, catamount : 0.19613 %
hamster : 0.184294 %

うちの猫はシャム猫と認識されました。雑種です。

cv::dnn::Blob クラスのコンストラクタは cv::InputArray を入力として受け付けているので、std::vector<cv::Mat> を作って渡せば複数の画像を対象にバッチ処理することもできます。

cv::dnn::Blob blob = cv::dnn::Blob(image);  // imageは cv::InputArray
std::cout << "blob shape: " << blob.shape() << std::endl;

// 実行結果例
// データ数、チャンネル数、画像幅、画像高さ
blob shape: [1, 3, 224, 224]

std::vector<cv::Mat> images;  // 画像データのリストを準備
// ... std::vector#push_back で cv::Mat を5つ追加

cv::dnn::Blob blob = cv::dnn:Blob(images);
std::cout << "blob shape: " << blob.shape() << std::endl;

// 実行結果例
blob shape: [5, 3, 224, 224]

複数の画像を入力とした場合のSoftmax(prob)層のBlobは [入力データ数、クラス数] の大きさの二次元データになり、クラス毎の確率(信頼度)が格納されています。確率上位5位までの予測結果を出力する場合も上述のサンプルコードと同様に書くことができます。現在の実装だとCaffeのArgmaxレイヤーが読み込めないみたいなのでソートする処理も別途書く必要がありました。

画像特徴量の抽出

Softmax(prob)層の出力から確率を得るのではなく、中間層の出力を特徴量として抽出したい場合も簡単です。cv::dnn::Net#getBlob メソッドに対象のレイヤー名を指定するだけでOK。レイヤー名は読み込むprototxtファイルを参照してください。

net.forward();  // cv::dnn::Net#forward メソッドでフォワードパス(順伝播)の計算
cv::dnn::Blob blob = net.getBlob("fc7");  // 全結合層 fc7 (InnerProduct)のBlobを取得
std::cout << "blob shape: " << blob.shape() << std::endl;  // blob shape: [1, 4096] (4096次元の特徴量を抽出)
const cv::Mat feature = blob.matRefConst();  // 抽出した特徴量を cv::Mat として取得(参照が返る)

抽出した特徴量を使って適当な分類器を作るのも簡単です。ここではSVM(cv::flag_ml::SVM)で学習するサンプルを載せます。

// feature(cv::Mat): 特徴量, trainLabel(cv::Mat): 正解ラベル
cv::Ptr<cv::ml::TrainData> data =
      cv::ml::TrainData::create(feature, cv::ml::ROW_SAMPLE, trainLabel, false);
cv::Ptr<cv::ml::SVM> clf = cv::ml::SVM::create();
clf->setType(cv::ml::SVM::C_SVC);
clf->setKernel(cv::ml::SVM::LINEAR);
clf->trainAuto(data, 5);  // グリッドサーチ + 交差検証(5-fold)で学習
clf->save("cnn_svm_model.yml");  // モデルをYAMLファイルとしてディスクに保存

cv::flag_ml::SVM#trainAuto メソッドはとても便利。学習したモデルは保存しておくと良いと思います。

ただし、cv::dnnモジュール単体ではモデルの学習は未サポートとのことです。

Functionality of this module is designed only for forward pass computations (i. e. network testing). A network training is in principle not supported.

CaffeやTorch本体のインストールは学習済みモデルを利用するだけであれば必要ありません。

OpenCVはマルチプラットフォーム対応ライブラリなので、cv::dnnモジュールを使ったアプリケーションも比較的簡単に作れるかと思います。なにより他の画像処理関連ライブラリと比べてOSSコミュニティが大きいのでサポートも期待できるという安心感もあります。

ここで紹介したサンプルコードに加え、SVMではなくロジスティック回帰(cv::flag_ml::LogisticRegression)を使って学習したり、cv::datasetsモジュールを使ってMNISTデータセットを読み込むサンプルなどもGitHubに上げてありますので興味があれば参考にしていただければと思います。