Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
12
Help us understand the problem. What is going on with this article?
@yasuno0327

GoでTensorFlowのAPIを使ってみた話

More than 1 year has passed since last update.

やること

最近機械学習(TensorFlow + Keras)とGo言語を学習する機会があったのでTensorFlowのGoのAPIを試しに使ってみたいと思います。

実際にやることとしては
PythonでMNIST手書き数字データを学習させ
TensorFlow GoのAPIから学習済みデータを取得し
入力された手書き数字データに対して予測結果を返すシンプルなJSON APIサーバーを作りたいと思います。

そもそもTensorFlowって何?

Googleが開発した機械学習、数値解析ライブラリです、内部ではデータフローグラフを利用しており、ニューラルネットワークなどの複雑なネットワークをわかりやすく記述する事ができます。
Python, Go, C, C++など複数の言語でAPIを提供しています。
ただしモデルの構築、学習が出来るのはPythonだけです。

Goで扱える形で学習済みモデルを保存する

今回は以前書かせて頂いたこちらの記事のプログラムを使って学習済みモデルを作り、Goで保存した学習済みモデルを呼び出すようにしたいと思います。
https://qiita.com/yasuno0327/items/2b978d60eade1334ad2c

TensorFlow Go APIでは学習済みモデルの呼び出しはProtocol Buffersの形式でなくてはいけません。
まず、上の記事のプログラムを改修して学習済みモデルをProtocol Buffersの形式で保存できるようにする必要があります。

具体的には次のコードを追加します。

mnist.py
import tensorflow as tf
from keras import backend

sess = tf.Session()
backend.set_session(sess)
mnist.py
from tensorflow.python.saved_model import builder as save_model_builder
from tensorflow.python.saved_model.signature_def_utils import predict_signature_def
from tensorflow.python.saved_model import tag_constants

# Save model
builder = save_model_builder.SavedModelBuilder('mymnist')
builder.add_meta_graph_and_variables(sess, ['mnisttag'])
builder.save()
sess.close()

上記のようにする事によりmymnistディレクトリが生成され、その中に学習済みモデルに関するデータがProtocol Buffersの形式で保存されます。

mymnist
├── saved_model.pb
└── variables
    ├── variables.data-00000-of-00001
    └── variables.index

GoでServeする

次に本題のGoのTensorFlow APIを使って入力された画像データを予測したいと思います
まず学習済みモデルを呼び出します。

client.go
model, err := tensorflow.LoadSavedModel("mymnist", []string{"mnisttag"}, nil)
if err != nil {
  return "", err
}
defer model.Session.Close()

tensorflow.LoadSavedModelを使うことにより指定したディレクトリ内の学習済みモデルを呼び出せます。
戻ってきたmodelには呼び出した学習済みモデルを元に生成されたtensorflow.SavedModelが帰ってきます。

画像データをTensorに変換する

与えられた画像データのクラスを予測するためには画像データをTensorFlowで扱えるTensorに変換し、学習済みモデルに与えてSessionを走らせる必要があります。
こちらはPythonで学習させた時に最初にinputさせているTensorと同じ形に整形して与えなければいけません。

client.go
tensor, err := tensorflow.NewTensor(imageBuffer.String())

上記のようにするimageデータを元に1次元のimageのtensor(*tensorflow.Tensor)が帰ってきます。

しかし、今回は画像を28x28x1(height x width x channels)にリサイズする必要があるので次のような処理をする必要があります。

client.go
graph, input, output, err := makeTransFormImageGraph(format)
session, err := tensorflow.NewSession(graph, nil)
defer session.Close()
normalized, err := session.Run(
  map[tensorflow.Output]*tensorflow.Tensor{input: tensor},
  []tensorflow.Output{output},
  nil,
)

ここではmakeTransFormImageGraphで画像リサイズ用のグラフを構築してSession.Runでリサイズ処理をしています。
makeTransFormImageGraphを見ると

client.go
// inputするimageの情報を返す [batch size][width][height][channels]
func makeTransFormImageGraph(format string) (graph *tensorflow.Graph, input, output tensorflow.Output, err error) {
    const (
        Height, Width = 28, 28
    )
    s := op.NewScope()
    input = op.Placeholder(s, tensorflow.String) // inputはstringで渡ってくる
    var decode tensorflow.Output
        // Jpegをデコードする
    decode = op.DecodeJpeg(s, input, op.DecodeJpegChannels(1)) //0,1だけなので1

    // tensorにbatch sizeを加える
    decodeWithBatch := op.ExpandDims(
        s,
        op.Cast(s, decode, tensorflow.Float),
        op.Const(s.SubScope("make_batch"), int32(0)),
    )
    // imageを28x28にリサイズ
    output = op.ResizeBilinear(
        s,
        decodeWithBatch,
        op.Const(s.SubScope("size"), []int32{Height, Width}),
    )
    graph, err = s.Finalize()
    return graph, input, output, err
}

少しわかりにくいですがop.NewScopeでその命令のスコープを定義しています。
これを関数に渡すと渡したScope内で実行されます。
op.Const(s.SubScope("size"), []int32{Height, Width})などの新しいOutput情報を返す処理などはs.SubScopeを使って新しいスコープを作成し実行します。

ここで定義したグラフを実行すると1x28x28x1にリサイズする事ができます。

client.go
func ConvertImageToTensor(imageBuffer *bytes.Buffer, format string) (*tensorflow.Tensor, error) {
    format = "jpeg"
    tensor, err := tensorflow.NewTensor(imageBuffer.String())
    if err != nil {
        return nil, err
    }
    graph, input, output, err := makeTransFormImageGraph(format)
    if err != nil {
        return nil, err
    }
    session, err := tensorflow.NewSession(graph, nil)
    if err != nil {
        return nil, err
    }
    defer session.Close()
    normalized, err := session.Run(
        map[tensorflow.Output]*tensorflow.Tensor{input: tensor},
        []tensorflow.Output{output},
        nil)
    if err != nil {
        return nil, err
    }
    return normalized[0], nil
}

渡されたテンソルから結果を出力する。

ここで先ほど取得した学習済みモデルをリサイズした画像に対して走らせると結果を出力できます。

client.go
func Recognition(tensor *tensorflow.Tensor) (string, error) {
    var probability float64
    // tf.saved_model.builder in Pythonで構築したモデルを呼び出す
    model, err := tensorflow.LoadSavedModel("mymnist", []string{"mnisttag"}, nil)
    if err != nil {
        return "", err
    }
    defer model.Session.Close()

    result, err := model.Session.Run(
        map[tensorflow.Output]*tensorflow.Tensor{
            model.Graph.Operation("conv2d_1_input").Output(0): tensor,
        },
        []tensorflow.Output{
            model.Graph.Operation("dense_2/Softmax").Output(0),
        },
        nil,
    )

    if err != nil {
        return "", err
    }
    labels := []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}
    probabilities := result[0].Value().([][]float32)[0]
    max := 0
    for i, v := range probabilities {
        if float64(v) > probability {
            probability = float64(probabilities[i])
            max = i
        }
    }
    return labels[max], nil
}

probabilitiesにはそのクラスである確率が帰ってくるのでその最大値を取るラベルを返せば完成です!
今回のモデルではSoftmax関数を使っているため0,1しか帰ってきませんが....笑

以上です!
少し走り書きで書いてしまったため間違えてる所があるかもしれません。
ご指摘頂けるとありがたいです!

ソースコードはこちらに置いてあります。

12
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
shinonomeinc
東京理科大学発ベンチャー。提携大学内にソフトウェア研究所を組織し、学生向けのTech教育を提供しています。

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
12
Help us understand the problem. What is going on with this article?