皆さん、スプラトゥーンは好きですか?私はスプラトゥーンが好きでほぼ毎日やってます。そんなとき、「AIで敵を検出できたら面白そうだな」と思いました。てなわけで2023年1月に出たYOLOv8を使って、スプラトゥーン3で敵を検出するモデルを作ってみました。
使用したもの
- labelImg
- google colaboratory
- YOLOv8
- 訓練データ888枚 + ラベルデータ
- 検証データ208枚 + ラベルデータ
- テストデータ100枚
フォルダー構成は以下のようにしました。↓
objectdetection/
└detection.yaml
├train/
└image.png
└label.txt
├val/
└image.png
└label.txt
├test/
└image.png
アノテーション
まず初めに訓練データをたくさん集めてlabelImgを使ってアノテーションしていきます。labelImgとは画像に写っている学習させたい検出物を四角で囲ってラベルを作成するツールのことです。四角で囲った部分の座標がテキストドキュメントとして保存されるようになっています。今回は検証データも用意してあるのでそちらにもアノテーションしていきます。
labelImgのインストール方法は以下のサイトを参考にしました。↓
https://www.tankobucreate.com/labelimg_install/
こんな感じでアノテーションしていきます。↓
labelImg
labelImgの使い方は簡単です。「Open Dir」を押して使用する画像が保存されているディレクトリを選択します。
「Change Save Dir」を押してアノテーションした場所の座標を保存するディレクトリを選択します。
※YOLOの事前学習モデルを使用する場合は下の画像の赤線で囲ってある所を「YOLO」にしてください。
ヘッダーの「view」を開くと一番上に「Auto Save Mode」があるのでそれを押すと、自動で保存されるようになります。
以下のように選択されたディレクトリにアノテーションの座標がテキストファイルで保存されます。
保存されたテキストファイルはこんな感じです。
画像へのアノテーションが終わったら、その画像と座標が入力されているテキストファイルをgoogle driveにアップロードします。
それとyamlファイルが必要なのでそれも作成していきます。私は以下のようにVScodeで作成しました。↓
- 「nc」には検出したい物の数を入力します。今回は敵だけを検出したいので「1」と入力しています。
- 「names」には検出したい物の名前を入力します。
学習
それでは学習に入ります。まず、yolov8をインストールしていきます。
!git clone https://github.com/ultralytics/ultralytics
%cd ultralytics
!pip install -r requirements.txt
次に以下のようにすれば学習が開始されます。学習が終わると一番精度が良いパラメーター(best.pt)が/content/ultralytics/runs/detect/train/weightsに自動で保存されるようになっています。
from ultralytics import YOLO
model = YOLO("yolov8m.pt")
results = model.train(data="ここには先ほど作成したyamlファイルパスを入力", epochs=200, batch=20, imgsz=1280, rect=True)
results = model.val()
以下のように学習が進んでいきます。
※google colaboratoryを使うときは有料版にしないとメモリ不足で学習できなくなります。
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
1/200 18.5G 2.204 3.967 1.859 8 1280: 100%|██████████| 45/45 [01:30<00:00, 2.02s/it]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 6/6 [00:53<00:00, 8.95s/it]
all 208 281 0.348 0.27 0.223 0.0894
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
2/200 21.7G 2.123 2.638 1.953 9 1280: 100%|██████████| 45/45 [00:15<00:00, 2.87it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 6/6 [00:04<00:00, 1.25it/s]
all 208 281 0.00926 0.342 0.00544 0.0022
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
3/200 21.7G 2.261 2.646 2.101 9 1280: 100%|██████████| 45/45 [00:15<00:00, 2.94it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 6/6 [00:04<00:00, 1.20it/s]
all 208 281 0.299 0.221 0.151 0.0557
結果は...?
学習が終わったらテストデータでモデルを試します。
/content/ultralytics/runs/detect/train/weightsに保存されたbest.ptを使います。その結果が以下の画像になります。
model = YOLO("/content/ultralytics/runs/detect/train/weights/best.pt")
results = model.predict(source="用意したテストデータ", data="yamlファイルパス", save=True, imgsz=1280, rect=True)
まとめ
今回、yolov8を使ってスプラトゥーン3で敵検出をやってみましたが、ちゃんと近くにいる敵や遠くにいる敵を検出出来ていたり、誤検出やそもそも検出しなかったりなどバラつきがありました。しかもまだ動画では試したことがないので動画でも試してみようと思います。また、何かアドバイスがあれば教えていただけると嬉しいです。