LoginSignup
1
0

Google の Magika を Python から C#に移植するまでの過程 (Day 2 / 7)

Last updated at Posted at 2024-03-04

Google のファイル判定プログラム Magika を Python から C# に移植する過程を共有する記事の第2回目です。

前回は Python で書かれた Magika のコードを見てどのように動作しているのか概要を理解することから始めました。今回はその内容をもとに実際に C# のコードを書いてみます。

目次

C# で 概念実証コードを書いてみる

自分は Python も C# 全然詳しくないですし、ONNX 形式の機械学習モデルを読み込んで推論させるのもやったことがないので、そもそも本当に C# で同じことができるの?ということから不透明な状態でした。ですので、いきなり移植作業を始めるのではなく、まずは概念実証コードを書いてみることにします。

最初のとっかかりの部分なので初学者の方でも追っかけられるよう、詳細に書いていきます。

開発環境を整える

まずは C# の開発環境を整えます。ちなみにすべての作業は Windows 11 で行っています。

  • .NET 8.0 SDK
    こちらから NET 8.0 SDK をダウンロードしてインストールします。

  • Visual Studio Code
    がっつり開発するなら Visual Studio 2022 のほうが良いのかもしれませんが、慣れていないので Visual Studio Code を使っていきます。ダウンロードはこちらから。

  • Visual Studio Code の C# 拡張機能
    Visual Studio Code には C# の開発をサポートする拡張機能があります。これをインストールしておいたほうが便利でした。

  • Visual Studio 2015-2022 Visual C++ 再頒布可能パッケージ
    ONNX Runtime を動かすのに必要です。こちらからダウンロードしてインストールします。

新規プロジェクトを作成する

最終的には C# のクラスライブラリとして作成していきたいのですが、まずは簡単なコンソールアプリで作ってみます。

コンソールアプリの新規プロジェクトを作成します。コマンドプロンプトを開いて

dotnet new console -o SampleConsoleApp
code .\SampleConsoleApp

これで新規のプロジェクトが作成され、Visual Studio Code で開かれます。Program.cs というファイルが自動作成されているので、これの内容を書き換えていくことになります。

day1-img-01.png

依存ライブラリの追加

プログラムコードを書いていく前に、プログラムから ONNX Runtime を使えるよう必要なライブラリを追加していきます。
コマンドプロンプトでプロジェクトのディレクトリに移動して、以下のコマンドを実行するだけです。

dotnet add package Microsoft.ML.OnnxRuntime

ONNXモデルをコピーしてくる

オリジナルの Magika レポジトリ内の /python/magika/models/standard_v1/model.onnx ファイルをダウンロードして、Program.cs と同じディレクトリにコピーしておきます。

プログラムコードを書く

ようやく準備が整いました。オリジナルのmagia.pyのコードを見ながらProgram.cs の内容を書き換えていきます。

まず、適当なファイルを読み込んで先頭 512 バイト、中間部分の 512 バイト、末尾の 512 バイトを抽出して連結した 1536 バイトの配列を用意する部分を書いてみます。本来は関数を分けたり、変数を外部から読み込むなどの処理を行うところですが、今回は簡単にすべての処理をMain関数内に書いて、変数もすべて固定値でベタ書きしていきます。

using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;

namespace SampleConsoleApp;

class Program
{
    static void Main(string[] args)
    {
      // 定数
        const int beg_size = 512, mid_size = 512, end_size = 512;
        const int padding_token = 256;
        int[] beg_ints, mid_ints, end_ints;
        ReadOnlySpan<int> pad_ints = Enumerable.Repeat<int>(padding_token, 512).ToArray().AsSpan();

        // ファイル読み込み
        byte[] input;
        var fileInfo = new FileInfo(@"C:\Users\test\Desktop\sample.zip");  //サンプルファイルは用意して保存しておく
        // とりあえずファイル全部読み込んじゃう。本番では部分的に読み込むなど工夫する
        input = File.ReadAllBytes(fileInfo.FullName);
        // byte配列をint配列に変換
        ReadOnlySpan<int> inputInt = input.Select(x => (int)x).ToArray().AsSpan();

        // 先頭の512バイトを取得
        if (beg_size < input.Length)
        {
            beg_ints = inputInt[..beg_size].ToArray();
        }
        else
        {
            var padding_size = beg_size - input.Length;
            beg_ints = ArrayConcat2(inputInt, pad_ints[..padding_size]);
        }
        // 中間の512バイトを取得
        var mid_idx = input.Length / 2;
        if (mid_size < input.Length)
        {
            var left_idx = mid_idx - mid_size / 2;
            var right_idx = mid_idx + mid_size / 2;
            if (mid_size % 2 != 0) right_idx += 1;
            mid_ints = inputInt[left_idx..right_idx].ToArray();
        }
        else
        {
            var padding_size = mid_size - input.Length;
            var left_padding_size = padding_size / 2;
            var right_padding_size = padding_size / 2;
            if (padding_size % 2 != 0) right_padding_size += 1;
            mid_ints = ArrayConcat3(pad_ints[..left_padding_size], inputInt, pad_ints[..right_padding_size]);
        }
        // 末尾の512バイトを取得
        if (input.Length > end_size)
        {
            end_ints = inputInt[^end_size..].ToArray();
        }
        else
        {
            var padding_size = end_size - input.Length;
            end_ints = ArrayConcat2(pad_ints[..padding_size], inputInt);
        }
        // コンソールに出力
        int[] output = ArrayConcat3<int>(beg_ints, mid_ints, end_ints);
        output.ToList().ForEach(x => Console.Write(x + ", "));
    }

    static Type[] ArrayConcat2<Type>(ReadOnlySpan<Type> s1, ReadOnlySpan<Type> s2)
    {
        var array = new Type[s1.Length + s2.Length];
        s1.CopyTo(array);
        s2.CopyTo(array.AsSpan(s1.Length));
        return array;
    }

    static Type[] ArrayConcat3<Type>(ReadOnlySpan<Type> s1, ReadOnlySpan<Type> s2, ReadOnlySpan<Type> s3)
    {
        var array = new Type[s1.Length + s2.Length + s3.Length];
        s1.CopyTo(array);
        s2.CopyTo(array.AsSpan(s1.Length));
        s3.CopyTo(array.AsSpan(s1.Length + s2.Length));
        return array;
    }
}

出来たらコマンドプロンプトでdotnet runで実行し、想定通りの出力が得られているか確認します。

> dotnet run
80, 75, 3, 4, 20, 0, 0, 0, 8, 0, ...

なんとなくいい感じにできているようです。

続いて、キモとなる ONNX Runtime を使って機械学習モデルを読み込んで推論を行う部分を書いていく...のですが、C# で ONNX Runtime を使う方法を解説しているWebページがあまり見つからずに苦戦しました。最終的には以下のページを参考にさせてもらいました。

Python の場合は単に float 型の NDArray を入力として渡せばよいようですが、C# の場合は float 型の配列を一度Tensor<float>型に変換し、NamedOnnxValue.CreateFromTensor<float>("bytes", input)NamedOnnxValue型にさらに変換したものをList<NamedOnnxValue>に格納して入力パラメータとして渡す必要があります。

もともとのファイルはバイナリであり byte 型の配列でしたので、byte[]int[]float[]Tensor<float>NamedOnnxValueList<NamedOnnxValue>と次々に変換していく必要があります。

正直なぜこんなに面倒なのかよくわからないと思いつつも、サンプルコードなどを見ながら試行錯誤した結果、以下のようなコードで推論を実行して出力を得ることができました。

// 先のコードのMain関数の続きから書いていきます

// int配列をfloat配列に変換
float[] floatInput = ArrayConcat3<int>(beg_ints, mid_ints, end_ints).Select(x => (float)x).ToArray();

// 機械学習モデルを読み込んで推論のためのセッションを作成
var modelPath = @".\model.onnx";
using var session = new InferenceSession(modelPath);

// ONNXモデルに入力するためのNamedOnnxValueを作成
Tensor<float> X = new DenseTensor<float>(floatInput, [1, 1536], false);
//
var inputs = new List<NamedOnnxValue>() {
    NamedOnnxValue.CreateFromTensor<float>("bytes", X)
};

// 推論を実行
var inferenceResults = session.Run(inputs);

// 結果をコンソールに出力
inferenceResults[0].AsEnumerable<float>().ToList().ForEach(x => Console.Write($"{x:N8}" + "\n"));

先ほどと同様にdotnet runで実行し、出力を確認します。113 行に渡ってずらーっと浮動小数点の数値が出力されているのが見えると思います。

> dotnet run
0.00000000
0.00000020
...
0.00000000
0.99997282
0.00000000

出力された数値はほとんどが 0 に近い値ですが、その中で一番最後から 2 番目の値だけが 0.99997282 と 1 に近い値になっていました。前回説明した通り、これらの数値はそれぞれがファイル形式と 1:1 で対応しており、数値が大きいほどそのファイル形式の確率が高いということになります。

では最後から 2 番目の数値がどのファイル形式に対応しているのか確認しましょう。オリジナルの Magika レポジトリの/python/magika/models/standard_v1/model_config.json内のtarget_labels_spaceというリストがそれです。このリストの中で一番最後から 2 番目の要素を見ると、それが"zip"であることがわかります。

"target_labels_space": [
  "ai",
  "apk",
  ....
  "yaml",
  "zip",
  "zlibstream"
],

使用したサンプルファイルは ZIP ファイルでしたので、どうやら正しく推論ができているようです。

続く

なんとか C# でオリジナルの Magika と同様の推論を行い結果を得ることができました。C# 移植への手ごたえをつかむことができましたので、次回から Magika の Python コードを C# にゴリゴリと移植していく作業を進めていきます。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0