アルゴリズム
Rust
FFT

RustでFFT (2)任意の要素数に適用できるFFT

More than 1 year has passed since last update.

概要

一般的なFFTは要素数が$2^n$でのみ適用可能である。
以下では、要素数が$2^n$に限らず任意の数で高速に処理する方法を示す。

Prime Factor型 FFT

use num::{Complex, Float, cast, one, zero};
use num_traits::float::FloatConst;
use std::marker::PhantomData;
use std::collections::HashMap;

pub struct PrimeFactorBasic<T> {
    phantom: PhantomData<T>,
}

fn primes() -> Box<Iterator<Item = usize>> {
    let mut d = HashMap::new();
    Box::new((2..).take(1).chain(
        (1..).map(|i| i * 2 + 1).filter_map(move |i| {
            let (result, factor) = match d.remove(&i) {
                Some(f) => (None, f),
                None => (Some(i), i * 2),
            };
            (1..)
                .map(|j| i + j * factor)
                .skip_while(|m| d.contains_key(m))
                .next()
                .map(|m| d.insert(m, factor));
            result
        }),
    ))
}

impl<T: Float + FloatConst> PrimeFactorBasic<T> {
    pub fn new() -> Self {
        Self { phantom: PhantomData }
    }

    fn convert_inner(&mut self, source: &[Complex<T>], is_back: bool) -> Vec<Complex<T>> {
        fn converter<T: Float + FloatConst>(source: &mut [Complex<T>], is_back: bool) {
            if source.len() <= 1 {
                return;
            }

            let mut output = (0..source.len()).map(|_| zero()).collect::<Vec<_>>();

            // 素因数分解 n1:FFTを行う次数のべき乗 n2:残りの因数
            let mut n1 = source.len();
            for p in primes() {
                if source.len() % p == 0 {
                    n1 = p;
                    break;
                }
                if p * p > source.len() {
                    break;
                }
            }
            let mut n2 = source.len();
            while n2 % n1 == 0 {
                n2 /= n1;
            }
            n1 = source.len() / n2;

            // DFT
            for j2 in 0..n2 {
                for k1 in 0..n1 {
                    output[n2 * k1 + j2] = if is_back {
                        (1..n1)
                            .fold(source[n1 * j2], |x, j1| {
                                x +
                                    Complex::<T>::from_polar(
                                        &one(),
                                        &(cast::<_, T>(2 * n2 * j1 * k1).unwrap() * T::PI() /
                                              cast(source.len()).unwrap()),
                                    ) *
                                        source[(n1 * j2 + n2 * j1) % source.len()]
                            })
                            .unscale(cast(n1).unwrap())
                    } else {
                        (1..n1).fold(source[n1 * j2], |x, j1| {
                            x +
                                Complex::<T>::from_polar(
                                    &one(),
                                    &(-cast::<_, T>(2 * n2 * j1 * k1).unwrap() * T::PI() /
                                          cast(source.len()).unwrap()),
                                ) *
                                    source[(n1 * j2 + n2 * j1) % source.len()]
                        })
                    }
                }
            }
            for j in 0..n1 {
                converter(&mut output[(j * n2)..((j + 1) * n2)], is_back);
            }

            // 並べ替え
            for k1 in 0..n1 {
                let j1 = n2 * k1 % n1;
                for k2 in 0..n2 {
                    let k = (n2 * k1 + n1 * k2) % source.len();
                    source[k] = output[n2 * j1 + (n1 * k2 % n2)];
                }
            }
        }
        let mut ret = source.to_vec();
        converter(&mut ret, is_back);
        return ret;
    }
}

impl<T> AlgorithmName for PrimeFactorBasic<T> {
    fn name(&self) -> &'static str {
        "Prime Factor Algorithm"
    }
}

impl<T: Float + FloatConst> DftAlgorithm<T> for PrimeFactorBasic<T> {
    fn convert(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
        self.convert_inner(source, false)
    }

    fn convert_back(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
        self.convert_inner(source, true)
    }
}

評価

大浦氏のコードを移植した。
氏のサイトでも言及されている通り、このコードは遅い。
本来であれば、基数毎のDFTを専用に作成し、高速化を図る必要がある。

Mixed Radix FFT

要素数$2^n$のFFTで示したCooley-Tukey型やStockham型のFFTは実際には任意の要素数で処理が可能である。
以下にCooley-Tukey型の時間引きにおけるMixed Radix FFTのコードを示す

use num::{Complex, Float, cast, one};
use num_traits::float::FloatConst;
use std::cmp;
use num_traits::NumAssign;

pub struct MixedRadix<T> {
    len: usize,
    ids: Vec<usize>,
    omega: Vec<Complex<T>>,
    omega_back: Vec<Complex<T>>,
    factors: Vec<usize>,
}

impl<T: Float + FloatConst + NumAssign> MixedRadix<T> {
    pub fn new() -> Self {
        Self {
            len: 0,
            ids: Vec::new(),
            omega: Vec::new(),
            omega_back: Vec::new(),
            factors: Vec::new(),
        }
    }

    fn initialize(&mut self, len: usize) {
        self.len = len;

        // ωの事前計算
        self.omega = Self::calc_omega(len);
        self.omega_back = Vec::with_capacity(len + 1);
        for i in (0..len + 1).rev() {
            self.omega_back.push(self.omega[i]);
        }
        // 素因数分解
        self.factors = Self::prime_factorization(len);

        // ビットリバースの計算
        self.ids = Vec::<usize>::with_capacity(len);
        let mut llen = 1_usize;
        self.ids.push(0);
        for &k in &self.factors {
            for i in 0..llen {
                self.ids[i] *= k;
            }
            for i in 1..k {
                for j in 0..llen {
                    let id = self.ids[j] + i;
                    self.ids.push(id);
                }
            }
            llen *= k;
        }
    }

    // ωの事前計算
    fn calc_omega(len: usize) -> Vec<Complex<T>> {
        let mut omega = Vec::with_capacity(len + 1);
        omega.push(one());
        match len & 3 {
            0 => {
                // 4で割り切れる(下位2ビットが0)ならば
                let q = len >> 2;
                let h = len >> 1;
                // 1/4周分を計算
                for i in 1..q {
                    omega.push(Complex::from_polar(
                        &one(),
                        &(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
                              cast(i).unwrap()),
                    ));
                }

                // 1/4~1/2周分を計算
                for i in q..h {
                    let tmp = omega[i - q];
                    omega.push(Complex::new(tmp.im, -tmp.re));
                }

                // 1/2周目から計算
                for i in h..len {
                    let tmp = omega[i - h];
                    omega.push(Complex::new(-tmp.re, -tmp.im));
                }

            }
            2 => {
                // 2で割り切れる(下位1ビットが0)ならば
                let h = cmp::max(len >> 1, 1);
                // 1/2周分を計算
                for i in 1..h {
                    omega.push(Complex::from_polar(
                        &one(),
                        &(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
                              cast(i).unwrap()),
                    ));
                }

                // 1/2周目から計算
                for i in h..len {
                    let tmp = omega[i - h];
                    omega.push(Complex::new(-tmp.re, -tmp.im));
                }
            }
            _ => {
                for i in 1..len {
                    omega.push(Complex::from_polar(
                        &one(),
                        &(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
                              cast(i).unwrap()),
                    ));
                }
            }
        }
        // 1周ちょうど
        omega.push(one());
        return omega;
    }

    // 素因数分解
    fn prime_factorization(mut value: usize) -> Vec<usize> {
        let mut factors = Vec::<usize>::with_capacity(32); // 2^32=4G
        if value == 0 {
            return factors;
        }
        while value & 3 == 0 {
            factors.push(4);
            value >>= 2;
        }
        while value & 1 == 0 {
            factors.push(2);
            value >>= 1;
        }
        let mut prime = 3;
        while value >= prime * prime {
            if value % prime == 0 {
                factors.push(prime);
                value /= prime;
            } else {
                prime += 2;
            }
        }
        if value > 1 {
            factors.push(value);
        }
        return factors;
    }

    fn convert_inner(&mut self, source: &[Complex<T>], is_back: bool) -> Vec<Complex<T>> {
        let len = source.len();

        // 1要素以下ならば入力値をそのまま返す
        if len <= 1 {
            return source.to_vec();
        }

        if len != self.len {
            self.initialize(len);
        }

        let omega = if is_back {
            &self.omega_back
        } else {
            &self.omega
        };

        let mut ret = self.ids.iter().map(|&i| if is_back {
            source[i].unscale(cast(source.len()).unwrap()) // 逆変換の場合はこのタイミングで割り戻しておく
        } else {
            source[i]
        }).collect::<Vec<_>>();

        // FFT
        let mut rot = Vec::new();
        let mut po2 = 1;
        let mut rad = len;
        let mut rot_len = 0;

        let im_one = if is_back { -Complex::i() } else { Complex::i() };

        for &factor in &self.factors {
            let po2m = po2;
            po2 *= factor;
            rad /= factor;

            for mut j in 0..po2m {
                let wpos = rad * j;

                while j < len {
                    match factor {
                        2 => {
                            let pos1 = j + po2m;
                            let z1 = ret[pos1] * omega[wpos];
                            ret[pos1] = ret[j] - z1;
                            ret[j] += z1;
                        }
                        3 => {
                            let pos1 = j + po2m;
                            let pos2 = pos1 + po2m;
                            let z1 = ret[pos1] * omega[wpos];
                            let z2 = ret[pos2] * omega[wpos << 1];
                            let t1 = z1 + z2;
                            let t2 = ret[j] - t1.scale(cast(0.5).unwrap());
                            let t3 = (z1 - z2) *
                                im_one.scale(
                                    (cast::<_, T>(-2.0).unwrap() * T::PI() / cast(3.0).unwrap())
                                        .sin(),
                                );
                            ret[j] += t1;
                            ret[pos1] = t2 + t3;
                            ret[pos2] = t2 - t3;
                        }
                        4 => {
                            let w1 = omega[wpos];
                            let w2 = omega[wpos << 1];
                            let w3 = omega[wpos * 3];

                            let pos1 = j + po2m;
                            let pos2 = pos1 + po2m;
                            let pos3 = pos2 + po2m;
                            let wfa = ret[j];
                            let wfb = ret[pos2] * w2;
                            let wfab = wfa + wfb;
                            let wfamb = wfa - wfb;
                            let wfc = ret[pos1] * w1;
                            let wfd = ret[pos3] * w3;
                            let wfcd = wfc + wfd;
                            let wfcimdi = (wfc - wfd) * im_one;

                            ret[j] = wfab + wfcd;
                            ret[pos1] = wfamb - wfcimdi;
                            ret[pos2] = wfab - wfcd;
                            ret[pos3] = wfamb + wfcimdi;
                        }
                        5 => {
                            let w1 = omega[wpos];
                            let w2 = omega[wpos << 1];
                            let w3 = omega[wpos * 3];
                            let w4 = omega[wpos << 2];

                            let pos2 = j + po2m;
                            let pos3 = pos2 + po2m;
                            let pos4 = pos3 + po2m;
                            let pos5 = pos4 + po2m;

                            let z0 = ret[j];
                            let z1 = ret[pos2] * w1;
                            let z2 = ret[pos3] * w2;
                            let z3 = ret[pos4] * w3;
                            let z4 = ret[pos5] * w4;

                            let t1 = z1 + z4;
                            let t2 = z2 + z3;
                            let t3 = z1 - z4;
                            let t4 = z2 - z3;
                            let t5 = t1 + t2;
                            let t6 = (t1 - t2).scale(
                                cast::<_, T>(0.25).unwrap() *
                                    cast::<_, T>(5.0).unwrap().sqrt(),
                            );
                            let t7 = z0 - t5.scale(cast(0.25).unwrap());
                            let t8 = t6 + t7;
                            let t9 = t7 - t6;
                            let t10 = (t3.scale((cast::<_, T>(-0.4).unwrap() * T::PI()).sin()) +
                                           t4.scale(
                                    (cast::<_, T>(-0.2).unwrap() * T::PI()).sin(),
                                )) * im_one;
                            let t11 = (t3.scale((cast::<_, T>(-0.2).unwrap() * T::PI()).sin()) -
                                           t4.scale(
                                    (cast::<_, T>(-0.4).unwrap() * T::PI()).sin(),
                                )) * im_one;

                            ret[j] = z0 + t5;
                            ret[pos2] = t8 + t10;
                            ret[pos3] = t9 + t11;
                            ret[pos4] = t9 - t11;
                            ret[pos5] = t8 - t10;
                        }
                        _ => {
                            if rot_len != factor {
                                rot_len = factor;
                                let rot_width = len / factor;
                                rot = (0..factor)
                                    .map(|i| omega[rot_width * i])
                                    .collect::<Vec<_>>();
                            }

                            // 定義式DFT
                            let pos = (0..factor).map(|i| j + po2m * i).collect::<Vec<_>>();
                            let z = (0..factor)
                                .map(|i| ret[pos[i]] * omega[wpos * i])
                                .collect::<Vec<_>>();
                            for i in 0..factor {
                                ret[pos[i]] =
                                    (1..factor).fold(z[0], |x, l| x + z[l] * rot[(i * l) % factor]);
                            }
                        }
                    }
                    j += po2;
                }
            }
        }
        return ret;
    }
}

impl<T> AlgorithmName for MixedRadix<T> {
    fn name(&self) -> &'static str {
        "Cooley-Turkey (Mixed Radix)"
    }
}

impl<T: Float + FloatConst + NumAssign> DftAlgorithm<T> for MixedRadix<T> {
    fn convert(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
        self.convert_inner(source, false)
    }

    fn convert_back(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
        self.convert_inner(source, true)
    }
}

評価

Prime Factor型と同様に、基数毎のDFTを専用に作成し、高速化を図る必要がある。
上記では基数2,3,4,5におけるDFTを作成し、高速化を図っているが、まだチューニングの余地は残されている。

Chirp-Z 変換

Chirp-Z 変換は、要素数を変更するための変換処理である。
$2^n$のFFTはかなり高速な為、変換を実施し、FFTを行い、逆変換を実施する事により、高速に処理できる。

use num::{Complex, Float, cast, one, zero};
use num_traits::float::FloatConst;
use std::cmp;

#[derive(Debug)]
pub struct ChirpZ<T> {
    level: usize,
    ids: Vec<usize>,
    omega: Vec<Complex<T>>,
    src_omega: Vec<Complex<T>>,
    src_len: usize,
}

impl<T: Float + FloatConst> ChirpZ<T> {
    pub fn new() -> Self {
        Self {
            level: 0,
            ids: Vec::new(),
            omega: Vec::new(),
            src_omega: Vec::new(),
            src_len: 0,
        }
    }

    // ωの事前計算
    fn calc_omega(len: usize) -> Vec<Complex<T>> {
        let mut omega = Vec::with_capacity(len + 1);
        omega.push(one());
        if len & 3 == 0 {
            // 4で割り切れる(下位2ビットが0)ならば
            let q = len >> 2;
            let h = len >> 1;
            // 1/4周分を計算
            for i in 1..q {
                omega.push(Complex::from_polar(
                    &one(),
                    &(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
                          cast(i).unwrap()),
                ));
            }

            // 1/4~1/2周分を計算
            for i in q..h {
                let tmp = omega[i - q];
                omega.push(Complex::new(tmp.im, -tmp.re));
            }

            // 1/2周目から計算
            for i in h..len {
                let tmp = omega[i - h];
                omega.push(Complex::new(-tmp.re, -tmp.im));
            }

        } else if len & 1 == 0 {
            // 2で割り切れる(下位1ビットが0)ならば
            let h = cmp::max(len >> 1, 1);
            // 1/2周分を計算
            for i in 1..h {
                omega.push(Complex::from_polar(
                    &one(),
                    &(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
                          cast(i).unwrap()),
                ));
            }

            // 1/2周目から計算
            for i in h..len {
                let tmp = omega[i - h];
                omega.push(Complex::new(-tmp.re, -tmp.im));
            }
        } else {
            for i in 1..len {
                omega.push(Complex::from_polar(
                    &one(),
                    &(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
                          cast(i).unwrap()),
                ));
            }
        }
        // 1周ちょうど
        omega.push(one());
        return omega;
    }

    fn initialize(&mut self, len: usize) {
        let pow2len = len.next_power_of_two() << 1;
        let lv = pow2len.trailing_zeros() as usize;

        if lv != self.level {
            self.level = lv;
            self.omega = Self::calc_omega(pow2len);
        }

        self.src_len = len;
        self.src_omega = Self::calc_omega(len << 1);

        // ビットリバースの計算
        self.ids = Vec::with_capacity(pow2len);
        self.ids.push(0);
        for _ in 0..lv {
            for j in 0..self.ids.len() {
                let id = self.ids[j] << 1;
                self.ids[j] = id;
                self.ids.push(id + 1);
            }
        }
    }

    fn convert_rad2(&mut self, source: &[Complex<T>], is_back: bool) -> Vec<Complex<T>> {

        // 入力の並び替え
        let mut ret = self.ids.iter().map(|&i| source[i]).collect::<Vec<_>>();

        // FFT
        let mut po2 = 1;
        if (self.level & 1) == 1 {
            po2 = 2;
            for j in 0..(ret.len() >> 1) {
                let pos_a = j << 1;
                let pos_b = pos_a + 1;
                let wfa = ret[pos_a];
                let wfb = ret[pos_b];
                ret[pos_a] = wfa + wfb;
                ret[pos_b] = wfa - wfb;
            }
        }

        let im_one = if is_back { -Complex::i() } else { Complex::i() };

        for i in 1..((self.level >> 1) + 1) {
            let po2m = po2;
            po2 <<= 2;

            for k in 0..po2m {
                let pos_w = (1 << ((self.level & !1) - (i << 1))) * k;

                let (w, w2, w3) = if is_back {
                    (
                        self.omega[pos_w].conj(),
                        self.omega[pos_w << 1].conj(),
                        self.omega[pos_w * 3].conj(),
                    )
                } else {
                    (
                        self.omega[pos_w],
                        self.omega[pos_w << 1],
                        self.omega[pos_w * 3],
                    )
                };
                let mut j = k;
                while j < ret.len() {
                    let pos_b = j + po2m;
                    let pos_c = j + (po2m << 1);
                    let pos_d = j + (po2m * 3);
                    let wfa = ret[j];
                    let wfb = ret[pos_b] * w2;
                    let wfab = wfa + wfb;
                    let wfamb = wfa - wfb;
                    let wfc = ret[pos_c] * w;
                    let wfd = ret[pos_d] * w3;
                    let wfcd = wfc + wfd;
                    let wfcimdi = (wfc - wfd) * im_one;

                    ret[j] = wfab + wfcd;
                    ret[pos_b] = wfamb - wfcimdi;
                    ret[pos_c] = wfab - wfcd;
                    ret[pos_d] = wfamb + wfcimdi;
                    j += po2;
                }
            }
        }
        if is_back {
            for i in 0..ret.len() {
                ret[i] = ret[i].unscale(cast(ret.len()).unwrap());
            }
        }
        return ret;
    }

    fn convert_inner(&mut self, source: &[Complex<T>], is_back: bool) -> Vec<Complex<T>> {
        let srclen = source.len();

        // 1要素以下ならば入力値をそのまま返す
        if srclen <= 1 {
            return source.to_vec();
        }

        if self.src_len != srclen {
            self.initialize(srclen);
        }

        let len = 1 << self.level;

        let mut a = Vec::with_capacity(len);
        let mut b = Vec::with_capacity(len);

        for i in 0..source.len() {
            let w = if is_back {
                self.src_omega[(i * i) % (srclen << 1)].conj()
            } else {
                self.src_omega[(i * i) % (srclen << 1)]
            };
            a.push(source[i] * w);
            b.push(w.conj());
        }

        let hlen = (len >> 1) + 1;
        for _ in srclen..hlen {
            a.push(zero());
            b.push(zero());
        }
        for i in hlen..len {
            a.push(zero());
            let t = b[len - i];
            b.push(t);
        }

        let aa = self.convert_rad2(&a, false);
        let bb = self.convert_rad2(&b, false);

        let gg = aa.iter().zip(&bb).map(|(&x, &y)| x * y).collect::<Vec<_>>();
        let g = self.convert_rad2(&gg, true);

        // Multiply phase factor
        return (0..srclen)
            .map(|i| if is_back {
                (g[i] * b[i].conj()).unscale(cast(srclen).unwrap())
            } else {
                g[i] * b[i].conj()
            })
            .collect::<Vec<_>>();
    }
}

impl<T> AlgorithmName for ChirpZ<T> {
    fn name(&self) -> &'static str {
        "Chirp-Z 変換"
    }
}

impl<T: Float + FloatConst> DftAlgorithm<T> for ChirpZ<T> {
    fn convert(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
        self.convert_inner(source, false)
    }

    fn convert_back(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
        self.convert_inner(source, true)
    }
}

評価

本コードでは、Chirp-Z変換を行なった後、基数4、時間引きのCooley-Tukey型FFTを実施している。
要素数が$2^n$の時も変換している等、無駄の多いコードとなっているため修正の必要があるが、
要素数が大きく素数である場合などでは、Chirp-Z変換が最速となる。

終わりに

次回は高速化について言及したい。