refer to code:
.
def model_forward(self, pixel_values, labels):
# Origin vit encoder-decoder outputs
outputs = self.model(pixel_values=pixel_values, labels=labels, output_hidden_states=True)
# Get last hidden state
last_hidden_state = outputs.decoder_hidden_states[-1] # batch_size, seq_len, hidden_size, ex)5, 15, 768
return last_hidden_state
def fc_part(self, last_hidden_state):
# Reshape the last hidden state
reshaped_logits = last_hidden_state.view(-1, self.model.config.decoder.hidden_size) # batch_size*seq_len, hidden_size
# Apply the fully connected layer
new_logits = self.custom_decoder_fc(reshaped_logits) # batch_size*seq_len, vocab_size
return new_logits
def compute_loss(self, new_logits, labels):
# Reshape labels to match logits dimension
reshaped_labels = labels.view(-1) #batch_size, seq_len -> batch_size*seq_len
# Calculate loss
# [batch_size*seq_len, vocab_size] vs [batch_size*seq_len] #ex) [70, 13] vs [70]
loss = self.loss_f(new_logits, reshaped_labels) #scalar tensor
return loss
def forward_pass(self, pixel_values, labels):
last_hidden_state = self.model_forward(pixel_values, labels) # batch_size, seq_len, hidden_size
new_logits = self.fc_part(last_hidden_state) # batch_size*seq_len, vocab_size
loss = self.compute_loss(new_logits, labels) # scalar tensor
# Reshape new_logits to match labels dimension
new_logits = new_logits.view(labels.shape[0], labels.shape[1], -1) # bathc_size, seq_len, vocab_size
return {'logits':new_logits, 'loss':loss}
..
forward_pass do process step by step.
And in the end return last hidden states logits and loss.
Thank you.
www.marearts.com
๐๐ป♂️
No comments:
Post a Comment