0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

連続最適化アルゴリズムの検証用目的関数を自動生成する

0
Last updated at Posted at 2026-02-11

目的

「最適化のアルゴリズムを検証したいが、良い目的関数を作ることが難しい...」という悩み(俺だけ?)に答えます。
良い検証用の目的関数とは、次のことを満たすものと考えています。

  • 大域最適解だけでなく、複数の局所最適解をもつ
  • 事前に最適解を知ることができる

実際にできたもの

作成方針

今回は上記を満たすような目的関数として、複数のガウス関数を重ね合わせた目的関数を生成するコードを書いてみました。
実際のコードと使用方法は非常に長いので、記事の最後にまとめて記載しています。

要件 実装方法
大域最適解だけでなく、複数の局所最適解をもつ 複数のガウス関数を重ね合わせる。
事前に最適解を知ることができる ユーザが最適解、局所最適解の座標を指定。山の重なりによる最適解シフトは、勾配法によって事前に把握する。

これらを実際に実装して得た目的関数の例が下になります。
わかりやすさのために、2次元で作成しており、5つのガウス関数を重ね合わせています。
左が等高線図、右が3Dの局面になります。
等高線図は黄色いほどピークが大きくなっており、右側の⭐️が最適解の位置です。
また、△は生成された局所最適解を含むピーク位置となっています。
seed56.png

注意点

上の図を使って説明を続けます。

私が指定した最適解は❌(3.0, -1.5)で、実際のもの⭐️とずれています。
これが上表の「山の重なりによる最適解シフト」です。
複数のガウス関数を重ね合わせるので、このような最適解シフトが起こり得て、事前に正確な最適解を把握することは難しいです。
そこで、それぞれの指定したピークの場所⚪️から勾配法によって近隣のピーク位置を求めます。
それらの中から最大値を示す場所を最適解として、ユーザーに提供します。

※それぞれのピーク位置は、指定した座標から大きくずれることはなく、かつ付近の勾配をたどればピークに到達するという仮定を置いています。

目的関数例

これらの目的関数の詳細な形は、乱数によって制御しています。
上の図はseed=56を指定して生成したものです。

下の図に様々な乱数を用いて生成した2次元の目的関数を出します。
seed=14の場合、私が指定した❌の位置(右下)ではない場所に最適解が生成されていることがわかります。
このように、全く違う位置に最適解が生成されてしまう可能性もあるので、事前にチェックしておくことは必要だと思います。

map.png

使い方

以下は Gemini様による説明です。

実装されている GaussianMixtureFunction の詳細な使い方と、その背後にある数理的な動作について解説します。

1. 概要

GaussianMixtureFunction は、複数のガウス関数(正規分布の確率密度関数の形状)を重ね合わせることで、多峰性(複数の山がある)を持つ複雑な最適化問題を作成するためのクラスです。

数式としては、以下のように $N$ 個のガウス関数の和で表されます。

$$
f(x) = \sum_{i=1}^{N} A_i \exp\left( -\sum_{d=1}^{D} \frac{(x_d - \mu_{i,d})^2}{2\sigma_{i,d}^2} \right)
$$

ここで、

  • $A_i$: 各山の高さ(振幅)
  • $\mu_{i,d}$: 各山の中心座標
  • $\sigma_{i,d}$: 各山の広がり(標準偏差)

複数の山を重ね合わせることで、大域的最適解(Global Optimum)以外に、局所最適解(Local Optima)を多数配置し、最適化アルゴリズムを「騙す」ような地形を作ることが可能です。

2. 実装上の重要な特徴「ピークシフト」

このクラスの最も特徴的、かつ注意が必要な点は、「設定したガウス関数の中心 $\mu_i$ と、実際の関数の最大値の座標は微妙にずれる」 ということです。

なぜずれるのか?

複数のガウス関数が近接して配置されると、それぞれの裾野(スロープ)が重なり合います。ある山の頂点付近において、隣の山のスロープが加算されるため、結果として合成された関数の頂点は、元のガウス関数の中心からわずかに移動します。

コードによる解決策

GaussianMixtureFunction は、初期化時(__init__)に以下の処理を自動的に行い、この問題を解決しています。

  1. ユーザー指定の配置: ユーザーが optimal_solution(大域最適解の希望位置)を指定すると、その位置に最も振幅の大きいガウス関数を配置します。
  2. 全ピークの微修正(Refinement): 重ね合わせによって頂点がずれることを前提とし、各ガウス関数の中心座標を初期値として勾配法(L-BFGS)を実行し、その付近での 「真の極大値」 を探索します。
  3. 大域最適解の特定: 探索で見つかったすべてのピークの中から、最も関数値が高いものを真の大域最適解として特定します。
  4. 情報の更新:
    • 見つかったすべての極大値の座標は refined_centers に、その値は refined_values に保存されます。
    • その中で最も高いピークの座標と値が optimal_solution および optimal_value として設定されます。

したがって、このクラスを使用する際は、「コンストラクタに渡した座標と、インスタンスが保持する refined_centers(ひいては optimal_solution)は微小に異なる場合がある」 ことを理解しておく必要があります。また、各ピークの微修正の結果、意図した大域最適解よりも、他のピークの方が高くなる(大域最適解が入れ替わる)可能性も考慮し、自動的に検証・設定されます。

3. 基本的な使い方

3.1 インスタンスの作成

最もシンプルな使い方は、次元数と山の数を指定してランダムに生成する方法です。

import torch
from verification_objective_function.gaussian_mixture import GaussianMixtureFunction

# 2次元、5つの山を持つ関数を作成
func = GaussianMixtureFunction(
    n_dim=2,
    n_peaks=5,
    seed=42  # 再現性のためシードを指定可能
)

# 生成された真の最適解を確認
print(f"真の最適解の座標: {func.optimal_solution}")
print(f"真の最適解の値: {func.optimal_value}")

# 全てのピーク(ローカル解)の正確な位置を確認
print(f"全ピークの微修正後座標: {func.refined_centers}")
print(f"全ピークの値: {func.refined_values}")

3.2 大域最適解と局所解(トラップ)の指定

検証用関数として、「正解(大域最適解)」と「ひっかけ(局所最適解)」の位置をコントロールしたい場合は、以下のように引数を指定します。

# 大域最適解の希望位置(ここ付近に最大ピークができる)
opt_sol = torch.tensor([3.0, 3.0])

# 局所最適解(トラップ)の希望位置
local_sols = torch.tensor([
    [-2.0, -2.0],
    [ 2.0, -2.0]
])

func = GaussianMixtureFunction(
    n_dim=2,
    n_peaks=5,  # 全体で5個(指定した3個 + ランダム2個)
    optimal_solution=opt_sol,
    optimal_value=10.0,            # 最大値の目安
    local_optima_solutions=local_sols,
    local_optima_values=[8.0, 6.0] # トラップの高さの目安
)

注意:

  • optimal_valuelocal_optima_values で高さを指定しても、近隣の山との干渉により、実際の関数の値はこれよりわずかに大きくなることがあります。
  • クラス内部で、振幅パラメータ上は大域最適解が最も高くなるように自動調整されますが、最終的な地形のプロパティは func.optimal_solutionfunc.refined_values で確認してください。

3.3 関数の評価

作成したインスタンスは関数として呼び出すことができます。PyTorchのTensorを入力として受け取ります。

# 1点の評価
x = torch.tensor([0.0, 0.0])
val = func(x)
print(val)

# バッチ評価(複数の点を一度に計算)
# 形状: (バッチサイズ, 次元数)
inputs = torch.tensor([
    [0.0, 0.0],
    [1.0, 1.0],
    [3.0, 3.0]
])
vals = func(inputs)
print(vals) # -> shape: (3,)

4. 可視化の例

2次元関数の場合、等高線図などで地形を確認することができます。

import matplotlib.pyplot as plt

# グリッドデータの作成
coords = torch.linspace(-5, 5, 100)
X, Y = torch.meshgrid(coords, coords, indexing="ij")
grid = torch.stack([X.flatten(), Y.flatten()], dim=1)

# 値の計算
with torch.no_grad():
    Z = func(grid).view(100, 100)

# プロット
plt.contourf(X.numpy(), Y.numpy(), Z.numpy(), levels=20)
# 真の最適解をプロット
plt.scatter(
    func.optimal_solution[0], 
    func.optimal_solution[1], 
    c='red', marker='*', s=200, label='Global Optimum'
)
plt.legend()
plt.show()

5. まとめ

  • GaussianMixtureFunction は、複数のガウス分布を重ねて複雑な地形を作ります。
  • ピークシフト: 重ね合わせの影響で、ピーク位置はガウス関数の中心からわずかにずれます。
  • 自動補正: クラス初期化時に勾配法を用いて、すべてのピークの「真の位置(修正後の中心)」を特定します。その結果は refined_centers に保存されます。
  • 大域最適解の特定: 修正された各ピークの中から、最も関数値が高いものを自動的に大域最適解として特定し、func.optimal_solution に設定します。
  • ユーザーは「だいたいこの辺に解をおきたい」という意図で optimal_solution を指定し、正確な正解値は func.optimal_solution を参照するという使い方が推奨されます。

コード

以下はコードになります。
antigravity (Gemini 3.0)様を用いて実装しています。

環境

pyproject.toml
[project]
name = "verification-objective-functions"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
    "torch>=2.10.0",
    "numpy>=2.4.2",
]

[dependency-groups]
dev = [
    "matplotlib>=3.10.8",
]

基本的に Torch を使って記述していますが、後述の to_numpy_function 関数を使えば NumPy に変換することが可能です。

以下2つのファイルを同じディレクトリに入れてください。

目的関数抽象クラス

今回作成した目的関数の形だけではなく、その他の様々な目的関数に対して拡張性を持たせたいので、抽象クラスBaseObjectiveFunctionを作成しています。(そのせいで長い...)

base.py
from abc import ABC, abstractmethod
from typing import Tuple, Optional, Union, Callable
import torch


class BaseObjectiveFunction(ABC):
    """
    Abstract base class for objective functions used in optimization benchmarks.
    """

    def __init__(
        self,
        dim: int,
        bounds: Optional[Tuple[float, float]] = None,
        device: Union[str, torch.device] = "cpu",
        dtype: torch.dtype = torch.float32,
        optimal_solution: Optional[torch.Tensor] = None,
        optimal_value: Optional[float] = None,
    ):
        """
        Args:
            dim (int): Dimension of the input space.
            bounds (Optional[Tuple[float, float]]): The range (min, max) for the function domain.
            device (Union[str, torch.device]): Device to store tensors on.
            dtype (torch.dtype): Data type for tensors.
            optimal_solution (Optional[torch.Tensor]): The coordinates of the global optimum.
            optimal_value (Optional[float]): The value of the global optimum.
        """
        self.dim = dim
        self.bounds = bounds
        self.device = device
        self.dtype = dtype
        self.optimal_solution = optimal_solution
        self.optimal_value = optimal_value

    @abstractmethod
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """
        Evaluate the objective function at points x.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, dim) or (dim,).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size,) or scalar.
        """
        pass

    def refine_solution(
        self,
        initial_guess: torch.Tensor,
        lr: float = 0.1,  # LBFGS default lr
        steps: int = 20,
        optimizer_cls: Optional[torch.optim.Optimizer] = None,
    ) -> Tuple[torch.Tensor, float]:
        """
        Refines the optimal solution near the initial guess using gradient descent.
        This helps correct shifts in the optimum caused by function superposition.

        Args:
            initial_guess (torch.Tensor): The starting point for optimization.
            lr (float): Learning rate for the optimizer.
            steps (int): Number of optimization steps.
            optimizer_cls (Optional[torch.optim.Optimizer]): Optimizer class to use (default: LBFGS).

        Returns:
            Tuple[torch.Tensor, float]: The refined optimal solution and its value.
        """
        # Ensure initial_guess is a leaf tensor requiring grad
        x_param = (
            initial_guess.clone()
            .detach()
            .to(self.device, self.dtype)
            .requires_grad_(True)
        )

        if optimizer_cls is None:
            # LBFGS is usually good for finding precise local optima in smooth functions
            optimizer = torch.optim.LBFGS(
                [x_param], lr=lr, max_iter=20, history_size=100
            )
        else:
            optimizer = optimizer_cls([x_param], lr=lr)

        def closure():
            optimizer.zero_grad()
            # We want to MAXIMIZE the function, so we MINIMIZE the negative
            val = self(x_param)
            loss = -val
            loss.backward()
            return loss

        # Optimization loop
        for _ in range(steps):
            # LBFGS step requires the closure
            optimizer.step(closure)

        # Final evaluation
        with torch.no_grad():
            final_val = self(x_param).item()
            final_x = x_param.detach()

        return final_x, final_val

    def to(
        self,
        device: Optional[Union[str, torch.device]] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        """
        Move the objective function parameters to the specified device and dtype.

        Args:
            device (Optional[Union[str, torch.device]]): Target device.
            dtype (Optional[torch.dtype]): Target data type.

        Returns:
            self
        """
        if device is not None:
            self.device = device
        if dtype is not None:
            self.dtype = dtype

        # Subclasses should override this if they have tensors to move,
        # but they should call super().to(device, dtype) first.
        return self

    def to_numpy_function(self) -> Callable:
        """
        Returns a callable version of this objective function that works with NumPy arrays.

        The returned function handles conversion between NumPy arrays and PyTorch tensors,
        allowing seamless integration with NumPy-based tools (e.g., SciPy).

        Returns:
            Callable[[Union[np.ndarray, float, list]], Union[np.ndarray, float]]:
                A wrapper function that inputs and outputs NumPy arrays or scalars.
        """
        try:
            import numpy as np
        except ImportError:
            raise ImportError("NumPy is required to use to_numpy_function()")

        def wrapper(x):
            # Convert input to tensor with correct device and dtype
            # as_tensor handles numpy arrays, lists, and scalars robustly
            x_tensor = torch.as_tensor(x, device=self.device, dtype=self.dtype)

            # Evaluate (gradient disabled as NumPy doesn't support it)
            with torch.no_grad():
                result_tensor = self(x_tensor)

            # Convert result back to NumPy
            # .cpu().numpy() handles both scalar and array tensors.
            return result_tensor.cpu().numpy()

        return wrapper

複数のガウス関数による目的関数

gaussian_mixture.py
import torch
from typing import Tuple, Optional, Union
from .base import BaseObjectiveFunction


class GaussianMixtureFunction(BaseObjectiveFunction):
    """
    A callable function representing a superposition of multiple Gaussian functions.
    This function is useful as an objective function for optimization benchmarks.

    The function is defined as:
    f(x) = sum_{i=1}^{n_peaks} A_i * exp( -sum_{d=1}^{n_dim} (x_d - mu_{i,d})^2 / (2 * sigma_{i,d}^2) )

    Attributes:
        refined_centers (torch.Tensor): The actual locations of the local maxima found by gradient ascent, starting from the Gaussian centers.
        refined_values (torch.Tensor): The function values at the refined_centers.
    """

    def __init__(
        self,
        n_dim: int,
        n_peaks: int,
        bounds: Tuple[float, float] = (-10.0, 10.0),
        magnitude_range: Tuple[float, float] = (1.0, 10.0),
        std_dev_range: Tuple[float, float] = (1.0, 5.0),
        seed: Optional[int] = None,
        device: Union[str, torch.device] = "cpu",
        dtype: torch.dtype = torch.float32,
        optimal_solution: Optional[torch.Tensor] = None,
        optimal_value: Optional[float] = None,
        local_optima_solutions: Optional[torch.Tensor] = None,
        local_optima_values: Optional[torch.Tensor] = None,
    ):
        """
        Args:
            n_dim (int): Dimension of the input space.
            n_peaks (int): Number of Gaussian peaks to superimpose.
            bounds (Tuple[float, float]): The range (min, max) for center initialization per dimension (and function domain).
            magnitude_range (Tuple[float, float]): The range (min, max) for the absolute magnitude of peaks.
            std_dev_range (Tuple[float, float]): The range (min, max) for the standard deviation of peaks.
            seed (Optional[int]): Random seed for reproducibility.
            device (Union[str, torch.device]): Device to store tensors on.
            dtype (torch.dtype): Data type for tensors.
            optimal_solution (Optional[torch.Tensor]): User-specified location for the global optimum.
            optimal_value (Optional[float]): User-specified value for the global optimum.
            local_optima_solutions (Optional[torch.Tensor]): User-specified locations for other peaks (local optima traps).
            local_optima_values (Optional[torch.Tensor]): User-specified values for other peaks.
        """
        super().__init__(
            dim=n_dim,
            bounds=bounds,
            device=device,
            dtype=dtype,
            optimal_solution=optimal_solution,
            optimal_value=optimal_value,
        )

        if seed is not None:
            torch.manual_seed(seed)

        self.n_peaks = n_peaks

        # New arguments handling:
        # We need to process them after random initialization to override parts of it.
        # But we must validate them first or during processing.
        # However, they are not passed to super().__init__ so we just use them in body.

        # Initialize centers uniformly within bounds
        low, high = bounds
        self.centers = (
            torch.rand(n_peaks, n_dim, device=device, dtype=dtype) * (high - low) + low
        )

        # Override first center if optimal_solution is specified
        if optimal_solution is not None:
            if isinstance(optimal_solution, list):
                optimal_solution = torch.tensor(
                    optimal_solution, device=device, dtype=dtype
                )
            else:
                optimal_solution = optimal_solution.to(device=device, dtype=dtype)

            if optimal_solution.numel() != n_dim:
                raise ValueError(f"optimal_solution must have {n_dim} elements")

            # Place the first center at the optimal solution
            self.centers[0] = optimal_solution

        # Override subsequent centers if local_optima_solutions is specified
        if local_optima_solutions is not None:
            if isinstance(local_optima_solutions, list):
                local_optima_solutions = torch.tensor(
                    local_optima_solutions, device=device, dtype=dtype
                )
            else:
                local_optima_solutions = local_optima_solutions.to(
                    device=device, dtype=dtype
                )

            # Check shape: (n_local_optima, n_dim)
            if (
                local_optima_solutions.dim() != 2
                or local_optima_solutions.shape[1] != n_dim
            ):
                raise ValueError(f"local_optima_solutions must be shape (n, {n_dim})")

            n_local = local_optima_solutions.shape[0]
            if n_local + 1 > n_peaks:
                raise ValueError(
                    f"Total specified peaks (1 global + {n_local} local) exceeds n_peaks={n_peaks}"
                )

            self.centers[1 : 1 + n_local] = local_optima_solutions

        # Initialize standard deviations (diagonal covariance) uniformly
        std_low, std_high = std_dev_range
        self.std_devs = (
            torch.rand(n_peaks, n_dim, device=device, dtype=dtype)
            * (std_high - std_low)
            + std_low
        )

        # Initialize magnitudes (amplitudes)
        mag_low, mag_high = magnitude_range
        magnitudes = (
            torch.rand(n_peaks, device=device, dtype=dtype) * (mag_high - mag_low)
            + mag_low
        )

        # Random signs (-1 or 1)
        signs = (torch.randint(0, 2, (n_peaks,), device=device) * 2 - 1).to(dtype)

        self.amplitudes = signs * magnitudes

        # --- Configure Global Optimum at Index 0 ---

        # 1. Set Amplitude for Global Optimum
        if optimal_value is not None:
            # User specified exact value
            self.amplitudes[0] = optimal_value
        else:
            # Auto-selection: Find the "best" amplitude among generated and move it to index 0
            best_idx = torch.argmax(self.amplitudes)

            if best_idx != 0:
                # Swap amplitudes to put the best one at index 0
                # We don't swap centers or std_devs, effectively assigning the best amplitude
                # to the center at index 0 (which might be user-specified)
                temp = self.amplitudes[0].clone()
                self.amplitudes[0] = self.amplitudes[best_idx]
                self.amplitudes[best_idx] = temp

        # 2. Assign Local Optima Amplitudes if specified
        if local_optima_values is not None:
            # This requires local_optima_solutions to be specified to make sense directionally,
            # but strictly speaking we just assign to indices 1, 2, ...

            if isinstance(local_optima_values, list):
                local_optima_values = torch.tensor(
                    local_optima_values, device=device, dtype=dtype
                )
            else:
                local_optima_values = local_optima_values.to(device=device, dtype=dtype)

            n_local_vals = local_optima_values.numel()

            if local_optima_solutions is None:
                raise ValueError(
                    "local_optima_values requires local_optima_solutions to be specified"
                )

            n_local_sol = local_optima_solutions.shape[0]
            if n_local_vals > n_local_sol:
                raise ValueError(
                    f"Number of local_optima_values ({n_local_vals}) exceeds number of local_optima_solutions ({n_local_sol})"
                )

            if n_local_vals + 1 > n_peaks:
                raise ValueError(
                    "Total specified values (1 global + local) exceeds n_peaks"
                )

            self.amplitudes[1 : 1 + n_local_vals] = local_optima_values

        # 3. Enforce Strict Optimality (Traps must be strictly worse)
        if n_peaks > 1:
            epsilon = 1e-2  # Minimum gap
            others = self.amplitudes[1:]

            # For maximization, A[0] must be strictly greater than all others
            allowed_max = self.amplitudes[0] - epsilon
            # Clamp others to be at most allowed_max
            # Clamp others to be at most allowed_max
            self.amplitudes[1:] = torch.min(others, allowed_max.detach())

        # --- 4. Refine All Peaks ---
        # Due to superposition, the peak of the mixture may shift slightly from the center of the Gaussian.
        # We perform a local optimization starting from EACH Gaussian center to find the *true* local maxima.

        self.refined_centers = self.centers.clone()
        self.refined_values = torch.zeros(n_peaks, device=device, dtype=dtype)

        for i in range(n_peaks):
            initial_guess = self.centers[i]
            # Use the base class method to find the precise peak
            # Note: We use a fresh optimizer for each peak to avoid history carry-over
            refined_x, refined_val = self.refine_solution(initial_guess)
            self.refined_centers[i] = refined_x
            self.refined_values[i] = refined_val

        # Update the attributes to reflect the TRUE global optimum found among all refined peaks.
        # Even though we intended index 0 to be the global max, superposition might have made another peak higher.
        best_idx = torch.argmax(self.refined_values)
        self.optimal_solution = self.refined_centers[best_idx]
        self.optimal_value = self.refined_values[best_idx].item()

    def to(
        self,
        device: Optional[Union[str, torch.device]] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        """
        Move internal tensors to the specified device and/or dtype.
        """
        super().to(device=device, dtype=dtype)
        if hasattr(self, "centers"):  # Safety check, though initialized in __init__
            self.centers = self.centers.to(device=self.device, dtype=self.dtype)
        if hasattr(self, "std_devs"):
            self.std_devs = self.std_devs.to(device=self.device, dtype=self.dtype)
        if hasattr(self, "amplitudes"):
            self.amplitudes = self.amplitudes.to(device=self.device, dtype=self.dtype)
        if hasattr(self, "refined_centers"):
            self.refined_centers = self.refined_centers.to(
                device=self.device, dtype=self.dtype
            )
        if hasattr(self, "refined_values"):
            self.refined_values = self.refined_values.to(
                device=self.device, dtype=self.dtype
            )
        return self

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """
        Calculate the value of the Gaussian mixture at points x.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, n_dim) or (n_dim,).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size,) or scalar.
        """
        # Ensure input is at least 2D (batch_size, n_dim)
        original_shape = x.shape
        if x.dim() == 1:
            x = x.unsqueeze(0)

        # Ensure x is on correct device/dtype if not already
        # This is a convenience, but calling code should ideally handle this.
        if x.device != self.centers.device or x.dtype != self.centers.dtype:
            x = x.to(device=self.centers.device, dtype=self.centers.dtype)

        # x: (B, D) -> (B, 1, D)
        x_expanded = x.unsqueeze(1)

        # centers: (P, D) -> (1, P, D)
        centers_expanded = self.centers.unsqueeze(0)

        # std_devs: (P, D) -> (1, P, D)
        std_expanded = self.std_devs.unsqueeze(0)

        # Calculate exponent: -0.5 * sum( ((x - mu)/sigma)^2 )
        # (B, 1, D) - (1, P, D) -> (B, P, D)
        diff = x_expanded - centers_expanded
        squared_normalized_diff = (diff / std_expanded).pow(2)
        exponent = -0.5 * squared_normalized_diff.sum(dim=-1)  # (B, P)

        # Calculate Gaussians
        gaussians = torch.exp(exponent)  # (B, P)

        # Weighted sum
        # amplitudes: (P,) -> (1, P)
        weighted_sum = (gaussians * self.amplitudes.unsqueeze(0)).sum(dim=-1)  # (B,)

        if len(original_shape) == 1:
            return weighted_sum.squeeze(0)

        return weighted_sum

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?