9
4

3値量子化パラメータを学習できるオプティマイザを作った【PyTorch】

Last updated at Posted at 2024-03-16

はじめに

2024年2月28日、Microsoft より BitNet b1.58 というモデルの論文1が公開されました。
この論文は量子化されたモデルを学習する手法を提案しており、精度は落とさずに推論の速度やメモリ消費量を改善できるという内容です。

このモデルの中で重みは {0, -1, 1} の3値で表現されており、にも関わらず学習ができているという点に感銘を受け、以下に紹介するオプティマイザを作成しました。

(BitNet を実装するという話ではありません。オプティマイザでこんな事ができるよという話です)

TernaryOpt

3値で量子化されたパラメータを学習できる pytorch のオプティマイザを作りました。これを TernaryOpt と呼ぶことにします。
このオプティマイザを使用して学習したすべての重みは 0, , γ のいずれか値をとります。ここで γ は分散を調整するための値です。

TernaryOpt を使用する利点として、パラメータの保存方法を工夫すればモデルファイルのサイズ軽量化が可能になります。

TernaryOpt は既存のオプティマイザを継承することで定義されます。基底クラスとなるオプティマイザは任意に差し替えることができます。

TernaryOpt はただのオプティマイザであるため、既存のモデルに適用することが容易です。

以下が TernaryOpt の全文です。

from torch import optim

class TernaryOpt(optim.Adam): # <- 基底クラスは任意に置き換え可能
    def __init__(self, *args, **kwargs):
        self.norm = kwargs.pop('norm', True)

        super(TernaryOpt, self).__init__(*args, **kwargs)

    def step(self, closure = None) -> None:
        for group in self.param_groups:
            for p in group['params']:
                if 'raw' in self.state[p]:
                    p.data = self.state[p]['raw']

        super(TernaryOpt, self).step(closure)

        for group in self.param_groups:
            for p in group['params']:
                self.state[p]['raw'] = p.data
                if self.norm:
                    gamma = p.data.abs().mean()
                    p.data = (p.data / gamma).round().clamp(-1, 1) * gamma
                else:
                    p.data = p.data.round().clamp(-1, 1)

アルゴリズム

このオプティマイザは量子化された重みを直接最適化しません。量子化前の重みが raw として保持されていて、基底クラスのオプティマイザが raw を最適化します。そして TernaryOpt の最適化の結果として raw を量子化して返します。
量子化は単に raw.round().clamp(-1, 1) として定義されます。

norm オプション

TernaryOpt は論文2に従い、量子化前後にスケーリングを行います。これは学習を安定させる効果があります。
スケーリングをオフにしたい場合は、 TernaryOpt(norm = False) のように使用してください。

実験

PyTorch の Mnist のサンプルコードをベースに TernaryOpt を適用している例です。

以下のように僅かな変更で元のコードと切替可能にできます。
(ここでは Adadelta を基底クラスとしているので TernaryAdadelta という名前にしています)

    if use_ternary:
        model = Net(bias = False).to(device)
        optimizer = TernaryAdadelta(model.parameters(), lr=args.lr)
    else:
        model = Net(bias = True).to(device)
        optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

学習結果

オリジナル

Test set: Average loss: 0.0259, Accuracy: 9922/10000 (99%)

TernaryOpt

Test set: Average loss: 0.0331, Accuracy: 9895/10000 (99%)

使用時の注意

埋め込みベクトルやバイアスまで量子化されてしまうことに注意してください。
パラメータごとに違うオプティマイザを使用したい場合はパラメータごとのオプションを使ってください。
Tips: バイアスは削除しても問題ないかもしれません3

  1. The Era of 1-bit LLMs:
    All Large Language Models are in 1.58 Bits https://arxiv.org/abs/2402.17764 2

  2. BitNet: Scaling 1-bit Transformers for
    Large Language Models
    https://arxiv.org/abs/2310.11453

  3. 1removes all biases と書いてある

9
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
4