LoginSignup
0
0

onnx runtime を使って GPT を c++ から使ってみる on Windows

Last updated at Posted at 2023-10-10

なぜもう一度?

以前にrinna GPTモデルをonnxに変換し、C++から使ってみるの記事で、onnx モデルを C++ から使ったのですが、このときは Windows.AI.MachineLearning API を使いました。
ところが、rinna/japanese-gpt-neox-3.6b 等の大きなモデルを optimum-cli で export すると、*.onnx のほかに、*.onnx_data という外部データが出力されます。どうも、このファイルの読み込みが Windows.AI.MachineLearning API では失敗し、回避策が見当たりませんでした。

エラーログにはこんなのが出ます。

onnxruntime::Initializer::Initializer status.IsOK() was false. ReadExternalRawData() failed: initializer.cc:23 onnxruntime::Initializer::ReadExternalRawData model_path must not be empty. Ensure that a path is provided when the model is created or loaded.'.

邪推するに、*.onnx_data を読み込むときに、必要な設定がなされていない感じです。

おそらく、Windows.AI.MachineLearning は、onnx runtime を OS 標準にして wrapper API を準備したものだと思われるので…

本家 onnx runtime 使う

vs だと、導入は簡単で

  1. nuget から Microsoft.ML.OnnxRuntime をインストール
  2. onnxruntime_cxx_api.h をインクルード
  3. onnxruntime.lib をリンク

だけでした。

モデル input/output を調べましょう

onnx モデルとしては、同じなので念のためですが、モデルの input/output を確認しました。

// モデルの読み込みとセッションの作成
Ort::Env env;
Ort::SessionOptions sessionOptions;
m_session = Ort::Session(env, modelFileName.c_str(), sessionOptions);

// 入力データの確認
for (size_t i = 0; i < m_session.GetInputCount(); ++i) {
    auto inputName = m_session.GetInputNameAllocated(i, alloc);
    auto inputType = m_session.GetInputTypeInfo(i);
    auto shapeInfo = inputType.GetTensorTypeAndShapeInfo();
    auto shape = shapeInfo.GetShape();
    auto elementType = shapeInfo.GetElementType();
}

// 出力データの確認
for (size_t i = 0; i < m_session.GetOutputCount(); ++i) {
    auto outputName = m_session.GetOutputNameAllocated(i, alloc);
    auto outputType = m_session.GetOutputTypeInfo(i);
    auto shapeInfo = outputType.GetTensorTypeAndShapeInfo();
    auto shape = shapeInfo.GetShape();
    auto elementType = shapeInfo.GetElementType();
}

結果は以下となります。 -1 は入力時の可変サイズということでしょう。
入力

name:input_ids: shape:[-1,-1], type:INT64
name:attention_mask: shape:[-1,-1], type:INT64

出力 (他にもありますが、以下だけでいいかと思われます)

name:logits: shape:[-1,-1,32000], type:FLOAT

onnx 使ってみる

ということで、モデルを実行してみましょう。
あまりきれいではないサンプルコードはここにあります。

モデルの読み込み

Ort::Env env;
Ort::SessionOptions sessionOptions;
m_session = Ort::Session(env, modelFileName.c_str(), sessionOptions);

入力として std::vector tokens; でトークン列が来るとして、入力データの作成とバインド

// attention mask ソースデータの作成
const auto tokenSize = static_cast<int64_t>(tokens.size());
std::vector<int64_t> attentionMask(tokens.size(), 1LL);

// Ort::Value Tensor の作成
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
auto inputShape = std::array<int64_t, 2> { 1, tokenSize }; 
auto idTensor = Ort::Value::CreateTensor<int64_t>(memory_info, tokens.data(), tokens.size(), inputShape.data(), inputShape.size());
auto maskTensor = Ort::Value::CreateTensor<int64_t>(memory_info, attentionMask.data(), attentionMask.size(), inputShape.data(), inputShape.size());

// 入力に設定
auto ioBinding = Ort::IoBinding(m_session);
ioBinding.BindInput("input_ids", idTensor);
ioBinding.BindInput("attention_mask", maskTensor);

出力データの作成とバインド (32000は、トークンIDのバリエーション数になります。本来なら、出力の logits の3次元目から持ってくるべきでしょう)

auto outputShape = std::array<int64_t, 3> { 1, tokenSize,  32000 };
std::vector<float> logits(tokens.size() * 32000, 0.0f);
auto outputTensor = Ort::Value::CreateTensor<float>(memory_info, logits.data(), logits.size(), outputShape.data(), outputShape.size());

ioBinding.BindOutput("logits", outputTensor);

あとは、モデルを実行します

auto runOptions = Ort::RunOptions();
m_session.Run(runOptions, ioBinding);

そうすると、std::vector<float> logits; に結果が入っているため、ここを使って必要な計算 (softmax など) をすればよいかと思います。

大きいモデルを実行してみる

さて、rinna/japanese-gpt-neox-3.6b を使ってみましょう。
3.6b は、36億 (3,600,000,000) パラメータということでしょうね。onnx にコンバートすると、14.4GB のデータが出てきます。3.6 * sizeof(float) で、14.4 なのでちょうどあってます。

また、(私のローカルの環境では、) モデルの読み込みに 14GB 以上のメモリを確保するし、モデルの読み込みには10秒程度かかるのですが、一回の計算は 500ms 程度で終わりました。

恐れ多くも AT*K 様と比較してみる

実は、"mozcの変換にRinna GPTモデルを使ってみる" のサンプルで使っている「部屋で犬を/飼う」「店で犬を/買う」等は、AT*K 様は難なく変換してくださいます。
想像するに、(時代的に)、おそらく確定単語(自立語)をワードバッグとして、変換候補との共起確率を計算しているのではないかと思います。
したがって、若干意地悪に、「助詞のみで時制を表しニュアンスを変える」「係り受けが複雑なケース」などでは対応しないようです。

助詞のみで時制を制御

確定部 入力 AT*K 結果
昨日犬を かった 飼った
昨日から犬を かった 飼った

係り受けを複雑にしたケース

確定部 入力 AT*K が間違う結果
店で買った犬を かっている 買っている
店で、友達が飼っている同じ犬を かった 飼った

ところが rinna/japanese-gpt-neox-3.6b は、素晴らしく確率を計算してくれます。(まだmozcに組み込んでないので、候補の全文確率比較ですが)

評価文 確率 sum(log(softmax)) Best
昨日犬を買った -24.128618
昨日犬を飼った -24.528269
昨日犬をかった -27.541098
昨日犬を勝った -33.042194
評価文 確率 sum(log(softmax)) Best
昨日から犬を買った -29.213923
昨日から犬を飼った -26.120699
昨日から犬をかった -31.969265
昨日から犬を勝った -37.458824
評価文 確率 sum(log(softmax)) Best
店で買った犬を買っている -32.549412
店で買った犬を飼っている -28.445147
店で買った犬をかっている -34.370369
店で買った犬を勝っている -38.278229
評価文 確率 sum(log(softmax)) Best
店で、友達が飼っている同じ犬を買った -48.479939
店で、友達が飼っている同じ犬を飼った -50.967255
店で、友達が飼っている同じ犬をかった -54.449730
店で、友達が飼っている同じ犬を勝った -56.815639

まとめ

  • 使用 API を ONNX Runtime に変更した (Windows.AI.MachineLearning で外部データ読み込み問題を回避できなかったので)
  • rinna/japanese-gpt-neox-3.6b を使ったときのメモリ使用量と計算時間を確認してみた
  • 仮名漢字変換のヘルパとして rinna/japanese-gpt-neox-3.6b を使ったときの精度をサンプルで試してみた。
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0