BitNetの学習をquantization-aware trainin(QAT)という。
要するに量子化済み重みでも推論において精度低下しないための学習である。
しかし、実装なんかを見るとこの量子化重みをいちから学習しているように思う。
さて、あらかじめ任意の「学習済み重み」が存在する場合に、正解ラベルを与えずに学習が出来るか試したい。「学習済み重みの推論」と「量子化後重みの推論」の2個の推論結果のMSELossを取って「量子化前重み」を学習させることが可能かどうかである。
一般的なBitNetは「量子化前重み」と「量子化後重み」の二種類があり、「量子化前重み」を学習する。このためモデルに「学習済み重み」を読み込んでも「量子化前重み」の更新によって「学習済み重み」をいずれ忘れてしまう。
自分が考えたいのは「学習済み重み」、「量子化前重み」、「量子化後重み」があり、「量子化前重み」を更新しても「学習済み重み」は更新しない方針である。
この場合、正解ラベルは不要でこれは自己教師学習(SimCLRとか)の考え方に近い。
コード
まず最初に「学習済み重み」と「量子化前重み」を定義し、これを任意にタイミングで同期させる。強化学習のonline、targetモデル学習を参考にすると下記のように書ける。
class bitNet(nn.Module):
def __init__(self, in_features, out_features):
super(bitNet, self).__init__()
self.w0 = nn.Module()
self.w0.fc1 = nn.Linear(in_features, 784)
self.w0.fc2 = nn.Linear(784, 784)
self.w0.fc3 = nn.Linear(784, out_features)
self.w0.norm1 = nn.LayerNorm(in_features)
self.w0.norm2 = nn.LayerNorm(784)
self.w0.norm3 = nn.LayerNorm(784)
self.w1 = copy.deepcopy(self.w0)
...
def w0_forward(self, x):
x = self.w0.norm1(x)
x = self.w0.fc1(x)
x = F.relu(x)
x = self.w0.norm2(x)
x = self.w0.fc2(x)
x = F.relu(x)
x = self.w0.norm3(x)
x = self.w0.fc3(x)
return x
def w1_forward(self, x):
x = self.w1.norm1(x)
x = self.w1.fc1(x)
x = F.relu(x)
x = self.w1.norm2(x)
x = self.w1.fc2(x)
x = F.relu(x)
x = self.w1.norm3(x)
x = self.w1.fc3(x)
return x
def sync_forward(self, x):
self.w1.load_state_dict(self.w0.state_dict())
y = self.w1_forward(x)
return y
次に「量子化済み重みで推論」を実行してみる。
「学習済み重みの推論」と「量子化済み重みで推論」を求め、このMSELossからself.w1
重みを学習させる。
def quant_forward(self, x):
x = self.w1.norm1(x)
qx = self.quant_x(x)
qx = (qx - x).detach() + x
w = self.w1.fc1.weight
qw = self.quant_w(w)
qw = (qw - w).detach() + w
x = torch.nn.functional.linear(qx, qw, self.w1.fc1.bias)
x = F.relu(x)
x = self.w1.norm2(x)
qx = self.quant_x(x)
qx = (qx - x).detach() + x
w = self.w1.fc2.weight
qw = self.quant_w(w)
qw = (qw - w).detach() + w
x = torch.nn.functional.linear(qx, qw, self.w1.fc2.bias)
x = F.relu(x)
x = self.w1.norm3(x)
qx = self.quant_x(x)
qx = (qx - x).detach() + x
w = self.w1.fc3.weight
qw = self.quant_w(w)
qw = (qw - w).detach() + w
x = torch.nn.functional.linear(qx, qw, self.w1.fc3.bias)
return x
def train_ssl(self, x):
y_true = (self.w0_forward(x)).detach()
y_pred = self.quant_forward(x)
self.train_backward2(y_pred, y_true)
return y_pred
def train_backward2(self, y_pred, y_true):
loss = nn.MSELoss()(y_pred, y_true)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.loss = loss.item()
この学習モデルで「学習済みモデル」同期した後、正しい正解ラベルを与えずに学習させた。通常の学習(check_type=5)は正解ラベルを与えないと上手く行かないが、「学習済みモデルの推論」に近づける学習(check_type=7)では正しい正解ラベルを与えなくても精度が出ている。(check_type=6,8)に対してはモデルの中間層での二個の推論を分割したままではなく、推論途中で足してる(交差する)のを意識したが何で学習出来ているのかよく分からない。
for i, (images, labels) in enumerate(train_loader):
...
if label_zero:
# labels are not used !!
labels = torch.zeros_like(labels)
outputs = model(images, labels)
--------------------------------------------------------
label_zero= True , check_type= 5 , acc= 0.098 , last_loss= 0.0001803367819353298
label_zero= False , check_type= 5 , acc= 0.9722 , last_loss= 0.0004097218212361137
label_zero= True , check_type= 6 , acc= 0.098 , last_loss= 7.17351681892372e-05
label_zero= False , check_type= 6 , acc= 0.9711 , last_loss= 0.0008830680602540573
label_zero= True , check_type= 7 , acc= 0.9663 , last_loss= 0.0005115935863927006
label_zero= False , check_type= 7 , acc= 0.9683 , last_loss= 0.0005283699582641323
label_zero= True , check_type= 8 , acc= 0.9628 , last_loss= 0.0023707961219052473
label_zero= False , check_type= 8 , acc= 0.9632 , last_loss= 0.0024202951254944007
まとめ
bitnetの追加学習は学習元のデータが必要だがSDXLにおいて学習データ(多数の画像データ)を用意するのは簡単ではない。一方、訓練済み重みと等価になる量子化済み重みなら学習ラベル(SDXLでは画像)なしでも学習できるかもしれない。ただし、promptは要る。
全コード
import torch
import torch.nn as nn
import copy
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
class bitNet(nn.Module):
def __init__(self, in_features, out_features):
super(bitNet, self).__init__()
self.w0 = nn.Module()
self.w0.fc1 = nn.Linear(in_features, 784)
self.w0.fc2 = nn.Linear(784, 784)
self.w0.fc3 = nn.Linear(784, out_features)
self.w0.norm1 = nn.LayerNorm(in_features)
self.w0.norm2 = nn.LayerNorm(784)
self.w0.norm3 = nn.LayerNorm(784)
self.w1 = copy.deepcopy(self.w0)
self.eps = 10e-6
self.optimizer = torch.optim.Adam(self.parameters(), lr=0.00025)
self.loss = 0.0
self.forward_type = 0
def quant_x(self, x):
gamma = torch.abs(x).max() / 127.0
x = torch.clamp(torch.round(x/gamma), -127.0 + self.eps, 127.0 - self.eps)
x = x * gamma
return x
def quant_w(self, w):
alpha = w.mean()
beta = torch.abs(w-alpha).mean()
w = torch.clamp(torch.round((w-alpha)/beta), -1.0, 1.0)
w = w * beta
return w
def w0_forward(self, x):
x = self.w0.norm1(x)
x = self.w0.fc1(x)
x = F.relu(x)
x = self.w0.norm2(x)
x = self.w0.fc2(x)
x = F.relu(x)
x = self.w0.norm3(x)
x = self.w0.fc3(x)
return x
def w1_forward(self, x):
x = self.w1.norm1(x)
x = self.w1.fc1(x)
x = F.relu(x)
x = self.w1.norm2(x)
x = self.w1.fc2(x)
x = F.relu(x)
x = self.w1.norm3(x)
x = self.w1.fc3(x)
return x
def quant_forward(self, x):
x = self.w1.norm1(x)
qx = self.quant_x(x)
qx = (qx - x).detach() + x
w = self.w1.fc1.weight
qw = self.quant_w(w)
qw = (qw - w).detach() + w
x = torch.nn.functional.linear(qx, qw, self.w1.fc1.bias)
x = F.relu(x)
x = self.w1.norm2(x)
qx = self.quant_x(x)
qx = (qx - x).detach() + x
w = self.w1.fc2.weight
qw = self.quant_w(w)
qw = (qw - w).detach() + w
x = torch.nn.functional.linear(qx, qw, self.w1.fc2.bias)
x = F.relu(x)
x = self.w1.norm3(x)
qx = self.quant_x(x)
qx = (qx - x).detach() + x
w = self.w1.fc3.weight
qw = self.quant_w(w)
qw = (qw - w).detach() + w
x = torch.nn.functional.linear(qx, qw, self.w1.fc3.bias)
return x
def quant_forward2(self, x):
x_t = self.w0.norm1(x)
x_t = self.w0.fc1(x_t)
x_t = F.relu(x_t)
x = self.w1.norm1(x)
qx = self.quant_x(x)
qx = (qx - x).detach() + x
w = self.w1.fc1.weight
qw = self.quant_w(w)
qw = (qw - w).detach() + w
x = torch.nn.functional.linear(qx, qw, self.w1.fc1.bias)
x = F.relu(x)
x = (x - x_t).detach() + x
x_t = self.w0.norm2(x)
x_t = self.w0.fc2(x_t)
x_t = F.relu(x_t)
x = self.w1.norm2(x)
qx = self.quant_x(x)
qx = (qx - x).detach() + x
w = self.w1.fc2.weight
qw = self.quant_w(w)
qw = (qw - w).detach() + w
x = torch.nn.functional.linear(qx, qw, self.w1.fc2.bias)
x = F.relu(x)
x = (x - x_t).detach() + x
x_t = self.w0.norm3(x)
x_t = self.w0.fc3(x_t)
x = self.w1.norm3(x)
qx = self.quant_x(x)
qx = (qx - x).detach() + x
w = self.w1.fc3.weight
qw = self.quant_w(w)
qw = (qw - w).detach() + w
x = torch.nn.functional.linear(qx, qw, self.w1.fc3.bias)
x = (x - x_t).detach() + x
return x
def sync_forward(self, x):
self.w1.load_state_dict(self.w0.state_dict())
y = self.w1_forward(x)
return y
def train_normal(self, x, y_true):
y_pred = self.w0_forward(x)
y_pred = nn.LogSoftmax(dim=1)(y_pred)
self.train_backward(y_pred, y_true)
return y_pred
def train_quant(self, x, y_true):
y_pred = self.quant_forward(x)
y_pred = nn.LogSoftmax(dim=1)(y_pred)
self.train_backward(y_pred, y_true)
return y_pred
def train_quant2(self, x, y_true):
y_pred = self.quant_forward2(x)
y_pred = nn.LogSoftmax(dim=1)(y_pred)
self.train_backward(y_pred, y_true)
return y_pred
def train_ssl(self, x):
y_true = (self.w0_forward(x)).detach()
y_pred = self.quant_forward(x)
self.train_backward2(y_pred, y_true)
return y_pred
def train_ssl2(self, x):
y_true = (self.w0_forward(x)).detach()
y_pred = self.quant_forward2(x)
self.train_backward2(y_pred, y_true)
return y_pred
def train_backward(self, y_pred, y_true):
loss = nn.NLLLoss()(y_pred, y_true)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.loss = loss.item()
def train_backward2(self, y_pred, y_true):
loss = nn.MSELoss()(y_pred, y_true)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.loss = loss.item()
def forward(self, x, y_true):
if self.forward_type==0:
y = self.w0_forward(x)
y = nn.LogSoftmax(dim=1)(y)
if self.forward_type==1:
y = self.w1_forward(x)
y = nn.LogSoftmax(dim=1)(y)
if self.forward_type==2:
y = self.quant_forward(x)
y = nn.LogSoftmax(dim=1)(y)
if self.forward_type==3:
y = self.sync_forward(x)
y = nn.LogSoftmax(dim=1)(y)
if self.forward_type==4:
y = self.train_normal(x, y_true)
if self.forward_type==5:
y = self.train_quant(x, y_true)
if self.forward_type==6:
y = self.train_quant2(x, y_true)
if self.forward_type==7:
y = self.train_ssl(x)
if self.forward_type==8:
y = self.train_ssl2(x)
return y
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 256
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download = True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download = True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
for check_type in [5,6,7,8]:
for label_zero in [True, False]:
model = bitNet(784, 10).to(device)
train_loss = 0.0
num_train = 0
model.train()
for name, param in model.w0.named_parameters():
param.requires_grad_(True)
for name, param in model.w1.named_parameters():
param.requires_grad_(False)
model.forward_type = 4
for i, (images, labels) in enumerate(train_loader):
num_train += len(labels)
images, labels = images.view(-1, 28*28).to(device), labels.to(device)
outputs = model(images, labels)
train_loss += model.loss
#if i % 30==0:
# print(i, train_loss/float(num_train))
for name, param in model.w0.named_parameters():
param.requires_grad_(False)
for name, param in model.w1.named_parameters():
param.requires_grad_(True)
train_loss = 0.0
num_train = 0
for i, (images, labels) in enumerate(train_loader):
if i==0:
model.forward_type = 3
else:
model.forward_type = check_type
num_train += len(labels)
images, labels = images.view(-1, 28*28).to(device), labels.to(device)
if label_zero:
# labels are not used !!
labels = torch.zeros_like(labels)
outputs = model(images, labels)
train_loss += model.loss
#if i % 30==0:
# print(i, train_loss/float(num_train))
correct = 0
total = 0
model.eval()
model.forward_type = 2
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.view(-1, 28*28).to(device), labels.to(device)
outputs = model(images, labels)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
acc = correct / total
print('label_zero=', label_zero, ', check_type=', check_type, ', acc=', acc, ', last_loss=', train_loss/float(num_train))
del model
--------------------------------------------------------
label_zero= True , check_type= 5 , acc= 0.098 , last_loss= 0.0001803367819353298
label_zero= False , check_type= 5 , acc= 0.9722 , last_loss= 0.0004097218212361137
label_zero= True , check_type= 6 , acc= 0.098 , last_loss= 7.17351681892372e-05
label_zero= False , check_type= 6 , acc= 0.9711 , last_loss= 0.0008830680602540573
label_zero= True , check_type= 7 , acc= 0.9663 , last_loss= 0.0005115935863927006
label_zero= False , check_type= 7 , acc= 0.9683 , last_loss= 0.0005283699582641323
label_zero= True , check_type= 8 , acc= 0.9628 , last_loss= 0.0023707961219052473
label_zero= False , check_type= 8 , acc= 0.9632 , last_loss= 0.0024202951254944007
(参考):SDXLのBitNetの自己教師学習
いちおうSDXLで「学習済み重み」と「BitNetの量子化後重み」の自己教師学習を出来るか書いてみたコードを残しておくが、エラーが出て上手く動かなかった。
モデル重みの8%程度のto_q
のレイヤーをbitlinearに置換する。
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
とあって学習時forward_type==3
の時失敗する。forward_type==0~2
の推論では問題ない。
勾配が何故か消えているのでSDXLのpipeline内の@torch.no_grad()
あたりが原因かもしれない。
forward_type==2
の結果は前回の1.58bit量子化のto_qの変換結果と等しい。
import torch
import torch.nn as nn
import copy
import torch.optim as optim
import bitsandbytes as bnb
global forward_type
class bitlinear(nn.Module):
def __init__(self, param):
super(bitlinear, self).__init__()
self.w0 = nn.Module()
self.w0.fc = nn.Linear(param.shape[0], param.shape[1], bias=False, dtype=torch.float16)
self.w0.fc.weight = param
self.w1 = copy.deepcopy(self.w0)
self.w1.load_state_dict(self.w0.state_dict())
self.w0.fc.requires_grad_(False)
self.w1.fc.requires_grad_(True)
self.eps = 10e-6
self.loss = 0.0
def quant_x(self, x):
gamma = torch.abs(x).max() / 127.0
x = torch.clamp(torch.round(x/gamma), -127.0 + self.eps, 127.0 - self.eps)
x = x * gamma
return x
def quant_w(self, w):
alpha = w.mean()
beta = torch.abs(w-alpha).mean()
w = torch.clamp(torch.round((w-alpha)/beta), -1.0, 1.0)
w = w * beta
return w
def w0_forward(self, x):
x = self.w0.fc(x)
return x
def w1_forward(self, x):
x = self.w1.fc(x)
return x
def quant_forward(self, x):
qx = self.quant_x(x)
qx = (qx - x).detach() + x
w = self.w1.fc.weight
qw = self.quant_w(w)
qw = (qw - w).detach() + w
y = torch.nn.functional.linear(qx, qw)
return y
def train_ssl(self, x):
y_true = (self.w0_forward(x)).detach()
y_pred = self.quant_forward(x)
self.loss = nn.MSELoss()(y_pred, y_true)
return y_true
def forward(self, x):
if forward_type==0:
y = self.w0_forward(x)
if forward_type==1:
y = self.w1_forward(x)
if forward_type==2:
y = self.quant_forward(x)
if forward_type==3:
y = self.train_ssl(x)
return y
def model_replace(pipe):
for i in range(2):
for j in range(2):
param = pipe.unet.down_blocks[1].attentions[i].transformer_blocks[j].attn1.to_q.weight
pipe.unet.down_blocks[1].attentions[i].transformer_blocks[j].attn1.to_q = bitlinear(param)
param = pipe.unet.down_blocks[1].attentions[i].transformer_blocks[j].attn2.to_q.weight
pipe.unet.down_blocks[1].attentions[i].transformer_blocks[j].attn2.to_q = bitlinear(param)
for i in range(2):
for j in range(10):
param = pipe.unet.down_blocks[2].attentions[i].transformer_blocks[j].attn1.to_q.weight
pipe.unet.down_blocks[2].attentions[i].transformer_blocks[j].attn1.to_q = bitlinear(param)
param = pipe.unet.down_blocks[2].attentions[i].transformer_blocks[j].attn2.to_q.weight
pipe.unet.down_blocks[2].attentions[i].transformer_blocks[j].attn2.to_q = bitlinear(param)
for i in range(3):
for j in range(10):
param = pipe.unet.up_blocks[0].attentions[i].transformer_blocks[j].attn1.to_q.weight
pipe.unet.up_blocks[0].attentions[i].transformer_blocks[j].attn1.to_q = bitlinear(param)
param = pipe.unet.up_blocks[0].attentions[i].transformer_blocks[j].attn2.to_q.weight
pipe.unet.up_blocks[0].attentions[i].transformer_blocks[j].attn2.to_q = bitlinear(param)
for i in range(3):
for j in range(2):
param = pipe.unet.up_blocks[1].attentions[i].transformer_blocks[j].attn1.to_q.weight
pipe.unet.up_blocks[1].attentions[i].transformer_blocks[j].attn1.to_q = bitlinear(param)
param = pipe.unet.up_blocks[1].attentions[i].transformer_blocks[j].attn2.to_q.weight
pipe.unet.up_blocks[1].attentions[i].transformer_blocks[j].attn2.to_q = bitlinear(param)
for i in range(1):
for j in range(10):
param = pipe.unet.mid_block.attentions[i].transformer_blocks[j].attn1.to_q.weight
pipe.unet.mid_block.attentions[i].transformer_blocks[j].attn1.to_q = bitlinear(param)
param = pipe.unet.mid_block.attentions[i].transformer_blocks[j].attn2.to_q.weight
pipe.unet.mid_block.attentions[i].transformer_blocks[j].attn2.to_q = bitlinear(param)
return pipe
def get_loss(pipe):
loss = 0.0
for i in range(2):
for j in range(2):
loss += pipe.unet.down_blocks[1].attentions[i].transformer_blocks[j].attn1.to_q.loss
loss += pipe.unet.down_blocks[1].attentions[i].transformer_blocks[j].attn2.to_q.loss
for i in range(2):
for j in range(10):
loss += pipe.unet.down_blocks[2].attentions[i].transformer_blocks[j].attn1.to_q.loss
loss += pipe.unet.down_blocks[2].attentions[i].transformer_blocks[j].attn2.to_q.loss
for i in range(3):
for j in range(10):
loss += pipe.unet.up_blocks[0].attentions[i].transformer_blocks[j].attn1.to_q.loss
loss += pipe.unet.up_blocks[0].attentions[i].transformer_blocks[j].attn2.to_q.loss
for i in range(3):
for j in range(2):
loss += pipe.unet.up_blocks[1].attentions[i].transformer_blocks[j].attn1.to_q.loss
loss += pipe.unet.up_blocks[1].attentions[i].transformer_blocks[j].attn2.to_q.loss
for i in range(1):
for j in range(10):
loss += pipe.unet.mid_block.attentions[i].transformer_blocks[j].attn1.to_q.loss
loss += pipe.unet.mid_block.attentions[i].transformer_blocks[j].attn2.to_q.loss
return loss
def callback_train(pipe, step_index, timestep, callback_kwargs):
if forward_type==3:
loss = get_loss(pipe)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return callback_kwargs
from diffusers import DiffusionPipeline
import numpy as np
model_id = './stable-diffusion-xl-base-1.0/'
output_path = './bitnet_ssl/'
prompt = "a photo of an astronaut riding a horse on mars"
seed = 42
generator = torch.Generator(device="cuda")
generator = generator.manual_seed(seed)
pipe = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True, torch_dtype=torch.float16, variant="fp16").to("cuda")
pipe.enable_model_cpu_offload()
pipe = model_replace(pipe)
train_params = []
for name, param in pipe.unet.named_parameters():
if (len(param.size())==2 and ('to_q' in name)):
print(name, param.shape)
if param.requires_grad:
train_params.append(param)
print(pipe.unet.down_blocks[1].attentions[0].transformer_blocks[0].attn1.to_q)
print(pipe.unet.down_blocks[1].attentions[0].transformer_blocks[0].attn1.to_q.w0.fc)
print(pipe.unet.down_blocks[1].attentions[0].transformer_blocks[0].attn1.to_q.w1.fc)
#optimizer = torch.optim.Adam(train_params, lr=0.00025)
optimizer = bnb.optim.PagedAdamW8bit(train_params, lr=0.00025)
torch.cuda.empty_cache()
for i in range(4):
forward_type = i
generator = generator.manual_seed(seed)
image = pipe(prompt=prompt, generator=generator, callback_on_step_end=callback_train).images[0]
image.save(output_path + "img_ssl_forward_type%d.png" % (forward_type))
for i in range(100):
forward_type = 2
generator = generator.manual_seed(seed)
image = pipe(prompt=prompt, generator=generator).images[0]
image.save(output_path + "img_ssl_train_num_%d.png" % (i))
forward_type = 3
generator = generator.manual_seed(seed)
image = pipe(prompt=prompt, generator=generator, callback_on_step_end=callback_train).images[0]