C#
VisualStudio
機械学習
Accord.NET

C#で機械学習をやってみた(Accord.NET)

Qiita初投稿。たびびとの服とこんぼうとおなべのふたで戦うプログラマです。

機械学習が大流行しているけど、なんだか敷居が高そう。
調べてたら、Accord.NETというライブラリでぽちぽち遊べそうなので、さっくりためしてみた。
さっくりなので荒い。
でも、Accord.NETの記事って意外と少なくて苦労したし、以前の記事の内容だと今では非推奨になってたりして、とりあえず動いたので公開してみようかなと思った。
興味もって使ってみる人、出てきてくれるといいな。

つくったのは、お絵かきした動物(ねこ・いぬ・とり)を判断して、それぞれの鳴き声を出すデスクトップアプリケーション。
ついでにWPFの勉強もしてみた。

ひらがなのところは適宜読み替えてください。

わけわからないところが多く、ぐぐりまくって解決しました。

まずは環境整備から

とりあえずVisual Studio2017 Communityをつかった。
ツールバーのプロジェクト(P)>NuGetパッケージの管理を選択。
左カラムのオンラインを選択したのち、accordを入力。

Accord.MachineLearning
Accord.Imaging
Accord.Statistics
Accord.IO
Accord.Vision
などもろもろをインストール。
(てきとうにそれらしいものを手当たり次第インストールしたので、何をインストールしたのか具体的に忘れてしまった)

とりあえずAccord.VisionをインストールしないとBagOfVisualWordsでエラーが出るので注意。
ここでけっこう詰まった。

あと、てきとうに3つディレクトリを作成して、自分の描いたいぬ・ねこ・とりのイラストをそれぞれに入れておく。

これで準備OK。

とりあえず見た目をなんとかする

新規プロジェクトでWPFを選択する。今回はWindowを透過して、背景を透過したイメージを用いることで、丸っこい見た目のアプリを作成してみた。

必要なのは、
お絵かきするためのInkCanvas。
(ツールボックスに入ってなかったから、調べて新しくいれた)
絵がかけたときの決定ボタン。
(このボタンを押したら、絵を判定して鳴き声が鳴る)
絵を消すときのボタン。
(必ずしも必要ではないが、書き直す必要があるときに使う)
アプリケーション終了ボタン。

<Window x:Class="おえかきあぷり.MainWindow"
        xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"
        xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
        xmlns:d="http://schemas.microsoft.com/expression/blend/2008"
        xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
        xmlns:local="clr-namespace:おえかきあぷり"
        mc:Ignorable="d"
        WindowStyle="None"
        AllowsTransparency="True"
        Background="Transparent"
        ResizeMode="CanResizeWithGrip"
        Title="MainWindow" Height="500" Width="500">
    <Viewbox>
        <Grid>
            <Image Height="500" Width="500" Source="image/ばっくぐらうんどいめーじ"/>
            <InkCanvas x:Name="inkCanvas" HorizontalAlignment="Left" Height="108" Margin="138,216,0,0" VerticalAlignment="Top" Width="229"/>
            <Button Content="おわろうかな。" HorizontalAlignment="Left" Margin="234,146,0,0" VerticalAlignment="Top" Width="34" Height="32" Background="#00E2EFD6" Click="Button_Click"/>
            <Button Content="かけたよ!" HorizontalAlignment="Left" Height="83" Margin="218,384,0,0" VerticalAlignment="Top" Width="106" Background="#00DDDDDD" Click="Button_Click_1"/>
            <TextBox x:Name="textBox1" HorizontalAlignment="Left" Height="23" Margin="204,356,0,0" TextWrapping="Wrap" Text="TextBox" VerticalAlignment="Top" Width="120"/>
            <Button Content="もういちどかくよ!" HorizontalAlignment="Left" Height="44" Margin="174,98,0,0" VerticalAlignment="Top" Width="40" Background="#00DDDDDD" Click="Button_Click_2"/>

        </Grid>
    </Viewbox>
</Window>

めいん

・・・どう書いて良いのかわからないので、頭からおっていく。

using System;
using System.Collections.Generic;
using System.Drawing;
using Accord.MachineLearning;
using Accord.MachineLearning.VectorMachines;
using Accord.MachineLearning.VectorMachines.Learning;
using System.Windows;
using System.Windows.Media;:
using System.Windows.Media.Imaging;
using Accord.Imaging;
using System.IO;
using System.Windows.Controls;
using Accord.Statistics.Kernels;
using Accord.IO;

このあたりはもろもろ、エラーが出たらおすすめされるままに追加しておけばOK。

namespace おえかきあぷり
{
    /// <summary>
    /// MainWindow.xaml の相互作用ロジック
    /// </summary>
    public partial class MainWindow : Window
    {
        private MediaPlayer mediaPlayer = new MediaPlayer();
        private List<ImageItem> itemList = new List<ImageItem>();
        private string directoryPath = @"・・・てきとうに作成したディレクトリのパス";
        private int codeWordCount = 200;
        private BagOfVisualWords bagofVW;
        private int classes = 3;
        private MulticlassSupportVectorMachine<ChiSquare> msvm;

        /// <summary>
        /// 学習に利用する画像データ
        /// </summary>
        internal class ImageItem
        {
            public String FileName;
            public Bitmap bmp;
            public int Classification;
            public double[] codeWord;

            public ImageItem()
            {
            }
        }

今回はねこ・いぬ・とりの3つのクラスにわける。

        public MainWindow()
        {
            InitializeComponent();

            // ウィンドウをマウスのドラッグで移動できるようにする
            this.MouseLeftButtonDown += (sender, e) => { this.DragMove(); };

            // 既存の画像データ集を用いて学習を開始する
            StartLearning();
        }
        // 学習を開始する
        private void StartLearning()
        {
            var catsCreator = new ItemsFactory("cats", directoryPath, itemList, 0);
            var birdsCreator = new ItemsFactory("birds", directoryPath, itemList, 1);
            var dogsCreator = new ItemsFactory("dogs", directoryPath, itemList, 2);
            bagofVW = new TrainFactory(codeWordCount, bagofVW, itemList).bagofVW;

            var inputs = new InputFactory(itemList).input;
            var outputs = new OutputFactory(itemList).output;

            msvm = new MulticlassSupportVectorMachine<ChiSquare>(0, new ChiSquare(), classes);

            // 学習アルゴリズムを作成する
            var teacher = new MulticlassSupportVectorLearning<ChiSquare>()
            {
                Learner = (param) => new SequentialMinimalOptimization<ChiSquare>()
                {
                    UseComplexityHeuristic = true,
                    UseKernelEstimation = true
                }
            };

            msvm = teacher.Learn(inputs, outputs);

            var calibration = new MulticlassSupportVectorLearning<ChiSquare>()
            {
                Model = msvm,
                Learner = (param) => new ProbabilisticOutputCalibration<ChiSquare>()
                {
                    Model = param.Model
                }
            };

            calibration.ParallelOptions.MaxDegreeOfParallelism = 1;
            calibration.Learn(inputs, outputs);

            inkCanvas.Strokes.Clear();
            textBox1.Clear();
        }

ディレクトリに入れておいた画像をImageItemクラスに格納して、リストにいれる。

        // 各動物の画像を学習用データに加える
        internal class ItemsFactory
        {
            private string directoryPath;
            private List<ImageItem> itemList;

            public ItemsFactory(string animalName, string directoryPath, List<ImageItem> itemList, int classNum)
            {
                this.directoryPath = directoryPath;
                this.itemList = itemList;

                List<String> list = new List<String>();

                System.IO.DirectoryInfo di = new System.IO.DirectoryInfo(Path.Combine(directoryPath, animalName));
                IEnumerable<System.IO.FileInfo> files =
                    di.EnumerateFiles("*", System.IO.SearchOption.AllDirectories);

                //ファイルを列挙する
                foreach (System.IO.FileInfo f in files)
                {
                    list.Add(f.FullName);
                }

                foreach (String fileName in list)
                {
                    ImageItem ii;
                    ii = new ImageItem();
                    ii.FileName = fileName;
                    FileStream fs;
                    fs = new FileStream(fileName, FileMode.Open, FileAccess.Read);
                    Bitmap source= (Bitmap)System.Drawing.Bitmap.FromStream(fs);
                    fs.Close();
                    ii.bmp = new Bitmap(source);
                    ii.Classification = classNum;
                    itemList.Add(ii);
                }
            }
        }

Bag Of Visual Words分類器の作成。
(特徴抽出して集計している・・・らしい?)

        // 訓練用インプットデータを作成する
        internal class TrainFactory
        {
            public BagOfVisualWords bagofVW;

            public TrainFactory(int codeWordCount, BagOfVisualWords bagofVW, List<ImageItem> itemList)
            {
                BinarySplit binarySplit = new BinarySplit(codeWordCount);
                bagofVW = new BagOfVisualWords(binarySplit);
                List<Bitmap> bitmapList = new List<Bitmap>();

                foreach (ImageItem item in itemList)
                {
                    bitmapList.Add(item.bmp);
                }

                Bitmap[] trainImages = bitmapList.ToArray();
                bagofVW.Learn(trainImages);

                foreach (ImageItem item in itemList)
                {
                    item.codeWord = bagofVW.Transform(item.bmp);
                }

                this.bagofVW = bagofVW;
            }
        }

インプットデータとアウトプットデータの作成。
この場合だと、それぞれの絵のデータと、その絵のクラス(ねこ・いぬ・とり)のデータが1対1に対応している配列をそれぞれ作成する。

        // 学習用インプットデータを作成する
        internal class InputFactory
        {
            public double[][] input;

            public InputFactory(List<ImageItem> list)
            {
                var inputList = new List<double[]>();

                foreach (ImageItem item in list)
                {
                    inputList.Add(item.codeWord);
                }

                input = inputList.ToArray();
            }
        }

        // 学習用アウトプットデータを作成する
        internal class OutputFactory
        {
            public int[] output;

            public OutputFactory(List<ImageItem> list)
            {
                var outputList = new List<int>();

                foreach (ImageItem item in list)
                {
                    outputList.Add(item.Classification);
                }

                output = outputList.ToArray();
            }
        }

        // アプリを終了する
        private void Button_Click(object sender, RoutedEventArgs e)
        {
            this.Close();
        }

ボタンがクリックされたら、絵がねこなのかいぬなのかとりなのか判定して、鳴き声を鳴らす。

        //判定するイラストの読み込みと鳴き声
        private void Button_Click_1(object sender, RoutedEventArgs e)
        {
            Bitmap bitmap = new BitmapFactory(inkCanvas).bitmap;

            double[] codeword = bagofVW.Transform(bitmap);
            int classResult = msvm.Decide(codeword);

            textBox1.Text = "Result:" + Convert.ToString(classResult) + "\r\n";

            String cryStr;

            if (classResult == 0)
            {
                cryStr = @"・・・ねこのなきごえ.mp3";
            }
            else if (classResult == 1)
            {
                cryStr = @"・・・とりのなきごえ.mp3";
            }
            else
            {
                cryStr = @"・・・いぬのなきごえ.mp3";
            }

            Uri cryFile = new Uri(cryStr);
            mediaPlayer.Open(cryFile);
            mediaPlayer.Play();

正しいかどうか簡易にフィードバックをとって再学習。
(ここは新しく学んだところだけ付け加え学習する形に直したい・・・)

            //フィードバック

            // Configure the message box to be displayed
            string messageBoxText = "猫ならYes、犬ならNo、鳥ならCancelを押してください。";
            string caption = "フィードバック";
            MessageBoxButton button = MessageBoxButton.YesNoCancel;
            MessageBoxImage icon = MessageBoxImage.Question;

            // Display message box
            MessageBoxResult result = MessageBox.Show(messageBoxText, caption, button, icon);

            string animalStr = "";

            // Process message box results
            switch (result)
            {
                //猫
                case MessageBoxResult.Yes:
                    animalStr = "cats";
                    break;
                //犬
                case MessageBoxResult.No:
                    animalStr = "dogs";
                    break;
                //鳥
                case MessageBoxResult.Cancel:
                    animalStr = "birds";
                    break;
            }

            string saveDi = Path.Combine(directoryPath, animalStr);
            string imagePath = DateTime.Now.ToString("yyyyMMddhhmmss") + ".jpg";

            string savePath = Path.Combine(saveDi, imagePath);

            SaveImage(savePath);

            //再学習
            StartLearning();
        }

描いた絵はフィードバックをもとにして、それぞれのディレクトリに保存する。

        // InkCanvasを画像として保存する
        private void SaveImage(string file)
        {
            Rect rectBounds = inkCanvas.Strokes.GetBounds();
            DrawingVisual dv = new DrawingVisual();
            DrawingContext dc = dv.RenderOpen();
            dc.PushTransform(new TranslateTransform(-rectBounds.X, -rectBounds.Y));
            dc.DrawRectangle(inkCanvas.Background, null, rectBounds);
            inkCanvas.Strokes.Draw(dc);
            dc.Close();

            RenderTargetBitmap rtb = new RenderTargetBitmap(
                (int)rectBounds.Width, (int)rectBounds.Height,
                96, 96,
                PixelFormats.Default);
            rtb.Render(dv);

            BitmapEncoder enc = new JpegBitmapEncoder();

            if (enc != null)
            {
                enc.Frames.Add(BitmapFrame.Create(rtb));
                System.IO.Stream stream = System.IO.File.Create(file);
                enc.Save(stream);
                stream.Close();
            }
        }


        // お描き画像を消す
        private void Button_Click_2(object sender, RoutedEventArgs e)
        {
            inkCanvas.Strokes.Clear();
        }

↓InkCanvasをBitmapに変換するのが意外と重たかったため、急遽わけた。
SaveImageとかぶってるから、もう少し整理できそうな気はする。

        // InkCanvasをBitMapで返す
        internal class BitmapFactory
        {
            public Bitmap bitmap;
            private InkCanvas inkCanvas;

            public BitmapFactory(InkCanvas inkCanvas)
            {
                this.inkCanvas = inkCanvas;
                double width = inkCanvas.ActualWidth;
                double height = inkCanvas.ActualHeight;
                RenderTargetBitmap bmpCopied = new RenderTargetBitmap((int)Math.Round(width), (int)Math.Round(height), 96, 96, PixelFormats.Default);
                DrawingVisual dv = new DrawingVisual();
                using (DrawingContext dc = dv.RenderOpen())
                {
                    VisualBrush vb = new VisualBrush(inkCanvas);
                    dc.DrawRectangle(vb, null, new Rect(new System.Windows.Point(), new System.Windows.Size(width, height)));
                }
                bmpCopied.Render(dv);
                System.Drawing.Bitmap bitmap;
                using (MemoryStream outStream = new MemoryStream())
                {
                    BitmapEncoder enc = new BmpBitmapEncoder();
                    enc.Frames.Add(BitmapFrame.Create(bmpCopied));
                    enc.Save(outStream);
                    bitmap = new System.Drawing.Bitmap(outStream);
                }

                this.bitmap = bitmap;
            }
        }
    }
}

なかなか楽しく遊べました。まる。