はじめに
前回の記事ではYOLOv3に論文でも使用されたCOCO datasetを使ってdatasetを作成しました。今回からモデル構造の作成に取り掛かります。YOLOv3では画像の特徴を抽出するために複数の畳み込み層と残差ブロックを使用します。本記事では畳み込み層を作成する関数と残差ブロックのclassの作成を解説していきます。
フォルダ構成
作業フォルダ
┠ COCO (COCO datasetのダウンロードの項で作成します)
┠ dataset
┠ ┗ cocodataset.py (前回作成)
┗ models
┗ yolov3.py (今回作成するファイル)
使用ライブラリ
本記事内で使用するライブラリは以下となります。本記事では使用しませんが、動作確認の使用する前回作成したCOCOdatasetもimportしておきます。その際に必要となるモジュール検索パスの追加もここで記述してあります。
import os
import sys
from pathlib import Path
import torch.nn as nn
FILE = Path(__file__).resolve() # fileまでのパス
ROOT = FILE.parents[1] # 作業ディレクトリまでのパス
if str(ROOT) not in sys.path: sys.path.append(str(ROOT))
from dataset.cocodataset import COCODataset
畳み込み層
初めに畳み込み層を作成する関数を実装します。ここで言う畳み込み層は、畳み込みした画像をBatchNormalizeし、活性化関数を通す一連の流れを一括りにして畳み込み層と呼んでいます。この関数はYOLOv3全体のモデルを作成するときやこの後説明する残差ブロックのところで使用します。[入力channel数, 出力channel数, フィルターサイズ, ストライド]の4つを入力とし,出力として畳み込みのモジュール(レイヤー)を返します。
def add_conv(in_ch, out_ch, ksize, stride):
layer = nn.Sequential()
pad = (ksize - 1) // 2
layer.add_module('conv', nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=ksize, stride=stride, padding=pad, bias=False))
layer.add_module('batch_norm', nn.BatchNorm2d(out_ch))
layer.add_module('leaky', nn.LeakyReLU(0.1))
return layer
残差ブロック
続いて残差ブロックのclassを作成します。このclassをインスタンス化する際には[入力channel数, nblocks, shortucut]の3つを入力として与えます。図で表すと以下となります。
入力が二つの畳み込み層を通ります。出てきた出力hのchannel数は入力と同じになる用設計されています。shortcutがTrueの場合は、畳み込みする前の入力xと出力hを足し合わせたものが出力として得られます。さらに得られた出力を再度入力として層に入力します。これをnblock回繰り返すのがこの残差ブロックとなります。コードは以下となります。
class resblock(nn.Module):
def __init__(self, in_ch, nblocks=1, shortcut=True):
super().__init__()
self.shortcut = shortcut
self.module_list = nn.ModuleList()
for _ in range(nblocks):
resblock_one = nn.ModuleList()
resblock_one.append(add_conv(in_ch, in_ch//2, 1, 1))
resblock_one.append(add_conv(in_ch//2, in_ch, 3, 1))
self.module_list.append(resblock_one)
def forward(self, x):
for module in self.module_list:
h = x
for res in module:
h = res(h)
x = x + h if self.shortcut else h
return x
まとめ
ここまで読んでいただきありがとうございます。今回は畳み込み層を作成する関数の定義と残差ブロック層のclassを作成しました。次回はYOLOv3に含まれるlayerの肝となるYOLO Layerの実装を解説したいと思います。