Rust で DFT のプログラムを書いてみた
高専4年最後の実験「フーリエ解析」で離散フーリエ変換のプログラムを書く必要があったので Rust 入門してみました.
初心者なので見苦しいコードがあるかもしれまん.ごめんなさい.
リポジトリ
https://github.com/yangniao23/dft
MIT ライセンスで公開しています.
WAVE ファイルの生成
今回は WAVE ファイルを DFT するプログラムを作成したいので,テストデータを生成するプログラムをまず書きました.
WAVEファイルの生成には hound を使っています.
use hound;
use std::f64::consts::PI;
use std::i16;
const SAMPLE_RATE: u32 = 44100;
fn write_wav(path: &str, samples: Vec<f64>) {
let spec = hound::WavSpec {
channels: 1,
sample_rate: SAMPLE_RATE as u32,
bits_per_sample: 16,
sample_format: hound::SampleFormat::Int,
};
let mut writer = hound::WavWriter::create(path, spec).unwrap();
for s in samples {
let s = (s * i16::MAX as f64) as i16;
writer.write_sample(s).unwrap();
}
}
fn generate_signal(freq: f64, duration: f64) -> Vec<f64> {
let dt = 1.0 / SAMPLE_RATE as f64;
let n = (duration / dt) as usize;
let mut signal = vec![0.0; n];
for i in 0..n {
let t = i as f64 * dt;
signal[i] = (2.0 * PI * freq * t).sin();
}
signal
}
fn main() {
// 1000Hz, 2000Hz, 3000Hzの正弦波を合成した信号を生成する
let signal = generate_signal(1000.0, 0.1)
.iter()
.zip(generate_signal(2000.0, 0.1).iter())
.zip(generate_signal(3000.0, 0.1).iter())
.map(|((&s1, &s2), &s3)| s1 + 2.0 * s2 + 3.0 * s3)
.collect::<Vec<f64>>();
// 信号の正規化
let abs_max = signal
.iter()
.fold(f64::MIN, |max, &s| f64::max(max, s.abs()));
let signal = signal.iter().map(|s| s / abs_max).collect::<Vec<f64>>();
write_wav("./sample.wav", signal);
}
write_wav
関数はサンプリングレート44.1kHz,量子化ビット数16bitでWAVEファイルを生成する関数です.
ただし,samples
の値域は[-1.0:1.0]
に正規化する必要があります.これは let s = (s * i16::MAX as f64) as i16
を実行することで f64 から i16 へ変換しているためです.
generate_signal
関数は周波数 freq [Hz] と長さ duration [s] で振幅1の正弦波を出力する関数です.
Rustの三角関数の書き方は独特で,f64型のメソッドとして定義されているため x.sin()
のように呼び出します.
signal
を生成する部分をコマンドライン引数で可変にしたらさらに使いやすくなりそうです.
ともかく,これでWAVEファイルが完成しました.
波形のダンプ
WAVEファイルが正しく生成されているか確認するために,WAVEファイルをダンプするプログラムも作成しました.
use hound;
use num::complex::Complex;
use std::f64::consts::PI;
use std::i16;
use std::thread;
// read_wav はモノラル16bitのwavファイルを読み込み、Vec<f64>に変換して返す関数
fn read_wav(path: &str) -> (hound::WavSpec, Vec<f64>) {
let mut reader = hound::WavReader::open(path).unwrap();
let spec = reader.spec();
let samples = reader
.samples::<i16>()
.map(|s| s.unwrap() as f64 / (1 << (spec.bits_per_sample - 1)) as f64) // 正規化
.collect::<Vec<f64>>();
(spec, samples)
}
fn dft(signal: &Vec<f64>, dt: f64) -> Vec<Complex<f64>> {
let n = signal.len();
let mut spectrum = vec![Complex::new(0.0, 0.0); n];
let mut handles = vec![];
for l in 0..n {
let signal = signal.clone();
let handle = thread::spawn(move || {
let mut sum = Complex::new(0.0, 0.0);
for k in 0..n {
let wn = -2.0 * PI / n as f64;
let theta = wn * (k * l) as f64;
sum += signal[k] * Complex::new(theta.cos(), theta.sin());
}
sum * dt
});
handles.push(handle);
}
for (l, handle) in handles.into_iter().enumerate() {
spectrum[l] = handle.join().unwrap();
}
spectrum
}
fn inverse_dft(spectrum: Vec<Complex<f64>>, dt: f64) -> Vec<f64> {
let n = spectrum.len();
let mut signal = vec![0.0; n];
let mut handles = vec![];
for l in 0..n {
let spectrum = spectrum.clone();
let handle = thread::spawn(move || {
let mut sum = Complex::new(0.0, 0.0);
for k in 0..n {
let wn = 2.0 * PI / n as f64;
let theta = wn * (k * l) as f64;
sum += spectrum[k] * Complex::new(theta.cos(), theta.sin());
}
sum.re / (n as f64 * dt)
});
handles.push(handle);
}
for (l, handle) in handles.into_iter().enumerate() {
signal[l] = handle.join().unwrap();
}
signal
}
fn main() {
let (spec, signal) = read_wav("./sample.wav");
let dt = 1.0 / spec.sample_rate as f64;
// t は 0.0, 1/44100, 2/44100, ..., 0.1 となる
let t = (0..signal.len())
.map(|i| i as f64 * dt)
.collect::<Vec<f64>>();
let spectrum = dft(&signal, dt);
let invspectrum = inverse_dft(spectrum, dt);
for i in 0..signal.len() {
println!("{:.7},{:.7},{:.7}", t[i], signal[i], invspectrum[i]);
}
}
read_wav
関数では 量子化ビット数 16bitのWAVE ファイルを読み取って値をVec<f64>
に詰めています.このとき,WAVEファイルの値をそのまま入れると値域が[-32768:32768]
になって直感的ではないので,[-1:1]に正規化しています.
dft
関数については後で紹介します.愚直な実装ですが,一応スレッドハンドラを使って並列処理できるようにしています.Rustはコンパイル時に
RUSTFLAGS='target-feature=+avx2'
を渡してあげるとベクトル化してくれるらしいので,それなりに大きいデータでも処理できました.
(実際に$N=4410, O(N^2) \approx 10^8$ でも数秒で処理できました.)
inverse_dft
関数も同様に愚直な逆離散フーリエ変換を並列処理化したものです.
main
関数では,時刻とWAVEファイルから読み取った振幅,離散フーリエ変換したものを逆変換して得られた振幅をそれぞれ7桁で表示しています.
WAVEファイルの情報はhound::WavSpec
構造体に詰められるので,そこからサンプリングレートを取得しています.
ここに先に用意した $\sin(t) + 2\sin(2t) + 3\sin(3t)$をダンプしたものを示します.
0.0000000,0.0000000,0.0000000
0.0000227,0.3640137,0.3640137
0.0000454,0.6772156,0.6772156
0.0000680,0.8969727,0.8969727
0.0000907,0.9957275,0.9957275
0.0001134,0.9650879,0.9650879
0.0001361,0.8170471,0.8170471
0.0001587,0.5812683,0.5812683
0.0001814,0.2997742,0.2997742
0.0002041,0.0196228,0.0196228
... (続く)
離散フーリエ変換・逆離散フーリエ変換が正しく行われていそうですね.次に波形を示します.gnuplotでグラフを出しました.
良さそうですね. $\max(\sin(1000\omega t) + 2\sin(2000\omega t) + 3\sin(3000\omega t)) \approx 5.35$ なので理論値を出すときには 5.35 で割ってます.
1/30 修正:式が間違っていました.
誤:$\sin(t) + 2\sin(2t) + 3\sin(3t)$
正:$\sin(1000\omega t) + 2\sin(2000\omega t) + 3\sin(3000\omega t)$
set datafile separator ","
version = ""
name = "waveform".version
pi = 3.14159265
set terminal pdf enhanced
set output name.".pdf"
set style fill solid
set autoscale
set size square
set key
set key outside right
set key width 0
set grid
set xr [0:0.002]
set yr [-1.1:1.1]
set mxtics 10
set mytics
set xlabel "t [s]"
set ylabel "Amplitude"
#set logscale x
set samples 12000
plot name.".csv" using 1:2 w l title "振幅",\
(sin(2*pi*1000*x) + 2*sin(2*2*pi*1000*x) + 3*sin(3*2*pi*1000*x))/5.35 w l title "理論値"
離散フーリエ変換
さて,下準備が完了したところで実際に離散フーリエ変換を行いましょう.
use hound;
use num::complex::Complex;
use std::f64::consts::PI;
use std::i16;
use std::thread;
// read_wav はモノラル16bitのwavファイルを読み込み、Vec<f64>に変換して返す関数
fn read_wav(path: &str) -> (hound::WavSpec, Vec<f64>) {
let mut reader = hound::WavReader::open(path).unwrap();
let spec = reader.spec();
let samples = reader
.samples::<i16>()
.map(|s| s.unwrap() as f64 / (1 << (spec.bits_per_sample - 1)) as f64) // 正規化
.collect::<Vec<f64>>();
(spec, samples)
}
fn dft(signal: Vec<f64>, dt: f64) -> Vec<Complex<f64>> {
let n = signal.len();
let mut spectrum = vec![Complex::new(0.0, 0.0); n];
let mut handles = vec![];
for l in 0..n {
let signal = signal.clone();
let handle = thread::spawn(move || {
let mut sum = Complex::new(0.0, 0.0);
for k in 0..n {
let wn = -2.0 * PI / n as f64;
let theta = wn * (k * l) as f64;
sum += signal[k] * Complex::new(theta.cos(), theta.sin()) * dt;
}
sum
});
handles.push(handle);
}
for (l, handle) in handles.into_iter().enumerate() {
spectrum[l] = handle.join().unwrap();
}
spectrum
}
fn main() {
let (spec, signal) = read_wav("./sample.wav");
let len: usize = signal.len();
let dt = 1.0 / spec.sample_rate as f64;
let f = (0..(signal.len() / 2))
.map(|i| i as f64 * spec.sample_rate as f64 / len as f64)
.collect::<Vec<f64>>();
let spectrum = dft(signal, dt);
// スペクトルの最大値を1に正規化
let abs_max = spectrum
.iter()
.fold(f64::MIN, |max, &s| f64::max(max, s.norm()));
let spectrum = spectrum
.iter()
.map(|s| s / abs_max)
.collect::<Vec<Complex<f64>>>();
for i in 0..(len / 2) {
//println!("{:.7},{:.7}", f[i], 2.0 * spectrum[i].norm() / len as f64);
println!("{:.7},{:.7}", f[i], spectrum[i].norm());
}
}
read_wav
関数は先程と同じです.
dft
関数は離散フーリエ変換を愚直に実装しています.離散フーリエ変換の定義式は
$$
\begin{align}
X_l = \sum^{N-1}_{k=0} x_k (W_N)^{kl} \Delta t && \text{ただし$W_N = \exp(-i\frac{2\pi}{N}$)}
\end{align}
$$
なので,その通りに計算しています.
また,信号をWAVEファイルに変換する段階で1への正規化を行ったので,こちらでもスペクトルを1に正規化しています.これによって複数のファイルでスペクトル強度を比較できなくなりますが,今回はそれが不要だったので,結果のグラフを綺麗にするために行っています.
離散フーリエ変換では折り返しが発生するので,出力する範囲はサンプリング周波数の半分まで,すなわち信号長の半分まで出力します.
結果を確認します.
0.0000000,0.0000000
(中略)
990.0000000,0.0000000
1000.0000000,0.3333426
1010.0000000,0.0000000
(中略)
1990.0000000,0.0000000
2000.0000000,0.6666778
2010.0000000,0.0000000
(中略)
2990.0000000,0.0000000
3000.0000000,1.0000000
3010.0000000,0.0000000
(後略)
1kHz, 2kHz, 3kHz でスペクトルが確認できます.
これをgnuplotで描画したら以下のようになりました.ただし,出力結果を3倍しています.
set datafile separator ","
version = ""
name = "spectrum".version
set terminal pdf enhanced
set output name.".pdf"
set style fill solid
set autoscale
set size square
set key
set key outside right
set key width 0
set grid
set xr [0:5000]
set yr [0:3.1]
set mxtics 10
set mytics
set xlabel "f [Hz]"
set ylabel "Amplitude Spectrum"
#set logscale x
# DFT で 1/6, 1/3, 1/2 -> 1/3, 2/3, 1 されているから3倍して出力
plot name.".dat" using 1:($2*3) w boxes title "スペクトル"
1/30 修正
横軸の単位を [kHz] から [Hz] に修正
アホみたいに高周波になってました…
あとがき
初 Rust にはいい入門だったと思います.Rust の cargo すごいですね.モダンでよいシステムだと思ったので今後も使っていきたいです.
1/30 追記 png 形式での出力方法
gnuplot で png 形式のグラフを出すときは set terminal pngcario enhanced
と指定すると綺麗に出せました.