7
8

More than 1 year has passed since last update.

【albumentations】データ拡張による精度向上を検証する

Last updated at Posted at 2022-05-15

目次

  1. はじめに
  2. データ拡張
  3. YOLOv5による学習
  4. 結果と考察
  5. 所感
  6. 参考

1. はじめに

物体検知の精度を向上させる方法として、データ拡張(Data augmentation)が存在します。
今回はデータ拡張ライブラリ「albumentations」の習熟もかねて、データ拡張による精度向上の検証を行いました。
使用するデータセットは「Global Wheat Detection」を、物体検出アルゴリズムはYOLOv5を使用します。

1.1 Global Wheat Detection とは

「Global Wheat Detection」は画像から麦の穂の領域を検出し、その精度を競うコンペティションです(開催済み)。
Global Wheat Detection

1.2 YOLOv5とは

代表的な物体検出アルゴリズムであるYOLO(You only Look Once)のver.5です。
YOLOは物体の検出とクラス分類を同時に行うことで、高速化を実現しています。

YOLOの詳細は以下のサイトで非常に分かり易くまとめられています。
物体検出の代表アルゴリズム YOLOシリーズを徹底解説!【AI論文解説】

1.3 動作環境

  • Google Colaboratory Pro

2. データ拡張

データセットのダウンロードとYOLOv5のインストールについては省略します。

2.1 ライブラリのimport

!pip install -U albumentations
!pip install "opencv-python-headless<4.3"
import os
import random
import cv2
import albumentations as A
from matplotlib import pyplot as plt
import copy

2.2 クラスの作成

データ拡張、データ読み込み、データ保存および可視化のメソッドを組み込んだクラスを作成しました。

class Data:
  horizontal_transform = A.Compose([
      A.HorizontalFlip(p=1.0),
  ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

  randomsizedcrop_transform = A.Compose([
      A.RandomSizedCrop(min_max_height=[512, 512], height = 1024,  width=1024, p=1.0),
  ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

  rotate90_transform = A.Compose([
      A.RandomRotate90(p=1.0),
  ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

  rotate180_transform = A.Compose([
      A.RandomRotate90(p=1.0),
      A.RandomRotate90(p=1.0),
  ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

  rotate270_transform = A.Compose([
      A.RandomRotate90(p=1.0),
      A.RandomRotate90(p=1.0),
      A.RandomRotate90(p=1.0),
  ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

  def __init__(self, image="", bboxes=0, id="", class_labels=[]):
    self.image = image
    self.bboxes = bboxes
    self.label = 0
    self.id = id
    self.class_labels = class_labels

# jpgとtxtからデータをインポートするメソッド
  def importdata(self, imgpath):
    dirpath = os.path.dirname(imgpath)[:-7]
    id = os.path.splitext(os.path.basename(imgpath))[0]
    txtpath = dirpath + f"/labels/{id}.txt"
    
    img = cv2.imread(imgpath)
    self.image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    bboxes = []
    with open(txtpath) as f:
      for line in f:
        line_list = line.split(" ")
        bbox = line_list[1:]
        bbox = [float(i.replace('\n', '')) for i in bbox]
        bboxes.append(bbox)
    
    self.bboxes = bboxes
    self.label = 0
    self.id = id
    self.class_labels = ["wheat" for i in range(len(bboxes))]

# albumentationsで変換したデータをインポートするメソッド
  def import_transformdata(self, transform_data, origin_data, process):
    self.image = transform_data["image"]
    self.bboxes = transform_data["bboxes"]
    self.label = 0
    self.id = origin_data.id + "_" + process
    self.class_labels = transform_data["class_labels"]
  
# モザイク画像のデータをインポートするメソッド
  def import_mosaicdata(self, img, bboxes, id, class_labels):
    self.image = img
    self.bboxes = bboxes
    self.label = 0
    self.id = id
    self.class_labels = class_labels

# 左右反転処理したデータを返すメソッド
  def horizonflip(self):
    horizon_transformed = Data.horizontal_transform(image=self.image, 
                                                    bboxes=self.bboxes, 
                                                    class_labels=self.class_labels)
    image = horizon_transformed["image"]
    bboxes = horizon_transformed["bboxes"]
    label = 0
    id = self.id + "_hori"
    class_labels = horizon_transformed["class_labels"]
    horizondata = Data(image, bboxes, id, class_labels)
    return horizondata

# ランダムに切り出してリサイズしたデータを返すメソッド
  def randomsizedcrop(self):
    randomsizedcrop_transformed = Data.randomsizedcrop_transform(image=self.image, 
                                                                 bboxes=self.bboxes, 
                                                                 class_labels=self.class_labels)
    image = randomsizedcrop_transformed["image"]
    bboxes = randomsizedcrop_transformed["bboxes"]
    label = 0
    id = self.id + "_hori"
    class_labels = randomsizedcrop_transformed["class_labels"]
    randomsizedcropdata = Data(image, bboxes, id, class_labels)
    return randomsizedcropdata

# 反時計回りに90゚回転したデータを返すメソッド
  def rotate90(self):
    rotate90_transformed = Data.rotate90_transform(image=self.image, 
                                                   bboxes=self.bboxes, 
                                                   class_labels=self.class_labels)
    image =rotate90_transformed["image"]
    bboxes = rotate90_transformed["bboxes"]
    label = 0
    id = self.id + "_rot90"
    class_labels = rotate90_transformed["class_labels"]
    rot90data = Data(image, bboxes, id, class_labels)
    return rot90data

# 反時計回りに180゚回転したデータを返すメソッド
  def rotate180(self):
    rotate180_transformed = Data.rotate180_transform(image=self.image, 
                                                    bboxes=self.bboxes, 
                                                    class_labels=self.class_labels)
    image =rotate180_transformed["image"]
    bboxes = rotate180_transformed["bboxes"]
    label = 0
    id = self.id + "_rot180"
    class_labels = rotate180_transformed["class_labels"]
    rot180data = Data(image, bboxes, id, class_labels)
    return rot180data

# 反時計回りに270゚回転したデータを返すメソッド
  def rotate270(self):
    rotate270_transformed = Data.rotate270_transform(image=self.image, 
                                                     bboxes=self.bboxes, 
                                                     class_labels=self.class_labels)
    image =rotate270_transformed["image"]
    bboxes = rotate270_transformed["bboxes"]
    label = 0
    id = self.id + "_rot270"
    class_labels = rotate270_transformed["class_labels"]
    rot270data = Data(image, bboxes, id, class_labels)
    return rot270data

# 指定のパスにjpgとtxtファイルでデータ保存するメソッド
  def export_data(self, imgdirpath):
    id = self.id
    dirpath = imgdirpath[:-7]
    export_imgpath = imgdirpath + f"/{id}.jpg"
    export_txtpath = dirpath + f"/labels/{id}.txt"

    img = cv2.cvtColor(self.image, cv2.COLOR_RGB2BGR)
    cv2.imwrite(export_imgpath, img)
    
    txt = ""
    for bbox in self.bboxes:
      x_min, y_min, width, height = [i for i in bbox]
      line = f"0 {x_min} {y_min} {width} {height}"
      txt += line + "\n"

    f = open(export_txtpath, 'w')
    f.write(txt) 
    f.close()

# 画像とバウンディングボックスを表示するメソッド
  def visualize(self, img_width, img_height, figsize = (10,10)):

    for bbox in self.bboxes:
      x_mid_nor, y_mid_nor, width_nor, height_nor  = [float(i) for i in bbox]

      width = width_nor * img_width  
      height = height_nor * img_height   

      x_min = x_mid_nor * img_width - width/2   
      y_min = y_mid_nor * img_height - height/2    
      x_max = x_min + width
      y_max = y_min + height

      x_min = int(x_min)
      x_max = int(x_max)
      y_min = int(y_min)
      y_max = int(y_max)

      img = cv2.rectangle(self.image,
                          pt1=(x_min, y_min),
                          pt2=(x_max, y_max),
                          color=(255, 0, 0),
                          thickness=3)
      
    plt.figure(figsize = figsize)
    plt.axis('off')
    plt.imshow(img)

2.3 データセットの読み込み

あらかじめコンペのデータセットからランダムに100のデータを抽出しました。(コード省略)
/content/drive/MyDrive/albumentations/wheat100/imagesにjpgファイルを、
/content/drive/MyDrive/albumentations/wheat100/labelsに対応するtxtファイルを格納しました。(コード省略)

dataset100 = []
imgdirpath = "/content/drive/MyDrive/albumentations/wheat100/images"
for filename in os.listdir(imgdirpath):
  fullpath = imgdirpath + "/" + filename
  data = Data()
  data.importdata(fullpath)
  dataset100.append(data)

# データの可視化
dataset100[0].visualize(1024, 1024)

2.3 データセットの拡張(左右反転)

hori_dataset = [data.horizonflip() for data in dataset100]
dataset = copy.deepcopy(dataset100)
dataset.extend(hori_dataset)

# データの可視化
hori_dataset[0].visualize(1024, 1024)

2.4 データセットの拡張(回転)

rotate90_dataset = [data.rotate90() for data in dataset]    #反時計周りに90゚回転
rotate180_dataset = [data.rotate180() for data in dataset]  #反時計周りに180゚回転
rotate270_dataset = [data.rotate270() for data in dataset]  #反時計周りに270゚回転
dataset.extend(rotate90_dataset)
dataset.extend(rotate180_dataset)
dataset.extend(rotate270_dataset)

# データの可視化
rotate90_dataset[0].visualize(1024, 1024)

2.5 データセットの拡張(切り抜き)

randomsizedcrop_dataset = [data.randomsizedcrop() for data in dataset]
dataset.extend(randomsizedcrop_dataset)

# データの可視化
randomsizedcrop_dataset[0].visualize(1024, 1024)

2.6 データセットの拡張(モザイク画像の作成)

モザイク画像は4種の画像を組み合わせます。
ランダムに切り抜くパターンと、切り抜かず画像を1/2に縮小する2パターンのモザイク画像を作成しました。

# 関数の定義
def generate_mosaicdata(mosaic_group, mode = "noncrop"):
  crop_transform = A.Compose([
  A.RandomCrop(height=512, width=512, p=1.0),
  ], bbox_params=A.BboxParams(format='yolo', min_visibility=0.3, label_fields=['class_labels']))
    
  img_list = [data.image for data in mosaic_group]
  bboxes_list = [data.bboxes for data in mosaic_group]
  classlabels_list = [data.class_labels for data in mosaic_group]
  id_list = [data.id for data in mosaic_group]
  
# ランダムに切り抜く場合の処理
  if not mode == "noncrop":
    pre_transformed_list = [crop_transform(image=img, bboxes=bboxes, class_labels=class_labels) for img, bboxes, class_labels in zip(img_list, bboxes_list, classlabels_list)] 
    transformed_list = []

    for crop_data, origin_data in zip(pre_transformed_list, mosaic_group):
      data = Data()
      data.import_transformdata(crop_data, origin_data, "crop")
      transformed_list.append(data)

    img1, img2, img3, img4 = [data.image for data in transformed_list]  
    bboxes_list = [data.bboxes for data in transformed_list]
    classlabels_list = [data.class_labels for data in transformed_list]
    mod_id = "_".join([id[:7] for id in id_list])
    
# 切り抜かない場合の処理
  else:
    img1, img2, img3, img4 = img_list
    mod_id = "resized_" + "_".join([id[:7] for id in id_list])

  img1_2 = cv2.hconcat([img1, img2])
  img3_4 = cv2.hconcat([img3, img4])
  mod_img = cv2.resize(cv2.vconcat([img1_2, img3_4]), dsize = (1024,1024))

  bboxes1, bboxes2, bboxes3, bboxes4 = bboxes_list
  mod_bboxes1 = []
  for bbox in bboxes1:
    mod_bbox = [i/2 for i in bbox]
    mod_bboxes1.append(mod_bbox)

  mod_bboxes2 = []
  for bbox in bboxes2:
    x, y, width, height = [i/2 for i in bbox]
    mod_x = x + 0.5
    mod_bbox = [mod_x, y, width, height]    
    mod_bboxes2.append(mod_bbox)
  
  mod_bboxes3 = []
  for bbox in bboxes3:
    x, y, width, height = [i/2 for i in bbox]
    mod_y = y + 0.5
    mod_bbox = [x, mod_y, width, height]
    mod_bboxes3.append(mod_bbox)

  mod_bboxes4 = []
  for bbox in bboxes4:
    x, y, width, height = [i/2 for i in bbox]
    mod_x = x + 0.5
    mod_y = y + 0.5
    mod_bbox = [mod_x, mod_y, width, height]
    mod_bboxes4.append(mod_bbox)
  
  mod_bboxes = mod_bboxes1 + mod_bboxes2 + mod_bboxes3 + mod_bboxes4
  mod_classlabels = [cl for classlabels in classlabels_list for cl in classlabels]

  data = Data()
  data.import_mosaicdata(mod_img, mod_bboxes, mod_id, mod_classlabels)

  return data

def generate_mosaicdatalist(mosaic_ori_dataset, num_mosaicimg, mode = "noncrop"):
  mosaic_dataset = copy.deepcopy(mosaic_ori_dataset)
  mosaic_groups = []

  for i in range(num_mosaicimg):
    mosaic_group = random.sample(mosaic_dataset, 4)
    mosaic_groups.append(mosaic_group)
    mosaic_dataset = list(set(mosaic_group) ^ set(mosaic_dataset)) 
    if len(mosaic_dataset) < 4:
      mosaic_dataset = copy.deepcopy(mosaic_ori_dataset)

  mosaic_data = []
  for mosaic_group in mosaic_groups:
    mosaic_data.append(generate_mosaicdata(mosaic_group, mode))
  
  return mosaic_data
# datasetを使用して400枚のモザイク画像を作成する。
noncropmosaic_dataset = generate_4mosaicimg(dataset, 400, mode = "noncrop")
dataset.extend(noncropmosaic_dataset)

# データの可視化(切り抜かない場合)
noncropmosaic_dataset[0].visualize(1024, 1024)
cropmosaic_dataset = generate_4mosaicimg(dataset, 400, mode = "crop")
dataset.extend(cropmosaic_dataset)

# データの可視化(切り抜く場合)
# datasetを使用して400枚のモザイク画像を作成する。
cropmosaic_dataset[0].visualize(1024, 1024)

2.7 データセットの保存

/content/drive/MyDrive/albumentations/wheattrain/imagesにjpgファイルを、
/content/drive/MyDrive/albumentations/wheattrain/labelsに対応するtxtファイルを保存しました。

for data in dataset:
  data.export_data("/content/drive/MyDrive/albumentations/wheattrain/images")

3. YOLOv5による学習

3.1 trainデータセットについて

trainデータセットとして下記のデータ拡張を実施した4データセットを用意しました。

オリジナルデータ:ランダム抽出した100データ
左右反転:オリジナルデータを左右反転したデータ
回転(90,180,270):オリジナルデータと左右反転データを反時計回りに90,180,270゚したデータ
切り抜き:オリジナルデータ、左右反転、回転データをランダムに切り抜いたデータ
モザイク(切抜なし):オリジナルデータ、左右反転、回転データからランダムに4枚選択し、各画像を1/2サイズに縮小して結合したデータ
モザイク(切抜あり):オリジナルデータ、左右反転、回転データからランダムに4枚選択し、各画像からランダムに切り抜いて結合したデータ

データ拡張 dataset1 dataset2 dataset3 dataset4
オリジナルデータ 100 100 100 100
左右反転 0 100 100 100
回転(90,180,270゚) 0 600 600 600
切り抜き 0 0 400 0
モザイク(切抜なし) 0 0 400 0
モザイク(切抜あり) 0 0 0 800
合計(枚) 100 800 1600 1600

3.2 validationデータセットについて

validationデータセットとして、オリジナルデータ以外からランダムに500データを抽出して使用しました。

3.3 学習について

各データセットの学習はearlysoppingを使用して行いました。

4. 結果と考察

4.1 学習結果

validationデータに対する各評価指標の最大値は以下の通りです。

評価指標 dataset1 dataset2 dataset3 dataset4
max_mAP 0.87081 0.90687 0.8818 0.90178
max_precision 0.90592 0.90825 0.88507 0.90869
max_recall 0.8076 0.85322 0.83271 0.85087

max_mAPはdataset2 > dataset4 > dataset3 > dataset1、
max_precisionはdataset4 ≒ dataset2 > dataset1 > dataset3、
max_recallはdataset2 > dataset4 > dataset3 > dataset1となりました。

4.2 考察

  • max_mAP/max_recall
    max_mAPとmax_recallはモザイク画像を含むdataset3・dataset4が高くなると予想していましたが、反転・回転しただけのdataset1が最高値となりました。
    モザイクの各画像については二重に学習することになるので意味がない。
    加えてデータが偏るので、若干過学習ぎみになるといったところでしょうか。
    個人的にはdataset3が最高値になると予想していましたが、下から2番目でした。
    画像の拡大縮小により汎化性能は向上しているが、今回のvalidationデータセットに対する特異的な精度は低下しているのだと思います。

  • max_precision
    max_precisionはオリジナルデータの時点で収束しています。
    datase3が最低値となっており、小麦でない領域を検出していることがmax_mAPの低下の要因だと考えられます。

  • モザイク画像の生成について
    Kaggleではモザイク画像の生成が重要である場合がありますが、本データセットでは精度の向上には寄与しませんでした。
    本データセットは単クラス検出なので、他クラス検出では異なる結果になる可能性があります。
    本筋とは関係ありませんが、bboxの密度が高い領域を抽出してモザイク画像にすることで、データセットの圧縮と学習時間の減少に繋がると感じました。

5. 所感

画像を反転して回転させるだけで精度が確実に向上するので、
学習時間やハードが許容できるのであれば実施しない手はないと思いました。
今回は単クラス検出での検証だったので、他クラス検出であればまた結果が異なってくると思います。
機会があれば検証したいです。

6. 参考

物体検出の代表アルゴリズム YOLOシリーズを徹底解説!【AI論文解説】
YOLOv5 Tutorial - Google Colaboratory (Colab)
Bounding boxes augmentation for object detection

7
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
8