Skip to content

Commit

Permalink
fixed the masked mse loss for cehrbert when value_prediction is enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Sep 13, 2024
1 parent 84d7e34 commit f254111
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/cehrbert/models/hf_models/hf_cehrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def forward(

merged = torch.where(
concept_value_masks.to(torch.bool),
concept_embeddings_with_val,
gelu_new(concept_embeddings_with_val),
concept_embeddings,
)

Expand Down Expand Up @@ -320,10 +320,8 @@ def forward(
if self.config.include_value_prediction:
mlm_masks = labels != -100
predicted_values = self.concept_value_decoder_layer(cehrbert_output.last_hidden_state)
num_items = torch.sum(concept_value_masks.to(torch.float32), dim=-1) + 1e-6
values_ = (predicted_values.squeeze(-1) - concept_values) ** 2
masked_mse = torch.sum(values_ * concept_value_masks * mlm_masks, dim=-1) / num_items
total_loss += torch.mean(masked_mse)
total_loss += torch.mean(values_ * concept_value_masks * mlm_masks)

return CehrBertModelOutput(
loss=total_loss,
Expand Down

0 comments on commit f254111

Please sign in to comment.