Skip to content

Commit

Permalink
tested masking lab values independent of MLM
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Sep 12, 2024
1 parent fd52c62 commit 1ab268c
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/cehrbert/models/hf_models/hf_cehrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,17 +320,15 @@ def forward(

# In addition to MLM, we also predict the values associated with the masked concepts
if self.config.include_value_prediction:
mlm_masks = labels != -100
# The intersection of the MLM masks and concept_value_masks is taken into account
# for this loss term Basically we only want to calculate the MLE loss for the lab
# concepts that are selected for MLM
masks = torch.logical_and(mlm_masks, concept_value_masks.to(torch.bool))
probability_matrix = torch.full(labels.shape, self.config.mlm_probability)
masked_indices = torch.bernoulli(probability_matrix).bool().to(labels.device)
masks = torch.logical_and(masked_indices, concept_value_masks.to(torch.bool))
masks = masks.to(torch.float32)
predicted_values = self.concept_value_decoder_layer(cehrbert_output.last_hidden_state)
num_items = torch.sum(masks, dim=-1) + 1e-6
values_ = (predicted_values.squeeze(-1) - concept_values) ** 2
masked_mse = torch.sum(values_ * masks, dim=-1) / num_items
total_loss += torch.mean(masked_mse)
total_loss += torch.mean((predicted_values.squeeze(-1) - concept_values) ** 2 * masks)

return CehrBertModelOutput(
loss=total_loss,
Expand Down

0 comments on commit 1ab268c

Please sign in to comment.