はじめに
今回は、NeurIPS 2023 にて発表された "TIES-MERGING: Resolving Interference When Merging Models" を紹介します。この論文は、複数のタスクに特化したモデルをマージしてマルチタスクモデルを生成する手法を提案しています。Sakana-AI の進化的モデルマージで採用されていたので詳しく内容を読んでみました。
参考文献:
P. Yadav, D. Tam, L. Choshen, C. Raffel, M. Bansal. "TIES-MERGING: Resolving Interference When Merging Models". NeurIPS 2023.
1 論文の概要
この論文では、モデルマージにおけるパラメータ干渉を解消するための新しい手法 TIES-MERGING (TRIM, ELECT SIGN, MERGE) を提案しています。従来のモデルマージ手法は、単純なパラメータ平均や重み付けに頼っていましたが、これによりモデル間のパラメータ値が干渉し、性能が低下する問題がありました。そこで提案手法では以下の 3 ステップでこの問題を解決しています。
TIES-MERGING
- Trim: パラメータの変化量が小さい値を切り捨て
- Elect Sign: 符号衝突部のパラメータの符号決定
- Merge: 符号が一致するパラメータのみを平均化してマージ
このアプローチにより既存手法を大幅に上回る性能を示しました。
🐣 論文中のパラメータはモデルの重みと考えてよさそうです
2 関連研究
関連研究として、タスク固有のモデルをマージする方法や、パラメータの干渉を軽減する研究が挙げられます。本論文はこれらを基に新たな視点を導入し、マージにおけるパラメータ干渉を直接解決するアプローチを提案しています。
Loss Landscape and Weight Interpolation に関する研究
- 異なる訓練で得られたモデルのパラメータを補間可能にする条件 (mode connectivity) についての研究
- 同じ初期化から始まるモデル間では補間が可能であることを発見
Model Merging and Different Use Cases に関する研究
- 同じ事前学習モデルを用いる場合、順列対称性を考慮せずにモデルをマージ可能
- 単一タスクの性能向上、OOD (Out-Of-Domain) 一般化、マルチタスク化、連合学習、モデル圧縮など、幅広い応用が研究されている
Weight Interpolation に関する研究
- Frankle ら: 最適化経路の一部が共有されている場合、補間可能であることを示した
- Neyshabur ら: 異なる最適化経路のモデル間では精度が劣化することを示した
- Entezari ら: 順列対称性を考慮すれば、同じアーキテクチャ・データセットのモデルは線形に接続可能との仮説を立てた
- Ainsworth ら、Singh ら、Wang ら: 順列や Optimal Transport を用いて補間性能を向上させる手法を提案
Model Merging に関する研究
- RegMean: 各層の局所的な線形回帰で解を求めるが、モデルサイズに比例する統計量が必要
- Fisher Merging: パラメータの重要度を Fisher Information Matrix で測定するが、複数モデルでの改善が限定的
- Task Arithmetic: タスクベクトルの加算による単純なマージ手法
- Ortiz-Jiménez ら: 事前学習による重みの分離性に基づく理論的分析を提示
補足: 順列対称性 (permutation symmetry) の重要性
ニューラルネットワークでは、各層のノードの順序をシャッフルしても全体の機能は変わらないという「順列対称性」という性質があります。この性質により、異なる初期化から学習した複数のモデルの重みを単純にマージすることは、同じ機能を持つノードが異なる位置に配置されている可能性があるため困難です。
しかし先行研究では、同一の事前学習モデルから始まる場合は順列対称性を考慮する必要がないと仮定できると報告されています。
TIES-Merging は、この「同一の事前学習モデルから始まる」という条件を活用し、順列対称性の問題を回避しています。
🐣 おばあさん細胞の場所は違っても機能は一緒ということなのね
3 提案手法
タスクベクトル (task vector) による変化の表現
TIES-Merging では、各タスクにおけるモデルの適応を定量化するために タスクベクトル という概念を用います。これは、以下のように定義されます。
\tau_t = \theta_{t,\text{ft}} - \theta_{\text{init}}
ここで、各記号は次の意味を持ちます。
- $\theta_{t,\text{ft}}$: タスク $t$ に対してファインチューニング後のモデルパラメータ
- $\theta_{\text{init}}$: ファインチューニング前の事前学習済みモデルの初期パラメータ
- $\tau_t$: タスク $t$ における タスクベクトル (モデルパラメータの変化量)
このタスクベクトル $\tau_t$ は、事前学習モデルから特定のタスクに適応する際に、どのようにパラメータが変化したかを直接的に示します。この表現を用いることで、異なるタスクにおけるモデルの変化を比較したり、マージの方向性を設計するための指針を得ることが可能となります。
🐣 ここが重要! 重みそのものの値は関係ありません
パラメータ干渉の問題
タスクベクトルの成分を比較した分析により、異なるタスク間でのパラメータ干渉が性能低下の原因となることが明らかになりました。この干渉は、次の 2 種類に分類されます。
🐣 重み全体の差分がタスクベクトルでその成分が一つの重みの差分です。タスクベクトルは名前の通りタスクごとに一つしかありません
1. 冗長なパラメータによる干渉
-
概要:
多くのタスクベクトル成分 (各重みの変化量) は性能への影響が小さいにも関わらず、単純な平均では重要な変化が冗長な変化によって打ち消されてしまいます。 -
具体例:
ベースモデルから異なるタスクでファインチューニングされた 3 つのモデルを考えます。タスク $A$, $B$, $C$ それぞれのタスクベクトル (ベースモデルとの重みの差のベクトル)\boldsymbol{\tau}_A, \boldsymbol{\tau}_B, \boldsymbol{\tau}_C
の成分
\tau_{A,i}, \tau_{B,i}, \tau_{C,i}
を比較します。
- タスク $A$ の成分 $\tau_{A,i} = +0.5$ (重要な変化)
- タスク $B$, $C$ の成分 $\tau_{B,i} = -0.2$, $\tau_{C,i} = -0.2$ (冗長な変化)
単純平均を取ると:
\tau_i^{\text{merged}} = \frac{\tau_{A,i} + \tau_{B,i} + \tau_{C,i}}{3} = \frac{0.5 - 0.2 - 0.2}{3} = +0.03
この結果、タスク $A$ にとって重要だった変化が他の冗長な変化によって大幅に弱められてしまいます。
2. 符号の衝突による干渉
-
概要:
同じパラメータでも、異なるタスク間で変化の方向 (符号) が異なる場合、単純な平均では正負の変化が相殺され、どちらのタスクにも不適切な値となります。 -
具体例:
タスク $A$, $B$ のタスクベクトル\boldsymbol{\tau}_A, \boldsymbol{\tau}_B
の成分
\tau_{A,i}, \tau_{B,i}
を比較します。
- タスク $A$ の成分 $\tau_{A,i} = +0.6$
- タスク $B$ の成分 $\tau_{B,i} = -0.4$
単純平均を取ると:
\tau_i^{\text{merged}} = \frac{\tau_{A,i} + \tau_{B,i}}{2} = \frac{0.6 - 0.4}{2} = +0.1
この結果、どちらのタスクにとっても適切でない値 (+0.1) となり、性能が低下します。
これらの干渉は、タスク数が増えるほど深刻になり、マージされたモデルの性能低下の主な要因となりえます。
図3、図4が示す内容
図3 (左): パフォーマンスと上位パラメータの依存性
図4 (右): Trimming 後における符号の衝突の頻度
(出典:P. Yadav et al., "TIES-MERGING: Resolving Interference When Merging Models," 2023)
1. 通常通りファインチューニング
各タスクに対してモデルをファインチューニングし、タスク固有のパラメータ変化を取得します。
2. ベースモデルからの変化量を計算
タスクベクトルを次の式で計算します:
\boldsymbol{\tau}_t = \boldsymbol{\theta}_t^{\text{ft}} - \boldsymbol{\theta}_{\text{init}}
ここで:
- $\boldsymbol{\theta}_t^{\text{ft}}$: タスク $t$ におけるファインチューニング後の全パラメータ (重み全体の集合)
- $\boldsymbol{\theta}_{\text{init}}$: ベースモデルの初期パラメータ (重み全体の集合)
- $\boldsymbol{\tau}_t$: タスクベクトル (タスク $t$ における全パラメータの変化を記述したベクトル)
タスクベクトルの一成分 $\tau_{t,i}$ は、個々の重み (スカラー値) の変化を表しており、一つの重みに対応します。
3. 変化量の大きい上位 $k\%$ のみを保持
タスクベクトルの各成分 $\tau_{t,i}$ について、絶対値 (ノルム) が大きい上位 $k\%$ の成分のみを保持し、それ以外の成分はベースモデルの値 (差分 0) に戻します。この操作により、タスクベクトルの一部の成分が 0 になります。
-
具体的な操作
-
各成分 $\tau_{t,i}$ の絶対値を計算します:
|\tau_{t,i}|
-
全成分を絶対値で降順にソートし、その中から 上位 $k\%$ の成分を選択します。
-
上位 $k\%$ に含まれない成分は 0 にリセットされます (ベースモデルの初期値のに戻るという意味) 。トリミング後のタスクベクトルは次のように定義されます:
\tau_{t,i}^{\text{trimmed}} = \begin{cases} \tau_{t,i}, & \text{if } |\tau_{t,i}| \text{ is in top } k\% \text{ of } |\boldsymbol{\tau}_t|, \\ 0, & \text{otherwise}. \end{cases}
-
4. 性能を評価
トリミング後のタスクベクトルを基にモデルを再構成し、その性能を評価します。
結果の解釈
-
上位 20% の変化量のみを保持しても性能は維持される
図3では、上位 20% の大きな変化を持つ成分を保持した場合のパフォーマンスが、すべての成分を保持した場合とほぼ同等であることが示されています。 つまり、下位 80% の成分が性能にほとんど寄与しない冗長な要素であることが示唆されています。🐣 下位 80% の重みをゼロにするわけではなく、変化量をゼロにしてベースモデルの値に戻すということです
-
符号の衝突は Trimming 後も発生する
図4 では、タスクベクトルをトリミングした後でも符号の衝突が発生することが示されています。また、マージするモデル数が増えるにつれて衝突の頻度が増加するため、TIES-Merging の次のステップである Elect Sign の重要性が強調されています。
4 TIES-MERGING: TRIM, ELECT SIGN & MERGE
TIES-Merging は以下の 3 つのステップから構成されています:
- Trim: 冗長な変化を削減して、重要な変化のみを保持
- Elect Sign: 各パラメータの方向性(符号)を統一
- Disjoint Merge: 統一した符号に基づき平均を計算してマージ
最終的なマージタスクベクトル $\boldsymbol{\tau}_m$ は、これらのステップを通じて構築されます。
1. Trim(冗長な変化の削減)
タスク $t$ におけるタスクベクトル $\boldsymbol{\tau}_t$ の成分について:
- 上位 $k\%$ の成分(絶対値が大きい変化)を保持。
- 残りの成分は 0 にリセット(差分を無効化、ベースモデル値に戻す)。
これにより、各タスクのタスクベクトルがスパース構造になります。
\boldsymbol{\tau}_t^{\text{trimmed}} = \left[ \tau_{t,1}^{\text{trimmed}}, \tau_{t,2}^{\text{trimmed}}, \dots, \tau_{t,P}^{\text{trimmed}} \right]
- ** Trim の結果**: 各タスクベクトルの中で、ファインチューニングによって大きく変化した重みに対応する成分のみが残ります。
2. Elect Sign(符号の選択)
各タスクに $t$ に対応するタスクベクトル $\boldsymbol{\tau}_t^{\text{trimmed}}$ を用いて、各パラメータ $p$ (位置 $p$ の重み) に対する「支配的な符号」を決定します。
符号 $\gamma_\mathrm{m}^{(p)}$ の選択基準 ($_\mathrm{m}$ はマージを表すラベル):
- 全タスクのトリミング後タスクベクトル成分 $\tau_{t,p}^{\text{trimmed}}$ の符号(方向)を合計。
- 合計値の符号を符号関数 $\text{sgn}$ に基づき決定。
具体例:
- タスク 1: $\tau_{1,p}^{\text{trimmed}} = +0.3$
- タスク 2: $\tau_{2,p}^{\text{trimmed}} = -0.1$
- タスク 3: $\tau_{3,p}^{\text{trimmed}} = +0.2$
→ 合計値 $\sum = +0.4$ の符号は $+1$
→ 支配的な符号 $\gamma_\mathrm{m}^{(p)} = +1$
これによってファインチューニング後、タスク間で異なる方向に重みが変化した個所のマージ後の向きが決定します。
3. Disjoint Merge(符号方向の平均計算)
符号 $\gamma_\mathrm{m}^{(p)}$ が決まった後、方向が一致するタスクベクトル成分 (重み) だけを平均します。
具体的には:
- 各パラメータ $p$ に対し、符号 $\gamma_\mathrm{m}^{(p)}$ と一致する成分だけを選び出す:
\mathcal{A}^{(p)} = \left\{ t \mid \text{sgn}(\tau_{t,p}^{\text{trimmed}}) = \gamma_\mathrm{m}^{(p)} \right\}
- 一致した成分の平均を計算:
\tau_\mathrm{m}^{(p)} = \frac{1}{|\mathcal{A}^{(p)}|} \sum_{t \in \mathcal{A}^{(p)}} \tau_{t,p}^{\text{trimmed}}
- 平均値 $\tau_\mathrm{m}^{(p)}$ を使って、マージタスクベクトルの成分 $\boldsymbol{\tau}_\mathrm{m}$ を構成。
もしもすべてのタスクで変化の向きが同じであり、Trim で削られてもいなければこれは単に各タスクによる変化を平均しているだけです。Trim が適用されている場合、そこは 0 として計算され符号が衝突している場合は支配的な符号と逆の符号の部分は 0 で計算されます。
結果として、マージで用いるタスクベクトルは以下となります。( $P$ は重みの数)
\boldsymbol{\tau}_\mathrm{m} = \left[ \tau_\mathrm{m}^{(1)}, \tau_\mathrm{m}^{(2)}, \dots, \tau_\mathrm{m}^{(P)} \right]
最終モデルの構築
最終的に、マージで用いるタスクベクトル $\boldsymbol{\tau}_\mathrm{m}$ を使用して、ベースモデルを更新します:
\boldsymbol{\theta}_\mathrm{m} = \boldsymbol{\theta}_{\text{init}} + \lambda \boldsymbol{\tau}_\mathrm{m}
🐣 $t$ のような変数はイタリック $\mathrm{m}$ のようなラベルはローマン体ですね。論文では $m$ になっていますがそれだと変数のように見えてわかりにくいです
5 数値実験と結果
提案手法 TIES-MERGING は、以下の点で従来手法を上回りました。
- NLP および画像処理タスクで、最も強力な従来手法を 2.3% 上回る性能を達成
- トップ 20% パラメータを用いたマージ設定では、バリデーションセットがなくても高い性能を維持
詳細な数値については論文 Table 1 および Table 2 を参照してください。
6 まとめ
TIES-MERGING は、モデルマージ時に発生するパラメータ干渉を軽減する効果的な手法です。パラメータ変化のトリミングと符号の調整によって、性能向上が実現されました。特に、符号の衝突を解消する工程がマージ性能に与える重要性を示しています。
おわりに
この論文は、モデルマージに新たな視点をもたらしました。個人的には、同一の事前学習モデルを用いることで順列対称性の問題を回避している点に感銘を受けました。また符号の衝突回避についてはまだまだ研究できることが多そうです。ではまた次の記事でお会いしましょう。