7
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Rustで行列演算を実装した話 ーーQiita

Last updated at Posted at 2020-12-16

この記事は 限界開発鯖アドベントカレンダー 17日目の記事です

読む前に

この記事における行列演算は、高校生で習う2X2行列の逆行列の演算までを指します。
また、今回のコードはここにあります

発端

Rustが書きたい...
行列演算わかりずれぇ...

そうだ、Rustで、書こう

おしながき

  • 基礎を作る
  • 行列作成
  • 単位行列作成
  • 加算
  • 乗算(外積)
  • 正則チェック(2X2のみ)
  • 逆行列(2X2)のみ

基礎を作る

やはり最初はデータを保存する型を作らなくてはならない。
というわけで最初にstructドーン!!!

#[derive(Debug, PartialEq)]
pub struct Matrix<T>
where
    T: MatrixNumber,
{
    matrix: Vec<Vec<T>>,
}

MatrixNumber

急に変なトレイトが出てきた。
これは、行列演算に利用する数値の型を定義するもので、以下の性質を持つ。

  • 四則演算ができる
  • signed
  • コピー可能

今回はこれを満たすためにnum_traitsを利用した。
定義は以下の通りである。

pub trait MatrixNumber
where
    Self: num_traits::Signed + std::marker::Copy,
{
}

というわけで、上のMatrixは任意の行列に利用可能な数値型の二次元配列を表している。

また、最低限のgetterを作成する。

impl<T> Matrix<T>
where
    T: MatrixNumber,
{
    fn get_row(&self) -> usize {
        self.matrix.len()
    }

    fn get_column(&self) -> usize {
        self.matrix[0].len()
    }

    fn get_value(&self, row: usize, column: usize) -> T {
        self.matrix[row - 1][column - 1]
    }
}

このように

  • 行取得
  • 列取得
  • 特定の場所の値取得

のgetterを定義した

行列作成

行列の型があっても作れなきゃ意味がない。
そのためcreateを実装する。

fn create(row: usize, column: usize, row_matrix: Vec<T>) -> Result<Self, MatrixError> {
        let length = row * column;
        if length != row_matrix.len() {
            return Err(MatrixError::InvalidLength {
                expected: length,
                found: row_matrix.len(),
            });
        }

        let mut matrix = vec![];
        for m in 0..row {
            matrix.push(vec![]);
            for n in 0..column {
                matrix[m].push(row_matrix[m * column + n]);
            }
        }

        Ok(Matrix { matrix })
    }

結構ややこしいコードになった。

引数

  • 行列の値が入ったVector

中身の解説

受け取った一次元のVectorを行、列の情報をもとに二次元配列として変換してstructに包んで返す。

作成可能な条件

行列の値が入ったVectorの長さが行*列の値と同じことである。
ちなみに、これを守らなかった場合Errが返ってくる。

単位行列作成

かなり重要なものなので早いうちに作れるようにしたい。
というわけで早めに作った。

fn create_unit_matrix(n: usize) -> Self {
        let mut matrix = vec![];
        for m in 0..n {
            matrix.push(vec![]);
            for l in 0..n {
                matrix[m].push(if m == l { T::one() } else { T::zero() });
            }
        }

        Matrix { matrix }
    }

引数

  • 単位行列の大きさ

中身の解説

列m, 行lの値をi_lmとするときに

  • m = lなら1
  • それ以外なら0

となる行列を作成する。

加算

基本的な行列の作成関数ができたので、演算を実装していく。
最初はaddを実装する。
本当はopts::Addを実装したいのだが、Errが返る以上ただのメソッドになっている。

fn add(&self, y: &Self) -> Result<Self, MatrixError> {
        let row = self.get_row();
        let column = self.get_column();
        if row != y.get_row() || column != y.get_column() {
            return Err(MatrixError::CannotCalculate {
                x_row: row,
                x_column: column,
                y_row: y.get_row(),
                y_column: y.get_column(),
            });
        }

        let mut matrix = vec![];

        for m in 0..row {
            matrix.push(vec![]);
            for n in 0..column {
                matrix[m].push(self.get_value(m + 1, n + 1) + y.get_value(m + 1, n + 1));
            }
        }

        Ok(Matrix { matrix })
    }

引数

  • 自身の参照
  • 加算する値の参照

中身の解説

まず、行列の大きさの比較をする。
もし異なるのであればその時点でErrを返す。

その後、行列の同じ場所の値を加算した行列を新しく作って返している。

演算可能な条件

行列のサイズが同じであること。

乗算(外積)

続いて、乗算を実装する。
これもopts::Mulの実装はErrが返る以上不可能である。

fn cross(&self, y: &Self) -> Result<Self, MatrixError> {
        let row = self.get_row();
        let column = y.get_column();

        if self.get_column() != y.get_row() {
            return Err(MatrixError::CannotCalculate {
                x_row: self.get_row(),
                x_column: self.get_column(),
                y_row: y.get_row(),
                y_column: y.get_column(),
            });
        }

        let z = self.get_column();

        let mut matrix = vec![];

        for m in 1..row + 1 {
            matrix.push(vec![]);
            for n in 1..column + 1 {
                let mut ans = T::zero();
                for l in 1..z + 1 {
                    ans = ans + self.get_value(m, l) * y.get_value(l, n);
                }
                matrix[m - 1].push(ans);
            }
        }

        Ok(Matrix { matrix })
    }

引数

  • 自身の参照
  • 掛け合わせる値の参照

中身の解説

初めに、自身の参照の列と掛け合わせる参照の行が同一であるか確認する。
もし異なればErrを返す。

(mXz)X(zXn)においてi_mnは左の行列のm行と右の行列のn列の値となるため、それを計算し、行列を生成している。

演算可能な条件

自身の参照の列数と掛け合わせる値の行数が一致していること

正則チェック(2X2)のみ

これは次に紹介する逆行列(2X2)を実装するための機能である。

fn check_regular(&self) -> Result<bool, MatrixError> {
        if self.get_row() != 2 || self.get_column() != 2 {
            Err(MatrixError::NonSupportedMatrixShape {
                row: self.get_row(),
                column: self.get_column(),
            })
        } else {
            Ok(self.get_value(1, 1) * self.get_value(2, 2)
                - self.get_value(1, 2) * self.get_value(2, 1)
                != T::zero())
        }
    }

引数

  • 自身の参照

中身の解説

まず2X2行列であるか確認する。
もしそれ以外ならErrを返す。
その後、(1,1) * (2, 2) - (1,2) * (2,1)の値が0であるか確認する。
これが0でなければtrue、0ならばfalseを返す。

演算可能な条件

2X2行列であること。
これ以外の場合は利用している条件が成り立たない。

逆行列(2X2)

本日のメインディッシュ
これが目的でこの行列演算を実装していた。

fn inverse_matrix(&self) -> Result<Self, MatrixError> {
        match self.check_regular() {
            Ok(status) => {
                if !status {
                    Err(MatrixError::ZeroDeterminant {})
                } else {
                    let determinant = self.get_value(1, 1) * self.get_value(2, 2)
                        - self.get_value(1, 2) * self.get_value(2, 1);
                    let new_matrix = vec![
                        vec![
                            self.get_value(2, 2) / determinant,
                            self.get_value(2, 1) / determinant * -T::one(),
                        ],
                        vec![
                            self.get_value(1, 2) / determinant * -T::one(),
                            self.get_value(1, 1) / determinant,
                        ],
                    ];
                    Ok(Self { matrix: new_matrix })
                }
            }
            Err(err) => Err(err),
        }
    }

引数

  • 自身の参照

中身の解説

まず、パターンマッチで正則チェッカーがErrを吐く(2X2以外の行列)の場合を弾いている。
その後、正則でなければErrを返している。
2X2かつ正則な際以下の行列を作成する

\begin{pmatrix}
a & b \\
c & d
\end{pmatrix}

をもとの行列として、

\frac{1}{a*d-b*c}
\left(
\begin{matrix}
d & -b\\
-c & a
\end{matrix}
\right)

を作成し、返している

演算可能な条件

2X2行列であること
正則であること

実装した感想

  • num_traitsが便利すぎる
  • ここまで抽象的にコードを書くことがなかなかないのでいい経験になった
  • Rustのtrait,where,implが強すぎる

あとがき

今までRustを書いたことがない人には本当にいい言語なのでぜひ一度こんな感じの軽いプログラムで触ってほしい。

7
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?