LoginSignup
3
1

Rustのwgpuで計算処理

Posted at

今回は wgpu で計算処理を実施してみました。

前回 の画像処理と基本的な処理の流れは同じで、違いは以下を使う点です。

  • コンピューティングシェーダー
  • コンピューティングパイプライン
  • コンピューティングパス

なお、コンピューティングシェーダーは決められた戻り値がなく、バインドグループで設定したストレージバッファ等を入出力として使うようです。

コンピューティングシェーダー

前回 と同様にシェーダーを WGSL で実装します。

コンピューティングシェーダーは @compute を付与した関数で定義します。

このシェーダーは X, Y, Z の 3次元のサイズがあるワークグループという単位に分割して呼び出されるようになっており、実行の際にワークグループのサイズを指定する必要がありました。

また、ビルトイン引数として次のようなものが使え、local_invocation_xxx はワークグループ内の(相対的な)id や index になるようです。

name type
local_invocation_id vec3<u32>
local_invocation_index u32
global_invocation_id vec3<u32>
workgroup_id vec3<u32>
num_workgroups vec3<u32>

以下では、ワークグループのサイズを (4, 1, 1) に設定し、読み書き可のストレージバッファを 1つ使って、各要素を 3倍して +1 した値で更新するような処理を実装してみました。

算出対象とする要素の位置は global_invocation_id の値(下記の global_id.x)から取得しています。

src/shader.wgsl
@group(0) @binding(0)
var<storage, read_write> data: array<u32>;

@compute @workgroup_size(4)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
    data[global_id.x] = 3u * data[global_id.x] + 1u;
}

@group(0)@binding(0) はバインドグループのための設定で、Rust 側で作成したリソースをバインドする事になります。

実装

上記のシェーダーを使って計算するための処理は次のようになります。

  • シェーダーモジュールを作成し、それを使ってコンピューティングパイプラインを作成(entry_point で関数名を指定)
  • ストレージバッファを作成して初期化
  • バインドグループを作成してストレージバッファを設定
  • コンピューティングパスでパイプラインとバインドグループを設定し、dispatch_workgroups で実行するワークグループの数を指定

バインドグループを使ってシェーダー側の data 変数と Rust 側で作成したストレージバッファをバインドします。
layout: &pipeline.get_bind_group_layout(0) がシェーダーの group(0) に、binding: 0@binding(0) に対応しています。

ワークグループの数を指定して実行するのがポイントで、今回は 8つの要素に対して、(4, 1, 1) サイズのワークグループで処理するため、dispatch_workgroups するワークグループ数は (2, 1, 1) となります。

また、output_slice.get_mapped_range().to_vec() で結果を取得すると型が Vec<u8> になってしまうので、bytemuck::cast_slice::<_, u32>(...) を使って to_vec する前に u32 へキャストしています。

src/main.rs
use std::borrow::Cow;
use wgpu::util::DeviceExt;

type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;

#[tokio::main]
async fn main() -> Result<()> {
    env_logger::init();

    let workgroup_size = 4;
    // 計算対象の値
    let vs: Vec<u32> = vec![1, 2, 3, 4, 5, 6, 7, 8];

    let instance = wgpu::Instance::default();

    let opts = wgpu::RequestAdapterOptions::default();
    let adapter = instance
        .request_adapter(&opts)
        .await
        .ok_or("notfound adapter")?;

    let desc = wgpu::DeviceDescriptor::default();
    let (device, queue) = adapter.request_device(&desc, None).await?;

    let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
        label: None,
        source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
    });

    // コンピューティングパイプラインの作成
    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
        label: None,
        layout: None,
        module: &shader,
        entry_point: "main",
    });

    // ストレージバッファの作成と初期化
    let storage_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
        label: None,
        contents: bytemuck::cast_slice(&vs),
        usage: wgpu::BufferUsages::STORAGE
            | wgpu::BufferUsages::COPY_DST
            | wgpu::BufferUsages::COPY_SRC,
    });

    let output_buf = device.create_buffer(&wgpu::BufferDescriptor {
        label: None,
        size: storage_buf.size(),
        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
        mapped_at_creation: false,
    });

    // バインドグループ作成
    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
        label: None,
        layout: &pipeline.get_bind_group_layout(0),
        entries: &[wgpu::BindGroupEntry {
            binding: 0,
            resource: storage_buf.as_entire_binding(),
        }],
    });

    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());

    {
        // コンピューティングパスの開始
        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
            label: None,
            timestamp_writes: None,
        });

        pass.set_pipeline(&pipeline);
        // バインドグループの設定
        pass.set_bind_group(0, &bind_group, &[]);
        // 実行するワークグループの数を指定
        pass.dispatch_workgroups((vs.len() / workgroup_size) as u32, 1, 1);
    }
    // ストレージバッファの内容を出力用のバッファへコピー
    encoder.copy_buffer_to_buffer(&storage_buf, 0, &output_buf, 0, storage_buf.size());

    queue.submit(Some(encoder.finish()));

    let output_slice = output_buf.slice(..);

    output_slice.map_async(wgpu::MapMode::Read, |_| {});

    device.poll(wgpu::MaintainBase::Wait);
    // u32 へキャストして Vec 化
    let res = bytemuck::cast_slice::<_, u32>(&output_slice.get_mapped_range()).to_vec();

    println!("{:?}", res);

    output_buf.unmap();

    Ok(())
}
Cargo.toml
[dependencies]
env_logger = "0.10"
wgpu = "0.18"
bytemuck = "1"
tokio = { version = "1", features = ["full"] }

実行結果はこのようになりました。

実行結果
$ cargo run
・・・ 省略
[4, 7, 10, 13, 16, 19, 22, 25]
3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1