42
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

ベイジアンなFactorization MachineでMovieLens 100k, 1MのSOTAと張り合う

Posted at

TL;DR

この論文を読んで、今更Bayesian Factorization Machineに興味を持って、再現しようとしたけどlibFMが辛かった。

ので自分でライブラリを作って遊んでたら、結構いい結果が出て、世に出回ってるSOTAアルゴリズムの結果と張り合うことができた。せっかくなので使ってほしい。

はじめに

みなさん行列分解してますか?私は行列分解が好きです!
行列分解とかFactorization Machine (FM, [Rendle, '10]) は10年くらい前にテクニックとしては確立していたんですが、2019年になって
On the Difficulty of Evaluating Baselines: A Study on Recommender Systems
という、なかなか凄い組み合わせ(みなさんGoogle Researchにいらっしゃるんですね。。アベンジャーズのようだ)による論文が出ました。ここでは

  • MovieLens10Mでよくベースラインとして引用されていたベイジアン行列分解(BPMF)のRMSEスコアはどうも不当に低かった
  • ちゃんと潜在次元数などを設定したら、すべての論文を上回るパフォーマンスに
  • さらに、SVD++や日付変数などもFactorization Machineで考慮したら、もっと良くなった。

ことが報告されております。(なお、論文の主旨は「行列分解最強!」ではなく、Netflixとの対比で「リサーチコミュニティでは新規性のみが重視されるのが問題やで」ということのようです)

というわけで、さっそく論文の結果を試してみたい!と思ったのですが、論文のすべての計算に使われているlibFMが使いづらい。。。Pythonラッパーもあるのですが、肝心のRelational Data Format(後述)に対応していなかったり、そもそも仕組みが別プロセスでlibFMを呼び出すだけだったり、モデルの保存が点推定だけだったり、とちょっと私的には難がある感じでした。fastFMもRelational Data formatに対応してない。。。

ということで、これはもう自分で作るしかないと思い立って作ったのが"俺のFactorization Machine" ことmyFMです。

Bayesian Factorization Machine について

myFMの機能について説明する前に、まずは Bayesian Factorization Machine ([Freudenthaler et al., '11]についておさらいしたいと思います。

Bayesian Factorization Machineでは、Gibbsサンプリングを用いた学習を行います 。
このため、通常の機械学習モデルとはやや毛色が異なり、「学習結果のパラメータ」とは、
FMの単一のパラメータ(点推定)ではなく、事後分布からのサンプル(FMのパラメータn_sample個分)になります。

  • 利点
    • チューンするパラメータが潜在次元数($r$と書きます)とイテレーション数 n_iterだけ!
      => 正則化係数などのハイパーパラメータは自動的に決定される
      また、n_iter $\simeq$ n_sample は大きければ大きいほどよいはずです。
    • n_sample個の事後サンプルは、各々は割とノイジーで弱いモデルだが、
      十分大きなn_sampleに対する事後平均をとると多くの場合SGDより高精度な結果が得られる
  • 欠点
    • 予測結果モデルは、単一のパラメータではなく事後分布からのサンプル n_sample 個になる。
      そのため、モデルを保存するには同じ次元のSGDやALS解と比べて n_sample倍だけメモリやストレージを消費する。また、予測に n_sample 倍だけ時間がかかる(ただし並列化可能)。
    • SGDと比べると(定数倍)遅いことが多い

こうした利点から、本気で精度を出したい場合にはBayesian FMを用いる利点は大きいでしょう。

パラメータを保存しようとすると若干メモリなどを消費しますが、あらかじめ予測したいデータが定まっているバッチ処理などだったら、各MCMCイテレーションで予測を行えば節約ができます。

Relational Format

FMの計算量は入力スパース行列$X$の非零要素数$N_Z(X)$に比例します。ので、通常の行列分解に、ユーザー・アイテムの属性情報を入れていくと計算量がどんどん増大してしまいます。さらに、SVD++([Koren, '08])のように「ユーザーが接触したアイテムを全部持ってきて特徴量にする!」ということをし始めると計算量は爆発的に増大してしまいます。

この状況に対処してSVD++みたいなことを効率よく行うために、[Rendle, '13]では、Relational Data Formatを提案しました。これは映画ID=1に対応する補助情報が

ジャンル=[アクション], 公開年=1988, この映画を見たユーザー=[42, 87, 137, ...]

だったとしたら、$X$に映画1は常にこの組み合わせで現れる、という事実をうまく使って計算量を劇的に減らす手法です。SVD++のようにimplicit feedbackを活用するなら必須と言える技法ですが、今のところオリジナルのlibFMにしか実装されていないようです。

myFMの機能

MCMCをサポートしているFactorization Machineのライブラリの対応機能を表にまとめるとこんな感じでしょうか。

libFM fastFM myFM
Python ラッパー △(非公式)
変数のグルーピング ×
Relational data format x
事後サンプルの保存 x x
SGDのサポート ×

「変数のグルーピング」とは、例えばユーザーIDに対応するベクトルたちとアイテムIDに対応するベクトルたちでは分散(正則化係数)が違うかも、ということを想定することです。

SGD系のアルゴリズムは今はTensorflowで簡単に実装できるだろう、ということで度外視していますが、Pythonから簡単に使えてlibFMと同等の機能があるのはmyFMだけ!という心づもりで開発しました。

myFMのインストール

Linux/Macだと、ちゃんとした(C++11対応してる)コンパイラがあれば

pip install git+https://github.com/tohtsky/myFM

でインストールされるはずです。ちなみに、上のコマンドは勝手にEigenをダウンロードして使います。

MovieLens100kのコード例

GitHubのReadmeにあるコードを転載しますが、

import numpy as np
from sklearn.preprocessing import OneHotEncoder
from sklearn import metrics

import myfm

# このデータ読み出しクラスは `examples/` 内で定義されている
from movielens100k_data import MovieLens100kDataManager
data_manager = MovieLens100kDataManager()
df_train, df_test = data_manager.load_rating(fold=3) # Note the dependence on the fold

def test_myfm(df_train, df_test, rank=8, grouping=None, n_iter=100, samples=95):
    explanation_columns = ['user_id', 'movie_id']
    ohe = OneHotEncoder(handle_unknown='ignore')
    X_train = ohe.fit_transform(df_train[explanation_columns])
    X_test = ohe.transform(df_test[explanation_columns])
    y_train = df_train.rating.values
    y_test = df_test.rating.values
    fm = myfm.MyFMRegressor(rank=rank, random_seed=114514)

    if grouping:
        # User ID と Movie IDに別の分散を割り当てる
        grouping = [ i for i, category in enumerate(ohe.categories_) for _ in category]
        assert len(grouping) == X_train.shape[1]

    # Gibbs サンプリングしてサンプルパラメータを貯める
    fm.fit(X_train, y_train, grouping=grouping, n_iter=n_iter, n_kept_samples=samples)

    # 事後平均を計算
    prediction = fm.predict(X_test)
    rmse = ((y_test - prediction) ** 2).mean() ** .5
    mae = np.abs(y_test - prediction).mean()
    print('rmse={rmse}, mae={mae}'.format(rmse=rmse, mae=mae))
    return fm

# basic regression
test_myfm(df_train, df_test, rank=8);
# rmse=0.90321, mae=0.71164

# with grouping
fm = test_myfm(df_train, df_test, rank=8, grouping=True)
# rmse=0.89594, mae=0.70481

という感じで動きます。手持ちのラップトップで3秒程度で学習・予測が終わりました。
なんちゃってsklearn風にメソッドは作ったのですが、中では、学習時にGibbsサンプリング + 予測時に事後平均の計算が行われています。

出てくる精度を例えばSurpriseのトップページにある結果と比べてみますと、RMSE値0.8954 vs 0.9320でmyFMが圧勝していることが分かります。やったぜ。

myFMでSOTAと戦う

MovieLens 100k

ここによるとMovieLens 100kのRMSE値のSOTAは 0.905らしいですが、これには既に上の単純な行列分解だけで勝ててしまっています。また、RLFMという手法を提案した10年前の論文だと既に(補助情報込みで)RMSE=0.896とか出ています。新しいSOTAの方が劣化してますね。。。

myFMで通常のユーザーID、映画IDに加えて、ユーザー映画の補助情報とSVD++風の特徴量、さらに評価の日時情報を加味したRMSE値は一番難しかったfold=1で 0.886, fold=3では前人未踏(?)の0.87台にまで精度を上げることができました!

コードはこちらにありまして、Relational Data Formatで如何に速度が改善されるかも示されています。

結局MovieLens 100kについても、10年以上SOTAは変わっていなかったということでしょうか。

MovieLens 1M

MovieLens 100kと同様にここの結果とmyFMによる推論結果を戦わせました。一番いい結果を出していたSparse FCはハイブリッドではないらしいので、SVD++ライクな特徴量と時刻情報のみを加えた結果がこちらになります。

RMSE値=0.821が出て、一応Sparse FCを上回っています!ただし、今回は次元をチューニングをしなければならず、正直交差検証が面倒だったので、次元数=32に落ち着くまでテストスコアを見ながらチューニングしてしまいました。。ただ、適当にやってもGraph Convolution系のものには勝てました。

感想

自分で実装してみると、Factorization Machineが共役分布の範囲内で如何によく考えられたモデルであるか、を痛感させられて、非常に楽しむことが出来ました。これの実装を理解することが、例えばPRMLの11章の最高の演習問題なのでは、という感じですね。

それにしてもSOTAを追いかけるって文献調査的にも厳しいですね。。こういうところ以外で勝負したいと思いました。

42
18
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
42
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?