refer to code first.
.
#this beam search only deal with batch size 1
    def beam_search(self, pixel_value, max_length):
        beam_size = self.cfg.num_beams
        alpha = self.cfg.beam_alpha  # Length normalization coefficient
        temperature = self.cfg.beam_temp  # Temperature for softmax
        # Initialize input ids as bos_token_id
        first_sequence = torch.full((pixel_value.shape[0], 1), self.model.config.decoder_start_token_id).to(pixel_value.device)
        # ic(first_sequence) #tensor([[1]])
        # Predict second token id
        outputs = self.forward_pass(pixel_value, first_sequence)
        # ic(outputs.keys()) #dict_keys(['logits', 'loss'])
        # We only need the logits corresponding to the last prediction
        next_token_logits = outputs['logits'][:, -1, :]  
        # ic(outputs['logits'].shape) #[1, 1, 13] batch, seq, vocab_size
        # ic(outputs['logits'][:, -1, :].shape) #[1, 13] batch, vocab_size
        # Apply temperature
        # ic(next_token_logits) 
        # [-5.0641, 32.7805, -2.6743, -4.6459,  0.8130, -1.3443, -1.2016, -4.0770,
        #                         -3.5401,  0.2425, -5.3685, -1.8074, -5.2606]],
        # next_token_logits /= temperature
        # ic(next_token_logits) 
        # [-7.2344, 46.8292, -3.8204, -6.6370,  1.1614, -1.9205, -1.7166, -5.8243,
        #                         -5.0573,  0.3464, -7.6693, -2.5820, -7.5152]],
        # Select top k tokens
        next_token_probs = F.softmax(next_token_logits, dim=-1) 
        top_k_probs, top_k_ids = torch.topk(next_token_probs, beam_size) 
        # ic(F.softmax(next_token_logits, dim=-1))
        # tensor([[3.3148e-24, 1.0000e+00, 1.0072e-22, 6.0241e-24, 1.4680e-20, 6.7340e-22,
        #                                            8.2570e-22, 1.3579e-23, 2.9239e-23, 6.4976e-21, 2.1458e-24, 3.4751e-22,
        #                                            2.5034e-24]]
        # ic(top_k_probs, top_k_ids)
        # top_k_probs: tensor([[1.]], grad_fn=<TopkBackward0>)
        # top_k_ids: tensor([[1]])
        # Prepare next sequences. Each top 1 token is appended to the first_sequence
        # ic(first_sequence.shape) #[1, 1]
        next_sequences = first_sequence.repeat_interleave(beam_size, dim=0)
        # ic(next_sequences.shape) #[10, 1] 10 is beam size, 1 is seq length
        next_sequences = torch.cat([next_sequences, top_k_ids.view(-1, 1)], dim=-1)
        # ic(next_sequences.shape) #[10, 2] 10 is beam size, 2 is seq length
        # ic(next_sequences) 
        # Also prepare a tensor to hold the cumulative scores of each sequence, or the sum of the log probabilities of each token in the sequence
        sequence_scores = (torch.log(top_k_probs).view(-1))  #/ (1 + 1) ** alpha
        # ic(sequence_scores) #[  0.0000, -15.9837]
        # We'll need to repeat the pixel_values for each sequence in each beam
        pixel_value = pixel_value.repeat_interleave(beam_size, dim=0)  
        # ic(pixel_value.shape) #[10, 3, 224, 224], 10 is beam size, 3 is channel, 224 is image size
        for idx in range(max_length - 1):  # We already generated one token
            # ic(idx, '--------------------')
            outputs = self.forward_pass(pixel_value, next_sequences)
            next_token_logits = outputs['logits'][:, -1, :]  
            # ic(outputs['logits'].shape, outputs['logits']) #[2, 2, 13], batch, seq, vocab_size
            # ic(next_token_logits.shape, next_token_logits)
            # Apply temperature
            # next_token_logits /= temperature
            # Convert logits to probabilities and calculate new scores
            next_token_probs = F.softmax(next_token_logits, dim=-1) 
            # ic(next_token_probs.shape, next_token_probs) #[2, 13], batch, vocab_size
            next_token_scores = torch.log(next_token_probs)
            # ic(next_token_scores.shape, next_token_scores) #[2, 13], batch, vocab_size
            new_scores = sequence_scores.unsqueeze(1) + next_token_scores
            # ic(sequence_scores.unsqueeze(1))
            # ic(new_scores.shape, new_scores) #[2, 13], batch, vocab_size
            # Select top k sequences
            # ic(new_scores.view(-1), new_scores.view(-1).shape)
            top_k_scores, top_k_indices = torch.topk(new_scores.view(-1), beam_size)  
            # ic(top_k_scores, top_k_indices)
            # Get the beam and token that each of the top k sequences comes from
            beams_indices = top_k_indices // self.cfg.num_tokens 
            token_indices = top_k_indices % self.cfg.num_tokens  
            # ic(beams_indices, token_indices)
            # Update pixel values, sequences, and scores
            # pixel_value = pixel_value[beams_indices]  
            # ic(next_sequences)
            next_sequences = next_sequences[beams_indices] 
            # ic(next_sequences)
            next_sequences = torch.cat([next_sequences, token_indices.unsqueeze(1)], dim=-1)
            # ic(next_sequences)
            sequence_scores = top_k_scores #/ (idx + 3) ** alpha
            # ic('-------------------')
            # if idx > 2: break
        # Select the best sequence
        max_score, max_score_idx = torch.max(sequence_scores, 0)
        # Select the sequence with the highest score
        best_sequence = next_sequences[max_score_idx]
        # ic(best_sequence, max_score)
        return best_sequence, max_score
..
This is portion of my class.
There are omitted code especially forward_pass however the code will work properly if you adapt this carefully.
And you can also capture some idea from here.
Thank you.
ππ»♂️
www.marearts.com
 
 
 
 
No comments:
Post a Comment