1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

確率的 Viterbi サンプリングの python コード

Last updated at Posted at 2026-01-15

コードを考案した動機と公開について

画像キャプショニング分野において、CRF を考慮した PPO の強化学習を行っています。強化学習では、サンプリングしたログ確率と報酬を計算するための生成キャプションは、multinomial で確率的にサンプリングするのが常のようです。生成したキャプションの評価に使う推論キャプションは argmax (max) で greedy に生成するのが常のようです。viterbi アルゴリズムは greedy であるという印象で、推論キャプションには使えますが、報酬を計算する確率的にサンプリングしたキャプションには使えません。そこで、確率的にサンプリングする viterbi アルゴリズムを考案できないかと考えました。その結果を公開させていただきたいと存じます。理論的に正しいかどうかは読者の皆様のご判断に委ねます。

強化学習の方は、まだ、ネットで公開できるような学習はできていないので、ここに確率的 Viterbi サンプリングのコードだけ公開さていただきます。強化学習で効果を発揮できなくても、Viterbi decoding のバージョンを増やすことはできるのではないかと考えています。

実際に生成したキャプション

強化学習において、実際に生成したキャプション。refe がターゲットキャプション、hypo が greedy な Viterbi アルゴリズムで生成したキャプション、samp が stochastic Viterbi サンプリングで生成したキャプション。

refe: [CLS] in this picture there are boats and there is a building. at the back there are trees and there are trees on the mountains. at the top there is sky. at the bottom there is water and sand. [SEP]
hypo: [CLS] in this image we can see the water left side of the water, we can see the background we can see the middle of the ground and there are trees. [SEP]
samp: [CLS] in this picture we can see the water, some trees. i can see the right side of the background we can see the middle of the ground and trees. [SEP]

一つのサンプルを生成するコード

関数入力の emissions は bertの last_hidden_state を nn.LayerNorm と nn.Linear で ( bsz, seq_len, vocab_size )にしたものを想定しています。

import torch
import torch.nn as nn
import torch.nn.functional as F

class StochasticViterbiSample(nn.Module):
    def __init__(self, num_embedding, low_rank=32, beam_size=256, temp = 1.0):
        super().__init__()

        self.E1 = nn.Embedding(num_embedding, low_rank)
        self.E2 = nn.Embedding(num_embedding, low_rank)

        self.rank = low_rank
        self.beam = beam_size
        self.temp = temp

    def _compute_stochastic_viterbi_sample(self, emissions, beam=None):

        eps = 1e-8
        device = emissions.device
        
        beam = beam if beam is not None else self.beam
        
        beam_emission_scores, beam_targets = torch.topk( emissions, beam, 2)        
        
        batch_size, seq_len = beam_emission_scores.size()[:2]

        beam_transition_score1 = self.E1(beam_targets[:, :-1])  # B x (T-1) x K x D
        beam_transition_score2 = self.E2(beam_targets[:, 1:])   # B x (T-1) x K x D
        beam_transition_matrix = torch.bmm(
            beam_transition_score1.view(-1, beam, self.rank),
            beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
        beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam) # bsz, seq_len, beam, beam

        traj_tokens, traj_scores = [], []
        finalized_tokens, finalized_scores = [], []

        score = beam_emission_scores[:, 0]  # B x K
        
        for i in range(1, seq_len):
            traj_scores.append(score)
            _score = score[:, :, None] + beam_transition_matrix[:, i-1] # bsz, beam, beam

            # greedy selection
            #_score, _index = _score.max(dim=1) # bsz, beam     bsz, beam 

            # multinomial selection
            B, C, W = _score.shape
            flat_score = _score.permute(0, 2, 1).reshape(-1, C)
            probs = F.softmax(flat_score / self.temp, dim=-1)
            _index_flat = torch.multinomial(probs, num_samples=1)
            _score_flat = torch.gather(flat_score, -1, _index_flat)
            _index = _index_flat.view(B, W)
            _score = _score_flat.view(B, W)

            _score = _score + beam_emission_scores[:, i] # bsz, beam
            
            #if masks is not None:
            #    score = torch.where(masks[:, i: i+1], _score, score)
            #    index = torch.where(masks[:, i: i+1], _index, dummy)
            #else:
            score, index = _score, _index
            traj_tokens.append(index)
        
        all_scores = traj_scores
        all_scores.append( score )
        all_scores = torch.stack( all_scores, dim = 0 ).transpose( 0, 1 ).to(device)
        beam_probs = F.softmax( all_scores, dim = 2 )

        # now running the back-tracing and find the best
        best_score, best_index = score.max(dim=1)
        finalized_tokens.append(best_index[:, None])
        finalized_scores.append(best_score[:, None])

        for idx, scs in zip(reversed(traj_tokens), reversed(traj_scores)):
            previous_index = finalized_tokens[-1]
            finalized_tokens.append(idx.gather(1, previous_index))
            finalized_scores.append(scs.gather(1, previous_index))

        finalized_tokens.reverse()
        sampled_beam_idx = torch.cat(finalized_tokens, 1)
        finalized_tokens = beam_targets.gather(2, sampled_beam_idx[:,:,None])[:, :, 0]

        finalized_scores.reverse()
        finalized_scores = torch.cat(finalized_scores, 1)
        finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]

        return beam_probs, sampled_beam_idx.unsqueeze(-1), finalized_tokens
        

test = StochasticViterbiSample( 30000 )

emissions = torch.randn( ( 8, 97, 30000 ) )

beam_probs, sampled_beam_idx, finalized_tokens = test._compute_stochastic_viterbi_sample( emissions )

print( beam_probs.size() )
print( sampled_beam_idx.size() )
print( finalized_tokens.size() )
torch.Size([8, 97, 256])
torch.Size([8, 97, 1])
torch.Size([8, 97])

複数のサンプルを生成するコード

GRPO の baseline 計算には、複数のサンプリングが必要なので、このコードを考えました。サンプル数を 8 にしています。

class StochasticViterbiSamples(nn.Module):
    def __init__(self, num_embedding, low_rank=32, beam_size=256, temp = 1.0, num_samples = 8 ):
        super().__init__()

        self.E1 = nn.Embedding(num_embedding, low_rank)
        self.E2 = nn.Embedding(num_embedding, low_rank)

        self.rank = low_rank
        self.beam = beam_size
        self.temp = temp
        self.num_samples = num_samples

    def _compute_grpo_samples(self, emissions, masks=None, beam=None):

        eps = 1e-8
        device = emissions.device
        
        beam = beam if beam is not None else self.beam
        
        beam_emission_scores, beam_targets = torch.topk( emissions, beam, 2)        
        
        batch_size, seq_len = beam_emission_scores.size()[:2]

        beam_transition_score1 = self.E1(beam_targets[:, :-1])  # B x (T-1) x K x D
        beam_transition_score2 = self.E2(beam_targets[:, 1:])   # B x (T-1) x K x D
        beam_transition_matrix = torch.bmm(
            beam_transition_score1.view(-1, beam, self.rank),
            beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2))
        beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam) # bsz, seq_len, beam, beam

        traj_tokens, traj_scores = [], []
        finalized_tokens, finalized_scores = [], []

        score = beam_emission_scores[:, 0, :,  None].expand( -1, -1, self.num_samples )  # B x K, N
        
        for i in range(1, seq_len):
            traj_scores.append(score)
            _score = score[:, :, None] + beam_transition_matrix[:, i-1, :, :, None] # bsz, beam, beam, 1

            # greedy selection
            #_score, _index = _score.max(dim=1) # bsz, beam     bsz, beam 

            # multinomial selection
            B, C, W, _ = _score.shape
            N = self.num_samples
            flat_score = _score.permute(0, 2, 3, 1).reshape(-1, C) # b * W * N, C
            probs = F.softmax(flat_score / self.temp, dim=-1)
            _index_flat = torch.multinomial(probs, num_samples=1, replacement=True)
            _score_flat = torch.gather(flat_score, -1, _index_flat)
            _index = _index_flat.view(B, W, N)
            _score = _score_flat.view(B, W, N)

            _score = _score + beam_emission_scores[:, i, :, None] # bsz, beam
            
            #if masks is not None:
            #    score = torch.where(masks[:, i: i+1], _score, score)
            #    index = torch.where(masks[:, i: i+1], _index, dummy)
            #else:
            score, index = _score, _index # bsz, N
            traj_tokens.append(index) 
        
        all_scores = traj_scores
        all_scores.append( score )
        all_scores = torch.stack( all_scores, dim = 0 ).transpose( 0, 1 ).to(device) #bsz, seq_len, beam, N
        beam_probs = F.softmax( all_scores, dim = 2 ) #bsz, seq_len, beam, N
        
        # now running the back-tracing and find the best
        best_score, best_index = score.max(dim=1) # max( bsz, beam ), bsz, N
        finalized_tokens.append(best_index[:, None, :]) #bsz,1, N
        finalized_scores.append(best_score[:, None, :]) #bsz,1, N

        for idx, scs in zip(reversed(traj_tokens), reversed(traj_scores)): # each of seq_len -1, bsz, beam, N 
            previous_index = finalized_tokens[-1]
            finalized_tokens.append(idx.gather(1, previous_index))
            finalized_scores.append(scs.gather(1, previous_index))

        finalized_tokens.reverse() # seq_len, bsz, N
        sampled_beam_idx = torch.cat(finalized_tokens, 1) # seq_len, bsz, N
        finalized_tokens = beam_targets.gather(2, sampled_beam_idx)
        
        finalized_scores.reverse()
        finalized_scores = torch.cat(finalized_scores, 1)
        finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1]
        
        return beam_probs, sampled_beam_idx, finalized_tokens 
test2 = StochasticViterbiSamples( 30000 )

emissions = torch.randn( ( 8, 97, 30000 ) )
beam_probs, sampled_beam_idx, finalized_tokens = test2._compute_grpo_samples( emissions )

print( beam_probs.size() )
print( sampled_beam_idx.size() )
print( finalized_tokens.size() )
torch.Size([8, 97, 256, 8])
torch.Size([8, 97, 8])
torch.Size([8, 97, 8])

お役に立てれば幸いです。よろしくおねがいします。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?