手作りデータセットで学習させてみる
前回は、PCA を使って
• 離散データ(顔画像)が
• 統計的な「連続空間」として扱えるようになる
という話をしてきました。
ここではもう少し踏み込んで、実際にミニモデルを学習させて挙動を眺めてみようと思います。
テーマはとても単純で:
• 画像の中にある「四角」が、3×3 に区切ったどのマスにいるかを当てるクイズ
これを CNN と Transformer を用いて解けるように学習させようと思います。モデルに 視認 させようって試みです。
0.やることのざっくりイメージ
• 32×32 の小さな白黒画像を自前で生成する
• 画像の中に 1 個だけ白い四角を描く
• 画像を 3×3 に分割して「どのマスに四角がいるか?」を 9クラス分類
• CNN / Transformer をそれぞれ学習させて挙動を見る
CNN と Transformerでは得意分野が違うのでどういう結果になるでしょうか・・・
1. マスあてクイズのデータセットを自前で作る
まずは PyTorch の Dataset クラスとして、「3×3 マスのどこかに四角がいる画像」を生成するクラス を用意します。画像ファイルを読み込むようなデータセットではなくて、その場で画像と同等のテンソルを作る ”データセット” です。
import numpy as np
import torch
from torch.utils.data import Dataset
import random
class SquareDataset(Dataset):
def __init__(self, n_samples=3000, img_size=32, square_size=6, invert_prob=0.5):
self.n_samples = n_samples
self.img_size = img_size
self.square_size = square_size
self.invert_prob = invert_prob # 白黒反転させる確率(頑丈さアップ用)
self.data = []
self.labels = []
self._generate()
def _generate(self):
H = W = self.img_size
region_size = H // 3 # 32 → 10(端は少し余るけど気にしない)
for _ in range(self.n_samples):
img = np.zeros((H, W), dtype=np.float32) + 0.1
label = random.randint(0, 8)
gy = label // 3 # 0,1,2 (row)
gx = label % 3 # 0,1,2 (col)
y_min = gy * region_size
y_max = (gy + 1) * region_size - self.square_size - 1
x_min = gx * region_size
x_max = (gx + 1) * region_size - self.square_size - 1
y_min = max(0, y_min)
x_min = max(0, x_min)
y_max = max(y_min, min(H - self.square_size - 1, y_max))
x_max = max(x_min, min(W - self.square_size - 1, x_max))
y0 = random.randint(y_min, y_max)
x0 = random.randint(x_min, x_max)
img[y0:y0 + self.square_size, x0:x0 + self.square_size] = 1.0
if random.random() < self.invert_prob:
img = 1.0 - img
self.data.append(img)
self.labels.append(label)
self.data = np.stack(self.data, axis=0) # (N, H, W)
self.labels = np.array(self.labels, dtype=np.int64) # (N,)
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
img = self.data[idx] # (H, W)
img = img[None, :, :] # (1, H, W) チャンネル次元を追加
x = torch.from_numpy(img) # float32 tensor
y = torch.tensor(self.labels[idx]) # int64
return x, y
これでいくつかのxを視覚化できるようにしてみると、
こんな感じで32x32の画像とみなせるテンソルで、その中に6x6の何かしろ白い四角を作って、白黒(or黒白)塗りになるようにするってのをランダムで作るようになってます。これがxです。label となってるのがyで[0,8]で出力されます。左上から右へいって折り返して中央左に来てって順番で全体を3x3のグリッドに区切り、順に0〜8としています。r はrow(行)、 c はcoloumn(列)です。実際には赤い点線のグリッドはxには含まれてません。これはあくまでpresentation用です。
これはDatasetクラスなので、これのインスタンスを作ってイテレーター的に呼び出すと __getitem__ 関数が毎回呼び出されることになって、xとyが返却されます。
今このデータセットクラスは
• 32x32 の画像の中に 1 個だけ白い四角を描く。
• 画像を 3x3 で区切ったとき、どのマスにいるか (0〜8) をラベルにする。
こんな感じです。
これで、例えば
train_ds = SquareDataset(n_samples=5000, img_size=32, square_size=6)
のようにして、9クラス分類用のデータセットが手に入ります。
2. CNN と Transformer を用意する
2-1. すごく小さい CNN
まずはごく普通の 2層 CNN を用意します。
import torch.nn as nn
import torch.nn.functional as F
class TinyCNN(nn.Module):
def __init__(self, num_classes=9):
super().__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2) # 32 → 16 → 8
self.fc1 = nn.Linear(32 * 8 * 8, 64)
self.fc2 = nn.Linear(64, num_classes)
def forward(self, x):
x = F.relu(self.conv1(x)) # (B,16,32,32)
x = self.pool(x) # (B,16,16,16)
x = F.relu(self.conv2(x)) # (B,32,16,16)
x = self.pool(x) # (B,32,8,8)
x = x.view(x.size(0), -1) # (B, 32*8*8)
x = F.relu(self.fc1(x))
x = self.fc2(x) # (B,9)
return x
これは、入力として(b,1,32,32) を期待してます。それをまずはself.conv1が受け止めてます。bはバッチチャンネルと呼ばれるSquareDatasetが出力するxをいくつかひとまとめに扱うための次元です。MLでは実際の自分たちの扱うものをいくつかまとめて扱うみたいなことをこのb次元でやってます。
このクラスをインスタンス化するとforward部分が__call__のように扱われて、たとえば
model = TinyCNN()
z = model(x)
みたいにすると(b,9)のテンソルが返却されます。この時点ではそれ以上のものでも以下のものでもまだないです。nn.Conv2dとnn.Linearは学習可能パラメーターと呼ばれるものすごく雑に言ってある特定の値を持ってる行列で、それに伝達される 勾配 に反応してその行列の構成している値(つまりパラメーター)を更新していきます。
(B,1,32,32)
→ (B,16,32,32)
→ (B,16,16,16)
→ (B,32,16,16)
→ (B,32,8,8)
→ (B, 32*8*8)
→ (B,9)
という流れになっていて、32x32のグレースケールのテンソルが最終的には出力範囲(-∞,∞)の値を9個得ることになります。この時点では(1,32,32) -> (9,)とするだけのもです。
*勾配に関しては機会あればまたどこかで詳しくみたいと思ってます。
2-2. すごく小さい Transformer
今回は本格的な ViT ではなく、パッチを切って TransformerEncoder に通すだけのミニ版です。
class TinyViT(nn.Module):
def __init__(self, img_size=32, patch_size=4,
dim=64, depth=2, num_heads=4, num_classes=9):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
num_patches = (img_size // patch_size) ** 2 # 32/4=8 → 8*8=64
self.num_patches = num_patches
# patched token 4x4
self.patch_embed = nn.Conv2d(
in_channels=1,
out_channels=dim,
kernel_size=patch_size,
stride=patch_size,
) # (B,1,32,32) → (B,dim,8,8)
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, dim))
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim,
nhead=num_heads,
dim_feedforward=dim * 4,
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
self.head = nn.Linear(dim, num_classes)
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x):
# x: (B,1,H,W)
x = self.patch_embed(x) # (B,dim,8,8)
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # (B, N=64, dim=C)
cls = self.cls_token.expand(B, -1, -1) # (B,1,dim)
x = torch.cat([cls, x], dim=1) # (B,65,dim)
x = x + self.pos_embed[:, : x.size(1)] # pos embed
x = self.encoder(x) # (B,65,dim)
cls_out = x[:, 0] # CLS token
logits = self.head(cls_out) # (B,num_classes)
return logits
ここでも同じく、入力として(b,1,32,32) を期待してます。Transformerではある程度意味のある塊をトークンとして扱います。なのでこの入力を一旦ざっくりと4x4のパッチにして(この時点でそのパッチが縦横8個ずつあるグリッドタイルのような感じになる)、それを一旦64チャンネルに拡張してます。この作業をここではself.patch_embedが担っていて、これの正体はConv2dです。TinyCNNの方でも出てきた学習可能パラメーターを持ってるモジュールです。学習可能パラメーターは、nn.Conv2d, nn.TransformerEncoder, nn.Linear, nn.Parameterに含まれてます。実際には、nn.TransformerEncoderはいくつかのnn.Linearとnn.LayerNormを組み合わせた機能的モジュールみたいな感じです。Transformerはそこそこやってることが複雑なのでここでは詳しい説明を割愛しますが、機会あればこれもまたどこかで説明したと思います。
テンソルは、
(B,1,32,32)
→ (B,64,8,8)
→ (B,64,64)
→ (B,65,64)
→ (B,65,64)
→ (B,64)
→ (B,9)
という流れになってます。最終的な返却値の出力範囲はこちらも同じく(-∞,∞)となります。
また、ここではパッチトークンとクラストークンというあとから追加したもう一つのトークンがそれぞれ特徴チャンネルを持ってますが、この全てのトークン数x特徴チャンネル(つまり構成要素数)分の適当な値としてnn.Parameterで学習可能な値を付与しています。Transformerでいうところのpositional embedなのですが、今回はタスクがタスクなので空間位置情報を特には気にしてません。画像生成などの Dense Prediction と呼ばれるタスクにおいてはそういうものが必要になってきます。
ここまで揃えばあとは学習です。(次回につづく
