CLIPをFine-Tuneして病理画像分類に挑戦してみた
こんにちは、しゅんです。
今回は、気分転換も兼ねて、Kaggleからダウンロードした NCT-CRC-HE-100K (約15.56 GB)データセットを使用し、OpenAIの CLIP (clip-vit-base-patch32) モデルをファインチューニングして、病理画像の分類に挑戦した実験についてシェアします。
正直なところ、業界やKaggleコンペでは同じようなアプローチを試みている方も多いかもしれません。僕自身は、特別な研究者ではなく、ただの気分転換として、そして自己成長のために取り組んでいる小さな挑戦です。
母のがん診断や、従兄弟が天性不治の病で亡くなったという、個人的に非常につらい経験があるからこそ、人を助ける技術に触れたいという思いは、ずっと胸に秘めています。
さらに、12年間日本語と技術を学んできた経験を活かし、現在は日本で留学中ですが、いつか日本で恩返しをしながら、社会に貢献できる技術を開発できればと考えています。
データセットの詳細
今回利用する NCT-CRC-HE-100K データセットは、Zenodoで公開されている内容に基づいています。
以下のような特徴があります:
-
概要:
86枚のHE染色されたヒト癌組織スライドと正常組織から抽出された、100,000枚の重複しない画像パッチで構成。
すべての画像は224×224ピクセル、0.5ミクロン/ピクセルの解像度で、Macenkoの手法により色正規化が施されています。 -
組織クラス:
画像は以下の9つのクラスに分類されています。- ADI: Adipose tissue(脂肪組織)
- BACK: Background(背景)
- DEB: Debris(細胞破片)
- LYM: Lymphocytes(リンパ球)
- MUC: Mucin(粘液)
- MUS: Smooth muscle(平滑筋、筋組織)
- NORM: Normal colon mucosa(正常大腸粘膜)
- STR: Cancer-associated stroma(癌関連間質)
- TUM: Colorectal adenocarcinoma epithelium(大腸腺癌上皮)
-
背景と倫理:
これらの画像は、ドイツのNCT Biobank(Heidelberg)およびUniversity Medical Center Mannheimの病理アーカイブから抽出され、専門の病理医により手動で組織領域が区分されています。
すべてのサンプルは倫理審査委員会の承認の下で匿名化されており、検証用の CRC-VAL-HE-7K や色正規化を行っていないバージョン NCT-CRC-HE-100K-NONORM も存在します。
詳細はZenodoのレコードを参照してください。
環境構築手順
以下の手順で作業環境を構築してください。
ダウンロードしたデータを必ずdata
フォルダに保存してください。
# 1. 任意のフォルダを作成して移動
mkdir my_clip_project
cd my_clip_project
# 2. 仮想環境の作成と有効化
python -m venv .venv
source .venv/bin/activate
# 3. pipのアップグレードと必要ライブラリのインストール
pip install --upgrade pip
pip install transformers pillow torch tqdm torchvision
# 4. dataフォルダを作成し、KaggleからダウンロードしたNCT-CRC-HE-100Kデータを配置
mkdir data
# 5. ソースコード用のファイルを作成
touch main.py
touch fine_tune.py
fine_tune.py のコードと解説
以下は、CLIPモデルに線形分類器を追加し、病理画像分類用にファインチューニングを行うコードです。
進捗表示にはtqdm
を使用しています。
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from transformers import CLIPProcessor, CLIPModel
from tqdm import tqdm # tqdmのインポート
# --- 1. CLIP用の前処理クラス ---
class CLIPTransform:
"""
CLIPProcessorを利用して画像を前処理するクラスです。
PIL.Imageを受け取り、モデルが受け取る形式(Tensor)に変換します。
"""
def __init__(self, processor):
self.processor = processor
def __call__(self, image):
pixel_values = self.processor(images=image, return_tensors="pt")["pixel_values"]
return pixel_values.squeeze(0)
# --- 2. ファインチューニング用のモデル定義 ---
class CLIPFineTuner(nn.Module):
"""
CLIPの画像エンコーダの出力に対して、線形層(分類器)を追加します。
fine_tune_clip=Trueの場合は、CLIP本体もファインチューニング対象にします。
"""
def __init__(self, clip_model, num_classes, fine_tune_clip=False):
super().__init__()
self.clip_model = clip_model
if not fine_tune_clip:
for param in self.clip_model.parameters():
param.requires_grad = False
self.classifier = nn.Linear(512, num_classes)
def forward(self, pixel_values):
image_features = self.clip_model.get_image_features(pixel_values=pixel_values)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
logits = self.classifier(image_features)
return logits
# --- 3. 学習ループ ---
def train(model, dataloader, optimizer, device, epochs=1):
model.train()
loss_fn = nn.CrossEntropyLoss()
for epoch in range(epochs):
running_loss = 0.0
for images, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_loss = running_loss / len(dataloader)
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
# --- 4. メイン処理 ---
if __name__ == "__main__":
device = torch.device("mps" if torch.mps.is_available() else "cpu")
print("Using device:", device)
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
clip_model = CLIPModel.from_pretrained(model_name)
clip_model.to(device)
dataset_dir = "data/NCT-CRC-HE-100K"
transform = CLIPTransform(processor)
dataset = ImageFolder(root=dataset_dir, transform=transform)
print("クラスラベル:", dataset.classes)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)
num_classes = len(dataset.classes)
model_ft = CLIPFineTuner(clip_model, num_classes, fine_tune_clip=False)
model_ft.to(device)
optimizer = optim.Adam(model_ft.parameters(), lr=1e-4)
epochs = 10
train(model_ft, dataloader, optimizer, device, epochs=epochs)
os.makedirs("model", exist_ok=True)
model_path = os.path.join("model", "clip_finetuned.pth")
torch.save(model_ft.state_dict(), model_path)
print(f"Fine-tuning完了!モデルを {model_path} として保存しました。")
fine_tune.py のポイント
-
CLIPTransform:
CLIPProcessorを使って画像をTensorに変換し、モデルが受け入れる形式に整えます。 -
CLIPFineTuner:
事前学習済みのCLIPモデルに対して、分類器として線形層を追加。
fine_tune_clip
フラグで、CLIP本体のパラメータを更新するか否かを選択できます(ここでは固定しています)。 -
train関数:
tqdmを利用して学習進捗を表示しながら、クロスエントロピー損失とAdamオプティマイザで学習を実施します。 -
メイン処理:
データセットの読み込み、モデルの学習、そして学習済みモデルの保存を行います。
main.py のコードと解説
以下は、学習済みファインチューニングモデルと元のCLIPモデルの推論結果を比較するコードです。
import os
import random
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision.datasets import ImageFolder
from transformers import CLIPProcessor, CLIPModel
from fine_tune import CLIPTransform, CLIPFineTuner
def main():
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
clip_model = CLIPModel.from_pretrained(model_name).to(device)
dataset_dir = "data/NCT-CRC-HE-100K"
transform = CLIPTransform(processor)
dataset = ImageFolder(root=dataset_dir, transform=transform)
classes = dataset.classes
print("クラスラベル:", classes)
sample_index = random.randint(0, len(dataset) - 1)
image_tensor, true_label = dataset[sample_index]
print("選ばれた画像の正解ラベル:", classes[true_label])
text_prompts = [f"{cls} tissue" for cls in classes]
sample_image_path = dataset.imgs[sample_index][0]
pil_image = Image.open(sample_image_path).convert("RGB")
inputs_clip = processor(text=text_prompts, images=pil_image, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
outputs_clip = clip_model(**inputs_clip)
logits_clip = outputs_clip.logits_per_image
probs_clip = logits_clip.softmax(dim=1).squeeze(0)
print("\n--- オリジナルCLIPのゼロショット分類結果 ---")
for prompt, prob in zip(text_prompts, probs_clip.tolist()):
print(f"{prompt}: {prob:.4f}")
num_classes = len(classes)
model_ft = CLIPFineTuner(clip_model, num_classes, fine_tune_clip=False).to(device)
model_path = os.path.join("model", "clip_finetuned.pth")
state_dict = torch.load(model_path, map_location=device)
model_ft.load_state_dict(state_dict)
model_ft.eval()
image_tensor = image_tensor.unsqueeze(0).to(device)
with torch.no_grad():
logits_ft = model_ft(image_tensor)
probs_ft = F.softmax(logits_ft, dim=1).squeeze(0)
print("\n--- ファインチューニング済みモデルの分類結果 ---")
for cls, prob in zip(classes, probs_ft.tolist()):
print(f"{cls}: {prob:.4f}")
if __name__ == "__main__":
main()
main.py のポイント
-
データとモデルの準備:
ImageFolderとCLIPTransformを用いてデータセットを読み込み、CLIPのプロセッサとモデルを初期化します。 -
ゼロショット推論:
テキストプロンプトと画像の類似度から、元のCLIPモデルによるゼロショット分類結果を出力します。 -
ファインチューニング済みモデルの推論:
fine_tune.pyで学習したモデルのパラメータをロードし、同じ画像に対する分類結果を表示。
これにより、ファインチューニング前後の性能の違いが明確に比較できます。
実験結果と考察
結果
最初は、fine-tuningを5エポックで実施し、その結果をmain.pyで2回確認しました。その後、tdqmをインストールしてから再度fine-tuningを行い、今回はエポック数を10に設定しました。同様に、main.pyを2回実行して結果を確認しました。つまり、5エポックと10エポックそれぞれでfine-tuningを行い、各条件でmain.pyを2回実行してその経過を比較しました。
見やすくためにまとめた表
Fine-tuning Loss Summary
Experiment | Epochs | Final Loss |
---|---|---|
初回 Fine-tuning | 5 | 0.6909 |
Fine-tuning (tdqm導入, 再実行) | 10 | 0.4705 |
Inference Results (main.py)
Experiment | Run | 正解ラベル | オリジナルCLIP (ゼロショット) | ファインチューニング済みモデル |
---|---|---|---|---|
初回 Fine-tuning (5 エポック) | 1 | STR | STR tissue: 0.0011 | STR: 0.2128 |
初回 Fine-tuning (5 エポック) | 2 | DEB | DEB tissue: 0.0004 | DEB: 0.5304 |
Fine-tuning (tdqm, 10 エポック) | 1 | MUS | MUS tissue: 0.4051 | MUS: 0.6806 |
Fine-tuning (tdqm, 10 エポック) | 2 | BACK | BACK tissue: 0.0433 | BACK: 0.9936 |
これらの表は、最初に5エポックで実施したファインチューニングと、tdqmを導入してエポック数を10に増やした再実行時の学習損失の変化および、main.pyで確認した推論結果の違いを示しています。
詳細の経過
python fine_tune.py
Using device: mps
クラスラベル: ['ADI', 'BACK', 'DEB', 'LYM', 'MUC', 'MUS', 'NORM', 'STR', 'TUM']
Epoch 1/5, Loss: 1.7266
Epoch 2/5, Loss: 1.1949
Epoch 3/5, Loss: 0.9388
Epoch 4/5, Loss: 0.7889
Epoch 5/5, Loss: 0.6909
Fine-tuning完了!モデルを model/clip_finetuned.pth として保存しました.
python main.py
Using device: mps
クラスラベル: ['ADI', 'BACK', 'DEB', 'LYM', 'MUC', 'MUS', 'NORM', 'STR', 'TUM']
選ばれた画像の正解ラベル: STR
--- オリジナルCLIPのゼロショット分類結果 ---
ADI tissue: 0.1119
BACK tissue: 0.0006
DEB tissue: 0.0004
LYM tissue: 0.0081
MUC tissue: 0.3725
MUS tissue: 0.4815
NORM tissue: 0.0004
STR tissue: 0.0011
TUM tissue: 0.0235
--- ファインチューニング済みモデルの分類結果 ---
ADI: 0.0060
BACK: 0.0027
DEB: 0.0570
LYM: 0.0037
MUC: 0.1215
MUS: 0.5262
NORM: 0.0222
STR: 0.2128
TUM: 0.0479
python main.py
Using device: mps
クラスラベル: ['ADI', 'BACK', 'DEB', 'LYM', 'MUC', 'MUS', 'NORM', 'STR', 'TUM']
選ばれた画像の正解ラベル: DEB
--- オリジナルCLIPのゼロショット分類結果 ---
ADI tissue: 0.1286
BACK tissue: 0.0007
DEB tissue: 0.0004
LYM tissue: 0.0072
MUC tissue: 0.3560
MUS tissue: 0.4992
NORM tissue: 0.0007
STR tissue: 0.0013
TUM tissue: 0.0058
--- ファインチューニング済みモデルの分類結果 ---
ADI: 0.0081
BACK: 0.0062
DEB: 0.5304
LYM: 0.0161
MUC: 0.0417
MUS: 0.1267
NORM: 0.0245
STR: 0.1345
TUM: 0.1117
# (tdqmをインストール後、エポック数を10に変更して再実行)
python fine_tune.py
Using device: mps
クラスラベル: ['ADI', 'BACK', 'DEB', 'LYM', 'MUC', 'MUS', 'NORM', 'STR', 'TUM']
Epoch 1/10: 100%|█████████████████████████████████████████████| 6250/6250 [08:16<00:00, 12.59it/s]
Epoch 1/10, Loss: 1.7281
Epoch 2/10: 100%|█████████████████████████████████████████████| 6250/6250 [08:02<00:00, 12.94it/s]
Epoch 2/10, Loss: 1.1944
Epoch 3/10: 100%|█████████████████████████████████████████████| 6250/6250 [07:53<00:00, 13.21it/s]
Epoch 3/10, Loss: 0.9385
Epoch 4/10: 100%|█████████████████████████████████████████████| 6250/6250 [08:04<00:00, 12.91it/s]
Epoch 4/10, Loss: 0.7887
Epoch 5/10: 100%|█████████████████████████████████████████████| 6250/6250 [07:52<00:00, 13.22it/s]
Epoch 5/10, Loss: 0.6907
Epoch 6/10: 100%|█████████████████████████████████████████████| 6250/6250 [07:54<00:00, 13.17it/s]
Epoch 6/10, Loss: 0.6216
Epoch 7/10: 100%|█████████████████████████████████████████████| 6250/6250 [08:01<00:00, 12.99it/s]
Epoch 7/10, Loss: 0.5699
Epoch 8/10: 100%|█████████████████████████████████████████████| 6250/6250 [08:03<00:00, 12.92it/s]
Epoch 8/10, Loss: 0.5297
Epoch 9/10: 100%|█████████████████████████████████████████████| 6250/6250 [08:04<00:00, 12.90it/s]
Epoch 9/10, Loss: 0.4972
Epoch 10/10: 100%|████████████████████████████████████████████| 6250/6250 [08:16<00:00, 12.58it/s]
Epoch 10/10, Loss: 0.4705
Fine-tuning完了!モデルを model/clip_finetuned.pth として保存しました。
python main.py
Using device: mps
クラスラベル: ['ADI', 'BACK', 'DEB', 'LYM', 'MUC', 'MUS', 'NORM', 'STR', 'TUM']
選ばれた画像の正解ラベル: MUS
--- オリジナルCLIPのゼロショット分類結果 ---
ADI tissue: 0.0947
BACK tissue: 0.0007
DEB tissue: 0.0005
LYM tissue: 0.0149
MUC tissue: 0.4666
MUS tissue: 0.4051
NORM tissue: 0.0005
STR tissue: 0.0014
TUM tissue: 0.0156
--- ファインチューニング済みモデルの分類結果 ---
ADI: 0.0030
BACK: 0.0002
DEB: 0.0068
LYM: 0.0002
MUC: 0.0387
MUS tissue: 0.6806
NORM tissue: 0.0048
STR tissue: 0.2524
TUM tissue: 0.0133
python main.py
Using device: mps
クラスラベル: ['ADI', 'BACK', 'DEB', 'LYM', 'MUC', 'MUS', 'NORM', 'STR', 'TUM']
選ばれた画像の正解ラベル: BACK
--- オリジナルCLIPのゼロショット分類結果 ---
ADI tissue: 0.2283
BACK tissue: 0.0433
DEB tissue: 0.0421
LYM tissue: 0.1669
MUC tissue: 0.1467
MUS tissue: 0.2233
NORM tissue: 0.0309
STR tissue: 0.0819
TUM tissue: 0.0364
--- ファインチューニング済みモデルの分類結果 ---
ADI: 0.0016
BACK: 0.9936
DEB: 0.0010
LYM tissue: 0.0002
MUC tissue: 0.0002
MUS tissue: 0.0031
NORM tissue: 0.0002
STR tissue: 0.0001
TUM tissue: 0.0001
経過と結果の違い
-
エポック数の違い:
最初は5エポックで実施した結果、損失が約1.73から0.69に低下しました。
その後、エポック数を10に増やすことで、損失はさらに0.47まで下がり、ファインチューニング済みモデルの分類性能が向上しました。 -
ゼロショット vs ファインチューニング:
元のCLIPモデル(ゼロショット)では、病理画像に対して正解クラスの確率が非常に低く、誤ったクラスに高い確率が割り当てられることが多かったのに対し、
ファインチューニング済みモデルは正しいクラスに対して大幅に高い確率を示すようになりました。
streamlitでアプリも作りました
import os
import random
import torch
import torch.nn.functional as F
import streamlit as st
import pandas as pd
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from fine_tune import CLIPTransform, CLIPFineTuner
@st.cache_resource
def load_models():
# 使用デバイスの設定(MPSが利用可能ならMPS、なければCPU)
device = torch.device("mps" if torch.mps.is_available() else "cpu")
# CLIPモデルとプロセッサの読み込み
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
clip_model = CLIPModel.from_pretrained(model_name).to(device)
# Fine-tuning済みモデルの読み込み(CLIPFineTunerを用いて分類層を追加)
num_classes = 9 # NCT-CRC-HE-100K の 9 クラス
model_ft = CLIPFineTuner(clip_model, num_classes, fine_tune_clip=False).to(device)
model_path = os.path.join("model", "clip_finetuned.pth")
state_dict = torch.load(model_path, map_location=device)
model_ft.load_state_dict(state_dict)
model_ft.eval()
return device, processor, clip_model, model_ft
def main():
st.title("CLIP Fine-Tuning 病理画像分類アプリ")
st.write("画像をアップロードすると、元の CLIP(ゼロショット)と Fine-tuning 済みモデルで分類予測を行います。")
# モデルの読み込み
device, processor, clip_model, model_ft = load_models()
classes = ['ADI', 'BACK', 'DEB', 'LYM', 'MUC', 'MUS', 'NORM', 'STR', 'TUM']
text_prompts = [f"{cls} tissue" for cls in classes]
# 画像アップロード
uploaded_file = st.file_uploader("画像を選択 (jpg, jpeg, png, tif, tiff)", type=["jpg", "jpeg", "png", "tif", "tiff"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="アップロードされた画像", use_container_width=True)
# --- オリジナル CLIP (ゼロショット) での予測 ---
inputs_clip = processor(text=text_prompts, images=image, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
outputs_clip = clip_model(**inputs_clip)
logits_clip = outputs_clip.logits_per_image
probs_clip = logits_clip.softmax(dim=1).squeeze(0)
# --- Fine-tuning 済みモデルでの予測 ---
transform = CLIPTransform(processor)
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits_ft = model_ft(img_tensor)
probs_ft = F.softmax(logits_ft, dim=1).squeeze(0)
# 結果を DataFrame で表示
df_clip = pd.DataFrame({
"クラス": classes,
"ゼロショット確率": [f"{p:.4f}" for p in probs_clip.tolist()]
})
df_ft = pd.DataFrame({
"クラス": classes,
"Fine-tuned 確率": [f"{p:.4f}" for p in probs_ft.tolist()]
})
st.subheader("オリジナル CLIP (ゼロショット) の予測結果")
st.table(df_clip)
st.subheader("Fine-tuning 済みモデルの予測結果")
st.table(df_ft)
pred_zero = classes[probs_clip.argmax().item()]
pred_ft = classes[probs_ft.argmax().item()]
st.write(f"**オリジナル CLIP の予測:** {pred_zero}")
st.write(f"**Fine-tuned モデルの予測:** {pred_ft}")
if __name__ == "__main__":
main()
stream run app.py
結果はXで投稿したぜひ見てください
streamlitを使って比較した録画です!!おおお!!!本当にすごい!!!!! pic.twitter.com/bN2zTcXtut
— SYUN@気は長く、勤めは堅く、昨日を越えて、今日を生きる (@syun88AI) March 13, 2025
自分なりの考察
業界やKaggleコンペで同様の手法を試みている方も多いかもしれません。
僕自身は、特別な研究者というわけではなく、ただの気分転換と自己研鑽のためにこのプロジェクトに取り組んでいます。
母のがん診断や、何年前に従兄弟が天性不治の病で亡くなったという個人的な経験から、「人を助ける」ことを夢見てきました。
また、12年間にわたり日本語と技術の勉強を続け、現在は日本で留学中です。いつか日本で恩返しをしながら、社会に貢献できる技術を開発するのが僕の目標です。
正直、最近は躁鬱の診断や人間関係の悩みもあり、気持ちが沈むこともありますが、そうした経験もまた、自分を成長させる大切なプロセスだと思っています。
まとめ
今回の記事では、NCT-CRC-HE-100Kデータセットを使用して、CLIPモデルのファインチューニングによる病理画像分類に挑戦した手順、コード、実験結果、そして考察を詳しく紹介しました。
データセットは、86枚のHE染色スライドから抽出された100,000枚の非重複画像パッチで構成され、9つの組織クラス(ADI, BACK, DEB, LYM, MUC, MUS, NORM, STR, TUM)に分類されています。
これらの画像は、Macenkoの手法で色正規化され、倫理的な承認のもと収集されています。
業界や本格的なKaggleコンペでの取り組みと比べれば、僕の試みは大したことではないかもしれません。しかし、僕にとっては母のがんや従兄弟の悲しい他にも色々の経験を背景に、人を助ける夢を胸に、気分転換と自己成長のために取り組んだ小さな挑戦です。
また、現在は日本で留学中で、12年間の勉強の成果を少しずつ形にしながら、いつか日本で恩返しし、社会に貢献できる技術を開発できればと考えています。
これからも得た知見を活かし、日本に来てさらに恩を返したい方もできたので、もっと精進し、前向きに歩んでいきたいと思います。
今回も最後までお読みいただき、ありがとうございました!