SIGNATEの練習問題「鋳造製品の欠陥検出」に取り組みました。
https://signate.jp/competitions/406
![鋳造製品の欠陥検出.jpeg](https://qiita-user-contents.imgix.net/https%3A%2F%2Fqiita-image-store.s3.ap-northeast-1.amazonaws.com%2F0%2F2451452%2F8a25620c-dc9b-c2d1-6200-f0121f1e28c9.jpeg?ixlib=rb-4.0.0&auto=format&gif-q=60&q=75&s=70eb6473cea30874dd2f6887827a82b3)
#1. ライブラリ
import cv2
import numpy as np
import pandas as pd
import torch
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
%matplotlib inline
def imshow(img):
plt.imshow(np.transpose(img, (1,2,0)))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#2. 訓練データセットの読み込み
shape=(300, 300)
def resize(img, shape=shape):
return cv2.resize(img, shape)
trans = transforms.Compose([
np.array,
resize,
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406], [0.229, 0.224, 0.225])
])
train_imgs = ImageFolder('./train/', transform=trans)
#3. 学習用と評価用のデータに分割
batch_size = 16
data = train_imgs
train_size = int(0.8 * len(data))
validation_size = len(data) - train_size
data_size = {"train":train_size, "validation":validation_size}
data_train, data_validation = torch.utils.data.random_split(data, [train_size, validation_size])
train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(data_validation, batch_size=batch_size, shuffle=False)
dataloaders = {"train":train_loader, "validation":validation_loader}
#4. モデルの定義
import torch.nn as nn
from torchvision import transforms, models
import torch.optim as optim
class Resnet(nn.Module):
def __init__(self):
super(Resnet,self).__init__()
self.resnet = models.efficientnet_b0(pretrained=True)
self.fc = nn.Linear(1000, 2) #2分類
self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.Adam(self.parameters())
def forward(self, x):
x = self.resnet(x)
x = self.fc(x)
return x
def train(model, train_loader):
# 今は学習時であることを明示するコード
model.train()
# 正しい予測数、損失の合計、全体のデータ数を数えるカウンターの初期化
total_correct = 0
total_loss = 0
total_data_len = 0
for batch_imgs, batch_labels in train_loader:
batch_imgs = batch_imgs.to(device)
batch_labels = batch_labels.to(device)
outputs = model(batch_imgs)
_, pred_labels = torch.max(outputs, 1)
model.optimizer.zero_grad()
loss = model.criterion(outputs, batch_labels)
loss.backward()
model.optimizer.step()
# 集計
batch_size = len(batch_labels)
for i in range(batch_size):
total_data_len += 1
if pred_labels[i] == batch_labels[i]:
total_correct += 1
total_loss += loss.item()
accuracy = total_correct/total_data_len*100
loss = total_loss/total_data_len
return accuracy, loss
def predict(model, batch_imgs):
outputs = model(batch_imgs)
_, pred_labels = torch.max(outputs, axis=1)
return pred_labels
def test(model, data_loader):
model.eval()
total_data_len = 0
total_correct = 0
for batch_imgs, batch_labels in data_loader:
batch_imgs = batch_imgs.to(device)
batch_labels = batch_labels.to(device)
pred_labels = predict(model, batch_imgs)
# 集計
batch_size = len(pred_labels)
for i in range(batch_size):
total_data_len += 1
if pred_labels[i] == batch_labels[i]:
total_correct += 1
acc = 100.0 * total_correct/total_data_len
return acc
#5. モデルの学習
model = Resnet()
train_acc = []
val_acc = []
model = model.to(device)
for i in range(1, 100):
acc, loss = train(model, train_loader)
train_acc.append(acc)
acc1 = test(model, validation_loader)
val_acc.append(acc1)
print(f'[epoch{i} train_loss: {loss:.4f}, train_acc: {acc:.2f} %, val_acc: {acc1:.2f} %]')
plt.plot(train_acc, label = "train_accuracy")
plt.plot(val_acc, label = "validation_accuracy")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.savefig("accurcay.jpeg")
plt.show()
model_path = 'cast.model'
torch.save(model.state_dict(), model_path)
#6. モデルの出力
test_imgs = torchvision.datasets.ImageFolder('./test/', transform=trans)
pred_label = []
with torch.no_grad():
for i in range(len(test_imgs)):
test_img, _ = test_imgs[i]
input_img = torch.unsqueeze(test_img, 0)
input_img = input_img.to(device)
label_tensor = predict(model, input_img)
label = label_tensor.cpu().detach().numpy()
pred_label.append(list(label))
#感想
・転移学習で扱うモデルをresnet18→efficientnet_b0に変更したら、
validation_accuracyの収束がよくなりました。