Skip to content

Commit

Permalink
added logic for removing records that don't have any clinical events
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Oct 11, 2024
1 parent 4d234b0 commit 7e4ec42
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/cehrbert/data_generators/hf_data_generator/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def create_cehrbert_pretraining_dataset(
streaming=data_args.streaming,
)

# Remove patients without any records
dataset = dataset.filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_visits"]],
num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None,
batched=True,
batch_size=data_args.preprocessing_batch_size,
)

if not data_args.streaming:
if isinstance(dataset, DatasetDict):
all_columns = dataset["train"].column_names
Expand Down Expand Up @@ -91,6 +99,14 @@ def create_cehrbert_finetuning_dataset(
streaming=data_args.streaming,
)

# Remove patients without any records
dataset = dataset.filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_visits"]],
num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None,
batched=True,
batch_size=data_args.preprocessing_batch_size,
)

if not data_args.streaming:
if isinstance(dataset, DatasetDict):
all_columns = dataset["train"].column_names
Expand Down
2 changes: 2 additions & 0 deletions src/cehrbert/data_generators/hf_data_generator/meds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,15 @@ def _create_cehrbert_data_from_meds(
writer_batch_size=data_args.preprocessing_batch_size,
streaming=data_args.streaming,
)

# Remove patients without any records
dataset = dataset.filter(
lambda batch: [num_of_concepts > 0 for num_of_concepts in batch["num_of_visits"]],
num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None,
batched=True,
batch_size=data_args.preprocessing_batch_size,
)

# Convert the CehrBertPatient to CehrBert data inputs
dataset = apply_cehrbert_dataset_mapping(
dataset,
Expand Down

0 comments on commit 7e4ec42

Please sign in to comment.