4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

GNNを使って競馬の予測AIの作成にチャレンジしてみた

Posted at

はじめに

不労所得での生活は全人類の夢です。そんな夢を叶えるためにグラフニューラルネットワークを使って競馬の予測AIにチャレンジしてみました。なお、私自身もそんなにGNNに詳しいわけではないので間違えがあれば教えていただけると助かります。

GNN(グラフニューラルネットワーク)とは

そもそもGNNって何よって人のために簡単な説明をすると

この画像のようにノード(まる)とエッジ(線)からなる構造のことをグラフ構造といいます。
グラフニューラルネットワークとはこの各ノードがもつ情報を伝言ゲームのように隣のノードに伝えていくことで、各ノードの情報を伝搬し、学習するモデルになります。
わかりにくい場合はCNNのグラフ版と思っていてくれればそれで大丈夫です。

詳しい解説を知りたい人はこちら

が非常にわかりやすいです。

GNNを使う目的

競馬のレースをグラフとして考えた時に

このような馬をノード、馬同士の関係性をエッジとしたようなグラフとすることで、全ての馬の情報を考慮した上での勝つ馬が予測できるのではないかと考えてGNNを採用しました。

アプローチ方法

今回GNNでの予測方法としてのアプローチ方法としてグラフ分類を考えました。

最大の馬数を18頭と仮定したときに、正解ラベルとして一位になった馬の番号を与えることで、18クラスのグラフ分類問題として設定しました。

ひとまず期待値などは考えず、一位の馬を予測することだけを目的としました。

全体の流れ

  1. レース情報と馬の情報をnetkeiba.comからスクレイピング
  2. 全体のデータを前処理
  3. 各レースごとにグラフを作成
  4. GCNを使って学習を行いモデルを作成
  5. 最終的な予測精度を確認する

スクレイピング

qiitaには優秀な人たちがすでに素晴らしいスクレイピングのコードを上げてくださっているのでここでは割愛します。
スクレイピングの範囲は2008〜2021の13年間としました。

利用する特徴量

特徴量名 概要 エンコード方法
number 枠番 数値
horse_num 馬番 数値
horse_name 馬名 ordinary encoding
jockey 騎手 ordinary encoding
trainer 調教師 ordinary encoding
sex 性別 onehot encoding
old 年齢 数値
weight 斤量 数値
horse_weight 馬体重 数値
horse_weight_diff 馬体重の変化量 数値
win_rate 単勝オッズ 数値
popularity 人気 数値
sub_title 何歳以上何百万以下など onehot encoding
round 何ラウンド目か 数値
race_type 芝かダートか onehot encoding
condition 馬場の状態 onehot encoding
weather 天気 onehot encoding
race_line inかoutか onehot encoding
distance 距離 onehot encoding
location どこ開催か onehot encoding
leg 脚質 数値
total winning その時点での獲得賞金 数値
last day 前回から何日経過したか 数値

前処理

基本的にはエンコード方法の通りに行っています。

グラフの作り方

基本的にはレースごとに前処理をした後networkxを使ってグラフに変換します。

前処理にはcategory_encodersを利用しました。

後で考えると全体に前処理してからグラフにしたほうが良かった気がする。

使用モデル

グラフニューラルネットワークのライブラリであるDeep Graph Libraryを使用します。
訓練データを2020年までのデータ、テストデータを2021年のデータとしました。
モデルのコードはDGLのグラフ分類のサンプルコードをそのまま流用しました。

model.py
import dgl
import dgl.nn.pytorch as dglnn
import torch.nn as nn
import torch.nn.functional as F

class GNNClassifier(nn.Module):
    # 馬の最大頭数は18頭なので分類は18classとする
    def __init__(self, in_feat=in_feat, hidden_feat=hidden_feat, n_classifier=18):
        super(GATClassifier, self).__init__()
        self.conv1 = dglnn.GraphConv(in_feat, hidden_feat)
        self.conv2 = dglnn.GraphConv(hidden_feat, hidden_feat)
        self.fc1 = nn.Linear(hidden_feat, n_classifier)

    def forward(self, g, h):
        # Apply graph convolution and activation.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        with g.local_scope():
            g.ndata['h'] = h
            # Calculate graph representation by average readout.
            hg = dgl.mean_nodes(g, 'h')
            x = self.fc1(hg)
            return x

また、訓練のコードは

train.py
import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
from dgl.dataloading import GraphDataLoader

def train(dataset, model, criterion, optim, batch_size, epochs, device):

    train_dataloader = GraphDataLoader(dataset, batch_size=batch_size, drop_last=True, shuffle=True)

    model = model.to(device)
    criterion = criterion
    optim = optim

    min_loss = 1e9

    model.train()

    for epoch in tqdm(range(epochs)):
        running_loss = 0
        total = 0
        correct = 0
        for batched_graph, labels in train_dataloader:
            optim.zero_grad()

            feats = batched_graph.ndata['feat'].to(device)
            outputs = model(batched_graph, feats)
            labels = labels.to(device)

            loss = criterion(outputs, labels)

            running_loss += loss.item()

            loss.backward()
            optim.step()

        print('epoch :{}'.format(epoch+1))
        print('loss     :{}'.format(running_loss))

        if min_loss >= running_loss:
            model_path = './data/weight'
            torch.save(model.state_dict(), model_path)

動かした結果

最終的な予測精度は7%程度ととても使い物にならない精度を叩き出しました。

原因として考えられるものとしては

  • グラフの作り方が悪い
    • レースごとに前処理してからグラフにするのではなく全体を前処理してからグラフにしてみる
  • 前処理に問題がある
    • めんどくさくて正規化標準化してないのでおそらくしたほうがいい
  • 正解率の計算に問題がある
    • GNNの正解率の計算ぶっちゃけよくわからない

この辺が原因じゃないかなと思ってます。

今後試してみること

  • 正規化標準化する
  • グラフの作り方の見直し
  • グラフ分類するのではなくノード分類としてノードごとの順位を予測してみる
  • 18頭のレースのみのデータのみに絞ってみる

終わりに

それでもGNNはいつか最強の精度を叩き出してくれると信じています。

この記事はまだ途中なので今後も更新していきます。
よろしければ引き続き見守ってくださると嬉しいです。

更新情報

  • 2021/12/23 記事をアップ
4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?