0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

ニューラルネットワークの予測の不確実性(stochastic variational inference・実装編)

Last updated at Posted at 2023-07-22

はじめに

ニューラルネットワークの予測の不確実性を算出する手法を検証します。
Jupyter Notebookは下記にあります。

概要

  • 連続値を予測する回帰のためのニューラルネットワークを構築
  • stochastic variational inference で予測の不確実性を算出

Stochastic variational inference

手法の概要は下記記事を参照してください。

Stochastic variational inferenceでは、各ミニバッチで近似分布$q(\theta)$からパラメータ$\theta$をサンプリングし、平均二乗誤差と事前分布と近似分布のKLダイバージェンスを最小化します。
$$
\mathcal{L} = \frac{1}{n} \sum_{i=1}^n (y_n - f(x_n;\theta))^2 + \alpha D_{KL}[q(\theta)||p(\theta)]
$$
ここで、$\alpha$は、2つの項のバランスを調整するパラメータです。

実装

1. ライブラリのインポート

必要なライブラリをインポートします。

import sys
import os

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

!pip install bayesian-torch
from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss

2. 実行環境の確認

使用するライブラリのバージョンや、GPU環境を確認します。
Google Colaboratoryで実行した際の例になります。

print('Python:', sys.version)
print('PyTorch:', torch.__version__)
!nvidia-smi
実行結果
Python: 3.10.6 (main, May 29 2023, 11:10:38) [GCC 11.3.0]
PyTorch: 2.0.1+cu118
Sat Jul 22 07:06:23 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P8    11W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

3. データセットの作成

sinカーブに従うデータを作成します。
ただし、学習には[-1,1]の範囲のデータは使用しません。

def make_dataset(seed, plot=0, batch_size=64):
    np.random.seed(seed)

    x_true = np.linspace(-4, 4, 100)
    y_true = np.sin(x_true)

    x = np.concatenate([np.random.uniform(-4, -1, 100), np.random.uniform(1, 4, 100)])
    y = np.sin(x)

    # データをPyTorchのテンソルに変換
    x = torch.from_numpy(x).float().view(-1, 1)
    y = torch.from_numpy(y).float().view(-1, 1)
    x_true = torch.from_numpy(x_true).float().view(-1, 1)

    # グラフを描画
    if plot == 1:
        plt.plot(x_true, y_true)
        plt.scatter(x, y)
        plt.xlabel('x')
        plt.ylabel('y')
        plt.show()

    dataset = torch.utils.data.TensorDataset(x, y)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return data_loader

_ = make_dataset(0, plot=1)

sin.png

ニューラルネットワークの定義

今回は3層の全結合ニューラルネットワークを用います。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(1, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

4. 学習

Bayesian Torch を用いてstochastic variational inferenceの学習を行います。
flipoutとreoarameterizationを選択できますが、今回はflipoutを使用します。

batch_size=64
data_loader = make_dataset(0, batch_size=batch_size)

model = Net()
const_bnn_prior_parameters = {
        "prior_mu": 0.0,
        "prior_sigma": 1.0,
        "posterior_mu_init": 0.0,
        "posterior_rho_init": -3.0,
        "type": "Flipout",  # Flipout or Reparameterization,
        "moped_enable": False, 
        }
dnn_to_bnn(model, const_bnn_prior_parameters)

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = model.to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

model.train()
num_epochs = 1000
for epoch in range(num_epochs):
    for inputs, targets in data_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        y_pred = model(inputs)

        # KLダイバージェンスの計算
        kl = get_kl_loss(model)

        # 損失の計算
        loss = criterion(y_pred, targets) + 1e-4 * kl / batch_size

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

5. 予測

予測の平均と標準偏差を描画します。

x_true = np.linspace(-4, 4, 100)
y_true = np.sin(x_true)
x_true = torch.from_numpy(x_true).float().view(-1, 1)

model.eval()
x_true = x_true.to(device)
y_preds = []
for _ in range(100):
    with torch.no_grad():
        y_pred = model(x_true)
    y_preds.append(y_pred.to('cpu').detach().numpy())

x_true = x_true.to('cpu')
y_preds = np.array(y_preds)
y_mean = np.mean(y_preds, axis=0)
y_std = np.sqrt(np.var(y_preds, axis=0) + 1e-4/2)


# グラフの描画
plt.figure(figsize=(8,4))

plt.subplot(121)
plt.plot(x_true, y_true, label='True Function')
plt.plot(x_true, y_mean, label='Mean Prediction')
plt.fill_between(x_true.flatten(), y_mean.flatten() - y_std.flatten(), y_mean.flatten() + y_std.flatten(), alpha=0.3, label='Uncertainty')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()


plt.subplot(122)
plt.plot(x_true, y_std)
plt.xlabel('x')
plt.ylabel('Std')

plt.tight_layout()
plt.show()

plt.show()

svi.png

おわりに

今回の結果

予測の不確実性は、x最小値および最大値付近とデータが含まれない[-1,1]の範囲で大きくなっています。
データ数が少なく、予測が不確実と考えられる領域と、予測の標準偏差が大きい領域が一致しているため、想定通り予測の不確実性が算出できていると考えられます。

次にやること

予測の不確実性を算出する他の手法も検証したいと思います。

参考資料

  • D. P. Kingma and M. Welling, Auto-Encoding Variational Bayes, ICLR, 2014
  • Y. Wen et al., Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-batches, ICLR, 2018.
  • Bayesian Torch
    https://github.com/IntelLabs/bayesian-torch
0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?