LoginSignup
0
2

More than 1 year has passed since last update.

factorization machines をスクラッチ実装で大解剖

Last updated at Posted at 2023-04-26

factorization machines

\hat{y}(\boldsymbol{x}) := w_{0} + \sum_{i=1}^n w_i x_i + \sum_{i=1}^{n-1} \sum_{j=i+1}^{n} <\boldsymbol{v}_i,\boldsymbol{v}_j>  x_ix_j\hspace{50pt}(1)  

・$w_0$: バイアス項
・$w_i$: パラメータ
・$<\boldsymbol{v}_i,\boldsymbol{v}_j>$: 交互作用パラメータ (それぞれのベクトル内積)

(1)のような定式になります。回帰モデルにちょっと変わった交互作用項がつくイメージ。(1)3項目は変形すると計算量を$O(D^2)$から$O(kD)$まで落とすことができ、線形計算可能です。

変形すると、

\hat{y}(\boldsymbol{x}) := w_{0} + \sum_{i=1}^n w_i x_i + \frac{1}{2}\sum_{f=1}^k \Biggl(\biggl(\sum_{i=1}^n v_{i,f}x_i\biggr)^2-\sum_{i=1}^n v_{i,f}^2x_i^2\Biggr)\hspace{50pt}(2)

$k$は因子数です。因子行列$V$から2本のベクトルを順に内積計算してくのではなく、要素を一つ一つ取り出して計算してくイメージですね。(ここの変形まだ理解不足...)

matrix factorizationとの比較

image.png

FMは評価値行列とともにユーザー情報(年齢、性別...)やアイテム情報(商品名、コンテンツタイプ...)もモデルに組み込める。例えば、$a_1$に年齢。$a_2$に評価した日時などを組み込める。
評価値履歴をOne-hot エンコーディングすることで、ユーザー、アイテム情報とともに、定式(1)3項目で交互作用、レコメンドでいう協調具合いを学習し、最適な$V$を得る。

また、特徴量がスパースな場合FMは分類問題にもよく使われるらしい。この場合、SVMよりも上手く交互作用項を学習できるらしい。(理解不足)

movielensの評価値予測問題に落とし込んでみる

導入

はじめに定式(1)をみたとき、なにが、matrix factorizationを一般化されているのかイマイチ理解できなかった。その原因は定式(1)3項目にあった。スクラッチでモデルを実装することで、アルゴリズムの理解を深める。

まずは、目的関数を定式化する。 今回はmovielens評価値が1~5までの値を取ることから、回帰問題とし、損失関数は二乗誤差とする。

argmin_{\theta} \hspace{10pt}J = \sum_{i}^n \biggl(y_i - \hat{y_i}(\boldsymbol{x}) \biggl)^2 \hspace{40pt} (3)

この$J$を最小にするようなパラメータ$\theta$を見つけましょう。
そして、最適化アルゴリズムには、勾配降下法を用いましょう。


\theta^{new} := \theta^{old} - \eta \frac{\partial }{\partial \theta}J \hspace{40pt} (4)

$\eta$は学習率です。そして、

\frac{\partial }{\partial \theta}J = -2 \sum_{i}^n \frac{\partial \hat{y_i}(\boldsymbol{x})}{\partial \theta} \biggl(y_i - \hat{y_i}(\boldsymbol{x}) \biggl) \hspace{40pt} (5)

そして、

\frac{\partial \hat{y}(\boldsymbol{x})}{\partial \theta}= \begin{cases}
    1 & (if \ \theta \ is \ w_0) \\
    x_i & (if \ \theta \ is \ w_i) \\
    x_i \sum_{j=1}^n v_{j,f}x_j - v_{i,f}x_i^2 & (if \ \theta \ is \ v_{i,f})
  \end{cases} \hspace{40pt} (6)

定式(2)を各パラメータで偏微分した形です。最適な$\theta$を見つけるとは、つまり$w_0$,$w_i$,$v_{i,f}$を見つけると同義です。((6)でのiはサンプルインデックスではなく、特徴量インデックスなので注意)。1,2項目の偏微分は比較的簡単です。3項目について、

\frac{\partial \hat{y}(\boldsymbol{x})}{\partial v_{i,f}} = \frac{1}{2}\biggl(2x_i\sum_{j=1}^n v_{j,f}x_j\biggr)-2x_i^2 v_{i,f}\\
=  x_i \sum_{j=1}^n v_{j,f}x_j - v_{i,f}x_i^2

となりますね。
そしたら、(6)の各勾配を(5)にぶち込み、(4)の更新式にも代入すると、

w_0^{new} := w_0^{old} + 2\eta \sum_{j=1}^{n}e_{j}\\
w_i^{new} := w_i^{old} + 2\eta \boldsymbol{x} \boldsymbol{e}\\
v_{i,f}^{new} := v_{i,f}^{old} + 2\eta\biggl( x_i \sum_{j=1}^n v_{j,f}x_j - v_{i,f}x_i^2\biggr)

となります($e$は評価値と予測値の誤差($y-\hat{y}$)であることに注意。変形してます!)。定式(5)より更新式の符号が変わります。また、各パラメータは独立に更新できるのもポイントです。(MFは$\boldsymbol{p}$と$\boldsymbol{q}$の更新が依存してる。)
そして、$v_{i,f}$の更新に注目してみると、ある特徴量$x_{i}$が0のとき、更新されずパスされてます。$x_i$が存在するとき、$v$と$x$の内積値を使い更新。$x_i$が0ではないときにだけ、$v_{i,f}$を更新することで上手く交互作用を学習できそうです。また、スパースな特徴量でも計算が重くならないのがこの更新式からわかると思います。

スクラッチ実装

スクラッチでと謳ってますが、ベクトル、行列計算はnumpyを使います。

class FM:
    def __init__(
        self, 
        epochs=50, 
        n_factors=3, 
        learning_rate=1e-6,
        random_seed=1234,
    ):
        self.epochs = epochs
        self.n_factors = n_factors
        self.learning_rate = learning_rate
        self.random_seed = random_seed
    
    def fit(self, train, test) -> None:
        
        train_X, train_y = train
        test_X, test_y = test
        
        train_size, self.n_features = train_X.shape
        self.w0 = 0.0
        self.wi = np.zeros(self.n_features)
        
        np.random.seed(self.random_seed)
        self.V = np.random.normal(
            scale=1/np.sqrt(self.n_factors),
            size=(self.n_factors, self.n_features)
        )
        
        batch_size = round(train_size/8)
        
        # バッチSGDで学習する
        self.train_loss, self.test_loss = [], []
        for _ in tqdm(range(self.epochs)):
            
            batch_index = np.random.permutation(train_size)[:batch_size]
            batch_X, batch_y = train_X[batch_index], train_y[batch_index]
            
            batch_y_hat = np.array([self.predict(x) for x in batch_X])
            error = batch_y - batch_y_hat                        
            
            for column_index in range(batch_size):
                
                e =  error[column_index]
                
                # update w0
                w0_grad = 1.0
                self.w0 += 2*self.learning_rate*e*w0_grad
                
                x_i = batch_X[column_index,:]
                
                for feature in range(self.n_features):
                    
                    x = batch_X[column_index, feature]
                    if x == 0.0:
                        continue
                    
                    # update wi
                    wi_grad = x
                    self.wi[feature] += 2*self.learning_rate*e*wi_grad
                                         
                    for factor in range(self.n_factors):
                        
                        # update V    
                        V_grad = x * (self.V[factor,:] @ x_i) - self.V[factor, feature] * (x ** 2)
                        self.V[factor, feature] += 2*self.learning_rate*e*V_grad           
            
            train_y_hat = np.array([self.predict(x) for x in train_X])
            train_rmse = mean_squared_error(train_y, train_y_hat, squared=False)
            self.train_loss.append(train_rmse)
            
            test_y_hat = np.array([self.predict(x) for x in test_X])
            test_rmse = mean_squared_error(test_y, test_y_hat, squared=False)
            self.test_loss.append(test_rmse)
        
    def predict(self, x):    
        # 2項目
        linear_out = x@self.wi    
        # 3項目
        factor_out = 0.5*np.sum(np.array([(self.V[factor,:]@x)**2 - (self.V[factor,:]**2)@(x**2) for factor in range(self.n_factors)]))
        
        return self.w0 + linear_out + factor_out
        
    def plot_train_curve(self, ylim_min=0.0, ylim_max=6.0):
        plt.subplots(1, figsize=(8,6))
        plt.plot(np.arange(len(self.train_loss)), self.train_loss, label=f"Train RMSE (last: {self.train_loss[-1]:.2f})", linewidth=3)
        plt.plot(np.arange(len(self.test_loss)), self.test_loss, label=f"Test RMSE, (last: {self.test_loss[-1]:.2f})", linewidth=3)
   
        plt.title("Train/Test Curves", fontdict=dict(size=20))
        plt.xlabel("Number of Epochs", fontdict=dict(size=20))
        plt.ylabel("Root Mean Squared Error", fontdict=dict(size=20))
        plt.ylim([ylim_min, ylim_max])
        plt.tight_layout()
        plt.legend(loc="best", fontsize=20)
        plt.show()

fit,predictメソッドでの計算部分は導入の通りになります。予測式第三項は線形に変換してます。また今回はバッチSGDで学習してます。

import pandas as pd

rating_df = pd.read_csv("../data/movielens/train_rating.csv")
rating_df.head()

スクリーンショット 2023-04-26 18.28.07.png

np.random.seed(1234)
index = rating_df.index.values
np.random.shuffle(index)
index = index[:1000]

rating_df = rating_df.iloc[index]
X = rating_df[["user_id","movie_id", "timestamp"]].values
y = rating_df["rating"].values

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234)

X_all = np.vstack((X_train, X_test))

# ワンホットエンコーディングする
encoder = OneHotEncoder()
encoded_X_all = encoder.fit_transform(X_all[:,:2]).toarray()

timestamp = X_all[:,-1:]
scaler = StandardScaler()
timestamp = scaler.fit_transform(timestamp)

X_all = np.concatenate([encoded_X_all, timestamp], axis=1)
# 再度trainデータとtestデータに分割する
X_train = X_all[:len(X_train)]
X_test = X_all[len(X_train):]

movielensデータセットから1000のサンプルを抽出して、正規化したtimestampを評価値に加えてFMにもたせます。trainデータとtestデータに分けて学習していきます。

model = FM(
    epochs=100,
    learning_rate=5e-4, 
    n_factors=500
)
train, test = (X_train, y_train), (X_test, y_test)
model.fit(train, test)
model.plot_train_curve(ylim_min=0.0, ylim_max=4.0)

スクリーンショット 2023-04-26 18.38.58.png

train, testデータともに二乗誤差の平均が1.0を下回る感じです。60エポックらへんから過学習し始めていることがわかります。
matrix factorizationに比べると、収束が早く安定感がある気がします(正則化必要?ってくらい笑)。交互作用項が結構大きい役割をしていると考えられます。

以上!

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2