LoginSignup
0
0

Pytorchを用いたNTKおよびNNGPの実装方法 (1次元回帰問題)

Last updated at Posted at 2024-06-09

この記事では,Pytorchを用いてNeural Tangent Kernel (NTK) およびNeural Gaussian Process (NNGP) を実装し実験する方法について簡単に紹介します.なお今回は簡単のために2クラス分類にのみ適用可能な実装となっています.

0.データセットとモデルの定義

まずはPytorchを利用してCIFAR10の用意をします.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

CIFAR10から最初の2クラスのみを抽出したデータセットを作成した上でデータローダーを定義します.

def filter_classes(dataset, class_list):
    # Convert the dataset targets to a tensor
    targets_tensor = torch.tensor(dataset.targets)
    
    # Create a mask for the specified classes
    mask = torch.zeros_like(targets_tensor, dtype=torch.bool)
    for class_ in class_list:
        mask |= (targets_tensor == class_)
    
    # Apply the mask to filter the data and targets
    dataset.data = dataset.data[mask.numpy()]
    dataset.targets = targets_tensor[mask].numpy()
    
    return dataset

train_set = filter_classes(train_set, [0,1])
test_set = filter_classes(test_set, [0,1])
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

学習には最初のバッチを用います.そのため今回はバッチサイズである128サンプルを利用した学習となります.

x_train, y_train = next(iter(train_loader))
x_train, y_train = x_train.to(device), y_train.to(device)

モデルは3層のMLPモデルを使用します.出力次元は1次元であり,クラスの予測確率を学習します.

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(32*32*3, 256) 
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, x):
        x = x.view(-1, 32*32*3)  
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
model = SimpleNN().to(device)

なおNNGPおよびNTKは最終層が線形でありMSE Lossを用いている場合は解析解が存在しますが,Cross Entropy Lossなど出力層にSoftmax関数が含まれる場合には解析解が存在しません.

1. NNGPによるCIFAR10の2クラス分類

NNGPでは中間層はランダム初期値で固定し,最終層のみカーネル法を用いて重みを更新します.そのためL層のMLPモデルを

\boldsymbol{u}_l=\boldsymbol{W}_l \boldsymbol{h}_{l-1}+\boldsymbol{b}_l, \quad \boldsymbol{h}_l=\phi\left(\boldsymbol{u}_l\right) \quad(l=1, \ldots, L)

としたとき,NNGP解は

\boldsymbol{W}_{L}= \left(\boldsymbol{h}_{L-1} \boldsymbol{h}_{L-1}^{\top}+\rho \boldsymbol{I}\right)^{-1} \boldsymbol{h}_{L} \boldsymbol{y}_{\text{train}}, \quad \boldsymbol{W}_{l, 0}(l<L)

と表されます.
なおGaussian Processではガウス分布の分散を用いて未知データに対する予測の不確かさを表すことができますが,無限幅深層学習モデルの研究では基本的には平均値に着目します.

1.1 NNGP解による最終層の重みの更新

まずは最終層への入力$\boldsymbol{h}_{L-1}$を取得する必要があります.
今回はモデルの最終層であるmodel.fc3にforward_hookをかけることで,forward時に最終層への入力を保存します.他のモデルで実験を行う場合は,model.fc3の部分をそのモデルにおける最終層のモジュールに変えてください.

def forward_hook(module, in_data, out_data):
    module.h = in_data[0].detach()
hook = model.fc3.register_forward_hook(forward_hook)
model.zero_grad()
output = model(x_train.float())
hook.remove()

次にNNGP解を計算した上で,最終層の重みを更新します.
calculate_theta_starは最終層のモジュールと$\boldsymbol{y}_{\text{train}}$を入力として,最終層の重みを更新する関数となっています.

def calculate_theta_star(last_module, y_train_float, lambda_val=0.1):
    G = last_module.h
    lambda_I = lambda_val * torch.eye(G.size(1)).to(G.device)
    theta_star = torch.inverse(G.t() @ G + lambda_I) @ G.t() @ y_train_float
    with torch.no_grad():
        last_module.weight.copy_(theta_star.t())
        last_module.bias.copy_(torch.zeros_like(last_module.bias))

y_train_float = y_train.float().unsqueeze(1)
calculate_theta_star(model.fc3, y_train_float)

1.2 Neural Gaussian Processを用いたクラス予測

以下のコードによってCIFAR10の予測確率を求めることができます.ここではクラス1の予測確率が0.5を超えた場合はそのデータをクラス1に分類し,0.5より小さい場合はそのデータをクラス0に分類しています.

def predict_with_nngp(model, x_test):
    model.eval()
    with torch.no_grad():
        test_output = model(x_test.float())
        return (test_output > 0.5).float()

def calculate_accuracy(test_loader, model):
    test_preds_list = []
    test_targets_list = []
    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            preds = predict_with_nngp(model, x_batch)
            test_preds_list.append(preds)
            test_targets_list.append(y_batch)
    test_preds = torch.cat(test_preds_list, dim=0)
    test_targets = torch.cat(test_targets_list, dim=0)
    accuracy = (test_preds.squeeze() == test_targets).float().mean().item()
    return accuracy

accuracy = calculate_accuracy(test_loader, model)
print(f'Test Accuracy: {accuracy * 100:.2f}%')

上のコードを実行した結果,筆者の環境では

Test Accuracy: 71.35%

となりました.

2. NTKによるCIFAR10の2クラス分類

NTKではニューラルネットワークのランダム初期値近傍でのカーネル解を求めます.
ニューラルネットワーク$f(x)$のNTK行列$\Theta(x_1, x_2)$は以下で定義されます.

\Theta(x_1, x_2) = \nabla_\theta f(x_1) \nabla_\theta f(x_2)^{\top}

このNTK行列$\Theta(x_1, x_2)$を用いたカーネル解の出力は,

\langle f(x') \rangle_{\theta_0} = \Theta\left(x^{\prime}, x\right) \left( \Theta(x, x) + \rho \boldsymbol{I} \right)^{-1} \boldsymbol{y}

となります.なお$\langle \cdot \rangle_{\theta_0}$はランダム初期値についての平均を表しています.

2.1.Neural Tangent Kernelの計算

Functorchを用いてNeural Tangent Kernelを計算する関数を定義します.
メモリ消費量を削減するためには,Blockごとに分割した上で計算します.

from torch.func import vmap, jacrev
from functorch import make_functional_with_buffers

def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, block_size=32):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1_flat = torch.cat([j.reshape(j.shape[0], -1) for j in jac1], dim=1)
    
    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2_flat = torch.cat([j.reshape(j.shape[0], -1) for j in jac2], dim=1)

    # Initialize result matrix
    num_x1 = jac1_flat.shape[0]
    num_x2 = jac2_flat.shape[0]
    result = torch.zeros((num_x1, num_x2), device=jac1_flat.device)

    # Block-wise computation to save memory
    for i in range(0, num_x1, block_size):
        for j in range(0, num_x2, block_size):
            block_jac1 = jac1_flat[i:i+block_size]
            block_jac2 = jac2_flat[j:j+block_size]
            result[i:i+block_size, j:j+block_size] = block_jac1 @ block_jac2.T
    return result

そしてempirical_ntk_jacobian_contraction関数を用いてNTK行列を計算します.
上のempirical_ntk_jacobian_contraction関数を用いるには,モデルが関数化されている必要があります.そのためmake_functionalを用いてモデルを関数に変換します.

from functorch import make_functional

fnet, params = make_functional(model)
def fnet_single(params, x):
    return fnet(params, x.unsqueeze(0)).squeeze(0)

ntk = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_train)

今回は1次元の回帰問題のため,NTK行列はサンプル数$\times$サンプル数となっています.

print(ntk.size())
# torch.Size([128, 128])

2.2.Neural Tangent Kernelを用いたクラス予測

NTK行列が与えられた時に実際に予測を行う関数を定義します.
$\langle f(x') \rangle_{\theta_0}$は以下で計算することができます.

def predict_with_ntk(model, x_train, y_train, x_test, reg=1e-4):
    # Create a function version of the model
    fnet, params = make_functional(model)
    def fnet_single(params, x):
        return fnet(params, x.unsqueeze(0)).squeeze(0)
        
    # Compute Gram matrix and its inverse
    K_train_train = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_train)
    K_train_train += reg * torch.eye(K_train_train.shape[-1], device=device)  # Regularization
    K_train_train_inv = torch.inverse(K_train_train)
    alpha = K_train_train_inv @ y_train.float()

    preds_list = []
    for i in range(0, len(x_test), 128):
        x_test_batch = x_test[i:i+128]
        K_train_test = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test_batch)
        preds_batch = K_train_test.transpose(0, 1) @ alpha
        preds_batch = (preds_batch > 0.5).float() 
        preds_list.append(preds_batch)
    
    return torch.cat(preds_list, dim=0)

上の関数を用いて実際に予測を行うコードは下のとおりです.

test_preds_list = []
test_targets_list = []
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        y_train_float = y_train.float().unsqueeze(1)  # y_trainをfloatに変換してサイズを調整
        preds = predict_with_ntk(model, x_train, y_train_float, x_batch)
        test_preds_list.append(preds)
        test_targets_list.append(y_batch)

test_preds = torch.cat(test_preds_list, dim=0)
test_targets = torch.cat(test_targets_list, dim=0)

accuracy = (test_preds == test_targets.unsqueeze(1)).float().mean().item()
print(f'Test Accuracy: {accuracy * 100:.2f}%')

筆者の環境では

Test Accuracy: 78.85%

となりました.128サンプルでのCIFAR10の2クラス分類の正解率としては問題ない正解率が出ていると言えそうです.

終わりに

今回はPytorchを用いた実装を行いました.しかしNNGPやNTKなどの無限幅深層学習モデルの実験はJAXを用いて開発されたNeural Tangentsがよく用いられます.またPytorchにおいてNTKを効率的に計算するライブラリとしてASDLTorchNTKなどの便利なライブラリも存在します.複雑なモデルを用いた実装を行う場合は,これらのライブラリを使うのが便利です.

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0