今回は wgpu で計算処理を実施してみました。
前回 の画像処理と基本的な処理の流れは同じで、違いは以下を使う点です。
- コンピューティングシェーダー
- コンピューティングパイプライン
- コンピューティングパス
なお、コンピューティングシェーダーは決められた戻り値がなく、バインドグループで設定したストレージバッファ等を入出力として使うようです。
コンピューティングシェーダー
前回 と同様にシェーダーを WGSL で実装します。
コンピューティングシェーダーは @compute
を付与した関数で定義します。
このシェーダーは X, Y, Z の 3次元のサイズがあるワークグループという単位に分割して呼び出されるようになっており、実行の際にワークグループのサイズを指定する必要がありました。
また、ビルトイン引数として次のようなものが使え、local_invocation_xxx
はワークグループ内の(相対的な)id や index になるようです。
name | type |
---|---|
local_invocation_id | vec3<u32> |
local_invocation_index | u32 |
global_invocation_id | vec3<u32> |
workgroup_id | vec3<u32> |
num_workgroups | vec3<u32> |
以下では、ワークグループのサイズを (4, 1, 1)
に設定し、読み書き可のストレージバッファを 1つ使って、各要素を 3倍して +1 した値で更新するような処理を実装してみました。
算出対象とする要素の位置は global_invocation_id の値(下記の global_id.x
)から取得しています。
@group(0) @binding(0)
var<storage, read_write> data: array<u32>;
@compute @workgroup_size(4)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
data[global_id.x] = 3u * data[global_id.x] + 1u;
}
@group(0)
と @binding(0)
はバインドグループのための設定で、Rust 側で作成したリソースをバインドする事になります。
実装
上記のシェーダーを使って計算するための処理は次のようになります。
- シェーダーモジュールを作成し、それを使ってコンピューティングパイプラインを作成(entry_point で関数名を指定)
- ストレージバッファを作成して初期化
- バインドグループを作成してストレージバッファを設定
- コンピューティングパスでパイプラインとバインドグループを設定し、
dispatch_workgroups
で実行するワークグループの数を指定
バインドグループを使ってシェーダー側の data
変数と Rust 側で作成したストレージバッファをバインドします。
layout: &pipeline.get_bind_group_layout(0)
がシェーダーの group(0)
に、binding: 0
が @binding(0)
に対応しています。
ワークグループの数を指定して実行するのがポイントで、今回は 8つの要素に対して、(4, 1, 1)
サイズのワークグループで処理するため、dispatch_workgroups するワークグループ数は (2, 1, 1)
となります。
また、output_slice.get_mapped_range().to_vec()
で結果を取得すると型が Vec<u8>
になってしまうので、bytemuck::cast_slice::<_, u32>(...)
を使って to_vec する前に u32 へキャストしています。
use std::borrow::Cow;
use wgpu::util::DeviceExt;
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
#[tokio::main]
async fn main() -> Result<()> {
env_logger::init();
let workgroup_size = 4;
// 計算対象の値
let vs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];
let instance = wgpu::Instance::default();
let opts = wgpu::RequestAdapterOptions::default();
let adapter = instance
.request_adapter(&opts)
.await
.ok_or("notfound adapter")?;
let desc = wgpu::DeviceDescriptor::default();
let (device, queue) = adapter.request_device(&desc, None).await?;
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
});
// コンピューティングパイプラインの作成
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &shader,
entry_point: "main",
});
// ストレージバッファの作成と初期化
let storage_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytemuck::cast_slice(&vs),
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
});
let output_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: storage_buf.size(),
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// バインドグループ作成
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &pipeline.get_bind_group_layout(0),
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: storage_buf.as_entire_binding(),
}],
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
{
// コンピューティングパスの開始
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
// バインドグループの設定
pass.set_bind_group(0, &bind_group, &[]);
// 実行するワークグループの数を指定
pass.dispatch_workgroups((vs.len() / workgroup_size) as u32, 1, 1);
}
// ストレージバッファの内容を出力用のバッファへコピー
encoder.copy_buffer_to_buffer(&storage_buf, 0, &output_buf, 0, storage_buf.size());
queue.submit(Some(encoder.finish()));
let output_slice = output_buf.slice(..);
output_slice.map_async(wgpu::MapMode::Read, |_| {});
device.poll(wgpu::MaintainBase::Wait);
// u32 へキャストして Vec 化
let res = bytemuck::cast_slice::<_, u32>(&output_slice.get_mapped_range()).to_vec();
println!("{:?}", res);
output_buf.unmap();
Ok(())
}
[dependencies]
env_logger = "0.10"
wgpu = "0.18"
bytemuck = "1"
tokio = { version = "1", features = ["full"] }
実行結果はこのようになりました。
$ cargo run
・・・ 省略
[4, 7, 10, 13, 16, 19, 22, 25]