Python
MachineLearning
matrix-factorization
NMF

欠損値のある行列の非負値行列因子分解

欠損値のある行列の非負値行列因子分解

非負値行列因子分解(NMF)はQiitaでもいくつか紹介されていますが、
欠損値のある場合についての情報が見当たらなかったので、紹介します。

参考文献

Lee, D. D., & Seung, H. S. (n.d.). Algorithms for Non-negative Matrix Factorization.

http://r9y9.github.io/blog/2013/07/27/nmf-euclid/

http://sap.ist.i.kyoto-u.ac.jp/members/yoshii/slides/mus91-tutorial-nmf.pdf

問題設定

NMFでは、二次元の計測データYについて以下のコスト関数を最小化する非負の行列 W および H を求めます。

  • 二乗距離
    $$
    \left|| \mathrm{Y} - \mathrm{WH} \right||^2
    = \frac{1}{2}\sum_{i j} \left\{
    y_{i j} - \sum_k w_{i k} h_{k j}
    \right\}^2
    \tag{1}
    $$

  • 一般化KL距離
    $$
    \mathrm{KL}\left[\mathrm{Y} || \mathrm{WH} \right]
    = \sum_{i, j} \left\{
    y_{i j} \log\frac{y_{i j}}{\sum_k w_{i k} h_{k j}}
    - y_{i j} + \sum_k w_{i k}h_{k j}
    \right\}
    \tag{2}
    $$

欠損値があるデータYの場合を考えます。
上式の総和の部分を欠損でない要素の組
$\mathcal{I}=\{(i_0, j_0), (i_1, j_1)... \}$
についての総和に置き換えます。

  • 二乗距離
    $$
    \begin{aligned}
    \left|| \mathrm{Y} - \mathrm{WH} \right||^2
    &= \frac{1}{2}\sum_{i, j \in \mathcal{I}} \left\{
    y_{i j} - \sum_k w_{i k} h_{k j}
    \right\}^2 \\
    &= \frac{1}{2}\sum_{i, j}\left\{
    \left(y_{i j} - \sum_k w_{i k} h_{k j}\right)m_{i,j}
    \right\}^2
    \end{aligned}
    \tag{3}
    $$

  • 一般化KL距離
    $$
    \begin{aligned}
    \mathrm{KL}\left[\mathrm{Y} || \mathrm{WH} \right]
    &= \sum_{i, j\in \mathcal{I}} \left\{
    y_{i j} \log\frac{y_{i j}}{\sum_k w_{i k} h_{k j}}
    - y_{i j} + \sum_k w_{i k}h_{k j}
    \right\}\\
    &= \sum_{i, j} \left\{
    \left(y_{i j} \log\frac{y_{i j}}{\sum_k w_{i k} h_{k j}}
    - y_{i j} + \sum_k w_{i k}h_{k j}\right)m_{i j}
    \right\}
    \end{aligned}
    \tag{4}
    $$

ここでそれぞれの二行目の式は、有効な要素に関する総和をマスク行列
$\mathrm{M}=\{m_{i j}\}$ で置き換えたものです。
マスク行列M は0と1からなる行列で、Yの欠損値に対応する要素には0、それ以外には1 が入っています。

更新式

二乗距離、一般化KL距離を採用した時のそれぞれの更新式は以下のようになります。
なお、これらの導出については後ほどで述べます。

二乗距離(欠損値なし)

$$
w_{i, k} \gets w_{i k}
\frac{
\sum_j y_{i j} h_{k j}
}
{
\sum_{j k'} w_{i k'} h_{k' j} h_{k j}
}
$$

$$
h_{k j} \gets h_{k j}
\frac{
\sum_i w_{i k} y_{i j}
}
{
\sum_{i k'} w_{i k} w_{i k'} h_{k' j}
}
$$

Python では、行列積と要素演算を使うと非常に簡単に記述できます。

def update_squared_loss(Y, W, H):
  """  Update W and H with squared loss  """
  W = W * np.dot(Y, H.T) / np.dot(np.dot(W, H), H.T)
  H = H * np.dot(W.T, Y) / np.dot(W.T, np.dot(W, H))
  return W, H

二乗距離(欠損値あり)

$$
w_{i k} \gets w_{i k}
\frac{
\sum_j y_{i j} m_{i j} h_{k j}
}
{
\sum_{j k'} w_{i k'} h_{k' j} h_{k j} m_{i j}
}
$$

$$
h_{k j} \gets h_{k j}
\frac{
\sum_i w_{i k} y_{i j} m_{i j}
}
{
\sum_{i k'} w_{i k} w_{i k'} h_{k' j} m_{i j}
}
$$

Python 擬似コード

def update_squared_loss_w_mask(Y, W, H, M):
  """  Update W and H with squared loss with missing values """
  W = W * np.dot(Y * M, H.T) / np.dot(np.dot(W, H) * M, H.T)
  H = H * np.dot(W.T, Y * M) / np.dot(W.T, np.dot(W, H) * M)
  return W, H

一般化KL距離(欠損値なし)

$$
w_{i k} \gets w_{i k}
\frac{
\sum_j \frac{y_{i j}}{\sum_{k'}w_{i k'} h_{k' j}} h_{k j}
}
{
\sum_{j} h_{k j}
}
$$

$$
h_{k j} \gets h_{k j}
\frac{
\sum_i \frac{y_{i j}}{\sum_{k'}w_{i k'} h_{k' j}} w_{i k}
}
{
\sum_{i} w_{i k}
}
$$

Python 擬似コードは以下のとおりです。

def update_kl_loss(Y, W, H):
  """  Update W and H with squared loss  """
  W = W * np.dot(Y / np.dot(W, H), H.T) / np.sum(H, axis=1, keepdims=True)
  H = H * np.dot(W.T, Y / np.dot(W, H)) / np.sum(W, axis=0, keepdims=True)
  return W, H

一般化KL距離(欠損値あり)

$$
w_{i k} \gets w_{i k}
\frac{
\sum_j \frac{y_{i j} m_{i j}}{\sum_{k'}w_{i k'} h_{k' j}} h_{k j}
}
{
\sum_{j} h_{k j} m_{i j}
}
$$

$$
h_{k j} \gets h_{k j}
\frac{
\sum_i \frac{y_{i j} m_{i j}}{\sum_{k'}w_{i k'} h_{k' j}} w_{i k}
}
{
\sum_{i} w_{i k} m_{i j}
}
$$

Python 擬似コードは以下のとおりです。

def update_kl_loss_w_mask(Y, W, H, M):
  """  Update W and H with squared loss  """
  W = W * np.dot(Y * M / np.dot(W, H), H.T) / np.dot(M, H.T)
  H = H * np.dot(W.T, Y * M / np.dot(W, H)) / np.sum(W.T, M)
  return W, H

更新式の導出

二乗距離(欠損値なし)

まず、コスト関数(1)を W に関して最小化することにします。
Wに関するコスト関数を F(W) とおくと、

$$
\begin{aligned}
\mathrm{F(W)} &= \frac{1}{2}\sum_{i j} \left\{
y_{i j} - \sum_k w_{i k} h_{k j}
\right\}^2 \\
&= \frac{1}{2}\sum_{i j} \left\{
y_{i j}^2 - 2 y_{i j} \sum_k w_{i k} h_{k j}
+ \left( \sum_k w_{i k} h_{k j} \right)^2
\right\}
\end{aligned}
\tag{5}
$$
のように書けます。

方針としては、F(W) を上から抑える単純な関数Gを求め、その関数について W を最適化します。
具体的には、以下の性質を持つ関数 G(W; W') を探します。

$$
\mathrm{
F(W) = G(W; W), \;\;\; F(W) \leq G(W; W')
}
\tag{6}
$$

上から抑える関数のイメージについては、
チュートリアル:非負値行列因子分解
などを参考にしてください。

ここで、Jensenの不等式を思い出します。
Jensenの不等式では 下に凸の関数 $f(x)$ に対して、 $\sum_k r_k = 1$ となる $r_k$ を用いて

$$
f\left(\sum_k r_k x_k\right)
\leq
\sum_k r_k f(x_k)
$$

の関係が成り立ちます。
これを式5の第3項に適用すると $x^2$ が下に凸であることから、

$$
\begin{aligned}
\left( \sum_k w_{i k} h_{k j} \right)^2
&= \left( \sum_k r_{i j k} \frac{w_{i k} h_{k j}}{r_{i j k}} \right)^2 \\
&\leq \sum_k r_{i j k} \left(\frac{w_{i k} h_{k j}}{r_{i j k}} \right)^2
\end{aligned}
$$

を示せます。
なお $r_{i j k}$ は $\sum_k r_{i j k} = 1$ を満たす定数です。
天下り的ですが、
$$
r_{i j k} = \frac{w_{i k'} h_{k j}}{\sum_{k'} w_{i k'} h_{k' j}}
$$
としたときに、以下が式6を満たすことを示します。
$$
\begin{aligned}
\mathrm{G(W; W')}=
\frac{1}{2}\sum_{i, j} \left\{
y_{i j}^2 - 2 y_{i j} \sum_k w_{i k} h_{k j}
+ \sum_k r_{i j k} \left(\frac{w_{i k} h_{k j}}{r_{i j k}}\right)^2
\right\}
\end{aligned}
\tag{7}
$$

まず、 $\mathrm{F(W) \leq G(W; W')}$ は上に述べた通り、Jensenの不等式から示すことができます。
また $\mathrm{F(W) = G(W; W)}$ も第3項を約分すると示せます。
そのため式7を最小化するWを用いると、$\mathrm{F(W)}$ は必ず $\mathrm{F(W')}$ より小さくなることがわかります。

式7を最小化するWは、勾配がゼロの点から求めます。

$$
\begin{aligned}
\frac{\partial \mathrm{G(W; W')}}{w_{i k}} = 0
\;\;\to\;\;
w_{i k} &= \frac{\sum_j h_{k j}}{\sum_{k, j} \frac{h_{k j}^2}{r_{i j k}}}
= w'_{i k} \frac{\sum_j y_{i j} h_{k j}}
{\sum_{j, k'} w'_{i k} h_{k' j} h_{k j}}
\end{aligned}
$$
となって、上で述べた更新式が示せます。

また、式の対称性を考えるとHに関する更新式も同様に示せます。

二乗距離(欠損値あり)

欠損値ありの場合も、マスク行列を用いることで同様に更新式を導くことができます。

$$
\begin{aligned}
\mathrm{G(W; W')} =
\frac{1}{2}\sum_{i, j} \left\{
y_{i j}^2 m_{i j}^2 - 2 y_{i j} \sum_k w_{i k} h_{k j} m_{i j}
+ \sum_k {r_{i j k}}\left(
\frac{w_{i k} h_{k j}}{r_{i j k}}
\right)^2 m_{i j}
\right\}
\end{aligned}
$$

上式が式6の関係を満たすため、勾配がゼロになる点を探すと

$$
\begin{aligned}
\frac{\partial \mathrm{G(W; W')}}{w_{i k}} = 0
\;\;\to\;\;
w_{i k} &= \frac{\sum_j h_{k j} m_{i j}}
{\sum_{k j} \frac{h_{k j}^2}{r_{i j k}} m_{i j}}
= w'_{i k} \frac{\sum_j y_{i j} m_{i j} h_{k j}}
{\sum_{j k'} w'_{i, k} h_{k' j} h_{k j} m_{i j}}
\end{aligned}
$$

となり、上の更新式を示せます。

一般化KL距離(欠損値なし)

同様に、コスト関数(2)を W に関して最小化することにします。
Wに関するコスト関数を F(W) とおくと、

$$
\begin{aligned}
\mathrm{F(W)} &=
\left\{
y_{i j} \log y_{i j} - y_{i j} \log \sum_k w_{i k} h_{k j}
+ \sum_k w_{i k} h_{k j}
\right\}
\end{aligned}
$$

右辺括弧内第二項を上から抑える関数を探すことにします。
$-\log x$ が下に凸であることから、
$$
-\log \sum_k w_{i k} h_{k j} \leq
-\sum_k r_{i j k} \log \frac{w_{i k} h_{k j}}{r_{i j k}}
$$
を満たします。
なお $r_{i j k}$ は $\sum_k r_{i j k} = 1$ を満たす定数であり、
$$
r_{i j k} = \frac{w_{i k} h_{k j}}{\sum_{k'} w_{i k'} h_{k' j}}
$$
としたときに等号が成立します。

そのため、以下の $\mathrm{G(W; W')}$ は式6の関係をを満たします。
$$
\begin{aligned}
\mathrm{G(W; W')} &=
\left\{
y_{i j} \log y_{i j}
- y_{i j} \sum_k r_{i j k} \log \frac{w_{i k} h_{k j}}{r_{i j k}}
+ \sum_k w_{i k} h_{k j}
\right\}
\end{aligned}
$$

上式の勾配をゼロにする W とすることが、コスト関数を減少させる更新式を求めることになります。
$$
\begin{aligned}
\frac{\partial \mathrm{G(W; W')}}{w_{i k}} = 0
\;\;\to\;\;
w_{i k} &= \frac{\sum_j y_{i j} r_{i j k}}
{\sum_j h_{k j}}
= w'_{i k} \frac{\sum_j \frac{y_{i j}}{\sum_{k'} w'_{i k'} h_{k' j}} h_{k j}}
{\sum_j h_{k j}}
\end{aligned}
$$

一般化KL距離(欠損値あり)

これも上と同様に、マスク行列を用いることで更新式を得ることができます。

$$
\begin{aligned}
\mathrm{G(W; W')} &=
\left\{
y_{i j} m_{i j} \log y_{i j}
- y_{i j} m_{i j} \sum_k r_{i j k} \log \frac{w_{i k} h_{k j}}{r_{i j k}}
+ m_{i j} \sum_k w_{i k} h_{k j}
\right\}
\end{aligned}
$$

上式が式6の関係を満たすため、勾配がゼロになる点を探すと

$$
\begin{aligned}
\frac{\partial \mathrm{G(W; W')}}{w_{i k}} = 0
\;\;\to\;\;
w_{i k} &= \frac{\sum_j y_{i j} m_{i j} r_{i j k}}
{\sum_j h_{k j} m_{i j}}
= w'_{i k} \frac{\sum_j \frac{y_{i j} m_{i j}}
{\sum_{k'} w'_{i k'} h_{k' j}} h_{k j}}
{\sum_j h_{k j} m_{i j}}
\end{aligned}
$$

となり、上の更新式を示せます。

まとめ

NMFを協調フィルタリングに用いるためには、欠損値を扱える必要が有りますが、
簡単にアクセスできる情報がなかったので、ここにまとめました。

欠損値を扱うことができると、交差検証法も使えることになります。
因子の数を決める手段の1つになるのではないでしょうか。