1. はじめに
みなさん、初めまして。メディア研究開発センターの山内です。今年、新卒で朝日新聞社に入社し、現在は広告へのAI適応や音声認識の研究をしています。入社してからあっという間の1年でした。入社する前は四元数ニューラルネットワークについて研究していました。四元数?? なんじゃそりゃと思った方が大半だと思います(僕もそうでした)。
四元数は4次元に拡張した複素数を用いた代数学です。実はこの四元数、画像認識において、特にカラー画像(3チャンネルの画像)で、実数CNNを凌駕したという論文が2017年に出ました。いきなり四元数畳み込みニューラルネットワークと言われても...となるので、今回は四元数の基本的な概念から四元数ニューラルネットワークの紹介、そして四元数畳み込みニューラルネットワークの仕組みから実験までを紹介したいと思います。
論文:Quaternion Convolutional Neural Networks
2. 四元数
上記の通り、四元数は4次元に拡張した複素数です。基本的には以下の式で表されます。
$$q = a + bi + cj + dk$$
ここで、$a, b, c, d$は実数、$i, j, k$は、それぞれ独立した複素数です。四元数の強みは、3次元空間における回転表現が非常に長けていることです。今回はあまり紹介しませんが、コンピューターグラフィックス(CG)やロボット工学・人工衛星の姿勢制御の計算に使われております。
四元数には以下の重要な要素があります。
\begin{align}
i^2 &= j^2 = k^2 = ijk = -1 \\
ij &= -ji = k \\
jk &= -kj=i \\
ki &=-ik=j
\end{align}
上記の式より、四元数の積には非可換な性質を持っています。つまり、掛け算の順序によって計算の結果が変わってきてしまいます。
そして、このパラメターを用いたニューラルネットワークを四元数ニューラルネットワーク(Quaternion Nueral Netowrk) と呼びます。これは次の章で紹介していきます。
3. 四元数ニューラルネットワーク(Quaternion Nueral Network)
3.1 高次元ニューラルネットワーク
四元数ニューラルネットワークは高次元ニューラルネットワーク(High Dimension Neural Network)の一種です。高次元ニューラルネットワークとは、実数よりも大きなパラメータを用いて表現したネットワークを指します。代表的なのは
- 四元数ニューラルネットワーク(4次元)
- 複素ニューラルネットワーク(2次元)
などがあります。これらは一般的にクリフォード代数で表現されます。クリフォード代数とは、高次元の空間での計算(回転や反転など)を、単純な掛け算のルールで扱えるようにした数学的な体系です。
ニューラルネットワークが深層化していく昨今とは反対に、パラメータの次元を上げていこう!というのが高次元ニューラルネットワークです。また、双対数や双曲数を用いたニューラルネットワークも提案されています。
3.2 四元数のフォワード計算
ここからは四元数ニューラルネットワークのあれこれを紹介していきます。(長いので四元数NNと呼びます) 四元数NNは全てのパラメータが四元数なのでフォワード計算は以下の式になります。$w$は重み、$x$は入力値、$b$はバイアスです。
\begin{align}
U &= \sum{wx} + b \\
&= \sum{{(w_1+w_2i+w_3j+w_4k)(x_1+x_2i+x_3j+x_4k)}} \\[5pt]
& \qquad + {b_1+b_2i+b_3j+b_4k} \\
\end{align}
$w, x, b$がそれぞれ四元数なので、これを展開して$real, i, j, k$でまとめればフォワード計算ができます。計16回の掛け算と虚数の掛け算が混ざるので結構めんどくさいです
頑張ってまとめると$U$は以下の式になります。
\begin{align}
U &= \sum(w_1 x_1 - w_2 x_2 - w_3 x_3 - w_4 x_4) + b_1 \\
&\qquad + \left\{ \sum(w_1 x_2 + w_2 x_1 + w_3 x_4 - w_4 x_3)+b_2 \right\}i \\
&\qquad + \left\{\sum\{w_1 x_3 - w_2 x_4 + w_3 x_1 + w_4 x_2)+b_3 \right\}j \\
&\qquad + \left\{\sum(w_1 x_4 + w_2 x_3 - w_3 x_2 + w_4 x_1)+b_4 \right\}k \\[10pt]
& = U_1 + U_2i + U_3j + U_4k
\end{align}
それぞれの行の結果が$U$の実数と各虚数に相当します。
3.3 四元数の活性化関数
四元数NNの活性化関数は特に難しい数式とかはなく、4つの要素に実数の活性化関数を当てるだけで表現できます。
活性化関数$f(q)$に$U$を代入した値$H$は以下のようになります。
\begin{align}
H = f(U_1) + f(U_2)i + f(U_3)j + f(U_4)k
\end{align}
つまり、四元数版のReLU関数を使いたい時は、今まで使ってきたReLU関数を各成分に通せばいいだけなのです。 このタイプの活性化関数をSplit型活性化関数といいます。
(ここからは飛ばしてもokです。)-----------------------------------
では、なぜこのような関数を使うのでしょうか。そもそも四元数とは複素数の一種なので、複素関数の微分を使います。ここで次の問題が生じます。
- 複素関数で微分可能であるためにはコーシー・リーマンの方程式を満たす必要があるため、使える関数が縛られる
- 当時(複素NNが出た1992年)では有界な活性化関数が主流であった。(非有界活性化関数のReLUの登場は2011年)
- シグモイド関数(有界)を微分すると$z=k\pi \quad(k=\pm1,\pm3, \pm5 ....)$で関数が無限大になり学習がうまくいかなかった
参考 : Tohru Nitta "An Extension of the Back-Propagation Algorithm to Complex Numbers"
ここで、split型にすると実部と虚部はそれぞれ独立した実数の関数と見なすことができます。それにより偏微分が可能になり学習もできるようになりました。また実装の容易さなどを考慮して、四元数NN・複素NNともに、このsplit型を使うのが主流となっています。しかしながら数式的にはあまり美しくないことや、複素数の性質を最大限に活かせていないので、いかに活性化関数を複素関数で表すことができるかが重要です(複素NNだと、複素関数型のReLU関数も提案されています)。
3.3 四元数バックプロパゲーション
四元数NNのバックプロパゲーションは以下の論文で提案されました
Tohru Nitta "A Quaternary Version of the Back-propagation Algorithm"
簡単に説明すると、4つのパラメータをそれぞれ偏微分し、それぞれの和をとるだけです。損失関数を$E_p$, ある層$m$と$n$の重みを$w_{nm} = w_{nm}^a + w_{nm}^bi + w_{nm}^cj + w_{nm}^dk $、学習率を$\epsilon$とすると、修正項$\nabla w_{nm}$は
\begin{align}
\nabla w_{nm} = - \epsilon \left\{ \frac{\partial E_p}{\partial w_{nm}^a} + \frac{\partial E_p}{\partial w_{nm}^b}i + \frac{\partial E_p}{\partial w_{nm}^c}j + \frac{\partial E_p}{\partial w_{nm}^d}k \right\}
\end{align}
となります。仮に計算しようとすると、実数ニューラルネットワークの4倍の項があるので、かなりの数の偏微分を行う必要があります。しかし、pytorchやtensorflowなどのフレームワークを使えば、重みを設定してフォワードの式さえ書けばあとは全部やってくれます。本当に感謝しかありません。学生時代は手計算で数式を求めて、それをnumpyとforで回して学習していました.....。
4. 四元数畳み込みニューラルネットワーク
4.1 概要
ここらは、今回の記事の四元数畳み込みニューラルネットワークの概要を紹介して行きたいと思います。まず、この論文で提案されたモデルの概要をまとめます。
-
従来の実数値CNNを拡張した四元数CNN(QCNN) を提案。カラー画像の3チャンネル(RGB)を四元数として3次元空間の回転で表現
-
カラー画像分類タスク・ノイズ除去のタスクで実数値CNNよりも少ないパラメータ数で高精度を実現(グレースケール画像では実数モデルと同等の精度)
左が実数CNN(従来のモデル)で、右が提案された四元数CNNです。実数CNNはRGBの3次元の情報を独立して処理を行い、最後に和をとるなどして特徴量を表現しています。これは、各チャネルを独立して変換してから加算するため、チャネル間の複雑な相関関係を表現するのが難しいという課題がありました。一方で四元数CNNは3次元のRGB情報を四元数の$i, j, k$で構築された3次元空間に投影し、色空間内での回転と大きさの変換として処理します。これにより、3次元の色ベクトルとして連続的に操作可能になり、実数CNNよりも表現豊かな色彩情報を捉えることができます。つまり、実数CNNが各チャンネルを「バラバラに」処理するのに対し、QCNNは色を「一体として」処理します。
四元数で3次元のチャンネルを扱うときは以下の式で表されます。
$$0 + Ri + Gj + Bk$$
実数成分を0にし、$i, j, k$の純虚数部分で処理を行います。この手法はRGBにとどまらず、3次元の座標などでもよく使われます。
4.2 四元数の畳み込み
メインとなる畳み込みの処理は以下の数式で行われます(かなり複雑なので、色彩情報を捉える回転行列を作り、畳み込んでいるんだなという認識でokです)。
- $s_{ll'}$は実数のスケーリング係数で$s_{ll'} \sim U\left[-\frac{\sqrt{6}}{\sqrt{n_j + n_{j+1}}}, \frac{\sqrt{6}}{\sqrt{n_j + n_{j+1}}}\right]$の一様分布で生成
- $\hat{w_{ll'}}=s_{ll}'(cos(θ_{ll'}/2) + sin(θ_{ll}'/2)μ)$
($μ$ はグレー軸を表す単位ベクトル$\frac{\sqrt{3}}{3}(i + j + k)$) - $\hat{w_{ll}^*}$は、$\hat{w}_{ll'}$の共役四元数
- $\theta_{ll'}$ は回転角 ([-π, π]の範囲)
- $a_{(k+l)(k'+l')}$は入力画像のピクセル位置
ここで、$f_1$, $f_2$, $f_3$ は以下のように定義されます。
- $f_1=\frac{1}{3}+\frac{2}{3}\cos\theta_{ll′}$
- $f_2 = \frac{1}{3} - \frac{2}{3}\cos(\theta_{ll'} - \frac{\pi}{3})$
- $f_2 = \frac{1}{3} - \frac{2}{3}\cos(\theta_{ll'} + \frac{\pi}{3})$
論文には具体的な説明がありませんでしたが、推察するにどうやらここで色の色彩関係を表現する回転行列を作成しているようです。この回転行列と入力値をpytorchでいうnn.ConvNdに入れることで、畳み込み演算ができます。また、全結合層もこの回転行列と入力値をnn.Linearに入れることで実装できます。
今回注意することは、この論文での四元数NNはカラー画像のタスクに特化しているということです。従来の四元数NNの実装は上記で書いたフォワード計算・バックプロパゲーションによって実装されているのがほとんどです。
5. 実験
今回、論文の分類タスクと同様にCIFAR10を用いた、四元数CNNと実数CNNの比較実験を行いました(ノイズ除去の実験は行っておりません)。
四元数CNNを実数CNNのおよそ半分に設定し、両モデル4つの畳み込み層と2つの全結合層の計6層のニューラルネットワークを構築しました。また、訓練データとテストデータを8:2に分割し、損失関数はクロスエントロピー誤差を使用しました。
具体的なパラメータは以下のとおりです。
四元数CNN | 実数CNN | |
---|---|---|
学習率 | 1e-6 | 1e-6 |
バッチサイズ | 64 | 64 |
epoch数 | 100 | 100 |
パラメータ | 5.2M | 11.2M |
学習曲線とテスト結果を示します。
lossは四元数CNN, 実数CNNとも収束しており学習ができています。また、val_accもepochが進むにつれて上がってきています。
次にテストデータで実験してみると、四元数CNNがわずかに実数CNNを上回る結果となりました。
四元数CNN | 実数CNN | |
---|---|---|
Test acc | 0.880 | 0.876 |
Test loss | 0.383 | 0.536 |
パラメータ数が半分でも実数CNNを超えることができました。
6. まとめ
今回、四元数の紹介から実験結果までを紹介しました。結果的に、四元数CNNは実数CNNの半分のパラメータで性能を上回りました。実際に論文でも実数よりも少ないパラメータでより良い性能を出しています。これはパラメータ数至上主義の今のLLM時代において、タスクは限られてしまいますが、「パラメータ数が少なくて済むが精度は変わらない、むしろ良い性能が出る」といったポテンシャルが四元数ニューラルネットワークは含まれているのではないでしょうか。四元数をはじめ、高次元ニューラルネットワークはニッチな研究分野で、知らない人が大半だと思いますが、面白い性質や実数を超える性能を出すことがあります。皆さんもぜひ「高次元の世界」へ足を広げてみてはいかがでしょうか!