はじめに
ニューラルネットを用いた機械学習の中でも、セグメンテーションは1ピクセルごとのラベルが必要になることから、ラベルのインプットを工夫しないと学習のボトルネックになることがあります。
整形してあるデータセットであれば良いですが、KaggleなどはCSVファイルでラベルが配布される場合がある為、学習時に都度CSVからラベルを作成してモデルにインプットした場合、学習が非常に遅くなります。
そこで予め全てのラベルをインデックスカラーのpngファイルとして保存しておき、学習時には読込みだけを行う方法を調べたので、備忘録として残します。
環境
- Google Colaboratory Pro
コード
モジュールのimport
import numpy as np
from PIL import Image
import torch
from matplotlib import pyplot as plt
初めに与えられたデータからインデックスカラーのNumPy配列を作成します。
以下は例で、背景を0、2番目に大きい四角を1、最も小さい四角を2と指定しています。
この時データ型をintあるいはuintで指定する必要があります。
label = np.array(
[[0,0,0,0,0,0],
[0,1,1,1,1,0],
[0,1,2,2,1,0],
[0,1,2,2,1,0],
[0,1,1,1,1,0],
[0,0,0,0,0,0]], dtype = 'int8'
)
print("label_shape:", label.shape)
plt.imshow(label)
plt.axis("off")
実行結果
label_shape: (6, 6)
label = Image.fromarray(label, mode = "P")
label.save("./label.png")
画像を読み込む。ワンホットエンコーディングは分類したいクラス数+1で実行する。
例では分類したいのが2種の四角なので、2+1(1は背景分)で実行しています。
label = Image.open("./label.png")
label = np.asarray(label) #NumPy配列に変換
label_tensor = torch.from_numpy(label.astype(np.float32)).clone() #tensorに変換
label_onehot = torch.nn.functional.one_hot(label_tensor.long(), num_classes=3)
print("label_shape:",label.shape)
print(label)
plt.imshow(label)
plt.axis("off")
label_shape: (6, 6)
[[0 0 0 0 0 0]
[0 1 1 1 1 0]
[0 1 2 2 1 0]
[0 1 2 2 1 0]
[0 1 1 1 1 0]
[0 0 0 0 0 0]]
print("label_tensor_shape:",label_tensor.size())
print(label_tensor)
plt.imshow(label_tensor)
plt.axis("off")
label_tensor_shape: torch.Size([6, 6])
tensor([[0., 0., 0., 0., 0., 0.],
[0., 1., 1., 1., 1., 0.],
[0., 1., 2., 2., 1., 0.],
[0., 1., 2., 2., 1., 0.],
[0., 1., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0.]])
print("label_tensor_shape:",label_onehot.size())
print(label_onehot)
label_tensor_shape: torch.Size([6, 6, 3])
tensor([[[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
[1, 0, 0]],
[[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 1, 0],
[0, 1, 0],
[1, 0, 0]],
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 0, 1],
[0, 1, 0],
[1, 0, 0]],
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 0, 1],
[0, 1, 0],
[1, 0, 0]],
[[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 1, 0],
[0, 1, 0],
[1, 0, 0]],
[[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
[1, 0, 0]]])
後はreshapeやcropなどを行いモデルにインプットするだけです。