6
Help us understand the problem. What are the problem?

More than 3 years have passed since last update.

posted at

updated at

~ HIKAKIN or SEIKIN ? ~ PyTorchでCNN

はじめに

皆さんはUUUMのHIKAKINとSEIKINを知ってますか?
二人は兄弟でどちらもチャンネル登録者は100万人を超える大人気YouTuberです。

image.png

今回はの二人が投稿したサムネイルを分類するCNNを試します。
画像収集は、以前作成したYouTuberデータセットを使用しました。
YouTuberデータセット公開してみた

本記事でしたこと

  • サムネイルがHIKAKIN,SEIKINのどちらが投稿したものかをCNNで分類
  • 「ResNetをFine-Tuning」 vs. 「ResNetをFine-Tuning (fc層以外の重みを固定)」で比較

本記事は、PyTorchのチュートリアル"Transfer Learning tutorial"でやってることとほぼ同じで、対象データだけ変えてる感じです。ですので、ResNet, Fine-Tuning、実装を詳細に確認したい場合は、以下のリンクをおススメします!
Transfer Learning tutorial - PyTorch Tutorials 0.4.0 documentation
PyTorch (8) Transfer Learning (Ants and Bees) - 人工知能に関する断創録

データ紹介

HIKAKIN画像数: トレーニング451枚, テスト50枚
SEIKIN画像数: トレーニング443枚, テスト50枚

↓が画像例です。二人とも自分の顔を載せてくれているものが多いので、有難い

HIKAKIN画像
image.png

SEIKIN画像
image.png

ResNetをFine-Tuning

やってること
- データ拡張
- Imagenetで学習したResNetを読み込み
- ResNetのfc層の出力を2クラスへ変更
- Fine-Tuning

データ読み込み(データ拡張の変換定義も含む)

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                  for x in ['train', 'valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=4) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes

モデルの学習フローを定義

use_gpu = torch.cuda.is_available()
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        for phase in ['train', 'valid']:
            if phase == 'train':
                scheduler.step()
                model.train(True)   # training mode
            else:
                model.train(False)  # evaluate mode

            running_loss = 0.0
            running_corrects = 0
            for data in dataloaders[phase]:
                inputs, labels = data
                if use_gpu:
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)
                optimizer.zero_grad()

                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels)
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                running_loss += loss.data[0] * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / float( dataset_sizes[phase] )
            epoch_acc = running_corrects.double() / float( dataset_sizes[phase] )

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val acc: {:.4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

モデルの定義

model_ft = models.resnet18(pretrained=True)
num_features = model_ft.fc.in_features; print(num_features);
model_ft.fc = nn.Linear(num_features, 2)

if use_gpu: model_ft = model_ft.cuda()
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)

学習結果

Epoch 0/24
----------
train Loss: 0.7782 Acc: 0.6130
valid Loss: 0.5999 Acc: 0.7700

...

Epoch 24/24
----------
train Loss: 0.3515 Acc: 0.8501
valid Loss: 0.3065 Acc: 0.9100

Training complete in 4m 42s
Best val acc: 0.9100

ResNetをFine-Tuning (fc層以外の重みを固定)

こちらの違いは、以下の点です。
- for param in model_conv.parameters(): param.requires_grad = Falseで、パラメータを固定
- optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)で、更新対象を追加したfc層のみとする

モデルの定義

model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters(): param.requires_grad = False
num_features = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_features, 2)

if use_gpu: model_conv = model_conv.cuda()
criterion = nn.CrossEntropyLoss()
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
model_conv = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25)

学習結果

Epoch 0/24
----------
train Loss: 0.7433 Acc: 0.6085
valid Loss: 0.5631 Acc: 0.7000

...

Epoch 24/24
----------
train Loss: 0.5323 Acc: 0.7360
valid Loss: 0.3660 Acc: 0.8600

Training complete in 2m 9s
Best val acc: 0.8600

実験結果

分類精度

0.9100 ... ResNetをFine-Tuning
0.8600 ... ResNetをFine-Tuning(fc層以外の重みを固定)

サムネイル分類では、fc層以外の重みを固定しないほうが精度が良いことが確認できました。公式チュートリアルの題材で扱っているAnts vs. Beesでは、重みを固定するほうが精度が良かったので、その逆になりました。ImageNetの分類とサムネイル分類が違う特徴を必要とするタスクであるということだったのですよね...

Confusion Matrix

Cofusion Matrixを確認したところ、均等に誤っている感じでした。

import pandas as pd
import seaborn as sn
from sklearn.metrics import confusion_matrix
cmx = confusion_matrix(trues_all, preds_all)
cmx_df = pd.DataFrame( cmx, index=class_names, columns=class_names )
sn.heatmap( cmx_df, cmap='Blues', annot=True)

image.png

誤分類結果

誤分類が発生しているサムネイルを確認してみます。

image.png
image.png
image.png

今回の実験では、9/100が誤りで、ほとんどが無理じゃねって感じですね...、誤り理由としては、

  • 二人でコラボしてるサムネイル
  • HIKAKINが映ってるのにSEIKINの動画
  • 詐欺写メ

などなど...

つぶやき

最近PyTorch含め色んなツールのチュートリアルで勉強してるんですが、チュートリアルで使ってるデータそのままダウンロードしてやるだけだと何となく同じような結果が出るだけで面白くなくて身につかない気がしてます。なので、今後やったことをYouTuberデータセットで試すというのをやる。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Sign upLogin
6
Help us understand the problem. What are the problem?