【Unity】Compute Shaderで波動方程式シミュレーション

UnityのCompute Shaderで波動方程式のシミュレーションの実装です。


using UnityEngine;

public class Wave : MonoBehaviour
    [SerializeField] GameObject plane;
    [SerializeField] ComputeShader computeShader;
    [SerializeField] float deltaSize = 0.1f;
    [SerializeField] float waveCoef = 1.0f;
    private RenderTexture waveTexture, drawTexture;

    private int kernelInitialize, kernelAddWave, kernelUpdate, kernelDraw;
    private ThreadSize threadSizeInitialize, threadSizeUpdate, threadSizeDraw;

    struct ThreadSize
        public int x;
        public int y;
        public int z;

        public ThreadSize(uint x, uint y, uint z)
            this.x = (int)x;
            this.y = (int)y;
            this.z = (int)z;

    private void Start()
        // カーネルIdの取得
        kernelInitialize = computeShader.FindKernel("Initialize");
        kernelAddWave = computeShader.FindKernel("AddWave");
        kernelUpdate = computeShader.FindKernel("Update");
        kernelDraw = computeShader.FindKernel("Draw");

        // 波の高さを格納するテクスチャの作成
        waveTexture = new RenderTexture(256, 256, 0, RenderTextureFormat.RG32);
        waveTexture.wrapMode = TextureWrapMode.Clamp;
        waveTexture.enableRandomWrite = true;
        // レンダリング用のテクスチャの作成
        drawTexture = new RenderTexture(256, 256, 0, RenderTextureFormat.ARGB32);
        drawTexture.enableRandomWrite = true;

        // スレッド数の取得
        uint threadSizeX, threadSizeY, threadSizeZ;
        computeShader.GetKernelThreadGroupSizes(kernelInitialize, out threadSizeX, out threadSizeY, out threadSizeZ);
        threadSizeInitialize = new ThreadSize(threadSizeX, threadSizeY, threadSizeZ);
        computeShader.GetKernelThreadGroupSizes(kernelUpdate, out threadSizeX, out threadSizeY, out threadSizeZ);
        threadSizeUpdate = new ThreadSize(threadSizeX, threadSizeY, threadSizeZ);
        computeShader.GetKernelThreadGroupSizes(kernelDraw, out threadSizeX, out threadSizeY, out threadSizeZ);
        threadSizeDraw = new ThreadSize(threadSizeX, threadSizeY, threadSizeZ);

        // 波の高さの初期化
        computeShader.SetTexture(kernelInitialize, "waveTexture", waveTexture);
        computeShader.Dispatch(kernelInitialize, Mathf.CeilToInt(waveTexture.width / threadSizeInitialize.x), Mathf.CeilToInt(waveTexture.height / threadSizeInitialize.y), 1);

    private void FixedUpdate()
        // 波の追加
        this.computeShader.SetFloat("time", Time.time);
        this.computeShader.SetTexture(kernelAddWave, "waveTexture", waveTexture);
        this.computeShader.Dispatch(kernelAddWave, Mathf.CeilToInt(waveTexture.width / threadSizeUpdate.x), Mathf.CeilToInt(waveTexture.height / threadSizeUpdate.y), 1);

        // 波の高さの更新
        this.computeShader.SetFloat("deltaSize", deltaSize);
        this.computeShader.SetFloat("deltaTime", Time.deltaTime * 2.0f);
        this.computeShader.SetFloat("waveCoef", waveCoef);
        this.computeShader.SetTexture(kernelUpdate, "waveTexture", waveTexture);
        this.computeShader.Dispatch(kernelUpdate, Mathf.CeilToInt(waveTexture.width / threadSizeUpdate.x), Mathf.CeilToInt(waveTexture.height / threadSizeUpdate.y), 1);

        // 波の高さをもとにレンダリング用のテクスチャを作成
        this.computeShader.SetTexture(kernelDraw, "waveTexture", waveTexture);
        this.computeShader.SetTexture(kernelDraw, "drawTexture", drawTexture);
        this.computeShader.Dispatch(kernelDraw, Mathf.CeilToInt(waveTexture.width / threadSizeDraw.x), Mathf.CeilToInt(waveTexture.height / threadSizeDraw.y), 1);
        plane.GetComponent<Renderer>().material.mainTexture = drawTexture;
#pragma kernel Initialize
#pragma kernel AddWave
#pragma kernel Update
#pragma kernel Draw

RWTexture2D<float2> waveTexture;
RWTexture2D<float4> drawTexture;
float waveCoef;
float deltaSize;
float deltaTime;
float time;

[numthreads(8, 8, 1)]
void Initialize(uint3 dispatchThreadId : SV_DispatchThreadID)
	waveTexture[dispatchThreadId.xy] = float2(0, 0);

[numthreads(8, 8, 1)]
void AddWave(uint3 dispatchThreadId : SV_DispatchThreadID)
	float width, height;
	waveTexture.GetDimensions(width, height);

	float x = (dispatchThreadId.x / width) * 2.0 - 1.0;
	float y = (dispatchThreadId.y / height) * 2.0 - 1.0;

	float cx = 0.7 * cos(time * 0.5);
	float cy = 0.7 * sin(time * 0.5);

	float dx = cx - x;
	float dy = cy - y;

	float r = sqrt(dx * dx + dy * dy);

	float h = 5.0 * pow(max(0.05 - r, 0.0), 0.5);
	waveTexture[dispatchThreadId.xy] += float2(h, 0);

[numthreads(8, 8, 1)]
void Update(uint3 dispatchThreadId : SV_DispatchThreadID)
	float width, height;
	waveTexture.GetDimensions(width, height);

	float2 wave = waveTexture[dispatchThreadId.xy];
	float a = (deltaTime * deltaTime * waveCoef * waveCoef) / (deltaSize * deltaSize); 
	float h = 2.0 * wave.x - wave.y + a * (
		(dispatchThreadId.x != 0 ?          waveTexture[dispatchThreadId.xy + uint2(-1, 0)].x : waveTexture[dispatchThreadId.xy].x) +
		(dispatchThreadId.x < width - 1 ?   waveTexture[dispatchThreadId.xy + uint2( 1, 0)].x : waveTexture[dispatchThreadId.xy].x) +
		(dispatchThreadId.y != 0 ?          waveTexture[dispatchThreadId.xy + uint2(0, -1)].x : waveTexture[dispatchThreadId.xy].x) +
		(dispatchThreadId.y < height - 1 ?  waveTexture[dispatchThreadId.xy + uint2(0,  1)].x : waveTexture[dispatchThreadId.xy].x) +
		- 4.0 * wave.x) - 0.1 * deltaTime * (wave.x - wave.y);

	waveTexture[dispatchThreadId.xy] = float2(h, wave.x);

[numthreads(8, 8, 1)]
void Draw(uint3 dispatchThreadId : SV_DispatchThreadID)
	drawTexture[dispatchThreadId.xy] = lerp(
		float4(0, 0, 0, 1),
		float4(0, 1, 1, 1),
		clamp(waveTexture[dispatchThreadId.xy].x, 0, 1)

プログラムの構成は以前書いたライフゲームのCompute Shader実装と同じようになっています。




