Edited at

KaggleのTitanic問題をFactorization Machinesで解く

More than 1 year has passed since last update.

Factorization MachinesでKaggleのTitanic問題を解いてみたので、メモとして解いた手順をここに残そうと思います。


Factorization Machinesとは

Factorization Machines (FM)は機械学習の手法の1つです。ざっくり説明すると、疎な学習データに強くて、特徴量の相互作用をうまく取り込める手法です。詳しいことはGoogle先生などで調べてください。

https://github.com/ibayer/fastFM

にあるライブラリを使います。scikit-learn APIがあるので、scikit-learnを使う場合と同じようにして、学習・予測を行うことができます。インストールは

$ pip install fastFM

で行います。


FMでTitanic問題を解く


準備

まずはデータをpandas.DataFrame形式で取り込みます。

import pandas as pd

df_train = pd.read_csv('/path/to/train.csv')
df_test = pd.read_csv('/path/to/test.csv')

学習データ・テストデータから学習で使う特徴量を作る関数を定義します。返り値の型はpandas.DataFrameです。疎な学習データを意識して作ってますが、特徴量の選定は適当です。

def make_feature(df):

# one-hot
sex = pd.get_dummies(df['Sex'])
pclass = pd.get_dummies(df['Pclass'])
emb = pd.get_dummies(df['Embarked'])

parch = df['Parch']

# non-null: 1 null: 0
cabin = (~df['Cabin'].isnull()).astype(int)

df_feature = pd.concat(
[
sex,
pclass,
parch,
cabin,
emb
],
axis=1
)

return df_feature


学習

FMを使って学習します。今回のタスクは2値分類であることを考慮して、Logit Classification with SGD Solverを使います。

まずモデルを定義します。パラメータはドキュメントに載っていた値をそのまま使用しています。

from fastFM import sgd

model = sgd.FMClassification(n_iter=1000, init_stdev=0.1, l2_reg_w=0,
l2_reg_V=0, rank=2, step_size=0.1)

次にこのモデルに学習データを流し込みますが、


  • 学習データの入力データは疎行列でしないといけない

  • 学習データの出力データは{-1, 1}にしないいけない

ということに注意しなければなりません。これを踏まえて、学習は次のように行います。

import scipy as sp

x_train = make_feature(df_train).values
x_train = sp.sparse.csr_matrix(x_train)

y_train = df_train['Survived'].values # {0, 1}
y_train = 2 * y_train - 1 # {0, 1} -> {-1, 1}

model.fit(x_train, y_train)


予測

あとはテストデータを使って予測するだけです。ただし、予測結果は{-1, 1}なので、Kaggleに提出する際には{0, 1}に戻さないといけません。

x_test = make_feature(df_test).values

x_test = sp.sparse.csr_matrix(x_test)

y_pred = (model.predict(x_test) + 1) / 2
y_pred = y_pred.astype(int)

最後にCSVに保存して、提出の準備完了です。

df_pred = pd.DataFrame(

{
'PassengerId': df_test['PassengerId'],
'Survived': y_pred
}
)

df_pred.to_csv('submission.csv', index=False)


どれくらいスコアが出るの?

スコアは0.75598でした。しょぼい…