0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

MedGemmaに触れてみる

Posted at

MedGemma

MedGemma[1]は、Geminiよりも軽量かつオープンなコミュニティモデルであるGemmaを医療用に訓練したモデル。執筆時点(2025/09/23)では、Gemma3の派生モデルの位置づけとなっている。

MedGemmaは入力にテキストと画像を扱えるマルチモーダルバージョンであり、画像に対応する入力部は、胸部X線、皮膚科画像、眼科画像、組織病理学スライドなど、様々な匿名化医療データで事前学習されたSigLIP(という画像エンコーダ)を利用してるとのこと。LLMコンポーネントには、医療テキスト、医療Q&Aペア、FHIRベースの電子健康記録データ(27Bバージョンのみ)、放射線画像、組織病理学パッチ、眼科画像、皮膚科画像など、多様な医療データで訓練されているらしい。LLMが画像で訓練されているというのは、ViTみたいなニュアンスなのだろう。

MedGemma 4Bは事前学習済み(末尾に-ptがついているバージョン)と指示学習済み(-itがついているバージョン)が利用可能。指示学習済みバージョンはほとんどのアプリケーションにおいてより良い出発点となるとのことで、広義な意味での推奨なのだろう。事前学習済みバージョンはモデルを深く実験したいユーザー向けに提供されているとのことで、蔵人向けのようだ。

早速試しに使ってみよう。

事前準備

HuggingFaceにサインインした後、AccessToken(read権限のみ)を作成する。AccessTokenは、ユーザー設定から行ける。後から使うので、コピーしとく。

# 以下、インストールしておく。Restartも必要。
# Python(3.12.x)を利用。
# First, install the Transformers library. Gemma 3 is supported starting from transformers 4.50.0.
!pip install transformers==4.50
!pip install timm==0.9.12
!pip install accelerate==0.26

サンプルコード

まずは、最短で、最も簡単な使い方を試したい。

最初に、HuggingFaceからMedGemmaモデルをダウンロードするために、HuggingFaceにログインする。

from huggingface_hub import notebook_login
notebook_login() # ここで入力画面が開くのでこぴーしておいたTokenを入力する。

モデルをロードする。

# Load model directly
from transformers import AutoProcessor, AutoModelForImageTextToText

processor = AutoProcessor.from_pretrained("google/medgemma-4b-it")
model = AutoModelForImageTextToText.from_pretrained("google/medgemma-4b-it")

推論実行する。

messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"},
            {"type": "text", "text": "What animal is on the candy?"}
        ]
    },
]
inputs = processor.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=800)
print(processor.decode(outputs[0][inputs["input_ids"].shape[-1]:]))

結果が得られる。

The animal on the candy is a **bird**.
<end_of_turn>

オープンデータで試す(脳MRI画像)

医療系の画像の場合、一枚では済まないことが多い。
上記の例では画像は1枚のみの入力となっている。
ここでは、脳MRI画像3枚を使って、推論結果を得られるようほんの少しだけ工夫した。

# この画像で試しました。
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

# ソース:https://imagej.net/ij/images/mri-stack.zip
image_paths = [
    "/home/tatsunidas/AI aided Image Interpretation/mri-stack18.jpg",
    "/home/tatsunidas/AI aided Image Interpretation/mri-stack19.jpg",
    "/home/tatsunidas/AI aided Image Interpretation/mri-stack20.jpg"
]
# 2. Pillowライブラリを使って、各画像ファイルを読み込みます
try:
    images = [Image.open(path).convert("RGB") for path in image_paths]
except FileNotFoundError:
    print("エラー: 指定された画像ファイルが見つかりません。パスを確認してください。")
    # ここでプログラムを終了するか、エラー処理を続けます
    exit()

# processorにchannels-lastで渡してOKとのこと。
plt.figure()
for i in range(len(image_paths)):
    if i == 0 : print(np.array(images[i]).shape)
    plt.subplot(1,3,(i+1))
    plt.imshow(np.array(images[i]))
plt.title("MRI-Stack (ImageJ Sample)")
plt.show()

bd9fef3b-3fcd-47bf-bdd5-8dc8131c5c80.png

実は、先述の方法を少し改変してやればいいだろうと思っていたのだが、そのままではだめだった。
複数画像(ここでは3枚)を入力すると、入力サイズのデータ量がモデル(medgemma-4b-it)のキャパを超えてしまった(この点、Llavaの方が進んでいるのかも。?)。
そのため、一枚ずつ入力することにして、最後にTextToTextでSummarizeすることにした。

# 画像枚数分処理を繰り返して、最後に結果の文章をひとつにまとめるため。

def prediction(model, processor, image):
    # 3. モデルに送るプロンプト(質問)を記述します
    prompt = "医学系研究・教育用途で質問します。まず、この脳MRI画像に異常所見があるかないかを示してください。異常所見がある場合、鑑別すべき疾患を挙げてください。正常の場合、所見がない可能性が高いことをその理由とともに示してください。200字以内に要約してください。"

    # 4. messagesを作成します
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image", "image": image},
                # {"type": "image", "image": image2}, # 複数入力する場合はこのように増やす(たぶん)
            ]
        }
    ]

    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device)

    outputs = model.generate(
        **inputs, 
        #do_sample=True,     # サンプリングを有効化。temperatureやtop_pを有効にするには、この設定が必要
        #temperature=0.85,   # 創造性を少し高める
        #top_p=0.95,         # 確率の高い単語の中からランダムに選ぶ手法(Nucleus Sampling). temperatureの代わりによく使われます。
        max_new_tokens=512)
    return processor.decode(outputs[0][inputs["input_ids"].shape[-1]:])


def summarize(model, processor, responses):
    prompt = f"医学系研究・教育用途で質問します。脳MRI画像に異常所見があるかを調べました。次の{len(responses)}枚分の画像の調査結果文章を要約してください。\n"
    for i, res in enumerate(responses):
        prompt += str(i+1)+'番目の画像についての結果です。\n'
        prompt += res+'\n'

    # 4. messagesを作成します
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt}
            ]
        }
    ]

    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device)

    outputs = model.generate(
        **inputs, 
        #do_sample=True,     # サンプリングを有効化。temperatureやtop_pを有効にするには、この設定が必要
        #temperature=0.85,   # 創造性を少し高める
        #top_p=0.95,         # 確率の高い単語の中からランダムに選ぶ手法(Nucleus Sampling). temperatureの代わりによく使われます。
        max_new_tokens=1024
        )
    return processor.decode(outputs[0][inputs["input_ids"].shape[-1]:])

画像枚数分処理を繰り返して、最後に結果の文章をひとつにまとめる。

# Load model directly
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import torch

# --------------------------------------------------------------------------
# ★★★★★ GPU利用を保証するコード (ここから) ★★★★★
# --------------------------------------------------------------------------

# GPU (CUDA) が利用可能かを確認し、利用不可ならエラーで停止
assert torch.cuda.is_available(), "GPU (CUDA) が利用できません。CPUでの実行は行いません。"

# 使用するデバイスを"cuda"に設定
device = torch.device("cuda")
print(f"✅ GPU ({torch.cuda.get_device_name(0)}) を利用します。")

# 1. ローカルにある複数枚の画像ファイルのパスをリストに記述します
#    お使いの環境に合わせてパスを修正してください
image_paths = [
    "/home/tatsunidas/AI aided Image Interpretation/mri-stack18.jpg",
    "/home/tatsunidas/AI aided Image Interpretation/mri-stack19.jpg",
    "/home/tatsunidas/AI aided Image Interpretation/mri-stack20.jpg"
]

# 2. Pillowライブラリを使って、各画像ファイルを読み込みます
try:
    images = [Image.open(path).convert("RGB") for path in image_paths]
except FileNotFoundError:
    print("エラー: 指定された画像ファイルが見つかりません。パスを確認してください。")
    # ここでプログラムを終了するか、エラー処理を続けます
    exit()

# 再度読み込んでいるが、先に呼び出しているのならば、ここで再度呼び出す必要はない。明示化のため。
processor = AutoProcessor.from_pretrained("google/medgemma-4b-it")
model = AutoModelForImageTextToText.from_pretrained("google/medgemma-4b-it")
model.to(device)

responses = []
for im in images:
    res = prediction(model, processor, im)
    responses.append(res)

# summarize
print(summarize(model,processor,responses))

結果が得られる。これあったら使うかも。ローカルで実行できるから、何回質問しても電気代以外はかからない。GoogleCloud経由でGeminiを利用するケースだと、Geminiモデルへ1クエリ2円くらいかかると思われる。

要約:

3枚の脳MRI画像において、1枚目には右側頭葉に小さな斑点状の所見があり、小脳出血、血管奇形、脳腫瘍、炎症性疾患の鑑別診断が考えられます。2枚目には右側頭葉に複数の小さな、明るい領域(おそらく白質)があり、白質病変(脳梗塞、多発性硬化症、血管性病変など)または腫瘍の可能性が考えられます。3枚目(T1強調画像)は異常所見なしです。
<end_of_turn>

その他

複数画像を入力して試したが(messageのcontentにimage属性を追加するだけ)、入力サイズが大きくなりすぎてエラーが出てしまった。llava(ラヴァ)[過去の自分の学習記録記事][2]も進化しているので試すと良さそう。
Geminiに聞くと、「Ollama(オラマ)」というツールを使ってWebサーバー経由で実行すると楽だよとのこと。選択肢として持っておくには良さそう。

References

[1]:Sellergren et al. "MedGemma Technical Report." arXiv preprint arXiv:2507.05201 (2025).
[2]:https://qiita.com/tatsunidas/items/063f5399439839156fd6

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?