1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【セグメンテーション】ラベルデータをインデックスカラーのpngファイルとして保存・読込みする

Posted at

はじめに

ニューラルネットを用いた機械学習の中でも、セグメンテーションは1ピクセルごとのラベルが必要になることから、ラベルのインプットを工夫しないと学習のボトルネックになることがあります。
整形してあるデータセットであれば良いですが、KaggleなどはCSVファイルでラベルが配布される場合がある為、学習時に都度CSVからラベルを作成してモデルにインプットした場合、学習が非常に遅くなります。
そこで予め全てのラベルをインデックスカラーのpngファイルとして保存しておき、学習時には読込みだけを行う方法を調べたので、備忘録として残します。

環境

  • Google Colaboratory Pro

コード

モジュールのimport

import numpy as np
from PIL import Image
import torch
from matplotlib import pyplot as plt

初めに与えられたデータからインデックスカラーのNumPy配列を作成します。
以下は例で、背景を0、2番目に大きい四角を1、最も小さい四角を2と指定しています。
この時データ型をintあるいはuintで指定する必要があります。

label = np.array(
    [[0,0,0,0,0,0],
    [0,1,1,1,1,0],
    [0,1,2,2,1,0],
    [0,1,2,2,1,0],
    [0,1,1,1,1,0],
    [0,0,0,0,0,0]], dtype = 'int8'
)

print("label_shape:", label.shape)
plt.imshow(label)
plt.axis("off")
実行結果
label_shape: (6, 6)

label.png
画像をPILのパレットモードで保存

label = Image.fromarray(label, mode = "P")
label.save("./label.png")

画像を読み込む。ワンホットエンコーディングは分類したいクラス数+1で実行する。
例では分類したいのが2種の四角なので、2+1(1は背景分)で実行しています。

label = Image.open("./label.png")
label = np.asarray(label)   #NumPy配列に変換
label_tensor = torch.from_numpy(label.astype(np.float32)).clone()   #tensorに変換
label_onehot = torch.nn.functional.one_hot(label_tensor.long(), num_classes=3)
print("label_shape:",label.shape)
print(label)
plt.imshow(label)
plt.axis("off")
label_shape: (6, 6)
[[0 0 0 0 0 0]
 [0 1 1 1 1 0]
 [0 1 2 2 1 0]
 [0 1 2 2 1 0]
 [0 1 1 1 1 0]
 [0 0 0 0 0 0]]

label.png

print("label_tensor_shape:",label_tensor.size())
print(label_tensor)
plt.imshow(label_tensor)
plt.axis("off")
label_tensor_shape: torch.Size([6, 6])
tensor([[0., 0., 0., 0., 0., 0.],
        [0., 1., 1., 1., 1., 0.],
        [0., 1., 2., 2., 1., 0.],
        [0., 1., 2., 2., 1., 0.],
        [0., 1., 1., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.]])

label.png

print("label_tensor_shape:",label_onehot.size())
print(label_onehot)
label_tensor_shape: torch.Size([6, 6, 3])
tensor([[[1, 0, 0],
         [1, 0, 0],
         [1, 0, 0],
         [1, 0, 0],
         [1, 0, 0],
         [1, 0, 0]],

        [[1, 0, 0],
         [0, 1, 0],
         [0, 1, 0],
         [0, 1, 0],
         [0, 1, 0],
         [1, 0, 0]],

        [[1, 0, 0],
         [0, 1, 0],
         [0, 0, 1],
         [0, 0, 1],
         [0, 1, 0],
         [1, 0, 0]],

        [[1, 0, 0],
         [0, 1, 0],
         [0, 0, 1],
         [0, 0, 1],
         [0, 1, 0],
         [1, 0, 0]],

        [[1, 0, 0],
         [0, 1, 0],
         [0, 1, 0],
         [0, 1, 0],
         [0, 1, 0],
         [1, 0, 0]],

        [[1, 0, 0],
         [1, 0, 0],
         [1, 0, 0],
         [1, 0, 0],
         [1, 0, 0],
         [1, 0, 0]]])

後はreshapeやcropなどを行いモデルにインプットするだけです。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?