Skip to content

Commit

Permalink
Move filters from Dataset init to self.records (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccl-core authored Jul 22, 2024
1 parent 6fc0adb commit a6d96d9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
21 changes: 12 additions & 9 deletions python/mlcroissant/mlcroissant/_src/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,13 @@ class Dataset:
downloads. If `document.csv` is the FileObject and you downloaded it to
`~/Downloads/document.csv`, you can specify `mapping={"document.csv":
"~/Downloads/document.csv"}`.
filters: A dictionary mapping a field ID to the value we want to filter in. For
example, when writing {'data/split': 'train'}, we want to keep all records
whose field `data/split` takes the value `train`.
"""

jsonld: epath.PathLike | str | dict[str, Any] | None
operations: OperationGraph = dataclasses.field(init=False)
metadata: Metadata = dataclasses.field(init=False)
debug: bool = False
mapping: Mapping[str, epath.PathLike] | None = None
filters: Filters | None = None

def __post_init__(self):
"""Runs the static analysis of `file`."""
Expand All @@ -89,8 +85,6 @@ def __post_init__(self):
# Draw the operations graph for debugging purposes.
if self.debug:
graphs_utils.pretty_print_graph(self.operations.operations, simplify=False)
if self.filters:
_validate_filters(self.filters)

@classmethod
def from_metadata(cls, metadata: Metadata) -> Dataset:
Expand All @@ -100,8 +94,17 @@ def from_metadata(cls, metadata: Metadata) -> Dataset:
dataset.operations = get_operations(metadata.ctx, metadata)
return dataset

def records(self, record_set: str) -> Records:
"""Accesses all records with @id==record_set if it exists."""
def records(self, record_set: str, filters: Filters | None = None) -> Records:
"""Accesses all records with @id==record_set if it exists.
record_set: The name of the record set to access.
filters: A dictionary mapping a field ID to the value we want to filter in. For
example, when writing {'data/split': 'train'}, we want to keep all records
whose field `data/split` takes the value `train`.
"""
if filters:
_validate_filters(filters)

if not any(rs for rs in self.metadata.record_sets if rs.uuid == record_set):
ids = [record_set.uuid for record_set in self.metadata.record_sets]
error_msg = f"did not find any record set with the name `{record_set}`. "
Expand All @@ -113,7 +116,7 @@ def records(self, record_set: str) -> Records:
return Records(
dataset=self,
record_set=record_set,
filters=self.filters,
filters=filters,
debug=self.debug,
)

Expand Down
4 changes: 2 additions & 2 deletions python/mlcroissant/mlcroissant/_src/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def load_records_and_test_equality(
with output_file.open("rb") as f:
lines = f.readlines()
expected_records = [json.loads(line) for line in lines]
dataset = datasets.Dataset(config, filters=filters)
records = dataset.records(record_set_name)
dataset = datasets.Dataset(config)
records = dataset.records(record_set_name, filters=filters)
records = iter(records)
length = 0
for i, record in enumerate(records):
Expand Down

0 comments on commit a6d96d9

Please sign in to comment.