LoginSignup
2
2

More than 3 years have passed since last update.

SageMaker で MXNet モデルをデプロイ・推論する際のデバッグ

Last updated at Posted at 2020-03-28

TL;DR

  • SageMaker で MXNet モデルをデプロイして推論する際は、try & except で囲んで、エラーを返すようにする
  • 何もしないとエラーが出ても python のエラーメッセージが出ず、ただ 500 Internal Server Error がでるので、デバッグがつらい

MXNet モデル用の推論コード

モデルを読み込むための model_fn が必須で、これに加えて前処理、予測、後処理を担当する input_fn, predict_fn, output_fn をオプションで実装できます (doc)。input_fn, predict_fn, output_fn を一つの transform_fn で書くこともできます (doc)。個人的には transform_fn のほうが好み。

先日こんな推論用のコードを書きました。これはここの日本語サンプルを参考にしたものです。

def model_fn(model_dir):
    net = gluon.nn.SymbolBlock.imports(model_dir+ '/model/detector-symbol.json', 
                                       ['data'], 
                                       model_dir+ '/model/detector-0000.params')
    return net


def transform_fn(net, data, input_content_type, output_content_type):
    data = json.loads(data)
    nda = mx.nd.array(data)
    class_IDs, scores, bounding_boxs = net(nda)

    output_list = []
    for i in range(class_IDs.shape[0]):
        exist_IDs = np.where(class_IDs[i,:,0].asnumpy() >= 0)
        output = {
            "class_ids": class_IDs[i,exist_IDs].asnumpy().tolist(),
            "scores": scores[i,exist_IDs].asnumpy().tolist(),
            "bbox": bounding_boxs[i,exist_IDs].asnumpy().tolist()
        }
        output_list.append(output)

    response_body = json.dumps(output_list)
    return response_body, output_content_type

これで推論してみたんですが、うんともすんとも言わない。実はこのスクリプトの最初のところでimport jsonし忘れていたんです。しかし、ログはInternal Server Error を表す500しか残っていません。ローカルモードで推論を実行するとこんな感じです。

algo-1-yixtf_1  | 2020-03-28 15:03:25,278 [INFO ] W-9001-model ACCESS_LOG - /172.18.0.1:40518 "POST /invocations HTTP/1.1" 500 414

いやさすがにこれだけでは、import 忘れに気づかないです。MXNet の推論用のコンテナが MMS に切り替わってからこんな感じです。もとは python の API サーバだったので、python のエラーはそのままログに出ていましたね。どうすればいいでしょう。

エラーを明示的に返すようにしましょう

transform_fnを try & except で囲んで、エラーの内容もしっかり返せるようにしておきましょう。

def transform_fn(net, data, input_content_type, output_content_type):
    try:
        data = json.loads(data)
        nda = mx.nd.array(data)
        class_IDs, scores, bounding_boxs = net(nda)

        output_list = []
        for i in range(class_IDs.shape[0]):
            exist_IDs = np.where(class_IDs[i,:,0].asnumpy() >= 0)
            output = {
                "class_ids": class_IDs[i,exist_IDs].asnumpy().tolist(),
                "scores": scores[i,exist_IDs].asnumpy().tolist(),
                "bbox": bounding_boxs[i,exist_IDs].asnumpy().tolist()
            }
            output_list.append(output)

        response_body = json.dumps(output_list)
        return response_body, output_content_type

    except Exception as e:
        print(e)
        return json.dumps(str(e)), output_content_type

ログはこんな感じで出力されます。

algo-1-yixtf_1  | 2020-03-28 15:03:25,278 [INFO ] W-9001-model-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - name 'json' is not defined
algo-1-yixtf_1  | 2020-03-28 15:03:25,278 [INFO ] W-9001-model ACCESS_LOG - /172.18.0.1:40518 "POST /invocations HTTP/1.1" 500 414

確かに name 'json' is not defined と言ってきているので、ピンと来る方は json インポートし忘れたことに気づきますね。

json.dumps(str(e)) でエラーの内容をクライアントに返しているので、推論をリクエストした側でもエラー内容を受け取ることができますよ。エラーメッセージの表示などに利用できると思います。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2