0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Keras(Tensorflow)で用いられる様々な行列演算のイメージを実例で理解する

Last updated at Posted at 2020-10-25

#はじめに
Kerasのコードを読むと様々な行列演算に遭遇する。これらの演算の中身を知らないと読み進めることが非常に難しい。今回、私が読んだコードを中心に、Kerasによく出てくる行列演算を実例を元に確認したため共有する。

#環境
Kerasといっても、今回確認した環境は、tensorflowの一部であった時代の古いバージョンである。ただ、行列演算は今とそれほど変わらないと思う。

  • tensoflow 1.14.0

###サンプルを動作させる場合の注意点
本記事のソースコードは以下のインポートを前提としています。動作させる場合はコピペしておいてください。

import numpy as np
import warnings
import os
warnings.simplefilter('ignore')

import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Input, Dropout, BatchNormalization

#確認方法
確認方法は、 Kerasでの処理をイメージできることを想定し、以下の通りとした。

  1. 演算を行うKerasのモデルを作成

  2. モデルのpredictメソッドに入力を与え、その結果を出力させることで、計算の内容を確認

#やってみよう

##concat
指定した軸に沿ってテンソルのリストを連結する。

ソース

    input1 = Input(shape=(2,))
    input2 = Input(shape=(2,))
    output = tf.concat(axis=1, values=[input1, input2])
    model = Model(inputs=[input1, input2], outputs=[output])
    print(model.summary())
    print(model.predict(x=[np.array([[10, 20], [20, 30]]), np.array([[20, 5], [30, 2]])]))

結果

[[10. 20. 20.  5.]
 [20. 30. 30.  2.]]

##stack
ランクRのテンソルのリストをランクR+1のテンソルに積み上げる。
concatとの大きな違いは、そのまま連結するのではなく、軸を1つ追加した上で連結する点であり、連結後も連結前の情報を取り出すことができるという点だろうか。

###ソース

    # stack
    input1 = Input(shape=(2,))
    input2 = Input(shape=(2,))
    output = tf.stack(axis=1, values=[input1, input2])
    model = Model(inputs=[input1, input2], outputs=[output])
    print(model.summary())
    print(model.predict(x=[np.array([[10, 20], [20, 30]]), np.array([[20, 5], [30, 2]])]))

###結果

[[[10. 20.]
  [20.  5.]]
 [[20. 30.]
  [30.  2.]]]

##expand_dims
添字"axis"でのサイズ1の次元を加える。

###ソース

    input1 = Input(shape=(1,))
    output = tf.expand_dims(input1, axis=1)
    model = Model(inputs=[input1], outputs=[output])
    print(model.summary())
    print(model.predict(x=[np.array([[10], [20], [30]])]))

結果

[[[10.]]
 [[20.]]
 [[30.]]]

squeeze

expand_dimsの逆で、テンソルから添字"axis"での1次元を除く。

ソース

    input1 = Input(shape=(1,1,))
    output = tf.squeeze(input1, axis=1)
    model = Model(inputs=[input1], outputs=[output])
    print(model.summary())
    print(model.predict(x=[np.array([[[10]], [[20]], [[30]]])]))

###結果

[[10.]
 [20.]
 [30.]]

reduce_max

テンソルの次元全体で要素の最大値を計算する。

###ソース
以下は 3x2 の行列に対し、0次元、1次元、2次元で演算してみた。0次元だと次元はそのままで結果がブロードキャストされ、1次元だと結果が全く同じであり、2次元で次数が減るというところが興味深い。

    input1 = Input(shape=(1,2))
    output = tf.reduce_max(input1, axis=0)
    model = Model(inputs=[input1], outputs=[output])
    print(model.summary())
    print(model.predict(x=[np.array([[[10, 20]], [[20, 5]], [[30, 4]]])]))

    input1 = Input(shape=(1,2))
    output = tf.reduce_max(input1, axis=1)
    model = Model(inputs=[input1], outputs=[output])
    print(model.summary())
    print(model.predict(x=[np.array([[[10, 20]], [[20, 5]], [[30, 4]]])]))

    input1 = Input(shape=(1,2))
    output = tf.reduce_max(input1, axis=2)
    model = Model(inputs=[input1], outputs=[output])
    print(model.summary())
    print(model.predict(x=[np.array([[[10, 20]], [[20, 5]], [[30, 4]]])]))

結果

[[30. 20.]
 [30. 20.]
 [30. 20.]]


[[10. 20.]
 [20.  5.]
 [30.  4.]]

[[20.]
 [20.]
 [30.]]

reduce_sum

テンソルの次元全体の要素の合計を計算する。ソースも同様に3x2 の行列に対し、0次元、1次元、2次元で演算してみた。reduce_maxと考え方は同じである。

ソース 

    input1 = Input(shape=(1,2))
    output = tf.reduce_sum(input1, axis=0)
    model = Model(inputs=[input1], outputs=[output])
    print(model.summary())
    print(model.predict(x=[np.array([[[10, 20]], [[20, 5]], [[30, 4]]])]))

    input1 = Input(shape=(1,2))
    output = tf.reduce_sum(input1, axis=1)
    model = Model(inputs=[input1], outputs=[output])
    print(model.summary())
    print(model.predict(x=[np.array([[[10, 20]], [[20, 5]], [[30, 4]]])]))

    input1 = Input(shape=(1,2))
    output = tf.reduce_sum(input1, axis=2)
    model = Model(inputs=[input1], outputs=[output])
    print(model.summary())
    print(model.predict(x=[np.array([[[10, 20]], [[20, 5]], [[30, 4]]])]))

結果

[[60. 29.]
 [60. 29.]
 [60. 29.]]

[[10. 20.]
 [20.  5.]
 [30.  4.]]

[[30.]
 [25.]
 [34.]]

matmul

"Multiplies matrix"の略で、行列aに行列bを乗算して、a * bを生成する。全結合層等によく使われる。

ソース

以下は1x2と2x1の行列を演算した結果、1x1の行列が生成される例である。

    input1 = Input(shape=(2,))
    input2 = Input(shape=(2,1))
    output = tf.matmul(input1, input2)
    model = Model(inputs=[input1, input2], outputs=[output])
    print(model.summary())
    print(model.predict(x=[np.array([[10, 20]]), np.array([[[1], [2]]])]))

###結果

[[[50.]]]

slice

テンソルに対し、開始位置および抽出サイズを指定し、テンソルの一部を抽出する。

###ソース

    input1 = Input(shape=(4,))
    input1_reshape = tf.reshape(input1, [4])
    input2 = Input(shape=(1), dtype=tf.int32)
    input2_reshape = tf.reshape(input2, [1])
    input3 = Input(shape=(1), dtype=tf.int32)
    input3_reshape = tf.reshape(input3, [1])
    output = tf.slice(input1_reshape, input2_reshape, input3_reshape)
    model = Model(inputs=[input1, input2, input3], outputs=[output])
    print(model.summary())
    print(model.predict(x=[np.array([[1, 8, 3, 4]]), np.array([[1]]), np.array([[1]])]))

###結果

[8.]

gather

テンソルに対し、インデックスを指定し、インデックスの要素を取得する。

###ソース
[1,8,3,4]というリストに対し、0番目、3番目のものを取り出したリストと、1番目、2番目のものを取り出したリストを要素とした行列を返している。入力のshapeが確定していないとエラーとなるため、無理やりreshapeしている。

    input1 = Input(shape=(4,))
    input1_reshape = tf.reshape(input1, [4])
    input2 = Input(shape=(2,2), dtype=tf.int32)
    output = tf.gather(input1_reshape, input2)
    model = Model(inputs=[input1, input2], outputs=[output])
    print(model.summary())
    print(model.predict(x=[np.array([[1, 8, 3, 4]]), np.array([[[0, 3],[1, 2]]])]))

###結果

[[[1. 4.]
  [8. 3.]]]

#おわりに

  • 今回は一部の演算のみの掲載ではあるが、多少、読解力があがったのではないかと期待する。
  • Kerasのモデルの入力に与えて計算させるだけでも、行列の次元数が合わない等のエラーにより、中々すんなりとはいかずかなり苦労したが、行列の理解にはおおいに役立った。
  • 特にInputのshape引数の値と、predcitに与える行列の形状を合わせることが重要となる
  • 今回のサンプルは、Kerasにおいて複数のInputをとるModelの作成、予測する場合のミニマムコードともなるので、今後複雑なモデルを作成する場合にも役立つと考える。

#参考

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?