6
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Vision Transformerでダーツの自動計算をやってみる【追記あり】

Last updated at Posted at 2022-08-02

ダーツ練習用ボード買っちゃった

最近、10年振りに友達とダーツしたんだけど、めっちゃ楽しくて、もっと上手くなりたい!と思って、その日の夜、練習用ボードをAmazonで購入しました。

こんな感じで自室にセッティングしてやってます。静音タイプのボードなので、夜でもやれます。
image.png

でもね、やっぱり自動計算できてゲームできると楽しいよね?
でもね、自動計算できる機能付きボードは相当高いんよね・・・。
AI使って、ダーツが刺さったところを認識できれば、点数計算ソフト作れるな・・・。

というわけで、やってみることにしました。

撮影環境

まずダーツボードから手前に130cm、左に60cmのところに高さ160cmでWebカメラを設置。
image.png

位置関係はこんな感じ。
image.png

学習方針と学習データ撮影

まずは、ダーツが1本だけ刺さった状態の画像で学習し、一定以上の認識率を目指すことにしました。
2本3本刺さった状態の画像は、AIが判定する上位3つを取り出せば上手くいくのでは?と考えたからです。

というわけで、ひたすらダーツを1本だけ投げて(もしくは希望部分に刺してw)撮影を行いました。
こんな感じの角度で640x360サイズのJpegで350枚程度保存しました。
WIN_20220802_12_20_47_Pro.jpg

学習データ、テストデータの準備

次に、データフォルダの準備です。
撮影した画像とその画像の内容(刺さっている位置、シングル/ダブル/トリプル種別)のCSVを作成します。
image.png

識別すべきIDを下記のように考えました。全部で64分類になります。

0・・・何も刺さっていない状態
1・・・"1"のシングル部分に刺さっている状態
2・・・"1"のダブル部分に刺さっている状態
3・・・"1"のトリプル部分に刺さっている状態
4・・・"2"のシングル部分に刺さっている状態
5・・・"2"のダブル部分に刺さっている状態
     ・
     ・
     ・
60・・・"20"のトリプル部分に刺さっている状態
61・・・ブルのシングル部分に刺さっている状態
62・・・ブルのダブル部分に刺さっている状態
63・・・アウトボード

学習データ、テストデータ共に、このIDのサブフォルダを作成し、その下に300x300にクリッピングした画像を収めました。下記のようなフォルダ構成です。

root --- train
           +-- 00
                +-- aaa.jpg
           +-- 01
                +-- bbb.jpg
                +-- ccc.jpg
                +-- ddd.jpg
           :
           :

     +-- test
           +-- 00
                +-- aaa.jpg
           :
           :

フォルダと画像作成用のコードは下記のような感じです。

datamake.py
# -*- coding: utf-8 -*-
import os
import csv
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

train_transforms = transforms.Compose(
    [
        transforms.CenterCrop((300, 300)),
    ]
)

class DartsMakeDataset(Dataset):
    def __init__(self, csv_file, org_dir, transform=None):
        self.file_name = []
        self.no = []
        self.rate = []

        self.transform = transform
        with open(csv_file, encoding='utf8') as f:
            reader = csv.reader(f)
            for i, row in enumerate(reader):
                if i > 0:
                    self.file_name.append(os.path.join(org_dir, row[0]))
                    self.no.append(int(row[1]))
                    self.rate.append(int(row[2]))

    def __len__(self):
        return len(self.file_name)

    def __getitem__(self, idx):
        return idx
    
    def create(self, idx, train_dir, num):
        if self.no[idx] == 0:
            id = 0
        elif self.no[idx] < 22:
            id = self.no[idx] * 3 - 3 + self.rate[idx]
        else:
            id = 63
        
        create_path = os.path.join(train_dir, f'{id:02}')
        if not os.path.isdir(create_path):
            os.makedirs(create_path)

        img = Image.open(self.file_name[idx])
        for i in range(int(num)):
            img_transformed = self.transform(img)
            img_transformed.save(os.path.join(create_path, 
                '{}_{}_{:03}.jpg'.format(
                    id, 
                    self.file_name[idx].split('\\')[-1].split('.')[0], 
                    i+1)), 
                quality=95)

org_data = DartsMakeDataset('pic1.csv', 'pic1', transform=train_transforms)
org_len = len(org_data)
for i in range(org_len):
    org_data.create(i, 'train', 1)

ディープラーニングモデル選定

さあ、いよいよAI部分の検討です。
今回は、最近仕事でも素晴らしい性能を発揮したViT(Vision Transformer)を採用することにしました。
コードは
https://farml1.com/vit/
を参考にさせてもらいました。

vit.py
# -*- coding: utf-8 -*-
from __future__ import print_function
import glob
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import seaborn as sns
import timm

batch_size = 8
epochs = 150
lr = 3e-5
gamma = 0.7
device = 'cuda:0'

train_dir = './train'
test_dir = './test'

train_list = glob.glob(os.path.join(train_dir,'**/*.jpg'), recursive=True)
test_list = glob.glob(os.path.join(test_dir, '**/*.jpg'), recursive=True)

train_transforms = transforms.Compose(
    [
        transforms.RandomCrop(280),
        transforms.Resize(384),
        transforms.RandomRotation(degrees=3),
        transforms.ToTensor(),
    ]
)

test_transforms = transforms.Compose(
    [
        transforms.CenterCrop(280),
        transforms.Resize(384),
        transforms.ToTensor(),
    ]
)

class DartsDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img)

        label = img_path.split("\\")[-2]
        label = int(label)

        return img_transformed, label

train_data = DartsDataset(train_list, transform=train_transforms)
test_data = DartsDataset(test_list, transform=test_transforms)

train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset = test_data, batch_size=batch_size, shuffle=False)

print(len(train_data), len(train_loader))
print(len(test_data), len(test_loader))

採用するモデルは学習済みのViTが山ほどあって悩みます。

from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)

とやると、そのリストが表示されるので、その中から色々試してみました。
以下、学習部分のコードです。

vit.py続き
model = timm.create_model('swin_base_patch4_window12_384', pretrained=True, num_classes=64)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

train_acc_list = []
test_acc_list = []
train_loss_list = []
test_loss_list = []
best_loss = 999.0

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in train_loader:
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_test_accuracy = 0
        epoch_test_loss = 0
        for data, label in test_loader:
            data = data.to(device)
            label = label.to(device)

            test_output = model(data)
            test_loss = criterion(test_output, label)

            acc = (test_output.argmax(dim=1) == label).float().mean()
            epoch_test_accuracy += acc / len(test_loader)
            epoch_test_loss += test_loss / len(test_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - test_loss : {epoch_test_loss:.4f} - test_acc: {epoch_test_accuracy:.4f}\n"
    )

    if best_loss > epoch_test_loss:
        best_loss = epoch_test_loss
        torch.save(model.state_dict(), 'darts_vit.model')

    train_acc_list.append(epoch_accuracy)
    test_acc_list.append(epoch_test_accuracy)
    train_loss_list.append(epoch_loss)
    test_loss_list.append(epoch_test_loss)

device2 = torch.device('cpu')

train_acc = []
train_loss = []
test_acc = []
test_loss = []

for i in range(epochs):
    train_acc2 = train_acc_list[i].to(device2)
    train_acc3 = train_acc2.clone().numpy()
    train_acc.append(train_acc3)
    
    train_loss2 = train_loss_list[i].to(device2)
    train_loss3 = train_loss2.clone().detach().numpy()
    train_loss.append(train_loss3)
    
    test_acc2 = test_acc_list[i].to(device2)
    test_acc3 = test_acc2.clone().numpy()
    test_acc.append(test_acc3)
    
    test_loss2 = test_loss_list[i].to(device2)
    test_loss3 = test_loss2.clone().numpy()
    test_loss.append(test_loss3)

sns.set()
num_epochs = epochs

fig = plt.subplots(figsize=(12, 4), dpi=80)

ax1 = plt.subplot(1,2,1)
ax1.plot(range(num_epochs), train_acc, c='b', label='train acc')
ax1.plot(range(num_epochs), test_acc, c='r', label='val acc')
ax1.set_xlabel('epoch', fontsize='12')
ax1.set_ylabel('accuracy', fontsize='12')
ax1.set_title('training and val acc', fontsize='14')
ax1.legend(fontsize='12')

ax2 = plt.subplot(1,2,2)
ax2.plot(range(num_epochs), train_loss, c='b', label='train loss')
ax2.plot(range(num_epochs), test_loss, c='r', label='val loss')
ax2.set_xlabel('epoch', fontsize='12')
ax2.set_ylabel('loss', fontsize='12')
ax2.set_title('training and val loss', fontsize='14')
ax2.legend(fontsize='12')
plt.show()

学習結果

150エポック程学習させた結果が下記のグラフです。
学習データの正答率が100%に到達するのに対して、テストデータでは70%程度でした。しかも若干過学習気味です。
このことから、まだまだ学習データが足りない様子がわかるので、頑張って追加撮影します!w
image.png

まぁ、自動採点機能付きのアプリ、という点では、多少間違っても訂正機能があれば使えそうです。
まだダーツ1本での評価なので、この後、2本、3本同時に刺さっている場合の評価は気になります。
上手く学習が進めば、自動採点アプリの完成も夢ではないように思えてきました。

続く・・・w

追記

学習データを272→792に増やして学習させた結果、なんと正答率95%を達成しました!
間違えたデータも、隣の数字だったり、シングルとダブルの間違いだったりと、それなりに納得の結果です。
image.png

というわけで、訂正機能を考慮すれば、ひとまずこれで運用できそうと判断し、次は2本刺さった状態、3本刺さった状態の検証に進みます!

・・・とここで嫌なケースが頭をよぎりました。
当初仮説は、AI出力のTOP3を拾うことで3本の刺さった位置を判定しようと考えてましたが、よく考えると、同じところに2本以上刺さった場合、どう判断すべきなのか・・・これは困ったことにw
もう少しロジック考えます!

続く・・・w

追記2

やはりダメでした。2本目3本目が刺さると、正しく認識してくれません。
というわけで、マルチラベルに切り替えます。

0~63:重複無し3本
64-127:同一箇所2本+別の箇所1本
128~191:同一箇所3本

として、最大3ヶ所のマルチラベル対応に切り替えます。
損失関数はBCELossに変更。one-hot配列からn-hot配列に変更、Accuracyは計算が難しくなるので、Lossのみで判断としました。

さて、学習データが大変。2本刺さっている画像の組み合わせだけでも2000枚以上、3本だと4万枚以上・・・。
もう実際にボードに挿して撮影、というわけにはいかなくなってきました。
そこで、1本画像2枚もしくは3枚を合成して学習データを大量に作ることにしました。

ベースとなる1本画像のダーツが刺さっている矩形の位置を記録します(ひとまず100枚程度)。
くり抜き位置決定用のコードです。

# -*- coding: utf-8 -*-
import os
import csv
from re import A
from PIL import Image, ImageDraw, ImageFont
from torch.utils.data import Dataset
from torchvision import transforms

train_transforms = transforms.Compose(
    [
        transforms.CenterCrop((300, 300)),
    ]
)

class DartsMakeDataset(Dataset):
    def __init__(self, csv_file, org_dir, transform=None):
        self.file_name = []
        self.no = []
        self.rate = []
        self.org_fn = []

        self.transform = transform
        with open(csv_file, encoding='utf8') as f:
            reader = csv.reader(f)
            for i, row in enumerate(reader):
                if i > 0 and len(row) == 3:
                    self.file_name.append(os.path.join(org_dir, row[0]))
                    self.no.append(int(row[1]))
                    self.rate.append(int(row[2]))
                    self.org_fn.append(row[0])

    def __len__(self):
        return len(self.file_name)

    def __getitem__(self, idx):
        return idx
    
    def clipbox(self, idx, outfn):
        if self.no[idx] == 0:
            return
        
        x = 170 + 150 + 8
        y = 30 + 100 + 28
        while True:
            img = Image.open(self.file_name[idx])
            draw = ImageDraw.Draw(img)
            for i in range(7):          
                draw.line((170+i*50, 30, 170+i*50, 330), fill=(255, 0, 0), width=1)
            for i in range(7):          
                draw.line((170, 30+i*50, 470, 30+i*50), fill=(255, 0, 0), width=1)

            draw.rectangle((x, y, x+45, y+25), fill=None, outline=(0, 255, 0))

            for i in range(6):       
                draw.multiline_text((170+i*50, 30), f'{170+i*50}', fill=(0, 0, 0))
            for i in range(5):       
                draw.multiline_text((170, 80+i*50), f'{80+i*50}', fill=(0, 0, 0))

            draw.multiline_text((x,y-10), f'({x},{y})', fill=(255, 0, 255))

            img_crop = img.crop((170,30,170+300,30+300))
            img_resize = img_crop.resize((900,900))
            
            print(f'({x},{y})')
            img_resize.show()
            while True:
                xy = input('x,y(or "ok") >')
                if xy == 'ok':
                    with open(outfn, 'a', encoding='utf8') as f:
                        f.write(f'{self.org_fn[idx]},{self.no[idx]},{self.rate[idx]},{x},{y}\n')
                    return

                lxy = xy.split(',')
                if len(lxy) != 2:
                    continue

                if (not lxy[0].isdigit()) or (not lxy[1].isdigit()):
                    continue

                if int(lxy[0]) < 170 or int(lxy[0]) > 470-45 or int(lxy[1]) < 30 or int(lxy[1]) > 330-25:
                    continue

                x = int(lxy[0])
                y = int(lxy[1])
                break
        
org_data = DartsMakeDataset('mix2.csv', 'pic1', transform=train_transforms)
org_len = len(org_data)
for i in range(org_len):
    org_data.clipbox(i, 'mix2tmp.csv')

あとはランダムに2~3枚を選んで、それぞれの矩形が重ならなければ合成します。
合成用コードです。

# -*- coding: utf-8 -*-
import os
import csv
import random
from PIL import Image, ImageDraw
from torch.utils.data import Dataset
from torchvision import transforms

train_transforms = transforms.Compose(
    [
        transforms.CenterCrop((300, 300)),
    ]
)

class DartsMakeDataset(Dataset):
    def __init__(self, csv_file, org_dir, transform=None):
        self.file_name = []
        self.no = []
        self.rate = []
        self.cx = []
        self.cy = []

        self.transform = transform
        with open(csv_file, encoding='utf8') as f:
            reader = csv.reader(f)
            for i, row in enumerate(reader):
                if i > 0:
                    self.file_name.append(os.path.join(org_dir, row[0]))
                    self.no.append(int(row[1]))
                    self.rate.append(int(row[2]))
                    self.cx.append(int(row[3]))
                    self.cy.append(int(row[4]))

    def __len__(self):
        return len(self.file_name)

    def __getitem__(self, idx):
        return idx
    
    def mix2(self, maxlen, train_dir, num):
        cnt = 1
        while True:
            a = random.randint(0, maxlen-1)
            b = random.randint(0, maxlen-1)
            ax = self.cx[a]
            ay = self.cy[a]
            bx = self.cx[b]
            by = self.cy[b]

            # マスク領域が重なっているか
            if ((ax >= bx and ax <= bx+45) or (ax+45 >= bx and ax+45 <= bx+45)) and ((ay >= by and ay <= by+25) or (ay+25 >= by and ay+25 <= by+25)):
                continue

            if self.no[a] == 0:
                id_a = 0
            elif self.no[a] < 22:
                id_a = self.no[a] * 3 - 3 + self.rate[a]
            else:
                id_a = 63
            
            if self.no[b] == 0:
                id_b = 0
            elif self.no[a] < 22:
                id_b = self.no[b] * 3 - 3 + self.rate[b]
            else:
                id_b = 63

            if id_a == id_b:
                same1 = 2
                id1 = id_a
                same2 = 0
                id2 = 0
                same3 = 0
                id3 = 0
            else:
                same1 = 1
                id1 = id_a
                same2 = 1
                id2 = id_b
                same3 = 0
                id3 = 0

            print(f'{cnt:05} create mix tow image. id: {id_a} + {id_b}')
            
            img_a = Image.open(self.file_name[a])
            img_b = Image.open(self.file_name[b])

            mask = Image.new("L", img_a.size, 0)
            draw = ImageDraw.Draw(mask)
            draw.rectangle((ax, ay, ax+45, ay+25), fill=255)
            img = Image.composite(img_a, img_b, mask)

            img_transformed = self.transform(img)
            img_transformed.save(os.path.join(train_dir, 
                f'{20000+cnt:05}_{same1}_{id1}_{same2}_{id2}_{same3}_{id3}_.jpg'), 
                quality=95)
            cnt += 1
            if cnt == num:
                break

org_data = DartsMakeDataset('mix2.csv', 'pic1', transform=train_transforms)
org_data.mix2(len(org_data), 'test2', 100)

image.png
こんな感じ、割といい感じです。1本画像のダーツ位置は判明しているので、ファイル名に位置情報を含ませます。
そうすることで、学習時にファイル名からマルチラベルを作成可能となります。

さぁ、これから学習開始です。20エポック程度回した時点では、正答率70%を超えてましたので期待できそうです!

続く・・・w

6
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?