2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

KerasのPermuteレイヤーの挙動確認

Last updated at Posted at 2021-08-17

こんにちは。つんあーです。

KerasのPermuteレイヤーなのですが、
「たぶん転置的な動きをしてくれるんだよなぁ」とふんわり理解しているものの、
実際の挙動をちゃんと把握していなかったので、簡単に試してみました。

さくっと構築

(モデルっていうようなものではないですけどね)

shape=[10, 2]のTensorを、Permuteレイヤーを通すことでshape=[2, 10]に変換する
という想定で定義しました。

import tensorflow as tf
from tensorflow.keras import layers, models

model = models.Sequential(name='Permute_sample')
model.add(layers.Permute(input_shape=(10, 2), dims=(2, 1)))

model.compile(optimizer='adam', loss='binary_crossentropy', metrics='acc')
model.summary()


# 出力
# Model: "Permute_sample"
# _________________________________________________________________
# Layer (type)                 Output Shape              Param #   
# =================================================================
# permute_1 (Permute)         (None, 2, 10)             0         
# =================================================================
# Total params: 0
# Trainable params: 0
# Non-trainable params: 0
# _________________________________________________________________

ご存知のことかと思いますが、暗黙に追加されているOutput ShapeNoneにはバッチサイズが入ります。

[10, 2]形式の入力データを用意します。
バッチサイズとして頭に1次元追加してshape=[1, 20, 2]となるようにしました。

array = np.arange(20).reshape(10, 2)
input = array.reshape((1, 10, 2))
print(input)


# 出力
# [[[ 0  1]
#   [ 2  3]
#   [ 4  5]
#   [ 6  7]
#   [ 8  9]
#   [10 11]
#   [12 13]
#   [14 15]
#   [16 17]
#   [18 19]]]

これをモデルに入力してみます。

result = model.predict(input)
print(result)

# [[[ 0.  2.  4.  6.  8. 10. 12. 14. 16. 18.]
#   [ 1.  3.  5.  7.  9. 11. 13. 15. 17. 19.]]]

2階と3階 が転置されたTensorが出力されました。
shapeとしては、バッチサイズを含め[1, 2, 10]となります。

結果的に、普通に転置した状態になりましたね。
(dtypeを特に指定しなかったので、floatになっている点だけ注意が必要かもしれませんね)

ちなみに

numpy.transposeの場合はこんな感じ。
明示的にしたかったのでわざわざaxes引数を付けていますが、指定なしでも同様の結果になります。

transposed = np.transpose(array, axes=(1, 0))
print(transposed)


# 出力
# [[ 0  2  4  6  8 10 12 14 16 18]
#  [ 1  3  5  7  9 11 13 15 17 19]]

numpy.reshapeだとこう。

reshaped = np.reshape(array, newshape=(2,10))
print(reshaped)


# 出力
# [[ 0  1  2  3  4  5  6  7  8  9]
#  [10 11 12 13 14 15 16 17 18 19]]

出力shape自体は[2, 10]で一致しますが、
中身は全然別物にですね。
Reshapeレイヤーでも、もちろんこの結果になります。

余談

Weblioによると、それぞれ意味は

transpose
=> 置き換える、入れ替える、転置する、言い直す、移調する、移項する、変換する

permute
=> (…を)変更する、交換する、(…を)並べ換える、順列する

とのことです。
大体同じことを言っていたのですね。

転置行列のことは transposed matrix とか言ったりすることもあるのでなんとなくtransposeの方が馴染みがあるような気がするのですが、
意味的にも挙動的にも同じ概念を指しているということがわかったので今後は安心して使えそうですね。

おしまい。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?