続・手作りデータセットで学習させてみる
前回はデータセットとモデルを作成したので、今回は学習を行い評価をしようかと思ってます。
3. 学習ループ
学習はデータセットから取り出したデータをモデルにfeedして返却値を得ます。今回の場合、モデルは9個の(-∞,∞)の値を返却します。これを 確率 として使いたいので実際には[0,1]にして返却すればいいのですが、ここにはMLの学習のトリックが存在するため、未加工の値 logits で返却させています。
1. モデルの出力(logits)は確率ではない
モデルは (B, 9) の生スコアを返します。(範囲は -∞〜+∞)
この時点では何も “確率” ではありません。
2. CrossEntropyLoss は内部で softmax をかけて確率(っぽいもの)にする
例えばこんな感じ、
• logits = [ -3, 1, 5 ] # これがモデルからの出力値
• softmax = [ 0.002, 0.018, 0.980 ] # 0〜1 に変換
※softmax の数値は説明用に丸めています。
softmaxは入力の配列の要素数に応じて合計1になるように返します。なので出力が確率っぽくなります。あえて重ねて補足ですが、softmax はあくまで “確率分布の形をした値” を返すだけで、
確率そのものではありません。(統計学的な意味の確率ではない)
3. 正解ラベル y は「正解クラスの番号」だけを持つ
9クラスなら 0〜8 の整数。
• y = 2
これだけで OK。
4. loss は “正解クラスの確率” p だけを取り出して作る
softmax の中から、y が示す位置の確率を抜き出す:
モデルからの出力は9個の値の配列だけど正解ラベルは一つってなってて、正解ラベルは整数値なのでindexとして用いて、スライシングで取り出すって感じ。
• p = softmax[y]
もし y=2 なら:
• p = softmax[2] # 0.980
5. CrossEntropyLoss は -log(p) を loss として返す
損失は通常0に近づけたくて、softmaxは[0,1]で正解なら1。これが表してるのは直感的には自信度っぽいもの(softmaxが返すものがこれだとなんとなく 1 - p をすれば良いって感じるんですが、値を強調したいので、
• loss = -log(p)
を使います。xが[0,1]の時のlog(x)の出力範囲は (-inf, 0] です。
p が大きい → loss は小さい(ご褒美)
p が小さい → loss は大きい(強いペナルティ)
これを踏まえて以下が学習コードです。
from torch.utils.data import DataLoader
import torch.optim as optim
import torch
def train_one_model(model, train_ds, test_ds,
epochs=5, batch_size=64, lr=1e-3, device=None, name="model"):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=batch_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
for ep in range(1, epochs + 1):
model.train()
total_loss = 0.0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
logits = model(x)
loss = criterion(logits, y)
loss.backward()
optimizer.step()
total_loss += loss.item() * x.size(0)
avg_loss = total_loss / len(train_loader.dataset)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
logits = model(x)
pred = logits.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
acc = correct / total if total > 0 else 0.0
print(f"[{name}] epoch {ep}/{epochs} - loss: {avg_loss:.4f}, acc: {acc:.3f}")
4. 実際に CNN と ViT を走らせてみる
最後に、さっきの前回のデータセットとモデルとを組み合わせて実行してみます。
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
# dataset
train_ds = SquareDataset(n_samples=5000, img_size=32)
test_ds = SquareDataset(n_samples=1000, img_size=32)
# Tiny CNN
cnn = TinyCNN(num_classes=9)
train_one_model(
model=cnn,
train_ds=train_ds,
test_ds=test_ds,
epochs=5,
batch_size=64,
lr=1e-3,
device=device,
name="TinyCNN"
)
# Tiny ViT
vit = TinyViT(img_size=32, patch_size=4, dim=64, depth=2, num_heads=4, num_classes=9)
train_one_model(
model=vit,
train_ds=train_ds,
test_ds=test_ds,
epochs=5,
batch_size=64,
lr=1e-3,
device=device,
name="TinyViT"
)
if __name__ == "__main__":
main()
これを実際に走らせてみると・・・(学習時間は10秒もかかりません)
[TinyCNN] epoch 1/5 - loss: 0.7125, acc: 1.000
[TinyCNN] epoch 2/5 - loss: 0.0012, acc: 1.000
[TinyCNN] epoch 3/5 - loss: 0.0005, acc: 1.000
[TinyCNN] epoch 4/5 - loss: 0.0003, acc: 1.000
[TinyCNN] epoch 5/5 - loss: 0.0002, acc: 1.000
[TinyViT] epoch 1/5 - loss: 2.2280, acc: 0.117
[TinyViT] epoch 2/5 - loss: 2.2107, acc: 0.105
[TinyViT] epoch 3/5 - loss: 2.2124, acc: 0.097
[TinyViT] epoch 4/5 - loss: 2.2057, acc: 0.108
[TinyViT] epoch 5/5 - loss: 2.2014, acc: 0.202
引くほどぶっちぎりでCNNの圧勝!!
5. 考察
CNNの圧勝理由:帰納バイアスの威力
->これ言っちゃうと the「元も子もない」なのですが、帰納バイアスつまりモデルの成り立ちだとか、モデルに組み込まれた特性や前提を指してますが、つまり、
• CNNはそのために作られてる
• 局所特徴に強い
• パターン(局所的な形)を効率よく捉えられる
• 分類タスク向き
があげられます。
- 局所性の仮定がドンピシャ
- 四角形は「局所的なパターン」(エッジの集まり)
- CNNの畳み込みは局所特徴の検出に最適化されている
- 3x3カーネルで四角のエッジを即座に捉えられる - 位置不変性が有利
- 「四角がどこにあっても四角」という知識がCNNに組み込まれている
- 同じフィルタを画像全体にスライドさせるので、位置に関わらず検出可能 - パラメータ効率
- CNNは少ないパラメータで局所パターンを学習
- 1エポックで必要な特徴を獲得済み
ViTが苦戦する理由:ゼロからの学習
- 事前知識ゼロ
- 「隣のピクセルは関連性が高い」すら知らない
- 全パッチ間の関係をAttentionでゼロから学習する必要がある - パッチ単位の処理
- 4x4パッチに分割 → 6x6の四角が2〜4パッチにまたがる
→ ViT は四角の“境界”そのものを見れなくなる
→ つまり“四角”の認識に必要な局所エッジが壊れる
- 「このパッチとあのパッチが同じ物体」を学習するのに時間がかかる - データ量不足
- ViTは大規模データで事前学習してから本領発揮
- 5000サンプルは少なすぎる(ImageNet-21kは1400万枚)
結論
• ViTは「万能だが素朴」、CNNは「専門特化で効率的」
小規模・単純タスクでは、CNNの帰納バイアス(局所性・位置不変性)が圧倒的に有利。ViTは大規模データと複雑なタスクで真価を発揮する。と言ったところでしょうか。