5
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

機械学習エンジニアの第一歩 ~MNIST分類API~

Posted at

はじめに

こんにちは!
昨今は機械学習の流行が続いております.
機械学習を業務レベルで行う職種として,機械学習エンジニアやデータサイエンティストがあるかと思います.
今回は,機械学習エンジニアの業務である画像認識のAPIについて書いてみたいと思います.

なお,他の記事も合わせてご覧いただくと,理解が深まると思うので,載せておきます.

・Python(Flask) でサクッと 機械学習 API を作る←おすすめ
https://qiita.com/fam_taro/items/1464c42324f15d7b8223
・Python と Flask で RESTful API を開発する
https://auth0.com/blog/jp-developing-restful-apis-with-python-and-flask/

対象読者

・機械学習を学んだものの,何に活かせばいいかわからない人
・web系をやっていたが,MLもやり始めた人
・機械学習で何ができるのか知りたい人

全体像

①まず,MNISTを使って,0-9の数字を分類するモデルを作成します.
②数字の写真とともにPOSTリクエストを送ります.
③分類結果と確率とともに,レスポンスを返します.

image.png

準備

必要なライブラリは以下の通りです.

requirement.txt
flask
sklearn
numpy 
scipy 
keras
pillow

①モデルの作成

モデルを作るにあたって,Kerasを用います.
Kerasでは非常に簡単に機械学習のモデルが作れてしまいます.
実際,モデルを構築する部分は下記だけです.

model.py
def construct_network():
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1))) 
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.25))
    model.add(Dense(10, activation='softmax'))
    return model

画像なので,CNNを用いております.ネットワークの構成は適当です(笑)
畳み込み層のinput_shapeの設定や,最終層の活性化関数にのみ注意を払えば,適当なモデルでも機能します.
また,最適化するところのコードも合わせて,モデリングのコードはこんな感じです.

model.py
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.optimizers import Adam
import numpy as np

"""
input: 
x_train (shape = (N1, 28, 28, 1))
y_train (shape = (N1, 10))
x_test (shape = (N2, 28, 28, 1))
y_test (shape = (N2, 10))
"""

# hyper parameter
batch_size = 256
epochs = 5

def construct_network():
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.25))
    model.add(Dense(10, activation='softmax'))
    return model

def train(x_train, y_train, x_test, y_test):
    # model
    model = construct_network()
    model.compile(loss='categorical_crossentropy',
                optimizer=Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False),
                metrics=['accuracy'])

    history = model.fit(x_train, y_train,  # 画像とラベルデータ
                        batch_size=batch_size,
                        epochs=epochs,     # エポック数の指定
                        verbose=1, 
                        validation_data=(x_test, y_test))

    score = model.evaluate(x_test, y_test, verbose=0)
    print('Test loss:', score[0])
    print('Test accuracy:', score[1])
    model.save("mnist.h5", include_optimizer=False)

if __name__ == '__main__':
    # read dataset
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    num_classes = 10
    x_train = x_train.reshape(60000, 28, 28, 1)
    x_train = x_train.astype('float32')
    x_test = x_test.reshape(10000, 28, 28, 1)
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    # one-hot encoding
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)
    train(x_train, y_train, x_test, y_test)

画像のshapeやone-hot encodingをしてから,main関数(今回ではtrain関数)にデータを渡しております.
コンパイルしてmodelを最適化したら,最後にモデルを保存します.
詳細は,以下の記事を参考にしてください.

・kerasのmnistのサンプルを読んでみる
https://qiita.com/ash8h/items/29e24fc617b832fba136
・Keras(+Tensorflow)でMNISTしてみる
https://qiita.com/fukuit/items/b3fa460577a0ea139c88

②POSTリクエスト

モデルが構築でき,保存したところで,クライアント側からPOSTリクエストを送ってみます.
ちなみに,まだサーバー側の実装ができていないのでエラーが出ます.

command.sh
curl -F "file=@[filename].jpg" http://localhost:5000/

画像をポストするときはFオプションを使います.
また,ポート番号はFlaskのデフォルトの5000番を使ってます.

③レスポンスを返す

サーバー側の実装です.

app.py
from keras.models import load_model
import flask
import numpy as np
from PIL import Image, ExifTags
from keras import backend as K


app = flask.Flask(__name__)
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'gif'])

def load():
    """modelを読み込む"""
    model = load_model("mnist.h5", compile=False)
    return model

@app.route("/help", methods = ["GET"])
def help():
    response = {"Content-Type": "application/json", 'help': None}
    if flask.request.method == "GET":
        msg = 'exp. curl -F "file=[filename].jpg" "http://localhost:5000/predict/"'
        response["help"] = msg
    return flask.jsonify(response)

def allowed_file(filename):
    """画像型の拡張子になっているか確認する"""
    return '.' in filename and \
        filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS

def transform_img(img):
    """読み込んだimageのshapeをMNISTのshape(28, 28)にする"""
    img = img.convert('L')
    width,height = 28, 28
    img = img.resize((width,height), Image.LANCZOS)
    img_array = np.asarray(img).reshape((1, width, height, 1))
    return img_array

def deal_rotation(img):
    """postした画像が回転してしまった場合,元に戻す"""
    
    for orientation in ExifTags.TAGS.keys() : 
        if ExifTags.TAGS[orientation]=='Orientation' : break 
    exif=dict(img._getexif().items())
    if   exif[orientation] == 3 : 
        img=img.rotate(180, expand=True)
    elif exif[orientation] == 6 : 
        img=img.rotate(270, expand=True)
    elif exif[orientation] == 8 : 
        img=img.rotate(90, expand=True)
    return img

@app.route("/predict", methods=["POST"])
def predict():
    model = load()
    response = {"Content-Type": "application/json",
                "result": None, 
                "probability": None}
    if flask.request.method == "POST":
        if flask.request.files["file"]:
            img = Image.open(flask.request.files["file"])
            img = deal_rotation(img)
            img_array = transform_img(img)
            result = model.predict(img_array,verbose=0)
            K.clear_session()
            response["result"] = str(np.argmax(result))
            response["probability"] = str(np.max(result))
    return flask.jsonify(response)
            

if __name__ == '__main__':
    app.run()

curlで送った画像が勝手に反転してしまうことがあるため,deal_rotation関数を用いております.

完成

それでは完成したものを動かしてみます.
まず,flaskをlocalhostで起動して,

command.sh
python app.py

curlでPOSTリクエストをします.

command.sh
curl -F "file=@[filename].jpg" http://localhost:5000/

filenameをしっかりと指定してあげると,

result.json
{"Content-Type":"application/json","probability":"0.99432","result":"0"}

こんな感じで結果が返ってきます!!

注意点

読み込んだ画像をMNISTと同じ(28, 28)にしているため,太いペンで濃く描かないと全く認識してくれませんのでご注意ください(笑)

最後に

機械学習エンジニアのタスクとしてAPIを作ることが多いと思うので,初心者の方はサクッとAPIを作れるように練習してみてください!
間違い・質問等あればお気軽にコメントをいただけると嬉しいです!!

参考文献

本文中に載せなかったものです.

・How to send image to Flask server from curl request
https://stackoverflow.com/questions/41655946/how-to-send-image-to-flask-server-from-curl-request
・How do I convert a numpy array to (and display) an image?
https://stackoverflow.com/questions/2659312/how-do-i-convert-a-numpy-array-to-and-display-an-image
・PIL thumbnail is rotating my image?
https://stackoverflow.com/questions/4228530/pil-thumbnail-is-rotating-my-image

5
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?