はじめに
こんにちは!
昨今は機械学習の流行が続いております.
機械学習を業務レベルで行う職種として,機械学習エンジニアやデータサイエンティストがあるかと思います.
今回は,機械学習エンジニアの業務である画像認識のAPIについて書いてみたいと思います.
なお,他の記事も合わせてご覧いただくと,理解が深まると思うので,載せておきます.
・Python(Flask) でサクッと 機械学習 API を作る←おすすめ
https://qiita.com/fam_taro/items/1464c42324f15d7b8223
・Python と Flask で RESTful API を開発する
https://auth0.com/blog/jp-developing-restful-apis-with-python-and-flask/
対象読者
・機械学習を学んだものの,何に活かせばいいかわからない人
・web系をやっていたが,MLもやり始めた人
・機械学習で何ができるのか知りたい人
全体像
①まず,MNISTを使って,0-9の数字を分類するモデルを作成します.
②数字の写真とともにPOSTリクエストを送ります.
③分類結果と確率とともに,レスポンスを返します.
準備
必要なライブラリは以下の通りです.
flask
sklearn
numpy
scipy
keras
pillow
①モデルの作成
モデルを作るにあたって,Kerasを用います.
Kerasでは非常に簡単に機械学習のモデルが作れてしまいます.
実際,モデルを構築する部分は下記だけです.
def construct_network():
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.25))
model.add(Dense(10, activation='softmax'))
return model
画像なので,CNNを用いております.ネットワークの構成は適当です(笑)
畳み込み層のinput_shapeの設定や,最終層の活性化関数にのみ注意を払えば,適当なモデルでも機能します.
また,最適化するところのコードも合わせて,モデリングのコードはこんな感じです.
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.optimizers import Adam
import numpy as np
"""
input:
x_train (shape = (N1, 28, 28, 1))
y_train (shape = (N1, 10))
x_test (shape = (N2, 28, 28, 1))
y_test (shape = (N2, 10))
"""
# hyper parameter
batch_size = 256
epochs = 5
def construct_network():
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.25))
model.add(Dense(10, activation='softmax'))
return model
def train(x_train, y_train, x_test, y_test):
# model
model = construct_network()
model.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False),
metrics=['accuracy'])
history = model.fit(x_train, y_train, # 画像とラベルデータ
batch_size=batch_size,
epochs=epochs, # エポック数の指定
verbose=1,
validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
model.save("mnist.h5", include_optimizer=False)
if __name__ == '__main__':
# read dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
num_classes = 10
x_train = x_train.reshape(60000, 28, 28, 1)
x_train = x_train.astype('float32')
x_test = x_test.reshape(10000, 28, 28, 1)
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
# one-hot encoding
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
train(x_train, y_train, x_test, y_test)
画像のshapeやone-hot encodingをしてから,main関数(今回ではtrain関数)にデータを渡しております.
コンパイルしてmodelを最適化したら,最後にモデルを保存します.
詳細は,以下の記事を参考にしてください.
・kerasのmnistのサンプルを読んでみる
https://qiita.com/ash8h/items/29e24fc617b832fba136
・Keras(+Tensorflow)でMNISTしてみる
https://qiita.com/fukuit/items/b3fa460577a0ea139c88
②POSTリクエスト
モデルが構築でき,保存したところで,クライアント側からPOSTリクエストを送ってみます.
ちなみに,まだサーバー側の実装ができていないのでエラーが出ます.
curl -F "file=@[filename].jpg" http://localhost:5000/
画像をポストするときはFオプションを使います.
また,ポート番号はFlaskのデフォルトの5000番を使ってます.
③レスポンスを返す
サーバー側の実装です.
from keras.models import load_model
import flask
import numpy as np
from PIL import Image, ExifTags
from keras import backend as K
app = flask.Flask(__name__)
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'gif'])
def load():
"""modelを読み込む"""
model = load_model("mnist.h5", compile=False)
return model
@app.route("/help", methods = ["GET"])
def help():
response = {"Content-Type": "application/json", 'help': None}
if flask.request.method == "GET":
msg = 'exp. curl -F "file=[filename].jpg" "http://localhost:5000/predict/"'
response["help"] = msg
return flask.jsonify(response)
def allowed_file(filename):
"""画像型の拡張子になっているか確認する"""
return '.' in filename and \
filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
def transform_img(img):
"""読み込んだimageのshapeをMNISTのshape(28, 28)にする"""
img = img.convert('L')
width,height = 28, 28
img = img.resize((width,height), Image.LANCZOS)
img_array = np.asarray(img).reshape((1, width, height, 1))
return img_array
def deal_rotation(img):
"""postした画像が回転してしまった場合,元に戻す"""
for orientation in ExifTags.TAGS.keys() :
if ExifTags.TAGS[orientation]=='Orientation' : break
exif=dict(img._getexif().items())
if exif[orientation] == 3 :
img=img.rotate(180, expand=True)
elif exif[orientation] == 6 :
img=img.rotate(270, expand=True)
elif exif[orientation] == 8 :
img=img.rotate(90, expand=True)
return img
@app.route("/predict", methods=["POST"])
def predict():
model = load()
response = {"Content-Type": "application/json",
"result": None,
"probability": None}
if flask.request.method == "POST":
if flask.request.files["file"]:
img = Image.open(flask.request.files["file"])
img = deal_rotation(img)
img_array = transform_img(img)
result = model.predict(img_array,verbose=0)
K.clear_session()
response["result"] = str(np.argmax(result))
response["probability"] = str(np.max(result))
return flask.jsonify(response)
if __name__ == '__main__':
app.run()
curlで送った画像が勝手に反転してしまうことがあるため,deal_rotation関数を用いております.
完成
それでは完成したものを動かしてみます.
まず,flaskをlocalhostで起動して,
python app.py
curlでPOSTリクエストをします.
curl -F "file=@[filename].jpg" http://localhost:5000/
filenameをしっかりと指定してあげると,
{"Content-Type":"application/json","probability":"0.99432","result":"0"}
こんな感じで結果が返ってきます!!
注意点
読み込んだ画像をMNISTと同じ(28, 28)にしているため,太いペンで濃く描かないと全く認識してくれませんのでご注意ください(笑)
最後に
機械学習エンジニアのタスクとしてAPIを作ることが多いと思うので,初心者の方はサクッとAPIを作れるように練習してみてください!
間違い・質問等あればお気軽にコメントをいただけると嬉しいです!!
参考文献
本文中に載せなかったものです.
・How to send image to Flask server from curl request
https://stackoverflow.com/questions/41655946/how-to-send-image-to-flask-server-from-curl-request
・How do I convert a numpy array to (and display) an image?
https://stackoverflow.com/questions/2659312/how-do-i-convert-a-numpy-array-to-and-display-an-image
・PIL thumbnail is rotating my image?
https://stackoverflow.com/questions/4228530/pil-thumbnail-is-rotating-my-image