1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

torch.nn.TransformerEncoder のイラスト

1
Last updated at Posted at 2026-03-20

torch.nn.TransformerEncoder (のサブモジュール) のイラストです。

[2026-04-09 追記] デコーダも含めた Transformer 全体のイラストは以下です。
Transformer のイラスト - Qiita

参考文献

緑色の矢印: 学習パラメータをもつネットワーク
薄緑の矢印: 学習パラメータをもつネットワーク (オプショナル)
黄色の矢印: ハイパーパラメータで制御されるモジュール

torch.nn.MultiheadAttention

nn.MultiheadAttention.png
ドキュメントによると、torch.nn.MultiheadAttention はあくまで Attention is All You Need 版を実装したものであって、最新の機能はあまり取り込んでいないので、よりよいレイヤーを自作することやより高レベルのライブラリを使用することが推奨されています。

torch.nn.TransformerEncoderLayer

nn.TransformerEncoderLayer.png
ドキュメントには torch.nn.MultiheadAttention 同様の注意書きがあります。
torch.nn.TransformerEncoderLayer はデフォルトではレイヤー正規化1層は残差接続後ですが (Post-LN)、フラグによって入力直後に移動できます (Pre-LN)。このフラグ設置は Pre-LN の方が安定という研究を受けています。

torch.nn.TransformerEncoder

nn.TransformerEncoder.png
ドキュメントには torch.nn.MultiheadAttention 同様の注意書きがあります。
torch.nn.TransformerEncoder は torch.nn.TransformerEncoderLayer を繰り返します。が、すべてのレイヤーを同じパラメータで初期するので、インスタンス作成後に手動で改めて初期化することが推奨されています。
引数 norm にレイヤー正規化層を渡せば最後にレイヤー正規化を追加することもできます (Pre-LN にすると最終出力が正規化されていないので、そのための引数だと思います)。

補足

  • 3 枚の図それぞれのパラメータ数は以下です。
    • (6 + 1) * 18 + (6 + 1) * 6 = 168
    • 168 + (6 + 6) + (6 + 1) * 10 + (10 + 1) * 6 + (6 + 6) = 328
    • 328 + 328 + (6 + 6) = 668
  • 3 枚の図に対応するモデルを生成し、情報を表示するスクリプトは下記です。
    • スクリプトの最後でオブジェクト ID を確認しているように、torch.nn.TransformerEncoder は引数に torch.nn.TransformerEncoderLayer のインスタンスを受け取って層数だけ複製しますが、受け取ったインスタンス自体は使用しません。レイヤー正規化層については受け取ったインスタンス自体を使用します。
スクリプト
import torch

def _print(model):
    print('----- Architecture -----')
    print(model)
    print('----- Trainable Parameters -----')
    for name, param in model.named_parameters():
        print(f'{name}: {tuple(param.shape)}')
    n = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('  -> total:', n, 'parameters')

# torch.nn.MultiheadAttention
attn = torch.nn.MultiheadAttention(embed_dim=6, num_heads=2)
_print(attn)

# torch.nn.TransformerEncoderLayer
layer = torch.nn.TransformerEncoderLayer(
    d_model=6, nhead=2, dim_feedforward=10, batch_first=True,
)
_print(layer)

# torch.nn.TransformerEncoder
norm = torch.nn.LayerNorm(6)
model = torch.nn.TransformerEncoder(
    encoder_layer=layer, num_layers=2, norm=norm,
)
_print(model)

assert id(layer) != id(model.layers[0])
assert id(norm) == id(model.norm)
  1. レイヤー正規化は各特徴量ベクトルを正規化する操作です。入力が「私は、犬が、好き」なら「私は」ベクトルを正規化、「犬が」ベクトルを正規化、「好き」ベクトルを正規化します。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?