LoginSignup
10
11

More than 5 years have passed since last update.

TFLearnで残差ネット(Residual Network)を書いてみる

Last updated at Posted at 2016-08-11

Residual Networkとは?

ILSVRC2015(世界的な一般画像認識コンテスト)で優勝したモデル
VGG Netのようなシステムに比べて計算量が少なく、単純に層を深くすることで精度を出しやすいらしい
詳しくは以下を参照

Deep Residual Learning (ILSVRC2015 winner)
[Survey]Deep Residual Learning for Image Recognition
keras-resnet

インストール

TensorFlow0.9以上が必要(TFLearn Installation

$ pip install tflearn

最新版はこちらから

$ pip install git+https://github.com/tflearn/tflearn.git

Residual Network

TFLearnのlayerにResidual BlockやResidual Bottleneckが実装されているので、それを使うだけ。

2016/8/13: Residual Bottleneckの書き方が間違えていたので修正。downsample=Trueだとエラーが発生。原因がわかったらコードを修正予定。(TFLearnのResidual Blockの実装を見てみたら、downsample=Trueの場合、最初のConv2DでStrides=2した上で最後にAvgPool2Dしてるんだけど、これあってるのかな?)
2016/8/15: 最初KerasとTFLearnの実装しか見ていなかったけど、Deep Residual Learning(ResNet)の実装から比較するディープラーニングフレームワークで複数の実装を見比べて原因が判明。
ダウンサンプル時に元の特徴マップを、
1. サイズはKernelSize=1, Strides=2で縮小、チャンネル数はゼロパディングで埋めて縮小する実装
2. KernelSize=1, Strides=2の畳み込みでShapeを変更する実装
があるらしい
Kerasの実装は後者、TFLearnの実装は前者を採用しているけど少なくとも自分の使っているTensorFlowのバージョン(0.9.0)ではAveragePoolingでKernelSize >= Stridesという制約があるせいでエラーが発生していた。AveragePoolingをKernelSize = Stridesとして、以下のように両方を実装して切り替えられるようにした。TensorFlowの実装として紹介されているものは、KernelSize = StridesのMaxPoolingになっていた。
ただ、1と2のどちらの実装も情報の落とし方が乱暴な気がする。これでうまくいくってことは、各ブロッックで少しずつサイズを落とすように学習しているってことかな?
2016/8/20: TensorFlowの実装がMaxPoolingだったし、元の意図からするとそっちのほうが近そうなのでMaxPoolingに変更して色々と実験。その後TensorFlow0.10にバージョンアップ。

cifar10.py
# -*- coding: utf-8 -*-

from __future__ import (
    absolute_import,
    division,
    print_function
)

from six.moves import range

import tensorflow as tf
import tflearn
from tflearn.datasets import cifar10

nb_first_filter = 64
reputation_list = [8, 8]
# 'basic' => Residual Block, 'deep' => Residual Bottleneck
residual_mode = 'basic'
# 'padding' => Zero Padding, 'shortcut' => Projection Shortcut
downsample_mode = 'padding'

nb_class = 10

(X_train, y_train), (X_test, y_test) = cifar10.load_data()
y_train = tflearn.data_utils.to_categorical(y_train, 10)
y_test = tflearn.data_utils.to_categorical(y_test, 10)

# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)

# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)

def residual_net(inputs, nb_first_filter, reputation_list, residual_mode='basic', activation='relu'):
    net = tflearn.conv_2d(inputs, nb_first_filter, 7, strides=2)
    net = tflearn.batch_normalization(net)
    net = tflearn.activation(net, activation)
    net = tflearn.max_pool_2d(net, 3, strides=2)
    for i, nb_shortcut in enumerate(reputation_list):
        if i == 0:
            if residual_mode == 'basic':
                net = tflearn.residual_block(net, nb_shortcut, nb_first_filter, activation=activation)
            elif residual_mode == 'deep':
                net = tflearn.residual_bottleneck(net, nb_shortcut, nb_first_filter, nb_first_filter * 4, activation=activation)
            else:
                raise Exception('Residual mode should be basic/deep')
        else:
            nb_filter = nb_first_filter * 2**i
            if residual_mode == 'basic':
                net = tflearn.residual_block(net, 1, nb_filter, activation=activation, downsample=True)
                net = tflearn.residual_block(net, nb_shortcut - 1, nb_filter, activation=activation)
            else:
                net = tflearn.residual_bottleneck(net, 1, nb_filter, nb_filter * 4, activation=activation, downsample=True)
                net = tflearn.residual_bottleneck(net, nb_shortcut - 1, nb_filter, nb_filter * 4, activation=activation)
    net = tflearn.batch_normalization(net)
    net = tflearn.activation(net, activation)
    net = tflearn.global_avg_pool(net)
    net = tflearn.fully_connected(net, nb_class, activation='softmax')
    return net

net = tflearn.input_data(shape=[None, 32, 32, 3], data_preprocessing=img_prep, data_augmentation=img_aug)
net = residual_net(net, nb_first_filter, reputation_list, residual_mode=residual_mode)
net = tflearn.regression(net, optimizer='adam', loss='categorical_crossentropy')
model = tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',
                    max_checkpoints=10, tensorboard_verbose=0)

model.fit(X_train, y_train, n_epoch=200, validation_set=(X_test, y_test),
          snapshot_epoch=False, snapshot_step=500,
          show_metric=True, batch_size=128, shuffle=True,
          run_id='resnet_cifar10')
# For TensorFlow 0.9
def residual_block(incoming, nb_blocks, out_channels, downsample=False,
                   downsample_strides=2, activation='relu', batch_norm=True,
                   bias=True, weights_init='variance_scaling', bias_init='zeros',
                   regularizer='L2', weight_decay=0.0001, trainable=True,
                   restore=True, reuse=False, scope=None, name='ResidualBlock'):
    resnet = incoming
    in_channels = incoming.get_shape().as_list()[-1]

    with tf.variable_op_scope([incoming], scope, name, reuse=reuse) as scope:
        name = scope.name

        for i in range(nb_blocks):

            identity = resnet

            if not downsample:
                downsample_strides = 1

            if batch_norm:
                resnet = tflearn.batch_normalization(resnet)
            resnet = tflearn.activation(resnet, activation)

            resnet = tflearn.conv_2d(resnet, out_channels, 3, downsample_strides,
                                     'same', 'linear', bias, weights_init,
                                     bias_init, regularizer, weight_decay,
                                     trainable, restore)

            if downsample_mode == 'original':
                if downsample_strides > 1 or in_channels != out_channels:
                    identity = resnet
                    in_channels = out_channels

            if batch_norm:
                resnet = tflearn.batch_normalization(resnet)
            resnet = tflearn.activation(resnet, activation)

            resnet = tflearn.conv_2d(resnet, out_channels, 3, 1, 'same',
                                     'linear', bias, weights_init, bias_init,
                                     regularizer, weight_decay, trainable,
                                     restore)

            if downsample_mode == 'padding':
                # Downsampling
                if downsample_strides > 1:
                    identity = tflearn.max_pool_2d(identity, downsample_strides, downsample_strides)

                # Projection to new dimension
                if in_channels != out_channels:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch]])
                    in_channels = out_channels
            elif downsample_mode == 'shortcut':
                if downsample_strides > 1 or in_channels != out_channels:
                    identity = tflearn.conv_2d(identity, out_channels, 1, downsample_strides, 'same')
                    in_channels = out_channels
            elif downsample_mode == 'original':
                pass
            else:
                raise Exception('Downsample mode should be padding/shortcut')

            resnet = resnet + identity

    return resnet

def residual_bottleneck(incoming, nb_blocks, bottleneck_size, out_channels,
                        downsample=False, downsample_strides=2, activation='relu',
                        batch_norm=True, bias=True, weights_init='variance_scaling',
                        bias_init='zeros', regularizer='L2', weight_decay=0.0001,
                        trainable=True, restore=True, reuse=False, scope=None,
                        name="ResidualBottleneck"):
    resnet = incoming
    in_channels = incoming.get_shape().as_list()[-1]

    with tf.variable_op_scope([incoming], scope, name, reuse=reuse) as scope:
        name = scope.name

        for i in range(nb_blocks):

            identity = resnet

            if not downsample:
                downsample_strides = 1

            if batch_norm:
                resnet = tflearn.batch_normalization(resnet)
            resnet = tflearn.activation(resnet, activation)

            resnet = tflearn.conv_2d(resnet, bottleneck_size, 1,
                                     downsample_strides, 'valid', 'linear', bias,
                                     weights_init, bias_init, regularizer,
                                     weight_decay, trainable, restore)

            if downsample_mode == 'original':
                if downsample_strides > 1 or in_channels != out_channels:
                    identity = resnet
                    in_channels = out_channels

            if batch_norm:
                resnet = tflearn.batch_normalization(resnet)
            resnet = tflearn.activation(resnet, activation)

            resnet = tflearn.conv_2d(resnet, bottleneck_size, 3, 1, 'same',
                                     'linear', bias, weights_init, bias_init,
                                     regularizer, weight_decay, trainable,
                                     restore)

            if batch_norm:
                resnet = tflearn.batch_normalization(resnet)
            resnet = tflearn.activation(resnet, activation)

            resnet = tflearn.conv_2d(resnet, out_channels, 1, 1, 'valid',
                                     'linear', bias, weights_init, bias_init,
                                     regularizer, weight_decay, trainable,
                                     restore)

            if downsample_mode == 'padding':
                # Downsampling
                if downsample_strides > 1:
                    identity = tflearn.max_pool_2d(identity, downsample_strides, downsample_strides)

                # Projection to new dimension
                if in_channels != out_channels:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch]])
                    in_channels = out_channels
            elif downsample_mode == 'shortcut':
                if downsample_strides > 1 or in_channels != out_channels:
                    identity = tflearn.conv_2d(identity, out_channels, 1, downsample_strides, 'same')
                    in_channels = out_channels
            elif downsample_mode == 'original':
                pass
            else:
                raise Exception('Downsample mode should be padding/shortcut')

    return resnet

tflearn.residual_block = residual_block
tflearn.residual_bottleneck = residual_bottleneck

参考:residual_network_cifar10.py

実験

いくつかよくわからない点があったので検証してみた
題材はCIFAR10

比較対象となる基本形は上のコードで
residual_mode = 'basic'
downsample_mode = 'padding'

実行結果
Accuracy = 0.8516

最初の畳み込みのカーネルサイズ

経験的に7がよかったのか、画像サイズによって異なるのか、試してみた
最初のカーネルサイズを最後の特徴マップサイズになるように変更

    # net = tflearn.conv_2d(inputs, nb_first_filter, 7, strides=2)
    side = inputs.get_shape().as_list()[1]
    first_kernel = side // 2**(len(reputation_list) + 1)
    net = tflearn.conv_2d(inputs, nb_first_filter, first_kernel, strides=2)

実行結果
Accuracy = 0.8506

あの7は一体何なんだろう?

最初のResidual Blockの特徴マップ縮小方法

なぜここだけMaxPoolingを使っているのか、畳み込んでもいいのか、試してみた

    # net = tflearn.max_pool_2d(net, 3, strides=2)
    net = tflearn.conv_2d(net, nb_first_filter, 3, strides=2)
    net = tflearn.batch_normalization(net)
    net = tflearn.activation(net, activation)

実行結果
Accuracy = 0.8634

少しだけ効果はありそうだけど、計算量も少しだけ増える

downsamplingの実装方法

downsample_mode = 'shortcut'

以下によると、downsampleを畳み込みにすると最適化が難しいとのことだが、(TensorFlow0.9では)元の方法が単純には実装できないので、畳み込みとMaxPoolingで比較してみる

[Survey]Identity Mappings in Deep Residual Networks

実行結果
Accuracy = 0.8385

Residual BlockとBottleneckの違い

residual_mode = 'deep'

実行結果
Accuracy = 0.8333

もっと層が深くなると結果も変わってくるんだろうか?

最後を全結合層にするとどうなるか?

GlobalAveragePoolingから512ノードの全結合層2つに変更してみた

    # net = tflearn.global_avg_pool(net)
    net = tflearn.fully_connected(net, 512)
    net = tflearn.batch_normalization(net)
    net = tflearn.activation(net, activation)
    net = tflearn.fully_connected(net, 512)
    net = tflearn.batch_normalization(net)
    net = tflearn.activation(net, activation)

実行結果
Accuracy = 0.8412

TrainingのAccuracyはこっちのほうが高くなっていたので、GlobalAveragePoolingのほうがよさそう

TensorFlow 0.10.0(本来のResidual Net)

実行結果
Accuracy = 0.8520

Stochastic Depth

Residual Netに対してドロップアウトのような効果のあるStochastic Depthも実装してみたかったが、生のTensorFlowを使って色々と実装しないといけない感じがしたので、今回は見送り

2016/8/20: Stochastic Depthを実装

Stochastic Depthについては以下を参照

[Survey]Deep Networks with Stochastic Depth
stochastic_depth_keras

residual_network_with_stochastic_depth.py
# -*- coding: utf-8 -*-

from __future__ import (
    absolute_import,
    division,
    print_function
)

from six.moves import range

import tensorflow as tf
import tflearn
from tflearn.datasets import cifar10

nb_first_filter = 64
reputation_list = [8, 8]
# 'basic' => Residual Block, 'deep' => Residual Bottleneck
residual_mode = 'basic'
# 'linear' => Linear Decay, 'uniform' => Uniform, 'none' => None
stochastic_depth_mode = 'linear'
stochastic_skip = 0.5

nb_class = 10

(X_train, y_train), (X_test, y_test) = cifar10.load_data()
y_train = tflearn.data_utils.to_categorical(y_train, 10)
y_test = tflearn.data_utils.to_categorical(y_test, 10)

# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)

# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)

def addBlock(incoming, bottleneck_size, out_channels, threshold=0.0,
             residual_mode='basic', downsample=False, downsample_strides=2,
             activation='relu', batch_norm=True, bias=True,
             weights_init='variance_scaling', bias_init='zeros',
             regularizer='L2', weight_decay=0.0001, trainable=True,
             restore=True, reuse=False, scope=None):
    if residual_mode == 'basic':
        residual_path = tflearn.residual_block(
                        incoming, 1, out_channels, downsample=downsample,
                        downsample_strides=downsample_strides,
                        activation=activation, batch_norm=batch_norm, bias=bias,
                        weights_init=weights_init, bias_init=bias_init,
                        regularizer=regularizer, weight_decay=weight_decay,
                        trainable=trainable, restore=restore, reuse=reuse,
                        scope=scope)
    else:
        residual_path = tflearn.residual_bottleneck(
                        incoming, 1, bottleneck_size, out_channels,
                        downsample=downsample,
                        downsample_strides=downsample_strides,
                        activation=activation, batch_norm=batch_norm, bias=bias,
                        weights_init=weights_init, bias_init=bias_init,
                        regularizer=regularizer, weight_decay=weight_decay,
                        trainable=trainable, restore=restore, reuse=reuse,
                        scope=scope)
    if downsample:
        in_channels = incoming.get_shape().as_list()[-1]

        with tf.variable_op_scope([incoming], scope, 'Downsample', 
                                  reuse=reuse) as scope:
            name = scope.name
            # Downsampling
            inference = tflearn.avg_pool_2d(incoming, 1, downsample_strides)
            # Projection to new dimension
            if in_channels != out_channels:
                ch = (out_channels - in_channels)//2
                inference = tf.pad(inference, [[0, 0], [0, 0], [0, 0], [ch, ch]])
            # Track activations.
            tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, inference)
        # Add attributes to Tensor to easy access weights
        inference.scope = scope
        # Track output tensor.
        tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, inference)

        skip_path = inference
    else:
        skip_path = incoming

    p = tf.random_uniform([1])[0]

    return tf.cond(p > threshold, lambda: residual_path, lambda: skip_path)

def residual_net(inputs, nb_first_filter, reputation_list, downsample_strides=2,
                 activation='relu', batch_norm=True, bias=True,
                 weights_init='variance_scaling', bias_init='zeros',
                 regularizer='L2', weight_decay=0.0001, trainable=True,
                 restore=True, reuse=False, scope=None, residual_mode='basic',
                 stochastic_depth_mode='linear', stochastic_skip=0.0,
                 is_training=True):
    if not is_training:
        stochastic_depth_mode = 'none'
        stochastic_skip = 0.0
    side = inputs.get_shape().as_list()[1]
    first_kernel = side // 2**(len(reputation_list) + 1)

    net = tflearn.conv_2d(inputs, nb_first_filter, first_kernel, strides=2)
    net = tflearn.batch_normalization(net)
    net = tflearn.activation(net, activation)

    net = tflearn.max_pool_2d(net, 3, strides=2)

    block_total = sum(reputation_list)
    block_current = 0
    for i, nb_block in enumerate(reputation_list):
        nb_filter = nb_first_filter * 2**i

        assert stochastic_depth_mode in ['linear', 'uniform', 'none'], 'Stochastic depth mode should be linear/uniform/none'
        assert residual_mode in ['basic', 'deep'], 'Residual mode should be basic/deep'
        for j in range(nb_block):
            block_current += 1

            if stochastic_depth_mode == 'linear':
                threshold = stochastic_skip * block_current / block_total
            else:
                threshold = stochastic_skip

            bottleneck_size = nb_filter
            if residual_mode == 'basic':
                out_channels = nb_filter
            else:
                out_channels = nb_filter * 4
            if i != 0 and j == 0:
                downsample = True
            else:
                downsample = False
            net = addBlock(net, bottleneck_size, out_channels,
                           downsample=downsample, threshold=threshold,
                           residual_mode=residual_mode,
                           downsample_strides=downsample_strides,
                           activation=activation, batch_norm=batch_norm,
                           bias=bias, weights_init=weights_init,
                           bias_init=bias_init, regularizer=regularizer,
                           weight_decay=weight_decay, trainable=trainable,
                           restore=restore, reuse=reuse, scope=scope)
    net = tflearn.batch_normalization(net)
    net = tflearn.activation(net, activation)
    net = tflearn.global_avg_pool(net)
    net = tflearn.fully_connected(net, nb_class, activation='softmax')
    return net

inputs = tflearn.input_data(shape=[None, 32, 32, 3],
                            data_preprocessing=img_prep,
                            data_augmentation=img_aug)
net = residual_net(inputs, nb_first_filter, reputation_list,
                   residual_mode=residual_mode,
                   stochastic_depth_mode=stochastic_depth_mode,
                   stochastic_skip=stochastic_skip)
net = tflearn.regression(net, optimizer='adam', loss='categorical_crossentropy')
model = tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',
                    max_checkpoints=10, tensorboard_verbose=0)

model.fit(X_train, y_train, n_epoch=200, snapshot_epoch=False, snapshot_step=500,
          show_metric=True, batch_size=128, shuffle=True, run_id='resnet_cifar10')
residual_network_with_stochastic_depth_test.py
inputs = tflearn.input_data(shape=[None, 32, 32, 3])
net_test = residual_net(inputs, nb_first_filter, reputation_list,
                        residual_mode=residual_mode,
                        stochastic_depth_mode=stochastic_depth_mode,
                        is_training=False)
model_test = tflearn.DNN(net)
model_test.load('model_resent_cifar10-xxxxx') # set the latest number
print(model_test.evaluate(X_test, y_test))

tf.condは条件の真偽にかかわらず両方を実行しているらしく、計算量の削減効果はなさそう
今回はtrainingとvalidation・testでis_trainingの値を変更しなくてはいけないのだが、同一step・epochでそれを実現する方法が見つからなかったので、testは別に実行
学習はできていそうだったが、途中でGPUメモリ周りがおかしくなったのか挙動が変だったので止めて、途中までのモデルでテストを試みる
TFLearnはデフォルトだとおそらく同じGraphに全てを書き込もうとするため、テスト用のモデルを同じファイル内で構築するとレイヤーの名前が変わってしまって、うまくロードできない
get_weight・set_weightを使ったり、tf.Graphを使えば問題ないのだろうけど、そこまでするとTFLearnを使う手軽さが薄れるので、ファイルをわけて別に実行するのが手っ取り早い
上のテスト用コードを実行すると、エラーが出て評価できなかった
predictは問題なく動いていそうだったので、TFLearnのバグか?
とりあえず、今回は断念

蛇足

コードを眺めていると、必要な情報を特定の位置に少しずつ集め(不要なものを捨て)て、1x1poolingでそれを回収していくイメージなのかなと思う
じゃあ、1x1poolingをフィルタ方向にも拡張したら、必要なフィルタと不要なフィルタが選別されてフィルタの枚数も節約できないかと素人考えで疑問を持ち、試してみた

residual_network_with_kernel_pooling.py
# -*- coding: utf-8 -*-

from __future__ import (
    absolute_import,
    division,
    print_function
)

from six.moves import range

import tensorflow as tf
import tflearn
from tflearn.datasets import cifar10

nb_filter = 64
reputation_list = [8, 8]
# 'basic' => Residual Block, 'deep' => Residual Bottleneck
residual_mode = 'basic'

nb_class = 10

(X_train, y_train), (X_test, y_test) = cifar10.load_data()
y_train = tflearn.data_utils.to_categorical(y_train, 10)
y_test = tflearn.data_utils.to_categorical(y_test, 10)

# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)

# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)

def avg_1x1pool_2d_all(incoming, kernel_size, strides, padding='same',
                       name='Avg1x1Pool2DAll'):
    input_shape = tflearn.utils.get_incoming_shape(incoming)
    assert len(input_shape) == 4, "Incoming Tensor shape must be 4-D"

    if isinstance(kernel_size, int):
        kernel = [1, kernel_size, kernel_size, kernel_size]
    elif isinstance(kernel_size, (tuple, list)):
        if len(kernel_size) == 3:
            kernel = [1, strides[0], strides[1], strides[2]]
        elif len(kernel_size) == 4:
            kernel = [strides[0], strides[1], strides[2], strides[3]]
        else:
            raise Exception("strides length error: " + str(len(strides))
                            + ", only a length of 3 or 4 is supported.")
    if isinstance(strides, int):
        strides = [1, strides, strides, strides]
    elif isinstance(strides, (tuple, list)):
        if len(strides) == 3:
            strides = [1, strides[0], strides[1], strides[2]]
        elif len(strides) == 4:
            strides = [strides[0], strides[1], strides[2], strides[3]]
        else:
            raise Exception("strides length error: " + str(len(strides))
                            + ", only a length of 3 or 4 is supported.")
    else:
        raise Exception("strides format error: " + str(type(strides)))
    padding = tflearn.utils.autoformat_padding(padding)

    with tf.name_scope(name) as scope:
        inference = tf.nn.avg_pool(incoming, kernel, strides, padding)

        # Track activations.
        tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, inference)

    # Add attributes to Tensor to easy access weights
    inference.scope = scope

    # Track output tensor.
    tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, inference)

    return inference

def residual_block(incoming, nb_blocks, downsample=False, downsample_strides=2,
                   activation='relu', batch_norm=True, bias=True,
                   weights_init='variance_scaling', bias_init='zeros',
                   regularizer='L2', weight_decay=0.0001, trainable=True,
                   restore=True, reuse=False, scope=None, name="ResidualBlock"):
    resnet = incoming
    in_channels = incoming.get_shape().as_list()[-1]

    with tf.variable_op_scope([incoming], scope, name, reuse=reuse) as scope:
        name = scope.name

        for i in range(nb_blocks):

            identity = resnet

            if not downsample:
                downsample_strides = 1

            if batch_norm:
                resnet = tflearn.batch_normalization(resnet)
            resnet = tflearn.activation(resnet, activation)

            resnet = tflearn.conv_2d(resnet, in_channels, 3, downsample_strides,
                                     'same', 'linear', bias, weights_init,
                                     bias_init, regularizer, weight_decay,
                                     trainable, restore)

            if batch_norm:
                resnet = tflearn.batch_normalization(resnet)
            resnet = tflearn.activation(resnet, activation)

            resnet = tflearn.conv_2d(resnet, in_channels, 3, 1, 'same',
                                     'linear', bias, weights_init, bias_init,
                                     regularizer, weight_decay, trainable,
                                     restore)

            # Downsampling
            if downsample_strides > 1:
                identity = avg_1x1pool_2d_all(identity, 1, downsample_strides)

                # Projection to new dimension
                current_channels = identity.get_shape().as_list()[-1]
                ch = (in_channels - current_channels)//2
                identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch]])

            resnet = resnet + identity

        # Track activations.
        tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, resnet)

    # Add attributes to Tensor to easy access weights.
    resnet.scope = scope

    # Track output tensor.
    tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, resnet)

    return resnet

def residual_bottleneck(incoming, nb_blocks, out_channels, downsample=False,
                        downsample_strides=2, activation='relu',
                        batch_norm=True, bias=True,
                        weights_init='variance_scaling', bias_init='zeros',
                        regularizer='L2', weight_decay=0.0001, trainable=True,
                        restore=True, reuse=False, scope=None,
                        name="ResidualBottleneck"):
    resnet = incoming
    in_channels = incoming.get_shape().as_list()[-1]

    with tf.variable_op_scope([incoming], scope, name, reuse=reuse) as scope:
        name = scope.name

        for i in range(nb_blocks):

            identity = resnet

            if not downsample:
                downsample_strides = 1

            if batch_norm:
                resnet = tflearn.batch_normalization(resnet)
            resnet = tflearn.activation(resnet, activation)

            resnet = tflearn.conv_2d(resnet, in_channels, 1, downsample_strides,
                                     'valid', 'linear', bias, weights_init,
                                     bias_init, regularizer, weight_decay,
                                     trainable, restore)

            if batch_norm:
                resnet = tflearn.batch_normalization(resnet)
            resnet = tflearn.activation(resnet, activation)

            resnet = tflearn.conv_2d(resnet, in_channels, 3, 1, 'same',
                                     'linear', bias, weights_init, bias_init,
                                     regularizer, weight_decay, trainable,
                                     restore)

            resnet = tflearn.conv_2d(resnet, out_channels, 1, 1, 'valid',
                                     activation, bias, weights_init, bias_init,
                                     regularizer, weight_decay, trainable,
                                     restore)

            # Downsampling
            if downsample_strides > 1:
                identity = avg_1x1pool_2d_all(identity, 1, downsample_strides)

            # Projection to new dimension
            current_channels = identity.get_shape().as_list()[-1]
            ch = (out_channels - current_channels)//2
            identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch]])

            resnet = resnet + identity

        # Track activations.
        tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, resnet)

    # Add attributes to Tensor to easy access weights.
    resnet.scope = scope

    # Track output tensor.
    tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, resnet)

    return resnet

tflearn.residual_block = residual_block
tflearn.residual_bottleneck = residual_bottleneck

def residual_net(inputs, nb_filter, reputation_list, residual_mode='basic', activation='relu'):
    net = tflearn.conv_2d(inputs, nb_filter, 7, strides=2)
    net = tflearn.batch_normalization(net)
    net = tflearn.activation(net, activation)
    net = tflearn.max_pool_2d(net, 3, strides=2)

    assert residual_mode in ['basic', 'deep'], 'Residual mode should be basic/deep'
    for i, nb_block in enumerate(reputation_list):
        for j in range(nb_block):
            downsample = True if i != 0 and j == 0 else False

            if residual_mode == 'basic':
                net = tflearn.residual_block(net, 1, activation=activation,
                                             downsample=downsample)
            else:
                net = tflearn.residual_bottleneck(net, 1, nb_filter * 4,
                                                  activation=activation,
                                                  downsample=downsample)
    net = tflearn.batch_normalization(net)
    net = tflearn.activation(net, activation)
    net = tflearn.global_avg_pool(net)

    net = tflearn.fully_connected(net, nb_class, activation='softmax')
    return net

net = tflearn.input_data(shape=[None, 32, 32, 3], data_preprocessing=img_prep, data_augmentation=img_aug)
net = residual_net(net, nb_filter, reputation_list, residual_mode=residual_mode)
net = tflearn.regression(net, optimizer='adam', loss='categorical_crossentropy')
model = tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',
                    max_checkpoints=10, tensorboard_verbose=0)

model.fit(X_train, y_train, n_epoch=200, validation_set=(X_test, y_test),
          snapshot_epoch=False, snapshot_step=500,
          show_metric=True, batch_size=128, shuffle=True,
          run_id='resnet_cifar10')

実行結果

ValueError: Current implementation does not support strides in the batch and depth dimensions.

TensorFlowに余計なことするなって怒られた……
よく知らないけど、Maxout使えば似たようなことできるのかな?
とりあえず、今回は断念

10
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
11