0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

pythonでGANの勉強2 変分オートエンコーダ・畳み込みVAE

Posted at

pythonでGANの勉強をしていきたいと思います。

自分の勉強のメモとなります。
コードが見づらかったり、正しくない場合があるかもしれません。
実装については怪しい部分がありますので何か気づいたらご指摘いただければと思います。

前回はオートエンコーダについて学びました。

今回は、変分オートエンコーダについて勉強していきます。
「はじめてのディープラーニング」をもとにkerasやpytorchでも実装してみるという流れとなります。

変分オートエンコーダ (VAE)

変分オートエンコーダでは、潜在空間は学習された平均値と標準偏差を持つ正規分布として表現される。
まずエンコーダにより入力から平均ベクトル$\mu$と分散ベクトル$\sigma$を求める。
これらを基に潜在変数$z$が確率的にサンプリングされ、$z$からデコーダにより元のデータが再現される。
この$z$を調整することで連続的に変化するデータを生成できる。

潜在変数のサンプリング

潜在変数とは入力の特徴をエンコーダを使ってより低い次元に減らしたものである。
平均値$\mu$と標準偏差$\sigma$を出力し、これらを使った正規分布により潜在変数$z$をサンプリングする。

Reparametrization Trick

サンプリングする処理はそのままではバックプロパゲーションを適用できない。
そこで、Reparametrization Trickという方法では、平均値0,標準偏差1の正規分布からサンプリングされた$\epsilon$を使って以下のように潜在変数を定義する。

z=\mu+\epsilon\sigma

この式を利用することで、バックプロパゲーションが適用できるようになる。

誤差

「潜在変数がどれだけ発散しているか」(正則化項)$E_{reg}$と「出力が入力からどれだけずれているかを表す再構成誤差」$E_{rec}$を合わせてVAEの誤差とする。
$$
E=E_{rec}+E_{reg}
$$

再構成誤差

$$
E_{rec}=\frac{1}{h}\sum_{i=1}^h\sum_{j=1}^m(-x_{ij}\log y_{ij}-(1-x_{ij})\log (1-y_{ij}))
$$
$x_{ij}$はVAEの入力、$y_{ij}$はVAEの出力、$h$はバッチサイズ、$m$は入力層、出力層のニューロン数となる。
ここで、$\Sigma$を省略して以下の通りに表す。
$$
e_{rec}=-x\log y-(1-x)\log (1-y)
$$
この誤差は「交差エントロピー」と呼ばれる。

正則化項

$$
E_{rec}=\frac{1}{h}\sum_{i=1}^h\sum_{k=1}^n-\frac{1}{2}(1+\log\sigma_{ik}^2-\mu_{ik}^2-\sigma_{ik}^2)
$$
$h$はバッチサイズ、$n$は潜在変数の数、$\sigma_{ij}$は標準偏差、$\mu_{ij}$は平均値である。
$\Sigma$を省略して以下の通りに表す。
$$
e_{rec}=-\frac{1}{2}(1+\log\sigma^2-\mu^2-\sigma^2)
$$

実装

平均値、標準偏差を出力する層

標準偏差に関しては層の出力を標準偏差の2乗(分散)の対数を表すこととする。

class ParamsLayer(BaseLayer):
    def __init__(self, n_upper, n):
        self.w = np.random.randn(n_upper, n) * np.sqrt(2/n_upper)
        self.b = np.zeros(n)
    
    def forward(self, x):
        self.x = x
        u = np.dot(x, self.w) + self.b
        self.y = u
    
    def backward(self, grad_y):
        delta = grad_y
        
        self.grad_w = np.dot(self.x.T, delta)
        self.grad_b = np.sum(delta, axis=0)
        self.grad_x = np.dot(delta, self.w.T)

サンプリング層

順伝搬はReparametrization Trickの式
$$
z=\mu+\epsilon\sigma
$$
に基づいて行われる。これを書き換えると、
$$
z=\mu+\epsilon\exp\frac{\phi}{2}
$$
となる。
逆伝播を考えるため、誤差を以下の形で表す。

\begin{align}
E&=E_{rec}+E_{reg}\\
E_{rec}&=\sum_{i=1}^h\sum_{j=1}^m(-x_{ij}\log y_{ij}-(1-x_{ij})\log (1-y_{ij}))\\
E_{reg}&=\sum_{i=1}^h\sum_{j=1}^m-\frac{1}{2}(1+\log\sigma_{ik}^2-\mu_{ik}^2-\sigma_{ik}^2)\\
&=\sum_{i=1}^h\sum_{j=1}^m-\frac{1}{2}(1+\phi_{ik}-\mu_{ik}^2-\exp\phi_{ik})
\end{align}

$\mu$での微分は次のようになる。

\begin{align}
\frac{\partial E}{\partial \mu}&=\frac{\partial}{\partial \mu}(E_{rec}+E_{reg})\\
&=\frac{\partial E_{rec}}{\partial z}\frac{\partial z}{\partial \mu}+\frac{\partial E_{reg}}{\partial \mu}\\
&=\frac{\partial E_{rec}}{\partial z}+\mu
\end{align}

$\frac{\partial E_{rec}}{\partial z}$はデコーダからの逆伝播で得ることができる。
$\phi$による微分は次のようになる。

\begin{align}
\frac{\partial E}{\partial \phi}&=\frac{\partial}{\partial \phi}(E_{rec}+E_{reg})\\
&=\frac{\partial E_{rec}}{\partial z}\frac{\partial z}{\partial \phi}+\frac{\partial E_{reg}}{\partial \phi}\\
&=\frac{\partial E_{rec}}{\partial z}\frac{\epsilon}{2}\exp\frac{\phi}{2}-\frac{1}{2}(1-\exp\phi)
\end{align}
class LatentLayer:
    def forward(self, mu, log_var):
        self.mu = mu
        self.log_var = log_var
        
        self.epsilon = np.random.randn(*log_var.shape)
        self.z = mu + self.epsilon*np.exp(log_var/2)
    
    def backward(self, grad_z):
        self.grad_mu = grad_z + self.mu
        self.grad_log_var = grad_z*self.epsilon/2*np.exp(self.log_var/2)-0.5*(1-np.exp(self.log_var))

出力層

逆伝播では、

\begin{align}
\delta&=\frac{\partial E}{\partial u}\\
&=\frac{\partial E}{\partial y}\frac{\partial y}{\partial u}\\
&=\frac{\partial}{\partial y}(E_{rec}+E_{reg})y(1-y)\\
&=(-\frac{t}{y}+\frac{1-t}{1-y})y(1-y)\\
&=-t(1-y)+(1-t)y\\
&=y-t
\end{align}

と計算できる。

class BaseLayer:
    def update(self, eta):
        self.w -= eta * self.grad_w
        self.b -= eta * self.grad_b

class MiddleLayer(BaseLayer):
    def __init__(self, n_upper, n):
        self.w = np.random.randn(n_upper, n) * np.sqrt(2/n_upper)
        self.b = np.zeros(n)
    
    def forward(self, x):
        self.x = x
        self.u = np.dot(x, self.w) + self.b
        self.y = np.where(self.u <= 0, 0, self.u) # ReLU
    
    def backward(self,grad_y):
        delta = grad_y * np.where(self.u <=0 , 0, 1)
        
        self.grad_w = np.dot(self.x.T, delta)
        self.grad_b = np.sum(delta, axis=0)
        self.grad_x = np.dot(delta, self.w.T)

class OutputLayer(BaseLayer):
    def __init__(self, n_upper, n):
        self.w = np.random.randn(n_upper, n) * np.sqrt(2/n_upper)
        self.b = np.zeros(n)
    
    def forward(self, x):
        self.x = x
        u = np.dot(x, self.w) + self.b
        self.y = 1/(1+np.exp(-u))
    
    def backward(self, t):
        delta = self.y - t
        
        self.grad_w = np.dot(self.x.T, delta)
        self.grad_b = np.sum(delta, axis=0)
        self.grad_x = np.dot(delta, self.w.T)
def forward_propagation(x_mb):
    middle_layer_enc.forward(x_mb)
    mu_layer.forward(middle_layer_enc.y)
    log_var_layer.forward(middle_layer_enc.y)
    z_layer.forward(mu_layer.y, log_var_layer.y)

    middle_layer_dec.forward(z_layer.z)
    output_layer.forward(middle_layer_dec.y)

def backpropagation(t_mb):
    output_layer.backward(t_mb)
    middle_layer_dec.backward(output_layer.grad_x)
    
    z_layer.backward(middle_layer_dec.grad_x)
    log_var_layer.backward(z_layer.grad_log_var)
    mu_layer.backward(z_layer.grad_mu)
    middle_layer_enc.backward(mu_layer.grad_x+log_var_layer.grad_x)

def update_params():
    middle_layer_enc.update(eta)
    mu_layer.update(eta)
    log_var_layer.update(eta)
    
    middle_layer_dec.update(eta)
    output_layer.update(eta)

学習を実行します。

img_size = 8
n_in_out = img_size * img_size
n_mid = 16
n_z = 2

eta = 0.001
epochs = 201
batch_size = 32
interval = 20

digits_data = datasets.load_digits()
x_train = np.asarray(digits_data.data)
x_train /= 15
t_train = digits_data.target

middle_layer_enc = MiddleLayer(n_in_out, n_mid)
mu_layer = ParamsLayer(n_mid,  n_z)
log_var_layer = ParamsLayer(n_mid,  n_z)
z_layer = LatentLayer()

middle_layer_dec = MiddleLayer(n_z, n_mid)
output_layer = OutputLayer(n_mid, n_in_out)

def get_rec_error(y, t):
    eps = 1e-7
    return -np.sum(t*np.log(y+eps) + (1-t)*np.log(1-y+eps)) / len(y)

def get_reg_error(mu, log_var):
    return -np.sum(1 + log_var - mu**2 - np.exp(log_var)) / len(mu)

rec_error_record = []
reg_error_record = []
total_error_record = []
n_batch = len(x_train) // batch_size
for i in range(epochs):
    
    index_random = np.arange(len(x_train))
    np.random.shuffle(index_random)
    for j in range(n_batch):
        mb_index = index_random[j*batch_size : (j+1)*batch_size]
        x_mb = x_train[mb_index, :]
        
        forward_propagation(x_mb)
        backpropagation(x_mb)
        
        update_params()
    
    forward_propagation(x_train)
    
    rec_error = get_rec_error(output_layer.y, x_train)
    reg_error = get_reg_error(mu_layer.y, log_var_layer.y)
    total_error = rec_error + reg_error
    
    rec_error_record.append(rec_error)
    reg_error_record.append(reg_error)
    total_error_record.append(total_error)
    
    if i%interval == 0:
        print("Epoch:",i,
              "Rec_error:",rec_error,
              "Reg_error:",reg_error,
              "Total_error:",total_error)

plt.plot(range(1, len(rec_error_record)+1), rec_error_record, label="Rec_error")
plt.plot(range(1, len(reg_error_record)+1), reg_error_record, label="Reg_error")
plt.plot(range(1, len(total_error_record)+1), total_error_record, label="Total_error")

plt.legend()
plt.xlabel("Epochs")
plt.ylabel("Error")
plt.show()

image.png

今回は2つの潜在変数を用いているので、平面上にプロットを行います。

forward_propagation(x_train)

plt.figure(figsize=(8, 8))
for i in range(10):
    zt = z_layer.z[t_train==i]
    z_1 = zt[:, 0]
    z_2 = zt[:, 1]
    marker = "$"+str(i)+"$"
    plt.scatter(z_2.tolist(), z_1.tolist(), marker=marker, s=75)

plt.xlabel("z_2")
plt.ylabel("z_1")
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.grid()
plt.show()

image.png

2つの潜在変数により同じ数字が固まってクラスタを形成していることが分かります。
次に潜在変数を連続的に変化させて出力の変化を確認してみます。

n_img = 16
img_size_spaced = img_size + 2

matrix_image = np.zeros((img_size_spaced*n_img,
                         img_size_spaced*n_img))

z_1 = np.linspace(3, -3, n_img)
z_2 = np.linspace(3, -3, n_img)

for i, z1 in enumerate(z_1):
    for j, z2 in enumerate(z_2):
        x = np.array([float(z1), float(z2)])
        middle_layer_dec.forward(x)
        output_layer.forward(middle_layer_dec.y)
        image = output_layer.y.reshape(img_size, img_size)
        top = i*img_size_spaced
        left = j*img_size_spaced
        matrix_image[top:top+img_size,
                     left:left+img_size] = image
    
plt.figure(figsize=(8,8))
plt.imshow(matrix_image.tolist(), cmap="Greys_r")
plt.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)

image.png

徐々に数字が変わっていくことが確認できます。

pytorch

まずはpytorchで実装を行います。
必要なライブラリのインポートをしておきます。

import torch
import torch.nn as nn
import torch.optim as optimizers
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn.functional as F

データの準備をします。
使用したデータは前回のオートエンコーダのときと同じものです。

digits_data = datasets.load_digits()
x_train = np.asarray(digits_data.data)
x_train /= 15
x_train = x_train.reshape(-1, 1, 64)
y = digits_data.target

x_train = torch.tensor(x_train, dtype=torch.float32)
y_train = torch.tensor(y, dtype=torch.float64)

train = TensorDataset(x_train, y_train)
train_dataloader = DataLoader(train, batch_size=32, shuffle=True)

エンコーダとデコーダの実装します。
エンコーダでは平均と分散の2種類を出力するようにします。
デコーダは2つの潜在変数を受け取り入力と同じ形で出力します。

class Encoder(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        self.l1 = nn.Linear(8*8, 16)
        self.l_mean = nn.Linear(16, 2)
        self.l_var = nn.Linear(16, 2)
    
    def forward(self, x):
        h = self.l1(x)
        h = torch.relu(h)
        mean = self.l_mean(h)
        var = self.l_var(h)
        var = F.softplus(var)
        
        return mean, var
    
class Decoder(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        self.l1 = nn.Linear(2, 16)
        self.l2 = nn.Linear(16, 8*8)
    
    def forward(self, x):
        h = self.l1(x)
        h = torch.relu(h)
        h = self.l2(h)
        y = torch.sigmoid(h)
        
        return y

次にVAE本体の実装をします。
ここで、サンプリング層と損失関数の定義を行います。

class VAE(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        self.encoder = Encoder(device=device)
        self.decoder = Decoder(device=device)
    
    def forward(self, x):
        mean, var = self.encoder(x)
        z = self.reparameterize(mean, var)
        y = self.decoder(z)
        return y, z
    
    def reparameterize(self, mean, var):
        eps = torch.randn(mean.size()).to(self.device)
        z = mean + torch.sqrt(var) * eps
        return z
    
    def lower_bound(self, x):
        mean, var = self.encoder(x)
        z = self.reparameterize(mean, var)
        y = self.decoder(z)
        
        reconst =  - torch.mean(torch.sum(x * torch.log(y)
                                       + (1 - x) * torch.log(1 - y),
                                       dim=1))
        kl = - 1/2 * torch.mean(torch.sum(1
                                          + torch.log(var)
                                          - mean**2
                                          - var, dim=1))
        
        L = reconst + kl
        
        return L

学習の設定を行います。

model = VAE(device=device).to(device)
criterion = model.lower_bound
optimizer = optimizers.Adam(model.parameters(), lr=0.001)

学習を実行します。

epochs = 200
train_loss_record = []
for epoch in range(epochs):
    train_loss = 0.
    for (x, _) in train_dataloader:
        x = x.to(device)
        model.train()
        loss = criterion(x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    train_loss /= len(train_dataloader)
    train_loss_record.append(train_loss)
    
    if epoch%20 == 0:
        print("Epoch: {}, Loss: {:3f}".format(epoch+1, train_loss))

学習の様子を可視化します。

plt.plot(range(1, len(train_loss_record)+1), train_loss_record)
plt.xlabel("Epochs")
plt.ylabel("Error")
plt.show()

image.png

次に潜在変数を連続的に変化させて出力の確認をします。

model.eval()

n_img = 16
img_size_spaced = img_size + 2

matrix_image = np.zeros((img_size_spaced*n_img,
                         img_size_spaced*n_img))

z_1 = np.linspace(3, -3, n_img)
z_2 = np.linspace(3, -3, n_img)


for i, z1 in enumerate(z_1):
    for j, z2 in enumerate(z_2):
        x = torch.Tensor([float(z1), float(z2)])
        image = model.decoder(x).detach().numpy().reshape(img_size, img_size)
        top = i*img_size_spaced
        left = j*img_size_spaced
        matrix_image[top:top+img_size,
                     left:left+img_size] = image

plt.figure(figsize=(8,8))
plt.imshow(matrix_image.tolist(), cmap="Greys_r")
plt.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)

image.png

潜在変数をプロットしてうまく学習できているか確認します。

x_train = torch.tensor(x_train, dtype=torch.float32)

z = model.forward(x_train)[1]

plt.figure(figsize=(8, 8))
for i in range(10):
    zt = z[t_train==i]
    z_1 = zt[:, 0]
    z_2 = zt[:, 1]
    marker = "$"+str(i)+"$"
    plt.scatter(z_2.tolist(), z_1.tolist(), marker=marker, s=75)

plt.xlabel("z_2")
plt.ylabel("z_1")
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.grid()
plt.show()

image.png

keras

同様に実装を行います。

まずは、必要なライブラリのインポートをします。

from keras import backend as K
from keras.layers import Lambda
from keras import backend, metrics
import tensorflow as tf

エンコーダとデコーダの実装をします。

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(2,), mean=0., stddev=1.)
    return z_mean + K.sqrt(z_log_var) * epsilon

class Encoder(Model):
    def __init__(self):
        super().__init__()
        self.l1 = Dense(16, activation='relu')
        self.l_mean = Dense(2)
        self.l_var = Dense(2, activation='softplus')
    
    def call(self, x):
        h = self.l1(x)
        mean = self.l_mean(h)
        var = self.l_var(h)
        
        return mean, var

class Decoder(Model):
    def __init__(self):
        super().__init__()
        self.l1 = Dense(16, activation='relu')
        self.l2 = Dense(64, activation='sigmoid')
    
    def call(self, x):
        h = self.l1(x)
        h = self.l2(h)
        return h

VAE本体の実装をします。
誤差をadd_lossを使って定義しています。
「WARNING」が出てしまうので、あまりよくないかもしれません。

class VAE(Model):
    def __init__(self):
        self.is_placeholder = True
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.lambda1 = Lambda(lambda x: sampling(x), output_shape=(2,))
    
    def call(self, x):
        mean_var = self.encoder(x)
        self.mean = mean_var[0]
        self.var = mean_var[1]
        z = self.lambda1(mean_var)
        y = self.decoder(z)

        xent_loss = - K.sum(x * K.log(y) + (1 - x) * K.log(1 - y), axis=1)
        kl_loss = - K.sum(1 + K.log(self.var) - K.square(self.mean) - self.var, axis=1)
        vae_loss = K.mean(xent_loss) + 0.5*K.mean(kl_loss)

        self.add_loss(vae_loss)
        return y, z

モデルを定義します。

model = VAE()
model.compile(optimizer='adam')

学習を実行します。

digits_data = datasets.load_digits()
x_train = np.asarray(digits_data.data)
x_train /= x_train.max()

history = model.fit(x_train,
                    epochs=200,
                    batch_size=32,
                    shuffle=True)

学習の様子を可視化します。

plt.plot(range(1, len(history.history['loss'])+1), history.history['loss'])
plt.xlabel("Epochs")
plt.ylabel("Error")
plt.show()

image.png

潜在変数を連続的に変化させて出力の変化を確認します。

n_img = 16
img_size_spaced = img_size + 2

matrix_image = np.zeros((img_size_spaced*n_img,
                         img_size_spaced*n_img))

z_1 = np.linspace(3, -3, n_img)
z_2 = np.linspace(3, -3, n_img)


for i, z1 in enumerate(z_1):
    for j, z2 in enumerate(z_2):
        x = np.float64(np.array([[z1,z2]]))
        image = np.array(model.decoder(x)).reshape(img_size, img_size)
        top = i*img_size_spaced
        left = j*img_size_spaced
        matrix_image[top:top+img_size,
                     left:left+img_size] = image

plt.figure(figsize=(8,8))
plt.imshow(matrix_image.tolist(), cmap="Greys_r")
plt.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)

image.png

潜在変数をプロットして様子を見てみます。

#x_train = torch.tensor(x_train, dtype=torch.float32)

z = model(x_train)[1]

plt.figure(figsize=(8, 8))
for i in range(10):
    zt = np.array(z[t_train==i])
    z_1 = zt[:, 0]
    z_2 = zt[:, 1]
    marker = "$"+str(i)+"$"
    plt.scatter(z_2.tolist(), z_1.tolist(), marker=marker, s=75)

plt.xlabel("z_2")
plt.ylabel("z_1")
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.grid()
plt.show()

image.png

畳み込みオートエンコーダ

今までは入力は画像をベクトル化したものを使用していましたが、
ここでは画像のまま入力できる畳み込み処理をいれた変分オートエンコーダの実装を行います。

実装の内容自体はさきほど実装した変分オートエンコーダに畳み込み処理を追加するだけです。
pytorchとkerasのみで実装を行います。

pytorch

エンコーダでは「Conv2d」を使ってチャンネルの増加と画像サイズの削減を行い、
デコーダでは「ConvTranspose2d」を使ってチャンネルの削減と画像サイズの増加を行い元のサイズに戻します。

class ConvEncoder(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.l1 = nn.Linear(32*2*2, 16)
        self.l_mean = nn.Linear(16, 2)
        self.l_var = nn.Linear(16, 2)
    
    def forward(self, x):
        h = self.conv1(x)
        h = self.conv2(h)
        h = h.view(h.shape[0], -1)
        h = self.l1(h)
        h = torch.relu(h)
        mean = self.l_mean(h)
        var = self.l_var(h)
        var = F.softplus(var)
        
        return mean, var
    
class ConvDecoder(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        self.l1 = nn.Linear(2, 16)
        self.l2 = nn.Linear(16, 32*2*2)
        self.conv1 = nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1)
    
    def forward(self, x):
        h = self.l1(x)
        h = torch.relu(h)
        h = self.l2(h)
        h = torch.relu(h)
        h = h.view(-1, 32, 2, 2)
        h = self.conv1(h)
        h = torch.relu(h)
        h = self.conv2(h)
        y = torch.sigmoid(h)
        
        return y

VAE本体はほぼ変わりません。
損失を計算する際に、画像をベクトル化します。

class ConvVAE(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        self.encoder = ConvEncoder(device=device)
        self.decoder = ConvDecoder(device=device)
    
    def forward(self, x):
        mean, var = self.encoder(x)
        z = self.reparameterize(mean, var)
        y = self.decoder(z)
        return y, z
    
    def reparameterize(self, mean, var):
        eps = torch.randn(mean.size()).to(self.device)
        z = mean + torch.sqrt(var) * eps
        return z
    
    def lower_bound(self, x):
        mean, var = self.encoder(x)
        z = self.reparameterize(mean, var)
        y = self.decoder(z)
        x = x.reshape(1,-1)
        y = y.reshape(1,-1)
        
        reconst =  - torch.mean(torch.sum(x * torch.log(y)
                                       + (1 - x) * torch.log(1 - y),
                                       dim=1))
        kl = - 1/2 * torch.mean(torch.sum(1
                                          + torch.log(var)
                                          - mean**2
                                          - var, dim=1))
        
        L = reconst + kl
        
        return L

データは画像のままで使用します。

digits_data = datasets.load_digits()
x_train = np.asarray(digits_data.data)
x_train /= x_train.max()
x_train = x_train.reshape(-1,1, 8,8)
y = digits_data.target

x_train = torch.tensor(x_train, dtype=torch.float32)
y_train = torch.tensor(y, dtype=torch.float64)

train = TensorDataset(x_train, y_train)
train_dataloader = DataLoader(train, batch_size=64, shuffle=True)

学習の設定をします。

device = None

model = ConvVAE(device=device).to(device)
criterion = model.lower_bound
optimizer = optimizers.Adam(model.parameters(), lr=0.001)

学習を実行します。

epochs = 300
train_loss_record = []
for epoch in range(epochs):
    train_loss = 0.
    for (x, _) in train_dataloader:
        x = x.to(device)
        model.train()
        loss = criterion(x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    train_loss /= len(train_dataloader)
    train_loss_record.append(train_loss)
    
    if epoch%20 == 0:
        print("Epoch: {}, Loss: {:3f}".format(epoch+1, train_loss))

学習の様子を可視化します。

plt.plot(range(1, len(train_loss_record)+1), train_loss_record)
plt.xlabel("Epochs")
plt.ylabel("Error")
plt.show()

image.png

潜在変数を連続的に変化させたときの出力を確認します。

model.eval()

n_img = 16
img_size=8
img_size_spaced = img_size + 2

matrix_image = np.zeros((img_size_spaced*n_img,
                         img_size_spaced*n_img))

z_1 = np.linspace(3, -3, n_img)
z_2 = np.linspace(3, -3, n_img)


for i, z1 in enumerate(z_1):
    for j, z2 in enumerate(z_2):
        x = torch.Tensor([float(z1), float(z2)])
        image = model.decoder(x).detach().numpy().reshape(img_size, img_size)
        top = i*img_size_spaced
        left = j*img_size_spaced
        matrix_image[top:top+img_size,
                     left:left+img_size] = image

plt.figure(figsize=(8,8))
plt.imshow(matrix_image.tolist(), cmap="Greys_r")
plt.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)

image.png

潜在変数をプロットしたときの様子を確認します。

x_train = torch.tensor(x_train, dtype=torch.float32)

z = model.forward(x_train)[1]

plt.figure(figsize=(8, 8))
for i in range(10):
    zt = z[y==i]
    z_1 = zt[:, 0]
    z_2 = zt[:, 1]
    marker = "$"+str(i)+"$"
    plt.scatter(z_2.tolist(), z_1.tolist(), marker=marker, s=75)

plt.xlabel("z_2")
plt.ylabel("z_1")
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.grid()
plt.show()

image.png

keras

kerasもほぼ同様です。

from keras.layers import Lambda, Conv2D, Conv2DTranspose, Flatten, Reshape, Dense, Input

エンコーダでは「Conv2D」、デコーダでは「Conv2DTranspose」を使って実装を行います。

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(2,), mean=0., stddev=1.)
    return z_mean + K.sqrt(z_log_var) * epsilon

class ConvEncoder(Model):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv2D(16, kernel_size=3, strides=2, padding='same')
        self.conv2 = Conv2D(32, kernel_size=3, strides=2, padding='same')
        self.flat = Flatten()
        self.l1 = Dense(16, activation='relu')
        self.l_mean = Dense(2)
        self.l_var = Dense(2, activation='softplus')
    
    def call(self, x):
        h = self.conv1(x)
        h = self.conv2(h)
        h = self.flat(h)
        h = self.l1(h)
        mean = self.l_mean(h)
        var = self.l_var(h)
        
        return mean, var

class ConvDecoder(Model):
    def __init__(self):
        super().__init__()
        self.l1 = Dense(16, activation='relu')
        self.l2 = Dense(32*2*2, activation='relu')
        self.reshape = Reshape((2, 2, 32))
        self.conv1 = Conv2DTranspose(16, kernel_size=3, strides=2, padding='same', activation='relu')
        self.conv2 = Conv2DTranspose(1, kernel_size=3, strides=2, padding='same', activation='sigmoid')
    
    def call(self, x):
        h = self.l1(x)
        h = self.l2(h)
        h = self.reshape(h)
        h = self.conv1(h)
        h = self.conv2(h)
        return h

本体の実装をします。
損失の計算をするときは画像をベクトル化します。

class ConvVAE(Model):
    def __init__(self):
        self.is_placeholder = True
        super().__init__()
        self.encoder = ConvEncoder()
        self.decoder = ConvDecoder()
        self.lambda1 = Lambda(lambda x: sampling(x), output_shape=(2,))
    
    def call(self, x):
        mean_var = self.encoder(x)
        self.mean = mean_var[0]
        self.var = mean_var[1]
        z = self.lambda1(mean_var)
        y = self.decoder(z)
        
        x_ = Reshape((-1, 1))(x)
        y_ = Reshape((-1, 1))(y)
        xent_loss = - K.sum(x_ * K.log(y_) + (1 - x_) * K.log(1 - y_), axis=1)
        kl_loss = - K.sum(1 + K.log(self.var) - K.square(self.mean) - self.var, axis=1)
        vae_loss = K.mean(xent_loss) + 0.5*K.mean(kl_loss)

        self.add_loss(vae_loss)
        return y, z

モデルを定義します。

model = ConvVAE()
model.compile(optimizer='adam')

データは画像のまま入力し、学習を行います。

digits_data = datasets.load_digits()
x_train = np.asarray(digits_data.data)
x_train /= x_train.max()
x_train = x_train.reshape(-1,8,8,1)

history = model.fit(x_train,
                    epochs=200,
                    batch_size=32,
                    shuffle=True)

学習の様子を可視化します。

plt.plot(range(1, len(history.history['loss'])+1), history.history['loss'])
plt.xlabel("Epochs")
plt.ylabel("Error")
plt.show()

image.png

潜在変数を連続的に変化させたときの出力の変化を確認します。

n_img = 16
img_size_spaced = img_size + 2

matrix_image = np.zeros((img_size_spaced*n_img,
                         img_size_spaced*n_img))

z_1 = np.linspace(3, -3, n_img)
z_2 = np.linspace(3, -3, n_img)


for i, z1 in enumerate(z_1):
    for j, z2 in enumerate(z_2):
        x = np.float64(np.array([[z1,z2]]))
        image = np.array(model.decoder(x)).reshape(img_size, img_size)
        top = i*img_size_spaced
        left = j*img_size_spaced
        matrix_image[top:top+img_size,
                     left:left+img_size] = image

plt.figure(figsize=(8,8))
plt.imshow(matrix_image.tolist(), cmap="Greys_r")
plt.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)

image.png

最後に潜在変数をプロットしたときの分布の確認を行います。

z = model(x_train)[1]

plt.figure(figsize=(8, 8))
for i in range(10):
    zt = np.array(z[y==i])
    z_1 = zt[:, 0]
    z_2 = zt[:, 1]
    marker = "$"+str(i)+"$"
    plt.scatter(z_2.tolist(), z_1.tolist(), marker=marker, s=75)

plt.xlabel("z_2")
plt.ylabel("z_1")
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.grid()
plt.show()

image.png

ところどこと淡しい部分はありますが、概ね期待する結果が得られました。

参考文献

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?