0. はじめに
この記事では「ソート」を「微分」する方法と、ソートの一般化について説明します。ソートの一般化を考えることで、quantile関数の「微分」も計算できるようになります。
また、実際にソートの微分と一般化を利用して、least quantile regressionという機械学習のタスクの実験を行います。
最近Googleが発表し注目されている次の論文の内容の紹介になります。
Differentiable Ranks and Sorting using Optimal Transport
輸送問題を微分する記事と似たようなテクニックを用いるので一部説明がかぶります。
ソートの微分?
我々がソートを利用するとき、次のような2つの関数、ソート関数$S(x)$とランク関数$R(x)$の形で利用することが多いと思います。
x=(x_1,\dots ,x_n) \in R^n \\
S(x)=(x_{\sigma _1},\dots ,x_{\sigma _n}) \ \ \ \ \ \ x_{\sigma _1}\leq x_{\sigma_2}\leq \dots \leq x_{\sigma _n} \\
R(x)=(rank(x_1),\dots ,rank(x_n)) \ \ (=\sigma^{-1})
例えば$S((3.4, 2.3, -1))=(-1,2.3,3.4)$、$R((3.4, 2.3, -1))=(3,2,1)$です。
本記事では「ソート」の「微分」として、
$$
\frac{\partial S(x)}{\partial x_i}や\frac{\partial R(x)}{\partial x_j}
$$
(のような何か)を考えることにします。
※$S(x),R(x)$はそのままでは微分可能ではありません。
ソートの微分の応用
ソートの「微分」ができると何がうれしいのでしょうか。
一番の利点は、上記論文にあげられているようにNeural Networkの出力をソートして何かを解くタイプのタスクがend to end で学習できることです。
例えば
- 数字の書かれた複数の画像が入力されたとき、数字の大きい順に画像のidを返す
- 商品に対してレコメンドスコアを付与し、ランキングロスを直接微分してレコメンドモデルを学習する
- NLPのタスクで、微分可能なビームサーチdecoderを利用する
といった問題が考えられます。
また、ソートが微分できることを利用すると、後述するようにquantile関数も微分できるようになるので「回帰誤差のn%値を最小化する」least quantile regression というタスクが直接目的関数を微分する勾配降下法を利用して解けるようになります。
記号
- $O^n$
$\{ x|x\in R_n,x_1 \leq x_2 \leq \dots \leq x_n \}$ - $R^n_+$
$\{ x|x\in R_n,x_i \geq 0,i=1,\dots ,n \}$ - $R^n_{+,*}$
$\{ x|x\in R_n,x_i > 0,i=1,\dots ,n \}$ - $1_n$
$(1,...,1) \in R^n$ - $\Sigma_n$
$\{x|x \in R^n_+, \sum x_i=1 \}$ - $U(a,b)$
$a \in R^n_+, b \in R^m_+, U(a,b)= $ $\{ P|P\in R^{n\times m}_+, P1_m=a, P^T1 _n=b \} $ - $\langle P,Q\rangle$
$P\in R^{n\times m},Q\in R^{n\times m},\langle P,Q\rangle=\sum_{i,j}P_{i,j}Q_{i,j}$ - $diag(a)$
$a \in R^n , diag(a):$ 対角成分が$a$である$n\times n$の対角行列
1. ソートと輸送問題と微分
さて、$S(x),R(x)$はそのままでは微分可能ではありません。さらに、例えば$x$を$R^n$からランダムにサンプリングした場合、$x_1$を少し増減させた程度では$rank(x_1)$は変化しないはずです。つまり、「微分っぽいなにか」が定義できたとしても常に0になってしまう気がします。
これらの問題を解決してソートの微分を計算するために次のようなステップを踏みます。
- ソートをある種の輸送問題とみなす。
- この輸送問題に正則化項を加えた問題を考える。(唯一の最適解が存在する。)
- 正則化項つき輸送問題は、微分可能な形で最適解を求める近似アルゴリズム(Shinkhornアルゴリズム)が存在する。
- 2をShinkhornアルゴリズムで解き、出力を微分する。
以下、順に説明します。
ソートと輸送問題
輸送問題
輸送問題はその名の通り、複数の工場から複数の店舗への最適な商品の輸送の仕方を決める問題です。
各工場と店舗間の配送には配送料に応じたコストがかかり、総配送コストを最小化するように各輸送量を決めます。
数式で書くと、$a\in R^n_+,b\in R^m_+, C\in R^{n\times m}_+$が与えられたとき 、$\langle P,C\rangle$を最小化する$P\in U(a,b)$を求めることです。
この記事では
$L_C(a,b)=min_{P\in U(a,b)}\langle P,C\rangle$
と書きます。
$a$が工場の供給量、$b$が店舗の需要量、$C_{i,j}$が工場$i$店舗$j$間の単位量当たりの輸送コスト、$P_{i,j}$が工場$i$店舗$j$間の輸送量、$\langle P,C\rangle$が輸送の総コストですね。最適な輸送を行った場合の総コストが$L_C(a,b)$です。
特別な輸送問題とソートの対応
さて、次のような単純な輸送問題を考えます。
- 工場と店舗の数は同じで$n$個
- 工場と店舗は1直線上に並んでいる
- $x,y\in R _n$をそれぞれ工場と店舗の座標とする。
- 各店舗の位置は$y _1 < y _2 < \dots < y _n$という関係を満たしているとする。($y\in O^n$)
- 工場店舗間の距離が遠くなるにしたがって配送コストが大きくなる。微分可能な非負の狭義凸関数$h$を利用して$C _{i,j}=h(y _j -x _i)$とかける。
- 工場の供給量と店舗の需要量はすべて同じで$1/n$、つまり$a=b=1 _n /n$
この時、ソートと輸送問題を結びつける次の命題が成り立ちます。
命題 1.
上記の状況のもと、輸送問題$L_C(a,b)=min_{P\in U(a,b)}\langle P,C\rangle$の最適解の一つを$P _* $とする。このとき次が成り立つ。
R(x)=n^2 P_* \hat{b} \\
S(x)=n P_*^T x \\
ここで$\hat{b}=(b_1, b_1+b_2, \dots ,\sum b)^T = (1/n, 2/n,\dots, 1)^T$
実際、次のような輸送問題を考えてみます。
工場id | 工場座標 | 供給量 | 店舗id | 店舗座標 | 需要量 | |
---|---|---|---|---|---|---|
1 | 2 | 1/3 | a | 0 | 1/3 | |
2 | 1 | 1/3 | b | 1 | 1/3 | |
3 | 0 | 1/3 | c | 2 | 1/3 |
輸送コスト(=距離の二乗)
工場\店舗 | a | b | c |
---|---|---|---|
1 | 4 | 1 | 0 |
2 | 1 | 0 | 1 |
3 | 0 | 1 | 4 |
最適な輸送量は次のようになるはずです。
工場\店舗 | a | b | c |
---|---|---|---|
1 | 0 | 0 | 1/3 |
2 | 0 | 1/3 | 0 |
3 | 1/3 | 0 | 0 |
これらを命題の式に代入してみます。
3^2 \left(
\begin{array}{ccc}
0 & 0 & 1/3 \\
0 & 1/3 & 0 \\
1/3 & 0 & 0
\end{array}
\right)
\left(
\begin{array}{ccc}
1/3 \\
2/3 \\
1
\end{array}
\right) =
\left(
\begin{array}{ccc}
3 \\
2 \\
1
\end{array}
\right) = R(
\left(
\begin{array}{ccc}
2 \\
1 \\
0
\end{array}
\right)
)
3 \left(
\begin{array}{ccc}
0 & 0 & 1/3 \\
0 & 1/3 & 0 \\
1/3 & 0 & 0
\end{array}
\right)
\left(
\begin{array}{ccc}
2 \\
1 \\
0
\end{array}
\right) =
\left(
\begin{array}{ccc}
0 \\
1 \\
2
\end{array}
\right) = S(
\left(
\begin{array}{ccc}
2 \\
1 \\
0
\end{array}
\right)
)
命題の式が成り立っていることが確認できます。
輸送問題と微分
前章では特別な輸送問題の解を利用して、$S(x),R(x)$が書き下せることを確認しました。したがって、輸送問題の解(命題 1の$P _*$)の$C _{i,j}$による微分が計算できれば、$S(x),R(x)$の$x _i$による微分も計算できることになります。この$P _*$自身は微分可能ではないのですが、$P _*$の近似解を微分可能な形で求める方法が存在します。
すなわち、まず次のような"正則化項"つきの輸送問題を考えます。
つまり、輸送量のエントロピーを
$H(P)=-\sum_{i,j}P_{i,j}(log(P_{i,j})-1)$
として、もとの問題の代わりに
$L_C^{\epsilon}(a,b)=min_{P\in U(a,b)}\langle P,C\rangle - \epsilon H(P)$ ★
を考えます。
$\epsilon \to 0$でこの正則化項つき輸送問題の解は、元の輸送問題の解に収束します。
また、次のShinkhornアルゴリズムによって、近似解を微分可能な形で求めることができます。
Shinkhornアルゴリズム
init $u=u^0,v=v^0,l=0$, calc K;
while $l$ < MAX_ITER:
$\ \ \ \ u=a/(Kv)$
$\ \ \ \ v=b/(K^Tu)$
$\ \ \ \ l++$
$P=diag(u)Kdiag(v)$
return $P$
MAX_ITER$\to \infty$とするとShinkhornアルゴリズムの出力は★の最適解に収束します。
このShinkhornアルゴリズムに関しては
で詳しく解説しました。
PyTorchによる実装
PyTorchでShinkhornアルゴリズムを実装し、それを利用してソートの微分を計算してみます。
import torch
from torch import nn
# Shinkhornアルゴリズム
class OTLayer(nn.Module):
def __init__(self, epsilon):
super(OTLayer,self).__init__()
self.epsilon = epsilon
def forward(self, C, a, b, L):
K = torch.exp(-C/self.epsilon)
u = torch.ones_like(a)
v = torch.ones_like(b)
l = 0
while l < L:
u = a / torch.mv(K,v)
v = b / torch.mv(torch.t(K),u)
l += 1
return u, v, u.view(-1,1)*(K * v.view(1,-1))
# sort & rank
class SortLayer(nn.Module):
def __init__(self, epsilon):
super(SortLayer,self).__init__()
self.ot = OTLayer(epsilon)
def forward(self, x, L):
l = x.shape[0]
y = (x.min() + (torch.arange(l, dtype=torch.float) * x.max() / l)).detach()
C = ( y.repeat((l,1)) - torch.t(x.repeat((l,1))) ) **2
a = torch.ones_like(x) / l
b = torch.ones_like(y) / l
_, _, P = self.ot(C, a, b, L)
b_hat = torch.cumsum(b, dim=0)
return l**2 * torch.mv(P, b_hat), l * torch.mv(torch.t(P), x)
sl = SortLayer(0.1)
x = torch.tensor([2., 8., 1.], requires_grad=True)
r, s = sl(x, 10)
print(r,s)
tensor([2.0500, 3.0000, 0.9500], grad_fn=<MulBackward0>) tensor([1.0500, 2.0000, 8.0000], grad_fn=<MulBackward0>)
(微分の計算)
r[0].backward()
print(x.grad)
tensor([ 6.5792e-06, 0.0000e+00, -1.1853e-20])
2. ソートの一般化とquantile関数
ソートの一般化
前章では特別な輸送問題の解を用いてソート関数とランク関数$S(x),R(x)$が書き下せることを見ました。このソートに対応する特別な輸送問題は、
- 工場の数と店舗の数は同じ
- 工場の供給量と店舗の需要量はすべて同じ
としていました。しかし、一般には、工場と店舗の数が異なっても、工場や店舗によって供給量や需要量が異なっても、輸送問題を考えることができます。これら一般の輸送問題の解に対応する$S(x),R(x)$を考えれば、ソートやランクの一般化ができるはずです。
このような考えのもと、Differentiable Ranks and Sorting using Optimal Transportでは次のような一般化されたソート関数とランク関数である、K(Kantorovich)ソートとKランクが導入されました。
定義 1. KソートとKランク
任意の $ x\in R^n,y\in O^n,a \in \Sigma_n, b\in \Sigma_m $ と狭義凸関数 $h$ に対して、輸送問題
$$
L_C(a,b)=min_{P\in U(a,b)}\langle P,C\rangle, C _{i,j}=h(y _j - x _i)
$$
の最適解の一つを$P _*$とおく。
このとき$x$のKソートとKランクをそれぞれ次のように定義する。
\hat{S} (a,x,b,y)=diag(b^{-1})P^T _* x \\
\hat{R} (a,x,b,y)=n* diag(a^{-1})P_* \hat{b}
ここで$a^{-1},b^{-1}$はベクトルの各要素の逆数をとったものを表す。
$P_* \in U(a,b)$なので
P_* 1_n = a \\
P^T_* 1_m = b
を満たします。したがって$\hat S,\hat R$の$diag(b^{-1})P^T _*$と$diag(a^{-1})P
_*$の部分は$P^T _*,P _*$の行の和がすべて1になるように正規化しているとみなせます。また$a=b=1 _n/n$とすると元の$S,R$になるので、ソート関数とランク関数の自然な拡張になっています。
また、$b=1_n/n$のとき$n\hat b$は「ランク」を順に並べたもの$(1,2,\dots,n)$になることに注意してください。したがって一般の$b$に対応する$\hat b$は「ランク」の一般化とみなせ、実数値をとります。この点を考慮に入れると
- $\hat S _i$は$i$番目の順位(ランク)に割り当てられる$x _i$の線形結合
- $\hat R _j$は$x _j$が割り当てられる一般化された順位(ランク=$\hat b$)の線形結合
になっていることも確認できます。
通常のソート、ランクとKソート、Kランクの違いを下のような模式図で表してみました。KソートとKランクは$a=1 _5/5,b=(0.48,0.16,0.36)$の場合を表しています。
quantile関数
quantile関数とは、
$$
q(x,\tau)=xの\tau %点
$$
なる関数$q$のことです。実はKソート$\hat S$を利用してquantile関数を効率的に計算することができます。
例えば、$x\in R^n$をデータ点の集合、$a=1_n/n, y=(0,1/2,1), b=(0.29,0.02,0.69)$なる状況でKソート$\hat S (a,x,b,y)$を考えてみます。
上の図のように、xの下位30%は$y_1$に、上位60%は$y_3$に対応付けられ、xの30%点近くの点のみ$y_2$に対応付けられると想像できます。
したがってquantile関数はKソートとある程度小さい値tを利用して、
$$
q(x,\tau ;t)=\hat S (1_n/n, x, (\tau /100 -t/2,t,1-\tau /100-t/2),(0,1/2,1))[2]
$$
とかけると期待できます。前章で説明したようにShinkhornアルゴリズムを利用すれば$\hat S$の微分可能な近似を求めることができるので、quantile関数の微分も計算できることになります。
また、上記のKソートは$O(nl)$($l$はShinkhornアルゴリズムのiteration数)の計算オーダーで済むことに注意してください。
PyTorchによる実装
実際にPyTorchでKソートを実装してquantile関数を計算してみます。前章のOTLayerを利用します。
class QuantileLayer(nn.Module):
def __init__(self, epsilon):
super(QuantileLayer,self).__init__()
self.ot = OTLayer(epsilon)
self.y = torch.Tensor([0.,0.5,1.])
def forward(self, x, tau, t, L):
l = x.shape[0]
C = ( self.y.repeat((l,1)) - torch.t(x.repeat((3,1))) ) **2
a = torch.ones_like(x) / l
b = torch.Tensor([tau-t/2, t, 1-tau-t/2])
_, _, P = self.ot(C, a, b, L)
b_hat = torch.cumsum(b, dim=0)
return (torch.mv(torch.t(P), x) / b)[1]
(30%点を求める)
import numpy as np
np.random.seed(47)
x = np.random.rand(1000)
quantile = QuantileLayer(0.1)
print(quantile(torch.tensor(x,dtype=torch.float), 0.3, 0.1, 10))
tensor(0.3338)
3. ソートの微分と一般化の応用
ソートの微分と一般化の機械学習への簡単な応用として、least quantile regression というタスクを、quantile関数を用いた目的関数を直接微分し勾配法を実行することで解いてみます。
least quantile regression
普通の線形回帰では予測誤差の「平均」を最小化するようにモデルを最適化しますが、least quantile regression は誤差の「n%値」を最小化するようにモデルを学習します。これには誤差の「50%値」である中央値を最小化するタスクも含まれます。
例えば
- データのノイズがガウス分布でなく、教師値が特定の方向へ大きくずれたりする。
といった状況で、外れ値的なデータを無視して、誤差の「中央値」を最小化したい、といった場合に有用です。
この記事では、「平均」と「中央値」の違いを説明するのによく用いられる年収のデータを用いてleast quantile regressionの実験を行います。上場企業の平均年齢と平均年収のデータを利用して、年齢から年収を予測するモデルを作成してみます。本記事で解説した手法を用いて「誤差の中央値」を直接微分して勾配降下法を実行することでモデルを学習します。
データはyutakikuchiさんがgithubで公開しているものを使用させていただきました。
実験
データから平均年齢45才以下の企業を抽出して利用します。(より通常の線形回帰との差異を観察しやすくするためです。)
おおむね年齢と年収は比例しているようですが、ノイズは正規分布でなく、年収が大きい方向へ分散が広がっていることが確認できます。通常の線形回帰モデルは外れ値に引きずられ、勾配を大きく見積もってしまうと予想できます。
本記事で開設した手法で、誤差の「中央値」(=50%点)を直接微分することで線形モデルを学習するコードは次のようになります。(データ整形の部分は省略してあります。)
(Sinkhorn アルゴリズムを実行するlayer)
import torch
from torch import nn
import torch.optim as optim
class OTLayer(nn.Module):
def __init__(self, epsilon):
super(OTLayer,self).__init__()
self.epsilon = epsilon
def forward(self,
C, # batch_size, n, m
a, # batch_size, n
b, # batch_size, m
L):
bs = C.shape[0]
K = torch.exp(-C/self.epsilon) # batch_size, n, m
u = torch.ones_like(a).view(bs,-1,1) # batch_size, n, 1
v = torch.ones_like(b).view(bs,-1,1) # batch_size, m, 1
l = 0
u = a.view(bs,-1,1) / (torch.bmm(K,v) + 1e-8) # batch_size, n, 1
v = b.view(bs,-1,1) / (torch.bmm(torch.transpose(K, 1, 2),u) + 1e-8)# batch_size, m, 1
l += 1
return u * K * v.view(bs,1,-1) # batch_size, n, m
(quantile関数を計算するlayer)
class QuantileLayer(nn.Module):
def __init__(self, epsilon, y):
super(QuantileLayer,self).__init__()
self.ot = OTLayer(epsilon)
self.y = y.detach()
def forward(self, x, # batch_size, seq_len
tau, t, L):
bs = x.shape[0]
seq_len = x.shape[1]
C = ( self.y.repeat((bs,seq_len,1)) - x.unsqueeze(-1).expand(bs,seq_len,3) ) **2 # batch_size, seq_len, 3
a = torch.ones_like(x) / seq_len # batch_size, seq_len
b = torch.Tensor([tau-t/2, t, 1-tau-t/2]).expand([bs, 3]) # batch_size, 3
P = self.ot(C, a, b, L) # batch_size, seq_len, 3
k_sort = torch.bmm(
torch.transpose(P,1,2), # batch_size, 3, seq_len
x.unsqueeze(-1) # batch_size, seq_len, 1
).view(-1) / b.view(-1) # 3,
return k_sort[1]
(年齢=ageと年収=incomeのデータ準備)
import pandas as pd
data = pd.DataFrame({"age":ages, "income":incomes})
data_2 = data[data.age <= 45]
ppd_data = (data_2- data_2.mean(axis=0))/data_2.std(axis=0)
X = torch.Tensor(ppd_data.age.values.reshape(-1,1))
ans = torch.Tensor(ppd_data.income.values.reshape(-1,1))
(学習実行)
model = nn.Linear(1,1)
loss = nn.MSELoss(reduction='none')
y = [0, ppd_data.income.max()/4., ppd_data.income.max()/2.]
quantile = QuantileLayer(0.1, torch.Tensor(y))
optimizer = optim.Adam(model.parameters(), lr=0.1)
MAX_ITER = 100
for i in range(MAX_ITER):
optimizer.zero_grad()
pred = model(X)
loss_value = loss(pred, ans).view(1,-1) # 1, seq_len( = data size)
# 誤差の中央値を計算
quantile_loss = quantile(loss_value, 0.5, 0.1, 10)
print(quantile_loss)
quantile_loss.backward()
optimizer.step()
学習したモデルと、誤差の「平均」を最小化する通常の回帰モデルのフィッティング結果を比較したものは、次のようになります。中央値を最小化(=quantile)したモデルの方が外れ値データに引きずられにくいことが確認できます。
4. まとめ
この記事ではソートを微分する方法と、ソートの一般化として微分可能な形でquantile関数を計算する方法を説明しました。また簡単な応用としてleast quantile regression というタスクを勾配法で解き、誤差の「平均値」の最小化するモデルと「中央値」を最小化するモデルの違いを観察しました。
冒頭で紹介したように、ソートの微分はビームサーチの微分などより幅広い応用が見込め、今後さまざまな論文で見かけることになるかもしれません。