LoginSignup
37
37

More than 3 years have passed since last update.

個人開発したディープラーニングモデルの公開方法

Posted at

はじめに

近年のディープラーニングブームの要因の一つは,tensorflow,kerasやpytorchといったフレームワークが整備され,誰でも簡単に動かせるようになったからだと思います。google colabなどの登場のおかげで,個人でもこんなこと出来たら面白いなって思ったことを容易に試してみることができる時代になりました。
試したらやっぱり公開してみたくなりますよね。
そこで公開方法についてまとめてみました。

サーバーで処理させる

image.png

VPSやAWS,GCPといったサーバーで処理させる方法です。クライアント側でブラウザから処理したいもの(例えば画像)をアップロードしてもらい,サーバー側で処理した結果をブラウザに表示することで,学習したモデルを公開することが出来ます。ディープラーニングのフレームワークはだいたいpythonの物が多いのでflaskなどでサーバーを立てると,受け取った処理対象をそのままpytorchなどで処理して結果を返すことが出来ます。

メリット

サーバー側で処理するので,クライアントのスペックなどを考える必要ありません。サーバーにインストールしているライブラリなら何でも使えます。前処理や後処理で各種ライブラリを使いたい場合はサーバー側で処理させるしかない場合も多いかと思います。

lyrics2chords
歌詞からコード進行の生成 Attentionによる可視化
これは私の作ってみた「歌詞からコード進行を生成するモデル」なのですが,ユーザーが入力した歌詞の前処理として形態素解析を行っています。形態素解析にはsudachiというライブラリを使っていますが,このような前処理を行うにはサーバーで処理させるのが一番簡単でした。

デメリット

サーバーの負担が大きいです。ディープラーニングのモデルであれば推論とはいえ計算量は少なくないため,もしたくさんのアクセスが同時にあったりするとしょぼいサーバーでは処理しきれなくなってしまいます。
まあ,私が作ったものにそんなアクセスがあった試しはないですが。。。
多くのアクセスを処理できるようにするにはサーバーを増強すればいいわけですが,そうなると費用がたくさん発生することになるため個人開発でちょっと公開する目的では辛いと思います。サーバーサイドのノウハウも必要になりそうです。

クライアントで処理させる

image.png

クライアント,つまりブラウザで処理させる方法です。tensorflow.jsやONNX.jsが登場したため,各種フレームワークで学習させたモデルをJavascriptで動かすことが出来ます。図のようにブラウザに学習済みモデルを渡してクライアントのPCやスマホ上で推論させてしまいます。

メリット

サーバーの役割はhtmlファイルと一緒に学習済みモデルを渡してやるだけなので負荷を気にする必要があまりないです。

デメリット

基本的にJavascriptで処理することになるため前処理などを別途書かなければいけません。ライブラリがあればいいですがpythonに比べるとデータ処理に関して充実しているとは言えないと思います。(詳しくないのでよくわからないのですが。)
また,ONNX.jsを動かしてみてわかったのですがすべてのニューラルネットワークの処理に対応しているわけではないようです。

オート般若心経
効率よく徳を稼げるように,般若心経の一文字目を書くと同じ筆跡で続きを書いてくれるようにした
上記では画像生成のモデルをONNX.jsで動かしていますが,ConvTransposeが使えませんでした。画像生成ではConvTransposeはよく使われるので,早く対応してほしいところです。concatを駆使すれば制限はありますが通常のconv2Dで等価な処理を書くことは出来ます。
処理速度は書き方が悪いのかcpuモードではあまり早くはないです。webGLで動かせるみたいなのですが,今ところwindows上のブラウザ限定らしく処理速度は未確認です。

あとモデルの処理内容が丸わかりなONNXファイルをユーザーにダウンロードさせることになり,javascriptのコードも丸見えなので,基本すべて公開してしまうことになるのも場合によってはデメリットかもしれません。

google colabを使う

Colaboratory上で学習したモデルをngrokを使って簡易デモする
これは常時公開するというよりは,ちょっと作ってみたのを友人に見せたいってときに使える方法です。基本的には「サーバーで処理させる」と一緒なのですが,ngrokを使うことで一時的に外部からcolab上のプロセスにアクセスできるURLを発行することが出来ます。

メリット

GPUが使えるので結構重い処理でも問題なく出来てしまう。

デメリット

colabのノートブック自体が最大12時間しか連続して動かせませんし,ngrokの制限もあるので,不特定多数の人に公開する目的では使えません。

websocketを使ったリアルタイムデモ

モデルの公開方法とはすこし違いますが,colabのGPUを使ってカメラ映像をリアルタイムに処理する方法を見つけたのでメモしておきます。

colab.ipynb
!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
!unzip ngrok-stable-linux-amd64.zip
!pip install bottle
!pip install bottle_websocket
!pip install gevent

get_ipython().system_raw('./ngrok http 6006 &')
! curl -s http://localhost:4040/api/tunnels | python3 -c \
    "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

import numpy as np
import cv2
import json
import bottle
import gevent
from bottle.ext.websocket import GeventWebSocketServer
from bottle.ext.websocket import websocket

app = bottle.Bottle()
@app.route('/', apply=[websocket])
def wsbin(ws):
    while True:
        body = ws.receive()
        if not body:
            break

        #文字列を画像にデコード
        data_np = np.frombuffer(body, dtype='uint8')
        decimg = cv2.imdecode(data_np, 3)

        #############何かの処理###############
        out_img = decimg
        #############何かの処理###############

        #文字列にエンコード
        _, encimg = cv2.imencode(".jpg", out_img, [int(cv2.IMWRITE_JPEG_QUALITY), 50])
        img_str = encimg.tostring()
        ws.send(img_str)

app.run(host='0.0.0.0', port=6006, server=GeventWebSocketServer)

上記をcolab側で動かすと
https://XXXXXXXXXX.ngrok.io
みたいなURLが発行されるので,httpsをwssに変えて下記のカメラを使う側のPCのコードの最初に貼り付けます。

mypc.py
import cv2
import numpy as np
import websocket

#WSするURL設定
url = 'wss://XXXXXXXXXX.ngrok.io'
ws = websocket.create_connection(url)
# 画像サイズが大きいと送受信に時間がかかるので適宜サイズ調整したほうが良い
cap = cv2.VideoCapture(0) 
while True: 
    try:
        #カメラ画像取得
        ret, frame = cap.read()
        #反転する場合
        if 1:frame = np.fliplr(frame)

        #文字列にエンコード
        _, encimg = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 50])
        img_str = encimg.tostring()

        #送受信
        ws.send_binary(img_str)
        body = ws.recv()

        #文字列を画像にデコード
        data_np = np.frombuffer(body, dtype='uint8')
        decimg = cv2.imdecode(data_np, 3)

        cv2.imshow("window", decimg)
        k = cv2.waitKey(1) & 0xFF
        #qで終了
        if k == ord('q'):
            break
    except:
        print('Error')

これを使ってUGATITというアニメ画像に変換するモデルを動かしてみた結果です。

まとめ

開発したディープラーニングモデルを公開する方法をまとめてみました。
これらは私が今までやってみた方法ですので他のやりかたもあるかもしれません。
他のよい方法があればコメント頂けたらと思います。

37
37
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
37
37