0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【グラフニューラルネットワーク】PyTorch Geometricで異種混合グラフのリンク予測をやってみた

Posted at

経緯

Link Prediction on Heterogeneous Graphs with PyGを読んだので、理解を深めるために別のデータセットでリンク予測にチャレンジしてみました。

この記事で扱う問題

利用するデータセット

データセットはPyGに収録されているTaobaoデータセットを利用します。データの詳しい内容は提供元のサイトに記載されています。

このデータセットにはユーザ、商品、カテゴリのノードがあり、ユーザと商品間のエッジはユーザの行動を表し、商品とカテゴリ間のエッジは商品が特定カテゴリに属することを意味します。ユーザの行動には「閲覧」「購入」「カート追加」「お気に入り」があり、各アクションにはタイムスタンプも付与されています。

問題設定

2017/11/25 1:00 ~ 2017/12/3 0:59の購入履歴を使い、2017/12/3 1:00 ~ 2017/12/4 0:59に発生する購入アクションを予想します。
なお、簡単のためにカテゴリノードや他のユーザ行動エッジは利用しないことにします。

主要ライブラリのバージョン

torch==2.4.1+cu124
torch-geometric==2.6.1

ライブラリのインポートなど

まずはライブラリのインポートとGPUの準備などをしておきます。

import pandas as pd
import numpy as np
from tqdm import tqdm
from datetime import datetime
from sklearn.model_selection import train_test_split

# PyTorch
import torch
from torch import nn
from torch.optim import Adam
from torch.nn import functional as F

# PyG
from torch_geometric import transforms as T
from torch_geometric.data import HeteroData
from torch_geometric.datasets import Taobao
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import SAGEConv, to_hetero

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

データの取得

最初に実行するときはデータセットのダウンロード処理が走るので少し時間がかかります。Taobaoの引数にはデータセットをダウンロード先を指定します。

dataset = Taobao(root="./data/Taodao")
raw_data = dataset[0]
print(raw_data)

出力は以下のようになります。異種混合グラフ(HeteroData)であることがわかります。

HeteroData(
  user={ num_nodes=987991 },
  item={ num_nodes=4161138 },
  category={ num_nodes=9437 },
  (user, to, item)={
    edge_index=[2, 100095182],
    time=[100095182],
    behavior=[100095182],
  },
  (item, to, category)={ edge_index=[2, 4162555] }
)

ノードやエッジには辞書型のデータと同様の方法でアクセスすることができます。

raw_data["user"]  # userノード
raw_data["user", "to", "item"]  # user->(to)->itemエッジ

特にエッジについてedge_indexを見るとユーザから商品への対応を確認することができます。例えば[[0,1,2],[3,4,5]]であれば、ユーザ0→商品3、ユーザ1→商品4、ユーザ2→商品5の3本のエッジを意味します。

print(raw_data["user", "to", "item"].edge_index)
出力
tensor([[      0,       0,       0,  ...,  970447,  970447,  970447],
        [1827766, 1880345, 2076699,  ..., 2939548, 1534057, 2978718]])

データの加工

今回はユーザ行動のエッジの内で購入アクションのものだけを利用します。データセット提供元のサイトを確認すると(user, to, item)のbehaviorがユーザ行動で(pv, buy, cart, fav)の列挙型で表現されているようです。

user2item_edge = raw_data["user", "to", "item"]
buy_mask = user2item_edge.behavior == 2

# 購入アクションのエッジのみ取得
user2item_idx = user2item_edge.edge_index[:, buy_mask]
user2item_time = user2item_edge.time[buy_mask]

次にtimestampを使ってデータ分割用のマスクを作成します。

  • 訓練データ:2017/11/25 1:00~ 2017/12/1 0:59
  • 訓練ラベル:2017/12/2 1:00 ~ 2017/12/3 0:59
  • テストデータ:2017/12/3 1:00 ~ 2017/12/4 0:59

今回はモデルのアーキテクチャやパラメータの調整はしないので検証データは作成しません。(学習時の性能チェック用に検証データを作成しておくべきだった...)

# 2017/12/2 01:00 ~ 2017/12/3 0:59はlabelデータ
label_border = datetime.timestamp(datetime(2017, 12, 2, 1, 0, 0))
# 2017/12/3 01:00以降はtestデータ
test_border = datetime.timestamp(datetime(2017, 12, 3, 1, 0, 0))

# masks
train_mask = user2item_time < label_border
train_label_mask = (user2item_time >= label_border) & (user2item_time < test_border)
test_mask = user2item_time >= test_border

データの作成

訓練用データセットとテスト用データセットを作成します。ノードは時間経過で増えないものと仮定して訓練データおよびテストデータで共通とします。元のデータセットではノードにIDは割り振られておらず、ノード数のみ持っていたので連番で付与します。

# ノード数を取得
num_src_nodes = raw_data["user"].num_nodes
num_dst_nodes = raw_data["item"].num_nodes
# Nodes
user_node_idx = torch.arange(num_src_nodes)
item_node_idx = torch.arange(num_dst_nodes)

訓練データを作成します。ノードは先ほど作成したものを設定し、エッジは(user, buy, item)という名前で登録します。edge_indexおよびedge_label_indexにはユーザから商品へのエッジを表すtorch.tensorを設定します。

今回の問題はリンク予測なので予測対象のエッジを付与する必要があり、edge_label_indexに設定します。そして、これらが正例なのか負例なのかを表現したものがedge_labelです。ここではすべて実際に存在するエッジをedge_label_indexに設定しているのですべて正例です。

train_data = HeteroData()
train_data["user"].node_id = user_node_idx
train_data["item"].node_id = item_node_idx

train_data["user", "buy", "item"].edge_index = user2item_idx[:, train_mask]
train_data["user", "buy", "item"].edge_label = torch.ones(train_label_mask.sum())
train_data["user", "buy", "item"].edge_label_index = user2item_idx[:, train_label_mask]

# 無向グラフ化
undirected_transformer = T.ToUndirected()
train_data = undirected_transformer(train_data)

テストデータを作成します。おおよそ訓練データと同様です。
テストデータを使うときは2017/12/3 0:59までのデータが出そろっている状態と考えることができるので、グラフのエッジには訓練データとそのラベルを含めることができます。

test_data = HeteroData()
test_data["user"].node_id = user_node_idx
test_data["item"].node_id = item_node_idx

test_data["user", "buy", "item"].edge_index = user2item_idx[:, train_mask+train_label_mask]
test_data["user", "buy", "item"].edge_label = torch.ones(test_mask.sum()) 
test_data["user", "buy", "item"].edge_label_index = user2item_idx[:, test_mask]

# 無向グラフ化
test_data = undirected_transformer(test_data)

データローダの作成

作成したデータを一気にモデルに入れるとメモリ不足になってしまったので、データローダを作成してミニバッチ学習できるようにします。ただし、グラフデータの場合を単純に分割できないため専用のデータローダを使います。今回はリンク予測なのでLinkNeighborLoaderを利用します。

また、上記のデータ作成ではedge_labelはすべて1でしたが、そのままでは負例が学習できないのでネガティブサンプリング(グラフ上に存在しないエッジを取得してラベル値を0とする)をします。
テストデータに対しては不要かもしれませんが、これを指定しないとラベル情報が消えてしまう(?)のでとりあえずつけています。

train_edge_label_idx = train_data["user", "buy", "item"].edge_label_index
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[20,10],  # サンプリングする近傍の数(user:20, item:10)
    neg_sampling_ratio=2.0,  # ネガティブサンプリングする割合
    edge_label_index=(("user", "buy", "item"), train_edge_label_idx),
    batch_size=512,
    shuffle=True
)

test_edge_label_idx = test_data["user", "buy", "item"].edge_label_index
test_loader = LinkNeighborLoader(
    data=test_data,
    num_neighbors=[20,10],  # サンプリングする近傍の数(user:20, item:10)
    neg_sampling_ratio=1.0,  # ネガティブサンプリングする割合
    edge_label_index=(("user", "buy", "item"), test_edge_label_idx),
    batch_size=1024,
    shuffle=True
)

モデルの定義

ここはPyGによる異種グラフのリンク予測チュートリアルからほとんど変えていません。今回はエッジに特徴量をつけていないのでModel.forwardの一部を変更しました。

ユーザと商品には特徴量がないので、モデル内で埋め込みレイヤーを設定しています。次のようなフローで予測をします。

  1. 入力ノードの埋め込みベクトルを取得
  2. 埋め込みベクトルおよびエッジを使って畳み込み
  3. ユーザベクトルと商品ベクトルの内積を計算
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = SAGEConv(hidden_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

class Classifier(torch.nn.Module):
    def forward(self, x_user, x_item, edge_label_index):
        edge_feat_user = x_user[edge_label_index[0]]
        edge_feat_item = x_item[edge_label_index[1]]
        return (edge_feat_user * edge_feat_item).sum(dim=-1)

class Model(torch.nn.Module):
    def __init__(self, hidden_channels, src_dim, dst_dim, metadata):
        super().__init__()
        self.user_emb = torch.nn.Embedding(src_dim, hidden_channels)
        self.item_emb = torch.nn.Embedding(dst_dim, hidden_channels)

        self.gnn = GNN(hidden_channels)
        self.gnn = to_hetero(self.gnn, metadata=metadata)  # 異種混合グラフに対応させる
        self.classifier = Classifier()

    def forward(self, data):
        x_dict = {
            "user": self.user_emb(data["user"].node_id),
            "item": self.item_emb(data["item"].node_id),
        }

        x_dict = self.gnn(x_dict, data.edge_index_dict)
        pred = self.classifier(
            x_dict["user"], 
            x_dict["item"],
            data["user", "buy", "item"].edge_label_index
        )
        
        return pred

# モデルのインスタンス化
metadata = train_data.metadata()
model = Model(hidden_channels=32, src_dim=num_src_nodes, dst_dim=num_dst_nodes, metadata=metadata)
model = model.to(device)

学習

このフェーズはPyTorchを使った学習と同様です。(本当は検証データを作成して汎化性能もウォッチしておくべきですが今回は省略しています。)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 5
for epoch in range(1, epochs+1):
    total_loss = 0
    total_examples = 0
    epoch_preds = []
    epoch_labels = []
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        batch = batch.to(device)
        pred = model(batch)
        label = batch["user", "buy", "item"].edge_label
        loss = F.binary_cross_entropy_with_logits(pred, label)
        loss.backward()
        optimizer.step()
        
        total_loss += float(loss) * pred.shape[0]
        total_examples += pred.shape[0]
        
        # epoch全体の予測とラベルを記録
        epoch_preds.append(pred)
        epoch_labels.append(label)

    # 評価
    preds = torch.cat(epoch_preds, dim=0).cpu().detach().numpy() > 0
    labels = torch.cat(epoch_labels, dim=0).cpu().numpy()
    acc = (preds == labels).mean()
    
    print(f"Epoch: {epoch}, Train Loss: {total_loss/total_examples:.4f}, Train ACC: {acc:.4f}")

最後にテストデータを使った評価をします。データローダ作成時に記載しましたが、諸事情でテストデータに対してもネガティブサンプリングをかけてしまっているので、評価時に正例に対する予測だけを抽出して正解率を計算しています。(実質Recall)

test_preds = []
test_labels = []
for batch in tqdm(test_loader):
    with torch.no_grad():
        batch = batch.to(device)
        test_preds.append(model(batch))
        test_labels.append(batch["user", "buy", "item"].edge_label)
        
preds = torch.cat(test_preds, dim=0).cpu().numpy() > 0
labels = torch.cat(test_labels, dim=0).cpu().numpy()

# テストデータに対してもネガティブサンプリングしているので正例に対する予測を抽出
preds = preds[labels.astype(bool)]
acc = preds.mean()
print(f"Test ACC:{acc:.4f}")
出力
Test ACC:0.6570

感想

参考にした記事の内容はある程度すんなり理解できたので、データセットを変えてもそれほど苦労はしないだろうと思っていましたが、実際にやってみるとわからないことが多く躓きました。特に今回は時系列を考慮して、データ分割にtransforms.RandomLinkSplitを使わなかったのですが、そうすると後続のデータローダ作成時に設定するedge_label_indexがどこにもなく「これはどこから取ってきたらよいのか」「そもそもこれは何なのか」となりました。
やはり自分で考えてアウトプットしてみるというのは大切だと改めて思いました。正直に言うと、まだ分かっていないこともあるので勉強を進めていこうと思います。(例えば訓練時にネガティブサンプリングをしたけど将来発生するであろう正例を負例として学習してしまわないのか?など)

記事の内容についてアドバイスなどあればお気軽にコメントしていただけると嬉しいです!
ここまでお読みいただきありがとうございました!

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?