コードを考案した動機と公開について
画像キャプショニング分野において、CRF を考慮した PPO の強化学習を行っています。強化学習では、サンプリングしたログ確率と報酬を計算するための生成キャプションは、multinomial で確率的にサンプリングするのが常のようです。生成したキャプションの評価に使う推論キャプションは argmax (max) で greedy に生成するのが常のようです。viterbi アルゴリズムは greedy であるという印象で、推論キャプションには使えますが、報酬を計算する確率的にサンプリングしたキャプションには使えません。そこで、確率的にサンプリングする viterbi アルゴリズムを考案できないかと考えました。その結果を公開させていただきたいと存じます。理論的に正しいかどうかは読者の皆様のご判断に委ねます。
強化学習の方は、まだ、ネットで公開できるような学習はできていないので、ここに確率的 Viterbi サンプリングのコードだけ公開さていただきます。強化学習で効果を発揮できなくても、Viterbi decoding のバージョンを増やすことはできるのではないかと考えています。
実際に生成したキャプション
強化学習において、実際に生成したキャプション。refe がターゲットキャプション、hypo が greedy な Viterbi アルゴリズムで生成したキャプション、samp が stochastic Viterbi サンプリングで生成したキャプション。
refe: [CLS] in this picture there are boats and there is a building. at the back there are trees and there are trees on the mountains. at the top there is sky. at the bottom there is water and sand. [SEP]
hypo: [CLS] in this image we can see the water left side of the water, we can see the background we can see the middle of the ground and there are trees. [SEP]
samp: [CLS] in this picture we can see the water, some trees. i can see the right side of the background we can see the middle of the ground and trees. [SEP]
一つのサンプルを生成するコード
関数入力の emissions は bertの last_hidden_state を nn.LayerNorm と nn.Linear で ( bsz, seq_len, vocab_size )にしたものを想定しています。
import torch
import torch.nn as nn
import torch.nn.functional as F
class StochasticViterbiSample(nn.Module):
def __init__(self, num_embedding, low_rank=32, beam_size=256, temp = 1.0):
super().__init__()
self.E1 = nn.Embedding(num_embedding, low_rank)
self.E2 = nn.Embedding(num_embedding, low_rank)
self.rank = low_rank
self.beam = beam_size
self.temp = temp
def _compute_stochastic_viterbi_sample(self, emissions, beam=None):
eps = 1e-8
device = emissions.device
beam = beam if beam is not None else self.beam
beam_emission_scores, beam_targets = torch.topk( emissions, beam, 2)
batch_size, seq_len = beam_emission_scores.size()[:2]
beam_transition_score1 = self.E1(beam_targets[:, :-1]) # B x (T-1) x K x D
beam_transition_score2 = self.E2(beam_targets[:, 1:]) # B x (T-1) x K x D
beam_transition_matrix = torch.bmm(
beam_transition_score1.view(-1, beam, self.rank),
beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam) # bsz, seq_len, beam, beam
traj_tokens, traj_scores = [], []
finalized_tokens, finalized_scores = [], []
score = beam_emission_scores[:, 0] # B x K
for i in range(1, seq_len):
traj_scores.append(score)
_score = score[:, :, None] + beam_transition_matrix[:, i-1] # bsz, beam, beam
# greedy selection
#_score, _index = _score.max(dim=1) # bsz, beam bsz, beam
# multinomial selection
B, C, W = _score.shape
flat_score = _score.permute(0, 2, 1).reshape(-1, C)
probs = F.softmax(flat_score / self.temp, dim=-1)
_index_flat = torch.multinomial(probs, num_samples=1)
_score_flat = torch.gather(flat_score, -1, _index_flat)
_index = _index_flat.view(B, W)
_score = _score_flat.view(B, W)
_score = _score + beam_emission_scores[:, i] # bsz, beam
#if masks is not None:
# score = torch.where(masks[:, i: i+1], _score, score)
# index = torch.where(masks[:, i: i+1], _index, dummy)
#else:
score, index = _score, _index
traj_tokens.append(index)
all_scores = traj_scores
all_scores.append( score )
all_scores = torch.stack( all_scores, dim = 0 ).transpose( 0, 1 ).to(device)
beam_probs = F.softmax( all_scores, dim = 2 )
# now running the back-tracing and find the best
best_score, best_index = score.max(dim=1)
finalized_tokens.append(best_index[:, None])
finalized_scores.append(best_score[:, None])
for idx, scs in zip(reversed(traj_tokens), reversed(traj_scores)):
previous_index = finalized_tokens[-1]
finalized_tokens.append(idx.gather(1, previous_index))
finalized_scores.append(scs.gather(1, previous_index))
finalized_tokens.reverse()
sampled_beam_idx = torch.cat(finalized_tokens, 1)
finalized_tokens = beam_targets.gather(2, sampled_beam_idx[:,:,None])[:, :, 0]
finalized_scores.reverse()
finalized_scores = torch.cat(finalized_scores, 1)
finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]
return beam_probs, sampled_beam_idx.unsqueeze(-1), finalized_tokens
test = StochasticViterbiSample( 30000 )
emissions = torch.randn( ( 8, 97, 30000 ) )
beam_probs, sampled_beam_idx, finalized_tokens = test._compute_stochastic_viterbi_sample( emissions )
print( beam_probs.size() )
print( sampled_beam_idx.size() )
print( finalized_tokens.size() )
torch.Size([8, 97, 256])
torch.Size([8, 97, 1])
torch.Size([8, 97])
複数のサンプルを生成するコード
GRPO の baseline 計算には、複数のサンプリングが必要なので、このコードを考えました。サンプル数を 8 にしています。
class StochasticViterbiSamples(nn.Module):
def __init__(self, num_embedding, low_rank=32, beam_size=256, temp = 1.0, num_samples = 8 ):
super().__init__()
self.E1 = nn.Embedding(num_embedding, low_rank)
self.E2 = nn.Embedding(num_embedding, low_rank)
self.rank = low_rank
self.beam = beam_size
self.temp = temp
self.num_samples = num_samples
def _compute_grpo_samples(self, emissions, masks=None, beam=None):
eps = 1e-8
device = emissions.device
beam = beam if beam is not None else self.beam
beam_emission_scores, beam_targets = torch.topk( emissions, beam, 2)
batch_size, seq_len = beam_emission_scores.size()[:2]
beam_transition_score1 = self.E1(beam_targets[:, :-1]) # B x (T-1) x K x D
beam_transition_score2 = self.E2(beam_targets[:, 1:]) # B x (T-1) x K x D
beam_transition_matrix = torch.bmm(
beam_transition_score1.view(-1, beam, self.rank),
beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam) # bsz, seq_len, beam, beam
traj_tokens, traj_scores = [], []
finalized_tokens, finalized_scores = [], []
score = beam_emission_scores[:, 0, :, None].expand( -1, -1, self.num_samples ) # B x K, N
for i in range(1, seq_len):
traj_scores.append(score)
_score = score[:, :, None] + beam_transition_matrix[:, i-1, :, :, None] # bsz, beam, beam, 1
# greedy selection
#_score, _index = _score.max(dim=1) # bsz, beam bsz, beam
# multinomial selection
B, C, W, _ = _score.shape
N = self.num_samples
flat_score = _score.permute(0, 2, 3, 1).reshape(-1, C) # b * W * N, C
probs = F.softmax(flat_score / self.temp, dim=-1)
_index_flat = torch.multinomial(probs, num_samples=1, replacement=True)
_score_flat = torch.gather(flat_score, -1, _index_flat)
_index = _index_flat.view(B, W, N)
_score = _score_flat.view(B, W, N)
_score = _score + beam_emission_scores[:, i, :, None] # bsz, beam
#if masks is not None:
# score = torch.where(masks[:, i: i+1], _score, score)
# index = torch.where(masks[:, i: i+1], _index, dummy)
#else:
score, index = _score, _index # bsz, N
traj_tokens.append(index)
all_scores = traj_scores
all_scores.append( score )
all_scores = torch.stack( all_scores, dim = 0 ).transpose( 0, 1 ).to(device) #bsz, seq_len, beam, N
beam_probs = F.softmax( all_scores, dim = 2 ) #bsz, seq_len, beam, N
# now running the back-tracing and find the best
best_score, best_index = score.max(dim=1) # max( bsz, beam ), bsz, N
finalized_tokens.append(best_index[:, None, :]) #bsz,1, N
finalized_scores.append(best_score[:, None, :]) #bsz,1, N
for idx, scs in zip(reversed(traj_tokens), reversed(traj_scores)): # each of seq_len -1, bsz, beam, N
previous_index = finalized_tokens[-1]
finalized_tokens.append(idx.gather(1, previous_index))
finalized_scores.append(scs.gather(1, previous_index))
finalized_tokens.reverse() # seq_len, bsz, N
sampled_beam_idx = torch.cat(finalized_tokens, 1) # seq_len, bsz, N
finalized_tokens = beam_targets.gather(2, sampled_beam_idx)
finalized_scores.reverse()
finalized_scores = torch.cat(finalized_scores, 1)
finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]
return beam_probs, sampled_beam_idx, finalized_tokens
test2 = StochasticViterbiSamples( 30000 )
emissions = torch.randn( ( 8, 97, 30000 ) )
beam_probs, sampled_beam_idx, finalized_tokens = test2._compute_grpo_samples( emissions )
print( beam_probs.size() )
print( sampled_beam_idx.size() )
print( finalized_tokens.size() )
torch.Size([8, 97, 256, 8])
torch.Size([8, 97, 8])
torch.Size([8, 97, 8])
お役に立てれば幸いです。よろしくおねがいします。