9
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Semantic Segmentationの実装

Last updated at Posted at 2021-11-15

#目次
・学習済みモデルの取得
・学習済みモデルによる推論

##学習済みモデルの取得
PyTorchで実装の学習済みモデルの取得する手段として、
・torchvision.models
・PyTorch Hub
・GitHub
がありますが、PyTorchHubで試してみます。
PyTorch Hub は研究の再現性を容易にするために設計された、事前学習モデルを公開するモデルレポジトリです。

こちらでは下記のモデルが利用可能です。
#####・U-Net for BRAIN MRI
#####・DeeplabV3 ResNet101
#####・FCN ResNet101
U-Net は Brain MRI segmentation dataset での学習済みモデルであり、DeeplabV3 と FCN に関しては、torchvision と同様です。

こちらでは、生成モデルや、自然言語モデル等、torchvision よりも幅広い領域の学習済みモデルを探すことができます。

また、PyTorch Hub の各モデルのページでは、モデルの読み込み方や、どのような入力データが必要で、どのような出力となるか、などの詳細も記載されています。
##学習済みモデルによる推論
・データの用意
 ・データの前処理
 ・モデルの入力の形状にリサイズ
・推論
 ・モデルの用意
 ・推論
・後処理
 ・one-hot 表現 -> マスク画像化
 ・カラーパレットで可視化
###サンプルデータの用意

img

000001.jpg

#元のサイズに戻すためにサイズを取得
tmp = np.array(img)
h,w = tmp.shape[0],tmp.shape[1]

(900,1600)
###画像の前処理
学習済みモデルは、(Batch_size, 3, H, W) のインプットが想定されています。また、こちらのモデルに関しては下記の条件が指定されています。
H, W は224 pixel 以上
輝度は mean = [0.485, 0.456, 0.406] , std = [0.229, 0.224, 0.225] で標準化

preprocess = transforms.Compose([
    transforms.Resize(320),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 前処理
img = preprocess(img)
# (3, H, W) -> (1, 3, H, W)
img_batch = img.unsqueeze(0)
print(img.size())
print(img_batch.size())

torch.Size([3, 320, 568])
torch.Size([1, 3, 320, 568])
###モデルのダウンロード

model = torch.hub.load('pytorch/vision:v0.5.0', 'deeplabv3_resnet101', pretrained=True)
# 推論モード
model.eval()

###推論

# device へ転送
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
img_batch = img_batch.to(device)
# 推論
output = model(img_batch)['out']

# もとの大きさに戻す
output = F.upsample(output, size=(h, w), mode='bilinear')
output.size() #(batch_size, n_classes, H, W)

torch.Size([1, 21, 900, 1600])

# output[0] の形状は (c, h, w)
# dim=0 で channel 方向で argmax を実行するよう指定
out = torch.argmax(output[0], dim=0) 
out.size()

torch.Size([900, 1600])

###カラーパレットで可視化
これに色をつけて表示するには、クラスごとに (R, G, B) の値を割り当てる必要があります。seaborn の color_palette メソッドで、クラスごとにカラーパレット(RGB)を割り当て色付きの可視化を行います。

```Python
def give_color(img, n_classes):
  colors = sns.color_palette(n_colors=n_classes)
  color_mask = np.zeros((img.shape[0], img.shape[1], 3))
  for c in range(n_classes):
    c_bool = (img == c)
    color_mask[:, :, 0] += (c_bool * colors[c][0])
    color_mask[:, :, 1] += (c_bool * colors[c][1])
    color_mask[:, :, 2] += (c_bool * colors[c][2])
  return color_mask

マスク画像を torch.tensor 型から numpy.ndarray 型に変換します。

mask = out.cpu().detach().numpy()
mask.shape

(900, 1600)

color_mask = give_color(mask, 21)
plt.imshow(color_mask);

image.png

###マスク画像の重ね合わせ

from PIL import Image

#元画像に透過して重ね合わせる
origin = Image.open(path)
out = Image.fromarray(np.uint8(color_mask*255))
mask = Image.new('L', origin.size, 128)

im = Image.composite(origin, out, mask)
im

image.png

9
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?