Skip to content

Commit

Permalink
Added naming convention check
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment committed Jul 23, 2023
1 parent ae088fd commit 3174472
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 51 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/test-backend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ jobs:
env:
TYPE_CHECK_LEVEL: 'error'

backend-name-check-test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.9'
- run: python ./backend/src/run.py --close-after-start
env:
NAME_CHECK_LEVEL: 'error'

backend-bootstrap:
runs-on: ubuntu-latest
strategy:
Expand Down
61 changes: 38 additions & 23 deletions backend/src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@

from base_types import InputId, OutputId
from custom_types import NodeType, RunFn
from node_check import (
NAME_CHECK_LEVEL,
TYPE_CHECK_LEVEL,
CheckFailedError,
CheckLevel,
check_naming_conventions,
check_schema_types,
)
from nodes.base_input import BaseInput
from nodes.base_output import BaseOutput
from nodes.group import Group, GroupId, NestedGroup, NestedIdGroup
from type_checking import (
TypeCheckLevel,
TypeMismatchError,
get_type_check_level,
typeValidateSchema,
)

KB = 1024**1
MB = 1024**2
Expand Down Expand Up @@ -121,22 +123,35 @@ def register(
if isinstance(see_also, str):
see_also = [see_also]

def run_check(level: CheckLevel, run: Callable[[bool], None]):
if level == CheckLevel.NONE:
return

try:
run(level == CheckLevel.FIX)
except CheckFailedError as e:
full_error_message = f"Error in {schema_id}: {e}"
if level == CheckLevel.ERROR:
# pylint: disable=raise-missing-from
raise CheckFailedError(full_error_message)
logger.warning(full_error_message)

def inner_wrapper(wrapped_func: T) -> T:
p_inputs, group_layout = _process_inputs(inputs)
p_outputs = _process_outputs(outputs)

TYPE_CHECK_LEVEL = get_type_check_level()

if TYPE_CHECK_LEVEL != TypeCheckLevel.NONE:
try:
typeValidateSchema(wrapped_func, node_type, p_inputs, p_outputs)
except TypeMismatchError as e:
full_error_message = f"Error in {schema_id}: {e}"
if TYPE_CHECK_LEVEL == TypeCheckLevel.WARN:
logger.warning(full_error_message)
elif TYPE_CHECK_LEVEL == TypeCheckLevel.ERROR:
# pylint: disable=raise-missing-from
raise TypeMismatchError(full_error_message)
run_check(
TYPE_CHECK_LEVEL,
lambda _: check_schema_types(
wrapped_func, node_type, p_inputs, p_outputs
),
)
run_check(
NAME_CHECK_LEVEL,
lambda fix: check_naming_conventions(
wrapped_func, node_type, name, fix
),
)

if decorators is not None:
for decorator in decorators:
Expand Down Expand Up @@ -268,7 +283,7 @@ def add(self, package: Package) -> Package:

def load_nodes(self, current_file: str):
import_errors: List[ImportError] = []
type_errors: List[TypeMismatchError] = []
failed_checks: List[CheckFailedError] = []

for package in list(self.packages.values()):
for file_path in _iter_py_files(os.path.dirname(package.where)):
Expand All @@ -285,12 +300,12 @@ def load_nodes(self, current_file: str):
logger.warning(f"Failed to load {module} ({file_path}): {e}")
except ValueError as e:
logger.warning(f"Failed to load {module} ({file_path}): {e}")
except TypeMismatchError as e:
except CheckFailedError as e:
logger.error(e)
type_errors.append(e)
failed_checks.append(e)

if len(type_errors) > 0:
raise RuntimeError(f"Type errors occurred in {len(type_errors)} node(s)")
if len(failed_checks) > 0:
raise RuntimeError(f"Checks failed in {len(failed_checks)} node(s)")

logger.info(import_errors)
self._refresh_nodes()
Expand Down
114 changes: 87 additions & 27 deletions backend/src/type_checking.py → backend/src/node_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
import inspect
import os
import pathlib
from enum import Enum
from typing import Any, Callable, Dict, List, NewType, Set, Union, cast, get_args

Expand All @@ -13,28 +14,42 @@
_Ty = NewType("_Ty", object)


class TypeMismatchError(Exception):
class CheckFailedError(Exception):
pass


# Enum for type check level
class TypeCheckLevel(Enum):
class CheckLevel(Enum):
NONE = "none"
WARN = "warn"
FIX = "fix"
ERROR = "error"

@staticmethod
def parse(s: str) -> CheckLevel:
s = s.strip().lower()
if s == CheckLevel.NONE.value:
return CheckLevel.NONE
elif s == CheckLevel.WARN.value:
return CheckLevel.WARN
elif s == CheckLevel.FIX.value:
return CheckLevel.FIX
elif s == CheckLevel.ERROR.value:
return CheckLevel.ERROR
else:
raise ValueError(f"Invalid check level: {s}")


def _get_check_level(name: str, default: CheckLevel) -> CheckLevel:
try:
s = os.environ.get(name, default.value)
return CheckLevel.parse(s)
except:
return default

# If it's stupid but it works, it's not stupid
def get_type_check_level() -> TypeCheckLevel:
type_check_level = os.environ.get("TYPE_CHECK_LEVEL", TypeCheckLevel.NONE.value)
if type_check_level.lower() == TypeCheckLevel.NONE.value:
return TypeCheckLevel.NONE
elif type_check_level.lower() == TypeCheckLevel.WARN.value:
return TypeCheckLevel.WARN
elif type_check_level.lower() == TypeCheckLevel.ERROR.value:
return TypeCheckLevel.ERROR
else:
return TypeCheckLevel.NONE

CHECK_LEVEL = _get_check_level("CHECK_LEVEL", CheckLevel.NONE)
NAME_CHECK_LEVEL = _get_check_level("NAME_CHECK_LEVEL", CHECK_LEVEL)
TYPE_CHECK_LEVEL = _get_check_level("TYPE_CHECK_LEVEL", CHECK_LEVEL)


class TypeTransformer(ast.NodeTransformer):
Expand Down Expand Up @@ -120,39 +135,39 @@ def get_type_annotations(fn: Callable) -> Dict[str, _Ty]:
def validate_return_type(return_type: _Ty, outputs: list[BaseOutput]):
if len(outputs) == 0:
if return_type is not None: # type: ignore
raise TypeMismatchError(
raise CheckFailedError(
f"Return type should be 'None' because there are no outputs"
)
elif len(outputs) == 1:
o = outputs[0]
if o.associated_type is not None and not is_subset_of(
return_type, o.associated_type
):
raise TypeMismatchError(
raise CheckFailedError(
f"Return type '{return_type}' must be a subset of '{o.associated_type}'"
)
else:
if not str(return_type).startswith("typing.Tuple["):
raise TypeMismatchError(
raise CheckFailedError(
f"Return type '{return_type}' must be a tuple because there are multiple outputs"
)

return_args = get_args(return_type)
if len(return_args) != len(outputs):
raise TypeMismatchError(
raise CheckFailedError(
f"Return type '{return_type}' must have the same number of arguments as there are outputs"
)

for o, return_arg in zip(outputs, return_args):
if o.associated_type is not None and not is_subset_of(
return_arg, o.associated_type
):
raise TypeMismatchError(
raise CheckFailedError(
f"Return type of {o.label} '{return_arg}' must be a subset of '{o.associated_type}'"
)


def typeValidateSchema(
def check_schema_types(
wrapped_func: Callable,
node_type: NodeType,
inputs: list[BaseInput],
Expand All @@ -173,7 +188,7 @@ def typeValidateSchema(
arg_spec = inspect.getfullargspec(wrapped_func)
for arg in arg_spec.args:
if not arg in ann:
raise TypeMismatchError(f"Missing type annotation for '{arg}'")
raise CheckFailedError(f"Missing type annotation for '{arg}'")

if node_type == "iteratorHelper":
# iterator helpers have inputs that do not describe the arguments of the function, so we can't check them
Expand All @@ -184,13 +199,13 @@ def typeValidateSchema(
context = [*ann.keys()][-1]
context_type = ann.pop(context)
if str(context_type) != "<class 'process.IteratorContext'>":
raise TypeMismatchError(
raise CheckFailedError(
f"Last argument of an iterator must be an IteratorContext, not '{context_type}'"
)

if arg_spec.varargs is not None:
if not arg_spec.varargs in ann:
raise TypeMismatchError(f"Missing type annotation for '{arg_spec.varargs}'")
raise CheckFailedError(f"Missing type annotation for '{arg_spec.varargs}'")
va_type = ann.pop(arg_spec.varargs)

# split inputs by varargs and non-varargs
Expand All @@ -203,7 +218,7 @@ def typeValidateSchema(

if associated_type is not None:
if not is_subset_of(associated_type, va_type):
raise TypeMismatchError(
raise CheckFailedError(
f"Input type of {i.label} '{associated_type}' is not assignable to varargs type '{va_type}'"
)

Expand All @@ -217,17 +232,62 @@ def typeValidateSchema(
if total is not None:
total_type = union_types(total)
if total_type != va_type:
raise TypeMismatchError(
raise CheckFailedError(
f"Varargs type '{va_type}' should be equal to the union of all arguments '{total_type}'"
)

if len(ann) != len(inputs):
raise TypeMismatchError(
raise CheckFailedError(
f"Number of inputs and arguments don't match: {len(ann)=} != {len(inputs)=}"
)
for (a_name, a_type), i in zip(ann.items(), inputs):
associated_type = i.associated_type
if associated_type is not None and a_type != associated_type:
raise TypeMismatchError(
raise CheckFailedError(
f"Expected type of {i.label} ({a_name}) to be '{associated_type}' but found '{a_type}'"
)


def check_naming_conventions(
wrapped_func: Callable,
node_type: NodeType,
name: str,
fix: bool,
):
expected_name = (
name.lower()
.replace(" (iterator)", "")
.replace(" ", "_")
.replace("-", "_")
.replace("(", "")
.replace(")", "")
.replace("&", "and")
)

if node_type == "iteratorHelper":
expected_name = "iterator_helper_" + expected_name

func_name = wrapped_func.__name__
file_path = pathlib.Path(inspect.getfile(wrapped_func))
file_name = file_path.stem

# check function name
if func_name != expected_name + "_node":
if not fix:
raise CheckFailedError(
f"Function name is '{func_name}', but it should be '{expected_name}_node'"
)

fixed_code = file_path.read_text(encoding="utf-8").replace(
f"def {func_name}(", f"def {expected_name}_node("
)
file_path.write_text(fixed_code, encoding="utf-8")

# check file name
if node_type != "iteratorHelper" and file_name != expected_name:
if not fix:
raise CheckFailedError(
f"File name is '{file_name}.py', but it should be '{expected_name}.py'"
)

os.rename(file_path, file_path.with_name(expected_name + ".py"))
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"scripts": {
"start": "electron-forge start -- --devtools",
"frontend": "electron-forge start -- --remote-backend=http://127.0.0.1:8000 --devtools",
"dev": "concurrently \"cd backend/src && cross-env TYPE_CHECK_LEVEL=warn nodemon ./run.py 8000\" \"electron-forge start -- --remote-backend=http://127.0.0.1:8000 --refresh --devtools\"",
"dev": "concurrently \"cd backend/src && cross-env CHECK_LEVEL=fix nodemon ./run.py 8000\" \"electron-forge start -- --remote-backend=http://127.0.0.1:8000 --refresh --devtools\"",
"debug": "concurrently \"npm run debug:py\" \"electron-forge start -- --remote-backend=http://127.0.0.1:8000 --refresh --devtools\"",
"debug:py": "cd backend/src && nodemon --exec \"python -m debugpy --listen 5678\" ./run.py 8000",
"package": "cross-env NODE_ENV=production electron-forge package",
Expand Down

0 comments on commit 3174472

Please sign in to comment.