4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Rustで機械学習(Linfa)

Last updated at Posted at 2022-05-02

はじめに

Rustにおけるsklearn的なものと言えばSmartCoreとLinfaが有名らしいが、2022/5現在SmartCoreは開発が滞っているっぽい(GitHubの最終コミットが半年前)ので、Linfaを触ってみることにした。今回はとりあえずのテスト。

環境

  • Ubuntu 20.04 LTS on WSL2 (Windows 10)
  • Rust:
$ cargo --version
cargo 1.60.0 (d1fd9fe 2022-03-01)
$ rustc --version
rustc 1.60.0 (7737e0b5c 2022-04-04)

クイックスタート

アルゴリズムはナイーブベイズ(二値分類)、データセットはlinfa-datasetsに含まれているUCIのワインデータを使う。

Cargo.toml

cargo new <project_name>でプロジェクトを作成し、生成されたCargo.tomlファイルに下記のように追記する。なお、linfaの"features"においてopenblasを指定しているが、WindowsとmacOSはBLASのバックエンドがIntel MKLのみらしい(Intel MKLをバックエンドで使う場合は、この記事が参考になる? ... 未確認)。

Cargo.toml
[package]
name = "rust-bayes"
version = "0.1.0"
edition = "2021"

 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
linfa = { version = "0.5.0", features=["openblas-system"] }
linfa-datasets = { version = "0.5.0", features = ["winequality"] }
linfa-bayes = { version = "0.5.0" }

main.rs

公式のexampleからほぼそのまま拝借した。

main.rs
use linfa::metrics::ToConfusionMatrix;
use linfa::traits::{Fit, Predict};
use linfa_bayes::{GaussianNb, Result};


fn main() -> Result<()> {
    // データセットを読み込み、ターゲット(ワインの品質の評価値)を二値に変換
    //     品質は、0から10の間で評価されており、0が最低で10が最高
    let (train, valid) = linfa_datasets::winequality()
        .map_targets(|x| if *x > 6 { "good" } else { "bad" })
        .split_with_ratio(0.9);

    // モデルの訓練
    let model = GaussianNb::params().fit(&train)?;

    // 推論
    let pred = model.predict(&valid);

    // 混同行列の計算
    let cm = pred.confusion_matrix(&valid)?;

    // 混同行列と精度の出力(MCCはマシューズ相関係数)
    println!("{:?}", cm);
    println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());
    
    Ok(())
}

実行結果

~/rust/rust-bayes$ cargo run
   Compiling rust-bayes v0.1.0 (/home/ezoalbus/rust/rust-bayes)
    Finished dev [unoptimized + debuginfo] target(s) in 2.10s
     Running `target/debug/rust-bayes`

classes    | good       | bad       
good       | 10         | 7         
bad        | 12         | 130       

accuracy 0.8805031, MCC 0.45080975

途中、"error: linking with `cc` failed: exit status: 1" が出たので、sudo apt install libopenblas-devなどをした。

Titanicデータセットで実行

データはKaggleから落としてくる("train.csv"のみ使用)。
https://www.kaggle.com/competitions/titanic/data

main.rs
use linfa::prelude::*;
use linfa::metrics::ToConfusionMatrix;
use linfa::traits::{Fit, Predict};
use linfa_bayes::GaussianNb;
use linfa_bayes::Result as NBResult;

use polars::prelude::*;
use polars::prelude::Result as PolarResult;


fn read_csv2df(path: &str) -> PolarResult<DataFrame> {
    // csvを読み込んで、DataFrameでreturn
    CsvReader::from_path(path)?
            .has_header(true)
            .finish()
}

fn split_x_y(df: &DataFrame) -> (PolarResult<DataFrame>, PolarResult<DataFrame>) {
    // 特徴量とターゲットに分割
    let target = df.select(vec!["Survived"]);
    // 運賃の情報だけを特徴量として使う
    let features = df.select(vec!["Fare"]);
    (features, target)
}

fn main() -> NBResult<()> {
    //  trainの読み込み
    let train_path = "./data/train.csv";
    let train_df = read_csv2df(&train_path).unwrap();

    // 特徴量とターゲットに分割
    let (train_x, train_y) = split_x_y(&train_df);

    // 前処理
    let train_x = train_x.unwrap().to_ndarray::<Float64Type>().unwrap();
    let train_y = train_y.unwrap().to_ndarray::<Int64Type>().unwrap();
    let train_y = train_y.map(|x| if *x == 1 {"Survived"} else {"Not Survived"});

    // DatasetBaseとしてまとめ、trainとvalidに分割する
    let (train, valid) = DatasetBase::new(
        train_x, 
        train_y
    ).split_with_ratio(0.8);
    
    // 訓練
    let model = GaussianNb::params().fit(&train)?;

    // 推論
    let pred = model.predict(&valid);
    
    // 混同行列の計算
    let cm = pred.confusion_matrix(&valid)?;

    // 混同行列と精度の出力(MCCはマシューズ相関係数)
    println!("{:?}", cm);
    println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());
    
    Ok(())
}

実行結果

~/rust/linfa/rust-titanic$ cargo run
   Compiling rust-titanic v0.1.0 (/home/ezoalbus/rust/linfa/rust-titanic)
    Finished dev [unoptimized + debuginfo] target(s) in 12.20s
     Running `target/debug/rust-titanic`

classes    | Not Survived | Survived  
Not Survived | 113        | 2         
Survived   | 48         | 15        

accuracy 0.71910113, MCC 0.35908055

おわり

参考

4
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?