紹介する論文
タイトル: AWQ: ACTIVATION-AWARE WEIGHT QUANTIZATION FOR
ON-DEVICE LLM COMPRESSION AND ACCELERATION
学会: MLSys2024
AWQは大規模言語モデル向け量子化手法です。Llama2の配布する際にしばしば見かける量子化方法です。論文では、TinyChatという対話システムに適用して評価することで、有用性を評価しています。
考え方
AWQはわずかな"重要な重み"は量子化せずに残すようにして、残りを量子化したいという考え方が基本になっています。
重要な重みは、Weightの重みではなく、Activationの大きさで判断します。Activationの絶対値が大きなチャンネルの重みは重要であると判断します。
わずかな"重要な重み"はチャンネル単位で選びます。
しかし、重要な重みを単純に量子化せずに処理すると混合精度演算になってしまい、ハードウェア実装した場合に遅くなってしまいます(Figure 2b)。その代わりに整数精度計算で実現する方法を提案しています。
整数精度計算でやりたいことを実現
AQWでは、重要なチャンネルの重みに対しては、$w$を$sw (s>1)$とスケーリングし、その代わりに入力$x$を$x/s$にすることで辻褄を合わせる方法を提案しています。この場合の量子化誤差の変化量を解析します。
重み$\mathbf{w}$を量子化する時の量子化幅を$\Delta=\max(|\mathbf{w}|)/2^{N-1}$とします。この時、量子化関数を$Q(\cdot)$とすると、以下のようになります。
Q(\mathbf{w}) = \Delta\cdot\mathrm{Round}(\frac{\mathbf{w}}{\Delta})
重み$\mathbf{w}$のうちの1つ$w$をスケーリングした後の量子化幅を$\Delta'$とすると、スケーリング後の値について以下が成り立ちます。
Q(w\cdot s)\cdot\frac{x}{s} = \Delta'\cdot\mathrm{Round}(\frac{ws}{\Delta'})\cdot x\cdot\frac{1}{s}
ここで、著者らは実験的に以下が成り立つことを見出しています。
(1) $\mathrm{Round}(\cdot)$の誤差$\mathtt{RoundErr}(\cdot)$は中身が変わっても同じである。$\mathtt{RoundErr}(\cdot)$は[0, 0.5]の一様分布なので、期待値は0.25であるため。
(2) 重み$\mathbf{w}$のうちの1つ$w$をスケーリングしても、$\max(|\mathbf{w}|)$は変わらない。 つまり$\Delta'\approx\Delta.$ おそらく、重要なチャンネル以外の重みが$\max(|\mathbf{w}|)$に支配的なのではないかと思います。
(3) スケーリング前の重みを$w$、対応する入力を$x$とする。Activation計算はFloatで行うこと仮定すると、$\Delta$や$x$はFloat型のため誤差がないと考えて、下記が成り立つ。
\mathtt{Err}(Q(w)x) = \Delta\cdot\mathtt{RoundErr}(\frac{w}{\Delta})\cdot x
\mathtt{Err}(Q(w\cdot s)(\frac{x}{s})) = \Delta'\cdot\mathtt{RoundErr}(\frac{ws}{\Delta'})\cdot x\cdot\frac{1}{s}
スケーリング後の誤差のスケーリング前の誤差に対する比率は
\mathtt{Err}(Q(w\cdot s)(\frac{x}{s})) / \mathtt{Err}(Q(w)x) = \frac{\Delta'}{\Delta}\cdot\frac{1}{s}
ここで、$\Delta'\approx\Delta$と$s>1$より、スケーリング後の誤差はスケーリング前の誤差より小さくなる。
最適なスケーリング係数を決める
AWQでは、スケーリング係数$s$は、すべてのチャンネルに対して最適な値を決めます。形式的には、各チャンネルの最適なスケーリング係数$\mathbf{s}^*$は、下記を少量のキャリブレーションデータで解いて決めることになります。
\mathbf{s}^* = \mathrm{argmin}_{\mathbf{s}}\mathcal{L}(\mathbf{s})
\mathcal{L}(\mathbf{s}) = ||Q(\mathbf{W}\cdot\mathrm{diag}(\mathbf{s}))\cdot(\mathrm{diag}(\mathbf{s})^{-1}\cdot\mathbf{X}) - \mathbf{WX}||
AWQでは、計算量を削減するために、探索空間を非常にシンプルにしています。
\mathbf{s}^* = \mathbf{s_{\mathbf{X}}}^{\alpha^*},
\alpha^* = \mathrm{argmin}_{\alpha}\mathcal{L}(\mathbf{s_{\mathbf{X}}}^{\alpha})
ここで、$\mathbf{s_{\mathbf{X}}}$はチャンネルごとのActivationの絶対値の平均です。$\alpha\in[0, 1]$が最適化パラメータで、グリッドサーチで最適値を求めます。$\mathbf{s_{\mathbf{X}}}^{\alpha}$は$\mathbf{s_{\mathbf{X}}}$の各要素を$\alpha$乗するという意味です。
評価
WikiText-2の言語タスクによる評価結果です。LLaMAとLlama-2を従来手法のGPTQと比較して改善されていることが示されています。