LoginSignup
2
2

More than 1 year has passed since last update.

RustのTrait基礎

Posted at

これには誤りが書いている可能性があるので注意してください

traitの概要

トレイトの基本は以下のページを参照する

:point_right: https://doc.rust-jp.rs/book-ja/ch10-02-traits.html

:point_right: https://doc.rust-jp.rs/book-ja/ch19-03-advanced-traits.html

:point_right: https://zenn.dev/mebiusbox/books/22d4c1ed9b0003/viewer/497a21

トレイトは、Rustコンパイラに、特定の型に存在し、他の型と共有できる機能について知らせます。
トレイトを使用すると、共通の振る舞いを抽象的に定義できます。トレイト境界を使用すると、あるジェネリックが、特定の振る舞いをもつあらゆる型になり得ることを指定できます。

簡単に言ってしまうと「型に共通のふるまいを定義できる」ということになる

型というのが重要で、Rustでは構造体(struct)や列挙子(enum)を使い、新しい型を簡単に作ることができる

これで何が良いかというと、型が異なるがやりたいことは同じときに処理を変えて、見た目が似ているコードを書くことができるようになる

trait HelloWorld {
    fn hello() -> Self;
    fn print_func(self) -> Self;
    fn main_func(self) -> Self;
    fn source_code(self) -> Self;
}

struct Java(String);
struct C(String);

// ---- 省略 ---- //

fn main() {
    // ここではJavaかCしか変わっていない
    // だが、実行結果はそれぞれのhello, Langになる
    println!("{}", Java::hello().print_func().main_func().source_code().0);
    println!("{}",    C::hello().print_func().main_func().source_code().0);
}
/* 実行結果
Class HelloWorld {
    public static void main(String[] args){
        System.out.println!("hello, Java");
    }
}
#include<stdin.h>

int main(void){
    printf("hello, C");
}
*/

main()にあるようにJavaやC以外は変わっていないが、出力結果がそれぞれの言語の形になっている

処理の部分を意図的に隠しているが、例えば新しくpythonの構造体が作られたが、どう処理されているかが分からないとする

だとしてもPython::hello().print_func().main_func().source_code().0と書けば良いのがすぐにわかる

このような感じで各型で似たふるまいを定義することで、処理の中身だけ各型ごとに変えることができる

u8だろうが、i64だろうがAdd():+という演算子を使ったらu8同士で足し算できるし、i64どうしでも同様に足し算できる

同様に、それら型を内包した新しい構造体を作ってAdd():+を使ったら勝手に中身を足し算すれば良いという発想になる

型ごとに共通化できる振る舞いを定義さえできれば、それを使う側から見れば似たようなコードを書けば良いことが分かる

実装する側も意味的に共通なコードを書くことで、書くコードに制限をつけて誤った返り値や引数を取らないようにすることができる

既存のTraitを使ってみる

そうなると先ほど述べたような四則演算や比較といったことを実装したければ、毎度のごとくtraitを書く必要が出てくる

そこでRustではderiveというAttribte(属性)を使ったり、四則演算は演算子のオーバーロードをすることで可能になる

deriveできるのは以下の通り

属性 簡易説明
Copy 所有権を移動せずに、複製できる
Clone コピーを介して&TからTを作る
Debug {:?}が使えるようになる
PartialEq ==と!=が使えるようになる
Eq ==, !=の厳密な定義
PartialOrd <,>,<=,>=が使えるようになる
Ord <,>,<=,>=の厳密な定義
Default デフォルトの値を使えるようになる
Hash &Tからハッシュを計算し,Hashとかが使えるようになる

deriveできるのはこれだけではなく、cfgやtestなど便利なものがたくさんある(https://doc.rust-lang.org/reference/attributes.html)

数学的にも座標同士の大小比較できないが、PartialEq(半同値)とPartialOrd(半順序)をつけることで比較演算子で比較をすることもできる

use std::ops::{Add, Sub};

#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
struct Point {
    x: i32,
    y: i32,
}
// +を使えるようにする
impl Add for Point {
    type Output = Self;

    fn add(self, other: Self) -> Self {
        // どう足し算するかを決める
        Self {x: self.x + other.x, y: self.y + other.y}
    }
}
// -を使えるようにする
impl Sub for Point {
    type Output = Self;

    fn sub(self, other: Self) -> Self {
        // どう引き算するかを決める
        Self {x: self.x - other.x, y: self.y - other.y}
        // 意図的にx, yを入れ替えて計算することもできる
        // Self {x: self.x - other.y, y: self.y - other.x}
    }
}

impl Point {
    fn new(x: i32, y: i32) -> Self {
        Self { x, y }
    }
}

これらPointを使うだけなら以下のように簡単に記述できるようになる

fn main() {
    let a = Point::new(10, 10);
    let b = Point::new(-1, -1);
    println!("{:?}", a + b); // Debugをderiveしているおかげで{:?}が使えるようになっている
    println!("{:?}", a - b); // +-の演算子も新しく作ったPoint構造体で使える
    println!("a > b: {}", a > b); // ここでa, bを使えるのはCopyのおかげ
    println!("a < b: {}", a <= b); // Copyを使えるようにするためにCloneがある
}
/*
Point { x: 9, y: 9 }
Point { x: 11, y: 11 }
a > b: true
a < b: false
*/

このように使う側からしたらa.get_x() + b.get_x()みたいな値を取り出して変更してみたいな書き方をしなくてもいつもの足し算や比較のように書くことができる

Traitの継承

deriveしているのも継承して処理を上書きすることもできる

当たり前だが、引数と返り値はtraitが指定しているもの以外はできない

比較の仕方や表示のさせ方を変える場合も、継承させるだけで簡単に実装できる

use std::cmp::Ordering;
use std::fmt;
use std::ops::{Add, Sub};

// ---- 省略 --- //

impl PartialOrd for Point {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        // x, yそれぞれで比較してそれを利用できる
        let x = self.x.partial_cmp(&other.x).unwrap();
        let y = self.y.partial_cmp(&other.y).unwrap();
        match (x, y) {
            // 適当なので注意
            (Ordering::Greater, _) => Some(Ordering::Greater),
            (_, Ordering::Greater) => Some(Ordering::Greater),
            (Ordering::Equal, Ordering::Equal) => Some(Ordering::Equal),
            _ => Some(Ordering::Less),
        }
    }
}

impl fmt::Debug for Point {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        // fmt.debug_tuple()もある
        write!(f, "Point({}, {})", &self.x, &self.y)
    }
}

fn main(){
    let a = Point::new(10, 10);
    let b = Point::new(-1, -1);
    println!("{:?}", a + b);
    println!("{:?}", a - b);
}
// Point(9, 9)
// Point(11, 11)

さらに複数の定義を上乗せする

// PartialOrdを先に定義しているため、それを使うだけで良い
// さらに厳密にしたければ再度上書きもできる
impl Ord for Point {
    fn cmp(&self, other: &Self) -> Ordering {
        self.partial_cmp(other).unwrap()
    }
}

struct Points {
    points: Vec<Point>,
}

impl Iterator for Points {
    type Item = Point;

    // next()は必ず実装しないといけない
    fn next(&mut self) -> Option<Self::Item> {
        self.points.pop()
    }
}

impl Points {
    fn new(points: Vec<Point>) -> Self {
        Self { points }
    }
}

VectorをsortできるようにOrdを実装し、複数のPointを持つPointsを作った

このように、Pointが複数ある場合も、いままでのVectorやIteratorと同じメソッドを使って各Pointを操作することができる

書く側もsortさせたければOrdをつけたり継承させれば良いし、Iteratorを作りたければ継承させればいい

その上で、引数や返り値を満たして簡単に自分の作った構造体で今までと同じ操作が可能なものが作れるようになる

fn main() {
    let a = Point::new(10, 10);
    let b = Point::new(-1, -1);
    let c = Point::new(1, 1);
    let d = Point::new(-1, 1);
    let e = Point::new(5, -2);
    let f = Point::new(-10, 10);

    let mut points = vec![a, b, c, d, e, f];
    let iter_points = points.iter();
    println!("Before");
    for p in iter_points {
        println!("{:?}", p);
    }

    points.sort(); // Ordのおかげでソートができる
    println!("After");
    let mut points = Points::new(points);
    // Iteratorのおかげでnext()が使える
    // `while let`を使わなくてもfor p in points {}が普通に使える
    while let Some(p) = points.next() {
        println!("{:?}", p);
    }
}

PointのフィールドをGenerics型にする

今まではPointのフィールドの型はi32に固定されていた

これを今度はGenerics型にしていろんな型を取れるようにする

#[derive(Clone, Copy)]
struct Point<T> {
    x: T,
    y: T,
}

impl<T: Add<Output=T>> Add for Point<T> {
    type Output = Self;

    fn add(self, other: Self) -> Self {
        Self {
            x: self.x + other.x,
            y: self.y + other.y,
        }
    }
}
impl<T: Sub<Output=T>> Sub for Point<T> {
    type Output = Self;

    fn sub(self, other: Self) -> Self {
        Self {
            x: self.x - other.x,
            y: self.y - other.y,
        }
    }
}

impl <T: fmt::Display> fmt::Debug for Point <T>{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "Point({}, {})", &self.x, &self.y)
    }
}

変更点は、Tという具体的な型を受け取るための引数が必要になる

このTは別にUでも何でもいいがimport numpy as npと書くのがお決まりのように、Typeの頭文字としてT、複数取るときはTの次のU、関数を取るときはFなどのお決まりがある

これにより、様々な型を受け取ることができるようになる

fn main(){
    let a: Point<u32> = Point::new(10, 10);
    let b: Point<u32> = Point::new(10, 10);
    println!("{:?}", a + b);
    println!("{:?}", a - b);

    let c: Point<f32> = Point::new(5.4, -1.05);
    let d: Point<f32> = Point::new(1.01, -5.1);
    println!("{:?}", c + d);
    println!("{:?}", c - d);

    let e: Point<i32> = Point::new(1, -1);
    let f: Point<i32> = Point::new(1, -1);
    println!("{:?}", e + f);
    println!("{:?}", e - f);
}
/*
Point(20, 20)
Point(0, 0)
Point(6.41, -6.1499996)      
Point(4.3900003, 4.05)       
Point(2, -2)
Point(0, 0)
*/

なぜこれがコンパイルが通るかというと、u8やf32といったprimitive型は基本手的にAddやSubをimplしているからである

なので、単純な足し算や引き算をする場合にはTを指定するだけでいい

ではStringはどうかというと、上記のもとの演算子(+, -)ではなくpush_str()を使っているため分けて書かないといけない

同様に、&str同士を足し算したければCow<>を使わないといけないためこれも分ける必要がある

impl Add<Point<&str>> for Point<String> {
    type Output = Point<String>;

    fn add(mut self, other: Point<&str>) -> Point<String> {
        self.x.push_str(other.x);
        self.y.push_str(other.y);
        Point {
            x: self.x,
            y: self.y,
        }
    }
}

impl<'a> Add<Point<&'a str>> for Point<Cow<'a, str>> {
    type Output = Point<Cow<'a, str>>;

    fn add(self, other: Point<&'a str>) -> Point<Cow<'a, str>> {
        Point {
            x: self.x + other.x,
            y: self.y + other.y,
        }
    }
}

だが、実際に動かすところだとほぼ似たような形、Point::new(T)して、+-をしているだけになる

もちろん-演算子を文字列でしたければ処理を別途考える必要がある

fn main() {
    let a: Point<u32> = Point::new(10, 10);
    let b: Point<u32> = Point::new(10, 10);
    println!("{:?}", a + b);
    println!("{:?}", a - b);

    let c: Point<f32> = Point::new(5.4, -1.05);
    let d: Point<f32> = Point::new(1.01, -5.1);
    println!("{:?}", c + d);
    println!("{:?}", c - d);

    let e: Point<i32> = Point::new(1, -1);
    let f: Point<i32> = Point::new(1, -1);
    println!("{:?}", e + f);
    println!("{:?}", e - f);

    let g: Point<Cow<str>> = Point::new(Cow::from("Hello,"), Cow::from("Hi"));
    let h: Point<Cow<str>> = Point::new(Cow::from(" world"), Cow::from(" there"));
    println!("{:?}", g + h); // Point(Hello, world, Hi there)

    let g: Point<String> = Point::new("Hello,".to_owned(), "Hi".to_owned());
    let h: Point<&str> = Point::new(" world", " there");
    println!("{:?}", g + h); // Point(Hello, world, Hi there)
}

終わりに

正直私もまだ使いこなせていない

ただ、traitの意味や使い方が分かってくるとドキュメントのstructs、traitsを見ればだいたいどうすれば良いかが分かってくる(わかりやすいわけではない)

closureのFnOnceだったり、Sizedのようなmaker traitなどさらに奥深い場所もあるので初心者のうちはこれらを知っておけばとりあえずは良いんじゃあないかと思う

自分で作った型を演算子で計算したりイテレートできるようになってくると表現の幅が広がり、よりRustっぽいコードが書けるようになってくる

Point
use std::cmp::*;
use std::fmt;
use std::ops::{Add, Sub};

#[derive(PartialEq, Eq)]
struct Point {
    x: i32,
    y: i32,
}

impl Add for Point {
    type Output = Self;

    fn add(self, other: Self) -> Self {
        // どのフィールドとどのフィールドを足し算するかを決める
        Self {
            x: self.x + other.x,
            y: self.y + other.y,
        }
    }
}

impl Sub for Point {
    type Output = Self;

    fn sub(self, other: Self) -> Self {
        // どのフィールドとどのフィールドを引き算するかを決める
        Self {
            x: self.x - other.x,
            y: self.y - other.y,
        }
    }
}

impl Point {
    fn new(x: i32, y: i32) -> Self {
        Self { x, y }
    }
}

impl PartialOrd for Point {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        let x = self.x.partial_cmp(&other.x).unwrap();
        let y = self.y.partial_cmp(&other.y).unwrap();
        match (x, y) {
            // 適当なので注意
            (Ordering::Greater, _) => Some(Ordering::Greater),
            (_, Ordering::Greater) => Some(Ordering::Greater),
            (Ordering::Equal, Ordering::Equal) => Some(Ordering::Equal),
            _ => Some(Ordering::Less),
        }
    }
}

impl Ord for Point {
    fn cmp(&self, other: &Self) -> Ordering {
        self.partial_cmp(other).unwrap()
    }
}

impl fmt::Debug for Point {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "Point({}, {})", &self.x, &self.y)
    }
}

struct Points {
    points: Vec<Point>,
}

impl Points {
    fn new(points: Vec<Point>) -> Self {
        Self { points }
    }
}

impl Iterator for Points {
    type Item = Point;

    fn next(&mut self) -> Option<Self::Item> {
        self.points.pop()
    }
}
Point Generics
use std::borrow::Cow;
use std::{fmt};
use std::ops::{Add, Sub};

#[derive(Clone, Copy)]
struct Point<T> {
    x: T,
    y: T,
}

impl<T: Add<Output=T>> Add for Point<T> {
    type Output = Self;

    fn add(self, other: Self) -> Self {
        Self {
            x: self.x + other.x,
            y: self.y + other.y,
        }
    }
}

impl<'a> Add<Point<&'a str>> for Point<Cow<'a, str>> {
    type Output = Point<Cow<'a, str>>;

    fn add(self, other: Point<&'a str>) -> Point<Cow<'a, str>> {
        Point {
            x: self.x + other.x,
            y: self.y + other.y,
        }
    }
}

impl Add<Point<&str>> for Point<String> {
    type Output = Point<String>;

    fn add(mut self, other: Point<&str>) -> Point<String> {
        self.x.push_str(other.x);
        self.y.push_str(other.y);
        Point {
            x: self.x,
            y: self.y,
        }
    }
}

impl<T: Sub<Output=T>> Sub for Point<T> {
    type Output = Self;

    fn sub(self, other: Self) -> Self {
        Self {
            x: self.x - other.x,
            y: self.y - other.y,
        }
    }
}

impl <T: fmt::Display> fmt::Debug for Point <T>{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "Point({}, {})", &self.x, &self.y)
    }
}

impl<T> Point<T> {
    fn new(x: T, y: T) -> Self {
        Self { x, y }
    }
}

fn main() {
    let a: Point<u32> = Point::new(10, 10);
    let b: Point<u32> = Point::new(10, 10);
    println!("{:?}", a + b);
    println!("{:?}", a - b);

    let c: Point<f32> = Point::new(5.4, -1.05);
    let d: Point<f32> = Point::new(1.01, -5.1);
    println!("{:?}", c + d);
    println!("{:?}", c - d);

    let e: Point<i32> = Point::new(1, -1);
    let f: Point<i32> = Point::new(1, -1);
    println!("{:?}", e + f);
    println!("{:?}", e - f);

    let g: Point<Cow<str>> = Point::new(Cow::from("Hello,"), Cow::from("Hi"));
    let h: Point<Cow<str>> = Point::new(Cow::from("World"), Cow::from("There"));
    println!("{:?}", g + h);

    let g: Point<String> = Point::new("Hello,".to_owned(), "Hi".to_owned());
    let h: Point<&str> = Point::new(" world", " there");
    println!("{:?}", g + h);

}
Hello
trait HelloWorld {
    fn hello() -> Self;
    fn print_func(self) -> Self;
    fn main_func(self) -> Self;
    fn source_code(self) -> Self;
}

struct Java(String);

impl HelloWorld for Java {
    fn hello() -> Self {
        Self("hello, Java".to_string())
    }

    fn print_func(self) -> Self {
        Self(format!("    System.out.println!(\"{}\");", self.0))
    }

    fn main_func(self) -> Self {
        Self(format!(
            "\n    public static void main(String[] args){{\n    {}\n    }}",
            self.0
        ))
    }

    fn source_code(self) -> Self {
        Self(format!("Class HelloWorld {{{}\n}}", self.0))
    }
}

struct C(String);

impl HelloWorld for C {
    fn hello() -> Self {
        Self("hello, C".to_string())
    }

    fn print_func(self) -> Self {
        Self(format!("printf(\"{}\");", self.0))
    }
    fn main_func(self) -> Self {
        Self(format!("\nint main(void){{\n    {}\n}}", self.0))
    }

    fn source_code(self) -> Self {
        Self(format!("#include<stdin.h>\n    {}", self.0))
    }
}

struct Python(String);

impl HelloWorld for Python {
    fn hello() -> Self {
        Self("hello, Python".to_string())
    }

    fn print_func(self) -> Self {
        Self(format!("print('{}')", self.0))
    }

    fn main_func(self) -> Self {
        Self(format!("def min():\n    {}", self.0))
    }

    fn source_code(self) -> Self {
        Self(format!("{}\n\nif __init__ == '__main__':\n    main()", self.0))
    }

}
2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2