LoginSignup
5
3

Rust で式テンプレート

Last updated at Posted at 2024-05-14

C++er が Rust を始める前に, Rust でどこまでできるのかを確認すべく,試行錯誤したことを記事にしました.

式テンプレートについては高橋晶氏のブログ記事がわかりやすい内容になっています.

本記事では,スカラーベクトル積を例に, Rust での式テンプレートがどのようなものになるのかを確認しています.

結論の概要は以下になります.

  • Rust で式テンプレートは実現できる.
  • ただし,オペレータオーバーロードを使ったかっこいい書き方を実現しようとするとかなりの手間.

式テンプレートの実装

トレイトの定義

ベクトル型のトレイトを定義します.(C++er 的に解釈すると,trait は基底になりうるかつインスタンス化できないクラスで,struct は final つきのクラスという感じでしょうか)
すべてのベクトル型は,ベクトルの次元を返す dimension と, i 番目の要素を返す get というメソッドを持つものとします.

せっかくなのでベクトルの要素の型をジェネリックパラメータとしたいところですが,今回は簡単のためf64 としています.

trait Vector {
    fn dimension(&self) -> usize;
    fn get(&self, i: usize) -> f64; // 後述の理由で [] 演算子を使わず,戻り値を参照にしていない
}

素朴なベクトル型の実装

1 つの配列をメンバに持つ素朴なベクトル型を実装します.

struct DenseVector {
    array: Vec<f64>,
}

impl Vector for DenseVector {
    fn dimension(&self) -> usize {
        self.array.len()
    }
    fn get(&self, i: usize) -> f64 {
        self.array[i]
    }
}

スカラーベクトル積の実装

簡単な例として,ベクトルのスカラー倍を表す型を実装します.
この型は,スカラー値とベクトルへの参照を保持しておき,get が呼ばれたタイミングでスカラー値と i 番目の要素の積を計算して返します.

struct ScalarVectorProduction<'a, V>
    where V: Vector
{
    scalar: f64,
    vector: &'a V,
}

impl<'a, V> Vector for ScalarVectorProduction<'a, V>
    where V: Vector
{
    fn dimension(&self) -> usize {
        self.vector.dimension()
    }
    fn get(&self, i: usize) -> f64 { // この関数の戻り値は参照にできない
        self.scalar * self.vector.get(i)
    }
}

(ライフタイム注釈をちゃんと理解しないまま使っていますが,コンパイルできて動いているので良しとします. Rust のコンパイラは親切なメッセージを吐いてくれるのでありがたいですね.)

【余談】なぜベクトルの要素の取得に [] 演算子を使わないのか

Rust の [] 演算子は戻り値が参照型に制限されてしまうようで,今回はスカラーベクトル積の戻り値が一時オブジェクトになるため,get というメソッドを用意して戻り値の型を値としています.

掛け算の実装

スカラー値とベクトルから,スカラーベクトル積型を構築して返す関数 product を実装します.
この関数 1 つで,スカラーと Vector トレイトを実装している任意の型との積を計算することができます.

fn product<V>(scalar: f64, vector: &V) -> ScalarVectorProduction<V>
    where V: Vector
{
    ScalarVectorProduction{scalar:scalar, vector:vector}
}

実行例

小さなデータでスカラー・ベクトル積を実現できていることを確認します.

fn print_vector<V: Vector>(vector: &V)
{
    for i in 0..vector.dimension() {
        print!("{}:{}, ", i, vector.get(i));
    }
    print!("\n");
}

fn main() {
    // 密ベクトルを構築
    let v = DenseVector{array:vec![4.0, 2.0]};
    print_vector(&v);
    // 2 倍する
    let w = product(2.0, &v);
    print_vector(&w);
    // 5 倍する
    let x = product(5.0, &w);
    print_vector(&x);
    // 0.1 倍する
    let y = product(0.1, &x);
    print_vector(&y);
}

出力

0:4, 1:2, 
0:8, 1:4, 
0:40, 1:20, 
0:4, 1:2, 

【余談】式テンプレートが保持している参照のライフタイムについて

C++ であれば,式テンプレートクラスがベクトルクラスの参照を保持している場合に,「式テンプレートクラスより先にベクトルクラスが破棄されてしまわないか」と常に気にする必要がありますが, Rust ではそのようなコードをコンパイル時にエラーにしてくれます.

fn main() {
    let u;
    {
        // 密ベクトルを構築
        let v = DenseVector{array:vec![4.0, 2.0]};
        u = product(2.0, &v);
        print_vector(&u);
    }
    // print_vector(&u);  // コメントアウトを外すとコンパイルエラー
}

↑ここまではうまく行った話
↓ここからはうまく行かなかった話

オペレーターオーバーロード

ベクトルの計算なので product などの関数呼び出しではなく,オペレーター(+, -, * など)を使って書けるとかっこいいですね.

Rust では * 演算子のオーバーロードには std::ops::Mul というトレイトを使い,左辺値のメソッドとして実装します.

コンパイルできないコード

任意の Vector を右辺値に取る f64 のメソッドとして, product と同じように以下のように書けそうな気がしますが,このコードはコンパイルできません.

関数でできていたことを実現できない理由はわかりませんが,そういう仕様のようです.

impl<'a, V> std::ops::Mul<&'a V> for f64  // なんでかエラーになる
    where V: Vector
{
    type Output = ScalarVectorProduction<'a, V>;
    fn mul(self, rhs: &'a V) -> Self::Output {
        ScalarVectorProduction{scalar:self, vector:rhs}
    }
}

コンパイルできるコード

右辺値を具体的な型または 具体的な型<パラメータ> として,以下のよう実装することで*演算子をオーバーロードできます.

impl<'a> std::ops::Mul<&'a DenseVector> for f64
{
    type Output = ScalarVectorProduction<'a, DenseVector>;
    fn mul(self, rhs: &'a DenseVector) -> Self::Output {
        ScalarVectorProduction{scalar:self, vector:rhs}
    }
}

impl<'a, V> std::ops::Mul<&'a ScalarVectorProduction<'a, V>> for f64
    where V: Vector
{
    type Output = ScalarVectorProduction<'a, V>;
    fn mul(self, rhs: &ScalarVectorProduction<'a, V>) -> Self::Output {
        ScalarVectorProduction{scalar:self * rhs.scalar, vector:rhs.vector}
    }
}

実行例

fn main() {
    // 密ベクトルを構築
    let v = DenseVector{array:vec![4.0, 2.0]};
    print_vector(&v);
    // 2 倍する
    let w = 2.0 * &v;
    print_vector(&w);
    // 5 倍する
    let x = 5.0 * &w;
    print_vector(&x);
    // 0.1 倍する
    let y = 0.1 * &x;
    print_vector(&y);
}

* 演算子を使ってスカラーベクトル積を記述できるようになりました.
ただ,本格的な式テンプレートライブラリを作ろうとすると型の数が多くなり,更にベクトル同士の和や行列ベクトル積を実現しようとすると,オペレーターオーバーロードのための実装が型の組み合わせの数だけ必要になってしまうので,今回紹介した方法はあまり現実的でないかもしれません.

まとめ

  • Rust で式テンプレートは実現できる.
    • オブジェクトの生存期間の間違いをコンパイラが指摘してくれるので,C++ と比べて安全なコードを書くことができる.
    • オペレーターではなく関数を使用するぶんには, C++ のテンプレート + コンセプトに近い書き方ができる.
  • オペレータオーバーロードを使ったかっこいい書き方を実現しようとするとかなりの手間.
    • C++ ではオペレータは通常の関数とほとんど違わないが, Rust ではオペレータオーバーロードの制限が強い?
    • マクロを使えば手間は緩和されるかもしれないがコンパイル時間も気になる.
    • 複数の式テンプレートクラスをラップする共通の型を用意するなど,型の数を減らす工夫は考えられるかもしれない.

追記

C++ でのコンセプトによる制限と違って, Rust でのトレイトによる制限では同名の関数を定義できないようです.
上記で実装した product は名前が良くなく, product_scalar_vector の方が良さそうです.

全コード

main.rs
trait Vector {
    fn dimension(&self) -> usize;
    fn get(&self, i: usize) -> f64;
}

struct DenseVector {
    array: Vec<f64>,
}

impl Vector for DenseVector {
    fn dimension(&self) -> usize {
        self.array.len()
    }
    fn get(&self, i: usize) -> f64 {
        self.array[i]
    }
}

struct ScalarVectorProduction<'a, V>
    where V: Vector
{
    scalar: f64,
    vector: &'a V,
}

impl<'a, V> Vector for ScalarVectorProduction<'a, V>
    where V: Vector
{
    fn dimension(&self) -> usize {
        self.vector.dimension()
    }
    fn get(&self, i: usize) -> f64 {
        self.scalar * self.vector.get(i)
    }
}

fn product<V>(scalar: f64, vector: &V) -> ScalarVectorProduction<V>
    where V: Vector
{
    ScalarVectorProduction{scalar:scalar, vector:vector}
}

// コンパイルできないコード
// impl<'a, V> std::ops::Mul<&'a V> for f64
//     where V: Vector
// {
//     type Output = ScalarVectorProduction<'a, V>;
//     fn mul(self, rhs: &'a V) -> Self::Output {
//         ScalarVectorProduction{scalar:self, vector:rhs}
//     }
// }

impl<'a> std::ops::Mul<&'a DenseVector> for f64
{
    type Output = ScalarVectorProduction<'a, DenseVector>;
    fn mul(self, rhs: &'a DenseVector) -> Self::Output {
        Self::Output{scalar:self, vector:rhs}
    }
}

impl<'a, V> std::ops::Mul<&'a ScalarVectorProduction<'a, V>> for f64
    where V: Vector
{
    type Output = ScalarVectorProduction<'a, V>;
    fn mul(self, rhs: &ScalarVectorProduction<'a, V>) -> Self::Output {
        Self::Output{scalar:self * rhs.scalar, vector:rhs.vector}
    }
}

fn print_vector<V: Vector>(vector: &V)
{
    for i in 0..vector.dimension() {
        print!("{}:{}, ", i, vector.get(i));
    }
    print!("\n");
}

fn main() {
    // 密ベクトルを構築
    let v = DenseVector{array:vec![4.0, 2.0]};
    print_vector(&v);
    // 2 倍する
    let w = 2.0 * &v;
    print_vector(&w);
    // 5 倍する
    let x = 5.0 * &w;
    print_vector(&x);
    // 0.1 倍する
    let y = 0.1 * &x;
    print_vector(&y);
}
5
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
3