11
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

RecBole: Sequential Model を APIとして実装する

Last updated at Posted at 2022-04-24

こんにちは〜

最近はこちらの記事 などでひたすら RecBole (推薦モデルをたくさん試せるやつ)をいじっています。色んなデータに色んなモデルを当ててみると色々発見があってたのしい。

さて、そんな RecBole ですがあくまでも実験用としての色合いが強く、production 利用するための記述はそれほど充実していません。この記事ではそんな応用の一用途である API 実装について書きます。

前提

Sequential Model を RecBole で学習させ、これを使いたいなと思ったとします。SHANとかSINEとか。
これらはモデルの原理を考えればユーザIDは必要なく、アイテムの履歴だけがあれば良いのです。が、RecBole 側が用意している topk アイテムを出力する関数 full_sort_topk では uid_series が args として必要となっており、ユーザIDが必要な挙動になっています。

def full_sort_topk(uid_series, model, test_data, k, device=None):
    """Calculate the top-k items' scores and ids for each user in uid_series.
    Note:
        The score of [pad] and history items will be set into -inf.
    Args:
        uid_series (numpy.ndarray): User id series.
        model (AbstractRecommender): Model to predict.
        test_data (FullSortEvalDataLoader): The test_data of model.
        k (int): The top-k items.
        device (torch.device, optional): The device which model will run on. Defaults to ``None``.
            Note: ``device=None`` is equivalent to ``device=torch.device('cpu')``.
    Returns:
        tuple:
            - topk_scores (torch.Tensor): The scores of topk items.
            - topk_index (torch.Tensor): The index of topk items, which is also the internal ids of items.
    """
    scores = full_sort_scores(uid_series, model, test_data, device)
    return torch.topk(scores, k)

こうなっている理由は、実は他のモデル: General Recommender などがIDを必要とするのでそれらとインタフェースを合わせているから、というだけです。もう少し正確に言うと、各ユーザごとの topk を見たいときしか使わないでしょう、という意図かな...
SINE のモデル実装コードを見に行ってみると、forward関数に通しているのは item_sequence だけであることがわかります。

    def full_sort_predict(self, interaction):
        item_seq = interaction[self.ITEM_SEQ]
        item_seq_len = interaction[self.ITEM_SEQ_LEN]
        seq_output = self.forward(item_seq, item_seq_len)
        test_items_emb = self.item_embedding.weight
        scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))  # [B, n_items]
        return scores

Sequential Model を API に実装しよう

Sequential Model はこのように、知らないユーザ(学習時に見たことがないユーザID)であってもある程度行動履歴があれば予測を返すことができます。なので、このモデルをAPIとして提供できればある程度応用価値がありそうです。

ということで、RecBole が提供している full_sort_topk 関数を使わずに、user_id を渡すのではなくアイテム履歴だけを渡して予測を返させるコードを考えてみます。

Interactionにデータを詰める

SINE.full_sort_predict を見てみると、入力は Interaction というクラスになっています。これは RecBole の中で使われるデータ格納クラスです。以下のような形で初期化することができます。

input_interaction = Interaction(
    {
        "variable": ["example"],
    }
)

この Interaction の中に必要なデータを詰めて model.full_sort_predict に渡せばいけそうです。

SINE が必要としているのは item_list と item_length なので、それを用意すると以下のようになります。

item_sequence = ["item1", "item2", "item3"]
item_length = len(item_sequence)
pad_length = 50  # pre-defined by recbole

padded_item_sequence = torch.nn.functional.pad(
    torch.tensor(dataset.token2id(dataset.iid_field, item_sequence)),
    (0, pad_length - item_length),
    "constant",
    0,
)

input_interaction = Interaction(
    {
        "item_list": padded_item_sequence.reshape(1, -1),
        "item_length": torch.tensor([item_length]),
    }
)

SINE 学習時にアイテム履歴の最大長を50にしたので、その長さになるように zero padding しています。
また、dataset ですが、これはRecBole でモデルを復元するとついてくるもので、いろいろな便利メソッドを持っています。ここでは item_field に従って、アイテム文字列を SINE モデル内で使う internal_id(int) に変換しています。

モデルの読み込み

RecBole が提供しているモデル読み込み関数は load_data_and_model しかないのですが、これは学習時に使ったデータセットなども同じ設定で復元してくれます。この過程で上述した dataset などが用意されます。
https://github.com/RUCAIBox/RecBole/blob/master/recbole/quick_start/quick_start.py#L102

が、load_data_and_model は学習時に使った raw_data が学習時と全く同じ場所に配置されていないと ValueError: Neither [dataset/your_dataset] exists in the devicenor [your_dataset] a known dataset name. というエラーを吐いて止まります...
API を実行するためだけにデータを APIと同じインスタンスに配備するのは流石に辛いので、どうにかします。

load_data_and_model の中身を見ていくと

def load_data_and_model(model_file):
    r"""Load filtered dataset, split dataloaders and saved model.
    Args:
        model_file (str): The path of saved model file.
    Returns:
        tuple:
            - config (Config): An instance object of Config, which record parameter information in :attr:`model_file`.
            - model (AbstractRecommender): The model load from :attr:`model_file`.
            - dataset (Dataset): The filtered dataset.
            - train_data (AbstractDataLoader): The dataloader for training.
            - valid_data (AbstractDataLoader): The dataloader for validation.
            - test_data (AbstractDataLoader): The dataloader for testing.
    """
    checkpoint = torch.load(model_file)
    config = checkpoint['config']
    init_seed(config['seed'], config['reproducibility'])
    init_logger(config)
    logger = getLogger()
    logger.info(config)

    dataset = create_dataset(config)
    logger.info(dataset)
    train_data, valid_data, test_data = data_preparation(config, dataset)

    init_seed(config['seed'], config['reproducibility'])
    model = get_model(config['model'])(config, train_data.dataset).to(config['device'])
    model.load_state_dict(checkpoint['state_dict'])
    model.load_other_parameter(checkpoint.get('other_parameter'))

    return config, model, dataset, train_data, valid_data, test_data

となっています。ここでほしいのはモデルファイルだけなので、下三行付近を見てみると、 model = get_model(config['model'])(config, train_data.dataset).to(config['device']) で dataset が要求されているだけで、学習データがないと重みを読み込めないといったことは流石になさそうです。
更に見ていくと、SINE 読み込み時に要求されるのは dataset の中でも user_num, item_num を渡すためだけに使われているのでここでも学習データ自体が必要なわけではなさそうです。

import cloudpickle
from recbole.data import create_dataset
import torch


checkpoint = torch.load(model_file_path)
config = checkpoint["config"]
dataset = create_dataset(config)
cloudpickle.dump(dataset, open("output/dataset.pkl", "wb"))

ということで、今回はこんな感じで事前に dataset を作りこれを pkl に吐いて API実行時参照できる場所に置いておくことにしてみました。
config から作成する dataset は interaction 全てを保持しているわけではなく、user, item の external id ←→ internal id の対応表、user_num, item_num などの基本的な情報しか持っておらず、比較的軽量なので問題ないかと思われます。
このスクリプト自体は学習データがある場所で実行する必要があります。

で、これらを考慮した結果、ダイエットしたモデル読み込み関数は以下のようになります。

def load_model(model_file: str, dataset_file: str) -> Tuple[{Your Model Class}, SequentialDataset]:
    with open(dataset_file, "rb") as f:
        dataset = cloudpickle.load(f)

    checkpoint = torch.load(model_file)
    config = checkpoint["config"]
    init_seed(config["seed"], config["reproducibility"])
    model = get_model(config["model"])(config, dataset).to(config["device"])
    model.load_state_dict(checkpoint["state_dict"])
    model.load_other_parameter(checkpoint.get("other_parameter"))
    return model, dataset

これで、学習済モデルの checkpoint ファイルと dataset.pkl がありさえすれば予測に必要なモデルデータが復元できるようになりました。

API全体を実装していく

あとはこれを使って、以下をやっていくだけです!

  1. モデルの読み込み
  2. リクエストからモデル入力となるアイテム列・topk を受け取る
  3. 2のデータを Interaction クラスに詰める
  4. model.full_sort_predict に渡して全アイテムに対するスコアを得る
  5. topk に応じて argsort して上位k件を取得
  6. internal id から external id に直して実際のアイテムIDとして予測を返す

FastAPIを使った実装例ですが、最終的な全体像はこんな感じになります。

from typing import List, Tuple

import numpy as np
import torch
from fastapi.applications import FastAPI
from pydantic import BaseModel
from recbole.data import create_dataset
from recbole.data.dataset.sequential_dataset import SequentialDataset
from recbole.data.interaction import Interaction
from recbole.model.sequential_recommender.sine import SINE
from recbole.utils import get_model, init_seed

app = FastAPI(docs_url=None, redoc_url=None)


def load_model(model_file: str) -> Tuple[SINE, SequentialDataset]:
    checkpoint = torch.load(model_file)
    config = checkpoint["config"]
    init_seed(config["seed"], config["reproducibility"])
    dataset = create_dataset(config)
    model = get_model(config["model"])(config, dataset).to(config["device"])
    model.load_state_dict(checkpoint["state_dict"])
    model.load_other_parameter(checkpoint.get("other_parameter"))
    return model, dataset


model, dataset = load_model(
    model_file="saved/{your_model_checkpoint}.pth"
    dataset_file="outout/dataset.pkl"
)


class ItemHistory(BaseModel):
    sequence: List[str]
    topk: int


class RecommendedItems(BaseModel):
    score_list: List[float]
    item_list: List[str]


@app.get("/hello")
def health_check() -> str:
    """
    Health check endpoint
    """

    return "Hello Sequential Recommendation api"


@app.post("/v1/sine/user_to_item", response_model=RecommendedItems)
def sine_user_to_item(item_history: ItemHistory):
    item_history_dict = item_history.dict()
    item_sequence = item_history_dict["sequence"]
    item_length = len(item_sequence)
    pad_length = 50  # pre-defined by recbole

    padded_item_sequence = torch.nn.functional.pad(
        torch.tensor(dataset.token2id(dataset.iid_field, item_sequence)),
        (0, pad_length - item_length),
        "constant",
        0,
    )

    input_interaction = Interaction(
        {
            "item_list": padded_item_sequence.reshape(1, -1),
            "item_length": torch.tensor([item_length]),
        }
    )
    scores = model.full_sort_predict(input_interaction.to(model.device))
    scores = scores.view(-1, dataset.item_num)
    scores[:, 0] = -np.inf  # pad item score -> -inf
    topk_score, topk_iid_list = torch.topk(scores, item_history_dict["topk"])

    predicted_score_list = topk_score.tolist()[0]
    predicted_item_list = dataset.id2token(
        dataset.iid_field, topk_iid_list.tolist()
    ).tolist()

    recommended_items = {
        "score_list": predicted_score_list,
        "item_list": predicted_item_list,
    }
    return recommended_items

おつかれさまでした〜

RecBole は主目的こそ実験用、という感じですが中身を見ていくときれいに整備されていて、やってる事自体はそれほど複雑なことになっていません。なので、中身を見れば大体自分がやりたいようにやるためにはどうすればよいか、は結構わかりやすい方だと思います。

ということで、本体に PR を出すほどではないけどやろうとしてみると色々内部実装を見たりいじったりする必要があるという内容でした。
どなたかのお役に立てれば幸いです。

11
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?