なぜもう一度?
以前にrinna GPTモデルをonnxに変換し、C++から使ってみるの記事で、onnx モデルを C++ から使ったのですが、このときは Windows.AI.MachineLearning API を使いました。
ところが、rinna/japanese-gpt-neox-3.6b 等の大きなモデルを optimum-cli で export すると、*.onnx
のほかに、*.onnx_data
という外部データが出力されます。どうも、このファイルの読み込みが Windows.AI.MachineLearning API では失敗し、回避策が見当たりませんでした。
エラーログにはこんなのが出ます。
onnxruntime::Initializer::Initializer status.IsOK() was false. ReadExternalRawData() failed: initializer.cc:23 onnxruntime::Initializer::ReadExternalRawData model_path must not be empty. Ensure that a path is provided when the model is created or loaded.'.
邪推するに、*.onnx_data
を読み込むときに、必要な設定がなされていない感じです。
おそらく、Windows.AI.MachineLearning は、onnx runtime を OS 標準にして wrapper API を準備したものだと思われるので…
本家 onnx runtime 使う
vs だと、導入は簡単で
- nuget から Microsoft.ML.OnnxRuntime をインストール
-
onnxruntime_cxx_api.h
をインクルード - onnxruntime.lib をリンク
だけでした。
モデル input/output を調べましょう
onnx モデルとしては、同じなので念のためですが、モデルの input/output を確認しました。
// モデルの読み込みとセッションの作成
Ort::Env env;
Ort::SessionOptions sessionOptions;
m_session = Ort::Session(env, modelFileName.c_str(), sessionOptions);
// 入力データの確認
for (size_t i = 0; i < m_session.GetInputCount(); ++i) {
auto inputName = m_session.GetInputNameAllocated(i, alloc);
auto inputType = m_session.GetInputTypeInfo(i);
auto shapeInfo = inputType.GetTensorTypeAndShapeInfo();
auto shape = shapeInfo.GetShape();
auto elementType = shapeInfo.GetElementType();
}
// 出力データの確認
for (size_t i = 0; i < m_session.GetOutputCount(); ++i) {
auto outputName = m_session.GetOutputNameAllocated(i, alloc);
auto outputType = m_session.GetOutputTypeInfo(i);
auto shapeInfo = outputType.GetTensorTypeAndShapeInfo();
auto shape = shapeInfo.GetShape();
auto elementType = shapeInfo.GetElementType();
}
結果は以下となります。 -1 は入力時の可変サイズということでしょう。
入力
name:input_ids: shape:[-1,-1], type:INT64
name:attention_mask: shape:[-1,-1], type:INT64
出力 (他にもありますが、以下だけでいいかと思われます)
name:logits: shape:[-1,-1,32000], type:FLOAT
onnx 使ってみる
ということで、モデルを実行してみましょう。
あまりきれいではないサンプルコードはここにあります。
モデルの読み込み
Ort::Env env;
Ort::SessionOptions sessionOptions;
m_session = Ort::Session(env, modelFileName.c_str(), sessionOptions);
入力として std::vector tokens; でトークン列が来るとして、入力データの作成とバインド
// attention mask ソースデータの作成
const auto tokenSize = static_cast<int64_t>(tokens.size());
std::vector<int64_t> attentionMask(tokens.size(), 1LL);
// Ort::Value Tensor の作成
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
auto inputShape = std::array<int64_t, 2> { 1, tokenSize };
auto idTensor = Ort::Value::CreateTensor<int64_t>(memory_info, tokens.data(), tokens.size(), inputShape.data(), inputShape.size());
auto maskTensor = Ort::Value::CreateTensor<int64_t>(memory_info, attentionMask.data(), attentionMask.size(), inputShape.data(), inputShape.size());
// 入力に設定
auto ioBinding = Ort::IoBinding(m_session);
ioBinding.BindInput("input_ids", idTensor);
ioBinding.BindInput("attention_mask", maskTensor);
出力データの作成とバインド (32000は、トークンIDのバリエーション数になります。本来なら、出力の logits の3次元目から持ってくるべきでしょう)
auto outputShape = std::array<int64_t, 3> { 1, tokenSize, 32000 };
std::vector<float> logits(tokens.size() * 32000, 0.0f);
auto outputTensor = Ort::Value::CreateTensor<float>(memory_info, logits.data(), logits.size(), outputShape.data(), outputShape.size());
ioBinding.BindOutput("logits", outputTensor);
あとは、モデルを実行します
auto runOptions = Ort::RunOptions();
m_session.Run(runOptions, ioBinding);
そうすると、std::vector<float> logits;
に結果が入っているため、ここを使って必要な計算 (softmax など) をすればよいかと思います。
大きいモデルを実行してみる
さて、rinna/japanese-gpt-neox-3.6b を使ってみましょう。
3.6b は、36億 (3,600,000,000) パラメータということでしょうね。onnx にコンバートすると、14.4GB のデータが出てきます。3.6 * sizeof(float) で、14.4 なのでちょうどあってます。
また、(私のローカルの環境では、) モデルの読み込みに 14GB 以上のメモリを確保するし、モデルの読み込みには10秒程度かかるのですが、一回の計算は 500ms 程度で終わりました。
恐れ多くも AT*K 様と比較してみる
実は、"mozcの変換にRinna GPTモデルを使ってみる" のサンプルで使っている「部屋で犬を/飼う」「店で犬を/買う」等は、AT*K 様は難なく変換してくださいます。
想像するに、(時代的に)、おそらく確定単語(自立語)をワードバッグとして、変換候補との共起確率を計算しているのではないかと思います。
したがって、若干意地悪に、「助詞のみで時制を表しニュアンスを変える」「係り受けが複雑なケース」などでは対応しないようです。
助詞のみで時制を制御
確定部 | 入力 | AT*K 結果 |
---|---|---|
昨日犬を | かった | 飼った |
昨日から犬を | かった | 飼った |
係り受けを複雑にしたケース
確定部 | 入力 | AT*K が間違う結果 |
---|---|---|
店で買った犬を | かっている | 買っている |
店で、友達が飼っている同じ犬を | かった | 飼った |
ところが rinna/japanese-gpt-neox-3.6b は、素晴らしく確率を計算してくれます。(まだmozcに組み込んでないので、候補の全文確率比較ですが)
評価文 | 確率 sum(log(softmax)) | Best |
---|---|---|
昨日犬を買った | -24.128618 | ○ |
昨日犬を飼った | -24.528269 | |
昨日犬をかった | -27.541098 | |
昨日犬を勝った | -33.042194 |
評価文 | 確率 sum(log(softmax)) | Best |
---|---|---|
昨日から犬を買った | -29.213923 | |
昨日から犬を飼った | -26.120699 | ○ |
昨日から犬をかった | -31.969265 | |
昨日から犬を勝った | -37.458824 |
評価文 | 確率 sum(log(softmax)) | Best |
---|---|---|
店で買った犬を買っている | -32.549412 | |
店で買った犬を飼っている | -28.445147 | ○ |
店で買った犬をかっている | -34.370369 | |
店で買った犬を勝っている | -38.278229 |
評価文 | 確率 sum(log(softmax)) | Best |
---|---|---|
店で、友達が飼っている同じ犬を買った | -48.479939 | ○ |
店で、友達が飼っている同じ犬を飼った | -50.967255 | |
店で、友達が飼っている同じ犬をかった | -54.449730 | |
店で、友達が飼っている同じ犬を勝った | -56.815639 |
まとめ
- 使用 API を ONNX Runtime に変更した (Windows.AI.MachineLearning で外部データ読み込み問題を回避できなかったので)
- rinna/japanese-gpt-neox-3.6b を使ったときのメモリ使用量と計算時間を確認してみた
- 仮名漢字変換のヘルパとして rinna/japanese-gpt-neox-3.6b を使ったときの精度をサンプルで試してみた。