26
20

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

pytorchで画像を多クラス多ラベル分類

Last updated at Posted at 2020-02-15

はじめに

  • タイトルの通りの事をやってみた。一通り出来たのでそのメモとして。
  • 内容に深く触れられない理由があり、ちょい雑になってる部分もあり。
    • 例えばその変数どこで宣言したの?的なのがあるかも
    • 元が notebook で切り貼り加工したのでマジックコマンドがあったり
  • 所々に参考リンク入れてます

やったこと

  • 画像の多クラス多ラベル分類。
  • 「この画像はクラスAだね。この画像はAとBが該当するね。」みたいなイメージ。
  • pytorch を使ってみたかったので実装は pytorchで。
  • そんなん当たり前やん!こんなコメント要る?みたいなのが散見するのは未熟であるため。

フォルダ構成

以下のようにしました。が、これは正直ベストではないかも。
調査もそこそこに走り出してしまったのでこうなったけど。

フォルダ構成
.
├── data
│   ├── labels        // イメージとラベルの組み合わせjson置き場
│   │     ├── A.json
│   │     ├── B.json
│   │     └── 他たくさんの json
│   └── images        // jpg画像置き場。学習用と検証用混在
│         ├── A.jpg
│         ├── B.jpg
│         └── 他たくさんの jpg
├── model             // モデル保存先
└── predict           // 予測したい画像置き場として設置

ちなみに labels 配下の json の中身はこんな内容。
キーが画像名、バリューがクラス情報(1 or 0)になっている。

サンプル
# A.json
{
    "A": {
        "ラベルA": 1,
        "ラベルB": 1,
        "ラベルC": 0
    }
}

# B.json
{
    "B": {
        "ラベルA": 0,
        "ラベルB": 0,
        "ラベルC": 1
    }
}

コード

色々準備

# ref: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

from PIL import Image
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import pathlib
import random


# GPUあれば使う
def check_cuda():
    return 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = torch.device(check_cuda())

# 学習データ、テストデータ分割
image_set = {pathlib.Path(i).stem for i in pathlib.Path('data/images').glob('*.jpg')}
n_data = len(image_set)
traindata_rate = 0.7
train_idx = random.sample(range(n_data), int(n_data*traindata_rate))

_trainset = {}
_testset = {}
for i, _tuple in enumerate(image_set.items()):
    k, v = _tuple    
    if i in train_idx:
        _trainset[k] = v
    else :
        _testset[k] = v

Transform

# ref: https://qiita.com/takurooo/items/e4c91c5d78059f92e76d
trfm = transforms.Compose([
    transforms.Resize((100, 100)),    # image size --> (100, 100)
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

Dataset

class MultiLabelDataSet(torch.utils.data.Dataset):
    def __init__(self, labels, image_dir='./data/images', ext='.jpg', transform=None):
        self.labels = labels
        self.image_dir = image_dir
        self.ext = ext
        self.transform = transform

        self.keys = list(labels.keys())
        self.vals = list(labels.values())

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_path = f'{self.image_dir}/{self.keys[idx]}{self.ext}'
        image_array = Image.open(image_path)
        if self.transform:
            image = self.transform(image_array)
        else:
            image = torch.Tensor(np.transpose(image_array, (2, 0, 1)))/255  # for 0~1 scaling
            
        label = torch.Tensor(list(self.vals[idx].values()))

        return {'image': image, 'label': label}

DataLoader

batch_size = 8

trainset = MultiLabelDataSet(_trainset, transform=trfm)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)

testset = MultiLabelDataSet(_testset, transform=trfm)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ['A', 'B', 'C'...] みたいな

データチェック

import matplotlib.pyplot as plt
%matplotlib inline

# functions to show an image
def imshow(img):
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()

# sample data
dataiter = iter(trainloader)
tmp = dataiter.next()
images = tmp['image']
labels = tmp['label']

# print images
imshow(torchvision.utils.make_grid(images))

モデル

レイヤーとかチャンネル数は適当…
BCEWithLogitsLoss を使うのでシグモイドは噛まさない(ググったらそう言ってた)

class MultiClassifier(nn.Module):
    def __init__(self):
        super(MultiClassifier, self).__init__()

        self.ConvLayer1 = nn.Sequential(
            # ref(H_out & W_out): https://pytorch.org/docs/stable/nn.html#conv2d
            nn.Conv2d(3, 32, 3),
            nn.MaxPool2d(2),
            nn.ReLU(),
            )

        self.ConvLayer2 = nn.Sequential(
            nn.Conv2d(32, 64, 3),
            nn.MaxPool2d(2),
            nn.ReLU(),
            )

        self.ConvLayer3 = nn.Sequential(
            nn.Conv2d(64, 128, 3),
            nn.MaxPool2d(2),
            nn.ReLU(),
            )    

        self.ConvLayer4 = nn.Sequential(
            nn.Conv2d(128, 256, 3),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Dropout(0.2, inplace=True),
            )    

        self.Linear1 = nn.Linear(256 * 4 * 4, 2048)
        self.Linear2 = nn.Linear(2048, 1024)
        self.Linear3 = nn.Linear(1024, 512)
        self.Linear4 = nn.Linear(512, len(classes))


    def forward(self, x):
        x = self.ConvLayer1(x)
        x = self.ConvLayer2(x)
        x = self.ConvLayer3(x)
        x = self.ConvLayer4(x)
#         print(x.size())
        x = x.view(-1, 256 * 4 * 4)
        x = self.Linear1(x)
        x = self.Linear2(x)
        x = self.Linear3(x)
        x = self.Linear4(x)
        return x

def try_gpu(target):
    if check_cuda():
        device = torch.device(check_cuda())
        target.to(device)

model = MultiClassifier()
try_gpu(model)

トレーニング

criterion で急に pos_weight という変数が出てくるが、これは正のクラス正解時の重み付けのため。
https://pytorch.org/docs/stable/nn.html#torch.nn.BCEWithLogitsLoss

そういう操作が不要なら指定なしでOK。自分は正解時の重みを増やしたかったので指定した。
詳細は ref としてリンクを貼っているのでそちらにて 自分は説明を逃げる

# ref: https://medium.com/@thevatsalsaglani/training-and-deploying-a-multi-label-image-classifier-using-pytorch-flask-reactjs-and-firebase-c39c96f9c427
import numpy as np
from pprint import pprint
from torch.autograd import Variable
import torch.optim as optim

# ref: https://discuss.pytorch.org/t/bceloss-vs-bcewithlogitsloss/33586
# ref: https://discuss.pytorch.org/t/weights-in-bcewithlogitsloss/27452
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
try_gpu(criterion)

optimizer = optim.SGD(model.parameters(), lr = 0.005, momentum = 0.9)

def pred_acc(original, predicted):
    # ref: https://pytorch.org/docs/stable/torch.html#module-torch
    return torch.round(predicted).eq(original).sum().numpy()/len(original)


def fit_model(epochs, model, dataloader, phase='training', volatile = False):
    if phase == 'training':
        model.train()
        
    if phase == 'validataion':
        model.eval()
        volatile = True
        
    running_loss = []
    running_acc = []
    for i, data in enumerate(dataloader):
        inputs, target = Variable(data['image']), Variable(data['label'])
        
        # for GPU
        if device != 'cpu':
            inputs, target = inputs.to(device), target.to(device)

        if phase == 'training':
            optimizer.zero_grad()  # 勾配初期化

        ops = model(inputs)
         acc_ = []
         for j, d in enumerate(ops):
             acc = pred_acc(torch.Tensor.cpu(target[j]), torch.Tensor.cpu(d))
             acc_.append(acc)

        loss = criterion(ops, target)
        running_loss.append(loss.item())
        running_acc.append(np.asarray(acc_).mean())
        
        if phase == 'training':
            loss.backward()  # 誤差逆伝播
            optimizer.step() # パラメータ更新

    total_batch_loss = np.asarray(running_loss).mean()
    total_batch_acc = np.asarray(running_acc).mean()

    if epochs % 10 == 0:
        pprint(f"[{phase}] Epoch: {epochs}, loss: {total_batch_loss}.")
        pprint(f"[{phase}] Epoch: {epochs}, accuracy: {total_batch_acc}.")
    
    return total_batch_loss, total_batch_acc


from tqdm import tqdm

num = 50
best_val = 99
trn_losses = []; trn_acc = []
val_losses = []; val_acc = []
for idx in tqdm(range(1, num+1)):
    trn_l, trn_a = fit_model(idx, model, trainloader)
    val_l, val_a = fit_model(idx, model, testloader, phase='validation')
    trn_losses.append(trn_l); trn_acc.append(trn_a)
    val_losses.append(val_l); val_acc.append(val_a)

    if best_val > val_l:
        torch.save(model.state_dict(), f'model/best_model.pth')
        best_val = val_l
        best_idx = idx

予測

def get_tensor(img):
    tfms = transforms.Compose([
        transforms.Resize((100, 100)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    return tfms(Image.open(img)).unsqueeze(0)

def predict(img, label_lst, model):
    tnsr = get_tensor(img)
    op = model(tnsr)  # Predict result(float)
    op_b = torch.round(op) # Rounding result(0 or 1)
    op_b_np = torch.Tensor.cpu(op_b).detach().numpy()
    preds = np.where(op_b_np == 1)[1]  # result == 1
    
    sigs_op = torch.Tensor.cpu(torch.round((op)*100)).detach().numpy()[0]
    o_p = np.argsort(torch.Tensor.cpu(op).detach().numpy())[0][::-1]  # label index order by score desc
    
    # anser label
    label = [label_lst[i] for i in preds]
    
    # all result
    arg_s = {label_lst[int(j)] : sigs_op[int(j)] for j in o_p}

    return label, dict(arg_s.items())


model = MultiClassifier()
model.load_state_dict(torch.load(f'model/best_model.pth', map_location=torch.device('cpu')))
model = model.eval()    # 推論モードに切り替え

target = 'XXXXXX'
img = Image.open(f'predict/{target}.jpg').resize((100, 100))
plt.imshow(img)

_, all_result = predict(f'predict/{target}.jpg', classes, model)
print('predict top5: ', *sorted(all_result.items(), key=lambda x: x[1], reverse=True)[:5])

最後に

実装は以上。

Data Augmentation(これとか実装も楽そう)、モデル磨き込み、
評価時の適切な weight 設定など行えばまだまだ精度向上の余地があると思う。

とりあえず作りたいのは出来たので自己満。










…参考までに。自分はこんな感じで予測するものをつくりました。

sample.png

※完全にプライベートの作品であり、所属組織を代表するものではありません

26
20
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
26
20

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?