概要
一般的なFFTは要素数が$2^n$でのみ適用可能である。
以下では、要素数が$2^n$に限らず任意の数で高速に処理する方法を示す。
Prime Factor型 FFT
use num::{Complex, Float, cast, one, zero};
use num_traits::float::FloatConst;
use std::marker::PhantomData;
use std::collections::HashMap;
pub struct PrimeFactorBasic<T> {
phantom: PhantomData<T>,
}
fn primes() -> Box<Iterator<Item = usize>> {
let mut d = HashMap::new();
Box::new((2..).take(1).chain(
(1..).map(|i| i * 2 + 1).filter_map(move |i| {
let (result, factor) = match d.remove(&i) {
Some(f) => (None, f),
None => (Some(i), i * 2),
};
(1..)
.map(|j| i + j * factor)
.skip_while(|m| d.contains_key(m))
.next()
.map(|m| d.insert(m, factor));
result
}),
))
}
impl<T: Float + FloatConst> PrimeFactorBasic<T> {
pub fn new() -> Self {
Self { phantom: PhantomData }
}
fn convert_inner(&mut self, source: &[Complex<T>], is_back: bool) -> Vec<Complex<T>> {
fn converter<T: Float + FloatConst>(source: &mut [Complex<T>], is_back: bool) {
if source.len() <= 1 {
return;
}
let mut output = (0..source.len()).map(|_| zero()).collect::<Vec<_>>();
// 素因数分解 n1:FFTを行う次数のべき乗 n2:残りの因数
let mut n1 = source.len();
for p in primes() {
if source.len() % p == 0 {
n1 = p;
break;
}
if p * p > source.len() {
break;
}
}
let mut n2 = source.len();
while n2 % n1 == 0 {
n2 /= n1;
}
n1 = source.len() / n2;
// DFT
for j2 in 0..n2 {
for k1 in 0..n1 {
output[n2 * k1 + j2] = if is_back {
(1..n1)
.fold(source[n1 * j2], |x, j1| {
x +
Complex::<T>::from_polar(
&one(),
&(cast::<_, T>(2 * n2 * j1 * k1).unwrap() * T::PI() /
cast(source.len()).unwrap()),
) *
source[(n1 * j2 + n2 * j1) % source.len()]
})
.unscale(cast(n1).unwrap())
} else {
(1..n1).fold(source[n1 * j2], |x, j1| {
x +
Complex::<T>::from_polar(
&one(),
&(-cast::<_, T>(2 * n2 * j1 * k1).unwrap() * T::PI() /
cast(source.len()).unwrap()),
) *
source[(n1 * j2 + n2 * j1) % source.len()]
})
}
}
}
for j in 0..n1 {
converter(&mut output[(j * n2)..((j + 1) * n2)], is_back);
}
// 並べ替え
for k1 in 0..n1 {
let j1 = n2 * k1 % n1;
for k2 in 0..n2 {
let k = (n2 * k1 + n1 * k2) % source.len();
source[k] = output[n2 * j1 + (n1 * k2 % n2)];
}
}
}
let mut ret = source.to_vec();
converter(&mut ret, is_back);
return ret;
}
}
impl<T> AlgorithmName for PrimeFactorBasic<T> {
fn name(&self) -> &'static str {
"Prime Factor Algorithm"
}
}
impl<T: Float + FloatConst> DftAlgorithm<T> for PrimeFactorBasic<T> {
fn convert(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
self.convert_inner(source, false)
}
fn convert_back(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
self.convert_inner(source, true)
}
}
評価
大浦氏のコードを移植した。
氏のサイトでも言及されている通り、このコードは遅い。
本来であれば、基数毎のDFTを専用に作成し、高速化を図る必要がある。
Mixed Radix FFT
要素数$2^n$のFFTで示したCooley-Tukey型やStockham型のFFTは実際には任意の要素数で処理が可能である。
以下にCooley-Tukey型の時間引きにおけるMixed Radix FFTのコードを示す
use num::{Complex, Float, cast, one};
use num_traits::float::FloatConst;
use std::cmp;
use num_traits::NumAssign;
pub struct MixedRadix<T> {
len: usize,
ids: Vec<usize>,
omega: Vec<Complex<T>>,
omega_back: Vec<Complex<T>>,
factors: Vec<usize>,
}
impl<T: Float + FloatConst + NumAssign> MixedRadix<T> {
pub fn new() -> Self {
Self {
len: 0,
ids: Vec::new(),
omega: Vec::new(),
omega_back: Vec::new(),
factors: Vec::new(),
}
}
fn initialize(&mut self, len: usize) {
self.len = len;
// ωの事前計算
self.omega = Self::calc_omega(len);
self.omega_back = Vec::with_capacity(len + 1);
for i in (0..len + 1).rev() {
self.omega_back.push(self.omega[i]);
}
// 素因数分解
self.factors = Self::prime_factorization(len);
// ビットリバースの計算
self.ids = Vec::<usize>::with_capacity(len);
let mut llen = 1_usize;
self.ids.push(0);
for &k in &self.factors {
for i in 0..llen {
self.ids[i] *= k;
}
for i in 1..k {
for j in 0..llen {
let id = self.ids[j] + i;
self.ids.push(id);
}
}
llen *= k;
}
}
// ωの事前計算
fn calc_omega(len: usize) -> Vec<Complex<T>> {
let mut omega = Vec::with_capacity(len + 1);
omega.push(one());
match len & 3 {
0 => {
// 4で割り切れる(下位2ビットが0)ならば
let q = len >> 2;
let h = len >> 1;
// 1/4周分を計算
for i in 1..q {
omega.push(Complex::from_polar(
&one(),
&(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
cast(i).unwrap()),
));
}
// 1/4~1/2周分を計算
for i in q..h {
let tmp = omega[i - q];
omega.push(Complex::new(tmp.im, -tmp.re));
}
// 1/2周目から計算
for i in h..len {
let tmp = omega[i - h];
omega.push(Complex::new(-tmp.re, -tmp.im));
}
}
2 => {
// 2で割り切れる(下位1ビットが0)ならば
let h = cmp::max(len >> 1, 1);
// 1/2周分を計算
for i in 1..h {
omega.push(Complex::from_polar(
&one(),
&(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
cast(i).unwrap()),
));
}
// 1/2周目から計算
for i in h..len {
let tmp = omega[i - h];
omega.push(Complex::new(-tmp.re, -tmp.im));
}
}
_ => {
for i in 1..len {
omega.push(Complex::from_polar(
&one(),
&(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
cast(i).unwrap()),
));
}
}
}
// 1周ちょうど
omega.push(one());
return omega;
}
// 素因数分解
fn prime_factorization(mut value: usize) -> Vec<usize> {
let mut factors = Vec::<usize>::with_capacity(32); // 2^32=4G
if value == 0 {
return factors;
}
while value & 3 == 0 {
factors.push(4);
value >>= 2;
}
while value & 1 == 0 {
factors.push(2);
value >>= 1;
}
let mut prime = 3;
while value >= prime * prime {
if value % prime == 0 {
factors.push(prime);
value /= prime;
} else {
prime += 2;
}
}
if value > 1 {
factors.push(value);
}
return factors;
}
fn convert_inner(&mut self, source: &[Complex<T>], is_back: bool) -> Vec<Complex<T>> {
let len = source.len();
// 1要素以下ならば入力値をそのまま返す
if len <= 1 {
return source.to_vec();
}
if len != self.len {
self.initialize(len);
}
let omega = if is_back {
&self.omega_back
} else {
&self.omega
};
let mut ret = self.ids.iter().map(|&i| if is_back {
source[i].unscale(cast(source.len()).unwrap()) // 逆変換の場合はこのタイミングで割り戻しておく
} else {
source[i]
}).collect::<Vec<_>>();
// FFT
let mut rot = Vec::new();
let mut po2 = 1;
let mut rad = len;
let mut rot_len = 0;
let im_one = if is_back { -Complex::i() } else { Complex::i() };
for &factor in &self.factors {
let po2m = po2;
po2 *= factor;
rad /= factor;
for mut j in 0..po2m {
let wpos = rad * j;
while j < len {
match factor {
2 => {
let pos1 = j + po2m;
let z1 = ret[pos1] * omega[wpos];
ret[pos1] = ret[j] - z1;
ret[j] += z1;
}
3 => {
let pos1 = j + po2m;
let pos2 = pos1 + po2m;
let z1 = ret[pos1] * omega[wpos];
let z2 = ret[pos2] * omega[wpos << 1];
let t1 = z1 + z2;
let t2 = ret[j] - t1.scale(cast(0.5).unwrap());
let t3 = (z1 - z2) *
im_one.scale(
(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(3.0).unwrap())
.sin(),
);
ret[j] += t1;
ret[pos1] = t2 + t3;
ret[pos2] = t2 - t3;
}
4 => {
let w1 = omega[wpos];
let w2 = omega[wpos << 1];
let w3 = omega[wpos * 3];
let pos1 = j + po2m;
let pos2 = pos1 + po2m;
let pos3 = pos2 + po2m;
let wfa = ret[j];
let wfb = ret[pos2] * w2;
let wfab = wfa + wfb;
let wfamb = wfa - wfb;
let wfc = ret[pos1] * w1;
let wfd = ret[pos3] * w3;
let wfcd = wfc + wfd;
let wfcimdi = (wfc - wfd) * im_one;
ret[j] = wfab + wfcd;
ret[pos1] = wfamb - wfcimdi;
ret[pos2] = wfab - wfcd;
ret[pos3] = wfamb + wfcimdi;
}
5 => {
let w1 = omega[wpos];
let w2 = omega[wpos << 1];
let w3 = omega[wpos * 3];
let w4 = omega[wpos << 2];
let pos2 = j + po2m;
let pos3 = pos2 + po2m;
let pos4 = pos3 + po2m;
let pos5 = pos4 + po2m;
let z0 = ret[j];
let z1 = ret[pos2] * w1;
let z2 = ret[pos3] * w2;
let z3 = ret[pos4] * w3;
let z4 = ret[pos5] * w4;
let t1 = z1 + z4;
let t2 = z2 + z3;
let t3 = z1 - z4;
let t4 = z2 - z3;
let t5 = t1 + t2;
let t6 = (t1 - t2).scale(
cast::<_, T>(0.25).unwrap() *
cast::<_, T>(5.0).unwrap().sqrt(),
);
let t7 = z0 - t5.scale(cast(0.25).unwrap());
let t8 = t6 + t7;
let t9 = t7 - t6;
let t10 = (t3.scale((cast::<_, T>(-0.4).unwrap() * T::PI()).sin()) +
t4.scale(
(cast::<_, T>(-0.2).unwrap() * T::PI()).sin(),
)) * im_one;
let t11 = (t3.scale((cast::<_, T>(-0.2).unwrap() * T::PI()).sin()) -
t4.scale(
(cast::<_, T>(-0.4).unwrap() * T::PI()).sin(),
)) * im_one;
ret[j] = z0 + t5;
ret[pos2] = t8 + t10;
ret[pos3] = t9 + t11;
ret[pos4] = t9 - t11;
ret[pos5] = t8 - t10;
}
_ => {
if rot_len != factor {
rot_len = factor;
let rot_width = len / factor;
rot = (0..factor)
.map(|i| omega[rot_width * i])
.collect::<Vec<_>>();
}
// 定義式DFT
let pos = (0..factor).map(|i| j + po2m * i).collect::<Vec<_>>();
let z = (0..factor)
.map(|i| ret[pos[i]] * omega[wpos * i])
.collect::<Vec<_>>();
for i in 0..factor {
ret[pos[i]] =
(1..factor).fold(z[0], |x, l| x + z[l] * rot[(i * l) % factor]);
}
}
}
j += po2;
}
}
}
return ret;
}
}
impl<T> AlgorithmName for MixedRadix<T> {
fn name(&self) -> &'static str {
"Cooley-Turkey (Mixed Radix)"
}
}
impl<T: Float + FloatConst + NumAssign> DftAlgorithm<T> for MixedRadix<T> {
fn convert(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
self.convert_inner(source, false)
}
fn convert_back(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
self.convert_inner(source, true)
}
}
評価
Prime Factor型と同様に、基数毎のDFTを専用に作成し、高速化を図る必要がある。
上記では基数2,3,4,5におけるDFTを作成し、高速化を図っているが、まだチューニングの余地は残されている。
Chirp-Z 変換
Chirp-Z 変換は、要素数を変更するための変換処理である。
$2^n$のFFTはかなり高速な為、変換を実施し、FFTを行い、逆変換を実施する事により、高速に処理できる。
use num::{Complex, Float, cast, one, zero};
use num_traits::float::FloatConst;
use std::cmp;
#[derive(Debug)]
pub struct ChirpZ<T> {
level: usize,
ids: Vec<usize>,
omega: Vec<Complex<T>>,
src_omega: Vec<Complex<T>>,
src_len: usize,
}
impl<T: Float + FloatConst> ChirpZ<T> {
pub fn new() -> Self {
Self {
level: 0,
ids: Vec::new(),
omega: Vec::new(),
src_omega: Vec::new(),
src_len: 0,
}
}
// ωの事前計算
fn calc_omega(len: usize) -> Vec<Complex<T>> {
let mut omega = Vec::with_capacity(len + 1);
omega.push(one());
if len & 3 == 0 {
// 4で割り切れる(下位2ビットが0)ならば
let q = len >> 2;
let h = len >> 1;
// 1/4周分を計算
for i in 1..q {
omega.push(Complex::from_polar(
&one(),
&(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
cast(i).unwrap()),
));
}
// 1/4~1/2周分を計算
for i in q..h {
let tmp = omega[i - q];
omega.push(Complex::new(tmp.im, -tmp.re));
}
// 1/2周目から計算
for i in h..len {
let tmp = omega[i - h];
omega.push(Complex::new(-tmp.re, -tmp.im));
}
} else if len & 1 == 0 {
// 2で割り切れる(下位1ビットが0)ならば
let h = cmp::max(len >> 1, 1);
// 1/2周分を計算
for i in 1..h {
omega.push(Complex::from_polar(
&one(),
&(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
cast(i).unwrap()),
));
}
// 1/2周目から計算
for i in h..len {
let tmp = omega[i - h];
omega.push(Complex::new(-tmp.re, -tmp.im));
}
} else {
for i in 1..len {
omega.push(Complex::from_polar(
&one(),
&(cast::<_, T>(-2.0).unwrap() * T::PI() / cast(len).unwrap() *
cast(i).unwrap()),
));
}
}
// 1周ちょうど
omega.push(one());
return omega;
}
fn initialize(&mut self, len: usize) {
let pow2len = len.next_power_of_two() << 1;
let lv = pow2len.trailing_zeros() as usize;
if lv != self.level {
self.level = lv;
self.omega = Self::calc_omega(pow2len);
}
self.src_len = len;
self.src_omega = Self::calc_omega(len << 1);
// ビットリバースの計算
self.ids = Vec::with_capacity(pow2len);
self.ids.push(0);
for _ in 0..lv {
for j in 0..self.ids.len() {
let id = self.ids[j] << 1;
self.ids[j] = id;
self.ids.push(id + 1);
}
}
}
fn convert_rad2(&mut self, source: &[Complex<T>], is_back: bool) -> Vec<Complex<T>> {
// 入力の並び替え
let mut ret = self.ids.iter().map(|&i| source[i]).collect::<Vec<_>>();
// FFT
let mut po2 = 1;
if (self.level & 1) == 1 {
po2 = 2;
for j in 0..(ret.len() >> 1) {
let pos_a = j << 1;
let pos_b = pos_a + 1;
let wfa = ret[pos_a];
let wfb = ret[pos_b];
ret[pos_a] = wfa + wfb;
ret[pos_b] = wfa - wfb;
}
}
let im_one = if is_back { -Complex::i() } else { Complex::i() };
for i in 1..((self.level >> 1) + 1) {
let po2m = po2;
po2 <<= 2;
for k in 0..po2m {
let pos_w = (1 << ((self.level & !1) - (i << 1))) * k;
let (w, w2, w3) = if is_back {
(
self.omega[pos_w].conj(),
self.omega[pos_w << 1].conj(),
self.omega[pos_w * 3].conj(),
)
} else {
(
self.omega[pos_w],
self.omega[pos_w << 1],
self.omega[pos_w * 3],
)
};
let mut j = k;
while j < ret.len() {
let pos_b = j + po2m;
let pos_c = j + (po2m << 1);
let pos_d = j + (po2m * 3);
let wfa = ret[j];
let wfb = ret[pos_b] * w2;
let wfab = wfa + wfb;
let wfamb = wfa - wfb;
let wfc = ret[pos_c] * w;
let wfd = ret[pos_d] * w3;
let wfcd = wfc + wfd;
let wfcimdi = (wfc - wfd) * im_one;
ret[j] = wfab + wfcd;
ret[pos_b] = wfamb - wfcimdi;
ret[pos_c] = wfab - wfcd;
ret[pos_d] = wfamb + wfcimdi;
j += po2;
}
}
}
if is_back {
for i in 0..ret.len() {
ret[i] = ret[i].unscale(cast(ret.len()).unwrap());
}
}
return ret;
}
fn convert_inner(&mut self, source: &[Complex<T>], is_back: bool) -> Vec<Complex<T>> {
let srclen = source.len();
// 1要素以下ならば入力値をそのまま返す
if srclen <= 1 {
return source.to_vec();
}
if self.src_len != srclen {
self.initialize(srclen);
}
let len = 1 << self.level;
let mut a = Vec::with_capacity(len);
let mut b = Vec::with_capacity(len);
for i in 0..source.len() {
let w = if is_back {
self.src_omega[(i * i) % (srclen << 1)].conj()
} else {
self.src_omega[(i * i) % (srclen << 1)]
};
a.push(source[i] * w);
b.push(w.conj());
}
let hlen = (len >> 1) + 1;
for _ in srclen..hlen {
a.push(zero());
b.push(zero());
}
for i in hlen..len {
a.push(zero());
let t = b[len - i];
b.push(t);
}
let aa = self.convert_rad2(&a, false);
let bb = self.convert_rad2(&b, false);
let gg = aa.iter().zip(&bb).map(|(&x, &y)| x * y).collect::<Vec<_>>();
let g = self.convert_rad2(&gg, true);
// Multiply phase factor
return (0..srclen)
.map(|i| if is_back {
(g[i] * b[i].conj()).unscale(cast(srclen).unwrap())
} else {
g[i] * b[i].conj()
})
.collect::<Vec<_>>();
}
}
impl<T> AlgorithmName for ChirpZ<T> {
fn name(&self) -> &'static str {
"Chirp-Z 変換"
}
}
impl<T: Float + FloatConst> DftAlgorithm<T> for ChirpZ<T> {
fn convert(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
self.convert_inner(source, false)
}
fn convert_back(&mut self, source: &[Complex<T>]) -> Vec<Complex<T>> {
self.convert_inner(source, true)
}
}
評価
本コードでは、Chirp-Z変換を行なった後、基数4、時間引きのCooley-Tukey型FFTを実施している。
要素数が$2^n$の時も変換している等、無駄の多いコードとなっているため修正の必要があるが、
要素数が大きく素数である場合などでは、Chirp-Z変換が最速となる。
終わりに
次回は高速化について言及したい。