0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

model.summary()でRNNの構成を図にしてみる

Posted at

はじめに

こんにちは、もちもちMAXです。
最近、改めてRNNについて調べていたのですが、モデルの全体像がいまいち分からなくてずっとモヤモヤしていました。雰囲気は分かるけど、どうしてもしっくりくる図と出会えていない、という状態でした。そこでKerasを使ってmodel.summary()でモデルの形状を見比べてみたところようやくスッキリしたので、記事にして残すことにしました。

RNNとは

RNNは再帰型ニューラルネットワーク(Recurrent Neural Network)の略称で、時系列データや連続したデータを処理するために設計されたニューラルネットワークの一種です。RNNの特徴として、隠れ層の出力が次の隠れ層の入力として使われる循環的な構造を持つ、というのがありましてそれを示した図が下記画像となっています。
image.png
引用:https://deepage.net/deep_learning/2017/05/23/recurrent-neural-networks.html

RNNよく分からん

先ほどの図自体は何も間違っていないですし、確かに隠れ層の出力が次の入力に使われています。ただ、こういった図を色々と眺めてみてもどうしてもネットワーク構造がイメージできませんでした。

  1. 隠れ層の次元が増えたらどうなる?
    今までMLP(多層パーセプトロン)を勉強していた時は次元数分のニューロンも図に含まれて書かれていることが多かく、RNNではその辺りが省略されて図示されていたのでよく分からなくなりました。
  2. 入力と出力は自由に次元を変えれる?
    例えば、入力は3次元でtime_step(後述)は5、出力は0,1の2値のモデルとかは作れるの?といったようにモデルのイメージができないので何ができるかも分からずでした。

model.summary()でネットワーク構造見てみよう

私の頭ではネットワーク構造の理解に限界があったので、実際にプログラムを動かして見てみることにしました。とりあえずmodel.summary()が動けば良いので早速作りましょう。

import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import SimpleRNN, Dense
from tensorflow.keras.optimizers import Adam

# データの準備
def generate_sine_wave(sample_size, frequency=0.1):
    x = np.arange(0, sample_size, 1)
    y = np.sin(2 * np.pi * frequency * x)
    return y.reshape(-1, 1)

sample_size = 100
time_steps = 3
total_size = sample_size + time_steps

sine_wave = generate_sine_wave(total_size, time_steps)

# データの前処理
X = []
y = []
for i in range(sample_size):
    X.append(sine_wave[i:i+time_steps])
    y.append(sine_wave[i+time_steps])

X = np.array(X)
y = np.array(y)

# モデルの構築
model = Sequential([
    SimpleRNN(1, activation='tanh', input_shape=(time_steps, 1), return_sequences=False),
    Dense(1)
])

model.compile(optimizer=Adam(learning_rate=0.001), loss='mse')

# モデル構造の出力
model.summary()

このコードでは、Kerasを使ってRNNを構築しています。今のところよく分からない状態ですがひとまずmodel.summary()の結果を見てみましょう。

image.png
simple_rnnのLayerのParamを見ると3になっています。Dense層の基本的なパラメータ数の計算方法は入力ユニット数 × 出力ユニット数 + バイアス(出力ユニット数)なので、ここから考えるとネットワーク構造はこんな感じでしょうか。
これは図で書くとこんな感じのネットワーク構造です。
image.png
入力ユニットとしてはx(t)と1つ前の隠れ層h(t-1)があるので、入力ユニット数は2となります。出力ユニットは隠れ層の次元数と同じなので1となります。以上で計算するとパラメータ数は3となります。ネットワーク構造の隠れ層の左側の線の数と隠れ層のユニット数を合計すればそれでパラメータ数が分かります。あと、流石に文章で書いても分かりにくいので図にしてみました。
image.png
ちなみにこのRNNはいわゆるシンプルなRNNで、3日分の気温を基に次の日の気温を予測する、というのをイメージしてもらえると分かりやすいかと思います。この、3日分というのがtime_stepにあたり3つの時刻分のデータを使うことを意味しています。1つの時刻の単位は日、分、秒などデータによって自由に設定することができます。

次元を変えてみよう

次は隠れ層の次元を増やしてみます。コードでいうと下記のSimpleRNN()の一つ目の引数を変えることになります。

# モデルの構築
model = Sequential([
    SimpleRNN(2, activation='tanh', input_shape=(time_steps, 1), return_sequences=False),  # 1 → 2に
    Dense(1)
])

次元を変えたら早速model.summary()でネットワーク構造を見てみましょう。
image.png
Output Shapeが(None, 2)になっておりParamが8になっています。先ほどのパラメータ数の計算式からネットワーク構造を考えると図のようになるでしょうか。
image.png
パラメータ数が合っているか確認するために隠れ層の左側の線の数と隠れ層のユニット数を足してみましょう。おそらく8になっていると思います。ということで隠れ層の次元が増えると先ほどの図のように入力、出力が繋がるようです。

では、もう一つのパターンとして入力の次元を増やしてみましょう。input_shapeを変えればそれでOKです。

# モデルの構築
model = Sequential([
    SimpleRNN(2, activation='tanh', input_shape=(time_steps, 2), return_sequences=False),  # input_shapeを1 → 2に
    Dense(1)
])

いつも通りmodel.summary()でネットワーク構造を見てみましょう。
image.png
ここまでくると予想できるかと思いますが、パラメータ数は10となっています。ネットワーク構造を図にしてみると分かりやすいですね。線の数と隠れ層のユニット数を数えるとちゃんと10になっているので問題なさそうです。
image.png
ちなみに、例でいうと過去3日分の温度、湿度から次の天気を予想する、などが入力の次元が2でtime_stepが3の状況にあたります。

さて、ここまでmodel.summary()を見ることでようやくRNNについて理解することができました。最後にreturn_sequencesについて簡単に説明して終わります。

return_sequencesについて

RNNなどの時系列データの学習の際、1時刻分の値のみを予測していく方式と複数時刻分を予測する方式の2つがあります。前者はreturn_sequences=Falseで後者はreturn_sequences=Trueとすれば実現できます。実際にやってみましょう。

# モデルの構築
model = Sequential([
    SimpleRNN(2, activation='tanh', input_shape=(time_steps, 2), return_sequences=True),
    Dense(1)
])

ついでにmodel.summary()を見てみましょう。
image.png
パラメータ数は変わっていませんが、Output Shapeに次元が追加されています。つまりネットワーク構造はこんなところでしょうか。
image.png
ネットワーク構造通り、パラメータ数は変わっていないですが出力が3つに増えています。一般的にはこのように複数時刻分を学習する方が精度が良くなる傾向にあるようです。きちんと過去の経緯を理解しながらの方が精度が良くなるのはイメージ通りですね。

さいごに

今回はRNNのネットワーク構造についてmodel.summary()を使いながら説明しました。自分自身きちんと理解できていませんでしたが、コードで確認しながら図に書いてみることでかなり理解できました。やはり手を動かすのが一番ですね。それでは!

最近、ブログと自作アプリのLPを立ち上げたので是非見てってください!

ブログ↓
https://mochinochikimchi.com

アプリLP↓
https://mochinochikimchi.com/applications/mealmotion/index.html

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?