fewshot 学習
1年位ゆっくりやってfew shot 学習に取り組んでました。他の事もやってましたが。
特に取り組んだ手法はPT+MAPという手法。
1年前に参考サイト1で見て、 一番正解率がいいという安直な理由でこの手法をコピーしてみようと思いました。
実装自体はほぼ自作した結果、論文の正解率より少し低くなってしまいましたが、とりあえずいい勉強にはなりました。
手法のざっくり解説
まずはfew shot 学習の問題設定を示します。
- $D_{base}\cap D_{novel}=0$:問題データセット全体の集合
- $D_{base}$:K個の異なったクラスが存在し、すべてにラベルがつけられたデータセット、データの数は十分に大きい。
- $D_{novel}$:W個の異なったクラスが存在し(w-ways)、各クラスにはs個のラベルがついたデータ(s-shot)とq個のラベルのついていないデータが存在する。$D_{novel}$の中のデータの合計は$w(s+q)$であり、この中でラべルのついていないwq個
のデータのクラスを推定するのがfew shot学習の問題設定となります。
一般的な機械学習ではsは大きな数字を使う事が多いですが、このsを小さい状態で学習させる事が
できれば、小さいデータ(=労力)で汎化性能を得ることができるモデルということができます。
そして、この論文の手法自体はよくあるニューラルネットワーク(NN)の転移学習の一種です。
①$D_{base}$でNNを学習させる。
②学習したNNに対して、$D_{novel}$のデータ全てを入力し後ろ側の層で出力を取り出す。
ーこの後にPTを適用するのですが、PTは正の値の集合にしか適用できないため、この時にはreluの出力を取り出します。
③取り出した出力$\boldsymbol{V}_{kl}\in \boldsymbol{R}^{n \times w(s+q)}$の各特長量出力$\boldsymbol{V}_k = \boldsymbol{v}$に対してPTを適用する。
ーここで$k$は$0<k\leq{n}$である各特徴量のインデックスであり、$l$は$0<l\leq{w(q+s)}$である$D_{novel}$中の各クラスを表すインデックスです。
PTは以下の関数$f$を適用します。
$$
f(\boldsymbol{v})=\begin{cases}
\frac{(\boldsymbol{v}+\epsilon)^{\beta}}{||(\boldsymbol{v}+\epsilon)^{\beta}||_{2}} & {if}\quad b\neq 0\
\frac{\log (\boldsymbol{v}+\epsilon)}{|| \log (\boldsymbol{v}+\epsilon)||_{2}} & {if}\quad b=0
\end{cases}
$$
$\epsilon$は小さな数であり、ゼロ除算を防ぎます。
割り算の部分を実行していないPTもありますが、割り算部分で行っている事は単純な正規化です。
NNの出力分布は左図のように偏った分布になることが多いため、PTを用いて右図のようにガウス分布に近づけさせます。
④PTの出力の中でラベルの付いたデータの出力から平均を求め、それをクラスセンター$c_j$とする。
⑤この$c_j$に対して全ての出力を$f_i$とし$\boldsymbol{L}_{ij}=|\boldsymbol{f}_{i}-\boldsymbol{c}_{j}|^{2}$を計算する。
ここで$i$は$0<i\leq{wq}$である$D_{novel}$中でラベルのついていないデータのインデックスであり、
$j$は$0<j\leq{w}$である$D_{novel}$中の各ラベルのインデックスです。
⑥$L_ij$を輸送コストとしてsinkhorn法で輸送行列を求める。
輸送行列は次の式を最小化する$M^*$で表されます。
$$
M^{*} = \text{argmin}_{M\in U(\boldsymbol{p},\boldsymbol{q})} \sum_{ij}M_{ij}L_{ij}+\lambda H(M)
$$
ここで$U(\boldsymbol{p},\boldsymbol{q})=\{M \in[0,1]_+^{wq \times w}|M1_w=\boldsymbol{p},M^T1_{wq}=\boldsymbol{q}\}$
であり、輸送行列の制約を表します。また$H(M)=-\sum_{ij} M_{ij}logM_{ij}$であり、輸送行列のエントロピーを表します。sinkhorn法で求められた輸送行列$M$の各$M_i$は$d_i \in D_{novel}$がどれだけクラスjに割り当てされるかを表す行列となります。sinkhorn法について詳しくは他記事12にお願いしましょう。何より私もよく分かっていないので笑
⑦クラスセンターを更新していくために、まずは各クラスに割り当てられたベクトルの平均を計算する。
$$
\boldsymbol{\mu}_{j}=g(\boldsymbol{M}^{*},j)=
$$
$$
\frac{\sum_{i=1}^{wq}M^*_{ij}\boldsymbol{f}+\sum_{\boldsymbol{f}\in \boldsymbol{f}_s\ ,\ l(\boldsymbol{f}=j)}\boldsymbol{f}}{s+\sum_{i=1}^{wq}M^*_{ij}}
$$
ここで$\hat{l}(\boldsymbol{f}_{i})$はデータ$f_i$の割り当てられたラベルを表します
この平均を用いて下記の式でクラスセンターを更新します。
$$
\boldsymbol{c_j}=\boldsymbol{c_j} - \alpha (\boldsymbol{\mu_j} - \boldsymbol{c_j})
$$
ここで$0<\alpha \leq 1$は学習率です。
⑧ ⑤から⑦を任意の回数繰り返して、最後の更新で求まった$\boldsymbol{M}$に対して
$$
\hat{l}(\boldsymbol{d}_{i})=\text{argmax}_j(\boldsymbol{M}^*_{ij})
$$
として、各データのラベルを求める。
コード
この手法の特徴的な所はPTとsinkhorn法にあると思うので、それらのコードを示しておきます。
PT
def pt(v, beta):
eps = 1e-6
v = v - tf.math.reduce_min(v)
#like relu
# v = tf.where(v>0.0, v, 0)
#relu
v_plus_eps = v + eps
if beta == 0:
return tf.math.log(v_plus_eps) / tf.norm(v_plus_eps)
else:
v_plus_eps = tf.math.pow(v_plus_eps, beta)
ans = v_plus_eps / tf.norm(v_plus_eps)
return ans
元論文ではreluの出力を取り出すことでPTを適用できるようにしていましたが、今回の実装では
単純に最小値を減算して、全ての値を0以上にしました。この方法でもそれなりの正解率となっていますし、
PTのバージョン違いみたいなもので負の値を取れるようになる変換手法も存在します。このあたりは色々とバリエーション
があるかと思います。
PT
def sinkhorn_stab(C, a, b, L, epsilon):
K = tf.math.exp(-C / epsilon)
f = tf.ones(C.shape[0])
g = tf.ones(C.shape[1])
def softmin(C, axis):
z = tf.reduce_min(C, axis)
z_sub_C = z - C if z.shape[0] == C.shape[1] else tf.transpose(z - tf.transpose(C))
return z - epsilon * tf.reduce_logsumexp(z_sub_C / epsilon, axis=axis)
# 安定なsoftmin
def softmin2(C, axis):
return -epsilon * tf.reduce_logsumexp(-C / epsilon, axis=axis)
# 不安定なsoftmin
def S(C, f, g):
return tf.transpose(tf.transpose(C - g) - f)
l = 0
while l < L:
f = softmin(S(C, f, g), axis=1) + f + epsilon * tf.math.log(a)
g = softmin(S(C, f, g), axis=0) + g + epsilon * tf.math.log(b)
l += 1
P = tf.matmul(K, tf.linalg.diag(tf.math.exp(g / epsilon)))
P = tf.matmul(tf.linalg.diag(tf.math.exp(f / epsilon)), P)
return f, g, P
sinkhorn法にも数値的に安定な方法 2と不安定な方法3があるようですが、不安定な方法の方が実行速度は速いようです。
安定な方法はsoftminを計算するとき最小値を求める必要があるため、ソートが必要になり時間がかかります。
しかし、不安定な方法は本当に不安定でほとんどの場合で$M^*$がnanとなってしまいました。 そのため、後の実験では完全に不安定な方法を使用しています。
結果
cifar100を使って試した結果、5-wayで正解率が80~85%位は行くようになりました。
自作の実装コード
下記がハイパーパラメータをoptunaで試行してみた結果です。縦軸がクラスセンターの更新回数、横軸がPTのハイパーパラメータ、$\beta$の値となっています。
$\beta$が0.5あたりの所が一番正解率が高い結果となっています。
また、論文ではクラスセンターを更新する手法をとっていましたが、更新せずとも一回でラベルを求めた場合が一番の結果となりました。
結果図からは、むしろ更新を行うと結果が悪化する場合もあるようです。
論文の結果では最高正解率90%まで行っているようですので、事前学習が足りないとかあるんでしょうか?
80~85%でも高い方だとは思うのでとりあえずは満足してしまいました笑
考察
PTを採用するという点は面白いと思います。こんな手法があるとは知りませんでした。確かにNNの出力分布はそのままでは歪すぎな印象でそのためにbatch_normalizationなどの正規化で分布の調整が行われています。PTも微分可能なため、batch_normの代わりに使うというのも面白いかもしれません。ただし、ハイパーパラメータがあるため、そこの調整が必要ですが。
sinkhorn法も面白い手法だと思います。参考サイト3にもあるのですが、この手法は微分もできるため、この手法の途中でさらに学習することも可能です。さらにgpuとも相性が良いです。不安定な方法もメインは行列演算ですし、安定な方法も総和を取る所 4や最小値を探す所5はgpuで行う事で高速化が見込めます。
しかし、この手法には明確な弱点もあります。sinkhorn法では各クラスのデータ数が同じであるという仮定を置いています。この仮定は強力すぎて、各ラベルのデータ数が不均衡な場合に精度が悪化する可能性がありますし、現実のデータはその傾向が強いでしょう。
あとがき
few-shot leaningについてですが、やはり少ないデータで学習が終える事ができれば、それに越した事はないのですが、やはり事前知識というものはあった方がいいのだと思います。
生まれたての鳥が動いている物体を親と認識することも、視界の中で動いている物体が親になりうる可能性が一番高いからという事前知識があっての物だと思うので、何らかのバイアスはやはり必要なのだと思います。実際、GPT-3などの手法は大きな事前知識で未知のデータに対応しようというアプローチです。老人か若者、どちらが対応力がありそうかと聞かれれば、やはり老人の方が対応力がありそうな気がします。最終目標はARK6を解くことで、ARKのデータだけでは少なすぎるような感じもしますので、ほかのデータセットで事前知識を手に入れるという事が必要になってくると考えています。
合わせて、fewshot学習ではやはりデータ間での比較が必要なのかなと思います。今のdeepleaningで各データ間を比較するもので一番よくつかわれているのはbatch normだと思いますが、batch normだけでは単純すぎるような気もするのでもっと凝った手法が今後出てくるのかなーと。
参考論文