Skip to content

Commit

Permalink
fixed a bug in calculating the MLM value prediction, where the num_it…
Browse files Browse the repository at this point in the history
…ems was miscalculated
  • Loading branch information
ChaoPang committed Sep 11, 2024
1 parent 9ecaca5 commit a77b216
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/cehrbert/models/hf_models/hf_cehrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def forward(

merged = torch.where(
concept_value_masks.to(torch.bool),
concept_embeddings_with_val,
gelu_new(concept_embeddings_with_val),
concept_embeddings,
)

Expand Down Expand Up @@ -319,10 +319,15 @@ def forward(
# In addition to MLM, we also predict the values associated with the masked concepts
if self.config.include_value_prediction:
mlm_masks = labels != -100
# The intersection of the MLM masks and concept_value_masks is taken into account
# for this loss term Basically we only want to calculate the MLE loss for the lab
# concepts that are selected for MLM
masks = torch.logical_and(mlm_masks, concept_value_masks.to(torch.bool))
masks = masks.to(torch.float32)
predicted_values = self.concept_value_decoder_layer(cehrbert_output.last_hidden_state)
num_items = torch.sum(concept_value_masks.to(torch.float32), dim=-1) + 1e-6
num_items = torch.sum(masks, dim=-1) + 1e-6
values_ = (predicted_values.squeeze(-1) - concept_values) ** 2
masked_mse = torch.sum(values_ * concept_value_masks * mlm_masks, dim=-1) / num_items
masked_mse = torch.sum(values_ * masks, dim=-1) / num_items
total_loss += torch.mean(masked_mse)

return CehrBertModelOutput(
Expand Down

0 comments on commit a77b216

Please sign in to comment.