この記事では,Pytorchを用いてNeural Tangent Kernel (NTK) およびNeural Gaussian Process (NNGP) を実装し実験する方法について簡単に紹介します.なお今回は簡単のために2クラス分類にのみ適用可能な実装となっています.
0.データセットとモデルの定義
まずはPytorchを利用してCIFAR10の用意をします.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
CIFAR10から最初の2クラスのみを抽出したデータセットを作成した上でデータローダーを定義します.
def filter_classes(dataset, class_list):
# Convert the dataset targets to a tensor
targets_tensor = torch.tensor(dataset.targets)
# Create a mask for the specified classes
mask = torch.zeros_like(targets_tensor, dtype=torch.bool)
for class_ in class_list:
mask |= (targets_tensor == class_)
# Apply the mask to filter the data and targets
dataset.data = dataset.data[mask.numpy()]
dataset.targets = targets_tensor[mask].numpy()
return dataset
train_set = filter_classes(train_set, [0,1])
test_set = filter_classes(test_set, [0,1])
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
学習には最初のバッチを用います.そのため今回はバッチサイズである128サンプルを利用した学習となります.
x_train, y_train = next(iter(train_loader))
x_train, y_train = x_train.to(device), y_train.to(device)
モデルは3層のMLPモデルを使用します.出力次元は1次元であり,クラスの予測確率を学習します.
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(32*32*3, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 1)
def forward(self, x):
x = x.view(-1, 32*32*3)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = SimpleNN().to(device)
なおNNGPおよびNTKは最終層が線形でありMSE Lossを用いている場合は解析解が存在しますが,Cross Entropy Lossなど出力層にSoftmax関数が含まれる場合には解析解が存在しません.
1. NNGPによるCIFAR10の2クラス分類
NNGPでは中間層はランダム初期値で固定し,最終層のみカーネル法を用いて重みを更新します.そのためL層のMLPモデルを
\boldsymbol{u}_l=\boldsymbol{W}_l \boldsymbol{h}_{l-1}+\boldsymbol{b}_l, \quad \boldsymbol{h}_l=\phi\left(\boldsymbol{u}_l\right) \quad(l=1, \ldots, L)
としたとき,NNGP解は
\boldsymbol{W}_{L}= \left(\boldsymbol{h}_{L-1} \boldsymbol{h}_{L-1}^{\top}+\rho \boldsymbol{I}\right)^{-1} \boldsymbol{h}_{L} \boldsymbol{y}_{\text{train}}, \quad \boldsymbol{W}_{l, 0}(l<L)
と表されます.
なおGaussian Processではガウス分布の分散を用いて未知データに対する予測の不確かさを表すことができますが,無限幅深層学習モデルの研究では基本的には平均値に着目します.
1.1 NNGP解による最終層の重みの更新
まずは最終層への入力$\boldsymbol{h}_{L-1}$を取得する必要があります.
今回はモデルの最終層であるmodel.fc3にforward_hookをかけることで,forward時に最終層への入力を保存します.他のモデルで実験を行う場合は,model.fc3の部分をそのモデルにおける最終層のモジュールに変えてください.
def forward_hook(module, in_data, out_data):
module.h = in_data[0].detach()
hook = model.fc3.register_forward_hook(forward_hook)
model.zero_grad()
output = model(x_train.float())
hook.remove()
次にNNGP解を計算した上で,最終層の重みを更新します.
calculate_theta_starは最終層のモジュールと$\boldsymbol{y}_{\text{train}}$を入力として,最終層の重みを更新する関数となっています.
def calculate_theta_star(last_module, y_train_float, lambda_val=0.1):
G = last_module.h
lambda_I = lambda_val * torch.eye(G.size(1)).to(G.device)
theta_star = torch.inverse(G.t() @ G + lambda_I) @ G.t() @ y_train_float
with torch.no_grad():
last_module.weight.copy_(theta_star.t())
last_module.bias.copy_(torch.zeros_like(last_module.bias))
y_train_float = y_train.float().unsqueeze(1)
calculate_theta_star(model.fc3, y_train_float)
1.2 Neural Gaussian Processを用いたクラス予測
以下のコードによってCIFAR10の予測確率を求めることができます.ここではクラス1の予測確率が0.5を超えた場合はそのデータをクラス1に分類し,0.5より小さい場合はそのデータをクラス0に分類しています.
def predict_with_nngp(model, x_test):
model.eval()
with torch.no_grad():
test_output = model(x_test.float())
return (test_output > 0.5).float()
def calculate_accuracy(test_loader, model):
test_preds_list = []
test_targets_list = []
with torch.no_grad():
for x_batch, y_batch in test_loader:
x_batch, y_batch = x_batch.to(device), y_batch.to(device)
preds = predict_with_nngp(model, x_batch)
test_preds_list.append(preds)
test_targets_list.append(y_batch)
test_preds = torch.cat(test_preds_list, dim=0)
test_targets = torch.cat(test_targets_list, dim=0)
accuracy = (test_preds.squeeze() == test_targets).float().mean().item()
return accuracy
accuracy = calculate_accuracy(test_loader, model)
print(f'Test Accuracy: {accuracy * 100:.2f}%')
上のコードを実行した結果,筆者の環境では
Test Accuracy: 71.35%
となりました.
2. NTKによるCIFAR10の2クラス分類
NTKではニューラルネットワークのランダム初期値近傍でのカーネル解を求めます.
ニューラルネットワーク$f(x)$のNTK行列$\Theta(x_1, x_2)$は以下で定義されます.
\Theta(x_1, x_2) = \nabla_\theta f(x_1) \nabla_\theta f(x_2)^{\top}
このNTK行列$\Theta(x_1, x_2)$を用いたカーネル解の出力は,
\langle f(x') \rangle_{\theta_0} = \Theta\left(x^{\prime}, x\right) \left( \Theta(x, x) + \rho \boldsymbol{I} \right)^{-1} \boldsymbol{y}
となります.なお$\langle \cdot \rangle_{\theta_0}$はランダム初期値についての平均を表しています.
2.1.Neural Tangent Kernelの計算
Functorchを用いてNeural Tangent Kernelを計算する関数を定義します.
メモリ消費量を削減するためには,Blockごとに分割した上で計算します.
from torch.func import vmap, jacrev
from functorch import make_functional_with_buffers
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, block_size=32):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1_flat = torch.cat([j.reshape(j.shape[0], -1) for j in jac1], dim=1)
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2_flat = torch.cat([j.reshape(j.shape[0], -1) for j in jac2], dim=1)
# Initialize result matrix
num_x1 = jac1_flat.shape[0]
num_x2 = jac2_flat.shape[0]
result = torch.zeros((num_x1, num_x2), device=jac1_flat.device)
# Block-wise computation to save memory
for i in range(0, num_x1, block_size):
for j in range(0, num_x2, block_size):
block_jac1 = jac1_flat[i:i+block_size]
block_jac2 = jac2_flat[j:j+block_size]
result[i:i+block_size, j:j+block_size] = block_jac1 @ block_jac2.T
return result
そしてempirical_ntk_jacobian_contraction関数を用いてNTK行列を計算します.
上のempirical_ntk_jacobian_contraction関数を用いるには,モデルが関数化されている必要があります.そのためmake_functionalを用いてモデルを関数に変換します.
from functorch import make_functional
fnet, params = make_functional(model)
def fnet_single(params, x):
return fnet(params, x.unsqueeze(0)).squeeze(0)
ntk = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_train)
今回は1次元の回帰問題のため,NTK行列はサンプル数$\times$サンプル数となっています.
print(ntk.size())
# torch.Size([128, 128])
2.2.Neural Tangent Kernelを用いたクラス予測
NTK行列が与えられた時に実際に予測を行う関数を定義します.
$\langle f(x') \rangle_{\theta_0}$は以下で計算することができます.
def predict_with_ntk(model, x_train, y_train, x_test, reg=1e-4):
# Create a function version of the model
fnet, params = make_functional(model)
def fnet_single(params, x):
return fnet(params, x.unsqueeze(0)).squeeze(0)
# Compute Gram matrix and its inverse
K_train_train = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_train)
K_train_train += reg * torch.eye(K_train_train.shape[-1], device=device) # Regularization
K_train_train_inv = torch.inverse(K_train_train)
alpha = K_train_train_inv @ y_train.float()
preds_list = []
for i in range(0, len(x_test), 128):
x_test_batch = x_test[i:i+128]
K_train_test = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test_batch)
preds_batch = K_train_test.transpose(0, 1) @ alpha
preds_batch = (preds_batch > 0.5).float()
preds_list.append(preds_batch)
return torch.cat(preds_list, dim=0)
上の関数を用いて実際に予測を行うコードは下のとおりです.
test_preds_list = []
test_targets_list = []
with torch.no_grad():
for x_batch, y_batch in test_loader:
x_batch, y_batch = x_batch.to(device), y_batch.to(device)
y_train_float = y_train.float().unsqueeze(1) # y_trainをfloatに変換してサイズを調整
preds = predict_with_ntk(model, x_train, y_train_float, x_batch)
test_preds_list.append(preds)
test_targets_list.append(y_batch)
test_preds = torch.cat(test_preds_list, dim=0)
test_targets = torch.cat(test_targets_list, dim=0)
accuracy = (test_preds == test_targets.unsqueeze(1)).float().mean().item()
print(f'Test Accuracy: {accuracy * 100:.2f}%')
筆者の環境では
Test Accuracy: 78.85%
となりました.128サンプルでのCIFAR10の2クラス分類の正解率としては問題ない正解率が出ていると言えそうです.
終わりに
今回はPytorchを用いた実装を行いました.しかしNNGPやNTKなどの無限幅深層学習モデルの実験はJAXを用いて開発されたNeural Tangentsがよく用いられます.またPytorchにおいてNTKを効率的に計算するライブラリとしてASDLやTorchNTKなどの便利なライブラリも存在します.複雑なモデルを用いた実装を行う場合は,これらのライブラリを使うのが便利です.