Skip to content

Commit

Permalink
Set up PyType checks
Browse files Browse the repository at this point in the history
  • Loading branch information
marcenacp committed Jul 10, 2023
1 parent 08b5150 commit f97ef8e
Show file tree
Hide file tree
Showing 21 changed files with 123 additions and 64 deletions.
20 changes: 20 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,26 @@ jobs:
- name: PyLint
run: pylint **/*.py

pytype-test:
name: PyType / Python 3.11
runs-on: ubuntu-latest
defaults:
run:
working-directory: ./python/ml_croissant
steps:
- uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'

- name: Install library
run: pip install .[dev]

- name: PyType
run: pytype --verbosity 2 .

validation-test:
name: Validation / JSON-LD Tests / Python 3.11
runs-on: ubuntu-latest
Expand Down
3 changes: 1 addition & 2 deletions python/ml_croissant/ml_croissant/_src/datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""datasets module."""
from __future__ import annotations

from collections.abc import Mapping
import dataclasses
from typing import Any

Expand Down Expand Up @@ -117,7 +116,7 @@ def __iter__(self):
Warning: at the moment, this method yields examples from the first explored
record_set.
"""
results: Mapping[str, Any] = {}
results: dict[str, Any] = {}
operations = self.dataset.operations.operations
if self.debug:
graphs_utils.pretty_print_graph(operations)
Expand Down
13 changes: 6 additions & 7 deletions python/ml_croissant/ml_croissant/_src/operation_graph/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""graph module."""

from collections.abc import Mapping
import dataclasses

from etils import epath
Expand Down Expand Up @@ -47,7 +46,7 @@ def _add_operations_for_field_with_source(
issues: Issues,
graph: nx.MultiDiGraph,
operations: nx.MultiDiGraph,
last_operation: Mapping[Node, Operation],
last_operation: dict[Node, Operation],
node: Field,
rdf_namespace_manager: namespace.NamespaceManager,
):
Expand Down Expand Up @@ -89,7 +88,7 @@ def _add_operations_for_field_with_source(
def _add_operations_for_field_with_data(
graph: nx.MultiDiGraph,
operations: nx.MultiDiGraph,
last_operation: Mapping[Node, Operation],
last_operation: dict[Node, Operation],
node: Field,
):
"""Adds a `Data` operation for a node of type `Field` with data.
Expand All @@ -105,8 +104,8 @@ def _add_operations_for_field_with_data(
def _add_operations_for_file_object(
graph: nx.MultiDiGraph,
operations: nx.MultiDiGraph,
last_operation: Mapping[Node, Operation],
node: Node,
last_operation: dict[Node, Operation],
node: FileObject,
croissant_folder: epath.Path,
):
"""Adds all operations for a node of type `FileObject`.
Expand All @@ -125,7 +124,7 @@ def _add_operations_for_file_object(
# Extract the file if needed
if (
node.encoding_format == "application/x-tar"
and isinstance(successor, (FileObject, FileSet))
and isinstance(successor, FileSet)
and successor.encoding_format != "application/x-tar"
):
untar = Untar(node=node, target_node=successor)
Expand Down Expand Up @@ -172,7 +171,7 @@ def from_nodes(
2. Building the computation graph by exploring the structure graph layers by
layers in a breadth-first search.
"""
last_operation: Mapping[Node, Operation] = {}
last_operation: dict[Node, Operation] = {}
operations = nx.MultiDiGraph()
# Find all fields
for node in nx.topological_sort(graph):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""graph_test module."""

from ml_croissant._src.operation_graph.operations import ReadField
from ml_croissant._src.tests.nodes import empty_node
from ml_croissant._src.tests.nodes import empty_field
import pytest
import rdflib
from rdflib import namespace
Expand All @@ -11,7 +11,9 @@ def test_find_data_type():
sc = rdflib.Namespace("https://schema.org/")
rdf_namespace_manager = namespace.NamespaceManager(rdflib.Graph())
rdf_namespace_manager.bind("sc", sc)
read_field = ReadField(node=empty_node, rdf_namespace_manager=rdf_namespace_manager)
read_field = ReadField(
node=empty_field, rdf_namespace_manager=rdf_namespace_manager
)
assert read_field.find_data_type("sc:Boolean") == bool
assert read_field.find_data_type(["sc:Boolean", "bar"]) == bool
assert read_field.find_data_type(["bar", "sc:Boolean"]) == bool
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""data_test module."""

from ml_croissant._src.operation_graph.operations import data
from ml_croissant._src.tests.nodes import empty_node
from ml_croissant._src.tests.nodes import empty_field


def test_str_representation():
operation = data.Data(node=empty_node)
assert str(operation) == "Data(node_name)"
operation = data.Data(node=empty_field)
assert str(operation) == "Data(field_name)"
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Untar(Operation):
"""Un-tars "application/x-tar" and yields filtered lines."""

node: FileObject
target_node: FileObject | FileSet
target_node: FileSet

def __call__(self):
includes = fnmatch.translate(self.target_node.includes)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""extract_test module."""

from ml_croissant._src.operation_graph.operations import extract
from ml_croissant._src.tests.nodes import empty_node
from ml_croissant._src.tests.nodes import empty_file_object, empty_file_set


def test_str_representation():
operation = extract.Untar(node=empty_node, target_node=empty_node)
assert str(operation) == "Untar(node_name)"
operation = extract.Untar(node=empty_file_object, target_node=empty_file_set)
assert str(operation) == "Untar(file_object_name)"
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ReadField(Operation):
node: Field
rdf_namespace_manager: namespace.NamespaceManager

def find_data_type(self, data_types: list[str] | tuple[str] | str) -> type:
def find_data_type(self, data_types: list[str] | tuple[str, ...] | str) -> type:
"""Finds the data type by expanding its name from the namespace manager.
In some cases, we specify a list of data types. In that case, we take the first
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""field_test module."""

from ml_croissant._src.operation_graph.operations import field
from ml_croissant._src.tests.nodes import empty_node
from ml_croissant._src.tests.nodes import empty_field
from rdflib import namespace


def test_str_representation():
operation = field.ReadField(
node=empty_node, rdf_namespace_manager=namespace.NamespaceManager
node=empty_field, rdf_namespace_manager=namespace.NamespaceManager
)
assert str(operation) == "ReadField(node_name)"
assert str(operation) == "ReadField(field_name)"
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
import pandas as pd


def apply_transform_fn(value: str, source: Source | None = None) -> Callable[..., Any]:
def apply_transform_fn(value: str, source: Source | None = None) -> str:
if source is None:
return value
if source.apply_transform_regex is not None:
source_regex = re.compile(source.apply_transform_regex)
match = source_regex.match(value)
if match is None:
return value
for group in match.groups():
if group is not None:
return group
Expand All @@ -32,8 +34,12 @@ def __call__(
if len(args) == 1:
return args[0]
elif len(args) == 2:
assert left.reference is not None, (
f'Reference for "{self.node.uid}" is None. It should be a valid'
assert left is not None and left.reference is not None, (
f'Left reference for "{self.node.uid}" is None. It should be a valid'
" reference."
)
assert right is not None and right.reference is not None, (
f'Right reference for "{self.node.uid}" is None. It should be a valid'
" reference."
)
left_key = left.reference[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Merge(Operation):

node: FileSet

def __call__(self, *args: list[pd.DataFrame]) -> pd.DataFrame:
def __call__(self, *args: pd.DataFrame) -> pd.DataFrame:
assert len(args) > 0, "No dataframe to merge."
df = args[0]
for other_df in args[1:]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""merge_test module."""

from ml_croissant._src.operation_graph.operations import merge
from ml_croissant._src.tests.nodes import empty_node
from ml_croissant._src.tests.nodes import empty_file_set


def test_str_representation():
operation = merge.Merge(node=empty_node)
assert str(operation) == "Merge(node_name)"
operation = merge.Merge(node=empty_file_set)
assert str(operation) == "Merge(file_set_name)"
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __post_init__(self):
def _edges_from_node(self):
return self.graph.edges(self.node, keys=True)

def assert_has_mandatory_properties(self, *mandatory_properties: list[str]):
def assert_has_mandatory_properties(self, *mandatory_properties: str):
"""Checks a node in the graph for existing properties with constraints.
Args:
Expand All @@ -90,7 +90,7 @@ def assert_has_mandatory_properties(self, *mandatory_properties: list[str]):
)
self.add_error(error)

def assert_has_optional_properties(self, *optional_properties: list[str]):
def assert_has_optional_properties(self, *optional_properties: str):
"""Checks a node in the graph for existing properties with constraints.
Args:
Expand All @@ -106,7 +106,7 @@ def assert_has_optional_properties(self, *optional_properties: list[str]):
)
self.add_warning(error)

def assert_has_exclusive_properties(self, *exclusive_properties: list[list[str]]):
def assert_has_exclusive_properties(self, *exclusive_properties: list[str]):
"""Checks a node in the graph for existing properties with constraints.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@


def test_there_exists_at_least_one_property():
@dataclasses.dataclass
class Node:
property1: str
property2: str
@dataclasses.dataclass(frozen=True, repr=False)
class Node(base_node.Node):
property1: str = ""
property2: str = ""

def check(self):
pass

node = Node(property1="property1", property2="property2")
node = Node(issues=Issues(), property1="property1", property2="property2")
assert base_node.there_exists_at_least_one_property(
node, ["property0", "property1"]
)
Expand Down
Loading

0 comments on commit f97ef8e

Please sign in to comment.