挨拶
こんにちは!
沖縄に住む機械学習エンジニア、ジェイクです。
今日の投稿はこのプロジェクトについてです。
https://github.com/birosjh/pytorch_ssd
概要
最近仕事では、物体検知案件とよく関わったりしていて、Single Shot Detector (SSD)とかYOLOなどのアルゴリズムをどんどんと使っています。
基本的にお客様のモデルを扱っていて、いつも具体的な目的があるので必要なところを修正して向上して返しています。これはこれでとても勉強になるんですが、最近はやっぱり好きなようにコードをいじったりいろいろ実験できるモデルが欲しくなってきましたので、SSDを作り始めました。そこからは「どうせ作るなら」と思って、今まであまり触られていないPyTorchで作ってみようと思いました。まだまだ途中のものですが、今日は一部を紹介したいと思っています。
その一部はPyTorchのDatasetです。
PyTorch Dataset
機械学習をやったことあるならいい経験だったかどうかわからないんですが、データ読み込みと関わったことあるはずです。
様々な読み込み方法はあるんですが、一般的なcsvデータだったらPandasなどを使用できます。画像処理の場合は、深層学習ライブラリは基本的に分類用のデータローダーを提供しています。ただ、場合によって提供されているデータローダーでやりたいことができません。その時は、カスタムデータローダーを作成するのが定番です。
PyTorchでは、Datasetというクラスを継承したら簡単にカスタムデータローダーを作成できます。それをこれからご紹介したいと思っています。
私の場合はSSDを構築しているので、SSD用のコード例で説明します。ちなみにPyTorchドキュメンテーションをみたい人はここを見てください
びっくりするほど簡単ですが、Datasetを継承したクラスはこの二つの関数しか必要ないです:__len__()
と __getitem__
__len__()
この関数は、Dataset内の全ての項目の数を表します。
__getitem__()
この関数は、一項目を取得する時に実行する処理です。画像などは、メモリが大きいので最初から全てを読み込むんではなくて、モデルが必要とする時にしか読み込みません。モデルがバッチを処理して次のバッチを準備する時にこの関数が呼ばれます。この段階で、画像を読み込み、アノテーションをロードすることが普通です。拡張機能(Augmentation) もこの段階で導入できます。
class MyDataset(Dataset):
def __len__(self):
# 全ての項目のリスト(例えば画像名のリスト)
return len(self.my_list_of_items)
def __getitem__(self, idx):
# 呼ばれる時にidx引数でtensorを渡される場合があるので、これはそのためです
if torch.is_tensor(idx):
idx = idx.tolist()
################################
# ここで画像読み込みなどを行う
################################
本当にこれだけです。その後は、Dataloaderに渡せば自動的にバッチしてくれて、そのままモデルに入れられます。
dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)
ただ、この説明だと少しもの足りないなので、SSDの場合はどうするかを見せたいと思っています。SSDに入力するデータは、画像とそれぞれ画像で現れる物体の領域です。領域はバウンディングボックスという4点のx,yポイントで指定されている箱のことです。つまり、__getitem__()
が呼ばれると一つの画像とその画像のバウンディングボックスをペアとして返すのが理想的です。
まずは、画像の初期化周りを見ておきましょう。今回使っているデータでjsonファイルを使って画像名とバウンディングボックス情報を読み込んでいます。 ...
ってところは単純に説明に必要ないコードを削除しています。隠しているわけじゃないよ〜(見たいなら上記のgithubリンクへどうぞ)
class MyDataset(Dataset):
def __init__(self, annotations_file, transform):
self.transform = transform
...
# アノテーションが入っているJSONファイルを読み込みます
annotation_json = json.loads(annotations_file)
# 画像名をキー、バウンディングボックスの配列がバリューのディクショナリーが返されます
self.image_annotations = self.prep_images_and_annotations(
annotation_json
)
# ここはidxで画像名を取得できるように上記のディクショナリーのキーを配列に変換しています
# __getitem__()が呼ばれる時に引数で渡されているidxをこれに入れて画像名を取得できます
self.images_list = list(self.image_annotations.keys())
...
簡単なのでさっと __len__
対応しましょう。単純に画像名のリストの長さを返しているだけです。
def __len__(self):
# 全ての画像名のリストの長さを返す
return len(self.images_list)
そして最後に __getitem__()
が残っています。ここで画像を一つ取得し、バウンディングボックスとペアとして返します。
def __getitem__(self, idx):
# ここで画像を読み込みます
# 渡されているidxをimages_listに入れて、画像名を取得して、それの画像を読み込みます
image_file = self.images_list[idx]
path_to_file = os.path.join(self.image_directory, image_file)
image = cv2.imread(path_to_file)
# 取得された画像のボックスをディクショナリーから取得します
labels = self.image_annotations[image_file]
# ここで拡張機能をやろうと思えばできます
# バウンディングボックスが得意なimgaugライブラリを使っています
if self.transform:
labels = BoundingBoxesOnImage(labels, shape=(self.height, self.width))
image, labels = self.resize_transformation(
image=image,
bounding_boxes=labels
)
labels = labels.bounding_boxes
labels = np.array([np.append(box.coords.flatten(), box.label) for box in labels])
...
return (image, labels)
最初に自らカスタムデータロードーを実装しないといけなかった時は、とても難しいことだと勘違いしていて少し抵抗があったんですが、本当は割とわかりやすい仕組みです。しかも、PyTorchのDatasetを利用すればこんなに簡単に実装できます。
最後に
この内容が役に立つかどうかわからないんですが、ここまで読んでくれてありがとうございました!
PyTorch Datasetでは、今回紹介できてない他の部分もいくつかあります。今後プロジェクトに使う機会があれば改めて投稿します!