はじめに
2024年2月28日、Microsoft より BitNet b1.58 というモデルの論文1が公開されました。
この論文は量子化されたモデルを学習する手法を提案しており、精度は落とさずに推論の速度やメモリ消費量を改善できるという内容です。
このモデルの中で重みは {0, -1, 1}
の3値で表現されており、にも関わらず学習ができているという点に感銘を受け、以下に紹介するオプティマイザを作成しました。
(BitNet を実装するという話ではありません。オプティマイザでこんな事ができるよという話です)
TernaryOpt
3値で量子化されたパラメータを学習できる pytorch のオプティマイザを作りました。これを TernaryOpt と呼ぶことにします。
このオプティマイザを使用して学習したすべての重みは 0
, -γ
, γ
のいずれか値をとります。ここで γ
は分散を調整するための値です。
TernaryOpt を使用する利点として、パラメータの保存方法を工夫すればモデルファイルのサイズ軽量化が可能になります。
TernaryOpt は既存のオプティマイザを継承することで定義されます。基底クラスとなるオプティマイザは任意に差し替えることができます。
TernaryOpt はただのオプティマイザであるため、既存のモデルに適用することが容易です。
以下が TernaryOpt の全文です。
from torch import optim
class TernaryOpt(optim.Adam): # <- 基底クラスは任意に置き換え可能
def __init__(self, *args, **kwargs):
self.norm = kwargs.pop('norm', True)
super(TernaryOpt, self).__init__(*args, **kwargs)
def step(self, closure = None) -> None:
for group in self.param_groups:
for p in group['params']:
if 'raw' in self.state[p]:
p.data = self.state[p]['raw']
super(TernaryOpt, self).step(closure)
for group in self.param_groups:
for p in group['params']:
self.state[p]['raw'] = p.data
if self.norm:
gamma = p.data.abs().mean()
p.data = (p.data / gamma).round().clamp(-1, 1) * gamma
else:
p.data = p.data.round().clamp(-1, 1)
アルゴリズム
このオプティマイザは量子化された重みを直接最適化しません。量子化前の重みが raw
として保持されていて、基底クラスのオプティマイザが raw
を最適化します。そして TernaryOpt の最適化の結果として raw
を量子化して返します。
量子化は単に raw.round().clamp(-1, 1)
として定義されます。
norm オプション
TernaryOpt は論文2に従い、量子化前後にスケーリングを行います。これは学習を安定させる効果があります。
スケーリングをオフにしたい場合は、 TernaryOpt(norm = False)
のように使用してください。
実験
PyTorch の Mnist のサンプルコードをベースに TernaryOpt を適用している例です。
以下のように僅かな変更で元のコードと切替可能にできます。
(ここでは Adadelta を基底クラスとしているので TernaryAdadelta
という名前にしています)
if use_ternary:
model = Net(bias = False).to(device)
optimizer = TernaryAdadelta(model.parameters(), lr=args.lr)
else:
model = Net(bias = True).to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
学習結果
オリジナル
Test set: Average loss: 0.0259, Accuracy: 9922/10000 (99%)
TernaryOpt
Test set: Average loss: 0.0331, Accuracy: 9895/10000 (99%)
使用時の注意
埋め込みベクトルやバイアスまで量子化されてしまうことに注意してください。
パラメータごとに違うオプティマイザを使用したい場合はパラメータごとのオプションを使ってください。
Tips: バイアスは削除しても問題ないかもしれません3。
-
The Era of 1-bit LLMs:
All Large Language Models are in 1.58 Bits https://arxiv.org/abs/2402.17764 ↩ ↩2 -
BitNet: Scaling 1-bit Transformers for
Large Language Models
https://arxiv.org/abs/2310.11453 ↩