28
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Gopher道場Advent Calendar 2018

Day 7

GoでTensorFlowのAPIを使ってみた話

Last updated at Posted at 2018-12-07

やること

最近機械学習(TensorFlow + Keras)とGo言語を学習する機会があったのでTensorFlowのGoのAPIを試しに使ってみたいと思います。

実際にやることとしては
PythonでMNIST手書き数字データを学習させ
TensorFlow GoのAPIから学習済みデータを取得し
入力された手書き数字データに対して予測結果を返すシンプルなJSON APIサーバーを作りたいと思います。

そもそもTensorFlowって何?

Googleが開発した機械学習、数値解析ライブラリです、内部ではデータフローグラフを利用しており、ニューラルネットワークなどの複雑なネットワークをわかりやすく記述する事ができます。
Python, Go, C, C++など複数の言語でAPIを提供しています。
ただしモデルの構築、学習が出来るのはPythonだけです。

Goで扱える形で学習済みモデルを保存する

今回は以前書かせて頂いたこちらの記事のプログラムを使って学習済みモデルを作り、Goで保存した学習済みモデルを呼び出すようにしたいと思います。
https://qiita.com/yasuno0327/items/2b978d60eade1334ad2c

TensorFlow Go APIでは学習済みモデルの呼び出しはProtocol Buffersの形式でなくてはいけません。
まず、上の記事のプログラムを改修して学習済みモデルをProtocol Buffersの形式で保存できるようにする必要があります。

具体的には次のコードを追加します。

mnist.py
import tensorflow as tf
from keras import backend

sess = tf.Session()
backend.set_session(sess)
mnist.py
from tensorflow.python.saved_model import builder as save_model_builder
from tensorflow.python.saved_model.signature_def_utils import predict_signature_def
from tensorflow.python.saved_model import tag_constants

# Save model
builder = save_model_builder.SavedModelBuilder('mymnist')
builder.add_meta_graph_and_variables(sess, ['mnisttag'])
builder.save()
sess.close()

上記のようにする事によりmymnistディレクトリが生成され、その中に学習済みモデルに関するデータがProtocol Buffersの形式で保存されます。

mymnist
├── saved_model.pb
└── variables
    ├── variables.data-00000-of-00001
    └── variables.index

GoでServeする

次に本題のGoのTensorFlow APIを使って入力された画像データを予測したいと思います
まず学習済みモデルを呼び出します。

client.go
model, err := tensorflow.LoadSavedModel("mymnist", []string{"mnisttag"}, nil)
if err != nil {
  return "", err
}
defer model.Session.Close()

tensorflow.LoadSavedModelを使うことにより指定したディレクトリ内の学習済みモデルを呼び出せます。
戻ってきたmodelには呼び出した学習済みモデルを元に生成されたtensorflow.SavedModelが帰ってきます。

画像データをTensorに変換する

与えられた画像データのクラスを予測するためには画像データをTensorFlowで扱えるTensorに変換し、学習済みモデルに与えてSessionを走らせる必要があります。
こちらはPythonで学習させた時に最初にinputさせているTensorと同じ形に整形して与えなければいけません。

client.go
tensor, err := tensorflow.NewTensor(imageBuffer.String())

上記のようにするimageデータを元に1次元のimageのtensor(*tensorflow.Tensor)が帰ってきます。

しかし、今回は画像を28x28x1(height x width x channels)にリサイズする必要があるので次のような処理をする必要があります。

client.go
graph, input, output, err := makeTransFormImageGraph(format)
session, err := tensorflow.NewSession(graph, nil)
defer session.Close()
normalized, err := session.Run(
  map[tensorflow.Output]*tensorflow.Tensor{input: tensor},
  []tensorflow.Output{output},
  nil,
)

ここではmakeTransFormImageGraphで画像リサイズ用のグラフを構築してSession.Runでリサイズ処理をしています。
makeTransFormImageGraphを見ると

client.go
// inputするimageの情報を返す [batch size][width][height][channels]
func makeTransFormImageGraph(format string) (graph *tensorflow.Graph, input, output tensorflow.Output, err error) {
	const (
		Height, Width = 28, 28
	)
	s := op.NewScope()
	input = op.Placeholder(s, tensorflow.String) // inputはstringで渡ってくる
	var decode tensorflow.Output
        // Jpegをデコードする
	decode = op.DecodeJpeg(s, input, op.DecodeJpegChannels(1)) //0,1だけなので1

	// tensorにbatch sizeを加える
	decodeWithBatch := op.ExpandDims(
		s,
		op.Cast(s, decode, tensorflow.Float),
		op.Const(s.SubScope("make_batch"), int32(0)),
	)
	// imageを28x28にリサイズ
	output = op.ResizeBilinear(
		s,
		decodeWithBatch,
		op.Const(s.SubScope("size"), []int32{Height, Width}),
	)
	graph, err = s.Finalize()
	return graph, input, output, err
}

少しわかりにくいですがop.NewScopeでその命令のスコープを定義しています。
これを関数に渡すと渡したScope内で実行されます。
op.Const(s.SubScope("size"), []int32{Height, Width})などの新しいOutput情報を返す処理などはs.SubScopeを使って新しいスコープを作成し実行します。

ここで定義したグラフを実行すると1x28x28x1にリサイズする事ができます。

client.go
func ConvertImageToTensor(imageBuffer *bytes.Buffer, format string) (*tensorflow.Tensor, error) {
	format = "jpeg"
	tensor, err := tensorflow.NewTensor(imageBuffer.String())
	if err != nil {
		return nil, err
	}
	graph, input, output, err := makeTransFormImageGraph(format)
	if err != nil {
		return nil, err
	}
	session, err := tensorflow.NewSession(graph, nil)
	if err != nil {
		return nil, err
	}
	defer session.Close()
	normalized, err := session.Run(
		map[tensorflow.Output]*tensorflow.Tensor{input: tensor},
		[]tensorflow.Output{output},
		nil)
	if err != nil {
		return nil, err
	}
	return normalized[0], nil
}

渡されたテンソルから結果を出力する。

ここで先ほど取得した学習済みモデルをリサイズした画像に対して走らせると結果を出力できます。

client.go
func Recognition(tensor *tensorflow.Tensor) (string, error) {
	var probability float64
	// tf.saved_model.builder in Pythonで構築したモデルを呼び出す
	model, err := tensorflow.LoadSavedModel("mymnist", []string{"mnisttag"}, nil)
	if err != nil {
		return "", err
	}
	defer model.Session.Close()

	result, err := model.Session.Run(
		map[tensorflow.Output]*tensorflow.Tensor{
			model.Graph.Operation("conv2d_1_input").Output(0): tensor,
		},
		[]tensorflow.Output{
			model.Graph.Operation("dense_2/Softmax").Output(0),
		},
		nil,
	)

	if err != nil {
		return "", err
	}
	labels := []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}
	probabilities := result[0].Value().([][]float32)[0]
	max := 0
	for i, v := range probabilities {
		if float64(v) > probability {
			probability = float64(probabilities[i])
			max = i
		}
	}
	return labels[max], nil
}

probabilitiesにはそのクラスである確率が帰ってくるのでその最大値を取るラベルを返せば完成です!
今回のモデルではSoftmax関数を使っているため0,1しか帰ってきませんが....笑

以上です!
少し走り書きで書いてしまったため間違えてる所があるかもしれません。
ご指摘頂けるとありがたいです!

ソースコードはこちらに置いてあります。

28
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
28
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?