SplineCNNとは
CVPR2018で採択されたSplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernelsで提案されている、B-スプライン基底を用いた連続カーネルによりノードの特徴量を集約する手法です。
ノードと記載していることからわかるように、Geometric Deep Learningの分野の手法です。
なお、SplineCNNはPyTorch geometricにおいて実装されているため、簡単に使用できます。
B-スプライン基底とは
この論文を読むまでB-スプライン基底について知らなかったので、こちらなどを参照しつつ理解を深めました。
B-スプライン基底は主にコンピューターグラフィックスの分野で複数の曲線を組み合わせた1つの曲線(=スプライン曲線)を描くために使用される基底(関数)のようです。
単純増加のノットベクトルti、出力値に影響を与える範囲となる次数nを使用し、スカラ値を引数に取って、スカラ値を返します。
例えば要素が7個の一様ノットベクトルを使用した2次のB-スプライン基底は以下の通りです。
(上では次数としてnを使いましたが、nだと何かとややこしいので以降は次数としてdegreeを使用します。)
import numpy as np
import seaborn as sns
def b_spline_basis(knot_vector:np.ndarray, t:float, degree:int, i:int)->float:
if degree == 0:
if knot_vector[i] <= t and t < knot_vector[i+1]:
v = 1
else:
v = 0
else:
w1 = ((t-knot_vector[i])*b_spline_basis(knot_vector, t, degree-1, i))/(knot_vector[i+degree]-knot_vector[i])
w2 = ((knot_vector[i+degree+1]-t)*b_spline_basis(knot_vector, t, degree-1, i+1))/(knot_vector[i+degree+1]-knot_vector[i+1])
v = w1+w2
return v
n_knots = 7
knot_vector = np.linspace(0.0, 1.0, n_knots)
t_vals = np.linspace(0.0, 1.0, 100)
degree = 2
for i in range(n_knots-degree-1):
outputs = map(lambda t: b_spline_basis(knot_vector, t, degree, i), t_vals)
sns.lineplot(x=t_vals, y=outputs)
扱うデータの形状
有向グラフ:G = (V, E, U)
\\
ノード:V = \{1, ..., N\}
\\
エッジ:E \subseteq V \times V
\\
d次元の疑似座標(エッジの特徴量):U \in [0, 1]^{N \times N \times d}
特徴量の集約方法
-
ノード間の位置関係を表すベクトル(疑似座標)を0~1に標準化する。
疑似座標の例 - (1)より
この際、円状や球状にマッピングされるグラフでは、ノード間の位置情報として角度(θ)を使用する場合が考えられます。この時、θを使用する疑似座標の次元について、閉じたB-スプライン基底(closed B-spline)を使用することで、θ=0とθ=2πで使用する重みを同じにして訓練することが可能になります。これによりデータの特性に応じた訓練ができるメリットに加え、訓練の対象となるパラメーターを削減するという効果も発生します。 -
学習するパラメーターとなる重み( product( kernel_size ) * 入力特徴量数( Min ) * 出力特徴量数( Mout ) )を初期化する。
つまり、重みは入力特徴量毎、出力特徴量毎に異なる重みを使用します。 -
疑似座標の各次元ごとに算出したB-スプライン基底の出力値の積と重みの積を得る。(カーネル関数)
-
一様ノットベクトルtを用いたm次のB-スプライン基底関数を考えます。
この時の t[ i ] ≦ x ≦ t[ i ]+1の出力値を0≦x≦1の区間に写像した関数を利用します。
前掲の degree=2 のB-スプライン基底の出力結果を利用すると以下のようなイメージです。上記の例のようにm=2であれば、以下の3つの関数が得られます。
青線:0.5x^2 - x + 0.5 \\ 黄線:-x^2 + x + 0.5 \\ 緑線:0.5x^2
xs = np.linspace(0.0, 1.0, 100) a = map(lambda x: 0.5*(x**2)-x+0.5, xs) b = map(lambda x: -(x**2)+x+0.5, xs) c = map(lambda x: 0.5*(x**2), xs) sns.lineplot(x=xs, y=a) sns.lineplot(x=xs, y=b) sns.lineplot(x=xs, y=c)
-
得られた関数毎に、疑似座標の各次元の値を与え、一つのエッジ毎に( degree + 1 )**次元数 個の出力値を得ます。
また、この際に疑似座標の値(xとする)毎にどの重みを適用するか、B-スプライン基底の形状に合わせてマッピングを行います。※1
因みに、可能性は低いですが、適用する重みを疑似座標に基づいて特定するため、使用されない重みが発生する場合も考えられます。開いたB-スプライン基底(open B-spline)の場合
x=0.0とx=1.0が最も離れた重みを使用するようにマッピングを行います。
(例) kernel_size = 5
degree = 2
x = 0.0 → weight_index = (0, 1, 2) ※2
x = 0.5 → weight_index = (1, 2, 3)
x = 1.0 → weight_index = (3, 4, 0)
閉じたB-スプライン基底(closed B-spline)の場合
x=0.0とx=1.0で同じ重みが使用されるようにマッピングを行います。
(例) kernel_size = 5
degree = 2
x = 0.0 → weight_index = (0, 1, 2)
x = 0.5 → weight_index = (2, 3, 4)
x = 1.0 → weight_index = (0, 1, 2)※1 詳細な重みのマッピング方法について学ぶには、公式の実装コードを確認するのが一番早いと思います。(L86~106)
※2 「1.」に示したように、degree+1個の関数が得られるため疑似座標の各次元に対して適用する重みのインデックスもdegree+1個が得られます。
-
-
「2.」で得られた値とマッピングされた重みの積を得ます。
-
「3.」で得た値とノードの特徴量と重みの積を求め、ノード毎に積の合計値を得る処理をMout回行います。
これで各ノードの特徴量を集約した後のノード( ノード数 * Mout )が得られます。
備忘録的なもの
- 公式の実装コードを確認したところ、4次以上のB-スプライン基底を利用したSplineCNNは行えない模様。ただ2次→3次の曲線の時点でさえグラフの形状に大きな差異がないので、そこまで次数を大きくしてもメリットが無いから実装されてないのかもしれない。
- open B-splineを使用する場合、疑似座標が0.0の場合にマッピングされる1つ目の重みのインデックスと疑似座標が1.0の場合にマッピングされるkernel_size目の重みのインデックスが同一になるため、学習後の重みを描画すると点対称に近い重みの分布になると考えられる。一方で、closed B-splineを使用する場合は疑似座標の値が0.0の時と0.5の時に、それぞれ重みがカーネルの両端にマッピングされることになるため、非対称な分布になりやすいと考えられる。そのため、転移学習を行うケースを想定した場合は、open/closed B-splineのどちらを疑似座標の各次元で使用しているかが、転移学習先のグラフの構造と一致していないと上手く転移学習が行えない気がする。(そもそも、open/closed B-splineの使用状態が各次元で一致していたとしても、異なるデータ、タスクで転移学習がうまくいくかは不明)
- SplineCNNを使用する場合において、適当なカーネルサイズを決める指針がいまいち分かっていない。疑似座標のカーディナリティが低い場合は比較的小さいカーネルサイズにして、そうでない場合は比較的大きめのカーネルサイズにした方が良さそうな気がするが試してみないとわからない。。(今度試す。)
出典
(1) SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels
Matthias Fey, Jan Eric Lenssen, Frank Weichert, Heinrich Müller
https://arxiv.org/abs/1711.08920
(2) コンピュータグラフィックス基礎 第6回 曲線・曲面の表現「Bスプライン曲線」
三谷 純
https://mitani.cs.tsukuba.ac.jp/lecture/2020/cg_basics/06/06_slides.pdf