2
5

4次元サイン関数のアニメーション。

Posted at

スクリーンショット 2024-08-30 045807.png

スクリーンショット 2024-08-30 045820.png

スクリーンショット 2024-08-30 045832.png

4次元サイン関数をアニメーションさせ、各グリッドの値を色で表現してます。このコードでは、TensorFlow.jsを使用して並列化された計算を行い、Three.jsで3Dのヒートマップを表示します。

<!DOCTYPE html>
<html lang="ja">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>3D Heatmap Animation with TensorFlow.js</title>
    <style>
        body { margin: 0; }
        canvas { display: block; }
    </style>
</head>
<body>
    <!-- three.jsとTensorFlow.jsのCDNを読み込み -->
    <script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <script>
        // シーン、カメラ、レンダラーの設定
        const scene = new THREE.Scene();
        const camera = new THREE.PerspectiveCamera(75, window.innerWidth / window.innerHeight, 0.1, 1000);
        const renderer = new THREE.WebGLRenderer();
        renderer.setSize(window.innerWidth, window.innerHeight);
        document.body.appendChild(renderer.domElement);

        // グリッドのサイズと波長、振幅
        const gridSize = 50;
        const wavelength = 10;
        const amplitude = 1;
        let time = 0;

        // 平面ジオメトリを作成
        const geometry = new THREE.PlaneGeometry(gridSize, gridSize, gridSize - 1, gridSize - 1);
        const material = new THREE.MeshBasicMaterial({ vertexColors: true });
        const plane = new THREE.Mesh(geometry, material);
        scene.add(plane);

        // カメラの位置を設定
        camera.position.z = 30;
        camera.position.y = 20;
        camera.lookAt(0, 0, 0);

        // 三辺数関数を使用してサイン波を計算する関数
        function updateTriangularWave(time) {
            const positionAttribute = geometry.getAttribute('position');
            const colorAttribute = geometry.getAttribute('color');
            const positions = positionAttribute.array;
            const colors = colorAttribute.array;

            // 並列化された計算
            tf.tidy(() => {
                const x = tf.linspace(-wavelength, wavelength, gridSize);
                const y = tf.linspace(-wavelength, wavelength, gridSize);
                const xGrid = tf.tile(x.reshape([gridSize, 1]), [1, gridSize]);
                const yGrid = tf.tile(y.reshape([1, gridSize]), [gridSize, 1]);

                // 三辺数関数とサイン波の組み合わせでZ軸の値を計算
                const zGrid = tf.mul(tf.sin(tf.mul(tf.add(xGrid, yGrid), time)), amplitude);

                // 色の計算
                const colorGrid = tf.div(tf.add(zGrid, 1), 2);

                const zValues = zGrid.arraySync();
                const colorValues = colorGrid.arraySync();

                let i = 0;
                for (let j = 0; j < gridSize; j++) {
                    for (let k = 0; k < gridSize; k++) {
                        positions[i + 2] = zValues[j][k];
                        const colorIntensity = colorValues[j][k];
                        colors[i] = colorIntensity; // 赤
                        colors[i + 1] = 1 - colorIntensity; // 緑
                        colors[i + 2] = 0.5; // 青
                        i += 3;
                    }
                }
            });

            positionAttribute.needsUpdate = true;
            colorAttribute.needsUpdate = true;
        }

        // カラー属性の設定
        const colors = new Float32Array(geometry.attributes.position.count * 3);
        geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3));

        // アニメーションループ
        function animate() {
            time += 0.05;
            updateTriangularWave(time);

            plane.rotation.x = -Math.PI / 4;
            plane.rotation.z = time / 2;

            renderer.render(scene, camera);
            requestAnimationFrame(animate);
        }

        // アニメーション開始
        animate();

        // ウィンドウサイズ変更時にレンダラーをリサイズ
        window.addEventListener('resize', () => {
            const width = window.innerWidth;
            const height = window.innerHeight;
            renderer.setSize(width, height);
            camera.aspect = width / height;
            camera.updateProjectionMatrix();
        });
    </script>
</body>
</html>

tf.tidy()は、TensorFlow.jsにおけるメモリ管理を効率化するための重要な関数です。この関数を使用することで、不要になったテンソル(Tensor)を自動的に解放し、メモリリークを防ぎ、計算を効率化することができます。

tf.tidy()の役割
TensorFlow.jsでは、テンソルを生成するたびに、そのテンソルがメモリに保持されます。計算が複雑になればなるほど、多くのテンソルが生成され、その結果、メモリ消費が増大します。特にGPUを使用する場合、メモリの効率的な管理が重要です。tf.tidy()は、このメモリ管理を自動化するツールです。

使用方法
tf.tidy()は、コールバック関数を引数として取り、そのコールバック内で生成されたテンソルを一度計算が終了した後に自動的に破棄します。これにより、メモリ消費を最小限に抑えます。

基本的な使用例


const result = tf.tidy(() => {
  const a = tf.tensor1d([1, 2, 3]);
  const b = tf.tensor1d([4, 5, 6]);
  
  // 'c'はこの計算の結果として一時的に生成される
  const c = a.add(b);

  // この関数内で生成されたテンソルは'tf.tidy'の外に出ると自動的に破棄される
  return c;
});

// 'result'は破棄されずに残ります
result.print();  // [5, 7, 9]

具体的な動作

テンソルの作成:

tf.tidy()の内部で作成されたテンソルは、計算が終わると同時に自動的に解放されます。

メモリ管理:

メモリを効率的に解放することで、GPUのメモリ不足などの問題を回避し、計算のパフォーマンスを向上させます。

利用シーン
繰り返し計算: ループの中で大量のテンソルを生成する場合、tf.tidy()を使うことで、各ループで生成されるテンソルが自動的に解放され、メモリリークを防げます。

2
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
5