Skip to content

Commit

Permalink
Merge pull request #104 from mlcommons/origin/refacto/pierremarcenac-3
Browse files Browse the repository at this point in the history
Clean naming across files for operations
  • Loading branch information
marcenacp authored Jul 7, 2023
2 parents 646d560 + 41a3de8 commit 08b5150
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,4 @@ jobs:
run: pip install .

- name: Generate JSON-LD files
run: python scripts/generate.py --file ../../datasets/titanic/metadata.json --record_set passengers
run: python scripts/load.py --file ../../datasets/titanic/metadata.json --record_set passengers
2 changes: 1 addition & 1 deletion python/ml_croissant/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The command:
Similarly, you can generate a dataset by launching:

```bash
python scripts/generate.py \
python scripts/load.py \
--file ../../datasets/titanic/metadata.json \
--record_set passengers \
--num_records 10
Expand Down
29 changes: 19 additions & 10 deletions python/ml_croissant/ml_croissant/_src/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ml_croissant._src.core.graphs import utils as graphs_utils
from ml_croissant._src.core.issues import Issues, ValidationError
from ml_croissant._src.operation_graph import (
ComputationGraph,
OperationGraph,
)
from ml_croissant._src.operation_graph.operations import (
GroupRecordSet,
Expand All @@ -32,7 +32,7 @@ class Validator:
file_or_file_path: epath.PathLike
issues: Issues = dataclasses.field(default_factory=Issues)
file: dict = dataclasses.field(init=False)
operations: ComputationGraph | None = None
operations: OperationGraph | None = None

def run_static_analysis(self, debug: bool = False):
try:
Expand All @@ -41,7 +41,7 @@ def run_static_analysis(self, debug: bool = False):
nodes, parents = from_jsonld_to_nodes(self.issues, json_ld)
# Print all nodes for debugging purposes.
if debug:
logging.info('Found the following nodes during static analysis.')
logging.info("Found the following nodes during static analysis.")
for node in nodes:
logging.info(node)
entry_node, structure_graph = from_nodes_to_structure_graph(
Expand All @@ -54,7 +54,7 @@ def run_static_analysis(self, debug: bool = False):
# features.
if entry_node.uid == "Movielens-25M":
return
self.operations = ComputationGraph.from_nodes(
self.operations = OperationGraph.from_nodes(
issues=self.issues,
metadata=entry_node,
graph=structure_graph,
Expand All @@ -78,10 +78,12 @@ class Dataset:
Args:
file: A JSON object or a path to a Croissant file (string or pathlib.Path).
operations: The operation graph class. None by default.
debug: Whether to print debug hints. False by default.
"""

file: epath.PathLike
operations: ComputationGraph | None = None
operations: OperationGraph | None = None
debug: bool = False

def __post_init__(self):
Expand All @@ -92,7 +94,7 @@ def __post_init__(self):
self.operations = self.validator.operations

def records(self, record_set: str) -> Records:
return Records(self, record_set)
return Records(self, record_set, debug=self.debug)


@dataclasses.dataclass
Expand All @@ -102,10 +104,12 @@ class Records:
Args:
dataset: The parent dataset.
record_set: The name of the record set.
debug: Whether to print debug hints.
"""

dataset: Dataset
record_set: str
debug: bool

def __iter__(self):
"""Executes all operations, runs dynamic analysis and yields examples.
Expand All @@ -114,9 +118,12 @@ def __iter__(self):
record_set.
"""
results: Mapping[str, Any] = {}
operations = self.dataset.operations.graph
operations = self.dataset.operations.operations
if self.debug:
graphs_utils.pretty_print_graph(operations)
for operation in nx.topological_sort(operations):
logging.info('Executing "%s"', operation)
if self.debug:
logging.info('Executing "%s"', operation)
kwargs = operations.nodes[operation].get("kwargs", {})
previous_results = [
results[previous_operation]
Expand All @@ -141,9 +148,11 @@ def __iter__(self):
read_fields = []
for read_field in operations.successors(operation):
assert isinstance(read_field, ReadField)
logging.info('Executing "%s"', read_field)
if self.debug:
logging.info('Executing "%s"', read_field)
read_fields.append(read_field(line, **kwargs))
logging.info('Executing "%s"', operation)
if self.debug:
logging.info('Executing "%s"', operation)
yield operation(*read_fields, **kwargs)
else:
if isinstance(operation, ReadField) and not previous_results:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ml_croissant._src.operation_graph.graph import ComputationGraph
from ml_croissant._src.operation_graph.graph import OperationGraph

__all__ = [
"ComputationGraph",
"OperationGraph",
]
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Base operation module."""

import abc
import dataclasses

from ml_croissant._src.structure_graph.base_node import Node


@dataclasses.dataclass(frozen=True, repr=False)
class Operation:
class Operation(abc.ABC):
"""Generic base class to define an operation.
`@dataclass(frozen=True)` allows having a hashable operation for NetworkX to use
Expand All @@ -22,6 +23,7 @@ class Operation:

node: Node

@abc.abstractmethod
def __call__(self, *args, **kwargs):
raise NotImplementedError

Expand Down
14 changes: 8 additions & 6 deletions python/ml_croissant/ml_croissant/_src/operation_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ def _add_operations_for_file_object(


@dataclasses.dataclass(frozen=True)
class ComputationGraph:
class OperationGraph:
"""Graph of dependent operations to execute to generate the dataset."""

issues: Issues
graph: nx.MultiDiGraph
operations: nx.MultiDiGraph

@classmethod
def from_nodes(
Expand All @@ -163,7 +163,7 @@ def from_nodes(
graph: nx.MultiDiGraph,
croissant_folder: epath.Path,
rdf_namespace_manager: namespace.NamespaceManager,
) -> "ComputationGraph":
) -> "OperationGraph":
"""Builds the ComputationGraph from the nodes.
This is done by:
Expand Down Expand Up @@ -208,13 +208,15 @@ def from_nodes(
init_operation = InitOperation(node=metadata)
for entry_operation in entry_operations:
operations.add_edge(init_operation, entry_operation)
return ComputationGraph(issues=issues, graph=operations)
return OperationGraph(issues=issues, operations=operations)

def check_graph(self):
"""Checks the computation graph for issues."""
if not self.graph.is_directed():
if not self.operations.is_directed():
self.issues.add_error("Computation graph is not directed.")
selfloops = [operation.uid for operation, _ in nx.selfloop_edges(self.graph)]
selfloops = [
operation.uid for operation, _ in nx.selfloop_edges(self.operations)
]
if selfloops:
self.issues.add_error(
f"The following operations refered to themselves: {selfloops}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
"The number of records to generate. Use `-1` to generate the whole dataset.",
)

flags.DEFINE_bool(
"debug",
False,
"Whether to print debug hints.",
)

flags.mark_flag_as_required("file")


Expand All @@ -37,7 +43,8 @@ def main(argv):
file = FLAGS.file
record_set = FLAGS.record_set
num_records = FLAGS.num_records
dataset = Dataset(file)
debug = FLAGS.debug
dataset = Dataset(file, debug=debug)
records = dataset.records(record_set)
print(f"Generating the first {num_records} records.")
for i, record in enumerate(records):
Expand Down

0 comments on commit 08b5150

Please sign in to comment.