0
1

胸部単純X線写真からの画像キャプショニング:GIT編

Last updated at Posted at 2024-07-03

モチベーション

ググッときた。

この技術が発展するとどう嬉しいのか

  • 画像確認レポートや、診断レポートのテンプレにできるかもしれない
  • 仕事が忙しいときや、疲れているときにでも、気づきを与えてくれる可能性

アプローチ

今回は、Image To Textで、GITを利用する。
この他にも、大きなフレームワークとしては、次のようなものがある。

Text To Text

キーワードから、文章を予測させる試みなど。
例えば、MeSHのような、標準化された用語をいくつか入力すると、レポートの文章が自動で作られるようなアプローチ。
逆もありで、長文を要約するアプローチも可能。
BERT、GPT、T5などのモデルで検証可能。

Image To Text

画像を入力して、画像から直接、文章を予測させるアプローチ。
CNN+LSTM、CNN+GPT、GITモデルなどで検証可能。

Image To Text To Text

これは、私の妄想フレームワーク。
例えば、画像から、キーワードだけをいくつか予測させておき、このキーワードから、Text To Textで文章を作成させるアプローチ。
画像から精度の高い文章が予測させられないケースでは有用になるのではないか。

データセット

Chest X-rays (Indiana University)
Open-i dataset taken from openi.nlm.nih.gov
https://www.kaggle.com/datasets/raddar/chest-xrays-indiana-university

モデル

GIT
@article{wang2022git, title={GIT: A Generative Image-to-text Transformer for Vision and Language}, author={Wang, Jianfeng and Yang, Zhengyuan and Hu, Xiaowei and Li, Linjie and Lin, Kevin and Gan, Zhe and Liu, Zicheng and Liu, Ce and Wang, Lijuan}, journal={arXiv preprint arXiv:2205.14100}, year={2022} }

コードを書いていく

必要なライブラリを用意する。

# !pip install torch torchvision
# !pip install transformers datasets evaluate -q
# !pip install jiwer -q # いらないかも
# !pip install nltk

ダウンロードしたデータセットを解凍する。

# !unzip archive.zip

パッケージ等のインポート

import pandas as pd
import numpy as np
import os, time, copy, re, html, string

from PIL import Image
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import AdamW

デバイスの識別:cpu or gpu

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

画像パス、画像の属性、および、レポートファイルを読み込んでいく。

img_base_dir = 'images/images_normalized/'
reports_df = pd.read_csv('indiana_reports.csv')
reports_df.head()
projections_df = pd.read_csv('indiana_projections.csv')
projections_df.head()

サイズを確認する。

#images 7470
#reports (3851, 8)
#projections (7466, 3)
print('images',len(os.listdir(img_base_dir)))
print('reports', reports_df.shape)
print('projections', projections_df.shape)

CSVをマージする。

'''
merge
'''
import math
def merge_df(reports, projections):
    images_captions_df = pd.DataFrame()
    for i in range(len(projections)):
        uid = projections.iloc[i]['uid']
        img = projections.iloc[i]['filename']
        proj = projections.iloc[i]['projection']
        label = reports[reports['uid']==uid]
        mesh = label['MeSH'].values[0]
        # Problems	image	indication	comparison	findings	impression
        problem = label['Problems'].values[0]
        img_attr = label['image'].values[0]
        indication = label['indication'].values[0]
        comparison = label['comparison'].values[0]
        findings = label['findings'].values[0]
        impression = label['impression'].values[0]
#         if findings is np.nan:
#             continue
        row = {
            'uid':[uid],
            'filename':[img],
            'projection':[proj],
            'MeSH':[mesh],
            'Problems':[problem],
            'image':[img_attr],
            'indication':[indication],
            'comparison':[comparison],
            'findings':[findings],
            'impression':[impression],
        }
        row_df = pd.DataFrame(row)
        images_captions_df = pd.concat([images_captions_df, row_df], ignore_index=True)
    
    return images_captions_df
df = merge_df(reports_df, projections_df)

キャプションとして扱いたい項目(ラベルとする文章)を指定して、NaNをDropする。

def select_df(df_merged, text_col):
    df_ = df_merged.dropna(subset=[text_col])
    return df_
'''
caption対象からNaNを落とす
'''
df_s = select_df(df, 'findings')

サイズ確認

print(df_s.shape)
print(df_s['findings'].isnull().values.sum()) # 0 なので、NaNは綺麗になくなった

サンプルを表示する。(正面のみと、正面・側面がある場合がある)
findingsも出力。(ここでは画像のみ表示した例)

def display_sample(uid, img_base_dir, df_merged):
    plt.figure(figsize=(10, 5))
    print("UID:", uid)
    samples = df_merged[df_merged['uid']==uid]
    findings = samples.iloc[0]['findings']
    for i in range(samples.shape[0]):
        filepath = img_base_dir + samples.iloc[i]['filename']
        proj = samples.iloc[i]['projection']
        plt.subplot(1, samples.shape[0], i+1)
        img = Image.open(filepath)
        rgb = img.convert('RGB')
        plt.title(proj)
        plt.imshow(rgb)
    plt.show()
    print("Findings:", findings)
display_sample(1, img_base_dir,df)

image.png

画像のファイル名(filename)を相対パスに変換する。

df_s['filename'] = img_base_dir + df_s['filename']

テキストの前処理

'''
text preprocessing
'''
from nltk.stem import PorterStemmer, WordNetLemmatizer
from nltk.tokenize import word_tokenize

def remove_chars(text):
    '''
    for html, xml
    '''
    re1 = re.compile(r'  +')
    x1 = text.lower().replace('#39;', "'").replace('amp;', '&').replace('#146;', "'").replace(
        'nbsp;', ' ').replace('#36;', '$').replace('\\n', "\n").replace('quot;', "'").replace(
        '<br />', "\n").replace('\\"', '"').replace('<unk>', 'u_n').replace(' @.@ ', '.').replace(
        ' @-@ ', '-').replace('\\', ' \\ ')
    return re1.sub(' ', html.unescape(x1))


def remove_non_ascii(text):
    '''
    remove euro accent, unicode(partial), emoji.
    '''
    return re.sub(r'[^\x00-\x7F]+', '', text)


def to_lowercase(text):
    return text.lower()


def remove_punctuation(text):
    '''
    remove: !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~
    '''
    translator = str.maketrans('', '', string.punctuation)
    return text.translate(translator)


def replace_numbers(text):
    return re.sub(r'\d+', '', text)


def remove_start_end_whitespaces(text):
    '''
    "   Hello, world!   " to "Hello, world!"
    '''
    return text.strip()


def remove_stopwords(words, stop_words):
    return [word for word in words if word not in stop_words]


def stem_words(words):
    '''
    running: run
    ran: ran
    runs: run
    runner: runner
    easily: easili
    fairly: fairli
    fairness: fair
    '''
    stemmer = PorterStemmer()
    return [stemmer.stem(word) for word in words]


def text2words(text):
    '''
    This is an example to ['This', 'is', 'an', 'example']
    '''
    return word_tokenize(text)


def lemmatize_words(words):
    lemmatizer = WordNetLemmatizer()
    return [lemmatizer.lemmatize(word) for word in words]


def lemmatize_verbs(words):
    lemmatizer = WordNetLemmatizer()
    return ' '.join([lemmatizer.lemmatize(word, pos='v') for word in words])


def normalize_text(text):
    text = remove_chars(text)
    text = remove_non_ascii(text)
    text = remove_punctuation(text)
    text = to_lowercase(text)
    text = replace_numbers(text)
    text = remove_start_end_whitespaces(text)
    words = text2words(text)
    words = remove_stopwords(words, stop_words=['xxxx'])
    #words = stem_words(words)# Either stem ovocar lemmatize
    words = lemmatize_words(words)
#     words = lemmatize_verbs(words)
    return ' '.join(words)

sample

'''
text preprocess sample
'''
for txt in df_s.iloc[:5]['findings']:
    print(txt)
    print()
    print("\033[32m"+normalize_text(txt)+"\033[0m")
    print()
The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax.

the cardiac silhouette and mediastinum size are within normal limit there is no pulmonary edema there is no focal consolidation there are no of a pleural effusion there is no evidence of pneumothorax

The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax.

the cardiac silhouette and mediastinum size are within normal limit there is no pulmonary edema there is no focal consolidation there are no of a pleural effusion there is no evidence of pneumothorax

Borderline cardiomegaly. Midline sternotomy XXXX. Enlarged pulmonary arteries. Clear lungs. Inferior XXXX XXXX XXXX.

borderline cardiomegaly midline sternotomy enlarged pulmonary artery clear lung inferior

Borderline cardiomegaly. Midline sternotomy XXXX. Enlarged pulmonary arteries. Clear lungs. Inferior XXXX XXXX XXXX.

borderline cardiomegaly midline sternotomy enlarged pulmonary artery clear lung inferior

There are diffuse bilateral interstitial and alveolar opacities consistent with chronic obstructive lung disease and bullous emphysema. There are irregular opacities in the left lung apex, that could represent a cavitary lesion in the left lung apex.There are streaky opacities in the right upper lobe, XXXX scarring. The cardiomediastinal silhouette is normal in size and contour. There is no pneumothorax or large pleural effusion.

there are diffuse bilateral interstitial and alveolar opacity consistent with chronic obstructive lung disease and bullous emphysema there are irregular opacity in the left lung apex that could represent a cavitary lesion in the left lung apexthere are streaky opacity in the right upper lobe scarring the cardiomediastinal silhouette is normal in size and contour there is no pneumothorax or large pleural effusion

ちょっと力技だけど、先にキャプションを前処理して置き換える。

findings = []
for i in range(df_s.shape[0]):
    txt = df_s.iloc[i]['findings']
    findings.append(normalize_text(txt))
df_s['findings'] = findings

データセットの分割

from sklearn.model_selection import GroupShuffleSplit
gss_t = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=None)
gss_v = GroupShuffleSplit(n_splits=1, test_size=0.1, random_state=None)
train_index, test_index = next(iter(gss_t.split(df_s, None, df_s['uid'])))
train_index, val_index = next(iter(gss_v.split(df_s.iloc[train_index], None, df_s.iloc[train_index]['uid'])))

train_df = df_s.iloc[train_index]
val_df = df_s.iloc[val_index]
test_df = df_s.iloc[test_index]
print(train_df.shape) # (4656, 10)
print(val_df.shape) # (520, 10)
print(test_df.shape) # (1293, 10)

GITのプロセッサーをロード。

from transformers import AutoProcessor

# GIT (GenerativeImage2Text), base-sized
checkpoint = "microsoft/git-base"
processor = AutoProcessor.from_pretrained(checkpoint)# transformers.models.git.processing_git.GitProcessor

カスタムデータセットを定義。

class CustomImageTextDataset(Dataset):
    def __init__(self, image_paths, texts, processor):
        self.image_paths = image_paths.values
        self.texts = texts.values
        self.processor = processor
        # ここは必要
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        image = self.transform(image)
        text = self.texts[idx]
        return image, text

モデルの精度評価用に、単語誤り率を利用。

from evaluate import load

# 単語誤り率(単語を対象とした[1-正解率])
wer = load("wer")

def compute_metrics(logits, labels):
    '''
    Word Error Rate
    '''
    predicted = logits.argmax(-1)
    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)
    decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)
    wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels)
    return wer_score

バッチサイズを決める。

'''
BATCH SIZE
※ over 8 slices, will occur out of memory...
'''
batch_size = 8

データローダを定義。

train_dataset = CustomImageTextDataset(train_df['filename'], train_df['findings'], processor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataset = CustomImageTextDataset(val_df['filename'], val_df['findings'], processor)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataset = CustomImageTextDataset(test_df['filename'], test_df['findings'], processor)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

訓練のコードを書く

def train_model(model,
                device,
                optimizer,
                train_loader,
                val_loader,
                scheduler,
                num_epochs=25,
                early_stop=5):
    since = time.time()
    model.to(device)
    best_model_wts = copy.deepcopy(model.state_dict())
    # 学習状況を記録するためのリスト
    train_loss_list = []
    train_wer_list = []
    val_loss_list = []
    val_wer_list = []
    
    # 最小の損失とその時のモデルを記録する変数
    min_loss = float('inf')
    best_wer = float('inf')
    patience = early_stop  # 検証の損失が改善するまで許容するエポック数
    early_stop_counter = 0
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)

        # 各エポックには訓練フェーズと検証フェーズがあります
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # モデルを訓練モードに設定します
                dataloader = train_loader
            else:
                model.eval()   # モデルを評価モードに設定します(dropoutやbatch normのため)
                dataloader = val_loader

            running_loss = 0.0
            running_wers = 0
            samples = 0
            
            for images, texts in dataloader:
                inputs = processor(text=texts,
                                   images=images,
                                   return_tensors='pt',
                                   padding=True,       # パディングを有効にする
                                   #truncation=True,    # トランケーションを有効にする
                                   #max_length=709 #512      # 必要に応じて最大長を指定
                                  )
                
                images = inputs['pixel_values'].to(device)
                texts = inputs['input_ids'].to(device)
                attention_masks = inputs['attention_mask'].to(device)
                labels = torch.tensor(texts, device=device)
                
                optimizer.zero_grad()
                
                # 順伝播
                # 訓練の時だけ、勾配を計算します
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(input_ids=texts, 
                                    attention_mask=attention_masks, 
                                    pixel_values=images, 
                                    labels=labels)
                    loss = outputs.loss # returned when labels is provided
                    wer_in_batch = compute_metrics(outputs.logits, labels)

                    # 訓練の時だけ逆伝播+最適化関数による更新を行います
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                # 損失を計算します
                running_loss += loss.item()  # lossはバッチ単位で計算
                running_wers += wer_in_batch # バッチ単位のWER
                samples += len(images)
            
            epoch_loss = running_loss / samples
            epoch_wer = running_wers / samples
            
            if phase == 'train':
                train_loss_list.append(epoch_loss)
                train_wer_list.append(epoch_wer)
            else:
                scheduler.step(epoch_loss)
                val_loss_list.append(epoch_loss)
                val_wer_list.append(epoch_wer)
            
            print('{} Loss: {:.4f} WER: {:.4f}'.format(phase, epoch_loss, epoch_wer))
            
            # モデルをディープ・コピーします
            if phase == 'val' and epoch_loss < min_loss:
                min_loss = epoch_loss
                early_stop_counter = 0  # カウンターをリセット
                if epoch_wer <= best_wer:
                    best_wer = epoch_wer
                    best_model_wts = copy.deepcopy(model.state_dict())
            else:
                if phase == 'val':
                    early_stop_counter += 1
        
        # Early Stoppingのチェック
        if early_stop_counter >= patience:
            print()
            print(f'Early stopping at epoch {epoch+1}')
            break
        
        print()
        
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Loss: {:.4f} WER: {:.4f}'.format(min_loss, best_wer))
    print()
    # AccuracyとLossの可視化
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_loss_list, label='Training Loss')
    plt.plot(val_loss_list, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(train_wer_list, label='Training WER')
    plt.plot(val_wer_list, label='Val WER')
    plt.xlabel('Epoch')
    plt.ylabel('WER')
    plt.legend()
    plt.show()

    # ベストモデルの重みをロードします
    model.load_state_dict(best_model_wts)
    return model

訓練済みGITモデルをロードして、転移学習。

from transformers import AutoModelForCausalLM
model = None
model = AutoModelForCausalLM.from_pretrained(checkpoint)

model_id = '5e-5'

learning_rate = 5e-5 # keep small (1e-3 was failed in my experience)
num_epochs = 30 
# Define the optimizer
optimizer = AdamW(model.parameters(), lr=learning_rate)

# 学習中、損失が改善しなくなったら、自動的に学習率を0.3倍するようにスケジュールします。
from torch.optim import lr_scheduler
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=2, min_lr=1e-6)

GIT_Che = train_model(model,
                      device,
                      optimizer,
                      train_loader,
                      val_loader,
                      scheduler,
                      num_epochs=num_epochs,
                      early_stop=5)        

download.png

↑werのカーブが0.8前後で動かなくなってしまっているが、学習を何回か試すと、WERが下に綺麗に下がることもある。

モデルを保存

model_name = model_id+'_'+'model_weights.pth'
torch.save(GIT_Che.state_dict(), model_name)

ロードしたいときはこのようにする。

# load example
test_model = AutoModelForCausalLM.from_pretrained(checkpoint)
test_model.load_state_dict(torch.load(model_name))

テストのためのコードを書く。

'''
TEST
'''
def test(model,
         device,
         test_loader):
    since = time.time()
    model.to(device)
    
    model.eval()   # モデルを評価モードに設定します(dropoutやbatch normのため)
    dataloader = test_loader
    running_loss = 0.0
    running_wers = 0
    samples = 0
           
    for images, texts in dataloader:
        inputs = processor(text=texts,
                           images=images,
                           return_tensors='pt',
                           padding=True,       # パディングを有効にする
                           #truncation=True,    # トランケーションを有効にする
                           # max_length=709 #512      # 必要に応じて最大長を指定
                           )
                
        images = inputs['pixel_values'].to(device)
        texts = inputs['input_ids'].to(device)
        attention_masks = inputs['attention_mask'].to(device)
        labels = torch.tensor(texts, device=device)
        optimizer.zero_grad()
            
        with torch.set_grad_enabled(False):
            outputs = model(input_ids=texts, 
                            attention_mask=attention_masks, 
                            pixel_values=images, 
                            labels=labels)
            loss = outputs.loss # returned when labels is provided
            wer_in_batch = compute_metrics(outputs.logits, labels)
            # 損失を計算します
            running_loss += loss.item()  # バッチ単位のロス
            running_wers += wer_in_batch # バッチ単位のWER
            samples += len(images)
            
    loss = running_loss / samples
    wer = running_wers / samples
    
    print('{} Loss: {:.4f} WER: {:.4f}'.format('TEST', loss, wer))
            
    time_elapsed = time.time() - since
    print('Test complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print()

テスト時のlossとwerを確認する。これが小さくなるほど良い。

test(test_model,
     device,
     test_loader)
TEST Loss: 0.0435 WER: 0.8338
Test complete in 1m 16s

バッチでなく、1例ずつテストしたいときはこのようにもできる。

def test_sample(index, df_s, model, processor, device):
    ```
    indexはdf_sの行番号
    ```
    sample = df_s.iloc[index,]
    print("UID:", sample.uid)
    findings = sample['findings']
    filepath = sample['filename']
    proj = sample['projection']
    img = Image.open(filepath)
    rgb = img.convert('RGB')
    plt.title(proj)
    plt.imshow(rgb)
    plt.show()
    plt.pause(0.01)
    model.to(device)
    model.eval()
    
    transform = transforms.Compose([transforms.Resize((224, 224)),
                                    transforms.ToTensor()])
    
    image_tensor = transform(img).unsqueeze(0).to(device)
    
    pixel_values = processor(images=img, return_tensors="pt").pixel_values.to(device)
    with torch.no_grad():
        generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        print('='*50)
        print("Findings:\n", findings)
        print('='*50)
        print('Prediction:')
        print("\033[34m"+generated_caption+"\033[0m")
test_sample(5000, df_s, test_model, processor, device)

image.png

Findings:
 stable appearance of previous sternotomy stable cardiomegaly stable mild bilateral interstitial opacity in which may represent mild pulmonary edema no evidence of large pleural effusion or pneumothorax
 (肺は両側とも明瞭であり、特に局所的な圧密気胸や胸水貯留の所見はない。 胸郭の骨格構造も視認され、縦隔シルエットに異常はない。)
Prediction:
the lung are clear bilaterally specifically no evidence of focal consolidation pneumothorax or pleural effusion cardio mediastinal silhouette is unremarkable visualized osseous structure of the thorax are without acute
(安定した心肥大 安定した軽度の両側間質性混濁 軽度の肺水腫の可能性あり 大きな胸水や気胸の所見なし)

どのようなトークン(学習させた文字)が注目されたかを知るには、attention mapを使う。

def calc_attention_map(index, df_s, model, processor, device):
    sample = df_s.iloc[index,]
    filepath = sample['filename']
    img = Image.open(filepath)
    text = sample['findings']
    
    inputs = processor(text=text, images=img, return_tensors='pt',padding="max_length")
    images = inputs['pixel_values'].to(device)
    texts = inputs['input_ids'].to(device)
    
    pixel_values = processor(images=img, return_tensors="pt").pixel_values.to(device)
    outputs = model(input_ids=texts, pixel_values=pixel_values, output_attentions=True)
    # Extract attention weights
    attentions = outputs.attentions
    '''
     attentions: [layer](batch_size, num_heads, sequence_length, sequence_length).
    '''
    # print(len(attentions)) # 6 layers
    # Example: Plot attention map for layer 0, head 0
    plot_attention_map(attentions, layer=0, head=0)
    
def plot_attention_map(attentions, layer, head):
    '''
    token to token visualization map
    e.g, https://captum.ai/tutorials/Bert_SQUAD_Interpret2
    
    layer: 0 to 5
    batch: バッチ次元
    head: 0 to 11 (マルチヘッドアテンション(multi-head attention)のヘッド)
    sequence_length: モデルが処理するトークン(単語やサブワードなど)の数
    '''
    attn_at_layer = attentions[layer]
    # torch.Size([1, 12, 709, 709])
    # print(attn_at_layer.shape) # batch, heads, sequence_length, sequence_length
    attention = attn_at_layer[0, head].detach().cpu().numpy()
    fig, ax = plt.subplots(figsize=(10, 10))
    cax = ax.matshow(attention, cmap='viridis')
    # fig.colorbar(cax)
    plt.title(f"Layer {layer+1}, Head {head+1}")
    plt.show()
calc_attention_map(5000, df_s, test_model, processor, device)

download.png

確認したい層やヘッドは適宜変える。

おわり

これを、CTやMRIなどに応用すれば良い。

References

stay visionary

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1