【Gaussian Splatting】NianticのSPZファイルをPythonで読み込む【Scaniverse】

Last updated at Posted at 2024-11-07


先日ポケモンGOで有名なNiantic(界隈では定番のScaniverseの開発元ですね)から効率的なGaussian Splattingの保存形式であるSPZファイルの諸々が公開されました。




import struct
import gzip
import numpy as np
from dataclasses import dataclass
from typing import List, Optional

class GaussianCloud:
    num_points: int = 0
    sh_degree: int = 0
    antialiased: bool = False
    positions: np.ndarray = None  # (N, 3)
    scales: np.ndarray = None     # (N, 3)
    rotations: np.ndarray = None  # (N, 4)
    alphas: np.ndarray = None     # (N,)
    colors: np.ndarray = None     # (N, 3)
    sh: np.ndarray = None         # (N, sh_dim, 3)

class PackedGaussiansHeader:
    magic: int = 0x5053474e  # NGSP in ASCII
    version: int = 2
    num_points: int = 0
    sh_degree: int = 0
    fractional_bits: int = 0
    flags: int = 0
    reserved: int = 0

def dim_for_degree(degree: int) -> int:
    """Convert SH degree to dimension."""
    if degree == 0:
        return 0
    elif degree == 1:
        return 3
    elif degree == 2:
        return 8
    elif degree == 3:
        return 15
        raise ValueError(f"Unsupported SH degree: {degree}")

def unquantize_sh(x: np.ndarray) -> np.ndarray:
    """Dequantize SH coefficient from uint8."""
    return (x.astype(np.float32) - 128.0) / 128.0

def inv_sigmoid(x: np.ndarray) -> np.ndarray:
    """Inverse sigmoid function."""
    # Avoid division by zero
    x = np.clip(x, 1e-7, 1 - 1e-7)  # Clip to avoid log(0)
    return np.log(x / (1.0 - x))

def load_spz(filename: str) -> Optional[GaussianCloud]:
    """Load a Gaussian Splat from an SPZ file."""
        with open(filename, 'rb') as f:
            data = f.read()
        return load_spz_from_bytes(data)
    except Exception as e:
        print(f"[SPZ ERROR] Failed to load {filename}: {e}")
        return None

def load_spz_from_bytes(data: bytes) -> Optional[GaussianCloud]:
    """Load a Gaussian Splat from bytes."""
        # Decompress gzipped data
            decompressed = gzip.decompress(data)
        except Exception as e:
            print(f"[SPZ ERROR] Failed to decompress data: {e}")
            return None

        # Read header - 8 bytes
        if len(decompressed) < 8:
            print("[SPZ ERROR] Data too short for header")
            return None

        # First try reading the basic header fields
        header = PackedGaussiansHeader()
        header.magic, header.version = struct.unpack('<II', decompressed[:8])
        if header.magic != 0x5053474e:  # "NGSP"
            print(f"[SPZ ERROR] Invalid magic number: {header.magic:08x}")
            return None
        if header.version not in (1, 2):
            print(f"[SPZ ERROR] Unsupported version: {header.version}")
            return None

        # Read the rest of the header based on version
        if header.version == 1:
            # Version 1 header format
            header.num_points, = struct.unpack('<I', decompressed[8:12])
            header.sh_degree = 0  # Default for version 1
            header.fractional_bits = 0
            header.flags = 0
            header_size = 12
            # Version 2 header format
            if len(decompressed) < 12:
                print("[SPZ ERROR] Data too short for version 2 header")
                return None
            header.num_points, header.sh_degree, header.fractional_bits, header.flags, header.reserved = \
                struct.unpack('<IBBBB', decompressed[8:16])
            header_size = 16

        uses_float16 = (header.version == 1)
        antialiased = (header.flags & 0x1) != 0
        sh_dim = dim_for_degree(header.sh_degree)

        # Print debug info
        print(f"[SPZ DEBUG] Version: {header.version}")
        print(f"[SPZ DEBUG] Num points: {header.num_points}")
        print(f"[SPZ DEBUG] SH degree: {header.sh_degree}")
        print(f"[SPZ DEBUG] SH dim: {sh_dim}")
        print(f"[SPZ DEBUG] Uses float16: {uses_float16}")
        # Calculate sizes
        pos_size = header.num_points * 3 * (2 if uses_float16 else 3)
        scales_size = header.num_points * 3
        rotations_size = header.num_points * 3
        alphas_size = header.num_points
        colors_size = header.num_points * 3
        sh_size = header.num_points * sh_dim * 3

        # Verify total size
        expected_size = header_size + pos_size + scales_size + rotations_size + alphas_size + colors_size + sh_size
        if len(decompressed) < expected_size:
            print(f"[SPZ ERROR] Data too short. Expected {expected_size} bytes, got {len(decompressed)}")
            return None
        # Read data sections
        offset = header_size
        positions_data = decompressed[offset:offset + pos_size]
        offset += pos_size
        alphas_data = decompressed[offset:offset + alphas_size]
        offset += alphas_size
        colors_data = decompressed[offset:offset + colors_size]
        offset += colors_size
        scales_data = decompressed[offset:offset + scales_size]
        offset += scales_size
        rotations_data = decompressed[offset:offset + rotations_size]
        offset += rotations_size
        sh_data = decompressed[offset:offset + sh_size]
        # Create result object
        result = GaussianCloud(
        # Unpack positions
        if uses_float16:
            positions = np.frombuffer(positions_data, dtype=np.float16)
            positions = np.zeros(header.num_points * 3, dtype=np.float32)
            scale = 1.0 / (1 << header.fractional_bits)
            pos_data = np.frombuffer(positions_data, dtype=np.uint8)
            for i in range(header.num_points * 3):
                fixed32 = pos_data[i*3] | (pos_data[i*3+1] << 8) | (pos_data[i*3+2] << 16)
                if fixed32 & 0x800000:
                    fixed32 |= 0xff000000  # Sign extension
                positions[i] = float(np.int32(fixed32)) * scale
        result.positions = positions.reshape(-1, 3)
        # Unpack scales
        scales = np.frombuffer(scales_data, dtype=np.uint8)
        result.scales = (scales.reshape(-1, 3) / 16.0 - 10.0).astype(np.float32)
        # Unpack rotations
        rot = np.frombuffer(rotations_data, dtype=np.uint8).reshape(-1, 3)
        xyz = (rot.astype(np.float32) / 127.5 - 1.0)
        w = np.sqrt(np.maximum(0.0, 1.0 - np.sum(xyz * xyz, axis=1)))
        result.rotations = np.column_stack([xyz, w])
        # Unpack alphas
        alphas = np.frombuffer(alphas_data, dtype=np.uint8)
        result.alphas = inv_sigmoid(alphas.astype(np.float32) / 255.0)
        # Unpack colors
        colors = np.frombuffer(colors_data, dtype=np.uint8)
        result.colors = ((colors.reshape(-1, 3).astype(np.float32) / 255.0) - 0.5) / 0.15
        # Unpack SH coefficients
        if sh_dim > 0:
            sh = np.frombuffer(sh_data, dtype=np.uint8)
            result.sh = unquantize_sh(sh).reshape(-1, sh_dim, 3)
        return result
    except Exception as e:
        print(f"[SPZ ERROR] Failed to parse data: {e}")
        import traceback
        return None



splats = load_spz(r"<path to spz>.spz")
[SPZ DEBUG] Version: 2
[SPZ DEBUG] Num points: 400281
[SPZ DEBUG] SH degree: 3
[SPZ DEBUG] SH dim: 15
[SPZ DEBUG] Uses float16: False


  • num_points: int
  • sh_degree: int
  • antialiased: bool
  • positions: np.ndarray # (N, 3)
  • scales: np.ndarray # (N, 3)
  • rotations: np.ndarray # (N, 4)
  • alphas: np.ndarray # (N,)
  • colors: np.ndarray # (N, 3)
  • sh: np.ndarray





