この記事では,VLM(Vision Language Model)の簡易版をスクラッチ開発します.
参考文献
この記事は下記の情報を参考にして執筆しました.
VLM
出典:Implement a Vision Language Model from Scratch
使用するライブラリのimport
import base64
import io
import pandas as pd
from PIL import Image
import torchvision.transforms as transforms
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init
device の設定
# Ensure every computation happens on the GPU when available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
パラメータの設定
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 100
eval_interval = 10
learning_rate = 1e-3
epochs = 1
eval_iters = 40
num_blks= 3
head_size = 16
n_embd = 128
n_head = 8
n_layer = 8
dropout = 0.1
img_size = 96
patch_size = 16
num_hiddens = 512
image_embed_dim = 512
emb_dropout = blk_dropout = 0.1
データセット
以下で使用するデータセット
- input.txt
- inputs.csv
は,こちらからダウンロードできます.
トークナイザー
#To build the encoding and decoding functions we use the tinyshakespear dataset. However for the sake of brevity we do not pretrain the decoder model on it
#the training function should be able to do it without an issue as well as it could take both images and text
text_path = "./input.txt"
with open(text_path, 'r', encoding='utf-8') as f:
text = f.read()
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
# stoi[''] = 65
stoi['<pad>'] = 65
itos = { i:ch for i,ch in enumerate(chars) }
# itos[65] = ''
itos[65] = '<pad>'
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
vocab_size = len(stoi.keys())
print(vocab_size)
# 66
データローダ
def base64_to_tensor(base64_str, img_size=96):
image = Image.open(io.BytesIO(base64.b64decode(base64_str)))
if image.mode != 'RGB':
image = image.convert('RGB')
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0) # Add batch dimension
#Adjusting the data loader from makemore for multimodal data
def get_batch(df, batch_size, split='train', img_size=96, val_batch_size=8):
# Split data into training and validation sets
n = int(0.9 * len(df)) # first 90% will be train, rest val
df_train = df.iloc[:n]
df_val = df.iloc[n:]
data = df_train if split == 'train' else df_val
batch_size = batch_size if split == 'train' else val_batch_size
replace = False if split == 'train' else True
batch = data.sample(n=batch_size, replace=replace)
images = torch.cat([base64_to_tensor(img, img_size) for img in batch['b64string_images']], dim=0).to(device)
text_indices = [torch.tensor(encode(desc), dtype=torch.long) for desc in batch['caption']]
max_length = max(len(t) for t in text_indices)
padded_text = torch.full((batch_size, max_length), fill_value=stoi['<pad>'], dtype=torch.long).to(device)
for i, text in enumerate(text_indices):
padded_text[i, :len(text)] = text
targets = torch.cat([padded_text[:, 1:], torch.full((batch_size, 1), fill_value=stoi['<pad>'], dtype=torch.long, device=device)], dim=1)
# Truncate or pad targets to match the length of padded_text
if targets.size(1) > padded_text.size(1):
targets = targets[:, :padded_text.size(1)]
elif targets.size(1) < padded_text.size(1):
targets = torch.cat([targets, torch.full((batch_size, padded_text.size(1) - targets.size(1)), fill_value=stoi['<pad>'], dtype=torch.long, device=device)], dim=1)
return images, padded_text, targets
df = pd.read_csv("images/inputs.csv")
# Expanding dataframe so that there's enough data to test. This is just duplicating data. A real dataset would have more rows
df = pd.concat([df] * 30)[['b64string_images', 'caption']]
print(df.shape)
# (90, 2)
images, idx, targets = get_batch(df, batch_size, 'train', img_size)
print(images.shape)
# torch.Size([16, 3, 96, 96])
print(idx.shape)
# torch.Size([16, 32])
print(targets.shape)
# torch.Size([16, 32])
from PIL import Image
import matplotlib.pyplot as plt
plt.imshow(images[0].to('cpu').detach().numpy().copy().transpose(1,2,0))
plt.show()
print(idx[0])
# tensor([57, 54, 39, 41, 43, 57, 46, 47, 54, 1, 53, 56, 40, 47, 58, 47, 52, 45, 1, 25, 39, 56, 57, 65, 65, 65, 65, 65, 65, 65, 65, 65])
print(decode(idx[0].to('cpu').detach().numpy().copy()))
# spaceship orbiting Mars<pad><pad><pad><pad><pad><pad><pad><pad><pad>
print(targets[0])
# tensor([54, 39, 41, 43, 57, 46, 47, 54, 1, 53, 56, 40, 47, 58, 47, 52, 45, 1, 25, 39, 56, 57, 65, 65, 65, 65, 65, 65, 65, 65, 65, 65])
print(decode(targets[0].to('cpu').detach().numpy().copy()))
# paceship orbiting Mars<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
VLMの実装
今回開発するVLMは以下のVisionLanguageModel
クラスになります.
class VisionLanguageModel(nn.Module):
def __init__(self, n_embd, image_embed_dim, vocab_size, n_layer, img_size, patch_size, num_heads, num_blks, emb_dropout, blk_dropout):
super().__init__()
# Set num_hiddens equal to image_embed_dim
num_hiddens = image_embed_dim
# Assert that num_hiddens is divisible by num_heads
assert num_hiddens % num_heads == 0, "num_hiddens must be divisible by num_heads"
# Initialize the vision encoder (ViT)
self.vision_encoder = ViT(img_size, patch_size, num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout)
# Initialize the language model decoder (DecoderLanguageModel)
self.decoder = DecoderLanguageModel(n_embd, image_embed_dim, vocab_size, num_heads, n_layer, use_images=True)
def forward(self, img_array, idx, targets=None):
# Get the image embeddings from the vision encoder
image_embeds = self.vision_encoder(img_array)
# Check if the image embeddings are valid
if image_embeds.nelement() == 0 or image_embeds.shape[1] == 0:
raise ValueError("Something is wrong with the ViT model. It's returning an empty tensor or the embedding dimension is empty.")
if targets is not None:
# If targets are provided, compute the logits and loss
logits, loss = self.decoder(idx, image_embeds, targets)
return logits, loss
else:
# If targets are not provided, compute only the logits
logits = self.decoder(idx, image_embeds)
return logits
def generate(self, img_array, idx, max_new_tokens):
# Get the image embeddings from the vision encoder
image_embeds = self.vision_encoder(img_array)
# Check if the image embeddings are valid
if image_embeds.nelement() == 0 or image_embeds.shape[1] == 0:
raise ValueError("Something is wrong with the ViT model. It's returning an empty tensor or the embedding dimension is empty.")
# Generate new tokens using the language model decoder
generated_tokens = self.decoder.generate(idx, image_embeds, max_new_tokens)
return generated_tokens
以下では,VisionLanguageModel
クラスで必要な4つのコンポーネントを実装します.
-
TransfomerBlock
:DecoderLanguageModel
とViT
の実装で使用される「Transformer Block」 -
DecoderLanguageModel
:テキストを生成する「デコーダのみの言語モデル」 -
MultiModalProjector
:画像特徴をテキスト埋め込み空間に投影する「画像言語プロジェクタ」 -
ViT
:画像から画像特徴を抽出する「画像エンコーダ」
TransfomerBlock
class TransfomerBlock(nn.Module):
def __init__(self, n_embd, num_heads, dropout=0.1, is_decoder=False):
super().__init__()
# Layer normalization for the input to the attention layer
self.ln1 = nn.LayerNorm(n_embd)
# Multi-head attention module
self.attn = MultiHeadAttention(n_embd, num_heads, dropout, is_decoder)
# Layer normalization for the input to the FFN
self.ln2 = nn.LayerNorm(n_embd)
# Feed-forward neural network (FFN)
self.ffn = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd), # Expand the dimension
nn.GELU(), # Activation function
nn.Linear(4 * n_embd, n_embd), # Project back to the original dimension
)
def forward(self, x):
original_x = x # Save the input for the residual connection
# Apply layer normalization to the input
x = self.ln1(x)
# Apply multi-head attention
attn_output = self.attn(x)
# Add the residual connection (original input) to the attention output
x = original_x + attn_output
# Apply layer normalization to the input to the FFN
x = self.ln2(x)
# Apply the FFN
ffn_output = self.ffn(x)
# Add the residual connection (input to FFN) to the FFN output
x = x + ffn_output
return x
MultiHeadAttention
class MultiHeadAttention(nn.Module):
def __init__(self, n_embd, num_heads, dropout=0.1, is_decoder=False):
super().__init__()
# Ensure that the embedding dimension is divisible by the number of heads
assert n_embd % num_heads == 0, "n_embd must be divisible by num_heads"
# Create a ModuleList of attention heads
self.heads = nn.ModuleList([
Head(n_embd, n_embd // num_heads, dropout, is_decoder)
for _ in range(num_heads)
])
# Linear layer for projecting the concatenated head outputs
self.proj = nn.Linear(n_embd, n_embd)
# Dropout layer for regularization
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Apply each attention head to the input tensor
head_outputs = [h(x) for h in self.heads]
# Concatenate the outputs from all heads along the last dimension
out = torch.cat(head_outputs, dim=-1)
# Apply the projection layer to the concatenated outputs
out = self.proj(out)
# Apply dropout to the projected outputs for regularization
out = self.dropout(out)
return out
Head
class Head(nn.Module):
def __init__(self, n_embd, head_size, dropout=0.1, is_decoder=False):
super().__init__()
# Linear layer for key projection
self.key = nn.Linear(n_embd, head_size, bias=False)
# Linear layer for query projection
self.query = nn.Linear(n_embd, head_size, bias=False)
# Linear layer for value projection
self.value = nn.Linear(n_embd, head_size, bias=False)
# Dropout layer for regularization
self.dropout = nn.Dropout(dropout)
# Flag indicating whether this head is used in the decoder
self.is_decoder = is_decoder
def forward(self, x):
# Get the batch size (B), sequence length (T), and embedding dimension (C) from the input tensor
B, T, C = x.shape
# Compute key, query, and value projections
k = self.key(x) # Shape: [B, T, head_size]
q = self.query(x) # Shape: [B, T, head_size]
v = self.value(x) # Shape: [B, T, head_size]
# Compute attention scores by taking the dot product of query and key
# and scaling by the square root of the embedding dimension
wei = q @ k.transpose(-2, -1) * (C ** -0.5) # Shape: [B, T, T]
if self.is_decoder:
# If this head is used in the decoder, apply a causal mask to the attention scores
# to prevent attending to future positions
tril = torch.tril(torch.ones(T, T, dtype=torch.bool, device=x.device))
wei = wei.masked_fill(tril == 0, float('-inf'))
# Apply softmax to the attention scores to obtain attention probabilities
wei = F.softmax(wei, dim=-1) # Shape: [B, T, T]
# Apply dropout to the attention probabilities for regularization
wei = self.dropout(wei)
# Perform weighted aggregation of values using the attention probabilities
out = wei @ v # Shape: [B, T, head_size]
return out
DecoderLanguageModel
class DecoderLanguageModel(nn.Module):
def __init__(self, n_embd, image_embed_dim, vocab_size, num_heads, n_layer, use_images=False):
super().__init__()
self.use_images = use_images
# Token embedding table
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
# Position embedding table
self.position_embedding_table = nn.Embedding(1000, n_embd)
if use_images:
# Image projection layer to align image embeddings with text embeddings
self.image_projection = MultiModalProjector(n_embd, image_embed_dim)
# Stack of transformer decoder blocks
self.blocks = nn.Sequential(*[TransfomerBlock(n_embd, num_heads, is_decoder=True) for _ in range(n_layer)])
# Final layer normalization
self.ln_f = nn.LayerNorm(n_embd)
# Language modeling head
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, image_embeds=None, targets=None):
# Get token embeddings from the input indices
tok_emb = self.token_embedding_table(idx)
if self.use_images and image_embeds is not None:
# Project and concatenate image embeddings with token embeddings
img_emb = self.image_projection(image_embeds).unsqueeze(1)
tok_emb = torch.cat([img_emb, tok_emb], dim=1)
# Get position embeddings
pos_emb = self.position_embedding_table(torch.arange(tok_emb.size(1), device=device)).unsqueeze(0)
# Add position embeddings to token embeddings
x = tok_emb + pos_emb
# Pass through the transformer decoder blocks
x = self.blocks(x)
# Apply final layer normalization
x = self.ln_f(x)
# Get the logits from the language modeling head
logits = self.lm_head(x)
if targets is not None:
if self.use_images and image_embeds is not None:
# Prepare targets by concatenating a dummy target for the image embedding
batch_size = idx.size(0)
targets = torch.cat([torch.full((batch_size, 1), -100, dtype=torch.long, device=device), targets], dim=1)
# Compute the cross-entropy loss
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
return logits, loss
return logits
def generate(self, idx, image_embeds, max_new_tokens):
# Get the batch size and sequence length
B, T = idx.shape
# Initialize the generated sequence with the input indices
generated = idx
if self.use_images and image_embeds is not None:
# Project and concatenate image embeddings with token embeddings
img_emb = self.image_projection(image_embeds).unsqueeze(1)
current_output = torch.cat([img_emb, self.token_embedding_table(idx)], dim=1)
else:
current_output = self.token_embedding_table(idx)
# Generate new tokens iteratively
for i in range(max_new_tokens):
# Get the current sequence length
T_current = current_output.size(1)
# Get position embeddings for the current sequence length
current_pos_emb = self.position_embedding_table(torch.arange(T_current, device=device)).unsqueeze(0)
# Add position embeddings to the current output
current_output += current_pos_emb
# Pass through the transformer decoder blocks
for block in self.blocks:
current_output = block(current_output)
# Get the logits for the last token
logits = self.lm_head(current_output[:, -1, :])
# Apply softmax to get probabilities
probs = F.softmax(logits, dim=-1)
# Sample the next token based on the probabilities
idx_next = torch.multinomial(probs, num_samples=1)
# Concatenate the generated token to the generated sequence
generated = torch.cat((generated, idx_next), dim=1)
# Get the embeddings for the generated token
idx_next_emb = self.token_embedding_table(idx_next)
# Concatenate the generated token embeddings to the current output
current_output = torch.cat((current_output, idx_next_emb), dim=1)
return generated
MultiModalProjector
class MultiModalProjector(nn.Module):
def __init__(self, n_embd, image_embed_dim, dropout=0.1):
super().__init__()
# Define the projection network
self.net = nn.Sequential(
# Linear layer to expand the image embedding dimension
nn.Linear(image_embed_dim, 4 * image_embed_dim),
# GELU activation function
nn.GELU(),
# Linear layer to project the expanded image embeddings to the text embedding dimension
nn.Linear(4 * image_embed_dim, n_embd),
# Dropout layer for regularization
nn.Dropout(dropout)
)
def forward(self, x):
# Pass the input through the projection network
x = self.net(x)
return x
ViT
最終的に返される表現は,CLSトークンに対応する埋め込みで,これは言語デコーダでのテキスト生成の条件として使用されます.
出典:Implement a Vision Language Model from Scratch
class ViT(nn.Module):
def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout):
super().__init__()
# Patch embedding layer to convert the input image into patches
self.patch_embedding = PatchEmbeddings(img_size, patch_size, num_hiddens)
# Learnable classification token
self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))
# Calculate the number of patches
num_patches = (img_size // patch_size) ** 2
# Learnable position embedding
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens))
# Dropout layer for the embeddings
self.dropout = nn.Dropout(emb_dropout)
# Stack of transformer blocks
self.blocks = nn.ModuleList([TransfomerBlock(num_hiddens, num_heads, blk_dropout, is_decoder=False) for _ in range(num_blks)])
# Layer normalization for the final representation
self.layer_norm = nn.LayerNorm(num_hiddens)
def forward(self, X):
# Convert the input image into patch embeddings
x = self.patch_embedding(X)
# Expand the classification token to match the batch size
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
# Concatenate the classification token with the patch embeddings
x = torch.cat((cls_tokens, x), dim=1)
# Add the position embedding to the patch embeddings
x += self.pos_embedding
# Apply dropout to the embeddings
x = self.dropout(x)
# Pass the embeddings through the transformer blocks
for block in self.blocks:
x = block(x)
# Apply layer normalization to the final representation
x = self.layer_norm(x[:, 0])
return x
PatchEmbeddings
下記のコードでは,入力画像は畳み込みレイヤーを使って (img_size // patch_size) ** 2パッチに分解され,512次元のベクトルに投影されます.
class PatchEmbeddings(nn.Module):
def __init__(self, img_size=96, patch_size=16, hidden_dim=512):
super().__init__()
# Store the input image size
self.img_size = img_size
# Store the size of each patch
self.patch_size = patch_size
# Calculate the total number of patches
self.num_patches = (img_size // patch_size) ** 2
# Create a convolutional layer to extract patch embeddings
# in_channels=3 assumes the input image has 3 color channels (RGB)
# out_channels=hidden_dim sets the number of output channels to match the hidden dimension
# kernel_size=patch_size and stride=patch_size ensure each patch is separately embedded
self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, X):
# Extract patch embeddings from the input image
X = self.conv(X)
# Flatten the spatial dimensions (height and width) of the patch embeddings
# This step flattens the patch dimensions into a single dimension
X = X.flatten(2)
# Transpose the dimensions to obtain the shape [batch_size, num_patches, hidden_dim]
# This step brings the num_patches dimension to the second position
X = X.transpose(1, 2)
return X
VLMの学習
# Adjusting the training loop from makemore for multimodal data
def train_model(model, df, epochs, vocab_size, img_size=96):
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
model.to(device)
for epoch in range(epochs):
model.train()
for _ in range(max_iters):
images, idx, targets = get_batch(df, batch_size, 'train', img_size)
optimizer.zero_grad()
logits, loss = model(images, idx, targets)
loss.backward()
optimizer.step()
if _ % eval_interval == 0:
print(f"Loss at iteration {_}: {loss.item()}")
val_loss = estimate_loss(model, df, 'val', img_size, val_batch_size=8)
print(f"Validation Loss after epoch {epoch+1}: {val_loss}")
def estimate_loss(model, df, split, img_size=96, val_batch_size=8):
losses = []
model.eval()
for _ in range(eval_iters):
images, idx, targets = get_batch(df, batch_size, split, img_size, val_batch_size=val_batch_size)
_, loss = model(images, idx, targets)
losses.append(loss.item())
return sum(losses) / len(losses)
train_model(model, df, epochs, vocab_size, img_size)
# Loss at iteration 0: 4.2010297775268555
# Loss at iteration 10: 0.48016276955604553
# Loss at iteration 20: 0.09957580268383026
# Loss at iteration 30: 0.051468685269355774
# Loss at iteration 40: 0.044466760009527206
# Loss at iteration 50: 0.03906780481338501
# Loss at iteration 60: 0.03154975175857544
# Loss at iteration 70: 0.02567865699529648
# Loss at iteration 80: 0.016583602875471115
# Loss at iteration 90: 0.0224172230809927
# Validation Loss after epoch 1: 0.021268921508453786
VLMの推論
images, idx, _ = get_batch(df, batch_size, 'val', img_size)
print(images.shape)
# torch.Size([8, 3, 96, 96])
print(idx.shape)
# torch.Size([8, 32])
from PIL import Image
import matplotlib.pyplot as plt
plt.imshow(images[0].to('cpu').detach().numpy().copy().transpose(1,2,0))
plt.show()
print(idx[0])
# tensor([39, 1, 56, 43, 42, 1, 57, 54, 53, 56, 58, 57, 1, 41, 39, 56, 1, 61, 47, 58, 46, 1, 40, 47, 45, 1, 61, 46, 43, 43, 50, 57])
print(decode(idx[0].to('cpu').detach().numpy().copy()))
# a red sports car with big wheels
output_ids = model.generate(images, idx, 8)
print(decode(output_ids[0].to('cpu').detach().numpy().copy()))
# a red sports car with big wheels<pad><pad><pad><pad><pad><pad><pad><pad>
まとめ
この記事では,VLM(Vision Language Model)の簡易版をスクラッチ開発しました.
詳しくは,下記の記事を参照してください.
今後は,ALM(Audio Language Model)やStable Diffusion(画像生成モデル)などもスクラッチ開発したい思っています.