1
2

More than 1 year has passed since last update.

【SIGNATE】鋳造製品の欠陥検出

Last updated at Posted at 2022-01-23

SIGNATEの練習問題「鋳造製品の欠陥検出」に取り組みました。
https://signate.jp/competitions/406

鋳造製品の欠陥検出.jpeg

1. ライブラリ

import cv2
import numpy as np
import pandas as pd

import torch
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms, models
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
%matplotlib inline

def imshow(img):
    plt.imshow(np.transpose(img, (1,2,0)))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

2. 訓練データセットの読み込み

shape=(300, 300)

def resize(img, shape=shape):
    return cv2.resize(img, shape)

trans = transforms.Compose([
    np.array,
    resize,
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229, 0.224, 0.225])
])

train_imgs = ImageFolder('./train/', transform=trans)

3. 学習用と評価用のデータに分割

batch_size = 16
data = train_imgs

train_size = int(0.8 * len(data))
validation_size  = len(data) - train_size

data_size  = {"train":train_size, "validation":validation_size}
data_train, data_validation = torch.utils.data.random_split(data, [train_size, validation_size])

train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(data_validation, batch_size=batch_size, shuffle=False)

dataloaders  = {"train":train_loader, "validation":validation_loader}

4. モデルの定義

import torch.nn as nn
from torchvision import transforms, models
import torch.optim as optim

class Resnet(nn.Module):
    def __init__(self):
        super(Resnet,self).__init__()
        self.resnet = models.efficientnet_b0(pretrained=True)
        self.fc = nn.Linear(1000, 2)                                #2分類
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.parameters())

    def forward(self, x):
        x = self.resnet(x)
        x = self.fc(x)
        return x

def train(model, train_loader):
    # 今は学習時であることを明示するコード
    model.train()
    # 正しい予測数、損失の合計、全体のデータ数を数えるカウンターの初期化
    total_correct = 0
    total_loss = 0
    total_data_len = 0
    for batch_imgs, batch_labels in train_loader:

        batch_imgs = batch_imgs.to(device)
        batch_labels = batch_labels.to(device)

        outputs = model(batch_imgs)
        _, pred_labels = torch.max(outputs, 1)
        model.optimizer.zero_grad()
        loss = model.criterion(outputs, batch_labels)
        loss.backward()
        model.optimizer.step()

        # 集計
        batch_size = len(batch_labels)
        for i in range(batch_size):
            total_data_len += 1
            if pred_labels[i] == batch_labels[i]:
                total_correct += 1
        total_loss += loss.item()
    accuracy = total_correct/total_data_len*100
    loss = total_loss/total_data_len
    return accuracy, loss

def predict(model, batch_imgs):
    outputs = model(batch_imgs)
    _, pred_labels = torch.max(outputs, axis=1)
    return pred_labels

def test(model, data_loader):
    model.eval()
    total_data_len = 0
    total_correct = 0
    for batch_imgs, batch_labels in data_loader:

        batch_imgs = batch_imgs.to(device)
        batch_labels = batch_labels.to(device)

        pred_labels = predict(model, batch_imgs)
        # 集計
        batch_size = len(pred_labels)
        for i in range(batch_size):
            total_data_len += 1
            if pred_labels[i] == batch_labels[i]:
                total_correct += 1
    acc = 100.0 * total_correct/total_data_len
    return acc

5. モデルの学習

model = Resnet()
train_acc = []
val_acc = []
model = model.to(device)

for i in range(1, 100):

    acc, loss = train(model, train_loader)
    train_acc.append(acc)
    acc1 = test(model, validation_loader)
    val_acc.append(acc1)
    print(f'[epoch{i} train_loss: {loss:.4f}, train_acc: {acc:.2f} %, val_acc: {acc1:.2f} %]')

plt.plot(train_acc, label = "train_accuracy")
plt.plot(val_acc, label = "validation_accuracy")

plt.legend()
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.savefig("accurcay.jpeg")
plt.show()

accurcay.jpeg

5. モデルの保存

model_path = 'cast.model'
torch.save(model.state_dict(), model_path)

6. モデルの出力

test_imgs = torchvision.datasets.ImageFolder('./test/', transform=trans)

pred_label = []

with torch.no_grad():
    for i in range(len(test_imgs)):
        test_img, _ = test_imgs[i]
        input_img = torch.unsqueeze(test_img, 0)

        input_img = input_img.to(device)

        label_tensor = predict(model, input_img)
        label = label_tensor.cpu().detach().numpy()
        pred_label.append(list(label))

感想

・転移学習で扱うモデルをresnet18→efficientnet_b0に変更したら、
 validation_accuracyの収束がよくなりました。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2