6
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Keras Conv1DのInput Shapeの順番はChannel firstかChannel lastのどちらが正解か?

Last updated at Posted at 2017-04-01

はじめに

Hirofumi Yashimaさんの投稿された
Keras 1d-CNN 1次元畳み込みニューラルネットワーク で 単変量回帰タスク を 行って成功した件
について、Keras Conv1DのInput Shapeの順番はChannel firstかChannel lastのどちらが正解かを議論するためのメモです

私の環境について

OS windows10 Home
Python 3.5.3 Anaconda 64-bit
tensorflow 1.0.0
Keras 2.0.2
.keras.json

{
"image_dim_ordering": "th",
"backend": "tensorflow",
"epsilon": 1e-07,
"floatx": "float32"
}

なので、まずKerasの環境が1系か2系で挙動が異なる可能性があります。

Channel first or last?

以下の通り、私はChannel firstになっているとコメントしました。

なるほど、ありがとうございます!
backendがtensorflowだけですが、ソースコードでもご指摘の通りでした。。

https://github.com/fchollet/keras/blob/master/keras/backend/tensorflow_backend.py#L2819

tensorflow_backend.py#L2819
def conv1d(x, kernel, strides=1, padding='valid',
           data_format=None, dilation_rate=1):
    if data_format == 'channels_last':
        tf_data_format = 'NWC'
    else:
        tf_data_format = 'NCW'

デフォルトのdata_formatがNoneになっているので、tf_data_formatはNCW
つまり、(None, 変数の種類, 3期分のデータ)になるんですね。

しかし、Conv1D他の部分も調べてみるとConv1Dの初期化の際、data_format='channels_last'に設定されています。(superの挙動をいまいち理解できていません。。。)

convolutional.py#L233
class Conv1D(_Conv):
 .
 .
        super(Conv1D, self).__init__(
            rank=1,
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            data_format='channels_last',

Conv1Dを呼び出す際、data_formatが設定されているのでdata_formatはchannels_lastになっていると思います。

convolutional.py#L148
    def call(self, inputs):
        if self.rank == 1:
            outputs = K.conv1d(
                inputs,
                self.kernel,
                strides=self.strides[0],
                padding=self.padding,
                data_format=self.data_format,
                dilation_rate=self.dilation_rate[0])

つまり、input shapeの順番はchannel last(None, データサイズ, 変数の種類数)が正しいのではないかと思います。

また、hayatoyさんが投稿されたTensorFlow (ディープラーニング)で為替(FX)の予測をしてみる CNN編の記事にて、1種類、24日分データを入力を以下のようにしていることから、channel lastになっているのではなかと思います。

CNN
input_shape = (24, 1)

そして、プログラムでも確認してみました。

プログラムの説明

DATA_SAMPLESがサンプルサイズ、CHANNELSが変数の数、DATA_LENGTHがデータの長さを示しています。
具体的に言えば、1次元の10か月データを20セット用意して、サイズ3の50個あるフィルタで1次元のCNNを行うプログラムです。
Conv1Dの処理後の出力としては、PADDING=sameにしているので、入力と同じデータの10とフィルタ数が50が出力されるはずです。
channel firstのConv1D出力は

conv1d_1 (Conv1D) (None, 1, 50)

channel lastのConv1D出力は

conv1d_2 (Conv1D) (None, 10, 50)

channel firstのConv1Dはデータ数が1個になり、フィルタ数50個が出力されています。
channel lastのConv1Dは期待通り、データ数が10、フィルタ数50となっています。

import numpy as np
from keras.models import Model, Sequential
from keras.layers import Dense, Activation, Flatten, Conv1D, Input
from keras.utils import np_utils
from keras.utils.vis_utils import plot_model
from keras.optimizers import Adam
import pydot_ng as pydot

Using TensorFlow backend.
DATA_SAMPLES = 20
CHANNELS = 1
DATA_LENGTH = 10
KERNEL_SIZE = 3
FILTERS = 50

data = np.arange(DATA_SAMPLES * CHANNELS * DATA_LENGTH) * 10
data_channels_first = data.reshape(DATA_SAMPLES, CHANNELS, DATA_LENGTH)
data_channels_last = data_channels_first.transpose(0,2,1)

seikai = np.arange(DATA_SAMPLES).reshape(-1, 1)
print(data_channels_first.shape, data_channels_last.shape)

#print(data_channels_first)
#print(data_channels_last)

(20, 1, 10) (20, 10, 1)
def make_model(shape):
    input_layer = Input(batch_shape=shape)
    conv_1d_output_layer = Conv1D(FILTERS, KERNEL_SIZE, padding='same')(input_layer)
    flatten_output_layer = Flatten()(conv_1d_output_layer)
    prediction_result = Dense(1)(flatten_output_layer)

    model = Model(inputs=input_layer, outputs=prediction_result)
    return model
model_channels_first = make_model(shape=(None, CHANNELS, DATA_LENGTH))
model_channels_last  = make_model(shape=(None, DATA_LENGTH, CHANNELS))
model_channels_first.compile(optimizer="adam", loss='mse', metrics=['accuracy'])
print('channels first model')
print(model_channels_first.summary())

print('\n\n')

model_channels_last.compile(optimizer="adam", loss='mse', metrics=['accuracy'])
print('channels last model')
print(model_channels_last.summary())
channels first model
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 1, 10)             0         
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 1, 50)             1550      
_________________________________________________________________
flatten_1 (Flatten)          (None, 50)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 51        
=================================================================
Total params: 1,601.0
Trainable params: 1,601.0
Non-trainable params: 0.0
_________________________________________________________________
None



channels last model
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         (None, 10, 1)             0         
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 10, 50)            200       
_________________________________________________________________
flatten_2 (Flatten)          (None, 500)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 501       
=================================================================
Total params: 701.0
Trainable params: 701.0
Non-trainable params: 0.0
_________________________________________________________________
None
model_channels_first.fit(data_channels_first, seikai, batch_size=600, epochs=200)
Epoch 1/200
20/20 [==============================] - 1s - loss: 30414.6836 - acc: 0.0000e+00
Epoch 2/200
20/20 [==============================] - 0s - loss: 9868.0898 - acc: 0.0000e+00
Epoch 3/200
20/20 [==============================] - 0s - loss: 312.4087 - acc: 0.0500
Epoch 4/200
20/20 [==============================] - 0s - loss: 3033.8848 - acc: 0.0000e+00
Epoch 5/200
20/20 [==============================] - 0s - loss: 9813.7725 - acc: 0.0000e+00
Epoch 6/200
20/20 [==============================] - 0s - loss: 12278.8740 - acc: 0.0000e+00
Epoch 7/200
20/20 [==============================] - 0s - loss: 9633.3096 - acc: 0.0000e+00
Epoch 8/200
20/20 [==============================] - 0s - loss: 4913.5479 - acc: 0.0000e+00
Epoch 9/200
20/20 [==============================] - 0s - loss: 1164.7600 - acc: 0.0000e+00
Epoch 10/200
20/20 [==============================] - 0s - loss: 36.5087 - acc: 0.0500

省略

Epoch 190/200
20/20 [==============================] - 0s - loss: 17.9575 - acc: 0.0500
Epoch 191/200
20/20 [==============================] - 0s - loss: 17.9285 - acc: 0.0500
Epoch 192/200
20/20 [==============================] - 0s - loss: 17.8995 - acc: 0.0500
Epoch 193/200
20/20 [==============================] - 0s - loss: 17.8703 - acc: 0.0500
Epoch 194/200
20/20 [==============================] - 0s - loss: 17.8411 - acc: 0.0500
Epoch 195/200
20/20 [==============================] - 0s - loss: 17.8119 - acc: 0.0500
Epoch 196/200
20/20 [==============================] - 0s - loss: 17.7828 - acc: 0.0500
Epoch 197/200
20/20 [==============================] - 0s - loss: 17.7535 - acc: 0.0500
Epoch 198/200
20/20 [==============================] - 0s - loss: 17.7242 - acc: 0.0500
Epoch 199/200
20/20 [==============================] - 0s - loss: 17.6949 - acc: 0.0500
Epoch 200/200
20/20 [==============================] - 0s - loss: 17.6655 - acc: 0.0500





<keras.callbacks.History at 0x1d1bf2cd240>
model_channels_last.fit(data_channels_last, seikai, batch_size=600, epochs=200)
Epoch 1/200
20/20 [==============================] - 0s - loss: 969.5367 - acc: 0.0000e+00
Epoch 2/200
20/20 [==============================] - 0s - loss: 2515.6497 - acc: 0.0000e+00
Epoch 3/200
20/20 [==============================] - 0s - loss: 575.2820 - acc: 0.0000e+00
Epoch 4/200
20/20 [==============================] - 0s - loss: 387.8024 - acc: 0.0500
Epoch 5/200
20/20 [==============================] - 0s - loss: 1396.0237 - acc: 0.0000e+00
Epoch 6/200
20/20 [==============================] - 0s - loss: 523.0975 - acc: 0.0500
Epoch 7/200
20/20 [==============================] - 0s - loss: 25.1633 - acc: 0.0000e+00
Epoch 8/200
20/20 [==============================] - 0s - loss: 682.5659 - acc: 0.0000e+00
Epoch 9/200
20/20 [==============================] - 0s - loss: 753.8232 - acc: 0.0000e+00
Epoch 10/200
20/20 [==============================] - 0s - loss: 151.1371 - acc: 0.0000e+00
Epoch 11/200
20/20 [==============================] - 0s - loss: 69.5525 - acc: 0.0500
省略
Epoch 190/200
20/20 [==============================] - 0s - loss: 0.0113 - acc: 1.0000
Epoch 191/200
20/20 [==============================] - 0s - loss: 0.0112 - acc: 1.0000
Epoch 192/200
20/20 [==============================] - 0s - loss: 0.0110 - acc: 1.0000
Epoch 193/200
20/20 [==============================] - 0s - loss: 0.0109 - acc: 1.0000
Epoch 194/200
20/20 [==============================] - 0s - loss: 0.0107 - acc: 1.0000
Epoch 195/200
20/20 [==============================] - 0s - loss: 0.0106 - acc: 1.0000
Epoch 196/200
20/20 [==============================] - 0s - loss: 0.0104 - acc: 1.0000
Epoch 197/200
20/20 [==============================] - 0s - loss: 0.0103 - acc: 1.0000
Epoch 198/200
20/20 [==============================] - 0s - loss: 0.0101 - acc: 1.0000
Epoch 199/200
20/20 [==============================] - 0s - loss: 0.0100 - acc: 1.0000
Epoch 200/200
20/20 [==============================] - 0s - loss: 0.0098 - acc: 1.0000





<keras.callbacks.History at 0x1d33647eef0>
predicted_channels_first = model_channels_first.predict(data_channels_first) 
predicted_channels_last  = model_channels_last.predict(data_channels_last) 
predicted_channels_first
array([[  8.09207153],
       [  8.46234894],
       [  8.83262348],
       [  9.20289421],
       [  9.57317162],
       [  9.94345951],
       [ 10.31372738],
       [ 10.6840086 ],
       [ 11.05428028],
       [ 11.42456627],
       [ 11.79482174],
       [ 12.16512108],
       [ 12.53540611],
       [ 12.90565395],
       [ 13.2759409 ],
       [ 13.64623451],
       [ 14.01647854],
       [ 14.38675022],
       [ 14.7570734 ],
       [ 15.12734604]], dtype=float32)
predicted_channels_last
array([[  0.18978588],
       [  1.17511916],
       [  2.16048503],
       [  3.14581513],
       [  4.13118029],
       [  5.11650562],
       [  6.10181856],
       [  7.08713865],
       [  8.07252884],
       [  9.05784416],
       [ 10.04322052],
       [ 11.02850723],
       [ 12.01392174],
       [ 12.99925709],
       [ 13.98452473],
       [ 14.96988583],
       [ 15.95536232],
       [ 16.94049072],
       [ 17.92590523],
       [ 18.91131401]], dtype=float32)

最後に

いつも色々な情報提供して下さるHirofumi Yashimaさんに多謝です。

6
12
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?