LoginSignup
2
2

More than 3 years have passed since last update.

TensorFlowのOptimizerの違いによる学習推移をアニメーションにした

Last updated at Posted at 2020-09-12

前回の記事で、Optimizerごとの学習推移の例をグラフにしました。

今回はアニメーションを作ってみました。

これです。

image.gif

損失関数の設定

今回の損失関数は $ (x^2+y^2-1)^2 + \frac{1}{8}(x + 1)^2 $ です。グラフにするとこんな感じです。

image.png

牛乳ビンの底をちょっと傾けたような形をしています。前回とほぼ同じ形の関数ですが、 $ x=-1, y=0 $ で最小値 $ 0 $ になるように少し変えました。

$ y=0 $ での断面はこんな感じです。

image.png

勾配降下法のOptimizer

シンプルな勾配降下法、モーメンタム、Adagrad、RMSprop、Adadelta、Adam、自作アルゴリズムを試しました。学習率はそれぞれのOptimizerで最適と思われる値を探しました。

[
  (tf.optimizers.SGD(learning_rate=0.1), "sgd"),
  (tf.optimizers.SGD(learning_rate=0.1, momentum=0.5), "momentum"),
  (tf.optimizers.Adagrad(learning_rate=2.0), "adagrad"),
  (tf.optimizers.RMSprop(learning_rate=0.005), "rmsprop"),
  (tf.optimizers.Adadelta(learning_rate=100), "adadelta"),
  (tf.optimizers.Adam(learning_rate=0.2), "adam"),
  (CustomOptimizer(learning_rate=0.1), "custom"),
]

グラフ

1000ステップ中の損失関数の値推移のグラフです。

image.png

青: シンプルな勾配降下法
橙: モーメンタム
緑: Adagrad
赤: RMSprop。初動が遅い
紫: Adadelta。振動してしまって解にたどり着けない
茶: Adam。最初振動しているが、解に近くなると動かなくなる
桃: 自作アルゴリズム。Adamと同じくらいの速さで解に近づき、その後も解に限りなく近づく

シンプルな勾配降下法

image.png

左はxy平面上での移動の様子です。右は損失関数の値推移です。

モーメンタム

image.png

シンプルな勾配降下法よりは収束が速いです。

Adagrad

image.png

RMSprop

image.png

初動が遅いのですが、振動せずにまっすぐに解に近づきます。

Adadelta

image.png

学習率を調整したのですが、収束しなかったです。

Adam

image.png

ボールが転げ落ちるように解に近づきます。谷で振動はします。

解にある程度近くなると動かなくなってしまうのは、式の分母が0になるのを防ぐための $ \epsilon $ があるためと思われます。

自作アルゴリズム

image.png

Adadeltaと違って限りなく解に近づいていきます。

谷を通り過ぎるとすぐに気がついて立ち止まるので振動はほとんどしません。そのかわり直進しやすいため円弧状の谷ではカーブを曲がれずに立ち止まるのを繰り返します。

2020/09/22追記: 自作アルゴリズムのソースコード → TensorFlowでOptimizerを自作する

処理速度と収束までのステップ数

  • time: 1000ステップの処理にかかった時間
  • counter1: 損失関数が0.0001を初めて下回るまでにかかったステップ数
  • counter2: 損失関数が0.0001を安定的に下回るまでにかかったステップ数
  • loss: 1000ステップ実行後の損失関数の値
sgd
time: 2.369345188140869
counter1: 731
counter2: 731
loss: 5.330099884304218e-05

momentum
time: 2.378025770187378
counter1: 361
counter2: 361
loss: 1.3026465239818208e-05

adagrad
time: 2.489086627960205
counter1: 1000
counter2: 1000
loss: 0.00011992067447863519

rmsprop
time: 3.6131269931793213
counter1: 522
counter2: 522
loss: 2.695291732379701e-05

adadelta
time: 2.5074684619903564
counter1: 1000
counter2: 1000
loss: 0.34417498111724854

adam
time: 2.8565642833709717
counter1: 46
counter2: 96
loss: 1.4833190107310656e-06

custom
time: 3.4037697315216064
counter1: 68
counter2: 72
loss: 1.9912960169676808e-12

counter1とcounter2が1000のOptimizerは1000ステップ処理しても0.0001に達しなかったことを示します。counter1とcounter2が1000未満で同じ値のOptimizerは振動せずに解に近づいていることを示します。

Pythonコード

Google Colaboratoryで実行しました。アニメーションを作るために最初にAPNGというパッケージをインストールします。

!pip install APNG
import time
import numpy as np
import matplotlib.pyplot as plt
import math
import tensorflow as tf
import matplotlib.patches as patches
from apng import APNG
import IPython

opts1 = [(tf.optimizers.SGD(learning_rate=lr), str(lr)) for lr in [0.3, 0.2, 0.1, 0.05]]
opts2 = [(tf.optimizers.SGD(learning_rate=lr, momentum=0.5), str(lr)) for lr in [0.3, 0.2, 0.1, 0.05]]
opts3 = [(tf.optimizers.Adagrad(learning_rate=lr), str(lr)) for lr in [3.0, 2.0, 1.0, 0.5]]
opts4 = [(tf.optimizers.RMSprop(learning_rate=lr), str(lr)) for lr in [0.01, 0.005, 0.003, 0.002]]
opts5 = [(tf.optimizers.Adadelta(learning_rate=lr), str(lr)) for lr in [200, 100, 50, 30]]
opts6 = [(tf.optimizers.Adam(learning_rate=lr), str(lr)) for lr in [0.5, 0.3, 0.2, 0.1]]
opts7 = [(CustomOptimizer(learning_rate=lr), str(lr)) for lr in [0.2, 0.1, 0.05, 0.03]]
opts8 = [
  (tf.optimizers.SGD(learning_rate=0.1), "sgd"),
  (tf.optimizers.SGD(learning_rate=0.1, momentum=0.5), "momentum"),
  (tf.optimizers.Adagrad(learning_rate=2.0), "adagrad"),
  (tf.optimizers.RMSprop(learning_rate=0.005), "rmsprop"),
  (tf.optimizers.Adadelta(learning_rate=100), "adadelta"),
  (tf.optimizers.Adam(learning_rate=0.2), "adam"),
  (CustomOptimizer(learning_rate=0.1), "custom"),
]
opts = opts8

k1x = 1.0
k1y = 1.0
k2 = 1.0

k1x2 = k1x * k1x
k1y2 = k1y * k1y

# 目的となる損失関数
def loss(i):
  x2 = k1x2 * x[i] * x[i]
  y2 = k1y2 * y[i] * y[i]
  r2 = (x2 + y2 - 1.0)
  x1 = k1x * x[i] + 1.0
  ret = r2 * r2 + 0.125 * x1 * x1
  return k2 * ret

thres = k2 * 0.0001

# 最適化する変数
x = []
y = []

# グラフにするための配列
xHistory = []
yHistory = []
lossHistory = []
calculationTime = []
convergenceCounter1 = []
convergenceCounter2 = []

maxLoopCount = 1000
maxLoopCountAnimation = 1000

x_ini = 0.1
y_ini = 2.0

for i in range(len(opts)):
  x.append(tf.Variable(x_ini))
  y.append(tf.Variable(y_ini))
  xHistory.append([])
  yHistory.append([])
  lossHistory.append([])
  convergenceCounter1.append(maxLoopCount)
  convergenceCounter2.append(0)
  start = time.time()
  for loopCount in range(maxLoopCount):
    l = float(loss(i))
    # グラフにするために記録
    xHistory[i].append(float(x[i]))
    yHistory[i].append(float(y[i]))
    lossHistory[i].append(l)
    if (math.isfinite(l) and l < thres and convergenceCounter1[i] >= maxLoopCount):
      convergenceCounter1[i] = loopCount
    if (not math.isfinite(l) or l >= thres):
      convergenceCounter2[i] = loopCount + 1

    # 最適化
    opts[i][0].minimize(lambda: loss(i), var_list = [x[i], y[i]])
  calculationTime.append(time.time() - start)

colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

# グラフ化1つ目
plt.rcParams['figure.figsize'] = (6.4, 4.8)
plt.ylim(-10.0, +2.0)
for i in range(len(opts)):
  plt.plot(range(maxLoopCount), np.log10(lossHistory[i]) - np.log10(k2), color=colors[i % len(colors)])
plt.show()

ths = np.linspace(-math.pi, math.pi, 100)
thsx = np.cos(ths) / k1x
thsy = np.sin(ths) / k1y

# グラフ化2つ目以降
plt.rcParams['figure.figsize'] = (16.0, 6.0)
for i in range(len(opts)):
  print(opts[i][1])
  print("time: " + str(calculationTime[i]))
  print("counter1: " + str(convergenceCounter1[i]))
  print("counter2: " + str(convergenceCounter2[i]))
  print("loss: " + str(lossHistory[i][-1]))
  fig, (ax1, ax2) = plt.subplots(ncols=2)
  ax1.set_xlim(-2 / k1x, +2 / k1x)
  ax1.set_ylim(-1.5 / k1y, +1.5 / k1y)
  ax1.plot(thsx, thsy, color="#aaaaaa")
  ax1.add_patch(patches.Ellipse(xy=(-1.0, 0.0), width=0.2 / k1x, height=0.2 / k1y, fc="#cccccc"))
  ax1.plot(xHistory[i], yHistory[i], color=colors[i % len(colors)])
  ax2.set_ylim(-10.0, +2.0)
  ax2.plot(range(maxLoopCount), np.log10(lossHistory[i]) - np.log10(k2), color=colors[i % len(colors)])
  plt.show()

# アニメーション
plt.rcParams['figure.figsize'] = (6.4, 9.6)
ax = plt.axes()
fnames = []
for loopCount in range(maxLoopCountAnimation):
  if not (loopCount <= 150 or loopCount % 3 == 0):
    continue
  plt.cla()
  plt.xlim(-1.5 / k1x, +1.5 / k1x)
  plt.ylim(-2.1 / k1y, +2.1 / k1y)
  plt.text(-1.5 / k1x, -2.1 / k1y, str(loopCount))
  plt.plot(thsx, thsy, color="#aaaaaa")
  ax.add_patch(patches.Ellipse(xy=(-1.0, 0.0), width=0.2 / k1x, height=0.2 / k1y, fc="#cccccc"))
  for i in range(len(opts)):
    if math.isfinite(xHistory[i][loopCount]) and math.isfinite(yHistory[i][loopCount]):
      plt.text(xHistory[i][loopCount], yHistory[i][loopCount], opts[i][1])
      ax.add_patch(patches.Ellipse(xy=(xHistory[i][loopCount], yHistory[i][loopCount]), width=0.05 / k1x, height=0.05 / k1y, fc=colors[i % len(colors)]))
  fname = str(loopCount) + ".png"
  plt.savefig(fname)
  fnames.append(fname)
  if loopCount < 10:
    fnames.append(fname)
  if loopCount == 0:
    fnames.append(fname)
    fnames.append(fname)
plt.cla()

APNG.from_files(fnames, delay=200).save("animation.png")
IPython.display.Image("animation.png")

ステップ150以降のアニメーションは3倍速にしています。

image.gif

リンク

関連する私の記事

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2