C#
Unity
ComputeShader

【Unity】Compute Shaderで波動方程式シミュレーション

UnityのCompute Shaderで波動方程式のシミュレーションの実装です。

wave.gif

波動方程式に関しては以下の記事を参考にしました。
波動方程式の数値解法 - Qiita

Wave.cs
using UnityEngine;

public class Wave : MonoBehaviour
{
    [SerializeField] GameObject plane;
    [SerializeField] ComputeShader computeShader;
    [SerializeField] float deltaSize = 0.1f;
    [SerializeField] float waveCoef = 1.0f;
    private RenderTexture waveTexture, drawTexture;

    private int kernelInitialize, kernelAddWave, kernelUpdate, kernelDraw;
    private ThreadSize threadSizeInitialize, threadSizeUpdate, threadSizeDraw;

    struct ThreadSize
    {
        public int x;
        public int y;
        public int z;

        public ThreadSize(uint x, uint y, uint z)
        {
            this.x = (int)x;
            this.y = (int)y;
            this.z = (int)z;
        }
    }

    private void Start()
    {
        // カーネルIdの取得
        kernelInitialize = computeShader.FindKernel("Initialize");
        kernelAddWave = computeShader.FindKernel("AddWave");
        kernelUpdate = computeShader.FindKernel("Update");
        kernelDraw = computeShader.FindKernel("Draw");

        // 波の高さを格納するテクスチャの作成
        waveTexture = new RenderTexture(256, 256, 0, RenderTextureFormat.RG32);
        waveTexture.wrapMode = TextureWrapMode.Clamp;
        waveTexture.enableRandomWrite = true;
        waveTexture.Create();
        // レンダリング用のテクスチャの作成
        drawTexture = new RenderTexture(256, 256, 0, RenderTextureFormat.ARGB32);
        drawTexture.enableRandomWrite = true;
        drawTexture.Create();

        // スレッド数の取得
        uint threadSizeX, threadSizeY, threadSizeZ;
        computeShader.GetKernelThreadGroupSizes(kernelInitialize, out threadSizeX, out threadSizeY, out threadSizeZ);
        threadSizeInitialize = new ThreadSize(threadSizeX, threadSizeY, threadSizeZ);
        computeShader.GetKernelThreadGroupSizes(kernelUpdate, out threadSizeX, out threadSizeY, out threadSizeZ);
        threadSizeUpdate = new ThreadSize(threadSizeX, threadSizeY, threadSizeZ);
        computeShader.GetKernelThreadGroupSizes(kernelDraw, out threadSizeX, out threadSizeY, out threadSizeZ);
        threadSizeDraw = new ThreadSize(threadSizeX, threadSizeY, threadSizeZ);

        // 波の高さの初期化
        computeShader.SetTexture(kernelInitialize, "waveTexture", waveTexture);
        computeShader.Dispatch(kernelInitialize, Mathf.CeilToInt(waveTexture.width / threadSizeInitialize.x), Mathf.CeilToInt(waveTexture.height / threadSizeInitialize.y), 1);
    }

    private void FixedUpdate()
    {
        // 波の追加
        this.computeShader.SetFloat("time", Time.time);
        this.computeShader.SetTexture(kernelAddWave, "waveTexture", waveTexture);
        this.computeShader.Dispatch(kernelAddWave, Mathf.CeilToInt(waveTexture.width / threadSizeUpdate.x), Mathf.CeilToInt(waveTexture.height / threadSizeUpdate.y), 1);

        // 波の高さの更新
        this.computeShader.SetFloat("deltaSize", deltaSize);
        this.computeShader.SetFloat("deltaTime", Time.deltaTime * 2.0f);
        this.computeShader.SetFloat("waveCoef", waveCoef);
        this.computeShader.SetTexture(kernelUpdate, "waveTexture", waveTexture);
        this.computeShader.Dispatch(kernelUpdate, Mathf.CeilToInt(waveTexture.width / threadSizeUpdate.x), Mathf.CeilToInt(waveTexture.height / threadSizeUpdate.y), 1);

        // 波の高さをもとにレンダリング用のテクスチャを作成
        this.computeShader.SetTexture(kernelDraw, "waveTexture", waveTexture);
        this.computeShader.SetTexture(kernelDraw, "drawTexture", drawTexture);
        this.computeShader.Dispatch(kernelDraw, Mathf.CeilToInt(waveTexture.width / threadSizeDraw.x), Mathf.CeilToInt(waveTexture.height / threadSizeDraw.y), 1);
        plane.GetComponent<Renderer>().material.mainTexture = drawTexture;
    }
}
Wave.compute
#pragma kernel Initialize
#pragma kernel AddWave
#pragma kernel Update
#pragma kernel Draw

RWTexture2D<float2> waveTexture;
RWTexture2D<float4> drawTexture;
float waveCoef;
float deltaSize;
float deltaTime;
float time;

[numthreads(8, 8, 1)]
void Initialize(uint3 dispatchThreadId : SV_DispatchThreadID)
{
    waveTexture[dispatchThreadId.xy] = float2(0, 0);
}

[numthreads(8, 8, 1)]
void AddWave(uint3 dispatchThreadId : SV_DispatchThreadID)
{
    float width, height;
    waveTexture.GetDimensions(width, height);

    float x = (dispatchThreadId.x / width) * 2.0 - 1.0;
    float y = (dispatchThreadId.y / height) * 2.0 - 1.0;

    float cx = 0.7 * cos(time * 0.5);
    float cy = 0.7 * sin(time * 0.5);

    float dx = cx - x;
    float dy = cy - y;

    float r = sqrt(dx * dx + dy * dy);

    float h = 5.0 * pow(max(0.05 - r, 0.0), 0.5);
    waveTexture[dispatchThreadId.xy] += float2(h, 0);
}

[numthreads(8, 8, 1)]
void Update(uint3 dispatchThreadId : SV_DispatchThreadID)
{
    float width, height;
    waveTexture.GetDimensions(width, height);

    float2 wave = waveTexture[dispatchThreadId.xy];
    float a = (deltaTime * deltaTime * waveCoef * waveCoef) / (deltaSize * deltaSize); 
    float h = 2.0 * wave.x - wave.y + a * (
        (dispatchThreadId.x != 0 ?          waveTexture[dispatchThreadId.xy + uint2(-1, 0)].x : waveTexture[dispatchThreadId.xy].x) +
        (dispatchThreadId.x < width - 1 ?   waveTexture[dispatchThreadId.xy + uint2( 1, 0)].x : waveTexture[dispatchThreadId.xy].x) +
        (dispatchThreadId.y != 0 ?          waveTexture[dispatchThreadId.xy + uint2(0, -1)].x : waveTexture[dispatchThreadId.xy].x) +
        (dispatchThreadId.y < height - 1 ?  waveTexture[dispatchThreadId.xy + uint2(0,  1)].x : waveTexture[dispatchThreadId.xy].x) +
        - 4.0 * wave.x) - 0.1 * deltaTime * (wave.x - wave.y);

    waveTexture[dispatchThreadId.xy] = float2(h, wave.x);
}

[numthreads(8, 8, 1)]
void Draw(uint3 dispatchThreadId : SV_DispatchThreadID)
{
    drawTexture[dispatchThreadId.xy] = lerp(
        float4(0, 0, 0, 1),
        float4(0, 1, 1, 1),
        clamp(waveTexture[dispatchThreadId.xy].x, 0, 1)
    );
}

プログラムの構成は以前書いたライフゲームのCompute Shader実装と同じようになっています。

波の高さを格納するwaveTextureにはxに現在の波の高さを、yに1つ前の波の高さを格納します。xyしか使用しないのでフォーマットにはRenderTextureFormat.ARGB32を使用しています。

LifeGame.computeAddWaveカーネルでは円状に新たな波を追加しています。
Updateカーネルで波の高さの更新を行い、Drawカーネルでは波の高さをもとにマテリアルに渡すテクスチャに値を設定しています。