LoginSignup
4
3

More than 5 years have passed since last update.

RustでFFT (2)任意の要素数に適用できるFFT

Posted at

概要

一般的なFFTは要素数が$2^n$でのみ適用可能である。
以下では、要素数が$2^n$に限らず任意の数で高速に処理する方法を示す。

Prime Factor型 FFT

use num::{Complex, Float, cast, one, zero};
use num_traits::float::FloatConst;
use std::marker::PhantomData;
use std::collections::HashMap;

pub struct PrimeFactorBasic<T> {
    phantom: PhantomData<T>,
}

fn primes() -> Box<Iterator<Item = usize>> {
    let mut d = HashMap::new();
    Box::new((2..).take(1).chain(
        (1..).map(|i| i * 2 + 1).filter_map(move |i| {
            let (result, factor) = match d.remove(&i) {
                Some(f) => (None, f),
                None => (Some(i), i * 2),
            };
            (1..)
                .map(|j| i + j * factor)
                .skip_while(|m| d.contains_key(m))
                .next()
                .map(|m| d.insert(m, factor));
            result
        }),
    ))
}

impl<T: Float + FloatConst> PrimeFactorBasic<T> {
    pub fn new() -> Self {
        Self { phantom: PhantomData }
    }

    fn convert_inner(&mut self, source: &[Complex<T>], is_back: bool) -> Vec<Complex<T>> {
        fn converter<T: Float + FloatConst>(source: &mut [Complex<T>], is_back: bool) {
            if source.len() <= 1 {
                return;
            }

            let mut output = (0..source.len()).map(|_| zero()).collect::<Vec<_>>();

            // 素因数分解 n1:FFTを行う次数のべき乗 n2:残りの因数
            let mut n1 = source.len();
            for p in primes() {
                if source.len() % p == 0 {
                    n1 = p;
                    break;
                }
                if p * p > source.len() {
                    break;
                }
            }
            let mut n2 = source.len();
            while n2 % n1 == 0 {
                n2 /= n1;
            }
            n1 = source.len() / n2;

            // DFT
            for j2 in 0..n2 {
                for k1 in 0..n1 {
                    output[n2 * k1 + j2] = if is_back {
                        (1..n1)
                            .fold(source[n1 * j2], |x, j1| {
                                x +
                                    Complex::<T>::from_polar(
                                        &one(),
                                        &(cast::<_, T>(2 * n2 * j1 * k1).unwrap() * T::PI() /
                                              cast(source.len()).unwrap()),
                                    ) *
                                        source[(n1 * j2 + n2 * j1) % source.len()]
                            })
                            .unscale(cast(n1).unwrap())
                    } else {
                        (1..n1).fold(source[n1 * j2], |x, j1| {
                            x +
                                Complex::<T>::from_polar(
                                    &one(),
                                    &(-cast::<_, T>(2 * n2 * j1 * k1).unwrap() * T::PI() /
                                          cast(source.len()).unwrap()),
                                ) *
                                    source[(n1 * j2 + n2 * j1) % source.len()]
                        })
                    }
                }
            }
            for j in 0..n1 {
                converter(&mut output[(j * n2)..((j + 1) * n2)], is_back);
            }

            // 並べ替え
            for k1 in 0..n1 {
                let j1 = n2 * k1 % n1;
                for k2 in 0..n2 {
                    let k = (n2 * k1 + n1 * k2) % source.len();
                    source[k] = output[n2 * j1 + (n1 * k2 % n2)];
                }
            }
        }
        let mut ret = source.to_vec();
        converter(&mut ret, is_back);
        return ret;
    }
}

impl<T> AlgorithmName for PrimeFactorBasic<T> {
    fn name(&self) -> &'static str {
        "Prime Factor Algorithm"
    }
}

impl<T: Float + FloatConst> DftAlgorithm<T> for PrimeFactorBasic<T> {
    fn convert(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
        self.convert_inner(source, false)
    }

    fn convert_back(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
        self.convert_inner(source, true)
    }
}

評価

大浦氏のコードを移植した。
氏のサイトでも言及されている通り、このコードは遅い。
本来であれば、基数毎のDFTを専用に作成し、高速化を図る必要がある。

Mixed Radix FFT

要素数$2^n$のFFTで示したCooley-Tukey型やStockham型のFFTは実際には任意の要素数で処理が可能である。
以下にCooley-Tukey型の時間引きにおけるMixed Radix FFTのコードを示す

use num::{Complex, Float, cast, one};
use num_traits::float::FloatConst;
use std::cmp;
use num_traits::NumAssign;

pub struct MixedRadix<T> {
    len: usize,
    ids: Vec<usize>,
    omega: Vec<Complex<T>>,
    omega_back: Vec<Complex<T>>,
    factors: Vec<usize>,
}

impl<T: Float + FloatConst + NumAssign> MixedRadix<T> {
    pub fn new() -> Self {
        Self {
            len: 0,
            ids: Vec::new(),
            omega: Vec::new(),
            omega_back: Vec::new(),
            factors: Vec::new(),
        }
    }

    fn initialize(&mut self, len: usize) {
        self.len = len;

        // ωの事前計算
        self.omega = Self::calc_omega(len);
        self.omega_back = Vec::with_capacity(len + 1);
        for i in (0..len + 1).rev() {
            self.omega_back.push(self.omega[i]);
        }
        // 素因数分解
        self.factors = Self::prime_factorization(len);

        // ビットリバースの計算
        self.ids = Vec::<usize>::with_capacity(len);
        let mut llen = 1_usize;
        self.ids.push(0);
        for &k in &self.factors {
            for i in 0..llen {
                self.ids[i] *= k;
            }
            for i in 1..k {
                for j in 0..llen {
                    let id = self.ids[j] + i;
                    self.ids.push(id);
                }
            }
            llen *= k;
        }
    }

    // ωの事前計算
    fn calc_omega(len: usize) -> Vec<Complex<T>> {
        let mut omega = Vec::with_capacity(len + 1);
        omega.push(one());
        match len & 3 {
            0 => {
                // 4で割り切れる(下位2ビットが0)ならば
                let q = len >> 2;
                let h = len >> 1;
                // 1/4周分を計算
                for i in 1..q {
                    omega.push(Complex::from_polar(
                        &one(),
                        &(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
                              cast(i).unwrap()),
                    ));
                }

                // 1/4~1/2周分を計算
                for i in q..h {
                    let tmp = omega[i - q];
                    omega.push(Complex::new(tmp.im, -tmp.re));
                }

                // 1/2周目から計算
                for i in h..len {
                    let tmp = omega[i - h];
                    omega.push(Complex::new(-tmp.re, -tmp.im));
                }

            }
            2 => {
                // 2で割り切れる(下位1ビットが0)ならば
                let h = cmp::max(len >> 1, 1);
                // 1/2周分を計算
                for i in 1..h {
                    omega.push(Complex::from_polar(
                        &one(),
                        &(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
                              cast(i).unwrap()),
                    ));
                }

                // 1/2周目から計算
                for i in h..len {
                    let tmp = omega[i - h];
                    omega.push(Complex::new(-tmp.re, -tmp.im));
                }
            }
            _ => {
                for i in 1..len {
                    omega.push(Complex::from_polar(
                        &one(),
                        &(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
                              cast(i).unwrap()),
                    ));
                }
            }
        }
        // 1周ちょうど
        omega.push(one());
        return omega;
    }

    // 素因数分解
    fn prime_factorization(mut value: usize) -> Vec<usize> {
        let mut factors = Vec::<usize>::with_capacity(32); // 2^32=4G
        if value == 0 {
            return factors;
        }
        while value & 3 == 0 {
            factors.push(4);
            value >>= 2;
        }
        while value & 1 == 0 {
            factors.push(2);
            value >>= 1;
        }
        let mut prime = 3;
        while value >= prime * prime {
            if value % prime == 0 {
                factors.push(prime);
                value /= prime;
            } else {
                prime += 2;
            }
        }
        if value > 1 {
            factors.push(value);
        }
        return factors;
    }

    fn convert_inner(&mut self, source: &[Complex<T>], is_back: bool) -> Vec<Complex<T>> {
        let len = source.len();

        // 1要素以下ならば入力値をそのまま返す
        if len <= 1 {
            return source.to_vec();
        }

        if len != self.len {
            self.initialize(len);
        }

        let omega = if is_back {
            &self.omega_back
        } else {
            &self.omega
        };

        let mut ret = self.ids.iter().map(|&i| if is_back {
            source[i].unscale(cast(source.len()).unwrap()) // 逆変換の場合はこのタイミングで割り戻しておく
        } else {
            source[i]
        }).collect::<Vec<_>>();

        // FFT
        let mut rot = Vec::new();
        let mut po2 = 1;
        let mut rad = len;
        let mut rot_len = 0;

        let im_one = if is_back { -Complex::i() } else { Complex::i() };

        for &factor in &self.factors {
            let po2m = po2;
            po2 *= factor;
            rad /= factor;

            for mut j in 0..po2m {
                let wpos = rad * j;

                while j < len {
                    match factor {
                        2 => {
                            let pos1 = j + po2m;
                            let z1 = ret[pos1] * omega[wpos];
                            ret[pos1] = ret[j] - z1;
                            ret[j] += z1;
                        }
                        3 => {
                            let pos1 = j + po2m;
                            let pos2 = pos1 + po2m;
                            let z1 = ret[pos1] * omega[wpos];
                            let z2 = ret[pos2] * omega[wpos << 1];
                            let t1 = z1 + z2;
                            let t2 = ret[j] - t1.scale(cast(0.5).unwrap());
                            let t3 = (z1 - z2) *
                                im_one.scale(
                                    (cast::<_, T>(-2.0).unwrap() * T::PI() / cast(3.0).unwrap())
                                        .sin(),
                                );
                            ret[j] += t1;
                            ret[pos1] = t2 + t3;
                            ret[pos2] = t2 - t3;
                        }
                        4 => {
                            let w1 = omega[wpos];
                            let w2 = omega[wpos << 1];
                            let w3 = omega[wpos * 3];

                            let pos1 = j + po2m;
                            let pos2 = pos1 + po2m;
                            let pos3 = pos2 + po2m;
                            let wfa = ret[j];
                            let wfb = ret[pos2] * w2;
                            let wfab = wfa + wfb;
                            let wfamb = wfa - wfb;
                            let wfc = ret[pos1] * w1;
                            let wfd = ret[pos3] * w3;
                            let wfcd = wfc + wfd;
                            let wfcimdi = (wfc - wfd) * im_one;

                            ret[j] = wfab + wfcd;
                            ret[pos1] = wfamb - wfcimdi;
                            ret[pos2] = wfab - wfcd;
                            ret[pos3] = wfamb + wfcimdi;
                        }
                        5 => {
                            let w1 = omega[wpos];
                            let w2 = omega[wpos << 1];
                            let w3 = omega[wpos * 3];
                            let w4 = omega[wpos << 2];

                            let pos2 = j + po2m;
                            let pos3 = pos2 + po2m;
                            let pos4 = pos3 + po2m;
                            let pos5 = pos4 + po2m;

                            let z0 = ret[j];
                            let z1 = ret[pos2] * w1;
                            let z2 = ret[pos3] * w2;
                            let z3 = ret[pos4] * w3;
                            let z4 = ret[pos5] * w4;

                            let t1 = z1 + z4;
                            let t2 = z2 + z3;
                            let t3 = z1 - z4;
                            let t4 = z2 - z3;
                            let t5 = t1 + t2;
                            let t6 = (t1 - t2).scale(
                                cast::<_, T>(0.25).unwrap() *
                                    cast::<_, T>(5.0).unwrap().sqrt(),
                            );
                            let t7 = z0 - t5.scale(cast(0.25).unwrap());
                            let t8 = t6 + t7;
                            let t9 = t7 - t6;
                            let t10 = (t3.scale((cast::<_, T>(-0.4).unwrap() * T::PI()).sin()) +
                                           t4.scale(
                                    (cast::<_, T>(-0.2).unwrap() * T::PI()).sin(),
                                )) * im_one;
                            let t11 = (t3.scale((cast::<_, T>(-0.2).unwrap() * T::PI()).sin()) -
                                           t4.scale(
                                    (cast::<_, T>(-0.4).unwrap() * T::PI()).sin(),
                                )) * im_one;

                            ret[j] = z0 + t5;
                            ret[pos2] = t8 + t10;
                            ret[pos3] = t9 + t11;
                            ret[pos4] = t9 - t11;
                            ret[pos5] = t8 - t10;
                        }
                        _ => {
                            if rot_len != factor {
                                rot_len = factor;
                                let rot_width = len / factor;
                                rot = (0..factor)
                                    .map(|i| omega[rot_width * i])
                                    .collect::<Vec<_>>();
                            }

                            // 定義式DFT
                            let pos = (0..factor).map(|i| j + po2m * i).collect::<Vec<_>>();
                            let z = (0..factor)
                                .map(|i| ret[pos[i]] * omega[wpos * i])
                                .collect::<Vec<_>>();
                            for i in 0..factor {
                                ret[pos[i]] =
                                    (1..factor).fold(z[0], |x, l| x + z[l] * rot[(i * l) % factor]);
                            }
                        }
                    }
                    j += po2;
                }
            }
        }
        return ret;
    }
}

impl<T> AlgorithmName for MixedRadix<T> {
    fn name(&self) -> &'static str {
        "Cooley-Turkey (Mixed Radix)"
    }
}

impl<T: Float + FloatConst + NumAssign> DftAlgorithm<T> for MixedRadix<T> {
    fn convert(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
        self.convert_inner(source, false)
    }

    fn convert_back(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
        self.convert_inner(source, true)
    }
}

評価

Prime Factor型と同様に、基数毎のDFTを専用に作成し、高速化を図る必要がある。
上記では基数2,3,4,5におけるDFTを作成し、高速化を図っているが、まだチューニングの余地は残されている。

Chirp-Z 変換

Chirp-Z 変換は、要素数を変更するための変換処理である。
$2^n$のFFTはかなり高速な為、変換を実施し、FFTを行い、逆変換を実施する事により、高速に処理できる。

use num::{Complex, Float, cast, one, zero};
use num_traits::float::FloatConst;
use std::cmp;

#[derive(Debug)]
pub struct ChirpZ<T> {
    level: usize,
    ids: Vec<usize>,
    omega: Vec<Complex<T>>,
    src_omega: Vec<Complex<T>>,
    src_len: usize,
}

impl<T: Float + FloatConst> ChirpZ<T> {
    pub fn new() -> Self {
        Self {
            level: 0,
            ids: Vec::new(),
            omega: Vec::new(),
            src_omega: Vec::new(),
            src_len: 0,
        }
    }

    // ωの事前計算
    fn calc_omega(len: usize) -> Vec<Complex<T>> {
        let mut omega = Vec::with_capacity(len + 1);
        omega.push(one());
        if len & 3 == 0 {
            // 4で割り切れる(下位2ビットが0)ならば
            let q = len >> 2;
            let h = len >> 1;
            // 1/4周分を計算
            for i in 1..q {
                omega.push(Complex::from_polar(
                    &one(),
                    &(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
                          cast(i).unwrap()),
                ));
            }

            // 1/4~1/2周分を計算
            for i in q..h {
                let tmp = omega[i - q];
                omega.push(Complex::new(tmp.im, -tmp.re));
            }

            // 1/2周目から計算
            for i in h..len {
                let tmp = omega[i - h];
                omega.push(Complex::new(-tmp.re, -tmp.im));
            }

        } else if len & 1 == 0 {
            // 2で割り切れる(下位1ビットが0)ならば
            let h = cmp::max(len >> 1, 1);
            // 1/2周分を計算
            for i in 1..h {
                omega.push(Complex::from_polar(
                    &one(),
                    &(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
                          cast(i).unwrap()),
                ));
            }

            // 1/2周目から計算
            for i in h..len {
                let tmp = omega[i - h];
                omega.push(Complex::new(-tmp.re, -tmp.im));
            }
        } else {
            for i in 1..len {
                omega.push(Complex::from_polar(
                    &one(),
                    &(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
                          cast(i).unwrap()),
                ));
            }
        }
        // 1周ちょうど
        omega.push(one());
        return omega;
    }

    fn initialize(&mut self, len: usize) {
        let pow2len = len.next_power_of_two() << 1;
        let lv = pow2len.trailing_zeros() as usize;

        if lv != self.level {
            self.level = lv;
            self.omega = Self::calc_omega(pow2len);
        }

        self.src_len = len;
        self.src_omega = Self::calc_omega(len << 1);

        // ビットリバースの計算
        self.ids = Vec::with_capacity(pow2len);
        self.ids.push(0);
        for _ in 0..lv {
            for j in 0..self.ids.len() {
                let id = self.ids[j] << 1;
                self.ids[j] = id;
                self.ids.push(id + 1);
            }
        }
    }

    fn convert_rad2(&mut self, source: &[Complex<T>], is_back: bool) -> Vec<Complex<T>> {

        // 入力の並び替え
        let mut ret = self.ids.iter().map(|&i| source[i]).collect::<Vec<_>>();

        // FFT
        let mut po2 = 1;
        if (self.level & 1) == 1 {
            po2 = 2;
            for j in 0..(ret.len() >> 1) {
                let pos_a = j << 1;
                let pos_b = pos_a + 1;
                let wfa = ret[pos_a];
                let wfb = ret[pos_b];
                ret[pos_a] = wfa + wfb;
                ret[pos_b] = wfa - wfb;
            }
        }

        let im_one = if is_back { -Complex::i() } else { Complex::i() };

        for i in 1..((self.level >> 1) + 1) {
            let po2m = po2;
            po2 <<= 2;

            for k in 0..po2m {
                let pos_w = (1 << ((self.level & !1) - (i << 1))) * k;

                let (w, w2, w3) = if is_back {
                    (
                        self.omega[pos_w].conj(),
                        self.omega[pos_w << 1].conj(),
                        self.omega[pos_w * 3].conj(),
                    )
                } else {
                    (
                        self.omega[pos_w],
                        self.omega[pos_w << 1],
                        self.omega[pos_w * 3],
                    )
                };
                let mut j = k;
                while j < ret.len() {
                    let pos_b = j + po2m;
                    let pos_c = j + (po2m << 1);
                    let pos_d = j + (po2m * 3);
                    let wfa = ret[j];
                    let wfb = ret[pos_b] * w2;
                    let wfab = wfa + wfb;
                    let wfamb = wfa - wfb;
                    let wfc = ret[pos_c] * w;
                    let wfd = ret[pos_d] * w3;
                    let wfcd = wfc + wfd;
                    let wfcimdi = (wfc - wfd) * im_one;

                    ret[j] = wfab + wfcd;
                    ret[pos_b] = wfamb - wfcimdi;
                    ret[pos_c] = wfab - wfcd;
                    ret[pos_d] = wfamb + wfcimdi;
                    j += po2;
                }
            }
        }
        if is_back {
            for i in 0..ret.len() {
                ret[i] = ret[i].unscale(cast(ret.len()).unwrap());
            }
        }
        return ret;
    }

    fn convert_inner(&mut self, source: &[Complex<T>], is_back: bool) -> Vec<Complex<T>> {
        let srclen = source.len();

        // 1要素以下ならば入力値をそのまま返す
        if srclen <= 1 {
            return source.to_vec();
        }

        if self.src_len != srclen {
            self.initialize(srclen);
        }

        let len = 1 << self.level;

        let mut a = Vec::with_capacity(len);
        let mut b = Vec::with_capacity(len);

        for i in 0..source.len() {
            let w = if is_back {
                self.src_omega[(i * i) % (srclen << 1)].conj()
            } else {
                self.src_omega[(i * i) % (srclen << 1)]
            };
            a.push(source[i] * w);
            b.push(w.conj());
        }

        let hlen = (len >> 1) + 1;
        for _ in srclen..hlen {
            a.push(zero());
            b.push(zero());
        }
        for i in hlen..len {
            a.push(zero());
            let t = b[len - i];
            b.push(t);
        }

        let aa = self.convert_rad2(&a, false);
        let bb = self.convert_rad2(&b, false);

        let gg = aa.iter().zip(&bb).map(|(&x, &y)| x * y).collect::<Vec<_>>();
        let g = self.convert_rad2(&gg, true);

        // Multiply phase factor
        return (0..srclen)
            .map(|i| if is_back {
                (g[i] * b[i].conj()).unscale(cast(srclen).unwrap())
            } else {
                g[i] * b[i].conj()
            })
            .collect::<Vec<_>>();
    }
}

impl<T> AlgorithmName for ChirpZ<T> {
    fn name(&self) -> &'static str {
        "Chirp-Z 変換"
    }
}

impl<T: Float + FloatConst> DftAlgorithm<T> for ChirpZ<T> {
    fn convert(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
        self.convert_inner(source, false)
    }

    fn convert_back(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
        self.convert_inner(source, true)
    }
}

評価

本コードでは、Chirp-Z変換を行なった後、基数4、時間引きのCooley-Tukey型FFTを実施している。
要素数が$2^n$の時も変換している等、無駄の多いコードとなっているため修正の必要があるが、
要素数が大きく素数である場合などでは、Chirp-Z変換が最速となる。

終わりに

次回は高速化について言及したい。

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3