From e274b6c7cae92b2e12d1d9ca8fc5d7a4110c547c Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Mon, 19 Apr 2021 21:26:58 +0300 Subject: [PATCH] Route group (#24) * Revamp routes to be nested inside RouteGroup * Cleanup some imports * Pass all Sanic routing tests * Passing all tests on main repo * Fix type annotations --- sanic_routing/__init__.py | 5 +- sanic_routing/exceptions.py | 2 +- sanic_routing/group.py | 122 ++++++++++++++++++ sanic_routing/route.py | 101 ++++++--------- sanic_routing/router.py | 131 +++++++++++++------- sanic_routing/tree.py | 241 ++++++++++++++++++++++++------------ tests/test_router_source.py | 6 +- tests/test_routing.py | 27 ++-- 8 files changed, 434 insertions(+), 201 deletions(-) create mode 100644 sanic_routing/group.py diff --git a/sanic_routing/__init__.py b/sanic_routing/__init__.py index 1579220..560af9d 100644 --- a/sanic_routing/__init__.py +++ b/sanic_routing/__init__.py @@ -1,5 +1,6 @@ +from .group import RouteGroup from .route import Route from .router import BaseRouter -__version__ = "0.5.2" -__all__ = ("BaseRouter", "Route") +__version__ = "0.6.0" +__all__ = ("BaseRouter", "Route", "RouteGroup") diff --git a/sanic_routing/exceptions.py b/sanic_routing/exceptions.py index ac6e9c0..9b9b1d3 100644 --- a/sanic_routing/exceptions.py +++ b/sanic_routing/exceptions.py @@ -22,7 +22,7 @@ class BadMethod(BaseException): class NoMethod(BaseException): def __init__( self, - message: str, + message: str = "Method does not exist", method: Optional[str] = None, allowed_methods: Optional[Set[str]] = None, ): diff --git a/sanic_routing/group.py b/sanic_routing/group.py new file mode 100644 index 0000000..ade0dad --- /dev/null +++ b/sanic_routing/group.py @@ -0,0 +1,122 @@ +from sanic_routing.utils import Immutable + +from .exceptions import InvalidUsage, RouteExists + + +class RouteGroup: + methods_index: Immutable + + def __init__(self, *routes) -> None: + if len(set(route.parts for route in routes)) > 1: + raise InvalidUsage("Cannot group routes with differing paths") + + if any(routes[-1].strict != route.strict for route in routes): + raise InvalidUsage("Cannot group routes with differing strictness") + + route_list = list(routes) + route_list.pop() + + self._routes = routes + self.pattern_idx = 0 + + def __str__(self): + display = ( + f"path={self.path or self.router.delimiter} len={len(self.routes)}" + ) + return f"<{self.__class__.__name__}: {display}>" + + def __iter__(self): + return iter(self.routes) + + def __getitem__(self, key): + return self.routes[key] + + def finalize(self): + self.methods_index = Immutable( + { + method: route + for route in self._routes + for method in route.methods + } + ) + + def reset(self): + self.methods_index = dict(self.methods_index) + + def merge(self, group, overwrite: bool = False, append: bool = False): + _routes = list(self._routes) + for other_route in group.routes: + for current_route in self: + if ( + current_route == other_route + or ( + current_route.requirements + and not other_route.requirements + ) + or ( + not current_route.requirements + and other_route.requirements + ) + ) and not append: + if not overwrite: + raise RouteExists( + f"Route already registered: {self.raw_path} " + f"[{','.join(self.methods)}]" + ) + else: + _routes.append(other_route) + self._routes = tuple(_routes) + + @property + def labels(self): + return self[0].labels + + @property + def methods(self): + return frozenset( + [method for route in self for method in route.methods] + ) + + @property + def params(self): + return self[0].params + + @property + def parts(self): + return self[0].parts + + @property + def path(self): + return self[0].path + + @property + def pattern(self): + return self[0].pattern + + @property + def raw_path(self): + return self[0].raw_path + + @property + def regex(self): + return self[0].regex + + @property + def requirements(self): + return [route.requirements for route in self if route.requirements] + + @property + def routes(self): + return self._routes + + @property + def router(self): + return self[0].router + + @property + def strict(self): + return self[0].strict + + @property + def unquote(self): + return self[0].unquote diff --git a/sanic_routing/route.py b/sanic_routing/route.py index ce6b668..80f41f5 100644 --- a/sanic_routing/route.py +++ b/sanic_routing/route.py @@ -1,9 +1,9 @@ import re import typing as t -from collections import defaultdict, namedtuple +from collections import namedtuple from types import SimpleNamespace -from .exceptions import InvalidUsage, ParameterNameConflicts, RouteExists +from .exceptions import InvalidUsage, ParameterNameConflicts from .patterns import REGEX_TYPES from .utils import Immutable, parts_to_path, path_to_parts @@ -12,7 +12,7 @@ ) -class Requirements(dict): +class Requirements(Immutable): def __hash__(self): return hash(frozenset(self.items())) @@ -22,10 +22,11 @@ class Route: "_params", "_raw_path", "ctx", - "handlers", + "handler", "labels", "methods", "name", + "overloaded", "params", "parts", "path", @@ -36,7 +37,6 @@ class Route: "static", "strict", "unquote", - "overloaded", ) def __init__( @@ -44,6 +44,9 @@ def __init__( router, raw_path: str, name: str, + handler: t.Callable[..., t.Any], + methods: t.Iterable[str], + requirements: t.Dict[str, t.Any] = None, strict: bool = False, unquote: bool = False, static: bool = False, @@ -52,10 +55,14 @@ def __init__( ): self.router = router self.name = name - self.handlers = defaultdict(lambda: defaultdict(list)) # type: ignore + self.handler = handler + self.methods = frozenset(methods) + self.requirements = Requirements(requirements or {}) + + self.ctx = SimpleNamespace() + self._params: t.Dict[int, ParamInfo] = {} self._raw_path = raw_path - self.ctx = SimpleNamespace() parts = path_to_parts(raw_path, self.router.delimiter) self.path = parts_to_path(parts, delimiter=self.router.delimiter) @@ -66,10 +73,11 @@ def __init__( self.pattern = None self.strict: bool = strict self.unquote: bool = unquote - self.requirements: t.Dict[int, t.Any] = {} self.labels: t.Optional[t.List[str]] = None - def __repr__(self): + self._setup_params() + + def __str__(self): display = ( f"name={self.name} path={self.path or self.router.delimiter}" if self.name and self.name != self.path @@ -77,53 +85,30 @@ def __repr__(self): ) return f"<{self.__class__.__name__}: {display}>" - def get_handler(self, raw_path, method, idx): - method = method or self.router.DEFAULT_METHOD - raw_path = raw_path.lstrip(self.router.delimiter) - try: - return self.handlers[raw_path][method][idx] - except (IndexError, KeyError): - raise self.router.method_handler_exception( - f"Method '{method}' not found on {self}", - method=method, - allowed_methods=self.methods, + def __eq__(self, other) -> bool: + if not isinstance(other, self.__class__): + return False + return bool( + ( + self.parts, + self.requirements, + ) + == ( + other.parts, + other.requirements, ) + and (self.methods & other.methods) + ) - def add_handler( - self, - raw_path, - handler, - method, - requirements, - overwrite: bool = False, - ): + def _setup_params(self): key_path = parts_to_path( - path_to_parts(raw_path, self.router.delimiter), + path_to_parts(self.raw_path, self.router.delimiter), self.router.delimiter, ) - - if ( - not self.router.stacking - and self.handlers.get(key_path, {}).get(method) - and ( - requirements is None - or Requirements(requirements) in self.requirements.values() - ) - and not overwrite - ): - raise RouteExists( - f"Route already registered: {key_path} [{method}]" - ) - - idx = len(self.handlers[key_path][method.upper()]) - self.handlers[key_path][method.upper()].append(handler) - if requirements is not None: - self.requirements[idx] = Requirements(requirements) - if not self.static: parts = path_to_parts(key_path, self.router.delimiter) for idx, part in enumerate(parts): - if "<" in part and len(self.handlers[key_path]) == 1: + if "<" in part: if ":" in part: ( name, @@ -173,17 +158,6 @@ def _finalize_params(self): sorted(params.items(), key=lambda param: self._sorting(param[1])) ) - def _finalize_methods(self): - self.methods = set() - for handlers in self.handlers.values(): - self.methods.update(set(key.upper() for key in handlers.keys())) - - def _finalize_handlers(self): - self.handlers = Immutable(self.handlers) - - def _reset_handlers(self): - self.handlers = dict(self.handlers) - def _compile_regex(self): components = [] @@ -225,11 +199,14 @@ def finalize(self): self._finalize_params() if self.regex: self._compile_regex() - self._finalize_methods() - self._finalize_handlers() + self.requirements = Immutable(self.requirements) def reset(self): - self._reset_handlers() + self.requirements = dict(self.requirements) + + @property + def defined_params(self): + return self._params @property def raw_path(self): diff --git a/sanic_routing/router.py b/sanic_routing/router.py index f77bc13..1d6f274 100644 --- a/sanic_routing/router.py +++ b/sanic_routing/router.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod from types import SimpleNamespace +from sanic_routing.group import RouteGroup + from .exceptions import BadMethod, FinalizationError, NoMethod, NotFound from .line import Line from .patterns import REGEX_TYPES @@ -28,19 +30,21 @@ def __init__( exception: t.Type[NotFound] = NotFound, method_handler_exception: t.Type[NoMethod] = NoMethod, route_class: t.Type[Route] = Route, + group_class: t.Type[RouteGroup] = RouteGroup, stacking: bool = False, cascade_not_found: bool = False, ) -> None: self._find_route = None self._matchers = None - self.static_routes: t.Dict[t.Tuple[str, ...], Route] = {} - self.dynamic_routes: t.Dict[t.Tuple[str, ...], Route] = {} - self.regex_routes: t.Dict[t.Tuple[str, ...], Route] = {} + self.static_routes: t.Dict[t.Tuple[str, ...], RouteGroup] = {} + self.dynamic_routes: t.Dict[t.Tuple[str, ...], RouteGroup] = {} + self.regex_routes: t.Dict[t.Tuple[str, ...], RouteGroup] = {} self.name_index: t.Dict[str, Route] = {} self.delimiter = delimiter self.exception = exception self.method_handler_exception = method_handler_exception self.route_class = route_class + self.group_class = group_class self.tree = Tree() self.finalized = False self.stacking = stacking @@ -61,9 +65,9 @@ def resolve( ): try: route, param_basket = self.find_route( - path, self, {"__handler_idx__": 0, "__params__": {}}, extra + path, method, self, {"__params__": {}}, extra ) - except NotFound as e: + except (NotFound, NoMethod) as e: if path.endswith(self.delimiter): return self.resolve( path=path[:-1], @@ -73,17 +77,29 @@ def resolve( ) raise self.exception(str(e), path=path) - handler = None - handler_idx = param_basket.pop("__handler_idx__") - raw_path = param_basket.pop("__raw_path__") + if isinstance(route, RouteGroup): + try: + route = route.methods_index[method] + except KeyError: + raise self.method_handler_exception( + f"Method '{method}' not found on {route}", + method=method, + allowed_methods=route.methods, + ) + params = param_basket.pop("__params__") if route.strict and orig and orig[-1] != route.path[-1]: raise self.exception("Path not found", path=path) - handler = route.get_handler(raw_path, method, handler_idx) + if method not in route.methods: + raise self.method_handler_exception( + f"Method '{method}' not found on {route}", + method=method, + allowed_methods=route.methods, + ) - return route, handler, params + return route, route.handler, params def add( self, @@ -95,7 +111,13 @@ def add( strict: bool = False, unquote: bool = False, # noqa overwrite: bool = False, + append: bool = False, ) -> Route: + if overwrite and append: + raise FinalizationError( + "Cannot add a route with both overwrite and append equal " + "to True" + ) if not methods: methods = [self.DEFAULT_METHOD] @@ -140,11 +162,15 @@ def add( self, path, name or "", + handler=handler, + methods=methods, + requirements=requirements, strict=strict, unquote=unquote, static=static, regex=regex, ) + group = self.group_class(route) # Catch the scenario where a route is overloaded with and # and without requirements, first as dynamic then as static @@ -154,20 +180,20 @@ def add( # Catch the reverse scenario where a route is overload first as static # and then as dynamic if not static and route.parts in self.static_routes: - route = self.static_routes.pop(route.parts) - self.dynamic_routes[route.parts] = route + existing_group = self.static_routes.pop(route.parts) + group.merge(existing_group, overwrite, append) else: if route.parts in routes: - route = routes[route.parts] - else: - routes[route.parts] = route + existing_group = routes[route.parts] + group.merge(existing_group, overwrite, append) + + routes[route.parts] = group if name: self.name_index[name] = route - for method in methods: - route.add_handler(path, handler, method, requirements, overwrite) + group.finalize() return route @@ -178,8 +204,14 @@ def finalize(self, do_compile: bool = True): raise FinalizationError("Cannot finalize with no routes defined.") self.finalized = True - for route in self.routes.values(): - route.finalize() + for group in ( + list(self.static_routes.values()) + + list(self.dynamic_routes.values()) + + list(self.regex_routes.values()) + ): + group.finalize() + for route in group.routes: + route.finalize() self._generate_tree() self._render(do_compile) @@ -189,16 +221,25 @@ def reset(self): self.tree = Tree() self._find_route = None - for route in self.routes.values(): - route.reset() + for group in ( + list(self.static_routes.values()) + + list(self.dynamic_routes.values()) + + list(self.regex_routes.values()) + ): + group.reset() + for route in group.routes: + route.reset() def _generate_tree(self) -> None: - self.tree.generate(self.dynamic_routes) + self.tree.generate( + list(self.dynamic_routes.values()) + + list(self.regex_routes.values()) + ) self.tree.finalize() def _render(self, do_compile: bool = True) -> None: src = [ - Line("def find_route(path, router, basket, extra):", 0), + Line("def find_route(path, method, router, basket, extra):", 0), Line("parts = tuple(path[1:].split(router.delimiter))", 1), ] delayed = [] @@ -210,9 +251,12 @@ def _render(self, do_compile: bool = True) -> None: # potentially has an impact on performance src += [ Line("try:", 1), - Line("route = router.static_routes[parts]", 2), + Line( + "group = router.static_routes[parts]", + 2, + ), Line("basket['__raw_path__'] = path", 2), - Line("return route, basket", 2), + Line("return group, basket", 2), Line("except KeyError:", 1), Line("pass", 2), ] @@ -228,11 +272,6 @@ def _render(self, do_compile: bool = True) -> None: # Line("basket['__raw_path__'] = route.path", 2), # Line("return route, basket", 2), # ] - - if self.dynamic_routes: - src += [Line("num = len(parts)", 1)] - src += self.tree.render() - if self.regex_routes: routes = sorted( self.regex_routes.values(), @@ -240,25 +279,15 @@ def _render(self, do_compile: bool = True) -> None: reverse=True, ) delayed.append(Line("matchers = [", 0)) - for idx, route in enumerate(routes): - delayed.append(Line(f"re.compile(r'^{route.pattern}$'),", 1)) - src.extend( - [ - Line(f"match = router.matchers[{idx}].match(path)", 1), - Line("if match:", 1), - Line("basket['__params__'] = match.groupdict()", 2), - Line(f"basket['__raw_path__'] = '{route.path}'", 2), - Line( - ( - f"return router.name_index['{route.name}'], " - "basket" - ), - 2, - ), - ] - ) + for idx, group in enumerate(routes): + group.pattern_idx = idx + delayed.append(Line(f"re.compile(r'^{group.pattern}$'),", 1)) delayed.append(Line("]", 0)) + if self.dynamic_routes or self.regex_routes: + src += [Line("num = len(parts)", 1)] + src += self.tree.render() + src.append(Line("raise NotFound", 1)) src.extend(delayed) @@ -297,13 +326,19 @@ def matchers(self): return self._matchers @property - def routes(self): + def groups(self): return { **self.static_routes, **self.dynamic_routes, **self.regex_routes, } + @property + def routes(self): + return tuple( + [route for group in self.groups.values() for route in group] + ) + def optimize(self, src: t.List[Line]) -> None: """ Insert NotFound exceptions to be able to bail as quick as possible, diff --git a/sanic_routing/tree.py b/sanic_routing/tree.py index c65afaa..f2f8ec5 100644 --- a/sanic_routing/tree.py +++ b/sanic_routing/tree.py @@ -1,9 +1,9 @@ import typing as t from logging import getLogger +from .group import RouteGroup from .line import Line from .patterns import REGEX_PARAM_NAME, REGEX_TYPES -from .route import Route logger = getLogger("sanic.root") @@ -19,16 +19,17 @@ def __init__( self.children: t.Dict[str, "Node"] = {} self.level = 0 self.offset = 0 - self.route: t.Optional[Route] = None + self.group: t.Optional[RouteGroup] = None self.dynamic = False self.first = False self.last = False self.children_basketed = False + self.children_param_injected = False - def __repr__(self) -> str: + def __str__(self) -> str: internals = ", ".join( f"{prop}={getattr(self, prop)}" - for prop in ["part", "level", "route", "dynamic"] + for prop in ["part", "level", "group", "dynamic"] if getattr(self, prop) or prop in ["level"] ) return f"" @@ -53,16 +54,18 @@ def display(self) -> None: for child in self.children.values(): child.display() - def render(self) -> t.List[Line]: + def render(self) -> t.Tuple[t.List[Line], t.List[Line]]: + output: t.List[Line] = [] + delayed: t.List[Line] = [] + final: t.List[Line] = [] + if not self.root: - output, delayed = self.to_src() - else: - output = [] - delayed = [] + output, delayed, final = self.to_src() for child in self.children.values(): - output += child.render() - output += delayed - return output + o, f = child.render() + output += o + final += f + return output + delayed, final def apply_offset(self, amt, apply_self=True, apply_children=False): if apply_self: @@ -71,9 +74,10 @@ def apply_offset(self, amt, apply_self=True, apply_children=False): for child in self.children.values(): child.apply_offset(amt, apply_children=True) - def to_src(self) -> t.Tuple[t.List[Line], t.List[Line]]: + def to_src(self) -> t.Tuple[t.List[Line], t.List[Line], t.List[Line]]: indent = (self.level + 1) * 2 - 3 + self.offset delayed: t.List[Line] = [] + final: t.List[Line] = [] src: t.List[Line] = [] level = self.level - 1 @@ -81,31 +85,37 @@ def to_src(self) -> t.Tuple[t.List[Line], t.List[Line]]: len_check = "" return_bump = 1 - if self.first or self.root: - src = [] - operation = ">" - use_level = level - conditional = "if" - if ( - self.last - and self.route - and not self.children - and not self.route.requirements - and not self.route.router.regex_routes - ): - use_level = self.level - operation = "==" - equality_check = True - conditional = "elif" - src.extend( - [ - Line(f"if num > {use_level}:", indent), - Line("raise NotFound", indent + 1), - ] + if self.first: + if level == 0: + if self.group: + src.append(Line("if True:", indent)) + else: + src.append(Line("if parts[0]:", indent)) + else: + operation = ">" + use_level = level + conditional = "if" + if ( + self.last + and self.group + and not self.children + and not self.group.requirements + ): + use_level = self.level + operation = "==" + equality_check = True + conditional = "elif" + src.extend( + [ + Line(f"if num > {use_level}:", indent), + Line("...", indent + 1) + if self._has_nested_path(self) + else Line("raise NotFound", indent + 1), + ] + ) + src.append( + Line(f"{conditional} num {operation} {use_level}:", indent) ) - src.append( - Line(f"{conditional} num {operation} {use_level}:", indent) - ) if self.dynamic: if not self.parent.children_basketed: @@ -133,62 +143,102 @@ def to_src(self) -> t.Tuple[t.List[Line], t.List[Line]]: if self.children: return_bump += 1 - if self.route and not self.route.regex: + if self.group: location = delayed if self.children else src - if self.route.requirements: + route_idx: t.Union[int, str] = 0 + if self.group.requirements: + route_idx = "route_idx" self._inject_requirements( location, indent + return_bump + bool(not self.children) ) - if self.route.params: + if self.group.params and not self.group.regex: if not self.last: return_bump += 1 self._inject_params( - location, indent + return_bump + bool(not self.children) + location, + indent + return_bump + bool(not self.children), + not self.parent.children_param_injected, ) - param_offset = bool(self.route.params) + if not self.parent.children_param_injected: + self.parent.children_param_injected = True + param_offset = bool(self.group.params) return_indent = ( indent + return_bump + bool(not self.children) + param_offset ) + if route_idx == 0 and len(self.group.routes) > 1: + route_idx = "route_idx" + for i, route in enumerate(self.group.routes): + if_stmt = "if" if i == 0 else "elif" + location.extend( + [ + Line( + f"{if_stmt} method in {route.methods}:", + return_indent, + ), + Line(f"route_idx = {i}", return_indent + 1), + ] + ) + location.extend( + [ + Line("else:", return_indent), + Line("raise NoMethod", return_indent + 1), + ] + ) + if self.group.regex: + if self._has_nested_path(self, recursive=False): + location.append(Line("...", return_indent - 1)) + return_indent = 2 + location = final + self._inject_regex( + location, + return_indent, + not self.parent.children_param_injected, + ) + routes = "regex_routes" if self.group.regex else "dynamic_routes" + route_return = ( + "" if self.group.router.stacking else f"[{route_idx}]" + ) location.extend( [ - Line( - (f"basket['__raw_path__'] = '{self.route.path}'"), - return_indent, - ), Line( ( - f"return router.dynamic_routes[{self.route.parts}]" - ", basket" + f"return router.{routes}[{self.group.parts}]" + f"{route_return}, basket" ), return_indent, ), - # Line("...", return_indent - 1, render=True), ] ) - if self.route.requirements and self.last and len_check: - location.append(Line("raise NotFound", return_indent - 1)) - - if self.route.params: - location.append(Line("...", return_indent - 1, render=False)) - if self.last: - location.append( - Line("...", return_indent - 2, render=False), - ) - return src, delayed + return src, delayed, final def add_child(self, child: "Node") -> None: self._children[child.part] = child def _inject_requirements(self, location, indent): - for k, (idx, reqs) in enumerate(self.route.requirements.items()): + for k, route in enumerate(self.group): + if k == 0: + location.extend( + [ + Line(f"if num > {self.level}:", indent), + Line("raise NotFound", indent + 1), + ] + ) + conditional = "if" if k == 0 else "elif" location.extend( [ - Line((f"{conditional} extra == {reqs}:"), indent), - Line((f"basket['__handler_idx__'] = {idx}"), indent + 1), + Line( + ( + f"{conditional} extra == {route.requirements} " + f"and method in {route.methods}:" + ), + indent, + ), + Line((f"route_idx = {k}"), indent + 1), ] ) + location.extend( [ Line(("else:"), indent), @@ -196,20 +246,40 @@ def _inject_requirements(self, location, indent): ] ) - def _inject_params(self, location, indent): + def _inject_regex(self, location, indent, first_params): + location.extend( + [ + Line( + ( + "match = router.matchers" + f"[{self.group.pattern_idx}].match(path)" + ), + indent - 1, + ), + Line("if match:", indent - 1), + Line( + "basket['__params__'] = match.groupdict()", + indent, + ), + ] + ) + + def _inject_params(self, location, indent, first_params): if self.last: lines = [ Line(f"if num > {self.level}:", indent), Line("raise NotFound", indent + 1), ] else: - lines = [ - Line(f"if num == {self.level}:", indent - 1), - ] + lines = [] + if first_params: + lines.append( + Line(f"if num == {self.level}:", indent - 1), + ) lines.append(Line("try:", indent)) - for idx, param in self.route.params.items(): - unquote_start = "unquote(" if self.route.unquote else "" - unquote_end = ")" if self.route.unquote else "" + for idx, param in self.group.params.items(): + unquote_start = "unquote(" if self.group.unquote else "" + unquote_end = ")" if self.group.unquote else "" lines.append( Line( f"basket['__params__']['{param.name}'] = " @@ -228,8 +298,20 @@ def _inject_params(self, location, indent): ] ) + def _has_nested_path(self, node, recursive=True): + if node.group and ( + (node.group.labels and "path" in node.group.labels) + or (node.group.pattern and r"/" in node.group.pattern) + ): + return True + if recursive and node.children: + for child in node.children: + if self._has_nested_path(child): + return True + return False + @staticmethod - def _sorting(item) -> t.Tuple[bool, int, str, int]: + def _sorting(item) -> t.Tuple[bool, int, str, bool, int]: key, child = item type_ = 0 if child.dynamic: @@ -240,7 +322,13 @@ def _sorting(item) -> t.Tuple[bool, int, str, int]: type_ = list(REGEX_TYPES.keys()).index(param_type) except ValueError: type_ = len(list(REGEX_TYPES.keys())) - return child.dynamic, len(child._children), key, type_ * -1 + return ( + child.dynamic, + len(child._children), + key, + bool(child.group and child.group.regex), + type_ * -1, + ) class Tree: @@ -248,10 +336,10 @@ def __init__(self) -> None: self.root = Node(root=True) self.root.level = 0 - def generate(self, routes: t.Dict[t.Tuple[str, ...], Route]) -> None: - for route in routes.values(): + def generate(self, groups: t.Iterable[RouteGroup]) -> None: + for group in groups: current = self.root - for level, part in enumerate(route.parts): + for level, part in enumerate(group.parts): if part not in current._children: current.add_child(Node(part=part, parent=current)) current = current._children[part] @@ -261,7 +349,7 @@ def generate(self, routes: t.Dict[t.Tuple[str, ...], Route]) -> None: if current.dynamic and not REGEX_PARAM_NAME.match(part): raise ValueError(f"Invalid declaration: {part}") - current.route = route + current.group = group def display(self) -> None: """ @@ -270,7 +358,8 @@ def display(self) -> None: self.root.display() def render(self) -> t.List[Line]: - return self.root.render() + o, f = self.root.render() + return o + f def finalize(self): self.root.finalize_children() diff --git a/tests/test_router_source.py b/tests/test_router_source.py index 0b79803..d6ebdf3 100644 --- a/tests/test_router_source.py +++ b/tests/test_router_source.py @@ -1,4 +1,5 @@ import pytest + from sanic_routing import BaseRouter @@ -10,8 +11,8 @@ def get(self, path, method): @pytest.mark.parametrize( "cascade,lines,not_founds", ( - (True, 32, 6), - (False, 30, 4), + (True, 32, 8), + (False, 28, 4), ), ) def test_route_correct_coercion(cascade, lines, not_founds): @@ -23,5 +24,6 @@ def handler(): router.add("//two/three", handler) router.finalize() + assert router.find_route_src.count("\n") == lines assert router.find_route_src.count("raise NotFound") == not_founds diff --git a/tests/test_routing.py b/tests/test_routing.py index c8712ce..cb0b8ef 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -2,6 +2,7 @@ from datetime import date import pytest + from sanic_routing import BaseRouter from sanic_routing.exceptions import NoMethod, NotFound, RouteExists @@ -43,14 +44,18 @@ def test_add_duplicate_route_fails(): router = Router() router.add("/foo/bar", lambda *args, **kwargs: ...) + assert len(router.routes) == 1 with pytest.raises(RouteExists): router.add("/foo/bar", lambda *args, **kwargs: ...) router.add("/foo/bar", lambda *args, **kwargs: ..., overwrite=True) + assert len(router.routes) == 1 router.add("/foo/", lambda *args, **kwargs: ...) + assert len(router.routes) == 2 with pytest.raises(RouteExists): router.add("/foo/", lambda *args, **kwargs: ...) router.add("/foo/", lambda *args, **kwargs: ..., overwrite=True) + assert len(router.routes) == 2 def test_add_duplicate_route_alt_method(): @@ -63,14 +68,12 @@ def test_add_duplicate_route_alt_method(): assert len(router.static_routes) == 1 assert len(router.dynamic_routes) == 2 - static_handlers = list( - list(router.static_routes.values())[0].handlers.values() - ) - assert len(static_handlers[0]) == 2 + static_handlers = list(router.static_routes.values())[0] + assert len(static_handlers.routes) == 2 - for route in router.dynamic_routes.values(): - assert len(list(route.handlers.values())) == 1 - assert len(list(route.handlers.values())) == 1 + for group in router.dynamic_routes.values(): + assert len(group.routes) == 1 + assert len(group.routes) == 1 def test_route_does_not_exist(): @@ -140,7 +143,11 @@ def test_casting(handler, label, value, cast_type): def test_conditional_check_proper_compile(handler): router = Router() router.add("//", handler, strict=True) - router.add("//", handler, strict=True, requirements={"foo": "bar"}) + + with pytest.raises(RouteExists): + router.add( + "//", handler, strict=True, requirements={"foo": "bar"} + ) router.finalize() assert router.finalized @@ -160,7 +167,7 @@ def test_use_param_name(handler, param_name): path_part_with_param = f"<{param_name}>" router.add(f"/path/{path_part_with_param}", handler) route = list(router.routes)[0] - assert ("path", path_part_with_param) == route + assert ("path", path_part_with_param) == route.parts @pytest.mark.parametrize( @@ -177,7 +184,7 @@ def test_use_param_name_with_casing(handler, param_name): path_part_with_param = f"<{param_name}:str>" router.add(f"/path/{path_part_with_param}", handler) route = list(router.routes)[0] - assert ("path", path_part_with_param) == route + assert ("path", path_part_with_param) == route.parts def test_use_route_contains_children(handler):