0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

FSLのFASTアルゴリズムを概念的に実装してみる

Last updated at Posted at 2025-05-23

動機

備忘録として。

できること

画像を指定したクラス数にガウシアンフレームワークでセグメントする。

使ったもの

  • NotebookLM
  • 論文:Zhang Y, Brady M, Smith S. Segmentation of brain MR images through a hidden Markov random field model and the expectation-maximization algorithm. IEEE Trans Med Imaging. 2001 Jan;20(1):45-57. doi: 10.1109/42.906424. PMID: 11293691.

何がしたかったか

FSLのFASTで利用されているアルゴリズムを概念的に理解して、初歩的なコードに落とし込みたかった。
画像を統計的に扱う上で重要な数学的解釈や表現が詰まっているとても勉強になる論文だと思う。正直、マルコフランダム場、有限混合モデル、など、筆者にとってはあまり馴染みのない概念が論文の中で登場するので筆者には読み解くのがとても難しかった。例えば、解説されている数式が現実世界のどの変数に当てはまるのか、そして、自分で予想したそれは正しいのかどうかを確かめていく作業は、一連の知識を理解するためにトライアンドエラーを繰り返すことを必要とする。
しかし、NotebookLMのおかげで高等数学に弱い私でも概念的に理解できた。

概要

要は、ピクセル1つ1つを操作して、その近傍ピクセルがどのクラスに属しているかの割合を計算し、セグメンテーションラベルを更新していく。セグメンテーションする各クラスの平均と標準偏差を先に与えておくことで、各ピクセルがあるクラスに属する確率を求めるガウシアンアプローチができる。

概念的実装(NotebookLM利用+最終化は筆者)

import numpy as np
import random
import matplotlib.pyplot as plt

# 仮定する関数や変数:
# - y: 観測された画像データ (numpy array)
# - initial_labels: 初期ラベル配置 (numpy array, yと同じ形状)
# - K: クラスの数
# - mu: 各クラスの輝度平均 (リストや配列, サイズK)
# - sigma: 各クラスの輝度標準偏差 (リストや配列, サイズK)

def mrf_map_icm(y, initial_labels, K, mu, sigma, get_neighbors, max_iterations=100):
    """
    MRF-MAP推定のためのICMアルゴリズムの概念的実装
    """
    labels = initial_labels.copy()
    height, width = y.shape # 画像のサイズ (2Dを仮定)
    sites = [(i, j) for i in range(height) for j in range(width)] # 全サイトのリスト

    for iteration in range(max_iterations):
        print(f"Iteration {iteration+1}/{max_iterations}...")
        labels_changed_count = 0

        # サイトをランダムな順序で処理するとより良いことがあるが、ここでは簡単な順序で
        for s in sites:
            r, c = s # サイトの座標

            current_label = labels[r, c]
            min_energy = float('inf')
            best_label = current_label

            # 可能なすべてのクラスラベルを試す
            for k in range(K):
                # サイトsのラベルをkに仮定した場合の局所的な事後エネルギーを計算
                # 事後エネルギー = 事前エネルギー + 尤度エネルギー + const
                # ICMでは、サイトs以外のラベルは固定なので、
                # 変化するのはサイトsに関する尤度エネルギー項と、
                # サイトsを含む事前エネルギー項のみ

                # 尤度エネルギー項(式 17 のサイトsに関する項)
                # GHMRFでガウス分布を仮定 (式 15, 17)
                # E_L_s = (y_s - mu_{x_s})^2 / (2*sigma_{x_s}^2) + log(sqrt(2*pi)*sigma_{x_s})
                # 定数項 log(sqrt(2*pi)) はラベルkに依存しないので省略可能
                if sigma[k] == 0: # 分散がゼロの場合は無限大エネルギー (起こりえないラベル)
                     likelihood_energy_s = float('inf')
                else:
                     likelihood_energy_s = (y[r, c] - mu[k])**2 / (2 * sigma[k]**2) + np.log(sigma[k]) # np.log(np.sqrt(2*np.pi))は定数なので省略

                # 事前エネルギー項の変化分を計算
                # これはサイトsのラベルが現在のcurrent_labelからkに変わった際に
                # 事前エネルギーE(x)がどれだけ変化するかを表す。
                # 実際の計算は、MRFのポテンシャル関数E(x)の定義(式6など)に依存する。
                # 通常、隣接サイトのラベルとの関係に基づいて計算される。
                # ここでは概念的な関数呼び出しとする。
                # 例: 隣接サイトとのラベルが異なる場合にペナルティが加算される場合など
                neighbors_s = get_neighbors(s, height, width) # get_neighbors関数は実装による
                prior_energy_change = calculate_prior_energy_s_neighbors_pairwise(s, current_label, labels, neighbors_s)

                # サイトsのラベルをkとした場合の全体の事後エネルギーの変化または局所エネルギーを計算
                # ここでは、サイトsのラベルをkとした場合の局所的なエネルギーとして計算
                # 局所的事後エネルギーは、サイトsに関する尤度エネルギーと、サイトsを含むクリックの事前エネルギーの合計に比例
                # 簡易的には、サイトsの尤度エネルギー + サイトsと隣接サイトの事前エネルギー項の合計
                # ICM更新では、delta Energy を計算する方が効率的だが、概念的には局所エネルギーを評価する
                # 局所エネルギー = 尤度エネルギー(s) + 事前エネルギー(sと隣接サイト)
                # 事前エネルギー(sと隣接サイト) の計算は calculate_prior_energy_s_neighbors に任せる (実装詳細による)
                local_prior_energy = calculate_prior_energy_s_neighbors_pairwise(s, k, labels, neighbors_s) # サイトsのラベルをkとした場合の局所事前エネルギー計算 (概念)

                local_posterior_energy = likelihood_energy_s + local_prior_energy

                # 最小エネルギーとなるラベルを見つける
                if local_posterior_energy < min_energy:
                    min_energy = local_posterior_energy
                    best_label = k

            # 最適なラベルに更新
            if best_label != current_label:
                labels[r, c] = best_label
                labels_changed_count += 1

        # ラベルが全く変化しなければ収束
        if labels_changed_count == 0:
            print("ICM converged.")
            break

    return labels

# 隣接サイト取得関数 (例: 4近傍)
def get_neighbors_4(s, h, w):
    r, c = s
    neighbors = []
    for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
        nr, nc = r + dr, c + dc
        if 0 <= nr < h and 0 <= nc < w:
            neighbors.append((nr, nc))
    return neighbors

# 局所事前エネルギー計算関数 (例: pairwise MRF, ラベルが一致するとエネルギーが低くなる)
def calculate_prior_energy_s_neighbors_pairwise(s, label_s, labels, neighbors_s, beta=1.0):
    energy = 0
    for nr, nc in neighbors_s:
        if label_s != labels[nr, nc]: # ラベルが一致しない場合にペナルティ
            energy += beta
    return energy

def generate_cow_like_pattern(width, height, brightness_val1, brightness_val2, brightness_val3):
    """
    指定された3つの輝度値を使用して、牛の模様風の画像を生成します。

    Args:
        width (int): 画像の幅 (ピクセル単位)。
        height (int): 画像の高さ (ピクセル単位)。
        brightness_val1 (int): 1つ目の輝度値 (0-255)。
        brightness_val2 (int): 2つ目の輝度値 (0-255)。
        brightness_val3 (int): 3つ目の輝度値 (0-255)。
        filename (str): 保存する画像ファイル名。
    """

    # 輝度値をソートして役割を割り当てます(背景、主たる斑点、副次的な斑点)。
    values = sorted([brightness_val1, brightness_val2, brightness_val3])
    bg_color = values[2]      # 最も明るい値を背景色とします (例: 220)
    spot_color_main = values[0] # 最も暗い値を主要な斑点の色とします (例: 25)
    spot_color_secondary = values[1] # 中間の値を副次的な斑点の色とします (例: 125)

    # 背景色で画像を初期化
    image_array = np.full((height, width), bg_color, dtype=np.uint8)

    # 複数の円を重ねて不規則な「ブロブ」(斑点)を描画する関数
    def draw_blob(arr, center_x, center_y, max_blob_radius, num_circles_per_blob, color_val):
        for _ in range(num_circles_per_blob):
            # ブロブを構成する個々の円の半径
            r = random.uniform(max_blob_radius * 0.3, max_blob_radius)

            # ブロブの中心から少しずらして、より不規則な形状を生成
            offset_x = random.uniform(-max_blob_radius * 0.5, max_blob_radius * 0.5)
            offset_y = random.uniform(-max_blob_radius * 0.5, max_blob_radius * 0.5)

            circ_center_x = center_x + offset_x
            circ_center_y = center_y + offset_y

            # 円のバウンディングボックスを取得してピクセル処理を効率化
            min_x = max(0, int(circ_center_x - r))
            max_x = min(width, int(circ_center_x + r) + 1)
            min_y = max(0, int(circ_center_y - r))
            max_y = min(height, int(circ_center_y + r) + 1)

            for y_coord in range(min_y, max_y):
                for x_coord in range(min_x, max_x):
                    # 円の内側のピクセルであれば色を塗る
                    if (x_coord - circ_center_x)**2 + (y_coord - circ_center_y)**2 < r**2:
                        arr[y_coord, x_coord] = color_val

    # 主要な斑点(暗い色)のパラメータと描画
    num_main_blobs = random.randint(4, 7) # 斑点の数
    for _ in range(num_main_blobs):
        blob_cx = random.randint(0, width - 1)  # 斑点の中心X座標
        blob_cy = random.randint(0, height - 1) # 斑点の中心Y座標
        # 斑点の最大半径 (画像の幅に対する相対値)
        blob_radius = random.uniform(width * 0.1, width * 0.25)
        circles_in_blob = random.randint(3, 7) # 1つの斑点を構成する円の数 (多いほど複雑)
        draw_blob(image_array, blob_cx, blob_cy, blob_radius, circles_in_blob, spot_color_main)

    # 副次的な斑点(中間色)のパラメータと描画
    num_secondary_blobs = random.randint(2, 4)
    for _ in range(num_secondary_blobs):
        blob_cx = random.randint(0, width - 1)
        blob_cy = random.randint(0, height - 1)
        blob_radius = random.uniform(width * 0.05, width * 0.15) # やや小さめの斑点
        circles_in_blob = random.randint(2, 5)
        draw_blob(image_array, blob_cx, blob_cy, blob_radius, circles_in_blob, spot_color_secondary)
    return image_array

テスト

# データとパラメータを準備 (例)
# 予測対象画像
y = generate_cow_like_pattern(100,100, 25,125,220)

print("y")
plt.imshow(y)
plt.show()

K = 3 # 3クラス
initial_labels = np.random.randint(0, K, size=y.shape) # ランダムな初期ラベル

mu = [25, 125, 220] # 各クラスの平均輝度
sigma = [0.3, 0.3, 0.3] # 各クラスの標準偏差

# ICMを実行
estimated_labels = mrf_map_icm(y, initial_labels, K, mu, sigma, get_neighbors_4, max_iterations=10)

# 結果として estimated_labels に推定されたラベル配置が得られる
print("estimate result")
plt.imshow(estimated_labels)
plt.show()

y画像
image.png

セグメンテーション結果
download.png

MRI画像(T1w)でテスト

import requests
import zipfile
import io
import numpy as np
import tifffile # TIFFファイルの読み込みに特化したライブラリ

def download_and_extract_tiff_zip_to_ndarray(url: str) -> np.ndarray | None:
    """
    指定されたURLからZIPファイルをダウンロードし、その中のTIFFファイルを読み込んで
    NumPy配列として返します。
    ZIPファイルには単一のマルチページTIFF、または複数のシングルページTIFFが含まれていることを想定しています。

    Args:
        url (str): TIFFファイルを含むZIPファイルのURL。

    Returns:
        np.ndarray | None: 成功した場合は画像データを含むNumPy配列、失敗した場合はNone。
    """
    try:
        # 1. ZIPファイルをダウンロード
        print(f"'{url}' からZIPファイルをダウンロードしています...")
        response = requests.get(url, timeout=30)
        response.raise_for_status()  # HTTPエラーがあれば例外を発生させる
        zip_file_bytes = io.BytesIO(response.content)
        print("ダウンロード完了。")

        tiff_image_data = None

        with zipfile.ZipFile(zip_file_bytes) as zf:
            # ZIPファイル内のTIFFファイル名を取得 (.tif または .tiff)
            tiff_file_names = [name for name in zf.namelist() if name.lower().endswith(('.tif', '.tiff'))]

            if not tiff_file_names:
                print("エラー: ZIPファイル内にTIFFファイルが見つかりませんでした。")
                return None

            print(f"ZIP内で見つかったTIFFファイル: {tiff_file_names}")

            # ケース1: ZIP内に単一のTIFFファイルが存在する場合 (多くはマルチページTIFFスタック)
            if len(tiff_file_names) == 1:
                file_name = tiff_file_names[0]
                print(f"単一のTIFFファイル '{file_name}' を処理しています...")
                with zf.open(file_name) as specific_file_bytes:
                    # tifffile.imread はファイルライクオブジェクトから直接読み込める
                    tiff_image_data = tifffile.imread(specific_file_bytes)
                print(f"'{file_name}' の読み込み成功。形状: {tiff_image_data.shape}, データ型: {tiff_image_data.dtype}")

            # ケース2: ZIP内に複数のTIFFファイルが存在する場合 (各ファイルがスライスであると仮定)
            else:
                print(f"複数のTIFFファイルが見つかりました。これらをスライスとして結合します: {tiff_file_names}")
                # ファイル名でソートして、スライスの順序を正しく保つ (例: slice01.tif, slice02.tif ...)
                tiff_file_names.sort()

                slices = []
                for file_name in tiff_file_names:
                    print(f"スライス '{file_name}' を読み込んでいます...")
                    with zf.open(file_name) as specific_file_bytes:
                        slice_data = tifffile.imread(specific_file_bytes)
                        slices.append(slice_data)

                if not slices:
                    print("エラー: TIFFスライスからデータを読み込めませんでした。")
                    return None

                try:
                    # スライスをNumPy配列としてスタック (通常、スライス方向が最初の次元になるように)
                    tiff_image_data = np.stack(slices, axis=0)
                    print(f"全{len(slices)}スライスのスタック成功。最終的な形状: {tiff_image_data.shape}, データ型: {tiff_image_data.dtype}")
                except ValueError as e:
                    print(f"エラー: スライスのスタックに失敗しました。各スライスの形状が異なる可能性があります。詳細: {e}")
                    return None

        return tiff_image_data

    except requests.exceptions.RequestException as e:
        print(f"ダウンロードエラー: {e}")
        return None
    except zipfile.BadZipFile:
        print("エラー: ZIPファイルが壊れているか、不正な形式です。")
        return None
    except tifffile.TiffFileError as e: # tifffile固有のエラー
        print(f"TIFファイルの読み込みエラー: {e}")
        return None
    except Exception as e:
        print(f"予期せぬエラーが発生しました: {e}")
        return None

実行

t1 = download_and_extract_tiff_zip_to_ndarray("https://imagej.net/ij/images/t1-head.zip")

plt.imshow(t1[63, :,:])
plt.show()

K = 3 # 3クラス(白質、灰白質、脳脊髄液)
initial_labels = np.random.randint(0, K, size=(256,256)) # ランダムな初期ラベル

mu = [180, 340, 50] # 各クラスの平均輝度
sigma = [20, 7, 12] # 各クラスの標準偏差

# ICMを実行
estimated_labels = mrf_map_icm(t1[63, :,:], initial_labels, K, mu, sigma, get_neighbors_4, max_iterations=100)
# 結果として estimated_labels に推定されたラベル配置が得られる

print("estimate result")
plt.imshow(estimated_labels)
plt.show()

download.png
download.png

今回は背景(空気や頭蓋骨など脳以外の領域)を含めているが、実際は頭蓋骨除去を行ってから処理するので、背景情報は無視できる。
平均値と標準偏差を適切に指定したり、処理前にデノイズしておいたりすることで、より精度を高めることができる。

免責

コードの間違いや解釈の違いによる実装の違いが含まれている可能性があります。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?