2
2

[timm] MNISTデータセットを多ラベル問題として解く

Posted at

TL;DR

MNISTデータセットの一部を用いて多ラベル分類モデルの訓練を行います。モデルの構築と訓練にはPyTorchを使用しています。MNISTデータセットは元々は、1ラベル多クラス分類問題ですが、多ラベル2クラス分類問題に変換しているところがポイントです。
主なステップは以下の通りです:

  1. データの読み込み: MNISTデータセットの一部をロード
  2. 多ラベルの生成: 素数(2、3、5、7)で割り切れるかどうかに基づいて複数のラベルを生成
  3. データの可視化: サンプル画像とそれに関連する多ラベルを可視化
  4. データの前処理: 画像とラベルをPyTorchのテンソルに変換し、DataLoaderを作成
  5. モデルの構築: timmライブラリを使用して、ResNet18に基づくモデルを構築
  6. 訓練: AdamオプティマイザとBCEWithLogitsLossを使用してモデルを訓練
  7. 評価: テストセットでモデルの性能を評価

この記事執筆のモチベーションは下記です

  • PyTorchを使用した多ラベル画像分類の基礎的な理解を手助けしたい
  • timmのベーシックな実装例を紹介したい

この記事とソースコードは、KaggleのCodeコミュニティを通して、下記のリンクで共有しています

ソースコード全文
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import timm

# Load a subset of the MNIST dataset
def load_mnist_data(using_ratio=0.2):
    # Load MNIST dataset
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
    # Calculate the subset size for training and testing data
    train_size = int(train_images.shape[0] * using_ratio)
    test_size = int(test_images.shape[0] * using_ratio)
    # Return the subset of the dataset
    return train_images[:train_size], train_labels[:train_size], test_images[:test_size], test_labels[:test_size]

# Create multi-labels based on divisibility by prime numbers
def create_multilabels(labels):
    divisors = [2, 3, 5, 7]
    return np.array([[1 if label % divisor == 0 else 0 for divisor in divisors] for label in labels])

# Visualize sample images and their multi-labels
def visualize_samples(images, labels, multilabels):
    # Get the first occurrence index of each label (0-9)
    sample_indices = [np.where(labels == i)[0][0] for i in range(10)]
    # Create a 2x5 grid for visualization
    fig, axes = plt.subplots(2, 5, figsize=(12, 6))
    # Display each sample image and its multi-labels
    for i, idx in enumerate(sample_indices):
        axes[i // 5, i % 5].imshow(images[idx].reshape(28, 28), cmap='gray')
        axes[i // 5, i % 5].set_title(f"Label: {labels[idx]}")
        multilabel_str = ', '.join(map(str, multilabels[idx]))
        axes[i // 5, i % 5].text(0, 32, f"Multi-labels: [{multilabel_str}]")
        axes[i // 5, i % 5].axis('off')
    plt.tight_layout()
    plt.show()

# Preprocess the data and create a DataLoader
def preprocess_data(images, multilabels):
    images = images.reshape(-1, 28, 28, 1).astype('float32') / 255
    images = np.transpose(images, (0, 3, 1, 2))
    images = torch.tensor(images, dtype=torch.float32)
    multilabels = torch.tensor(multilabels, dtype=torch.float32)
    dataset = TensorDataset(images, multilabels)
    loader = DataLoader(dataset, batch_size=128, shuffle=True)
    return loader
    
    
# Build a custom PyTorch model using timm library
class MyImageModel(nn.Module):
    def __init__(self, model_name: str, pretrained: bool, hidden_dim: int, out_dim: int):
        super(MyImageModel, self).__init__()
        # Initialize the backbone model
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        self.in_features = self.backbone.num_features
        # Define the head of the model
        self.head = nn.Sequential(
            nn.Linear(self.in_features, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, out_dim)
        )
        
    def forward(self, x):
        # Adjust the number of channels
        x = x.repeat(1, 3, 1, 1)
        h = self.backbone(x)
        y = self.head(h)
        return y



# Train the PyTorch model
def train_pytorch_model(model, train_loader, epochs=3):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters())
    
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}")
                
# Evaluate the PyTorch model
def evaluate_pytorch_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    criterion = nn.BCEWithLogitsLoss()
    total_loss = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()
            
            predicted = (torch.sigmoid(output) > 0.5).float()
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    accuracy = correct / (total * 4)  # 4 multi-labels
    avg_loss = total_loss / len(test_loader)
    print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy * 100:.2f}%")


def main():
    # 1. Dataset
    train_images, train_labels, test_images, test_labels = load_mnist_data()
    train_multilabels = create_multilabels(train_labels)
    test_multilabels = create_multilabels(test_labels)
    print(train_images.shape, train_labels.shape, test_images.shape, test_labels.shape)
    # (12000, 28, 28) (12000,) (2000, 28, 28) (2000,)

    # 2. Visualization
    visualize_samples(train_images, train_labels, train_multilabels)

    # 3. Modeling
    # PyTorch model(resnet18d by timm)
    train_loader = preprocess_data(train_images, train_multilabels)
    test_loader = preprocess_data(test_images, test_multilabels)
    pt_model = MyImageModel(model_name='resnet18d', pretrained=False, hidden_dim=128, out_dim=4)
    train_pytorch_model(pt_model, train_loader)
    evaluate_pytorch_model(pt_model, test_loader)

if __name__ == "__main__":
    main()

ソースコードのざっくり解説

データの読み込み

def load_mnist_data(using_ratio=0.2):
    ...

この関数は、TensorFlowのKeras APIからMNISTデータセットを読み込み、訓練データとテストデータの一部を返します。using_ratioパラメータで、使用するデータの割合を指定できます。

多ラベルの生成

def create_multilabels(labels):
    divisors = [2, 3, 5, 7]
    return np.array([[1 if label % divisor == 0 else 0 for divisor in divisors] for label in labels])

この関数は、各画像(数字)が2, 3, 5, 7で割り切れるかどうかに基づいて多ラベルを生成します。例えば、ラベルが6の場合、2と3で割り切れるので、その多ラベルは[1, 1, 0, 0]になります。

データの可視化

def visualize_samples(images, labels, multilabels):
    ...

この関数は、訓練データの一部とそれに対応する多ラベルを下記のように可視化するためのものです。

image.png

データの前処理

def preprocess_data(images, multilabels):
    ...
    return loader

この関数は、画像データと多ラベルデータをPyTorchのTensorに変換し、DataLoaderを生成します。

モデルの構築

class MyImageModel(nn.Module):
    ...

timmライブラリを使用して、カスタムの画像分類モデルを定義しています。ここでは、事前訓練されていないResNet18をベースとしています。

訓練

def train_pytorch_model(model, train_loader, epochs=3):
    ...

この関数は、AdamオプティマイザとBCEWithLogitsLoss損失関数を用いて、モデルを訓練します。

評価

def evaluate_pytorch_model(model, test_loader):
    ...

この関数は、訓練されたモデルの性能をテストデータセットで評価します。


参考文献


2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2