36
46

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlowで学習してモデルファイルを小さくしてコマンドラインアプリを作るシンプルな流れ

Last updated at Posted at 2017-04-14

たぶんこれが一番シンプルだと思います。

ソースは YusukeSuzuki/tensorflow_minimum_template にあります。

TensorFlowのチュートリアル MNIST For ML Beginners を終えてTensorBoardを使ったりなんかモデルを書いてみたりすることに慣れて、さてどうしようかというレベルの人向けです。

やることとやらないこと

この記事では以下のことをやります。

  • TensorFlowで何かを学習する。
  • 学習データであるチェックポイントファイルから最適化器のデータを取り除いて小さくする。
  • 小さくしたチェックポイントファイルをgraph_def ProtocolBuffersファイルに変換する。
  • graph_defからグラフを読み込んでコマンドラインアプリで利用する。

この記事では以下のことはやりません。

  • なんかすごいニューラルネットワークの作り方
    • プログラムの仕方であって理論の話ではないので
    • model.pyを目的のネットワークモデルに書き換えてその他のスクリプトの入出力の仕様を合わせれば同様の流れでコマンドラインアプリを作れると思います。
    • GANやRNNなど複雑なネットワークの場合は改修も複雑になると思います。
  • iOSとかAndroidとかで学習データを使う方法
  • setup.pyまで含めたワークフローの作り方。Python的にはここまでやるべきなのでしょうが主題がボケるので割愛。setup.pyに慣れていれば難しくないと思います。

参考リンク

モデルの定義

コンボリューション層を使った画像のオートエンコーダです。つまり画像を入力してコンボリューション層を通して元の画像と同じ出力をしようとします。オートエンコーダ単体にはあまり意味はなく初期学習や最近ではGANの一部として使われます。正解データを別に作成する作業が不要で学習でき、また学習の進み具合が目視で確認できるので例示に適しているため採用しました。

細かい解説はソースに全部あります。

# モデルのモジュールです。
# シンプルで学習の結果が見た目にわかりやすいオートエンコーダです。
# このモデル自体はなにかができるわけではありません。
# アプリケーション作成の説明用のものです。
import tensorflow as tf
import math

IMAGE_SUMMARY_MAX_OUTPUTS = 3

def inference(images):
    # 推論部分
    # フォーマットは'NHWC'を想定しています。
    kernel_size = 5
    out = images

    # あとで推論に関わる変数だけ取り出すために全体をスコープでくくります。
    with tf.variable_scope('inference'):
        # あとはレイヤーを重ねるだけ。
        # モデルの解説ではないので詳しいことは省略します。
        with tf.variable_scope('conv1'):
            prev_shape = out.get_shape().as_list()
            stddev = math.sqrt(2 / (kernel_size * kernel_size*prev_shape[3]))
            w = tf.get_variable(initializer=tf.truncated_normal_initializer(stddev=stddev),
                shape=[kernel_size, kernel_size, prev_shape[3], 32], name='weight')
            out = tf.nn.conv2d(out, w, padding='SAME', strides=[1,1,1,1])
            out = tf.nn.relu(out)
        with tf.variable_scope('downscale'):
            prev_shape = out.get_shape().as_list()
            stddev = math.sqrt(2 / (kernel_size * kernel_size*prev_shape[3]))
            w = tf.get_variable(initializer=tf.truncated_normal_initializer(stddev=stddev),
                shape=[kernel_size, kernel_size, prev_shape[3], 64], name='weight')
            out = tf.nn.conv2d(out, w, padding='SAME', strides=[1,2,2,1])
            out = tf.nn.relu(out)
        with tf.variable_scope('upscale'):
            reverse_shape = prev_shape
            prev_shape = out.get_shape().as_list()
            stddev = math.sqrt(2 / (kernel_size * kernel_size*prev_shape[3]))
            w = tf.get_variable(initializer=tf.truncated_normal_initializer(stddev=stddev),
                shape=[kernel_size, kernel_size, 32, prev_shape[3]], name='weight')
            out = tf.nn.conv2d_transpose(out, w, padding='SAME', strides=[1,2,2,1],
                output_shape=reverse_shape)
            out = tf.nn.relu(out)
        with tf.variable_scope('output'):
            prev_shape = out.get_shape().as_list()
            stddev = math.sqrt(2 / (kernel_size * kernel_size*prev_shape[3]))
            w = tf.get_variable(initializer=tf.truncated_normal_initializer(stddev=stddev),
                shape=[kernel_size, kernel_size, prev_shape[3], 3], name='weight')
            out = tf.nn.conv2d(out, w, padding='SAME', strides=[1,1,1,1])
            out = tf.nn.relu(out)

        # ログを出しておきます。
        # コレクションにデフォルト、意味的カテゴリ、型的カテゴリを追加しておくことで
        # ログ出しの切り分けがしやすくなると思います。
        # GPUではログ出しができないのでデバイスを指定すべきですがそれはtf.Sessionの
        # コンフィグで一括で行います。
        tf.summary.image('inference', out, max_outputs=IMAGE_SUMMARY_MAX_OUTPUTS,
            collections=[tf.GraphKeys.SUMMARIES, 'inference', 'image'])

    return out

def loss(label, inference):
    # ロスの定義です。自乗誤差基準です。
    with tf.name_scope('loss'):
        out = tf.squared_difference(label, inference)
        out_mean = tf.reduce_mean(out)
        tf.summary.scalar(
            'loss', out_mean, collections=[tf.GraphKeys.SUMMARIES, 'loss', 'scalar'])
    return out

def train(loss, global_step):
    # 学習の定義です。Adamオプティマイザ
    # 勾配の計算と適用を分割しています。勾配のヒストグラムをログに出せば
    # 学習中に勾配が消失していないか確認できます。
    # ただしログ出力が重くなりがちなので1000イテレーションに1回くらいが良いでしょう。
    # そこまでやらない場合は opt.minimize() で十分です。
    with tf.name_scope('train'):
        opt = tf.train.AdamOptimizer(1e-5)
        grads = opt.compute_gradients(loss)
        out = opt.apply_gradients(grads, global_step=global_step)

        for g, u in grads:
            if g is None:
                continue
            tf.summary.histogram(u.name, g,
                collections=[tf.GraphKeys.SUMMARIES, 'train', 'histogram'])

    return out

1. 学習コマンド

train.pyを使用してモデルを学習します。学習のセッションをそのまま保存するとファイルサイズはかなり大きくなります。

$ python train.py --samplesdir-/path/to/sample/jpg/directory --num-iterations-3000
$ ls -lh model_training/
total 29M
-rw-rw-r-- 1 user user  115 Apr 11 00:58 checkpoint
-rw-rw-r-- 1 user user 1.3M Apr 11 00:57 model-0.data-00000-of-00001
-rw-rw-r-- 1 user user  788 Apr 11 00:57 model-0.index
-rw-rw-r-- 1 user user  14M Apr 11 00:57 model-0.meta
-rw-rw-r-- 1 user user 1.3M Apr 11 00:58 model-1000.data-00000-of-00001
-rw-rw-r-- 1 user user  788 Apr 11 00:58 model-1000.index
-rw-rw-r-- 1 user user  14M Apr 11 00:58 model-1000.meta

ソースは以下のようになります。

# TensorFlowだと独自のコマンドライン引数パーサを使用していますがここでは
# python3標準のArgumentParserを使用します。
# TensorFlowが独自のものを使用するのはおそらくPython 2/3両対応のためでしょう。
from argparse import ArgumentParser
from pathlib import Path

import tensorflow as tf

# 自分のモデルをインポートしておきます
import simple_autoencoder.model as model
import simple_autoencoder.utils as utils

# 学習におけるイテレーション数など各種パラメータのデフォルト値です。
# 基本的にすべてコマンドライン引数で上書きできるようにしておきます。
DEFAULT_NUM_ITERATIONS=100000
DEFAULT_MINIBATCH_SIZE=16
DEFAULT_LOGDIR='./logs'
DEFAULT_SAMPLESDIR='./samples'
DEFAULT_INPUT_MODEL='model_training'
DEFAULT_OUTPUT_MODEL='model_training'

DEFAULT_INPUT_THREADS=8
DEFAULT_INPUT_QUEUE_MIN=2000
DEFAULT_INPUT_QUEUE_MAX=10000

DEFAULT_INPUT_WIDTH=128
DEFAULT_INPUT_HEIGHT=128

def create_argument_parser():
    # コマンドライン引数のパーサを作ります。
    # まず大体のアプリケーションで共通して使う引数を設定します。
    parser = ArgumentParser()
    parser.add_argument('-n','--num-iterations', type=int, default=DEFAULT_NUM_ITERATIONS)
    parser.add_argument('-m','--minibatch-size', type=int, default=DEFAULT_MINIBATCH_SIZE)
    parser.add_argument('-l','--logdir', type=str, default=DEFAULT_LOGDIR)
    parser.add_argument('-s','--samplesdir', type=str, default=DEFAULT_SAMPLESDIR)
    parser.add_argument('-i','--input-model', type=str, default=DEFAULT_INPUT_MODEL)
    parser.add_argument('-o','--output-model', type=str, default=DEFAULT_OUTPUT_MODEL)


    parser.add_argument('--input-threads', type=int, default=DEFAULT_INPUT_THREADS)
    parser.add_argument('--input-queue_min', type=int, default=DEFAULT_INPUT_QUEUE_MIN)
    parser.add_argument('--input-queue-max', type=int, default=DEFAULT_INPUT_QUEUE_MAX)
    return parser

def add_application_arguments(parser):
    # 今回のアプリケーション(オートエンコーダ)固有の引数を追加します。
    # ここらへんの切り分けはほとんど趣味です。
    parser.add_argument('--width', type=int, default=DEFAULT_INPUT_WIDTH)
    parser.add_argument('--height', type=int, default=DEFAULT_INPUT_HEIGHT)
    return parser

def main():
    # メインです。
    # コマンドライン引数のパーサを作ってパースして本処理に投げるだけ
    parser = create_argument_parser()
    parser = add_application_arguments(parser)
    args = parser.parse_args()
    proc(args)

def proc(args):
    # 学習サンプルの画像ファイルのリストを取得します。
    # まずはディレクトリを開く
    sample_dir = Path(args.samplesdir)

    # ディレクトリの直下からjpgファイルのパスを取得します。
    # Path.globを使用すればもうちょっと凝ったことができます。
    file_list = [p for p in sample_dir.iterdir() if p.suffix == '.jpg']
    file_list = list(map(str, file_list))

    # 学習のグラフを作っていきます。
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # まずはグローバルステップ。
        # 複数回に分けて学習するときには独自に作っておくと便利です。
        with tf.variable_scope('global'):
            global_step = tf.get_variable(
                'global_step', shape=[], initializer=tf.constant_initializer(0, dtype=tf.int64))

        # 次に画像ファイルの読み込みキューを作ります。
        # ここらへんはチュートリアルでもよく出てきます。
        with tf.name_scope('input'):
            filename_queue = tf.train.string_input_producer(file_list)
            reader = tf.WholeFileReader()
            images = tf.train.shuffle_batch(
                [utils.read_image_op(filename_queue, reader, args.height, args.width)],
                args.minibatch_size, args.input_queue_max, args.input_queue_min,
                num_threads=args.input_threads)

        # 推論と学習を組んでいきます。
        # GPUを1つ使用します。CPUしかない場合はここを書き換えることになります。
        # ここもコマンドライン引数にしてしまってもよいでしょう。
        with tf.device('/gpu:0'):
            with tf.variable_scope('model'):
                # 推論
                out_images = model.inference(images)
                # あとで使用するために推論に別名をつけておきます
                out_images = tf.identity(out_images, name='out_node')
                # ロスの計算
                loss = model.loss(images, out_images)
                # 学習オペレーションの取得
                train_op = model.train(loss, global_step)

        # ログ出力オペレーションを取得しておきます。
        log_op = tf.summary.merge_all()

        # セッションを作成します。
        # allow_soft_placement=True しておけばGPUでsummaryオペレーションを作っても
        # 善きに計らってくれます。
        # XLAなどの最適化もここで指定しますが割愛。
        config_proto = tf.ConfigProto(
            allow_soft_placement=True, log_device_placement=False)
        sess = tf.Session( config=config_proto)

        # ログ書き出しオブジェクトを作ります。
        # 実際にはもう少し詳しくディレクトリを分けたほうがよいでしょう
        # 例: ./logs/train, ./logs/test など
        writer = tf.summary.FileWriter('./logs')

        # tf.Variable を初期化します。
        sess.run(tf.global_variables_initializer())

        # チェックポイントファイルのIOオブジェクトを作ります。
        saver = tf.train.Saver()

        # モデル出力先ディレクトリを作成します。
        training_dir = Path(args.output_model)
        training_dir.mkdir(parents=True, exist_ok=True)

        # 以前に学習したモデルが存在すればそれをレストアして続きの学習とします。
        latest_checkpoint = tf.train.latest_checkpoint(str(args.input_model))

        if latest_checkpoint:
            saver.restore(sess, latest_checkpoint)

        # ログにグラフ構造を書き出しておきます。
        # こうしておくことでTensorBoardのGraphタブからグラフの構造を確認できます。
        writer.add_graph(tf.get_default_graph())

        # 入力にキューを使用しているのでスレッドをスタートします。
        coordinator = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)

        # 指定イテレーション分学習を行います
        for i in range(args.num_iterations):
            # 既定グローバルステップごとにモデルを書き出してバックアップとします
            gs = tf.train.global_step(sess, global_step)
            if gs % 5000 == 0:
                saver.save(sess, str(training_dir/'model'), global_step=gs)

            # 10回ごとにロスなどのログを取得して書き出します。
            # scalarのログであればそれほど重くはありませんが画像やヒストグラムは
            # 書き出し処理が重くなりがちなので回数は適宜調整したほうがよいでしょう。
            # モデルや規模によりますが勾配のヒストグラムは500回から1000回ごとで十分だと思います。
            if gs % 10 == 0:
                print("global_step = {}".format(gs))
                _, logs =  sess.run([train_op, log_op])
                writer.add_summary(logs, i)
            else:
                _ = sess.run([train_op])

        # イテレーションが終了したら完成モデルを書き出します。
        gs = tf.train.global_step(sess, global_step)
        saver.save(sess, str(training_dir/'model'), global_step=gs)

        # 開放処理はプロセスにまかせます

if __name__ == '__main__':
    main()

2. モデルのチェックポイントファイルを小さくする

convert.pyを使って最適化器の変数などを含まないチェックポイントファイルに変換します。

$ python convert.py
$ ls -lh model/
total 448K
-rw-rw-r-- 1 user user   67 Apr 11 01:00 checkpoint
-rw-rw-r-- 1 user user 419K Apr 11 01:00 model.data-00000-of-00001
-rw-rw-r-- 1 user user  311 Apr 11 01:00 model.index
-rw-rw-r-- 1 user user  18K Apr 11 01:00 model.meta

ソースは以下になります。考え方としては学習したtf.Variableの値を使ってグラフを組み直して保存するという感じです。やろうと思えば学習した値を使ってまったく別のグラフを作ることもできると思います。

# おおまかな構造はtrain.pyと同じです
from argparse import ArgumentParser
from pathlib import Path

import tensorflow as tf

import simple_autoencoder.model as model
import simple_autoencoder.utils as utils

# コマンドラインアプリでは1回1枚の画像を変換することにしてミニバッチサイズを1にします
DEFAULT_MINIBATCH_SIZE=1
DEFAULT_INPUT_MODEL='model_training'
DEFAULT_OUTPUT_MODEL='model'

DEFAULT_INPUT_WIDTH=128
DEFAULT_INPUT_HEIGHT=128

def create_argument_parser():
    # 学習はしないのでインプットキュー関係の引数はなくなります
    parser = ArgumentParser()
    parser.add_argument('-m','--minibatch-size', type=int, default=DEFAULT_MINIBATCH_SIZE)
    parser.add_argument('-i','--input-model', type=str, default=DEFAULT_INPUT_MODEL)
    parser.add_argument('-o','--output-model', type=str, default=DEFAULT_OUTPUT_MODEL)

    return parser

def add_application_arguments(parser):
    parser.add_argument('--width', type=int, default=DEFAULT_INPUT_WIDTH)
    parser.add_argument('--height', type=int, default=DEFAULT_INPUT_HEIGHT)
    return parser

def main():
    parser = create_argument_parser()
    parser = add_application_arguments(parser)
    args = parser.parse_args()
    proc(args)

def proc(args):
    # 学習した変数の値を使いつつコマンドライン用の別のグラフを作ります。

    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # 学習時はファイルリストから画像を読み込みtf.train.shuffle_batchで
        # モデルに渡していました。
        # アプリでの利用のためにtf.placeholderでの画像渡しに変えておきます。
        with tf.name_scope('input'):
            images = tf.placeholder(name='images_placeholder',
                shape=[args.minibatch_size, args.height, args.width, 3], dtype=tf.float32)
            tf.add_to_collection('images_placeholder',images)

        # 推論モデルの作成
        # CPUのみを使用したい場合はここを書き換えてください。
        # 本来ならコマンドラインオプションにするところ。
        with tf.device('/gpu:0'):
            with tf.variable_scope('model'):
                # 推論のみなのでロスや学習のオペレーションは作りません
                out_images = model.inference(images)
                # アプリで使用するため推論に名前をつけておきます
                out_images = tf.identity(out_images, name='out_node')
                tf.add_to_collection('inference_op',out_images)

        # 推論に関わる変数のみを書き出すために名前で変数を集めておきます。
        # もしかしたら必要ないかも知れない。
        inference_variables = [
            v for v in tf.trainable_variables() if v.name.count('/inference/')]

        # セッション作成
        config_proto = tf.ConfigProto(
            allow_soft_placement=True, log_device_placement=False)
        sess = tf.Session( config=config_proto)

        # 変数を初期化
        sess.run(tf.global_variables_initializer())

        # 学習したチェックポイントファイルから変数を復元します。
        inference_saver = tf.train.Saver(var_list=inference_variables)
        latest_checkpoint = tf.train.latest_checkpoint(args.input_model)

        if latest_checkpoint:
            inference_saver.restore(sess, latest_checkpoint)
        else:
            raise EnvironmentError('no checkpoint file')

        # あらためて推論のみのグラフをチェックポイントファイルとして書き出します。
        # graph_defファイルも書き出しておきます。
        converted_dir = Path(args.output_model)
        converted_dir.mkdir(parents=True, exist_ok=True)
        tf.train.write_graph(sess.graph.as_graph_def(), str(converted_dir), 'model.pb')
        inference_saver.save(sess, str(converted_dir/'model'))

if __name__ == '__main__':
    main()

3. チェックポイントファイルをProtocolBuffersファイルに変換する

TensorFlowのソース/パッケージに付属するfreeze_graph.pyを使用してグラフ(構造)と変数(データ)を一つのgraph_def ProtocolBuffersファイルにまとめます。freeze_graph.pyは内部でVariableConstに変換してくれます。

$ bash make_graph_pb.sh
$ ls -lh graph.pb
-rw-rw-r-- 1 user user 422K Apr 12 00:40 graph.pb

シェルスクリプトのソースです。使用しているTensorFlowバージョンのfreeze_graph.pyを見つけて呼ぶだけ。

#!/bin/bash

# tensorflowのインストールパスを取得します。
# pipでインストールしていることを想定しています。
# ネイティブでもpyenvでも取得できるはず。多分。
# 使用しているTensorFlowのバージョンと一致したfreeze_graph.pyが必要なためです。
tfpath=`pip show tensorflow | grep "Location: \(.\+\)$" | sed 's/Location: //'`

# tensorflowパッケージに付属しているfreeze_graph.pyを使用してチェックポイントファイルから
# graph_def ProtocolBuffersファイルに変換します。
python $tfpath/tensorflow/python/tools/freeze_graph.py \
    --input_checkpoint model/model --output_node_names=model/out_node \
    --output_graph=graph.pb --input_graph=model/model.pb

4. コマンドラインアプリに組み込んで使う

ここまでくればモデルの使用はシンプルです。execute.pyが今回の「モデルを使用するアプリケーションになります。graph_defファイルを読み込んでグラフを作成し画像ファイルに適用するアプリケーションです。

$ python execute.py /path/to/your/image/jpg/file.jpg
$ your_favorite_image_viewer out.jpg

ソースは以下になります。読み込みコードはほとんどテンプレなので一旦は覚えてしまうのが速いでしょう(いずれはちゃんと理解しましょう)。

from argparse import ArgumentParser
from pathlib import Path

# 画像の入出力にpillowを使用します
import tensorflow as tf
import numpy as np
from PIL import Image

DEFAULT_INPUT_MODEL='model'
DEFAULT_INPUT_GRAPHDEF='graph.pb'

DEFAULT_INPUT_WIDTH=128
DEFAULT_INPUT_HEIGHT=128

def create_argument_parser():
    # 学習モデルのグラフを指定します
    # 入出力がテンソルの名前も含めて同一仕様であれば差し替えもできます
    parser = ArgumentParser()
    parser.add_argument('-i','--input-graphdef', type=str, default=DEFAULT_INPUT_GRAPHDEF)
    return parser

def add_application_arguments(parser):
    # 変換する画像の指定
    parser.add_argument('imagefiles', nargs='+', type=str)
    return parser

def main():
    parser = create_argument_parser()
    parser = add_application_arguments(parser)
    args = parser.parse_args()
    proc(args)

def proc(args):
    # セッションを作って
    config_proto = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)
    sess = tf.Session(config=config_proto)

    # graph_defファイルを読み込んでデフォルトグラフにします。
    with tf.gfile.FastGFile(args.input_graphdef, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(graph_def, name='')

    # 入力のtf.placeholderを取得します
    images_placeholder = tf.get_default_graph().get_tensor_by_name('input/images_placeholder:0')
    # 推論オペレーションを取得します
    inference_op = tf.get_default_graph().get_tensor_by_name('model/out_node:0')

    i = 0

    for imagepath in args.imagefiles:
        # 画像を読み込んでサイズ1のミニバッチの形式にします。
        input_image = Image.open(imagepath)
        input_image = input_image.resize((DEFAULT_INPUT_WIDTH, DEFAULT_INPUT_HEIGHT))
        input_image = np.expand_dims(np.asarray(input_image), axis=0) / 255

        # 入力画像をplaceholderに仕込んで推論オペレーションを実行します。
        out_images = sess.run(
            [inference_op], feed_dict={images_placeholder: input_image})

        # 画像の値を調整してミニバッチから単一画像行列に変換して保存します。
        out_images = np.multiply(out_images, 255)
        out_images = np.squeeze(out_images, axis=(0,1))
        out_image = Image.fromarray(np.uint8(out_images))
        out_image.save('out_{}.jpg'.format(i))
        
        i += 1

if __name__ == '__main__':
    main()

おわりに

ここまで理解して出来るようになれば「学習は出来たけどファイル大きいしうまくコマンドにまとめられない」という初心者にありがちな行き詰まりは解消されると思います。私はこの記事を書くことで解消しました。

graph_defファイルの読み込みは単純なのでPythonで書くならウェブアプリへの組み込みもすぐにできると思います。

For your happy TensorFlow days.

36
46
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
36
46

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?