Rust で数値計算をしていて意外にハマるのが, f32
と f64
でコードを共有しようとしてもなかなかコンパイルが通らないという点です. 傾向と対策をまとめます.
戦略
まずそもそもコードを共通化と言ってもジェネリクスを用いた関数
fn some_calculation<T>(x: T) -> T {
// ここを実装する //
}
を定義しようとしているのか, なんらかのトレイト Foo
を f32
と f64
に実装しようとしているのか, という 2 パターンがあり得ます. 後者の場合は
macro_rules! impl_foo {
($t: ty) => {
impl Foo for $t {
fn foo(&self) -> Self {
// ここを実装する //
}
}
};
}
impl_foo!(f32);
impl_foo!(f64);
という感じでしょうか. 基本的にマクロなら難しいポイントはないのですが, ジェネリック関数では適切なトレイトを探したりしなければならずなかなか大変です. それならば, とマクロでふたつの関数を定義しようとすると, 残念ながら concat_idents!
が関数名に使えず詰みます (paste クレートを使うという手はあります).
値の取得
Rust は暗黙の型変換を嫌うため, 値を取得するだけでもなかなか手間です.
0 と 1
マクロ中なら展開後に数値リテラルが適切な型として解釈されるので単に 0.
や 1.
と書けば大丈夫です. 一方ジェネリック関数では num-traits
のトレイト num_traits::Zero
, num_traits::One
を使うのが簡単です.
use num_traits::One;
fn get_one<T: One>() -> T {
T::one()
}
もちろん次小節で述べるより一般的な方法を使うこともできます.
整数を変換
任意の整数が与えられてそれを浮動小数点数として計算に使いたいということは頻繁にあります. マクロ中では単に as $t
とすれば良いので簡単です. 一方ジェネリック関数では案外クセがあって, From
が使えればそれが簡単なのですが, 型によって使えたり使えなかったりします.
-
f32
:From<u8>
,From<i8>
,From<u16>
,From<i16>
-
f64
:From<u8>
,From<i8>
,From<u16>
,From<i16>
,From<u32>
,From<i32>
つまり一番使うと思われる i32
については, f64
は From<i32>
ですが f32
はそうでないのです. この違いは, From<U>
は「情報喪失なし」に型変換できることを表すトレイトですが, f32
では有効数字が足りず i32
から変換することができないことに由来します.
ではどうするかというと, これも num-traits
が適切なトレイトを提供してくれています: num_traits::FromPrimitive
です.
use num_traits::FromPrimitive;
fn get<T: FromPrimitive>(x: i32) -> T {
T::from_i32(x).unwrap()
}
fn main() {
let _: f32 = get(7);
}
ちなみに, num_complex::Complex
, num_rational::Ratio
ともに Zero
, One
, FromPrimitive
です.
四則演算や関数
四則演算は std に std::ops::{Add, AddAssign, Sub, SubAssign, Neg, Mul, MulAssign, Div, DivAssign, Rem, RemAssign}
があるため, これを使えば大丈夫です. 書くのが面倒だという方は num_traits::{NumOps, NumOpsAssign}
を利用してください.
一方, sin
や exp
などの数学関数は std ではトレイトとしては用意されていないため, num_traits::Float
が必要です. このトレイトには num_traits::Num
を通じて PartialEq
や四則演算 NumOps
などが含まれています (NumOpsAssign
は入っていないため別枠で指定する必要があります). また Float
には PartialOrd
も入っています.
まとめ
f32
と f64
に使えるジェネリック関数を定義する場合, トレイト境界として T: Float + FromPrimitive
を指定すればたいていのことはできます. 必要に応じて NumOpsAssign
を補ってください.
use num_traits::{Float, FromPrimitive};
fn one_plus_cos<T: Float + FromPrimitive>(x: T, n: i32) -> T {
T::one() + ( x * T::from_i32(n).unwrap() ).cos()
}