0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【PyTorch】残差ブロックの仕組み

Posted at

この記事では、残差ブロックとは何か、なぜ必要なのか、そしてPyTorchでどのように実装するかをコード例を挙げて解説します。

残差ブロックとは?

残差ブロックは、主に「スキップ接続」と呼ばれる特別な接続方法を持つニューラルネットワークの構成要素です。

簡単に言うと、入力をそのまま出力に足すという仕組みです。これだけだと「何が特別なの?」と思うかもしれませんが、この単純な仕組みがディープラーニングモデルの性能を劇的に改善します。

残差ブロックの基本構造:入力が処理パスを通過すると同時に、スキップ接続を通じて出力に直接足し合わされる

なぜ残差ブロックが必要?

ディープラーニングでは、ネットワークを深くすればするほど表現力が高まると考えられていました。しかし、単純にレイヤーを重ねるだけでは、勾配消失問題(Vanishing Gradient Problem) が発生し、うまく学習できなくなってしまいます。

残差ブロックを使用すると、入力から出力への「近道(ショートカット)」ができるため、勾配がこの近道を通って流れることができます。これにより勾配消失問題が緩和され、非常に深いネットワークでも効率的に学習できるようになります。

PyTorchでの実装例

それでは、具体的なPyTorchの実装例を見てみましょう。

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    """残差ブロック"""
    
    def __init__(self, hidden_dim: int, dropout: float = 0.1):
        """
        初期化
        
        Args:
            hidden_dim: 隠れ層の次元
            dropout: ドロップアウト率
        """
        super().__init__()
        
        self.layer1 = nn.Linear(hidden_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        順伝播
        
        Args:
            x: 入力テンソル (batch_size, hidden_dim)
            
        Returns:
            torch.Tensor: 出力テンソル (batch_size, hidden_dim)
        """
        # スキップ接続のために入力を保存
        residual = x
        
        # バッチサイズが1かつ評価モードの場合の特別処理
        batch_size = x.size(0)
        is_single_batch = batch_size == 1 and not self.training
        
        if is_single_batch:
            # バッチを一時的に複製して、BatchNormが動作するようにする
            x_batch = torch.cat([x, x], dim=0)
            
            out = self.layer1(x_batch)
            out = self.bn1(out)
            out = F.relu(out)
            out = self.dropout(out)
            
            out = self.layer2(out)
            out = self.bn2(out)
            
            out += torch.cat([residual, residual], dim=0)  # 残差接続
            out = F.relu(out)
            
            # 最初のバッチだけを取り出す
            return out[:1]
        else:
            # 通常の処理(バッチサイズ >= 2)
            out = self.layer1(x)
            out = self.bn1(out)
            out = F.relu(out)
            out = self.dropout(out)
            
            out = self.layer2(out)
            out = self.bn2(out)
            
            out += residual  # ここが残差接続!
            out = F.relu(out)
            
            return out

コードの解説

1. 初期化メソッド (__init__)

def __init__(self, hidden_dim: int, dropout: float = 0.1):
    super().__init__()
    
    self.layer1 = nn.Linear(hidden_dim, hidden_dim)
    self.bn1 = nn.BatchNorm1d(hidden_dim)
    self.layer2 = nn.Linear(hidden_dim, hidden_dim)
    self.bn2 = nn.BatchNorm1d(hidden_dim)
    self.dropout = nn.Dropout(dropout)

ここでは以下のコンポーネントを初期化しています:

  • 2つの全結合層(Linear): 入力と出力の次元が同じ(hidden_dim
  • 2つのバッチ正規化層(BatchNorm1d): 学習を安定させる
  • ドロップアウト層(Dropout): 過学習を防ぐ

2. 順伝播メソッド (forward)

def forward(self, x: torch.Tensor) -> torch.Tensor:
    # スキップ接続のために入力を保存
    residual = x

最初に、入力xresidual変数に保存します。これが後で足し合わされる「スキップ接続」の入力になります。

3. バッチサイズのチェック

batch_size = x.size(0)
is_single_batch = batch_size == 1 and not self.training

ここで特別なチェックが行われています。バッチ正規化(BatchNorm)は、バッチサイズが1の場合に問題を起こすことがあります。そのため、バッチサイズが1で、かつ評価モード(not self.training)の場合は特別な処理を行います。

4. 通常の処理フロー

out = self.layer1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.dropout(out)

out = self.layer2(out)
out = self.bn2(out)

out += residual  # ここが残差接続!
out = F.relu(out)

処理の流れは:

  1. 第1層の全結合層
  2. バッチ正規化
  3. ReLU活性化関数
  4. ドロップアウト
  5. 第2層の全結合層
  6. バッチ正規化
  7. 残差接続(スキップ接続): out += residual
  8. ReLU活性化関数

最も重要なのは7番目のステップです。ここで、変換された出力に元の入力を足しています。これが「残差接続」と呼ばれる部分です。

5. バッチサイズ=1の特別処理

if is_single_batch:
    # バッチを一時的に複製して、BatchNormが動作するようにする
    x_batch = torch.cat([x, x], dim=0)
    
    # 処理...
    
    # 最初のバッチだけを取り出す
    return out[:1]

バッチサイズが1の場合は注意が必要です。バッチ正規化を正常に機能させるために入力を複製してバッチサイズを2にします。処理後、最初のサンプルだけを取り出して返します。

バッチサイズ=1の特殊処理:評価モード時にバッチを複製してバッチ正規化を正常に機能させる工夫

この特殊処理は、評価時(not self.trainingTrueのとき)に単一サンプルを処理する際の安定性を大きく向上させます。バッチ正規化は統計値(平均・分散)を使用するため、バッチサイズが1だと正確な統計値が計算できません。サンプルを複製することでこの問題を解決しています。

残差ブロックの使用例

残差ブロックはモデルの中で以下のように使用できます:

class MyModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.residual1 = ResidualBlock(hidden_dim)
        self.residual2 = ResidualBlock(hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = self.input_layer(x)
        x = F.relu(x)
        
        x = self.residual1(x)
        x = self.residual2(x)
        
        x = self.output_layer(x)
        return x

まとめ

残差ブロックは以下の特徴を持ちます:

  1. スキップ接続により勾配消失問題を緩和
  2. より深いネットワークの学習を可能に
  3. バッチ正規化で学習を安定化
  4. ドロップアウトで過学習を防止
  5. バッチサイズ=1のエッジケースにも対応

残差ブロックを使いこなすことで、より深く、より表現力の高いニューラルネットワークを構築できるようになります。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?