10
1

More than 1 year has passed since last update.

深層学習ライブラリ DeepCTR で映画のレコメンドを試す

Last updated at Posted at 2021-12-24

はじめに

こんにちは。株式会社エイアイ・フィールドの友野と申します。
この記事はエイアイ・フィールドアドベントカレンダー2021の最終日の記事です。
この記事では、Python で使用できる深層学習ライブラリ DeepCTR を使用して映画のレコメンドを行ってみたいと思います。

DeepCTR とは

DeepCTR はCTR予測やレコメンドのタスクでよく用いられる深層学習モデルを簡単に呼び出して使用できるライブラリになります。実装は Tensorflow で行われており、2021/12月現在、実装されているモデルは30種類近くにもなります。リポジトリの更新も活発に行われているようです。

以下、公式ドキュメントの説明を機械翻訳したものです。

DeepCTRは、深層学習ベースのCTRモデルで、カスタムモデルを簡単に構築するために使用できるコアコンポーネントレイヤーとともに、使いやすく、モジュール化され、拡張可能なパッケージです。
tf.keras.Modelのようなインターフェイスを提供し、素早く実験ができる。
大規模データや分散学習のためのtensorflow estimatorインターフェイスを提供する。
tf 1.xとtf 2.xの両方に対応しています。

また、派生版として Pytorch 実装版の DeepCTR-Torch やマッチングタスクに特化した DeepMatch といったライブラリもあります。

動かしてみる

ここからは DeepCTR を実際に使用してみます。実行環境は以下になります。

  • Google Colaboratory
  • DeepCTR 0.9.0
  • Tensorflow 2.4.0

今回は DeepCTR の DeepFM というモデルを使用し、 MovieLens 100K データセットのユーザーレビュー点数(最低1点,最高5点)が4以上かどうかを予測する2値分類タスクを行います。また、学習させたモデルを使用してユーザーに映画をレコメンドしてみます。
使用したソースコードの全文は以下に公開しているので興味のある方はご覧ください。

また、公式ドキュメントのチュートリアルも参考になるので、使用の際はぜひご覧ください。
(以下のコードも公式ドキュメントの内容に少し改変を加えたものになります)

データセットの読み込み

データセットをダウンロードし、読み込みを行います。
users, movies, ratings にはそれぞれ、ユーザー情報、映画情報、ユーザーの映画に対するレビュー点数の情報が記入されています。

users = pd.read_csv(
    'ml-100k/u.user', sep='|', names=users_cols, encoding='latin-1')

ratings = pd.read_csv(
    'ml-100k/u.data', sep='\t', names=ratings_cols, encoding='latin-1')

movies = pd.read_csv(
    'ml-100k/u.item', sep='|', names=movies_cols, encoding='latin-1')

users

|    |   user_id |   age | sex   | occupation   |   zip_code |
|---:|----------:|------:|:------|:-------------|-----------:|
|  0 |         1 |    24 | M     | technician   |      85711 |
|  1 |         2 |    53 | F     | other        |      94043 |
|  2 |         3 |    23 | M     | writer       |      32067 |
|  3 |         4 |    24 | M     | technician   |      43537 |
|  4 |         5 |    33 | F     | other        |      15213 |

movies

|    |   movie_id | title             | release_date   |   video_release_date | imdb_url                                               |   genre_unknown |   Action |   Adventure |   Animation |   Children |   Comedy |   Crime |   Documentary |   Drama |   Fantasy |   Film-Noir |   Horror |   Musical |   Mystery |   Romance |   Sci-Fi |   Thriller |   War |   Western |   year |
|---:|-----------:|:------------------|:---------------|---------------------:|:-------------------------------------------------------|----------------:|---------:|------------:|------------:|-----------:|---------:|--------:|--------------:|--------:|----------:|------------:|---------:|----------:|----------:|----------:|---------:|-----------:|------:|----------:|-------:|
|  0 |          1 | Toy Story (1995)  | 01-Jan-1995    |                  nan | http://us.imdb.com/M/title-exact?Toy%20Story%20(1995)  |               0 |        0 |           0 |           1 |          1 |        1 |       0 |             0 |       0 |         0 |           0 |        0 |         0 |         0 |         0 |        0 |          0 |     0 |         0 |   1995 |
|  1 |          2 | GoldenEye (1995)  | 01-Jan-1995    |                  nan | http://us.imdb.com/M/title-exact?GoldenEye%20(1995)    |               0 |        1 |           1 |           0 |          0 |        0 |       0 |             0 |       0 |         0 |           0 |        0 |         0 |         0 |         0 |        0 |          1 |     0 |         0 |   1995 |
|  2 |          3 | Four Rooms (1995) | 01-Jan-1995    |                  nan | http://us.imdb.com/M/title-exact?Four%20Rooms%20(1995) |               0 |        0 |           0 |           0 |          0 |        0 |       0 |             0 |       0 |         0 |           0 |        0 |         0 |         0 |         0 |        0 |          1 |     0 |         0 |   1995 |
|  3 |          4 | Get Shorty (1995) | 01-Jan-1995    |                  nan | http://us.imdb.com/M/title-exact?Get%20Shorty%20(1995) |               0 |        1 |           0 |           0 |          0 |        1 |       0 |             0 |       1 |         0 |           0 |        0 |         0 |         0 |         0 |        0 |          0 |     0 |         0 |   1995 |
|  4 |          5 | Copycat (1995)    | 01-Jan-1995    |                  nan | http://us.imdb.com/M/title-exact?Copycat%20(1995)      |               0 |        0 |           0 |           0 |          0 |        0 |       1 |             0 |       1 |         0 |           0 |        0 |         0 |         0 |         0 |        0 |          1 |     0 |         0 |   1995 |

ratings

|    |   user_id |   movie_id |   rating |   unix_timestamp |
|---:|----------:|-----------:|---------:|-----------------:|
|  0 |       196 |        242 |        3 |      8.81251e+08 |
|  1 |       186 |        302 |        3 |      8.91718e+08 |
|  2 |        22 |        377 |        1 |      8.78887e+08 |
|  3 |       244 |         51 |        2 |      8.80607e+08 |
|  4 |       166 |        346 |        1 |      8.86398e+08 |

前処理・データ分割

モデルに入力するための前処理を行い、データを学習用とテスト用に分割します。
前処理では公開年カラムの生成、users, movies, ratingsデータを結合、2値分類用ラベルの作成、欠損値処理、数値変数の標準化、カテゴリ変数のラベルエンコーディングを行っています。

def preprocess(users, movies, ratings):
    ratings["rating"] = ratings["rating"].apply(lambda x: float(x)) 
    movies["year"] = movies['release_date'].apply(lambda x: str(x).split('-')[-1]) #公開年カラムの生成

    # データのマージ
    data = ratings.merge(movies, on='movie_id').merge(users, on='user_id')

    # 2値分類用ラベルカラム作成
    data[target] = data['rating'] >= 4.0
    data[target] = data[target].astype(int)

    # 欠損値埋め
    data[sparse_features] = data[sparse_features].fillna('-1', )
    data[dense_features] = data[dense_features].fillna(0, )

    # エンコーディング
    ss = StandardScaler()
    data[dense_features] = ss.fit_transform(data[dense_features])
    oe = OrdinalEncoder()
    data[sparse_features] = oe.fit_transform(data[sparse_features])
    data[sparse_features] = data[sparse_features].astype(int)
    return data, ss, oe

data, ss, oe = preprocess(users, movies, ratings)

# データ分割
train, test = train_test_split(data, test_size=0.1, random_state=SEED)
print('train shape: ', train.shape)
print('test shape: ', test.shape)

前処理後のデータは以下のような形になります。

|    |   user_id |   movie_id |   rating |   unix_timestamp | title                                                    | release_date   |   video_release_date | imdb_url                                                                                                  |   genre_unknown |   Action |   Adventure |   Animation |   Children |   Comedy |   Crime |   Documentary |   Drama |   Fantasy |   Film-Noir |   Horror |   Musical |   Mystery |   Romance |   Sci-Fi |   Thriller |   War |   Western |   year |     age |   sex |   occupation |   zip_code |   label |
|---:|----------:|-----------:|---------:|-----------------:|:---------------------------------------------------------|:---------------|---------------------:|:----------------------------------------------------------------------------------------------------------|----------------:|---------:|------------:|------------:|-----------:|---------:|--------:|--------------:|--------:|----------:|------------:|---------:|----------:|----------:|----------:|---------:|-----------:|------:|----------:|-------:|--------:|------:|-------------:|-----------:|--------:|
|  0 |       195 |        241 |        3 |        881250949 | Kolya (1996)                                             | 24-Jan-1997    |                  nan | http://us.imdb.com/M/title-exact?Kolya%20(1996)                                                           |               0 |        0 |           0 |           0 |          0 |        1 |       0 |             0 |       0 |         0 |           0 |        0 |         0 |         0 |         0 |        0 |          0 |     0 |         0 |     69 | 1.38638 |     1 |           20 |      55105 |       0 |
|  1 |       195 |        256 |        2 |        881251577 | Men in Black (1997)                                      | 04-Jul-1997    |                  nan | http://us.imdb.com/M/title-exact?Men+in+Black+(1997)                                                      |               0 |        1 |           1 |           0 |          0 |        1 |       0 |             0 |       0 |         0 |           0 |        0 |         0 |         0 |         0 |        1 |          0 |     0 |         0 |     69 | 1.38638 |     1 |           20 |      55105 |       0 |
|  2 |       195 |        110 |        4 |        881251793 | Truth About Cats & Dogs, The (1996)                      | 26-Apr-1996    |                  nan | http://us.imdb.com/M/title-exact?Truth%20About%20Cats%20&%20Dogs,%20The%20(1996)                          |               0 |        0 |           0 |           0 |          0 |        1 |       0 |             0 |       0 |         0 |           0 |        0 |         0 |         0 |         1 |        0 |          0 |     0 |         0 |     68 | 1.38638 |     1 |           20 |      55105 |       1 |
|  3 |       195 |         24 |        4 |        881251955 | Birdcage, The (1996)                                     | 08-Mar-1996    |                  nan | http://us.imdb.com/M/title-exact?Birdcage,%20The%20(1996)                                                 |               0 |        0 |           0 |           0 |          0 |        1 |       0 |             0 |       0 |         0 |           0 |        0 |         0 |         0 |         0 |        0 |          0 |     0 |         0 |     68 | 1.38638 |     1 |           20 |      55105 |       1 |
|  4 |       195 |        381 |        4 |        881251843 | Adventures of Priscilla, Queen of the Desert, The (1994) | 01-Jan-1994    |                  nan | http://us.imdb.com/M/title-exact?Adventures%20of%20Priscilla,%20Queen%20of%20the%20Desert,%20The%20(1994) |               0 |        0 |           0 |           0 |          0 |        1 |       0 |             0 |       1 |         0 |           0 |        0 |         0 |         0 |         0 |        0 |          0 |     0 |         0 |     66 | 1.38638 |     1 |           20 |      55105 |       1 |

モデル入力用のDenseFeat, SparseFeatの生成

次にモデルに入力するためのオブジェクトを生成します。
DeepCTR では使用する各変数について、数値変数の場合は DenseFeat , カテゴリ変数の場合は SparseFeat オブジェクトを生成する必要があります。
DenseFeat に指定した特徴はモデル内部で全結合層に、SparseFeat に指定した特徴はモデル内部で Embedding 層に入力され、学習が行われるようです。

# 使用する各変数についてSparseFeat, DenseFeat オブジェクトを作成
sparse_columns = [SparseFeat(feat, vocabulary_size=len(data[feat].unique()), embedding_dim=4) for feat in sparse_features]
dense_columns = [DenseFeat(feat, 1, ) for feat in dense_features]

モデル定義、学習

モデルのインスタンスを作成し学習を行います。モデルの引数には以下を指定する必要があります。

  • dnn_feature_columns: モデル内部のDNN層に使用する変数 (SparseFeat, DenseFeatのリスト)
  • linear_feature_columns: モデル内部の線形層に使用する変数 (SparseFeat, DenseFeatのリスト)
  • task: 解きたいタスク名

dnn_feature_columns, linear_feature_columns に指定する変数ですが、ドキュメントを見る限り、基本的に入力するSparseFeat, DenseFeatをすべて渡してしまっていいと思われます。今回は2値分類のため task には binary を指定します。

fixlen_feature_columns = sparse_columns + dense_columns
dnn_feature_columns = fixlen_feature_columns #DNN層に使用するカラム
linear_feature_columns = fixlen_feature_columns #線形層に使用するカラム
feature_names = get_feature_names(linear_feature_columns + dnn_feature_columns)
print(feature_names)

# 2値分類なので task に binary を指定
model = DeepFM(linear_feature_columns, dnn_feature_columns, task='binary')
model.compile("adam", loss='binary_crossentropy', metrics=['accuracy'], )
model.summary()

このあとは tensorflow でモデルを作成するときと同様の流れですが、モデルの fit 時に与えるデータの形式は少し注意が必要です。キーを特徴名、値を pd.Series とする辞書を生成して渡しています。

# モデル入力用データ作成
# key = 変数名、value= pd.Series型の辞書
train_model_input = {name: train[name] for name in feature_names}
test_model_input = {name: test[name] for name in feature_names}

# EarlyStopping などの callback も使用可能
es = EarlyStopping(patience=3, monitor='val_loss', )
history = model.fit(train_model_input, train[target].values,
                    batch_size=256, epochs=30, verbose=2, validation_split=0.1, callbacks=[es] )

学習が終わったらテストデータの loss値と AUCの値を確認してみます。

pred_ans = model.predict(test_model_input, batch_size=256)
print("test LogLoss", round(log_loss(test[target].values, pred_ans), 4))
print("test AUC", round(roc_auc_score(test[target].values, pred_ans), 4))

結果は以下になりました。何かしらの学習はできていそうですね。

  • test LogLoss 0.5531
  • test AUC 0.7859

ユーザー予測結果の確認

学習したモデルを使用して特定ユーザーに対して映画をレコメンドしてみます。
今回結果を確認するユーザーの情報は以下です。39歳、男性、エンジニアとなっています。

|    |   user_id |   age | sex   | occupation   |   zip_code |
|---:|----------:|------:|:------|:-------------|-----------:|
| 24 |        25 |    39 | M     | engineer     |      55107 |

ユーザーの視聴作品は以下です。(ratingが高い順に20件を表示しています)

|    | title                                                  |   rating |
|---:|:-------------------------------------------------------|---------:|
|  0 | Return of the Jedi (1983)                              |        5 |
|  1 | Raiders of the Lost Ark (1981)                         |        5 |
|  2 | Vertigo (1958)                                         |        5 |
|  3 | Silence of the Lambs, The (1991)                       |        5 |
|  4 | Wallace & Gromit: The Best of Aardman Animation (1996) |        5 |
|  5 | Toy Story (1995)                                       |        5 |
|  6 | Wrong Trousers, The (1993)                             |        5 |
|  7 | Close Shave, A (1995)                                  |        5 |
|  8 | Star Wars (1977)                                       |        5 |
|  9 | Back to the Future (1985)                              |        5 |
| 10 | Phenomenon (1996)                                      |        5 |
| 11 | Grand Day Out, A (1992)                                |        5 |
| 12 | Philadelphia Story, The (1940)                         |        5 |
| 13 | 39 Steps, The (1935)                                   |        5 |
| 14 | Contact (1997)                                         |        5 |
| 15 | Birdcage, The (1996)                                   |        5 |
| 16 | American President, The (1995)                         |        4 |
| 17 | Breakfast at Tiffany's (1961)                          |        4 |
| 18 | Around the World in 80 Days (1956)                     |        4 |
| 19 | Secret of Roan Inish, The (1994)                       |        4 |

モデル予測結果は以下のようになりました。(予測値が高い順に20件を表示)

|      | title                                     |     pred |
|-----:|:------------------------------------------|---------:|
|  514 | Boot, Das (1981)                          | 0.941268 |
|   95 | Terminator 2: Judgment Day (1991)         | 0.93625  |
| 1018 | Die xue shuang xiong (Killer, The) (1989) | 0.927317 |
|  317 | Schindler's List (1993)                   | 0.926411 |
|   78 | Fugitive, The (1993)                      | 0.922977 |
|  478 | Vertigo (1958)                            | 0.922477 |
|   49 | Star Wars (1977)                          | 0.922225 |
|  171 | Empire Strikes Back, The (1980)           | 0.921925 |
|  163 | Abyss, The (1989)                         | 0.915995 |
|  407 | Close Shave, A (1995)                     | 0.910564 |
|  168 | Wrong Trousers, The (1993)                | 0.909609 |
|  180 | Return of the Jedi (1983)                 | 0.905442 |
|  683 | In the Line of Fire (1993)                | 0.905395 |
|  565 | Clear and Present Danger (1994)           | 0.903256 |
|  854 | Diva (1981)                               | 0.902645 |
|  497 | African Queen, The (1951)                 | 0.901646 |
|  173 | Raiders of the Lost Ark (1981)            | 0.900039 |
|  209 | Indiana Jones and the Last Crusade (1989) | 0.899919 |
| 1420 | My Crazy Life (Mi vida loca) (1993)       | 0.899704 |
|  398 | Three Musketeers, The (1993)              | 0.898667 |

タイトルが英語なので分かりづらいですが、ユーザーが視聴した作品がいくつかモデル予測の上位に入っています。また、ユーザーが高評価したスターウォーズ関連の映画や、SF、スリラーチックな映画も予測上位にあり、ある程度妥当な結果になっていそうです。

まとめ

DeepCTRを使用してモデルの学習、映画レコメンドを行ってみました。
CTR予測やレコメンドタスクに有用なライブラリかと思いますので、興味が湧いた方は是非試してみてください。

10
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
1