LoginSignup
4
5

More than 1 year has passed since last update.

torch.distributionsのバッチサイズについて

Last updated at Posted at 2021-06-30

深層強化学習の方策や生成モデルの作成で誤差逆伝播可能なサンプリングを行いたいことがあると思います.
torch.distributions.Distributionは誤差逆伝播可能なサンプリングを行うことができ,sampleが通常のサンプリングで,rsampleが誤差逆伝播可能なサンプリングを行うメソッドです.

torch.distributionsではサンプリング時にサイズを指定するので,入力のバッチサイズと出力のバッチサイズが見かけ上異なることがあります.そこで今回は,torch.distributionsのDistributionクラスが出力するバッチサイズについて確認します.

import torch
import matplotlib.pyplot as plt

以下の例では,平均・分散ともにバッチサイズは3であり,各確率変数の次元は2です.VAE(変分オートエンコーダ)では,平均と分散をバッチサイズごとに用意し,サンプリングそれぞれひとつづつ行い,結果として全体のサンプルサイズがバッチサイズと等しくなります.

one_mean1 = torch.tensor([1,1]).float()
one_mean2 = torch.tensor([2,2]).float()
one_mean3 = torch.tensor([3,3]).float()

mean = torch.stack([one_mean1, one_mean2, one_mean3], dim=0)
print("maen:",mean)

one_var1 = torch.tensor([0.1, 0.1]).float()
one_var2 = torch.tensor([0.1, 0.1]).float()
one_var3 = torch.tensor([0.1, 0.1]).float()

var = torch.stack([one_var1, one_var2, one_var3], dim=0)  # 対角成分のみ
print("var:", var)

var_matrix = var[:,:,None] * torch.eye(var.size(1))[None,:,:]
print("var_matrix:", var_matrix)
maen: tensor([[1., 1.],
        [2., 2.],
        [3., 3.]])
var: tensor([[0.1000, 0.1000],
        [0.1000, 0.1000],
        [0.1000, 0.1000]])
var_matrix: tensor([[[0.1000, 0.0000],
         [0.0000, 0.1000]],

        [[0.1000, 0.0000],
         [0.0000, 0.1000]],

        [[0.1000, 0.0000],
         [0.0000, 0.1000]]])

torch.distributionsのクラスはバッチを考慮してパラメータの指定ができます.

multi_normal = torch.distributions.MultivariateNormal(loc=mean, covariance_matrix=var_matrix)

バッチサイズについてはbatch_sizeでアクセスできます.

print(multi_normal.batch_shape)
torch.Size([3])

一方確率変数の次元についてはevent_shapeからアクセスできます.

print(multi_normal.event_shape)
torch.Size([2])

Distributionクラスのsampleメソッドはバッチを考慮してサンプリングできるので,バッチサイズ=すべてのサンプルサイズとなるVAEでは,サンプルサイズを指定しません.

samples = multi_normal.sample()
print(samples.shape)
torch.Size([3, 2])

一つの平均・分散ごとに複数のサンプルを指定するときは,sampleメソッドでサンプルサイズを指定できます.以下の例ではバッチ次元の三つそれぞれで100個のサンプルをサンプリングしています.

samples = multi_normal.sample(torch.Size([100]))
print(samples.shape)
torch.Size([100, 3, 2])

以下のように,二つ目がもとのバッチ次元になっていることが分かります.

fig, ax = plt.subplots()
for i in range(samples.size(1)):  # バッチ次元
    each_samples = samples[:,i,:]
    each_samples_array = each_samples.numpy()
    ax.scatter(each_samples_array[:,0], each_samples_array[:,1])

以上の例では,わざわざ単位行列を利用して分散を行列にする必要がありました.しかし一般にVAEでは分散行列は対角行列であるので,各確率変数の次元は独立です.このように独立性がある場合,torch.distributions.Independentを利用すれば,一次元の変数の確率分布を用いて多次元の変数を表すことができます.

one_mean1 = torch.tensor([1,1]).float()
one_mean2 = torch.tensor([2,2]).float()
one_mean3 = torch.tensor([3,3]).float()

mean = torch.stack([one_mean1, one_mean2, one_mean3], dim=0)
print("maen:",mean)

one_sigma1 = torch.sqrt(torch.tensor([0.1, 0.1]))
one_sigma2 = torch.sqrt(torch.tensor([0.1, 0.1]))
one_sigma3 = torch.sqrt(torch.tensor([0.1, 0.1]))

sigma = torch.stack([one_sigma1, one_sigma2, one_sigma3], dim=0)  # 対角成分のみ
print("sigma:", sigma)
maen: tensor([[1., 1.],
        [2., 2.],
        [3., 3.]])
sigma: tensor([[0.3162, 0.3162],
        [0.3162, 0.3162],
        [0.3162, 0.3162]])
normal = torch.distributions.Normal(loc=mean, scale=sigma)

このmean・varはバッチサイズが3であるが,Normalは一変数の確率分布であるので,バッチサイズが間違って認識されてしまいます.

print(normal.batch_shape)
torch.Size([3, 2])
print(normal.event_shape)
torch.Size([])

以下で独立な確率変数を定義できます.二つ目の引数はバッチ次元における新しい確率変数次元に対応するインデックスです.元のバッチ次元が(3,2)であるので,二つ目を確率変数次元と指定しています.

iid_normal = torch.distributions.Independent(normal, 1) 
print(iid_normal.batch_shape)
torch.Size([3])
print(iid_normal.event_shape)
torch.Size([2])

以下の例でもバッチ次元の三つそれぞれで100個のサンプルをサンプリングしています.

samples = iid_normal.sample(torch.Size([100]))
print(samples.shape)
torch.Size([100, 3, 2])
fig, ax = plt.subplots()
for i in range(samples.size(1)):  # バッチ次元
    each_samples = samples[:,i,:]
    each_samples_array = each_samples.numpy()
    ax.scatter(each_samples_array[:,0], each_samples_array[:,1])

4
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
5