トランスフォーマーやBERTなどの現代的な深層学習モデルにおいて、セルフアテンションは非常に重要なメカニズムです。この記事では、セルフアテンションの基本概念と、PyTorchを使った具体的な実装方法について解説します。
セルフアテンションとは
セルフアテンションは、シーケンス内の各要素が同じシーケンス内の他の要素との関連性を考慮して重み付けを行うメカニズムです。これにより、特定の位置の要素が、シーケンス内の他のどの要素と関連が強いかを学習することができます。
従来のRNN(Recurrent Neural Network)やLSTM(Long Short-Term Memory)が時系列に沿って情報を処理するのに対し、セルフアテンションは全ての位置を一度に参照できるため、長距離依存関係の捕捉に優れています。
セルフアテンションの計算プロセスは以下の図のように表現できます:
PyTorchによるセルフアテンションの実装
以下はPyTorchを使用したセルフアテンションモジュールの基本的な実装です。
class SelfAttention(nn.Module):
"""セルフアテンションモジュール"""
def __init__(self, hidden_dim: int, num_heads: int = 4):
"""
初期化
Args:
hidden_dim: 隠れ層の次元
num_heads: アテンションヘッド数
"""
super().__init__()
self.attention = nn.MultiheadAttention(hidden_dim, num_heads)
self.norm = nn.LayerNorm(hidden_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
順伝播
Args:
x: 入力テンソル (batch_size, seq_len, hidden_dim)
Returns:
torch.Tensor: 出力テンソル (batch_size, seq_len, hidden_dim)
"""
# 入力の順番を変更 (batch_size, seq_len, hidden_dim) -> (seq_len, batch_size, hidden_dim)
x_t = x.transpose(0, 1)
# セルフアテンションの適用
attn_out, _ = self.attention(x_t, x_t, x_t)
# 残差接続と正規化
attn_out = attn_out + x_t
attn_out = self.norm(attn_out)
# 出力の順番を戻す (seq_len, batch_size, hidden_dim) -> (batch_size, seq_len, hidden_dim)
return attn_out.transpose(0, 1)
この実装について、詳しく見ていきましょう。
コード解説
クラスの初期化
def __init__(self, hidden_dim: int, num_heads: int = 4):
super().__init__()
self.attention = nn.MultiheadAttention(hidden_dim, num_heads)
self.norm = nn.LayerNorm(hidden_dim)
ここでは2つの重要なコンポーネントを初期化しています:
-
MultiheadAttention: PyTorchの
nn.MultiheadAttentionクラスを使用しています。このクラスは複数のアテンションヘッドを並列に計算します。各ヘッドは異なる表現サブ空間に注目することができるため、モデルの表現能力が向上します。 -
LayerNorm: 正規化層として
nn.LayerNormを使用しています。これにより、ネットワークの学習が安定し、勾配消失や勾配爆発の問題を緩和できます。
順伝播(forward)の実装
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 入力の順番を変更
x_t = x.transpose(0, 1)
# セルフアテンションの適用
attn_out, _ = self.attention(x_t, x_t, x_t)
# 残差接続と正規化
attn_out = attn_out + x_t
attn_out = self.norm(attn_out)
# 出力の順番を戻す
return attn_out.transpose(0, 1)
このメソッドでは以下の処理を行っています:
-
テンソルの転置: PyTorchの
MultiheadAttentionは入力形状として(seq_len, batch_size, hidden_dim)を期待するため、入力テンソルxの形状を(batch_size, seq_len, hidden_dim)から変換します。 -
セルフアテンションの計算:
self.attentionメソッドを同じテンソルx_tに対して3回(クエリ、キー、バリュー)適用します。これがセルフアテンションの本質的な部分です。このメソッドは注意重みも返しますが(2番目の戻り値)、ここでは使用していないので_として無視しています。 -
残差接続: 注意出力を元の入力に足し合わせます。これにより、勾配が直接伝播できるパスが提供され、深いネットワークの学習が容易になります。
-
レイヤー正規化: 注意出力と入力の和に対して正規化を適用します。これにより、モデルの学習が安定します。
-
テンソルの再転置: 出力テンソルを元の形状
(batch_size, seq_len, hidden_dim)に戻します。
マルチヘッドアテンションとは
このコードではnum_heads引数を使用してマルチヘッドアテンションを設定しています。マルチヘッドアテンションとは、同じ入力に対して複数のアテンション計算を並列に行い、異なる表現空間で関連性を学習する仕組みです。
例えば、4つのヘッドを持つマルチヘッドアテンションでは、隠れ層の次元を4分割し、各部分に対して別々のアテンション計算を行います。これにより、モデルは異なる観点から入力データの関係性を捉えることができます。
以下の図は、マルチヘッドアテンションの概念を示しています:
残差接続と正規化の重要性
トランスフォーマーアーキテクチャの重要な特徴の一つに、残差接続(Residual Connection)と層正規化(Layer Normalization)があります。
# 残差接続と正規化
attn_out = attn_out + x_t
attn_out = self.norm(attn_out)
残差接続は、入力信号が変換を経由せずに直接出力に影響を与えるショートカットパスを提供します。これにより、深いネットワークでも勾配が消失することなく効率的に学習できます。
層正規化は、入力の平均と分散を調整して、学習プロセスを安定させる役割を果たします。これらの技術により、より深いネットワークを効果的に訓練することが可能になりました。
実際の使用例
このセルフアテンションモジュールは以下のように使用できます:
# モデルのインスタンス化
hidden_dim = 256
num_heads = 4
self_attention = SelfAttention(hidden_dim, num_heads)
# 入力データの作成(バッチサイズ=32、シーケンス長=10、隠れ層次元=256)
batch_size = 32
seq_len = 10
x = torch.randn(batch_size, seq_len, hidden_dim)
# 順伝播の実行
output = self_attention(x)
print(output.shape) # torch.Size([32, 10, 256])
応用と発展
このシンプルなセルフアテンション実装を基に、以下のような発展的な機能を追加することができます:
- 位置エンコーディング: シーケンスの位置情報を組み込むため
- ドロップアウト: オーバーフィッティングを防ぐため
- フィードフォワードネットワーク: トランスフォーマーの完全なブロックを構築するため
- マスキング: 自己回帰モデルや可変長シーケンスの処理のため
まとめ
セルフアテンションは現代の深層学習、特に自然言語処理において不可欠なメカニズムです。本記事では、PyTorchを使ってセルフアテンションモジュールを実装し、その内部動作について解説しました。
このシンプルな実装を理解することで、BERT、GPT、その他のトランスフォーマーベースのアーキテクチャがどのように動作するかの基礎を把握することができたのではないでしょうか