1
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

3D MNISTで3DCNN練習

Posted at

はじめに

3Dの情報を使って機械学習をしていきたいので、とりあえず3DMNISTを使って練習してみました。
Kaggleのノートブックでの実装になります。ローカルでの動作は確認していません。

実装

こちらの実装を参考にしました。
https://www.kaggle.com/code/michaelcripman/3d-mnist-basic-cnn-adorable-visualisations/comments

データ読み込み

Kaggle上で3DMNISTのデータが公開されているので、それを使います。
https://www.kaggle.com/datasets/daavoo/3d-mnist/data

with h5py.File("../input/3d-mnist/full_dataset_vectors.h5", "r") as hf:    
    # Split the data into training/test features/targets
    x_train = hf["X_train"][:]
    y_train = hf["y_train"][:]
    x_test = hf["X_test"][:] 
    y_test = hf["y_test"][:]

print(f"Before x_train size:{x_train.shape}")
print(f"Before x_test size:{x_test.shape}")

"""出力
Before x_train size:(10000, 4096)
Before x_test size:(2000, 4096)
"""

データ変換

3DのCNNをしたいので3次元に直します。x_trainの中身は1行に4096(16x16x16)要素入っているので、16x16x16の3次元に変換します。
3次元の配列がちょっとイメージしにくかったのですが、@ken_yoshiさんの下記の記事で改めてイメージできました。

def prepare_points(tensor, data_config, threshold = False):
    
    if threshold:
        """
        thresholdに何か入ってれば値を0 or 1にする
        """
        tensor = np.where(
            tensor > data_config['threshold'], 
            data_config['lower'], 
            data_config['upper']
        )
    tensor = tensor.reshape((
            tensor.shape[0], 
            data_config['y_shape'], # = 16
            data_config['x_shape'], # = 16
            data_config['z_shape']  # = 16
        ))
    return tensor
    
x_train = prepare_points(x_train, data_config)
x_test = prepare_points(x_test, data_config)
print(f"After x_train size:{x_train.shape}")
print(f"After x_test size:{x_test.shape}")

"""出力
After x_train size:(10000, 16, 16, 16)
After x_test size:(2000, 16, 16, 16)
"""

可視化

3次元に変換したのでいくつか見てみます。

for index in range(3):
    plot_idx = random.randint(0, x_train.shape[0]) # 適当に選ぶ
    plot_img_3d = x_train[plot_idx]
    plot_label = y_train[plot_idx]

    plot_data = []
    for x in range(data_config['x_shape']): # x座標取得
        for y in range(data_config['y_shape']): # y座標取得
            for z in range(data_config['z_shape']): # z座標取得
                val = int(plot_img_3d[x, y, z] * 255) # 0 ~ 255に変換する
                plot_data.append([x, y, z, val])

    plot_df = pd.DataFrame(plot_data, columns=["x", "y", "z", "val"])
    plot_df = plot_df.loc[plot_df["val"] > 0]


    fig = go.Figure(data=[go.Scatter3d(x=plot_df['x'], y=plot_df['y'], z=plot_df['z'], 
                                       mode='markers',
                                       text=f"current label: {plot_label.nonzero()[0][0]}",
                                       marker=dict(
                                       color = [f'rgb({c}, {c}, {c})' for c in plot_df['val']],
                                       size=6,       
                                       colorscale='Plotly3',
                                       opacity=0.8))])
    fig.show()

スクリーンショット 2024-05-08 102841.png

モデル

3層+全結合のモデルを作ります。扱うデータが3次元というだけで実装的には2次元と一緒です。

class CNN3D(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool3d((2, 2, 2))
        self.conv1 = nn.Conv3d(in_channels, 16, kernel_size=(3,3,3), padding='same')
        self.bn1 = nn.BatchNorm3d(16)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=(3,3,3), padding='same')
        self.bn2 = nn.BatchNorm3d(32)
        self.conv3 = nn.Conv3d(32, 16, kernel_size=(3,3,3), padding='same')
        self.bn3 = nn.BatchNorm3d(16)
        self.flat = nn.Flatten()
        self.ln = nn.Linear(128, 10) # (16/2/2/2)^3*16
        
    def forward(self, x):
        # 1層目
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.maxpool(x)
        x = self.relu(x)
        # 2層目
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.maxpool(x)
        x = self.relu(x)
        # 3層目
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.maxpool(x)
        x = self.relu(x)
        # 全結合
        x = self.flat(x)
        x = self.ln(x)
        
        return x

学習

weight_decayは正則化項のようなもので過学習を抑制してくれるようです。
学習部分のコードも2次元のものと変わりありませんでした。

model = CNN3D(1) # インスタンス化
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=train_config['lr'], weight_decay=train_config['weight_decay'])
history = {'train_loss': []}
n = 0
for epoch in range(train_config['epochs']):
    train_loss = 0
    acc = 0
    
    model.train()
    
    for i, data in enumerate(train_loader):
        inputs, labels = data['train'].to(device), data['target'].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
        out_label = torch.argmax(outputs, dim=1)
        label = torch.argmax(labels, dim=1)
        acc += (out_label == label).sum().item() / (len(labels))
        n += 1
        if i % (len(x_train)//train_config['batch_size']//5) == (len(x_train)//train_config['batch_size']//5)-1:
            print(f"epoch:{epoch+1}  index:{i+1}  train_loss:{train_loss/n:.3f}  ACC:{acc/n:.3f}")
            n = 0
            train_loss = 0
            acc = 0
    model.eval()
    for j, data_ in enumerate(test_loader):
        inputs_, labels_ = data_['train'].to(device), data_['target'].to(device)
        output = model(inputs_)
        loss_ = loss_fn(output, labels_)
        acc_ = (torch.argmax(output, dim=1) == torch.argmax(labels_, dim=1)).sum().item() / (len(labels_))
        print(f"EVAL_LOSS:{loss_/len(label):.3f} EVAL_ACC:{acc_:.3f}") 

"""出力
epoch:1  index:62  train_loss:2.109  ACC:0.274
epoch:1  index:124  train_loss:1.633  ACC:0.474
epoch:1  index:186  train_loss:1.367  ACC:0.543
epoch:1  index:248  train_loss:1.255  ACC:0.582
epoch:1  index:310  train_loss:1.148  ACC:0.618
EVAL_LOSS:0.071 EVAL_ACC:0.622
epoch:2  index:62  train_loss:1.070  ACC:0.592
epoch:2  index:124  train_loss:1.058  ACC:0.648
epoch:2  index:186  train_loss:0.988  ACC:0.668
epoch:2  index:248  train_loss:1.048  ACC:0.640
epoch:2  index:310  train_loss:1.031  ACC:0.645
EVAL_LOSS:0.064 EVAL_ACC:0.627
epoch:3  index:62  train_loss:0.892  ACC:0.652
epoch:3  index:124  train_loss:0.975  ACC:0.662
epoch:3  index:186  train_loss:0.974  ACC:0.667
epoch:3  index:248  train_loss:0.858  ACC:0.702
epoch:3  index:310  train_loss:0.892  ACC:0.697
EVAL_LOSS:0.063 EVAL_ACC:0.640
epoch:4  index:62  train_loss:0.748  ACC:0.692
epoch:4  index:124  train_loss:0.845  ACC:0.710
epoch:4  index:186  train_loss:0.786  ACC:0.737
epoch:4  index:248  train_loss:0.850  ACC:0.715
epoch:4  index:310  train_loss:0.837  ACC:0.722
EVAL_LOSS:0.055 EVAL_ACC:0.695
epoch:5  index:62  train_loss:0.689  ACC:0.732
epoch:5  index:124  train_loss:0.709  ACC:0.764
epoch:5  index:186  train_loss:0.757  ACC:0.731
epoch:5  index:248  train_loss:0.708  ACC:0.754
epoch:5  index:310  train_loss:0.735  ACC:0.746
EVAL_LOSS:0.057 EVAL_ACC:0.686
epoch:6  index:62  train_loss:0.588  ACC:0.760
epoch:6  index:124  train_loss:0.627  ACC:0.790
epoch:6  index:186  train_loss:0.675  ACC:0.777
epoch:6  index:248  train_loss:0.662  ACC:0.768
epoch:6  index:310  train_loss:0.650  ACC:0.772
EVAL_LOSS:0.066 EVAL_ACC:0.648
epoch:7  index:62  train_loss:0.532  ACC:0.785
epoch:7  index:124  train_loss:0.521  ACC:0.835
epoch:7  index:186  train_loss:0.591  ACC:0.803
epoch:7  index:248  train_loss:0.586  ACC:0.800
epoch:7  index:310  train_loss:0.612  ACC:0.796
EVAL_LOSS:0.058 EVAL_ACC:0.690
epoch:8  index:62  train_loss:0.449  ACC:0.815
epoch:8  index:124  train_loss:0.463  ACC:0.848
epoch:8  index:186  train_loss:0.517  ACC:0.827
epoch:8  index:248  train_loss:0.552  ACC:0.815
epoch:8  index:310  train_loss:0.552  ACC:0.809
EVAL_LOSS:0.053 EVAL_ACC:0.717
epoch:9  index:62  train_loss:0.407  ACC:0.828
epoch:9  index:124  train_loss:0.424  ACC:0.856
epoch:9  index:186  train_loss:0.435  ACC:0.860
epoch:9  index:248  train_loss:0.452  ACC:0.846
epoch:9  index:310  train_loss:0.511  ACC:0.830
EVAL_LOSS:0.077 EVAL_ACC:0.625
epoch:10  index:62  train_loss:0.368  ACC:0.842
epoch:10  index:124  train_loss:0.364  ACC:0.888
epoch:10  index:186  train_loss:0.384  ACC:0.873
epoch:10  index:248  train_loss:0.416  ACC:0.862
epoch:10  index:310  train_loss:0.408  ACC:0.865
EVAL_LOSS:0.058 EVAL_ACC:0.704
"""

終わりに

GNNとかも実装してみようと思います。

参考記事

Kaggle 3D MNIST https://www.kaggle.com/datasets/daavoo/3d-mnist/data

全コード

import numpy as np
import polars as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from sklearn.metrics import r2_score
from torchvision import models, transforms
import h5py
import matplotlib.pyplot as plt
import seaborn as sns
import plotly
import sklearn
import pandas as pd
import random
import tqdm
import plotly.express as px
import plotly.graph_objects as go
from dataclasses import dataclass
import os
from sklearn.preprocessing import OneHotEncoder 

# 前処理用
data_config = {
    'threshold':0.2,
    'upper':1,
    'lower':0,
    'x_shape':16,
    'y_shape':16,
    'z_shape':16
}

train_config = {
    'lr':0.001,
    'epochs':10,
    'batch_size':32,
    'weight_decay':0.001
}

with h5py.File("../input/3d-mnist/full_dataset_vectors.h5", "r") as hf:    
    # Split the data into training/test features/targets
    x_train = hf["X_train"][:]
    y_train = hf["y_train"][:]
    x_test = hf["X_test"][:] 
    y_test = hf["y_test"][:]

print(f"Before x_train size:{x_train.shape}")
print(f"Before x_test size:{x_test.shape}")

def prepare_points(tensor, data_config, threshold = False):
    """
    thresholdに何か入ってれば値を0 or 1にする
    """
    if threshold:
        tensor = np.where(
            tensor > data_config['threshold'], 
            data_config['lower'], 
            data_config['upper']
        )
    tensor = tensor.reshape((
            tensor.shape[0], 
            data_config['y_shape'], 
            data_config['x_shape'], 
            data_config['z_shape']
        ))
    return tensor

# apply threshold and reshaping given tensors
print(f"Before x_train size:{x_train.shape}")
print(f"Before x_test size:{x_test.shape}")
print(f"Before Unique value:{np.unique(x_train)}")
x_train = prepare_points(x_train, data_config)
x_test = prepare_points(x_test, data_config)
print(f"After x_train size:{x_train.shape}")
print(f"After x_test size:{x_test.shape}")
print(f"After Unique value:{np.unique(x_train)}")

# apply one-hot-encoding to given labels
encoder = OneHotEncoder(sparse_output = False)
y_train = encoder.fit_transform(y_train.reshape((y_train.shape[0], 1)))
y_test = encoder.fit_transform(y_test.reshape((y_test.shape[0], 1)))
print(f"y_train size:{y_train.shape}")
print(f"y_test size:{y_test.shape}")

for index in range(3):
    plot_idx = random.randint(0, x_train.shape[0]) # 適当に選ぶ
    plot_img_3d = x_train[plot_idx]
    plot_label = y_train[plot_idx]

    plot_data = []
    for x in range(data_config['x_shape']): # x座標取得
        for y in range(data_config['y_shape']): # y座標取得
            for z in range(data_config['z_shape']): # z座標取得
                val = int(plot_img_3d[x, y, z] * 255) # 0 ~ 255に変換する
                plot_data.append([x, y, z, val])

    plot_df = pd.DataFrame(plot_data, columns=["x", "y", "z", "val"])
    plot_df = plot_df.loc[plot_df["val"] > 0]


    fig = go.Figure(data=[go.Scatter3d(x=plot_df['x'], y=plot_df['y'], z=plot_df['z'], 
                                       mode='markers',
                                       text=f"current label: {plot_label.nonzero()[0][0]}",
                                       marker=dict(
                                       color = [f'rgb({c}, {c}, {c})' for c in plot_df['val']],
                                       size=6,       
                                       colorscale='Plotly3',
                                       opacity=0.8))])
    fig.show()

class CNN3D(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool3d((2, 2, 2))
        self.conv1 = nn.Conv3d(in_channels, 16, kernel_size=(3,3,3), padding='same')
        self.bn1 = nn.BatchNorm3d(16)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=(3,3,3), padding='same')
        self.bn2 = nn.BatchNorm3d(32)
        self.conv3 = nn.Conv3d(32, 16, kernel_size=(3,3,3), padding='same')
        self.bn3 = nn.BatchNorm3d(16)
        self.flat = nn.Flatten()
        self.ln = nn.Linear(128, 10) # (16/2/2/2)^3*16
        
    def forward(self, x):
        # 1層目
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.maxpool(x)
        x = self.relu(x)
        # 2層目
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.maxpool(x)
        x = self.relu(x)
        # 3層目
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.maxpool(x)
        x = self.relu(x)
        # 全結合
        x = self.flat(x)
        x = self.ln(x)
        
        return x

class CustomDataset(Dataset):
    def __init__(self, train, target):
        self.train = train
        self.target = target
    
    def __getitem__(self, index):
        train_data = self.train[index]
        target_data = self.target[index]
        train_data = torch.from_numpy(np.expand_dims(train_data, axis=0).astype(np.float32))
        target_data = torch.from_numpy(target_data.astype(np.float32))
        data = {'train':train_data, 'target':target_data}
        return data
    
    def __len__(self):
        return len(self.train)

train_dataset = CustomDataset(x_train, y_train)
print('TRAIN Dataset size:', dataset[1]['train'].shape)
train_loader = DataLoader(train_dataset, batch_size=train_config['batch_size'], shuffle=True)

test_dataset = CustomDataset(x_test, y_test)
print('TEST Dataset size:', dataset[1]['train'].shape)
test_loader = DataLoader(test_dataset, batch_size=len(x_test), shuffle=True)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Now, using {device}")

model = CNN3D(1) # インスタンス化
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=train_config['lr'], weight_decay=train_config['weight_decay'])
history = {'train_loss': []}
n = 0
for epoch in range(train_config['epochs']):
    train_loss = 0
    acc = 0
    
    model.train()
    
    for i, data in enumerate(train_loader):
        inputs, labels = data['train'].to(device), data['target'].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
        out_label = torch.argmax(outputs, dim=1)
        label = torch.argmax(labels, dim=1)
        acc += (out_label == label).sum().item() / (len(labels))
        n += 1
        if i % (len(x_train)//train_config['batch_size']//5) == (len(x_train)//train_config['batch_size']//5)-1:
            print(f"epoch:{epoch+1}  index:{i+1}  train_loss:{train_loss/n:.3f}  ACC:{acc/n:.3f}")
            n = 0
            train_loss = 0
            acc = 0
    model.eval()
    for j, data_ in enumerate(test_loader):
        inputs_, labels_ = data_['train'].to(device), data_['target'].to(device)
        output = model(inputs_)
        loss_ = loss_fn(output, labels_)
        acc_ = (torch.argmax(output, dim=1) == torch.argmax(labels_, dim=1)).sum().item() / (len(labels_))
        print(f"EVAL_LOSS:{loss_/len(label):.3f} EVAL_ACC:{acc_:.3f}") 

1
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?