LoginSignup
10
6

More than 3 years have passed since last update.

BERTでアイテムベースレコメンドを試みる

Last updated at Posted at 2020-03-24

はじめに

動機
ユーザの行動ログを一つの文章と見立て、文章の文脈を捉えることが可能なBERTで特徴を抽出することで
これまでの協調フィルタとは異なったレコメンドが可能ではないかと考えました。

概要
ECサイトに訪問したユーザの同一セッション内の閲覧ログを一つの文章と捉えてBERTでベクトル化します。
クエリとなる閲覧ログを元に類似の閲覧ログをfaissを使って検索します。

方法
1. Google Merchanise Store の閲覧ログをBigQuery (BQ) の公開データセットから取得
2. 取得した閲覧ログでBERTを学習
3. 訓練したモデルで閲覧ログをベクトル化
4. 似た傾向の閲覧ログをfaissを使って検索

環境
Google Colaboratory のTPU環境を使用して実装します。

免責事項
本記事は、技術検証と学習のために作成したもので実用性は保証しません。。m(*_ _)m

データの用意

BQの公開データからGAのサンプルログを取得

クエリを作成して必要なるデータだけを抽出したベースのデータセットを作成します。
データの詳細は→https://support.google.com/analytics/answer/3437719?hl=ja
bq2020-03-21.png

BigQuery Storage APIを使ってBQからデータを取得

ここからcolabを使用します。
まずはユーザの認証を行います。

from google.colab import auth
auth.authenticate_user()

bqstorage apiをインストールします。


#インストール後にRuntimeの再起動が必要です
pip install --upgrade google-cloud-bigquery[bqstorage,pandas]

BQのベースとなるデータセットからデータフレーム形式でデータを取得します。

%%bigquery df --use_bqstorage_api --project {プロジェクトID}
CREATE TEMP FUNCTION replace_space(x string)
RETURNS string
LANGUAGE js AS """
  return x.replace(/ /g,'-');
""";
SELECT 
  visitId
  , productSKU
  , v2ProductName
  , v2ProductCategory
FROM(
  SELECT  
    visitId
    , row_number() over(partition by visitId order by date desc ) as itm_cnt
    , replace_space(product.v2ProductName) as v2ProductName
    , product.productSKU
    , product.v2ProductCategory 
  FROM 
    {データセット}
  WHERE date >= '20170701'
)
WHERE itm_cnt >= 3 --閲覧商品数3以上とします
limit 10000 --テスト実行はデータを限定します
;

BERTの入力ファイルを作成

sample.txt

#ユーザの閲覧ログを一つ文字列として結合します。
data = (
          df.groupby('visitId')['v2ProductName']
          .apply(list)
          .apply(lambda x:sorted(x))
          .apply(' '.join)
       )
#sample.txt
sample = pd.DataFrame(data).reset_index()

vocab.txt

#以下を追加します
# [PAD]
# [UNK]
# [CLS]
# [SEP]
# [MASK]
vocab_ = pd.DataFrame({'vocab': ['[PAD]', '[UNK]', '[CLS]','[SEP]','[MASK]']})
vocab1 = pd.DataFrame({'vocab':df['v2ProductName'].unique()})
vocab = vocab_.append(vocab1)

ファイル出力

#File output
sample['v2ProductName'].to_csv('./sample.txt', sep='\t' , header=None, index=False, encoding='utf8')
vocab.to_csv('./vocab.txt', sep='\t', header=None, index=False, encoding='utf8')

BERTの学習

BERTソースコードを取得

extract_features.pyをTPUで実装するためにissueに従ってコードを差し替えます。
BERTをlocalへダウンロード → コードを修正 → GCSアップロード → colabへダウンロード


!git clone https://github.com/google-research/bert

# こちらを修正→ https://github.com/google-research/bert/pull/758 
# GCS→colabへのダウンロードコマンド
!gsutil cp -r gs://{Path}/bert ./

学習データの作成(MASK処理)

!python ./bert/create_pretraining_data.py --input_file=./sample.txt --output_file=./tf_examples.tfrecord --vocab_file=./vocab.txt --do_lower_case=True --max_seq_length=128 --max_predictions_per_seq=20 --masked_lm_prob=0.15 --random_seed=12345 --dupe_factor=5

モデル定義

こちらを参考にさせて頂きました。
https://github.com/FeiSun/BERT4Rec/blob/master/bert_train/bert_config_ml-20m_256.json

import json

config = {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 1024,
  "max_position_embeddings": 200,
  "num_attention_heads": 8,
  "num_hidden_layers": 2,
  "type_vocab_size": 2,
  "vocab_size": len(vocab.index),
}

with open('bert_config.json', 'w') as f:
    json.dump(config, f, ensure_ascii=False)

モデルの学習

TPUで実行するためにモデルの出力先をGCSにします。
${TPU_NAME}がTPUのアドレスが格納された環境変数です。


!python bert/run_pretraining.py --input_file=./tf_examples.tfrecord --output_dir=gs://{GCS PATH}/model --do_train=True --do_eval=True --bert_config_file=./bert_config.json  --train_batch_size=16 --max_seq_length=128 --max_predictions_per_seq=20 --num_train_steps=100 --num_warmup_steps=50 --lerning_rate=2e-5 --use_tpu=True --tpu_name=${TPU_NAME}

閲覧ログのベクトル化

モデルの学習に使った閲覧ログデータから重複を除き入力ファイルとします。

sample_test = sample.drop_duplicates()
pd.DataFrame(sample_test['v2ProductName']).to_csv('./sample_test.txt', sep='\t')

閲覧ログデータをベクトル表現します。

!python ./bert/extract_features.py --input_file=./sample_test.txt --output_file=gs://${GCS PATH}/output.jsonl --vocab_file=./vocab.txt --bert_config_file=bert_config.json --init_checkpoint=gs://${GCS PATH}/model --use_tpu=True --master=${TPU_NAME} --do_lower_case False --layers -2  

faissを使った近傍検索

faissをインストール

#!pip install faiss-gpu --no-cache
!pip install faiss-cpu --no-cache

ベクトル化したデータをGCSからダウンロードします。

!gsutil cp  gs://${GCS PATH}/output.jsonl ./

ベクトル化した閲覧ログをfaissに格納

こちらのサイトを参考にさせて頂きました。
https://datanerd.hateblo.jp/entry/2019/03/01/134310

import json
import numpy as np
import faiss

file ="output.jsonl"
data = []

with open(file) as f:
  for line in f:
    tmp = json.loads(line)
    tmp = np.mean([d['layers'][0]['values'] for d in tmp['features']], axis=0)
    data.append(tmp)

index = faiss.IndexFlatL2(data[0].shape[0])
index.add(np.array(data, dtype=np.float32))
faiss.write_index(index, "faiss_bert.faiss")

近傍検索

クエリとなる閲覧ログと類似の閲覧ログからレコメンド結果として20商品を取得します。

def get_rec_itm(input_file,query_num,itm_cnt):
    with open(input_file) as f:
        data = [line.strip() for line in f]

    print("[query log:", data[query_num], "]\n")
    index = faiss.read_index("./faiss_bert.faiss")
    vec = index.reconstruct(query_num)
    D,I = index.search(np.array([vec]), k=50)

    #get 20 unique items 
    res = ' '.join([data[i] for i in I[0]])
    res=sorted(set(res.split(' ')), key=res.index)
    return(res[:itm_cnt])

「ベクトル化に使用した入力ファイル」、「クエリとするログの行番号(+1)」、「取得する商品数」を引数にします。

get_rec_itm("sample_test.txt",42,20)

出力結果

# クエリログ
[query log: 41  25L-Classic-Rucksack Google-Alpine-Style-Backpack Waterproof-Backpack ]

# レコメンド商品
["726\tGoogle-Men's-Performance-Full-Zip-Jacket-Black",
 '682\t25L-Classic-Rucksack',
 '25L-Classic-Rucksack',
 'Google-Alpine-Style-Backpack',
 'Waterproof-Backpack',
 "131\tGoogle-Women's-Short-Sleeve-Hero-Tee-White",
 "268\tGoogle-Men's-Long-Sleeve-Raglan-Ocean-Blue",
 "Google-Men's-Long-Sleeve-Raglan-Ocean-Blue",
 "170\tGoogle-Men's-Watershed-Full-Zip-Hoodie-Grey",
 "Google-Men's-Watershed-Full-Zip-Hoodie-Grey",
 '41\t25L-Classic-Rucksack',
 '396\t25L-Classic-Rucksack',
 '702\tGoogle-Toddler-Raglan-Shirt-Blue-Heather/Navy',
 '257\tGoogle-Alpine-Style-Backpack',
 'Google-Rucksack',
 '510\tGoogle-Car-Clip-Phone-Holder',
 'Google-Car-Clip-Phone-Holder',
 'Waze-Mobile-Phone-Vent-Mount',
 '239\tGoogle-Stylus-Pen-w/-LED-Light',
 '305\tGift-Card--$100.00']
10
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
6