refer to code first.
.
#this beam search only deal with batch size 1
def beam_search(self, pixel_value, max_length):
beam_size = self.cfg.num_beams
alpha = self.cfg.beam_alpha # Length normalization coefficient
temperature = self.cfg.beam_temp # Temperature for softmax
# Initialize input ids as bos_token_id
first_sequence = torch.full((pixel_value.shape[0], 1), self.model.config.decoder_start_token_id).to(pixel_value.device)
# ic(first_sequence) #tensor([[1]])
# Predict second token id
outputs = self.forward_pass(pixel_value, first_sequence)
# ic(outputs.keys()) #dict_keys(['logits', 'loss'])
# We only need the logits corresponding to the last prediction
next_token_logits = outputs['logits'][:, -1, :]
# ic(outputs['logits'].shape) #[1, 1, 13] batch, seq, vocab_size
# ic(outputs['logits'][:, -1, :].shape) #[1, 13] batch, vocab_size
# Apply temperature
# ic(next_token_logits)
# [-5.0641, 32.7805, -2.6743, -4.6459, 0.8130, -1.3443, -1.2016, -4.0770,
# -3.5401, 0.2425, -5.3685, -1.8074, -5.2606]],
# next_token_logits /= temperature
# ic(next_token_logits)
# [-7.2344, 46.8292, -3.8204, -6.6370, 1.1614, -1.9205, -1.7166, -5.8243,
# -5.0573, 0.3464, -7.6693, -2.5820, -7.5152]],
# Select top k tokens
next_token_probs = F.softmax(next_token_logits, dim=-1)
top_k_probs, top_k_ids = torch.topk(next_token_probs, beam_size)
# ic(F.softmax(next_token_logits, dim=-1))
# tensor([[3.3148e-24, 1.0000e+00, 1.0072e-22, 6.0241e-24, 1.4680e-20, 6.7340e-22,
# 8.2570e-22, 1.3579e-23, 2.9239e-23, 6.4976e-21, 2.1458e-24, 3.4751e-22,
# 2.5034e-24]]
# ic(top_k_probs, top_k_ids)
# top_k_probs: tensor([[1.]], grad_fn=<TopkBackward0>)
# top_k_ids: tensor([[1]])
# Prepare next sequences. Each top 1 token is appended to the first_sequence
# ic(first_sequence.shape) #[1, 1]
next_sequences = first_sequence.repeat_interleave(beam_size, dim=0)
# ic(next_sequences.shape) #[10, 1] 10 is beam size, 1 is seq length
next_sequences = torch.cat([next_sequences, top_k_ids.view(-1, 1)], dim=-1)
# ic(next_sequences.shape) #[10, 2] 10 is beam size, 2 is seq length
# ic(next_sequences)
# Also prepare a tensor to hold the cumulative scores of each sequence, or the sum of the log probabilities of each token in the sequence
sequence_scores = (torch.log(top_k_probs).view(-1)) #/ (1 + 1) ** alpha
# ic(sequence_scores) #[ 0.0000, -15.9837]
# We'll need to repeat the pixel_values for each sequence in each beam
pixel_value = pixel_value.repeat_interleave(beam_size, dim=0)
# ic(pixel_value.shape) #[10, 3, 224, 224], 10 is beam size, 3 is channel, 224 is image size
for idx in range(max_length - 1): # We already generated one token
# ic(idx, '--------------------')
outputs = self.forward_pass(pixel_value, next_sequences)
next_token_logits = outputs['logits'][:, -1, :]
# ic(outputs['logits'].shape, outputs['logits']) #[2, 2, 13], batch, seq, vocab_size
# ic(next_token_logits.shape, next_token_logits)
# Apply temperature
# next_token_logits /= temperature
# Convert logits to probabilities and calculate new scores
next_token_probs = F.softmax(next_token_logits, dim=-1)
# ic(next_token_probs.shape, next_token_probs) #[2, 13], batch, vocab_size
next_token_scores = torch.log(next_token_probs)
# ic(next_token_scores.shape, next_token_scores) #[2, 13], batch, vocab_size
new_scores = sequence_scores.unsqueeze(1) + next_token_scores
# ic(sequence_scores.unsqueeze(1))
# ic(new_scores.shape, new_scores) #[2, 13], batch, vocab_size
# Select top k sequences
# ic(new_scores.view(-1), new_scores.view(-1).shape)
top_k_scores, top_k_indices = torch.topk(new_scores.view(-1), beam_size)
# ic(top_k_scores, top_k_indices)
# Get the beam and token that each of the top k sequences comes from
beams_indices = top_k_indices // self.cfg.num_tokens
token_indices = top_k_indices % self.cfg.num_tokens
# ic(beams_indices, token_indices)
# Update pixel values, sequences, and scores
# pixel_value = pixel_value[beams_indices]
# ic(next_sequences)
next_sequences = next_sequences[beams_indices]
# ic(next_sequences)
next_sequences = torch.cat([next_sequences, token_indices.unsqueeze(1)], dim=-1)
# ic(next_sequences)
sequence_scores = top_k_scores #/ (idx + 3) ** alpha
# ic('-------------------')
# if idx > 2: break
# Select the best sequence
max_score, max_score_idx = torch.max(sequence_scores, 0)
# Select the sequence with the highest score
best_sequence = next_sequences[max_score_idx]
# ic(best_sequence, max_score)
return best_sequence, max_score
..
There are omitted code especially forward_pass however the code will work properly if you adapt this carefully.
And you can also capture some idea from here.
Thank you.