背景
バスケットボールにおける人物検出の研究をしていてコート検出を行う必要がありました。そこで文献を参考にしながらAutoEncoderを実装したいと思います。
Classification of Basketball Actions through Deep Learning Techniques
論文リンク:
Classification of Basketball Actions through Deep Learning Techniques
今回、論文中から畳み込みオートエンコーダーによるコート検出を参考にしました。
構造
論文を参考に4層のEncoder/8層のDecoderで構成します。
モデルを見ると浅いEncoderと深いDecoderの非対称な構造になっています。これはBinary segmentationタスクは高度な処理は必要ないことから、Encoder部分は浅く、代わりに複合化のためのDecoder部分は精確な出力を行うため、Encoder部よりも多層になっているそうです。
論文を参考に畳み込みオートエンコーダーを実装してみます。
import torch.nn as nn
import torch.nn.functional as F
class ConvAutoEncoder(nn.Module):
def __init__(self):
super(ConvAutoEncoder,self).__init__()
#inputsize(720*480*3)
#Encoder
self.conv1=nn.Conv2d(3,64,kernel_size=(3,3),stride=(2,2))
self.conv2=nn.Conv2d(64,128,kernel_size=(3,3),stride=(2,2))
self.conv3=nn.Conv2d(128,256,kernel_size=(3,3),stride=(2,2))
self.conv4=nn.Conv2d(256,512,kernel_size=(3,3),stride=(2,2))
#Decoder
self.t_conv1=nn.ConvTranspose2d(512,256,kernel_size=(3,3),stride=(2,2))
self.t_conv2=nn.ConvTranspose2d(256,256,kernel_size=(3,3),stride=(1,1),padding=1)
self.t_conv3=nn.ConvTranspose2d(256,256,kernel_size=(3,3),stride=(1,1),padding=1)
self.t_conv4=nn.ConvTranspose2d(256,128,kernel_size=(3,3),stride=(2,2))
self.t_conv5=nn.ConvTranspose2d(128,128,kernel_size=(3,3),stride=(1,1),padding=1)
self.t_conv6=nn.ConvTranspose2d(128,128,kernel_size=(3,3),stride=(1,1),padding=1)
self.t_conv7=nn.ConvTranspose2d(128,64,kernel_size=(3,3),stride=(2,2))
self.t_conv8=nn.ConvTranspose2d(64,1,kernel_size=(3,3),stride=(2,2),output_padding=1)
def forward(self,x):
x=F.relu(self.conv1(x))
x=F.relu(self.conv2(x))
x=F.relu(self.conv3(x))
x=F.relu(self.conv4(x))
x=F.relu(self.t_conv1(x))
x=F.relu(self.t_conv2(x))
x=F.relu(self.t_conv3(x))
x=F.relu(self.t_conv4(x))
x=F.relu(self.t_conv5(x))
x=F.relu(self.t_conv6(x))
x=F.relu(self.t_conv7(x))
x=self.t_conv8(x)
return x
カラー画像を白黒画像にするためEncoder部で入力/出力でチャネル数3→1になるように構成します。(参考文献のままのフィルタサイズ/ストライドだと計算があわなかったため、自分で調整しました。)
データセットの準備
今回のモデルは入力を試合画像とし、教師画像をコートのみをマスクした2値化画像として学習を行います。本論文ではこのデータセットを配布していなかったため、自分でデータセットを用意する必要があります。
そこでコートの座標を指定し、csvとして保存するアノテーションツールを作成しました
。
「onMouse」 はマウスの動作によって座標を指定、点と線を描画する関数です。
「save_point_list」は指定した座標を保存する関数です。
## マウス処理
def onMouse(event,x,y,flag,params):
"""
左クリック : ポイントを追加 |右クリック : ポイント削除 |Enter : 次の画像へ
"""
raw_img=params["img"]
wname=params["wname"]
point_list=params["point_list"]
frame=params["frame"]
total_frame=params["total_frame"]
##クリックイベント
##左クリックでポイント追加
if event == cv2.EVENT_LBUTTONDOWN:
point_list.append([x,y])
##右クリックでポイント削除
if event==cv2.EVENT_RBUTTONDOWN:
point_list.pop(-1)
#レーダーの作成、描画
img=raw_img.copy()
h,w=img.shape[0],img.shape[1]
cv2.line(img,(x,0),(x,h),(255,0,0),1)
cv2.line(img,(0,y),(w,y),(255,0,0),1)
##点、線の描画
for i in range(len(point_list)):
#各ポイントリストの座標を取得
cv2.circle(img,(point_list[i][0],point_list[i][1]),3,(0,0,255),3)
if 0<i:
cv2.line(img, (point_list[i][0], point_list[i][1]),
(point_list[i-1][0], point_list[i-1][1]), (0, 255, 0), 2)
if 0<len(point_list):
cv2.line(img, (x, y),
(point_list[len(point_list)-1][0], point_list[len(point_list)-1][1]), (0, 255, 0), 2)
#座標情報をテキストで出力
cv2.putText(img,"({0},{1})".format(x,y),(0,20),cv2.FONT_HERSHEY_PLAIN,1,(255, 255, 255), 1, cv2.LINE_AA)
cv2.putText(img,"{}/{} FRAME".format(frame,total_frame),(0,50),cv2.FONT_HERSHEY_SIMPLEX,1,(0, 0, 200), 2, cv2.LINE_AA)
cv2.imshow(wname, img)
##取得した座標情報を保存
def save_point_list(path,point_list):
f=open(path,"w")
for p in point_list:
f.write(str(p[0])+","+str(p[1])+"\n")
f.close()
def main(target_path,save_path):
PATH=target_path+'*'
PATH=sorted(glob.glob(PATH))
SAVE_PATH=save_path
for i,path in enumerate(PATH):
#画像の読み込み
print(path)
img=cv2.imread(path)
wname="MouseEvent"
point_list=[]
frame=i+1
total_frame=len(PATH)
params={"img":img,
"wname":wname,
"point_list": point_list,
"frame":frame,
"total_frame":total_frame
}
cv2.namedWindow(wname,cv2.WINDOW_NORMAL)
cv2.setMouseCallback(wname,onMouse,params)
cv2.imshow(wname,img)
cv2.waitKey(0)
cv2.destroyAllWindows()
## 取得したポイントをcsvに保存
csv_path=os.path.join(SAVE_PATH,'frame-{}.csv'.format(i))
save_point_list(csv_path, point_list)
print("Save csv file:", csv_path)
'BinaryAnnotool.py'
import sys
import os
import argparse
import cv2
import glob
parser=argparse.ArgumentParser()
parser.add_argument('-tar')
parser.add_argument('-save')
if __name__=="__main__":
args=parser.parse_args()
target_path=args.tar
save_path=args.save
main(target_path,save_path)
----------------------------------------------------------
$ python Binary-Annotool.py -tar '入力画像のあるディレクトリのpath' -save 'csvを保存するディレクトリのpath'
BinaryAnnotool.py の実行結果は下図のようになります。
アノテーションを行い、各試合画像のコートの座標を記録し、cv2.fillPolyでコートのマスク画像を生成します。
また、試合映像/マスク画像に対し、-30度~30度の回転をランダムに行いデータ拡張を行いました。
結果
下図のようにコートのみを上手くマスクできていることがわかります。
参考
[元論文]
1.Simone Francia( April 2018) :Classification of Basketball Actions through Deep Learning Techniques
[座標保存用のプログラム]
2.Kaggle note OpenCVのonMouseを使った座標取得プログラム : https://kagglenote.com/misc/get-point-with-opencv/