TL;DR
- SageMaker で MXNet モデルをデプロイして推論する際は、try & except で囲んで、エラーを返すようにする
- 何もしないとエラーが出ても python のエラーメッセージが出ず、ただ 500 Internal Server Error がでるので、デバッグがつらい
MXNet モデル用の推論コード
モデルを読み込むための model_fn
が必須で、これに加えて前処理、予測、後処理を担当する input_fn, predict_fn, output_fn
をオプションで実装できます (doc)。input_fn, predict_fn, output_fn
を一つの transform_fn
で書くこともできます (doc)。個人的には transform_fn
のほうが好み。
先日こんな推論用のコードを書きました。これはここの日本語サンプルを参考にしたものです。
def model_fn(model_dir):
net = gluon.nn.SymbolBlock.imports(model_dir+ '/model/detector-symbol.json',
['data'],
model_dir+ '/model/detector-0000.params')
return net
def transform_fn(net, data, input_content_type, output_content_type):
data = json.loads(data)
nda = mx.nd.array(data)
class_IDs, scores, bounding_boxs = net(nda)
output_list = []
for i in range(class_IDs.shape[0]):
exist_IDs = np.where(class_IDs[i,:,0].asnumpy() >= 0)
output = {
"class_ids": class_IDs[i,exist_IDs].asnumpy().tolist(),
"scores": scores[i,exist_IDs].asnumpy().tolist(),
"bbox": bounding_boxs[i,exist_IDs].asnumpy().tolist()
}
output_list.append(output)
response_body = json.dumps(output_list)
return response_body, output_content_type
これで推論してみたんですが、うんともすんとも言わない。実はこのスクリプトの最初のところでimport json
し忘れていたんです。しかし、ログはInternal Server Error を表す500しか残っていません。ローカルモードで推論を実行するとこんな感じです。
algo-1-yixtf_1 | 2020-03-28 15:03:25,278 [INFO ] W-9001-model ACCESS_LOG - /172.18.0.1:40518 "POST /invocations HTTP/1.1" 500 414
いやさすがにこれだけでは、import 忘れに気づかないです。MXNet の推論用のコンテナが MMS に切り替わってからこんな感じです。もとは python の API サーバだったので、python のエラーはそのままログに出ていましたね。どうすればいいでしょう。
エラーを明示的に返すようにしましょう
transform_fn
を try & except で囲んで、エラーの内容もしっかり返せるようにしておきましょう。
def transform_fn(net, data, input_content_type, output_content_type):
try:
data = json.loads(data)
nda = mx.nd.array(data)
class_IDs, scores, bounding_boxs = net(nda)
output_list = []
for i in range(class_IDs.shape[0]):
exist_IDs = np.where(class_IDs[i,:,0].asnumpy() >= 0)
output = {
"class_ids": class_IDs[i,exist_IDs].asnumpy().tolist(),
"scores": scores[i,exist_IDs].asnumpy().tolist(),
"bbox": bounding_boxs[i,exist_IDs].asnumpy().tolist()
}
output_list.append(output)
response_body = json.dumps(output_list)
return response_body, output_content_type
except Exception as e:
print(e)
return json.dumps(str(e)), output_content_type
ログはこんな感じで出力されます。
algo-1-yixtf_1 | 2020-03-28 15:03:25,278 [INFO ] W-9001-model-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - name 'json' is not defined
algo-1-yixtf_1 | 2020-03-28 15:03:25,278 [INFO ] W-9001-model ACCESS_LOG - /172.18.0.1:40518 "POST /invocations HTTP/1.1" 500 414
確かに name 'json' is not defined
と言ってきているので、ピンと来る方は json インポートし忘れたことに気づきますね。
json.dumps(str(e))
でエラーの内容をクライアントに返しているので、推論をリクエストした側でもエラー内容を受け取ることができますよ。エラーメッセージの表示などに利用できると思います。