0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Transformerのフィードフォワードネットワーク(FFN)を書き換えてみる

0
Posted at

1. はじめに

Transformerの中で全パラメータ数の6割程を占めているフィードフォワードネットワーク(以下、FFN)の動作について興味があったので試してみました。
一般的に、Attentionが「どの情報に注目するか(注意)」を決め、「その情報で何を思い出し、何に変換するか(記憶・変換)」がFFNに格納されると言われていますが、
FFNの値を変えたり構造そのものを変えることでモデルの動作がどう変わるかを確認したいと思います。
(この記事は書き換え方法の説明のみとなります)

2. テスト内容

Transformer LLM における層単位の FFN 層の重要度検証
https://www.anlp.jp/proceedings/annual_meeting/2025/pdf_dir/P2-8.pdf

上記論文によると、「中間から後方の層にFFN を集中配置することで,複数の下流タスクでベースラインの性能を上回る結果を得た」とのことで、これに近いことを試してみたいと思います。
ベースに使うモデルは論文と同じLlamaとし、HuggingFaceのtransformers内のクラスを継承して一部上書きする方法で確認します。

3. 変更箇所

transformersでLlama(LlamaModel, LlamaForCausalLM)をロードした時に呼び出されるコードは以下の「modeling_llama.py」内にあります。
(models内にはLlama以外のモデルのコードもあります。)
modeling_llama

上記にコード内の各クラスの関係を簡単に図にまとめると以下のような構造になっているようでした。
image.png

コードを見ると、LlamaModelからModuleListで各レイヤーを構成する際にLlamaDecoderLayerを呼び出しているようなので、ここで引数に使うconfigをレイヤーごとに個別に指定すればできそうでした。

4. モデルの書き換え

4-1. 必要なもの

transformersとtorchがインストール済みであれば問題ありません。

python

from transformers import LlamaConfig, LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
import copy
from torch import nn

4-2. Model部分

今回作成する子クラスでは名前の頭に「My~」を付けました。
Model生成時に使うconfigにはgetattrで値を自由に追加できるため、今回は「intermediate_size_list」という名前で追加しました。
後の確認でテキスト生成させたいので、MyLlamaForCausalLMも作成しました。

class MyLlamaModel(LlamaModel):
    def __init__(self, config):
        super().__init__(config)
        
        # configから引数を読み込む default_listは値の指定がない場合に使われる
        default_list = [config.intermediate_size for _ in range(config.num_hidden_layers)] # list: [intermediate_size, intermediate_size, ...]
        intermediate_size_list = getattr(config, "intermediate_size_list", default_list)
        
        # レイヤごとにサイズを指定してlayersを上書き
        self.layers = nn.ModuleList()
        layer_config = copy.deepcopy(config)
        for layer_idx in range(config.num_hidden_layers):
            layer_config.intermediate_size = intermediate_size_list[layer_idx]
            self.layers.append(LlamaDecoderLayer(layer_config, layer_idx))


class MyLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        
        # 元のself.modelをMyLlamaModelで上書き
        self.model = MyLlamaModel(config)

4-3. config部分

FFNのサイズを倍率のようなもの(scaleとしました)で元のサイズから変更できるようにします。
レイヤーの数は前方、中盤、後方で3等分できるように12層とし、論文を参考に元の合計サイズと一致するように後方のみ大きくしそれ以外は小さくしました。

scale = [0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,1.5,1.5,1.5,1.5,]
print(len(scale))
# 12
print(sum(scale))
# 12.0 ⇒合計サイズは元と同じ

モデルの形状をconfigに書いていきます。
今回追加するintermediate_size_list以外の値に特に理由はなく、単に0.2Bほどの総パラメータ数となるように決めました。

config = LlamaConfig(
    vocab_size=32000,
    hidden_size=1024,
    intermediate_size=2816,
    num_hidden_layers=12,
    num_attention_heads=16,
    num_key_value_heads=16,
    max_position_embeddings=2048,
    rms_norm_eps=1e-6,
    rope_theta=10000.0,
    tie_word_embeddings=False,
    intermediate_size_list = [int(i * 2816) for i in scale], # 2816は元のintermediate_size
)

print(config.intermediate_size_list)
# [2112, 2112, 2112, 2112, 2112, 2112, 2112, 2112, 4224, 4224, 4224, 4224]

5. 動作確認

5-1. modelの作成、確認

上記で作成したModel部分とconfig部分を使ってmodelを作成していきます。

model = MyLlamaForCausalLM(config)

print(sum(p.numel() for p in model.parameters()))
# 219702272 パラメータ総数

print(model)
##### 以下出力 #####
MyLlamaForCausalLM(
  (model): MyLlamaModel(
    (embed_tokens): Embedding(32000, 1024)
    (layers): ModuleList(
      (0-7): 8 x LlamaDecoderLayer(                                             # 後方以外の8レイヤー
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=1024, out_features=2112, bias=False)  # outが指定した2112に減っている
          (up_proj): Linear(in_features=1024, out_features=2112, bias=False)    # outが指定した2112に減っている
          (down_proj): Linear(in_features=2112, out_features=1024, bias=False)  # inが指定した2112に減っている
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((1024,), eps=1e-06)
      )
      (8-11): 4 x LlamaDecoderLayer(                                            # 後方4レイヤー
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=1024, out_features=4224, bias=False)  # outが指定した4224に増えている
          (up_proj): Linear(in_features=1024, out_features=4224, bias=False)    # outが指定した4224に増えている
          (down_proj): Linear(in_features=4224, out_features=1024, bias=False)  # inが指定した4224に増えている
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((1024,), eps=1e-06)
      )
    )
    (norm): LlamaRMSNorm((1024,), eps=1e-06)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=1024, out_features=32000, bias=False)
)

各レイヤーのFFNがconfigで指定した通りのサイズに変わっていることが確認できました。

5-2. 推論テスト

import torch
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-step-50K-105b")
# Tokenizerは既存のLlama系のものを使いました
# vocab_sizeはモデルと一致している必要があるようです

model.eval()

prompt = "The capital of Japan is"
inputs = tokenizer(prompt, return_tensors="pt")

with torch.no_grad():
    out = model.generate(**inputs, max_new_tokens=20)
# Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.

print(tokenizer.decode(out[0]))
# <s> The capital of Japan is Inform Inform Inform Inform Inform Inform tres??????? Inform Inform Inform Inform Inform
# 意味が分からない文章になるかと思いますが、全く学習していない状態なので問題ありません

6. まとめ

FFNの構造を変えたい場合、HuggingFaceのtransformersを使うことで簡単に試すことができました。
Llama以外のモデルについてもいくつか見てみましたが、このMLP部分の実装はほとんど変わらないようなので、他のモデルでも適用できそうです。
他にも、同じ要領で、例えばAttention部分を書き換えたりすることもできるかと思います。

実行環境

Ubuntu 24.04.3 LTS
Python 3.12.3
torch 2.9.1+cu129
transformers 4.57.3

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?