はじめに
どうも、エンジニア志望の理系大学生NMSです。
最近、画像分類で使用されるAIアルゴリズムを調査することがあり、それをまとめたので、いい機会だからQiita記事でもアウトプットしてみようと思い立ちました。
お役に立てたら幸いです。
画像分類(多クラス)において使用されるAIアルゴリズム
機械学習(Machine Learning)
ロジスティック回帰
このモデルは、線形回帰で使われる線形多項式(目標の値を予測するための式)
y(x, w)=w_{0}+w_{11}(x)+w_{22}(x)+ \cdots +w_{MM}(x)=\sum_{i=0}^{M} w_{ii}(x)
に、ロジスティック関数(もしくはシグモイド関数とも呼ばれます)を適用することで、分類問題を解けるようにしただけのものです。
k-近傍法
k-近傍法は、分類が未知の新しいデータに対して直近のk点を探し、それらの多数決でそのデータの分類を決めるというモデルです。
サポートベクターマシン (SVM)
このモデルでは「マージン最大化」がキーワードです。これはデータをクラスごとに分けるときに、境界とその近くの点との距離がなるべく大きくなるように境界を決めるというものです。これにより、未知のデータに対しても、精度良く予測ができるようになります(これを汎化性能が高いと言います)。
また、データの次元が大きくなってしまっても識別精度が高い点や、最適化すべきパラメータが少ない点などは、SVMの大きなメリットといえるでしょう。ただし、学習データの増加に伴い計算量が膨大になる点や、2クラス分類に特化している点、スケーリングが必要になる点などはデメリットといえます。
決定木
決定木とは、ツリー状に条件分岐を繰り返すことによって分類するクラスを予測するモデルです。枝分かれ部分の条件は、データの変数を使って作ります。
アンサンブル学習
決定木は、条件分岐を繰り返して予測するツリー上のモデルでしたが、このモデルの精度を高める方法に、アンサンブル学習というものがあります。
アンサンブル学習とはモデルを複数用意して、各モデルの出力をまとめることで、より良い予測をさせようというものです。
-
バギング(Bootstrap AGGregatING)
バギングとはデータセットから重複を許してランダムに抽出した訓練データを複数用意し、それぞれ独立にモデルを学習させて出力を各モデルの平均や多数決により決定する手法です。
バギングによって予測精度が不安定なモデルの性能を向上させて過学習(学習データに対してあまりに忠実に適合しすぎて、未知データに対しては適合できていない、汎化できていない状態)を抑えることができます。 -
ブースティング
バギングとの違いは、バギングがモデルを独立に増やしていくのに対し、ブースティングは逐次的にモデルを増やす(あるモデルが間違えた問題に対して正解しているモデルを追加していくなど、前のモデルの結果を受けて次に足し合わせるモデルを決める)という点です。これによって、バギングよりも高い精度を得られるようになりました。
ただし、モデルの数を増やすと汎化性能が高まるバギングと違い、ブースティングの場合はモデルを増やしすぎてしまうと過学習に陥ってしまうので、どこまで増やすかには注意する必要があります。
また、ブースティングはバギングと違って並列処理ができないので、時間も多くかかってしまいます。しかし、ブースティングを使用した勾配ブースティングマシン(XGBoost等)は、機械学習アルゴリズムの中で唯一ディープラーニングの精度に匹敵することがあります。 -
スタッキング
タッキングとは複数のモデルを積み上げる手法で、ブースティングと一緒に用いられることが多いモデルです。
仕組みとしては、まず一層目にいくつかのモデルを用意し、それぞれのモデルからの出力を、次の層(これもまた選ばれたいくつかのモデルからなります)への入力として用い、その出力をまた次の層への入力とし〜…を繰り返すことで、最終的な予測値を得るというものです。
過学習などに気をつけながら、どのモデルをどの層に使うかなどを考えなければならず、非常に難しい手法となっていますが、その分とても強力です。
ランダムフォレスト
ランダムフォレストとは、アンサンブル学習のバギングを基本として、複数の少しずつ異なる決定木を集めたものです。決定木単体では過学習してしまうという欠点があり、ランダムフォレストはこの問題に対応する方法の1つです。ランダムに元のデータからサンプリングしているため、各決定木はそれぞれのデータを過学習している状態で構築されます。
異なった方向に過学習している決定木を大量に作成し、それぞれの決定木で得られた結果を平均したり多数決をとったりすることで、過学習の度合いを減らし、より汎化性能が高くなっているモデルです。
ニューラルネットワーク(NN)
ニューラルネットワークの構造は、複数の入力を受け取り、それをもとに出力する「ノード」が、いくつも連なって一つの層を成し、その層がさらに何層にも積み上がることで、複雑な問題に対しても高い精度で予測できるようにするというものです。
層は基本的に、入力層、隠れ層、出力層に分けられます。このうち、隠れ層の層の数を増やす(層を深くする)ことで、より困難な問題が解決できるようになるという事例がいくつも認められていて、これがいわゆるディープラーニングへとつながります。
深層学習(Deep Learning)
CNN(Convolutional Neural Network, 畳み込みニューラルネットワーク)
CNNの層には、畳み込み層、ReLu層、プーリング層、完全接続層の4つの主要なタイプがあります。その中でも畳み込み層とプーリング層が大きな影響を及ぼしていて、畳み込み層では、さまざまなフィルタと特徴量を使った計算(畳み込み演算)により、何パターンかの新しい画像データを作り出す。これに対してプーリング層では、簡単に言えば、画像の解像度を下げて抽象化した画像データを作り出します。そして、また畳み込み層の処理を実行し……と繰り返していくことで、元画像は特徴を残したままデータ量が少なくなり、最終的に何の画像なのかを分類できるようになります。
現在、画像認識におけるディープラーニングではCNNが一般的に使用されていています。主なCNNモデルの構造を以下に示します。
モデル名 | ネットワーク構造 |
---|---|
AlexNet | 畳み込み層が5層で、そのうちのいくつかにはMaxPooling層があります。また、出力層には全結合層3層が使用されており、合計で8層により構成されています。 |
VGG | CNNの層を深くすることで認識性能を向上させるというシンプルな構成です。16層(VGG16)と19層(VGG19)があります。大きな特徴としては、3×3フィルターのみの畳み込みを採用している点です。これは、構造をシンプルにするだけではなく、表現能力を上げて、パラメータ数を削減する効果があります。 |
GoogLeNet | 22層で構成されていますが、全結合層が1つで、代わりに全体の平均プーリング層を利用することで最終出力を得ていて、全結合層はその結果を渡されているという構成になっています。過学習を抑制することができます。 また、ネットワークの途中から分岐させたサブネットワークにおいてもクラス分類を行う手法を用いていて、これにより勾配消失(活性化関数の勾配がゼロに近づくことによって、ネットワークの重み付けの修正ができなくなり、学習が進まなくなること)の防止・学習の効率化・ネットワークの正則化が実現されています。 |
ResNet | 152層という非常に深い層を実現した表現力の高いネットワークです。あまりにも深いネットワークは効率的な学習が困難であったが、ResNetでは通常のネットワークのように、入力xの畳み込みを行って出力される関数F(x)を学習するのではなく、入力xをショートカットし、H(x)=F(x)+xを学習するというShortcut Connectionを使用したResidual block(残差ブロック)を導入することで、非常に深いネットワークにおいても効率的に学習ができるようになった。 |
DenseNet | DenseNetは、ResNetを改善したモデルになります。Shortcut Connectionをたくさん導入することで、各層の間が密(dense)に結合している構造を持つことから、DenseNetと呼ばれています。従来よりコンパクトなモデルになりますが、高い性能を持っているのが特徴です。 |
GNN(Graph Neural Network, グラフニューラルネットワーク)
データをグラフ(ノード(または頂点)と2つのノードをリンクするエッジの2つの部分で構成される一種のデータ構造 )で表現し、そのデータに対して推論を実行する一種の深層学習アプローチです。GNNはGCNによって画像に使用できます。
- GCN(Graph Convolutional Networkグラフ畳み込みネットワーク)
CNNの場合では、グリッドで区切った局所的なデータを処理(対象の上下左右斜めの8方向からの情報を畳み込んでいます)し、複数の層を重ねたうえで対象の全体像を把握できます。一方で、グラフでは対象に関係するデータ点を結合して(畳み込んで)データ集合として取り扱うことができます。これにより、最先端のパフォーマンスとともに柔軟性を提供しています。
グラフの適用例は広く、生体内での複雑な化合物の作用もグラフですし、神経回路もグラフ構造、はたまた、人間関係、商品のやり取りなど、何かしらの繋がりは全てグラフで記述できます。つまりCNNと比較すると、CNNは画像などが専門で、GCNはもっと複雑なデータも扱うことができるということです。
まとめ
今回の調査によって、画像分類においては、現状CNNが一番最適なんだなと思いました。
機械学習に限定するなら、ブースティングを使用した勾配ブースティングマシン(XGBoost等)が良さそう。
また、CNNモデルの中では、ResNetやDenseNetが良さげかな。
あとがき
最後までご覧いただきありがとうございました。
少しでも良いと感じていただけましたらLGTMしていただけますと幸いです。
もし間違ってるところがあったら、遠慮なくコメントしてご教示ください。
参考・引用記事
-
機械学習アルゴリズム
-
深層学習アルゴリズム
- CNN
- GNN