LoginSignup
2
3

深層学習と深層展開

Last updated at Posted at 2023-11-02

はじめに

昨今注目されている深層学習の応用例として深層展開を紹介する。
深層展開は既存の反復アルゴリズムに対して、深層学習の技術を用いることでアルゴリズムの性能を向上させる手法である。

基礎知識

深層学習

深層学習ではニューラルネットワークを用いたものが一般的である。ニューラルネットワークに含まれるパラメータを学習することにより目標のタスクを達成する手法である。
数学的には入力$x$から出力$y$を得る関数と捉えられる。
引用:Stevens, E., Antiga, L.,, Viehmann, T. (2020). Deep Learning with PyTorch.

深層展開

深層展開は反復アルゴリズムを展開し$1$反復を$1$層とみなす。この層に含まれるパラメータを、確率的勾配降下法や誤差逆伝播法のような深層学習の学習技術を用いることで最適化を行いアルゴリズムの性能向上を目指す手法である。

深層展開の簡単な例

ここでは簡単例として最急降下法を考える。
最急降下法は関数の最小値を求める最適化手法の一つであり、更新式は以下のようになる。

x^{(t+1)} = x^{(t)} - \alpha \nabla f(x^{(K)})

この式は、単純に勾配の方向に点を動かしていくことを意味している。
関数

f(x) = (x-1)^2

のとき最小値は$0(x=1)$となる。
また一回微分は

f^{\prime}(x) = 2(x-1)

となる。
これを最急降下法を用いて求めると、$x^{(0)}=0$を初期値、$\alpha=0.1$とすると

\begin{align}
x^{(1)} & = x^{(0)} - 0.1 * f^{\prime}(x^{(0)}) , x^{(1)} = 0.2 \\
x^{(2)} & = x^{(1)} - 0.1 * f^{\prime}(x^{(1)}) , x^{(2)} = 0.36 \\
& \vdots \\
x^{(24)} & = x^{(23)} - 0.1 * f^{\prime}(x^{(23)}) , x^{(24)} = 0.99527 \\
\end{align}

と続いていき、勾配の大きさが$10^{-2}$未満を停止条件とすると$24$回必要であった。
下の図の赤い点は探索点を表しており、最小値に方向へ収束していることがわかる。(ソースコードは最下部)
GD.png
反復回数をできるだけ減らし、計算量を削減するためにアルゴリズム内のパラメータ$\alpha$の調整が必要である。
この$\alpha$を自動的に調整するために深層展開を用いる。
下の図は深層学習(左)と深層展開(右)の対応関係を示している。このように反復$1$回を$1$層とみなすことでアルゴリズムに含まれるパラメータの学習を可能にしている。

deepunfolding (2).png
実際に深層展開を適用してパラメータ$\alpha$を調節した最急降下法は$10^{-2}$を停止条件とすると$2$回で停止した。つまりパラメータの調整次第ではアルゴリズムの性能を最大限引き出すことが可能である。
実際に深層展開によって調節されたパラメータ$\alpha$の値は次の図のようになった。
Untitled.png

まとめ

深層展開の概要について説明した。深層展開は既存の数学的な知識を用いながら深層学習技術を適用することで、アルゴリズムの性能を向上させることが可能であることがわかっていただけたのではないでしょうか。反復アルゴリズムを活用する場面は数多くあるため、応用範囲は広いと考えられる。興味を持っていただいた方は関連文献を読んでいただければ幸いです。

関連文献

和田山正, モデルベース深層学習と深層展開, 森北出版, 2023
深層学習から深層展開の最新研究まで日本語でわかりやすくまとめられている。

A. Balatsoukas-Stimming and C. Studer, "Deep Unfolding for Communications Systems: A Survey and Some New Directions," 2019 IEEE International Workshop on Signal Processing Systems (SiPS), Nanjing, China, 2019, pp. 266-271, doi: 10.1109/SiPS47522.2019.9020494.
無線通信に深層展開を適用した手法についてのサーベイ論文
こちらの論文から各手法の詳細を辿ることができる。

M. Miyoshi, T. Nishimura, T. Sato, T. Ohgane, Y. Ogawa and J. Hagiwara,
"Parameter-Learned AMP for MIMO Signal Detection," 2022 IEEE VTS Asia Pacific Wireless Communications Symposium (IEEE VTS APWCS 2022), Th-5G1, Aug. 2022.

ソースコード(Python)

最急降下法

import numpy as np
import matplotlib.pyplot as plt

def func(x):
    return (x - 1)**2

def dfunc(x):
    df = 2*(x-1)
    df_abs = np.abs(df)
    return df, df_abs

fig = plt.figure()
ax = fig.add_subplot(1,1,1)

ax.set_xlabel("x", fontsize = 16)
ax.set_ylabel("y", fontsize = 16)
ax.set_xlim([0, 2])
ax.set_ylim([0, 1.0])
x = np.arange(-5,5, 0.1)
ax.plot(x, func(x), color = "gray", zorder = 1)

x = 0
alpha = 0.1
eps = 0.01
itr_max = 100

for itr in range(1, itr_max):
    x = x - alpha * dfunc(x)[0]
    print(x)
    ax.scatter(x, func(x), s = 10, color = "red", zorder = 2)
    if dfunc(x)[1] < eps:
        break

result = np.array([x, dfunc(x)[1]])
print("勾配降下回数",itr)

深層展開を適用した最急降下法

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import math
import matplotlib.pyplot as plt

init_val = 0.1 #学習可能ステップサイズパラメータの初期値
itr = 5 # 勾配法の反復数
batch_size = 50 # mini batch size

def func(x):
    return (x - 1)**2

def dfunc(x):
    df = 2*(x-1)
    return df

class DUGD(nn.Module):
    def __init__(self, num_itr):
        super(DUGD, self).__init__()
        self.alpha = nn.Parameter(init_val*torch.ones(num_itr)) #学習可能ステップサイズパラメータ
    def forward(self, num_itr, bs):
        s = (torch.rand(bs)*10-5) # ランダムな初期探索点
        for i in range(num_itr):
            s = s - self.alpha[i] * dfunc(s)
        return s

model = DUGD(itr)
opt   = optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.MSELoss()
solution = torch.ones(batch_size)
for gen in range(itr):
  for i in range(10000):
      opt.zero_grad()
      x_hat = model(gen + 1, batch_size)
      loss  = loss_func(x_hat, solution)
      loss.backward()
      opt.step()
  print(loss)
alpha = model.state_dict()['alpha'].to('cpu').detach().numpy().copy()
print(alpha)

fig = plt.figure()
ax = fig.add_subplot(1,1,1)

ax.set_xlabel("x", fontsize = 16)
ax.set_ylabel("y", fontsize = 16)
ax.set_xlim([0, 2])
ax.set_ylim([0, 1.0])
x = np.arange(-5,5, 0.1)
ax.plot(x, func(x), color = "gray", zorder = 1)

x = 0
eps = 0.01

for i in range(1, itr):
    x = x - alpha[i] * dfunc(x)
    print(x)
    ax.scatter(x, func(x), s = 10, color = "red", zorder = 2)
    if np.abs(func(x)) < eps:
        break

print("勾配降下回数",i+1)
2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3