4
6

【TorchSharp】C#で機械学習 ①手書き文字認識アプリの実装

Last updated at Posted at 2024-09-01

0. はじめに

 機械学習に関するプログラムは,基本的にpythonにより実装されます.しかしC#(.Net)で作成済みのプログラムに機械学習の機能を追加したい場合や,GUIも作成したい場合など,C#(.Net)で機械学習の処理を実装しなければならない状況もあると思います.C#用の機械学習ライブラリはいくつかありますが,.NET Foundationに組み込まれているTorchSharpが無難な選択肢だと思います.Pytorchベースなので,Pytorchに精通している方はTorchSharpも使いこなせると思います.しかし,TorchSharpに関する情報が少なかったため,本稿ではTorchSharp + C#(.Net 8.0)による教師あり学習の実装例として,手書き数字のクラス分類アプリケーションをGUI付で作成する例を紹介します. 教師データの学習に加えて,以下のようにGUIにキャンバスを設置して,描かれた数字を認識する機能を持つアプリケーションを作成します.

image.png

ホームページでは同一記事をソースコードと学習用データセットをGitHub経由でダウンロードできます. HP: https://kkaneko-lab.com/?p=285

1. プロジェクトの作成

 今回はWPFでGUIも作成したいので,.Net WPFアプリケーションを選択します.
image.png

2. パッケージの準備

 最新のTorchSharpをNugetでインストールします.GPUを使って学習を行いたいので,同じバージョンのTorchSharp-cuda-windowsも同様にインストールします.また,本プログラムではBitmapを扱うため,System.Drawing.Commonも追加しました.
image.png

3. 機械学習モデルの実装

 以下のように,機械学習モデルのクラス MLModel を実装しました.メンバ変数としてモデルを構成する各層をメンバ変数として記述します.各層の次元は,コンストラクタで初期化するようにしました.全結合層の入力次元を把握するために,ダミーの入力データを用いて,畳み込み層の出力”dammyConvOutput”を計算しています.
 if (torch.cuda.is_available()) _device = CUDA;では,GPUが使用可能かを判定し,使用できる場合はGPUで学習を実行するようにコーディングしました._device = CUDAであれば,this.to(_device);でモデルがGPUに転送されます.
 コンストラクタに加えて,順伝播の関数”forward”,バッチ学習を行う”TrainOnBatch”,学習後に推論を行うための”Predict”の3つの関数も定義しています.それぞれの関数は以下の通りです.

forward・・・親クラスの関数をoverrideしています.単純に各層の出力を順番に計算していくだけです.今回はクラス分類を行うので,出力層の活性化関数はSoftmaxとしております.

TrainOnBatch・・・学習データをバッチに分割して,まずvar predicted = this.forward(input);で順伝播します.その後教師データとの差分をvar error = loss.forward(predicted, output);で計算することで,計算グラフが構築されるので,error.backward();でグラフを辿って逆伝播します.

Predict・・・学習後に呼び出す推論用の関数です.Tensor化した画像がどのクラスに分類されるか予測します.戻り値は予測されるクラスのインデックスと確率です.

MLModel.cs
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
using static TorchSharp.torch.nn.functional;

namespace TorchSharpSupervisedLearning
{
    public class MLModel : Module<Tensor, Tensor>
    {
        #region メンバ変数
        /// <summary>
        /// 畳み込み層1
        /// </summary>
        private Conv2d _conv1;
        /// <summary>
        /// 畳み込み層2
        /// </summary>
        private Conv2d _conv2;
        /// <summary>
        /// 全結合層1
        /// </summary>
        private Linear _linear1;
        /// <summary>
        /// 全結合層2
        /// </summary>
        private Linear _linear2;
        /// <summary>
        /// 隠れ層のサイズ
        /// </summary>
        private int _hiddenLayerSize = 32;

        /// <summary>
        /// デバイス(CPU or GPU)
        /// </summary>
        private Device _device = CPU;
        #endregion

        /// <summary>
        /// コンストラクタ
        /// </summary>
        /// <param name="inputSize">入力する画像のサイズ</param>
        /// <param name="outputSize">出力するベクトルのサイズ</param>
        public MLModel(int[] inputSize, int outputSize) : base("CNN")
        {
            //ダミーの入力データを作成(各層の次元の初期化に使用)
            Tensor dammyInput = zeros([inputSize[0], inputSize[1]]).unsqueeze(0).unsqueeze(0);
            //畳み込み層の初期化
            _conv1 = Conv2d(in_channels: 1, out_channels: 16, kernelSize: 8, stride: 2);
            _conv2 = Conv2d(in_channels: 16, out_channels: 16, kernelSize: 8, stride: 2);
            //ダミーの畳み込み層の出力
            Tensor dammyConvOutput = _conv1.forward(dammyInput); //畳み込み層1
            dammyConvOutput = _conv2.forward(dammyConvOutput); //畳み込み層2
            dammyConvOutput = flatten(dammyConvOutput, start_dim: 1); //平滑化
            //全結合層の初期化
            _linear1 = Linear(inputSize: dammyConvOutput.shape[1], outputSize: _hiddenLayerSize);
            _linear2 = Linear(inputSize: _hiddenLayerSize, outputSize: outputSize);

            //コンポーネントの登録
            RegisterComponents();


            //GPUを使用できるか
            if (torch.cuda.is_available()) _device = CUDA; //GPUを活用
            //デバイスに転送
            this.to(_device); 
        }

        /// <summary>
        /// 順伝播処理のオーバーライド
        /// </summary>
        /// <param name="input">入力データ</param>
        /// <returns></returns>
        public override Tensor forward(Tensor input)
        {
            //畳み込み
            var x = relu(_conv1.forward(input)); //活性化関数はReLUを使用
            x = relu(_conv2.forward(x));
            //平坦化
            x = torch.flatten(x, start_dim: 1);
            //全結合層
            x = relu(_linear1.forward(x));
            x = softmax(_linear2.forward(x), dim:1); //クラス分類のためSoftmax関数
            return x;
        }

        /// <summary>
        /// バッチ学習
        /// </summary>
        /// <param name="dataset">教師データのリスト</param>
        /// <param name="epochCount">エポック数</param>
        /// <param name="batchSize">バッチサイズ</param>
        public void TrainOnBatch(List<(Tensor input, Tensor output)> dataset, int epochCount, int batchSize)
        {
            //オプティマザの初期化
            var optimizer = optim.Adam(parameters: this.parameters(), lr: 0.001); //学習率を0.001に設定
            //損失関数
            var loss = CrossEntropyLoss();

            //エポック数だけ繰り返し
            for (int epoch = 0; epoch < epochCount; epoch++)
            {
                //バッチ取り出し
                var batcheArray = Utility.GetBatch(dataset, batchSize);

                //バッチの繰り返し
                for (int batch = 0; batch < batcheArray.Length; batch++)
                {
                    //入力データ
                    var input = batcheArray[batch].input.to(_device);
                    //出力データ
                    var output = batcheArray[batch].output.to(_device);
                    
                    //オプティマイザの勾配を初期化
                    optimizer.zero_grad();
                    //推論
                    var predicted = this.forward(input);
                    //残差
                    var error = loss.forward(predicted, output);
                    //逆伝播
                    error.backward();
                    optimizer.step();

                    Console.WriteLine(error.ToSingle());
                }

                //メモリ解放
                GC.Collect();
            }
        }

        /// <summary>
        /// 推論
        /// </summary>
        /// <param name="input">単一の入力データ</param>
        /// <returns></returns>
        public (int index, float probability) Predict(Tensor input) 
        {
            //順伝播
            Tensor output =  this.forward(input.unsqueeze(0).to(_device)).squeeze(0);
            //配列に変換
            float[] array = new float[output.shape[0]];
            for(int i = 0; i < array.Length; i++) array[i] = output[i].ToSingle();
            //最大値をとるインデックスを取得
            int maxIndex = Array.IndexOf(array, array.Max());
            return (maxIndex, array[maxIndex]);
        }
    }
}

 TorchSharpのDataLoaderを使えばバッチ学習をよりスマートに実装できそうですが,今回はUtility.csにバッチ生成用のコードを独自に実装しました.そのほかにもUtility.csには,Bitmap→Tensorの変換なども記載しております.

Utility.cs
using System.Drawing;
using System.IO;
using System.Windows.Controls;
using System.Windows.Media;
using System.Windows.Media.Imaging;
using TorchSharp;
using static TorchSharp.torch;

namespace TorchSharpSupervisedLearning
{
    public static class Utility
    {
        /// <summary>
        /// 学習データの読み込み
        /// </summary>
        /// <param name="folderPath">フォルダパス</param>
        /// <returns></returns>
        public static (List<(Tensor input, Tensor output)> dataset, string[] labels) LoadDataset(string folderPath, int[] imageSize)
        {
            //サブフォルダパスを取得
            var subFolders = System.IO.Directory.GetDirectories(folderPath, "*", System.IO.SearchOption.TopDirectoryOnly);
            //サブフォルダ名をラベルとして認識
            string[] labels = subFolders.Select(x => Path.GetFileName(x)).ToArray();
            //戻り値となる学習データセット
            var dataset = new List<(Tensor input, Tensor output)>();

            //ラベルの繰り返し
            for (int labelIndex = 0; labelIndex < labels.Length; labelIndex++)
            {
                //このラベル(フォルダ)内の画像のパスを取得
                string[] files = Directory.GetFiles(subFolders[labelIndex]);

                //パスの繰り返し
                foreach (var path in files)
                {
                    //画像を読み込み,リサイズする
                    Bitmap bmp = ResizeBitmap(new Bitmap(path), imageSize[0], imageSize[1]);
                    //画像をTensorに変換
                    Tensor imageTensor = BitmapToTensor(bmp);

                    //ラベルを配列化
                    float[] labelArray = new float[labels.Length];
                    labelArray[labelIndex] = 1;
                    //Tensorに変換
                    Tensor labelTensor = tensor(rawArray: labelArray, dimensions: new long[] { labelArray.Length }, dtype: float32);

                    //リストに追加
                    dataset.Add((imageTensor, labelTensor));
                }
            }


            return (dataset, labels);
        }

        /// <summary>
        /// Bitmapのリサイズ
        /// </summary>
        /// <param name="originalBitmap"></param>
        /// <param name="width"></param>
        /// <param name="height"></param>
        /// <returns></returns>
        static Bitmap ResizeBitmap(Bitmap originalBitmap, int width, int height)
        {
            // 指定されたサイズで新しいBitmapを作成
            Bitmap resizedBitmap = new Bitmap(width, height);
            // Graphicsオブジェクトを使ってリサイズ
            using (Graphics graphics = Graphics.FromImage(resizedBitmap))
            {
                //背景の初期化
                graphics.Clear(System.Drawing.Color.White);
                //補間モードを指定
                graphics.InterpolationMode = System.Drawing.Drawing2D.InterpolationMode.High;
                //新しいサイズで描画
                graphics.DrawImage(originalBitmap, 0, 0, width, height);
                return resizedBitmap;
            }

        }

        /// <summary>
        /// BitmapをTensorに変換
        /// </summary>
        /// <param name="bitmap">Bitmap</param>
        /// <returns></returns>
        public static Tensor BitmapToTensor(Bitmap bitmap)
        {
            //データの次元
            int width = bitmap.Width;
            int height = bitmap.Height;
            int channels = 1; // 白黒のため
            // ピクセルデータを格納する配列を作成
            float[] imageData = new float[width * height * channels];

            // Bitmapからピクセルデータを取得
            for (int y = 0; y < height; y++)
            {
                for (int x = 0; x < width; x++)
                {
                    System.Drawing.Color pixel = bitmap.GetPixel(x, y);
                    int index = (y * width + x) * channels;
                    imageData[index] = (pixel.R + pixel.G + pixel.B) / 3 / 255.0f;
                }
            }

            // Tensorを作成 (チャンネル、高さ、幅の順序で配置)
            var tensor = torch.tensor(rawArray: imageData, dimensions: new long[] { channels, height, width }, dtype: float32);
            return tensor;
        }

        /// <summary>
        /// 訓練データをバッチごとに分割して返す
        /// </summary>
        /// <param name="dataset">訓練データセット</param>
        /// <param name="batchSize">バッチサイズ</param>
        /// <returns></returns>
        public static (Tensor input, Tensor output)[] GetBatch(List<(Tensor input, Tensor output)> dataset, int batchSize)
        {
            //バッチに格納するデータの順番
            int[] order = Enumerable.Range(0, dataset.Count).ToArray();
            //バッチ数
            int batchCount = (int)Math.Ceiling((decimal)dataset.Count / batchSize);
            //バッチの集合(リスト)
            var batchList = new List<(Tensor input, Tensor output)>();

            //バッチの繰り返し
            for (int batchIndex = 0; batchIndex < batchCount; batchIndex++)
            {
                //バッチ
                List<(Tensor, Tensor)> batch = new List<(Tensor, Tensor)>();

                //データの繰り返し
                for (int i = 0; i < batchSize; i++)
                {
                    //データのインデックス
                    int dataIndex = batchIndex * batchSize + i;
                    //インデックスが範囲内のとき,バッチに追加
                    if (dataIndex < dataset.Count) batch.Add(dataset[order[dataIndex]]);
                }

                //バッチ内のTensorを1つに結合
                Tensor tempInput = torch.stack(batch.Select(x => x.Item1));
                Tensor tempOutput = torch.stack(batch.Select(x => x.Item2));
                //バッチリストに格納
                batchList.Add((tempInput, tempOutput));
            }

            //配列に変換して返す
            return batchList.ToArray();
        }

        /// <summary>
        /// InkCanvasをBitmapに変換
        /// </summary>
        /// <param name="canvas">InkCanvas</param>
        /// <param name="bitmapSize">Bitmapのサイズ</param>
        /// <returns></returns>
        public static Bitmap InkCanvasToBitmap(InkCanvas canvas, int[] bitmapSize)
        {
            // InkCanvasのサイズを取得
            int width = (int)canvas.ActualWidth;
            int height = (int)canvas.ActualHeight;
            // RenderTargetBitmapに変換
            RenderTargetBitmap renderBitmap = new RenderTargetBitmap(width, height, 96, 96, PixelFormats.Pbgra32);
            renderBitmap.Render(canvas);

            //Bitmapに変換
            using (var stream = new MemoryStream())
            {
                PngBitmapEncoder encoder = new PngBitmapEncoder();
                encoder.Frames.Add(BitmapFrame.Create(renderBitmap));
                encoder.Save(stream);
                Bitmap bmp = new Bitmap(stream);
                //リサイズして返す
                return ResizeBitmap(bmp, bitmapSize[0], bitmapSize[1]);
            }
        }
    }
}

4. GUIの作成

 手書きの数字について,学習と推論の両方の機能を備えたアプリケーションを作成したいので,以下のようなGUIを作成しました.Trainボタンをクリックすると学習データを読み込み学習を開始します.学習中のLossの推移は,GUIと同時に立ち上がるコンソールに表示するようにしました.GUIの上部の白い正方形領域はInkCanvasで,ここにマウスでドラッグして数字を描き,PredictボタンをクリックするとGUI上に推論結果が表示されます.またResetボタンでInkCanvasを白紙に戻せます.

image.png

 ボタンクリックなど各種イベントは以下のように実装しております.エポック数50,バッチサイズ128としてTrainボタンをクリックすると学習を行います.

MainWindow.xaml.cs
using Microsoft.Win32;
using SkiaSharp;
using System.Drawing;
using System.IO;
using System.Text;
using System.Windows;
using System.Windows.Controls;
using System.Windows.Data;
using System.Windows.Documents;
using System.Windows.Input;
using System.Windows.Media;
using System.Windows.Media.Imaging;
using System.Windows.Navigation;
using System.Windows.Shapes;
using TorchSharp;
using static TorchSharp.torch;

namespace TorchSharpSupervisedLearning
{
    /// <summary>
    /// Interaction logic for MainWindow.xaml
    /// </summary>
    public partial class MainWindow : Window
    {
        public MainWindow()
        {
            InitializeComponent();

            //キャンバスのペンサイズを設定
            cnvDrawingArea.DefaultDrawingAttributes.Width = 15;
            cnvDrawingArea.DefaultDrawingAttributes.Height = 15;

            //推論ボタンを機能停止
            btnPredict.IsEnabled = false;
        }

        /// <summary>
        /// 機械学習モデル
        /// </summary>
        MLModel _model;
        /// <summary>
        /// データのラベル
        /// </summary>
        string[] _labels;
        /// <summary>
        /// 画像のサイズ
        /// </summary>
        int[] _imageSize = [128, 128];

        /// <summary>
        /// Canvasのリセットボタン
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void btnReset_Click(object sender, RoutedEventArgs e)
        {
            cnvDrawingArea.Strokes.Clear(); //ストロークのクリア
        }

        /// <summary>
        /// 推論ボタン
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void btnPredict_Click(object sender, RoutedEventArgs e)
        {
            //CanvasをBitmapに変換
            Bitmap bitmap = Utility.InkCanvasToBitmap(cnvDrawingArea, _imageSize);
            //Tensorに変換して推論
            (int labelIndex, float probability) = _model.Predict(Utility.BitmapToTensor(bitmap));
            //フォームに推論結果と確立を表示
            txtPredicted.Text = "Predicted:  " + _labels[labelIndex] + string.Format(" ({0} %)", probability * 100);
        }

        /// <summary>
        /// 学習ボタン
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void btnTrain_Click(object sender, RoutedEventArgs e)
        {
            //フォルダ選択ダイアログ
            OpenFolderDialog dialog = new OpenFolderDialog();
            dialog.ShowDialog();//表示

            //フォルダが1つだけ選択されたとき
            if(!dialog.Multiselect && dialog.FolderName != null)
            {
                //フォルダ名を取得
                string folderName = dialog.FolderName;
                //訓練用データセットとラベル
                (var dataset, _labels) = Utility.LoadDataset(folderName, _imageSize);
                //モデルの初期化
                _model = new MLModel(inputSize: _imageSize, outputSize: _labels.Length);//ラベル数から出力次元を決定
                //バッチ学習
                _model.TrainOnBatch(dataset: dataset, epochCount: 50, batchSize: 128);
                //推論ボタンを有効にする
                btnPredict.IsEnabled = true; 
            }
        }


     
    }
}

XAMLは以下の通り.

<Window x:Class="TorchSharpSupervisedLearning.MainWindow"
        xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"
        xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
        xmlns:d="http://schemas.microsoft.com/expression/blend/2008"
        xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
        xmlns:local="clr-namespace:TorchSharpSupervisedLearning"
        mc:Ignorable="d"
        Title="MainWindow" Height="500" Width="309"
        Background="Gray">
    <Grid>
        <StackPanel Orientation="Vertical">
            <Canvas Background="Transparent" Margin="10" Width="257" Height="257">
                <InkCanvas x:Name="cnvDrawingArea" Width="256" Height="256" Margin="0" Background="White"/>
            </Canvas>
            <TextBlock x:Name="txtPredicted" Text="null" Height="32" Width="250" Foreground="Yellow" TextAlignment="Center" FontSize="18" FontStyle="Normal"/>
            <Button x:Name="btnReset" Content="Reset" Height="32" Width="150" Margin="5" Click="btnReset_Click"/>
            <Button x:Name="btnPredict" Content="Predict" Height="32" Width="150" Margin="5" Click="btnPredict_Click"/>
            <Button x:Name="btnTrain" Content="Train" Height="32" Width="150" Margin="5" Click="btnTrain_Click"/>
        </StackPanel>
    </Grid>
</Window>

5. アプリケーションの実行

 以下の動画のように操作することで,①手書き数字の学習,②手書き文字の推論,③Canvasのリセットを行えます.

①Trainボタンをクリックするとフォルダ選択ダイアログが立ち上がるので,学習用の画像データセット(ソースコードに付属のTrainningDataフォルダ)を選択すると自動で学習が開始します.動画では学習中のLossが右のコンソールに出力されており,Lossが減少していることがわかります.
②InkCanvasに数字を描いた状態でPredictボタンをクリックすると,InkCanvasの下に黄色文字で推論結果が表示されます. 動画では,精度よく推論できていることがわかります.
③ResetボタンでInkCanvasをリセットできます.

このようにTorchSharpを用いることで,C#でも十分に機械学習が行えます.

ソースコードはHPを参照:https://kkaneko-lab.com/?p=285

4
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
6