diff --git a/src/cehrbert/models/hf_models/hf_cehrbert.py b/src/cehrbert/models/hf_models/hf_cehrbert.py index 4d07361e..4b187a15 100644 --- a/src/cehrbert/models/hf_models/hf_cehrbert.py +++ b/src/cehrbert/models/hf_models/hf_cehrbert.py @@ -96,7 +96,7 @@ def forward( merged = torch.where( concept_value_masks.to(torch.bool), - concept_embeddings_with_val, + gelu_new(concept_embeddings_with_val), concept_embeddings, ) @@ -320,10 +320,8 @@ def forward( if self.config.include_value_prediction: mlm_masks = labels != -100 predicted_values = self.concept_value_decoder_layer(cehrbert_output.last_hidden_state) - num_items = torch.sum(concept_value_masks.to(torch.float32), dim=-1) + 1e-6 values_ = (predicted_values.squeeze(-1) - concept_values) ** 2 - masked_mse = torch.sum(values_ * concept_value_masks * mlm_masks, dim=-1) / num_items - total_loss += torch.mean(masked_mse) + total_loss += torch.mean(values_ * concept_value_masks * mlm_masks) return CehrBertModelOutput( loss=total_loss,