From 7dde582dfd05e2d668d233f623a0a3c368e80814 Mon Sep 17 00:00:00 2001 From: Pierre Marcenac Date: Fri, 7 Jul 2023 14:03:39 +0000 Subject: [PATCH] Clean naming across files --- .../ml_croissant/_src/datasets.py | 29 ++++++++++++------- .../_src/operation_graph/__init__.py | 4 +-- .../_src/operation_graph/base_operation.py | 4 ++- .../_src/operation_graph/graph.py | 14 +++++---- .../scripts/{generate.py => load.py} | 9 +++++- 5 files changed, 40 insertions(+), 20 deletions(-) rename python/ml_croissant/scripts/{generate.py => load.py} (86%) diff --git a/python/ml_croissant/ml_croissant/_src/datasets.py b/python/ml_croissant/ml_croissant/_src/datasets.py index 9dad227c5..731b017da 100644 --- a/python/ml_croissant/ml_croissant/_src/datasets.py +++ b/python/ml_croissant/ml_croissant/_src/datasets.py @@ -10,7 +10,7 @@ from ml_croissant._src.core.graphs import utils as graphs_utils from ml_croissant._src.core.issues import Issues, ValidationError from ml_croissant._src.operation_graph import ( - ComputationGraph, + OperationGraph, ) from ml_croissant._src.operation_graph.operations import ( GroupRecordSet, @@ -32,7 +32,7 @@ class Validator: file_or_file_path: epath.PathLike issues: Issues = dataclasses.field(default_factory=Issues) file: dict = dataclasses.field(init=False) - operations: ComputationGraph | None = None + operations: OperationGraph | None = None def run_static_analysis(self, debug: bool = False): try: @@ -41,7 +41,7 @@ def run_static_analysis(self, debug: bool = False): nodes, parents = from_jsonld_to_nodes(self.issues, json_ld) # Print all nodes for debugging purposes. if debug: - logging.info('Found the following nodes during static analysis.') + logging.info("Found the following nodes during static analysis.") for node in nodes: logging.info(node) entry_node, structure_graph = from_nodes_to_structure_graph( @@ -54,7 +54,7 @@ def run_static_analysis(self, debug: bool = False): # features. if entry_node.uid == "Movielens-25M": return - self.operations = ComputationGraph.from_nodes( + self.operations = OperationGraph.from_nodes( issues=self.issues, metadata=entry_node, graph=structure_graph, @@ -78,10 +78,12 @@ class Dataset: Args: file: A JSON object or a path to a Croissant file (string or pathlib.Path). + operations: The operation graph class. None by default. + debug: Whether to print debug hints. False by default. """ file: epath.PathLike - operations: ComputationGraph | None = None + operations: OperationGraph | None = None debug: bool = False def __post_init__(self): @@ -92,7 +94,7 @@ def __post_init__(self): self.operations = self.validator.operations def records(self, record_set: str) -> Records: - return Records(self, record_set) + return Records(self, record_set, debug=self.debug) @dataclasses.dataclass @@ -102,10 +104,12 @@ class Records: Args: dataset: The parent dataset. record_set: The name of the record set. + debug: Whether to print debug hints. """ dataset: Dataset record_set: str + debug: bool def __iter__(self): """Executes all operations, runs dynamic analysis and yields examples. @@ -114,9 +118,12 @@ def __iter__(self): record_set. """ results: Mapping[str, Any] = {} - operations = self.dataset.operations.graph + operations = self.dataset.operations.operations + if self.debug: + graphs_utils.pretty_print_graph(operations) for operation in nx.topological_sort(operations): - logging.info('Executing "%s"', operation) + if self.debug: + logging.info('Executing "%s"', operation) kwargs = operations.nodes[operation].get("kwargs", {}) previous_results = [ results[previous_operation] @@ -141,9 +148,11 @@ def __iter__(self): read_fields = [] for read_field in operations.successors(operation): assert isinstance(read_field, ReadField) - logging.info('Executing "%s"', read_field) + if self.debug: + logging.info('Executing "%s"', read_field) read_fields.append(read_field(line, **kwargs)) - logging.info('Executing "%s"', operation) + if self.debug: + logging.info('Executing "%s"', operation) yield operation(*read_fields, **kwargs) else: if isinstance(operation, ReadField) and not previous_results: diff --git a/python/ml_croissant/ml_croissant/_src/operation_graph/__init__.py b/python/ml_croissant/ml_croissant/_src/operation_graph/__init__.py index 4999c63a6..cc7bc78e0 100644 --- a/python/ml_croissant/ml_croissant/_src/operation_graph/__init__.py +++ b/python/ml_croissant/ml_croissant/_src/operation_graph/__init__.py @@ -1,5 +1,5 @@ -from ml_croissant._src.operation_graph.graph import ComputationGraph +from ml_croissant._src.operation_graph.graph import OperationGraph __all__ = [ - "ComputationGraph", + "OperationGraph", ] diff --git a/python/ml_croissant/ml_croissant/_src/operation_graph/base_operation.py b/python/ml_croissant/ml_croissant/_src/operation_graph/base_operation.py index ad3435208..54986490c 100644 --- a/python/ml_croissant/ml_croissant/_src/operation_graph/base_operation.py +++ b/python/ml_croissant/ml_croissant/_src/operation_graph/base_operation.py @@ -1,12 +1,13 @@ """Base operation module.""" +import abc import dataclasses from ml_croissant._src.structure_graph.base_node import Node @dataclasses.dataclass(frozen=True, repr=False) -class Operation: +class Operation(abc.ABC): """Generic base class to define an operation. `@dataclass(frozen=True)` allows having a hashable operation for NetworkX to use @@ -22,6 +23,7 @@ class Operation: node: Node + @abc.abstractmethod def __call__(self, *args, **kwargs): raise NotImplementedError diff --git a/python/ml_croissant/ml_croissant/_src/operation_graph/graph.py b/python/ml_croissant/ml_croissant/_src/operation_graph/graph.py index abee23dfa..d26914a5b 100644 --- a/python/ml_croissant/ml_croissant/_src/operation_graph/graph.py +++ b/python/ml_croissant/ml_croissant/_src/operation_graph/graph.py @@ -149,11 +149,11 @@ def _add_operations_for_file_object( @dataclasses.dataclass(frozen=True) -class ComputationGraph: +class OperationGraph: """Graph of dependent operations to execute to generate the dataset.""" issues: Issues - graph: nx.MultiDiGraph + operations: nx.MultiDiGraph @classmethod def from_nodes( @@ -163,7 +163,7 @@ def from_nodes( graph: nx.MultiDiGraph, croissant_folder: epath.Path, rdf_namespace_manager: namespace.NamespaceManager, - ) -> "ComputationGraph": + ) -> "OperationGraph": """Builds the ComputationGraph from the nodes. This is done by: @@ -208,13 +208,15 @@ def from_nodes( init_operation = InitOperation(node=metadata) for entry_operation in entry_operations: operations.add_edge(init_operation, entry_operation) - return ComputationGraph(issues=issues, graph=operations) + return OperationGraph(issues=issues, operations=operations) def check_graph(self): """Checks the computation graph for issues.""" - if not self.graph.is_directed(): + if not self.operations.is_directed(): self.issues.add_error("Computation graph is not directed.") - selfloops = [operation.uid for operation, _ in nx.selfloop_edges(self.graph)] + selfloops = [ + operation.uid for operation, _ in nx.selfloop_edges(self.operations) + ] if selfloops: self.issues.add_error( f"The following operations refered to themselves: {selfloops}" diff --git a/python/ml_croissant/scripts/generate.py b/python/ml_croissant/scripts/load.py similarity index 86% rename from python/ml_croissant/scripts/generate.py rename to python/ml_croissant/scripts/load.py index 2f863b3ce..9c3b9131a 100644 --- a/python/ml_croissant/scripts/generate.py +++ b/python/ml_croissant/scripts/load.py @@ -26,6 +26,12 @@ "The number of records to generate. Use `-1` to generate the whole dataset.", ) +flags.DEFINE_bool( + "debug", + False, + "Whether to print debug hints.", +) + flags.mark_flag_as_required("file") @@ -37,7 +43,8 @@ def main(argv): file = FLAGS.file record_set = FLAGS.record_set num_records = FLAGS.num_records - dataset = Dataset(file) + debug = FLAGS.debug + dataset = Dataset(file, debug=debug) records = dataset.records(record_set) print(f"Generating the first {num_records} records.") for i, record in enumerate(records):