LoginSignup
7
1

More than 1 year has passed since last update.

NVIDIA Triton Inference Serverで推論サーバーを作成してみた。

Last updated at Posted at 2022-12-07

Triton Inference Serverで推論サーバーを作成してみた。

はじめに

こんにちは! @nyanchu22 です。
この記事は、ラクス Advent Calendar 2022の8日目です。
普段は機械学習を使用したプロダクトでMLOps領域を担当しています。
最近気になっているNVIDIA Triton Inference Serverを試してみたので、
NVIDIA Triton Inference Serverについて書きたいと思います。

忙しい人/MLOpsあまり興味ない人/実装興味ない人のためのまとめ

  • Trition Inference Serverは存在するほとんどの機械学習フレームワークに対応
  • オンライン推測及びバッチ推論を高速で実行できる
  • Python Backendを使用してpythonのコードでも記載できる。
  • cloud(AWS,GCP)を含むさまざまな環境で動作する。

複数の機械学習モデルを運用するプロダクトは一度検討する価値があるよ!!

内容

  • Triton Inference Serverとは
  • Triton Inference Serverに興味を持った理由
  • 実装

Triton Inference Serverとは

Triton Inference Serverは機械学習モデルの標準化および、高速でスケーラブルな推論環境を提供するOOSです。

全体図:github repositoryより引用
image.png

Triton Inference Serverの特徴

  • ほとんどの機械学習フレームワークに対応(1番のメリットと言っても過言ではない)

    • TensorFlow
    • TensorRT™
    • PyTorch
    • MXNet
    • Python
    • ONNX
    • XGBoost
    • scikit-learn
    • RandomForest
    • OpenVINO
    • custom C++
  • 高性能推論サーバ

    • CPUとGPUの両方に最適化
    • dynamic batching(動的バッチ処理)に対応
    • 複数モデルの同時実行が可能
    • Model Ensemble(パイプラインのDAG記述が可能)
    • gRPC, HTTP/REST APIをサポート
    • GPUやメモリ、レイテンシーなどの使用状況や各種メトリクスの監視が可能
  • さまざまな環境で稼働可能(コンテナ化)

    • k8s
    • Google Cloud Platform
    • Amazon Web Service

ちなみにONNX RuntimeもTriton Inference Serverに近い思想を持っていると思います。
さまざまなフレームワークの機械学習モデルをONNX形式に変換することでONNX Runtime上で稼働することが可能となります。しかし、変換できないモデルが一定数存在しているなど、さまざまな機械学習フレームワークを使いたいという点においてはTriton Inference Serverの方がより適任であると考えています。

Triton Inference Serverに興味を持った理由

複数の機械学習フレームワークを使用したい

複数の機械学習フレームワークを使って推論システムを作成したいが、フレームワーク毎に合わせて推論サーバーを構築するようなコストがかかることはしたくない。現状使用している以外の機械学習フレームワークを使いたいという未来の要望にも応えられるようにしたい。

複数モデルの複雑なパイプラインを構築したい

複数モデルを使用する推論システムにおいて、モデルと前処理を繋ぎ合わせ一つのパイプラインとして実行したい。一つのパイプラインにまとめることのメリットは、クライアント側で前処理リクエストと推論リクエストを分けて管理する必要がなくなる点と、前処理から推論へのデータ受け渡しのオーバーヘッドがなくなることなどがある。

github repositoryより引用
下記の図の例では、画像の入力からセグメンテーションと分類を行う処理を一つのパイプラインとして実行している。
image.png

バッチ推論処理を楽に構築したい。

pythonのフレームワークを用いて推論サーバーを構築する際は、バッチ予測処理用のエンドポイントをリアルタイム処理用のエンドポイントとは別で作成する必要が出てくるケースがあると思います。バッチ予測処理用のエンドポイントはリアルタイム処理とは違い、大量のデータを捌く為の実装が必要となりコストがかかるケースがあります。Triton Inference Serverには動的バッチ処理の機能が備わっており、個々の推論要求をまとめることで簡単にスループットを向上させることが可能である為、バッチ予測処理用のエンドポイントを手軽に用意できそう。

実装してみた。

実装環境

MacOS Monetery version 12.5
Docker version 20.10.5
python 3.10.8
Poetry 1.2.2

python package version

tensorflow 2.11.0
Pillow 9.3.0

想定状況(自分の想像です。)

**Triton Inference Serverを構築して問題が解決できそうなシチュエーションを考えてみました。

複数の機械学習モデルを使用するアプリケーション。
モデル毎に使用する機械学習フレームワークが異なっている為、機械学習モデル毎に異なるエンドポイントを作成し、マイクロサービスアーキテクチャを採用している。

このようなシチュエーションにおいてTriton Inference Serverを導入したらどう変化するか考えてみます。

なんとなくのイメージ図
micro_service_architecture_without_triton.png

*table_data_model: categorical変数(ex.年齢や性別情報)を使って何かしらの予測をするモデル
*image_model: 画像を使用して何かしらの予測をするモデル。
*multi_modal_model: table_data_model, image_modelの予測結果を使ってサービスで提供する何かしらの予測を行うモデル

問題点

  1. 機械学習モデル毎にサーバーが増えることでメンテナンスコストが高い
  2. マイクロサービスアーキテクチャを採用することで、データ受け渡しのオーバーヘッドが生じておりパフォーマンスに影響を及ぼしている。

目指したい理想郷

  1. 異なる機械学習フレームワーク間の違いを吸収して、同じサーバー上で運用したい。
  2. マイクロサービスアーキテクチャは運用コストが高く、パフォーマンスに影響を及ぼしているので辞めたい。
  3. 機械学習モデルの予測や前処理、後処理の一連の流れを簡単に管理したい。

なんとなくのイメージ図
monolith_with_triton.png

コード実装

triton Inference Serverの実装をやっていきます。
Python Backendを使用したensemble_modelの参考としてこちらのgithubコードを参照させて頂きました。

Python Backend
The Triton backend for Python. The goal of Python backend is to let you serve models written in > Python by Triton Inference Server without having to write any C++ code.

Python BackendとはPythonのコードをTriton Inference Serverで動作させる為の機能です。

ディレクトリ構成図

.
├── model_repository
│   ├── ensemble_multi_modal_model
│   │   ├── 1
│   │   └── config.pbtxt
│   ├── image_model
│   │   ├── 1
│   │   │   └── model
│   │   ├── config.pbtxt
│   │   └── image_model_exporter.py
│   ├── image_preprocess
│   │   ├── 1
│   │   │   └── model.py
│   │   └── config.pbtxt
│   ├── multi_modal_model
│   │   ├── 1
│   │   │   └── model
│   │   ├── config.pbtxt
│   │   └── multi_modal_model_exporter.py
│   ├── table_data_model
│   │   ├── 1
│   │   │   └── model
│   │   ├── config.pbtxt
│   │   └── table_data_model_exporter.py
│   └── table_data_preprocess
│       ├── 1
│       │   └── model.py
│       └── config.pbtxt
├── multi_modal_client.py
└── triton-server.Dockerfile

ディレクトリ構成の説明

├── model_repository
│   ├── MODEL_NAME -> モデルの名前になります。モデルの名前は処理を呼び出すときに使用されます。
│   │   ├── VERSION -> モデルのバージョンを記載します。
│   │   │   └── model -> モデルやPython backendの実装を配置
│   │   ├── config.pbtxt -> モデルの入力や出力、処理に関する設定情報をこちらに記載します。
│   │   └── image_model_exporter.py -> モデルを生成する処理を記載するファイル。
...
├── multi_modal_client.py -> 予測リクエストを作成するクライアント
└── triton-server.Dockerfile -> 必要なパッケージを含んだdocker imageをビルドする為のDockerfile

Docker Imageのビルド

リリースノートを参照して、環境に合うdocker imageをbaseにして必要なpackageをinstallしたdocker imageを最初にビルドします。

triton-server.Dockerfile
FROM nvcr.io/nvidia/tritonserver:22.11-py3

RUN pip install tensorflow pillow # 機械学習モデルや前処理に必要なパッケージをダウンロード
$ docker build -f triton-server.Dockerfile -t triton-server .

TableDataModelを作成

ここでは簡単なテーブルデータを使用するモデルを作成し、ONNX形式に変換してみます。

Trition Inference Serverの検証が目的の為、機械学習モデルは特に意味のない、学習もしないモデルを作成して使用します

table_data_model_exporter.py
from tensorflow import keras

class TableDataModel(keras.Model):
    table_data_cols = {"categorical_col_1": 3, "categorical_col_2": 3}
    def __init__(self):
        inputs = []
        for name, vocab_num in self.table_data_cols.items():
            inputs.append(keras.layers.Input(shape=(vocab_num), name=name))
        concatenate= keras.layers.Concatenate(name="concatenate_layer")(inputs)
        x = keras.layers.Dense(64, activation="relu", name="dense")(concatenate)
        output = keras.layers.Dense(10, name="table_model_output")(x)

        super(TableDataModel, self).__init__(
            inputs=inputs, outputs=output
        )

if __name__ == "__main__":
    model = TableDataModel()
    model.save("./tmp/model/)
cd model_repository/table_data_model
python table_data_model_exporter.py
python -m tf2onnx.convert --saved-model ./tmp/model/ --output ./1/model/model.onnx  --verbose

ONNX形式のモデルはNetronというGUIツールで可視化することが可能です。
ONNXを使用する際は必須のツールみたいです。

Netronで可視化したモデル
model.onnx.png

TableDataModelのconfig.pbtxt
config.pbtxt
name: "table_data_model"
platform: "onnxruntime_onnx"
input [
{
    name: "categorical_col_1"
    data_type: TYPE_FP32
    dims: [ 3 ]
},
{
    name: "categorical_col_2"
    data_type: TYPE_FP32
    dims: [ 3 ]
}
]
output[
{
    name: "table_model_output"
    data_type: TYPE_FP32
    dims: [ 10 ]
}
]
default_model_filename: "model"
table_data_preprocess
model_repository/table_data_preprocess/1/model.py
import numpy as np
import json
import triton_python_backend_utils as pb_utils


class TritonPythonModel:
    def initialize(self, args):
        self.model_config = model_config = json.loads(args['model_config'])

    def execute(self, requests):
        output0_dtype = np.float32
        output1_dtype = np.float32

        responses = []

        def preprocess_one_hot(value):
            vocabs = [1, 2, 3]
            return np.array([1 if vocab == value else 0 for vocab in vocabs])

        for request in requests:
            in_0 = pb_utils.get_input_tensor_by_name(request, "table_data_input_1")
            in_1 = pb_utils.get_input_tensor_by_name(request, "table_data_input_2")
            table_data_0 = in_0.as_numpy()
            table_data_1 = in_1.as_numpy()

            table_data_0_out = np.expand_dims(preprocess_one_hot(table_data_0), axis=0)
            table_data_1_out = np.expand_dims(preprocess_one_hot(table_data_1), axis=0)

            out_tensor_0 = pb_utils.Tensor("table_data_output_1", table_data_0_out.astype(output0_dtype))
            out_tensor_1 = pb_utils.Tensor("table_data_output_2", table_data_1_out.astype(output1_dtype))

            inference_response = pb_utils.InferenceResponse(
                output_tensors=[out_tensor_0, out_tensor_1,])
            responses.append(inference_response)

        return responses

    def finalize(self):
        print('ImagePreprocess Cleaning up...')
table_data_preprocessのconfig.pbtxt
config.pbtxt
name: "table_data_preprocess"
backend: "python"
input [
{
    name: "table_data_input_1"
    data_type: TYPE_INT8
    dims: [ 1 ]
},
{
    name: "table_data_input_2"
    data_type: TYPE_INT8
    dims: [ 1 ]
}
]

output [
{
    name: "table_data_output_1"
    data_type: TYPE_FP32
    dims: [ 3 ]
},
{
    name: "table_data_output_2"
    data_type: TYPE_FP32
    dims: [ 3 ]
}
]

image_modelの実装

image_model_exporter.py
from tensorflow import keras
from tensorflow.keras.layers import (
    Conv2D,
    Flatten,
)


class MultiModelModel(keras.Model):
    def __init__(self):
        image_input = keras.layers.Input(shape=(3, 256, 256), name="image")
        conved_img = Conv2D(filters=3, kernel_size=(3, 3), name="conv_1")(image_input)
        flatten_img = Flatten(name="flatten_image")(conved_img)
        densed_img = keras.layers.Dense(64, activation="relu", name="img_dense")(flatten_img)

        output = keras.layers.Dense(10, name="output")(densed_img)

        super(MultiModelModel, self).__init__(
            inputs=[image_input], outputs=output
        )

if __name__ == "__main__":
    model = MultiModelModel()
    model.save("./1/model")
config.pbtxt
name: "image_model"
platform: "tensorflow_savedmodel"
max_batch_size: 256
input [
{
    name: "image"
    data_type: TYPE_FP32
    dims: [ 3, 256, 256 ]
}
]
output [
{
    name: "output"
    data_type: TYPE_FP32
    dims: [ 10 ]
}
]

instance_group [
  {
    count: 1
    kind: KIND_GPU # GPUを使用
  }
]
default_model_filename: "model"

image_preprocessの実装

model.py
import numpy as np
import json
import io
import triton_python_backend_utils as pb_utils
import numpy as np
from PIL import Image


class TritonPythonModel:
    def initialize(self, args):
        self.model_config = json.loads(args['model_config'])

        output0_config = pb_utils.get_output_config_by_name(
            self.model_config, "image_preprocess_output")

        self.output0_dtype = pb_utils.triton_string_to_numpy(
            output0_config['data_type'])

    def execute(self, requests):
        output0_dtype = self.output0_dtype

        responses = []

        for request in requests:
            in_0 = pb_utils.get_input_tensor_by_name(request, "image_preprocess_input")
            img = in_0.as_numpy()

            image = Image.open(io.BytesIO(img.tobytes()))
            resized_image = image.resize((256, 256))
            img_channel_first = np.transpose(resized_image, (2, 0, 1))
            img_out = np.array(img_channel_first)
            expand_dim_img_out = np.expand_dims(img_out, axis=0)
            out_tensor_0 = pb_utils.Tensor("image_preprocess_output",
                                           expand_dim_img_out.astype(output0_dtype))

            inference_response = pb_utils.InferenceResponse(
                output_tensors=[out_tensor_0])
            responses.append(inference_response)

        return responses

    def finalize(self):
        print('ImagePreprocess Cleaning up...')
config.pbtxt
name: "image_preprocess"
backend: "python"
max_batch_size: 256
input [
{
    name: "image_preprocess_input"
    data_type: TYPE_UINT8
    dims: [ -1 ]
}
]

output [
{
    name: "image_preprocess_output"
    data_type: TYPE_FP32
    dims: [ 3, 256, 256 ]
}
]

instance_group [
  {
    count: 2 # 並列実行処理
    kind: KIND_CPU
  }
]

上記のinstance_groupは並列実行の設定となっており、2つのインスタンスのpython_backendを並列に実行できるようになります。これによって大量のリクエストが来てもスループットの低下を防ぐことができます。

並列処理のイメージ
concurrency_image_preprocess.png

multi_modal_modelの実装

multi_modal_model_exporter.py
from tensorflow import keras


class MultiModelModel(keras.Model):
    def __init__(self):
        image_input = keras.layers.Input(shape=(10,), name="image")
        table_data_input = keras.layers.Input(shape=(10,), name="table_data")
        concatenated = keras.layers.Concatenate(name="concatenate_layer")([image_input] + [table_data_input])

        x = keras.layers.Dense(3, activation="relu", name="dense")(concatenated)
        output = keras.layers.Dense(1, name="output")(x)

        super(MultiModelModel, self).__init__(
            inputs=[table_data_input] + [image_input], outputs=output
        )

if __name__ == "__main__":
    model = MultiModelModel()
    model.save("./1/model/")
config.pbtxt
name: "multi_modal_model"
platform: "tensorflow_savedmodel"
max_batch_size: 256

# dynamic bathingを有効化する。
dynamic_batching {
    preferred_batch_size: [ 4 ] # まとめて予測したいバッチサイズ
    max_queue_delay_microseconds: 100000 # 0.1s
}

input [
{
    name: "image"
    data_type: TYPE_FP32
    dims: [ 10 ]
},
{
    name: "table_data"
    data_type: TYPE_FP32
    dims: [ 10 ]
}
]
output [
{
    name: "output"
    data_type: TYPE_FP32
    dims: [ 1 ]
}
]

instance_group [
  {
    count: 1
    kind: KIND_GPU
  }
]
default_model_filename: "model"

multi_modal_modelでは動的バッチ処理を有効化しています。
これによりmax_queue_deply_microsecondsの時間内に到着したリクエストはpreferred_batch_sizeの範囲で一つにまとめられて同時に処理されるようになります。
動的バッチ処理のイメージは下記になります。

記事より参照
image.png

ensemble_multi_modal_modelの設定

実際に実装した

  • table_data_model
  • table_data_preprocess
  • image_model
  • image_preprocess
  • multi_modal_model

の各ステップをつなぎ合わせてPipelineとするDAGの実装を行なっていきたいと思います。
やることは簡単でconfig.pbtxtを一つ用意するだけです。

model_repository/ensemble_multi_modal_model/config.pbtxt
name: "ensemble_multi_modal_model"
platform: "ensemble"
max_batch_size: 256

input [
{
    name: "table_data_input_1"
    data_type: TYPE_INT8
    dims: [ 1 ]
},
{
    name: "table_data_input_2"
    data_type: TYPE_INT8
    dims: [ 1 ]
},
{
    name: "image_preprocess_input"
    data_type: TYPE_UINT8
    dims: [ -1 ]
}
]
output [
{
    name: "multi_modal_output"
    data_type: TYPE_FP32
    dims: [ 1 ]
}
]
ensemble_scheduling {
    step [
        # table data modal
        {
            model_name: "table_data_preprocess"
            model_version: 1
            input_map {
                key: "table_data_input_1"
                value: "table_data_input_1"
            }
            input_map {
                key: "table_data_input_2"
                value: "table_data_input_2"
            }
            output_map {
                key: "table_data_output_1"
                value: "preprocessed_table_data_1"
            }
            output_map {
                key: "table_data_output_2"
                value: "preprocessed_table_data_2"
            }
        },
        {
            model_name: "table_data_model"
            model_version: 1
            input_map {
                key: "categorical_col_1"
                value: "preprocessed_table_data_1"
            }
            input_map {
                key: "categorical_col_2"
                value: "preprocessed_table_data_2"
            }
            output_map {
                key: "table_model_output"
                value: "table_data_model_output"
            }
        },

        # image modal
        {
            model_name: "image_preprocess"
            model_version: 1
            input_map {
                key: "image_preprocess_input"
                value: "image_preprocess_input"
            }
            output_map {
                key: "image_preprocess_output"
                value: "preprocessed_image"
            }
        },
        {
            model_name: "image_model"
            model_version: 1
            input_map {
                key: "image"
                value: "preprocessed_image"
            }
            output_map {
                key: "output"
                value: "image_model_output"
            }
        },

        # multi modal model
        {
        model_name: "multi_modal_model"
        model_version: 1
        input_map {
            key: "image"
            value: "image_model_output"
        }
        input_map {
            key: "table_data"
            value: "table_data_model_output"
        }
        output_map {
            key: "output"
            value: "multi_modal_output"
        }
        }
    ]
}

サーバーの立ち上げ

モデル、前処理の実装とensemble modelの設定が済んだのでサーバーを立ち上げていきます。
最初に作成したdocker imageを使用し、ローカルのモデルコードをマウントして起動させてあげるだけで大丈夫です!
楽ちんですね。

docker run --shm-size=1g -it --rm -p8000:8000 -p8001:8001 -p8002:8002 -v $(pwd)/model_repository/models/ triton-server tritonserver --model-repository=/models

Triton Inference Server上に各モデルが展開されるログ情報が流れてくると思いますが、下記のように全てのモデルステータスがREADYとなり、最後に3つのサーバがスタートしていることが表示されれば大丈夫です。

.
.
.
+----------------------------+---------+--------+
| Model                      | Version | Status |
+----------------------------+---------+--------+
| ensemble_multi_modal_model | 1       | READY  |
| image_model                | 1       | READY  |
| image_preprocess           | 1       | READY  |
| multi_modal_model          | 1       | READY  |
| table_data_model           | 1       | READY  |
| table_data_preprocess      | 1       | READY  |
+----------------------------+---------+--------+
.
.
.

Started GRPCInferenceService at 0.0.0.0:8001
Started HTTPService at 0.0.0.0:8000
Started Metrics Service at 0.0.0.0:8002

クライアントの実行

サーバーが無事に立ち上がったのでクライアントを実行していきたいと思います。
Tritonで公式に提供しているクライアントライブラリを利用します。
ビルド済みバイナリ、pip経由、ビルド済みdockerコンテナイメージが存在していますが、
今回はコンテナを利用してクライアントを実行したいと思います。

client.pyの実装

client.py
import numpy as np
import tritonclient.grpc as triton_client
import argparse


def load_image(img_path: str):
    """
    Loads an encoded image as an array of bytes.
    """
    return np.fromfile(img_path, dtype='uint8')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name",
                        type=str,
                        required=False,
                        default="ensemble_multi_modal_model",
                        help="Model name")
    parser.add_argument("--image",
                        type=str,
                        required=False,
                        help="Path to the image")
    parser.add_argument("--url",
                        type=str,
                        required=False,
                        default="localhost:8001",
                        help="Inference server URL. Default is localhost:8001.")
    args = parser.parse_args()

    client = triton_client.InferenceServerClient(url=args.url)

    inputs = []
    outputs = []

    # table_data
    table_data_input1_name = "table_data_input_1"
    table_data_input2_name = "table_data_input_2"
    inputs.append(triton_client.InferInput(table_data_input1_name, (1, 1), "INT8"))
    inputs.append(triton_client.InferInput(table_data_input2_name, (1, 1), "INT8"))
    inputs[0].set_data_from_numpy(np.array([[1]]).astype(np.int8))
    inputs[1].set_data_from_numpy(np.array([[2]]).astype(np.int8))

    # image
    input_name = "image_preprocess_input"
    image_data = load_image(args.image)
    image_data = np.expand_dims(image_data, axis=0)
    inputs.append(triton_client.InferInput(input_name, image_data.shape, "UINT8"))
    inputs[2].set_data_from_numpy(image_data)

    # output
    output_name = "multi_modal_output"
    outputs.append(triton_client.InferRequestedOutput(output_name))

    results = client.infer(model_name=args.model_name,
                                  inputs=inputs,
                                  outputs=outputs)

    output = results.as_numpy(output_name)
    print(output)

githubのサンプルコードからテスト用画像をダウンロードします。

$ wget https://raw.githubusercontent.com/triton-inference-server/server/main/qa/images/mug.jpg -O "mug.jpg"

client docker containerを起動して予測を行なってみます。

$ docker run --rm --net=host -v $(pwd):/workspace/ nvcr.io/nvidia/tritonserver:22.11-py3-sdk python multi_modal_client.py --image mug.jpg
$ [[-64.25122]]

無事に予測が成功して何かしらの予測値が返ってくることが確認できました。

結論

Triton Inference Serverを使用して簡単な実装を行なってみました。
意外に簡単に実装できたなと感じています。
もっと複雑なアーキテクチャのアプリケーションでは実装がより大変になるのかもしれませんね。
本番の環境においてどれほど有能かはまだわかりませんが、下記の3つの点で大変優れた機械学習推論サーバーだと感じました。

  • ほとんどの機械学習フレームワークに対応
  • 高性能推論サーバ
    • 並列実行
    • 動的バッチ処理
    • model ensemble
  • さまざまな環境で稼働可能

参考

  1. https://medium.com/nvidiajapan/gpu-for-inference-easy-deploy-by-triton-inference-server-fd2980514af2
  2. https://github.com/triton-inference-server/python_backend/tree/main/examples/preprocessing
7
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
1