LoginSignup
27
17

BitNet

Last updated at Posted at 2024-02-29

※BitNetそのものを検証した記事ではなく,BitNetで使われているレイヤーを使った簡単なNNモデルを作成し検証しています.

はじめに

今回はBitNetの紹介になります.BitNetは量子化を考慮した学習を行うアーキテクチャとなっており,他の量子化手法と比べ競争力のある性能を保ちつつ,メモリ消費量を大幅に抑えています.
今回の検証ではBitNetで使われているBitLinearレイヤーを使いLanguage Modelではなくテーブルデータのための簡単なNNモデルを作り学習を行えるか検証しています.

記事に誤り等ありましたらご指摘いただけますと幸いです。

目次

1. BitNet

ライセンス:MIT

リポジトリ:https://github.com/kyegomez/BitNet
(今回検証に用いた実装であり,公式の実装ではないです)

論文:https://arxiv.org/abs/2310.11453

BitNetは8bit量子化やFP16 Transformerなどの量子化を行う新しいアプローチであり,Language Modelにおける実験の結果,メモリ消費を大幅に削減しつつ,従来の量子化手法に比べ競争力のある性能レベルを達成しています.

要点

  • 主な貢献はBitLinearというレイヤーであり,それをMulti-Head AttentionおよびFeed-Forward Networkにも適用して作成したものがBitNet
  • 8bitで順伝搬と逆伝播を計算し,重みの更新時は[-1,0,1]の値を使う
  • 行列計算を乗算を使わずに加減算のみで計算を行う
  • BitNetは量子化を考慮した学習方法となっており,従来の量子化のアプローチである学習済みのモデルを量子化する方法と異なり,今回の結果は量子化を考慮した学習方法の可能性を示している
  • スケーリング則にも従うため今後はさらにBitNetをスケールアップしていく予定

実装

  • BitLinear
    BitLinearの計算フロー,具体的な計算方法は論文を参照してください.

スクリーンショット 2024-02-29 12.54.00.png

from bitnet import BitLinear

layer = BitLinear(512, 400)
  • BitNetTransformer
    BitNetのMulti-Head Attention部分のアーキテクチャおよびコードです.

スクリーンショット 2024-02-29 13.05.25.png

from bitnet import BitNetTransformer

bitnet = BitNetTransformer(
    num_tokens=20000,
    dim=512,
    depth=6,
    dim_head=64,
    heads=8,
    ff_mult=4,
)
  • BitFeedForward
    BitNetのFeedForward部分のアーキテクチャおよびコードです.

スクリーンショット 2024-02-29 13.05.50.png

from bitnet.bitffn import BitFeedForward

ff = BitFeedForward(512)
  • BitNetのモデルの推論
    通常のPytorchのモデルと同様に保存および呼び出しを行うことができます.
from bitnet import BitNetInference

bitnet = BitNetInference()
bitnet.load_model('../model_checkpoint.pth') #Download model
output_str = bitnet.generate("The dog jumped over the ", 512)
print(output_str)

2. BitLinearの検証

今回はTitanicデータセット(テーブルデータ)でBitNetの全結合層(BitLinear)を使い学習ができるかを検証しました.すぐに試したい方はColabのリンクから実行してみてください

Open In Colab

以下のコマンドでインストール

pip install bitnet

必要なパッケージ

import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset

# 今回必要なもの
from bitnet import BitLinear

データセットの前処理.TitanicデータセットをNNモデルに入力できる形式に変更します.この辺りは特に工夫などしていません.

df = sns.load_dataset('titanic')

df = df[['survived', 'pclass', 'sex', 'age', 'sibsp', 'parch', 'fare']]
df['sex'] = df['sex'].map({'male': 0, 'female': 1})
df.dropna(inplace=True)

X = df.drop('survived', axis=1).values
y = df['survived'].values

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)

X_train_tensor = torch.tensor(X_train, dtype=torch.float)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test, dtype=torch.float)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

BitNetを使った簡単なNNモデルの実装です.以下のようにBitLinearの層を追加することで簡単に組み込むことができます.

class BitNetModel(nn.Module):
    def __init__(self):
        super(BitNetModel, self).__init__()
        self.fc1 = BitLinear(6, 64)  # BitLinearのレイヤー
        self.relu = nn.ReLU()
        self.fc2 = BitLinear(64, 2)    # BitLinearのレイヤー

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = BitNetModel()

以下のコードで学習を行います.今回は100エポックで学習を回しました.

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 100
for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

with torch.no_grad():
    correct = 0
    total = 0
    outputs = model(X_test_tensor)
    _, predicted = torch.max(outputs.data, 1)
    total += y_test_tensor.size(0)
    correct += (predicted == y_test_tensor).sum().item()

print(f'Accuracy: {100 * correct / total:.2f}%')

以下は結果になります.特徴量作成などしっかり行っておらず特に工夫していなかったので精度は77%ぐらいになりました.BitNetも同じぐらいの精度が出ており通常のモデルと同様に学習が行えることが確認できました.

  • 通常のNNモデルのAccuracy:77.62
  • BitLinearに置き換えたNNモデルのAccuracy:76.22

4. おわりに

今回はBitNetの紹介と簡単な検証を行いました.今回紹介したOSSのBitNet(BitLinear)はPytorchのNNモデルの実装に簡単に組み込むことができ,通常のモデルと同じように学習を行うことができました.今後はLanguage Modelなどを実装し,どれぐらいの精度を出すことができるか検証していければと考えています.

最後までお読みいただきありがとうございます.記事に誤り等ありましたらご指摘いただけますと幸いです。

5. 参考文献

27
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
27
17