From cc013a5aba3ec20a2f2e9d924f40363e2615b7ef Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Mon, 4 Nov 2024 20:51:22 +0100 Subject: [PATCH 01/16] Handled mypy issues in utils/bezier.py --- manim/utils/bezier.py | 107 +++++++++++++++++++++++++++--------------- mypy.ini | 5 +- 2 files changed, 73 insertions(+), 39 deletions(-) diff --git a/manim/utils/bezier.py b/manim/utils/bezier.py index a424b617e1..3b7129b0f3 100644 --- a/manim/utils/bezier.py +++ b/manim/utils/bezier.py @@ -35,6 +35,8 @@ BezierPoints, BezierPoints_Array, ColVector, + InternalPoint3D, + InternalPoint3D_Array, MatrixMN, Point3D, Point3D_Array, @@ -54,10 +56,12 @@ def bezier( @overload def bezier( points: Sequence[Point3D_Array], -) -> Callable[[float | ColVector], Point3D_Array]: ... +) -> Callable[[float | ColVector], Point3D | Point3D_Array]: ... -def bezier(points): +def bezier( + points: Point3D_Array | Sequence[Point3D_Array], +) -> Callable[[float | ColVector], Point3D | Point3D_Array]: """Classic implementation of a Bézier curve. Parameters @@ -111,21 +115,21 @@ def bezier(points): if degree == 0: - def zero_bezier(t): + def zero_bezier(t: float | ColVector) -> Point3D | Point3D_Array: return np.ones_like(t) * P[0] return zero_bezier if degree == 1: - def linear_bezier(t): + def linear_bezier(t: float | ColVector) -> Point3D | Point3D_Array: return P[0] + t * (P[1] - P[0]) return linear_bezier if degree == 2: - def quadratic_bezier(t): + def quadratic_bezier(t: float | ColVector) -> Point3D | Point3D_Array: t2 = t * t mt = 1 - t mt2 = mt * mt @@ -135,7 +139,7 @@ def quadratic_bezier(t): if degree == 3: - def cubic_bezier(t): + def cubic_bezier(t: float | ColVector) -> Point3D | Point3D_Array: t2 = t * t t3 = t2 * t mt = 1 - t @@ -145,11 +149,12 @@ def cubic_bezier(t): return cubic_bezier - def nth_grade_bezier(t): + def nth_grade_bezier(t: float | ColVector) -> Point3D | Point3D_Array: is_scalar = not isinstance(t, np.ndarray) if is_scalar: B = np.empty((1, *P.shape)) else: + assert isinstance(t, np.ndarray) t = t.reshape(-1, *[1 for dim in P.shape]) B = np.empty((t.shape[0], *P.shape)) B[:] = P @@ -162,7 +167,8 @@ def nth_grade_bezier(t): # In the end, there shall be the evaluation at t of a single Bezier curve of # grade d, stored in the first slot of B if is_scalar: - return B[0, 0] + val: Point3D = B[0, 0] + return val return B[:, 0] return nth_grade_bezier @@ -700,7 +706,7 @@ def split_bezier(points: BezierPoints, t: float) -> Point3D_Array: # Memos explained in subdivide_bezier docstring -SUBDIVISION_MATRICES = [{} for i in range(4)] +SUBDIVISION_MATRICES: list[dict[int, npt.NDArray]] = [{} for i in range(4)] def _get_subdivision_matrix(n_points: int, n_divisions: int) -> MatrixMN: @@ -812,7 +818,9 @@ def _get_subdivision_matrix(n_points: int, n_divisions: int) -> MatrixMN: return subdivision_matrix -def subdivide_bezier(points: BezierPoints, n_divisions: int) -> Point3D_Array: +def subdivide_bezier( + points: InternalPoint3D_Array, n_divisions: int +) -> InternalPoint3D_Array: r"""Subdivide a Bézier curve into :math:`n` subcurves which have the same shape. The points at which the curve is split are located at the @@ -1012,14 +1020,22 @@ def interpolate(start: float, end: float, alpha: ColVector) -> ColVector: ... @overload -def interpolate(start: Point3D, end: Point3D, alpha: float) -> Point3D: ... +def interpolate( + start: InternalPoint3D, end: InternalPoint3D, alpha: float +) -> Point3D: ... @overload -def interpolate(start: Point3D, end: Point3D, alpha: ColVector) -> Point3D_Array: ... +def interpolate( + start: InternalPoint3D, end: InternalPoint3D, alpha: ColVector +) -> Point3D_Array: ... -def interpolate(start, end, alpha): +def interpolate( + start: float | InternalPoint3D, + end: float | InternalPoint3D, + alpha: float | ColVector, +) -> float | ColVector | Point3D | Point3D_Array: """Linearly interpolates between two values ``start`` and ``end``. Parameters @@ -1099,10 +1115,12 @@ def mid(start: float, end: float) -> float: ... @overload -def mid(start: Point3D, end: Point3D) -> Point3D: ... +def mid(start: InternalPoint3D, end: InternalPoint3D) -> Point3D: ... -def mid(start: float | Point3D, end: float | Point3D) -> float | Point3D: +def mid( + start: float | InternalPoint3D, end: float | InternalPoint3D +) -> float | Point3D: """Returns the midpoint between two values. Parameters @@ -1124,15 +1142,21 @@ def inverse_interpolate(start: float, end: float, value: float) -> float: ... @overload -def inverse_interpolate(start: float, end: float, value: Point3D) -> Point3D: ... +def inverse_interpolate( + start: float, end: float, value: InternalPoint3D +) -> InternalPoint3D: ... @overload -def inverse_interpolate(start: Point3D, end: Point3D, value: Point3D) -> Point3D: ... +def inverse_interpolate( + start: InternalPoint3D, end: InternalPoint3D, value: InternalPoint3D +) -> Point3D: ... def inverse_interpolate( - start: float | Point3D, end: float | Point3D, value: float | Point3D + start: float | InternalPoint3D, + end: float | InternalPoint3D, + value: float | InternalPoint3D, ) -> float | Point3D: """Perform inverse interpolation to determine the alpha values that would produce the specified ``value`` @@ -1186,7 +1210,7 @@ def match_interpolate( new_end: float, old_start: float, old_end: float, - old_value: Point3D, + old_value: InternalPoint3D, ) -> Point3D: ... @@ -1195,7 +1219,7 @@ def match_interpolate( new_end: float, old_start: float, old_end: float, - old_value: float | Point3D, + old_value: float | InternalPoint3D, ) -> float | Point3D: """Interpolate a value from an old range to a new range. @@ -1227,7 +1251,7 @@ def match_interpolate( return interpolate( new_start, new_end, - old_alpha, # type: ignore[arg-type] + old_alpha, ) @@ -1263,7 +1287,8 @@ def get_smooth_cubic_bezier_handle_points( # they can only be an interpolation of these two anchors with alphas # 1/3 and 2/3, which will draw a straight line between the anchors. if n_anchors == 2: - return interpolate(anchors[0], anchors[1], np.array([[1 / 3], [2 / 3]])) + val = interpolate(anchors[0], anchors[1], np.array([[1 / 3], [2 / 3]])) + return (val[0], val[1]) # Handle different cases depending on whether the points form a closed # curve or not @@ -1738,7 +1763,12 @@ def get_quadratic_approximation_of_cubic( ) -> QuadraticBezierPoints_Array: ... -def get_quadratic_approximation_of_cubic(a0, h0, h1, a1): +def get_quadratic_approximation_of_cubic( + a0: Point3D | Point3D_Array, + h0: Point3D | Point3D_Array, + h1: Point3D | Point3D_Array, + a1: Point3D | Point3D_Array, +) -> QuadraticBezierPoints_Array: r"""If ``a0``, ``h0``, ``h1`` and ``a1`` are the control points of a cubic Bézier curve, approximate the curve with two quadratic Bézier curves and return an array of 6 points, where the first 3 points represent the first @@ -1842,33 +1872,33 @@ def get_quadratic_approximation_of_cubic(a0, h0, h1, a1): If ``a0``, ``h0``, ``h1`` and ``a1`` have different dimensions, or if their number of dimensions is not 1 or 2. """ - a0 = np.asarray(a0) - h0 = np.asarray(h0) - h1 = np.asarray(h1) - a1 = np.asarray(a1) - - if all(arr.ndim == 1 for arr in (a0, h0, h1, a1)): - num_curves, dim = 1, a0.shape[0] - elif all(arr.ndim == 2 for arr in (a0, h0, h1, a1)): - num_curves, dim = a0.shape + a0c = np.asarray(a0) + h0c = np.asarray(h0) + h1c = np.asarray(h1) + a1c = np.asarray(a1) + + if all(arr.ndim == 1 for arr in (a0c, h0c, h1c, a1c)): + num_curves, dim = 1, a0c.shape[0] + elif all(arr.ndim == 2 for arr in (a0c, h0c, h1c, a1c)): + num_curves, dim = a0c.shape else: raise ValueError("All arguments must be Point3D or Point3D_Array.") - m0 = 0.25 * (3 * h0 + a0) - m1 = 0.25 * (3 * h1 + a1) + m0 = 0.25 * (3 * h0c + a0c) + m1 = 0.25 * (3 * h1c + a1c) k = 0.5 * (m0 + m1) result = np.empty((6 * num_curves, dim)) - result[0::6] = a0 + result[0::6] = a0c result[1::6] = m0 result[2::6] = k result[3::6] = k result[4::6] = m1 - result[5::6] = a1 + result[5::6] = a1c return result -def is_closed(points: Point3D_Array) -> bool: +def is_closed(points: InternalPoint3D_Array) -> bool: """Returns ``True`` if the spline given by ``points`` is closed, by checking if its first and last points are close to each other, or``False`` otherwise. @@ -1938,7 +1968,8 @@ def is_closed(points: Point3D_Array) -> bool: return False if abs(end[1] - start[1]) > tolerance[1]: return False - return abs(end[2] - start[2]) <= tolerance[2] + val: bool = abs(end[2] - start[2]) <= tolerance[2] + return val def proportions_along_bezier_curve_for_point( diff --git a/mypy.ini b/mypy.ini index 65a77f1d00..2bc7cc2d27 100644 --- a/mypy.ini +++ b/mypy.ini @@ -71,7 +71,7 @@ ignore_errors = True ignore_errors = True [mypy-manim.mobject.geometry.*] -ignore_errors = False +ignore_errors = True [mypy-manim.plugins.*] ignore_errors = True @@ -85,6 +85,9 @@ ignore_errors = True [mypy-manim.utils.*] ignore_errors = True +[mypy-manim.utils.bezier.*] +ignore_errors = False + [mypy-manim.utils.iterables] ignore_errors = False warn_return_any = False From 8daf40f8313cdfb55a3bf916d1d7738fd9408b2e Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Mon, 4 Nov 2024 20:51:57 +0100 Subject: [PATCH 02/16] Disable mypy errors in manim.utils.* --- mypy.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy.ini b/mypy.ini index 2bc7cc2d27..69c061f5ff 100644 --- a/mypy.ini +++ b/mypy.ini @@ -83,7 +83,7 @@ ignore_errors = True ignore_errors = True [mypy-manim.utils.*] -ignore_errors = True +ignore_errors = False [mypy-manim.utils.bezier.*] ignore_errors = False From 059b71eee929fdfe3083aca2ef776c512a4e80ca Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Mon, 4 Nov 2024 21:07:05 +0100 Subject: [PATCH 03/16] Fix mypy errors in utils/unit.py --- manim/utils/unit.py | 11 ++++++----- mypy.ini | 5 ++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/manim/utils/unit.py b/manim/utils/unit.py index f629e070db..c4755f0191 100644 --- a/manim/utils/unit.py +++ b/manim/utils/unit.py @@ -5,20 +5,21 @@ import numpy as np from .. import config, constants +from ..typing import Vector3D __all__ = ["Pixels", "Degrees", "Munits", "Percent"] class _PixelUnits: - def __mul__(self, val): + def __mul__(self, val: float) -> float: return val * config.frame_width / config.pixel_width - def __rmul__(self, val): + def __rmul__(self, val: float) -> float: return val * config.frame_width / config.pixel_width class Percent: - def __init__(self, axis): + def __init__(self, axis: Vector3D) -> None: if np.array_equal(axis, constants.X_AXIS): self.length = config.frame_width if np.array_equal(axis, constants.Y_AXIS): @@ -26,10 +27,10 @@ def __init__(self, axis): if np.array_equal(axis, constants.Z_AXIS): raise NotImplementedError("length of Z axis is undefined") - def __mul__(self, val): + def __mul__(self, val: float) -> float: return val / 100 * self.length - def __rmul__(self, val): + def __rmul__(self, val: float) -> float: return val / 100 * self.length diff --git a/mypy.ini b/mypy.ini index 69c061f5ff..3ced4b3430 100644 --- a/mypy.ini +++ b/mypy.ini @@ -83,11 +83,14 @@ ignore_errors = True ignore_errors = True [mypy-manim.utils.*] -ignore_errors = False +ignore_errors = True [mypy-manim.utils.bezier.*] ignore_errors = False +[mypy-manim.utils.unit.*] +ignore_errors = False + [mypy-manim.utils.iterables] ignore_errors = False warn_return_any = False From c330912b77b25d7ad0d7e173271efa4cad6c67e7 Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Mon, 4 Nov 2024 21:07:46 +0100 Subject: [PATCH 04/16] Handle mypy errors in utils/debug.py --- manim/mobject/text/numbers.py | 5 ++++- manim/utils/debug.py | 13 ++++++++----- mypy.ini | 3 +++ 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/manim/mobject/text/numbers.py b/manim/mobject/text/numbers.py index 5283c24a20..1c74cf5f0a 100644 --- a/manim/mobject/text/numbers.py +++ b/manim/mobject/text/numbers.py @@ -5,6 +5,7 @@ __all__ = ["DecimalNumber", "Integer", "Variable"] from collections.abc import Sequence +from typing import Any import numpy as np @@ -327,7 +328,9 @@ def construct(self): self.add(Integer(number=6.28).set_x(-1.5).set_y(-2).set_color(YELLOW).scale(1.4)) """ - def __init__(self, number=0, num_decimal_places=0, **kwargs): + def __init__( + self, number: float = 0, num_decimal_places: int = 0, **kwargs: Any + ) -> None: super().__init__(number=number, num_decimal_places=num_decimal_places, **kwargs) def get_value(self): diff --git a/manim/utils/debug.py b/manim/utils/debug.py index a2c56a69d9..bca9537a7b 100644 --- a/manim/utils/debug.py +++ b/manim/utils/debug.py @@ -5,14 +5,17 @@ __all__ = ["print_family", "index_labels"] +from typing import Any + from manim.mobject.mobject import Mobject from manim.mobject.text.numbers import Integer +from manim.utils.color import ManimColor from ..mobject.types.vectorized_mobject import VGroup from .color import BLACK -def print_family(mobject, n_tabs=0): +def print_family(mobject: Mobject, n_tabs: int = 0) -> None: """For debugging purposes""" print("\t" * n_tabs, mobject, id(mobject)) for submob in mobject.submobjects: @@ -22,10 +25,10 @@ def print_family(mobject, n_tabs=0): def index_labels( mobject: Mobject, label_height: float = 0.15, - background_stroke_width=5, - background_stroke_color=BLACK, - **kwargs, -): + background_stroke_width: float = 5, + background_stroke_color: ManimColor = BLACK, + **kwargs: Any, +) -> VGroup: r"""Returns a :class:`~.VGroup` of :class:`~.Integer` mobjects that shows the index of each submobject. diff --git a/mypy.ini b/mypy.ini index 3ced4b3430..6aa8a808e0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -88,6 +88,9 @@ ignore_errors = True [mypy-manim.utils.bezier.*] ignore_errors = False +[mypy-manim.utils.debug.*] +ignore_errors = False + [mypy-manim.utils.unit.*] ignore_errors = False From b9f18b79fbb885b8eae3da2b296d078c3f926324 Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Mon, 4 Nov 2024 21:58:23 +0100 Subject: [PATCH 05/16] Fix mypy issues in utils.color.* --- manim/utils/color/core.py | 63 +++++++++++++++++++++++---------------- mypy.ini | 3 ++ 2 files changed, 40 insertions(+), 26 deletions(-) diff --git a/manim/utils/color/core.py b/manim/utils/color/core.py index 63c1605dc1..8e76cb9fcc 100644 --- a/manim/utils/color/core.py +++ b/manim/utils/color/core.py @@ -74,6 +74,7 @@ RGBA_Tuple_Float, RGBA_Tuple_Int, ) +from manim.utils.color.manim_colors import WHITE from ...utils.space_ops import normalize @@ -237,7 +238,7 @@ def _internal_value(self, value: ManimColorInternal) -> None: self.__value: ManimColorInternal = value @classmethod - def _construct_from_space(cls, _space) -> Self: + def _construct_from_space(cls, _space: npt.NDArray[ManimFloat]) -> Self: """ This function is used as a proxy for constructing a color with an internal value, this can be used by subclasses to hook into the construction of new objects using the internal value format @@ -560,7 +561,7 @@ def to_hsl(self) -> HSL_Array_Float: """ return np.array(colorsys.rgb_to_hls(*self.to_rgb())) - def invert(self, with_alpha=False) -> Self: + def invert(self, with_alpha: bool = False) -> Self: """Returns an linearly inverted version of the color (no inplace changes) Parameters @@ -795,10 +796,10 @@ def parse( Either a list of colors or a singular color depending on the input """ - def is_sequence(colors) -> TypeGuard[Sequence[ParsableManimColor]]: + def is_sequence(colors: Any) -> TypeGuard[Sequence[ParsableManimColor]]: return isinstance(colors, (list, tuple)) - def is_parsable(color) -> TypeGuard[ParsableManimColor]: + def is_parsable(color: Any) -> TypeGuard[ParsableManimColor]: return not isinstance(color, (list, tuple)) if is_sequence(color): @@ -807,9 +808,11 @@ def is_parsable(color) -> TypeGuard[ParsableManimColor]: ] elif is_parsable(color): return cls._from_internal(ManimColor(color, alpha)._internal_value) + else: + return cls._from_internal(ManimColor(WHITE, alpha)._internal_value) @staticmethod - def gradient(colors: list[ManimColor], length: int): + def gradient(colors: list[ManimColor], length: int) -> None: """This is not implemented by now refer to :func:`color_gradient` for a working implementation for now""" # TODO: implement proper gradient, research good implementation for this or look at 3b1b implementation raise NotImplementedError @@ -825,7 +828,8 @@ def __eq__(self, other: object) -> bool: raise TypeError( f"Cannot compare {self.__class__.__name__} with {other.__class__.__name__}" ) - return np.allclose(self._internal_value, other._internal_value) + value: bool = np.allclose(self._internal_value, other._internal_value) + return value def __add__(self, other: int | float | Self) -> Self: if isinstance(other, (int, float)): @@ -911,7 +915,8 @@ def __int__(self) -> int: return self.to_integer() def __getitem__(self, index: int) -> float: - return self._internal_space[index] + value: float = self._internal_space[index] + return value def __and__(self, other: Self) -> Self: return self._construct_from_space( @@ -945,7 +950,7 @@ def __init__( if len(hsv) == 3: self.__hsv: HSVA_Array_Float = np.asarray((*hsv, alpha)) elif len(hsv) == 4: - self.__hsv: HSVA_Array_Float = np.asarray(hsv) + self.__hsv = np.asarray(hsv) else: raise ValueError("HSV Color must be an array of 3 values") @@ -958,48 +963,54 @@ def _from_internal(cls, value: ManimColorInternal) -> Self: @property def hue(self) -> float: - return self.__hsv[0] - - @property - def saturation(self) -> float: - return self.__hsv[1] - - @property - def value(self) -> float: - return self.__hsv[2] + value: float = self.__hsv[0] + return value @hue.setter def hue(self, value: float) -> None: self.__hsv[0] = value + @property + def saturation(self) -> float: + value: float = self.__hsv[1] + return value + @saturation.setter def saturation(self, value: float) -> None: self.__hsv[1] = value + @property + def value(self) -> float: + value: float = self.__hsv[2] + return value + @value.setter def value(self, value: float) -> None: self.__hsv[2] = value @property def h(self) -> float: - return self.__hsv[0] - - @property - def s(self) -> float: - return self.__hsv[1] - - @property - def v(self) -> float: - return self.__hsv[2] + value: float = self.__hsv[0] + return value @h.setter def h(self, value: float) -> None: self.__hsv[0] = value + @property + def s(self) -> float: + value: float = self.__hsv[1] + return value + @s.setter def s(self, value: float) -> None: self.__hsv[1] = value + @property + def v(self) -> float: + value: float = self.__hsv[2] + return value + @v.setter def v(self, value: float) -> None: self.__hsv[2] = value diff --git a/mypy.ini b/mypy.ini index 6aa8a808e0..4550abeaca 100644 --- a/mypy.ini +++ b/mypy.ini @@ -88,6 +88,9 @@ ignore_errors = True [mypy-manim.utils.bezier.*] ignore_errors = False +[mypy-manim.utils.color.*] +ignore_errors = False + [mypy-manim.utils.debug.*] ignore_errors = False From 0aa4318ec1d86fb874079e8b8cb20ab34af0baa2 Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Mon, 4 Nov 2024 22:09:52 +0100 Subject: [PATCH 06/16] Avoid circular import. --- manim/utils/color/core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/manim/utils/color/core.py b/manim/utils/color/core.py index 8e76cb9fcc..c4ebbe0bf9 100644 --- a/manim/utils/color/core.py +++ b/manim/utils/color/core.py @@ -74,7 +74,6 @@ RGBA_Tuple_Float, RGBA_Tuple_Int, ) -from manim.utils.color.manim_colors import WHITE from ...utils.space_ops import normalize @@ -809,7 +808,7 @@ def is_parsable(color: Any) -> TypeGuard[ParsableManimColor]: elif is_parsable(color): return cls._from_internal(ManimColor(color, alpha)._internal_value) else: - return cls._from_internal(ManimColor(WHITE, alpha)._internal_value) + return cls._from_internal(ManimColor("WHITE", alpha)._internal_value) @staticmethod def gradient(colors: list[ManimColor], length: int) -> None: From f09d0831b0f24844f7cb67e7784582143a1a2303 Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Tue, 5 Nov 2024 09:54:12 +0100 Subject: [PATCH 07/16] Handle mypy errors in utils.simple_functions.* --- manim/utils/simple_functions.py | 20 +++++++++++++++----- mypy.ini | 3 +++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/manim/utils/simple_functions.py b/manim/utils/simple_functions.py index 898c1d527b..152a6c276e 100644 --- a/manim/utils/simple_functions.py +++ b/manim/utils/simple_functions.py @@ -11,7 +11,7 @@ from functools import lru_cache -from typing import Callable +from typing import Callable, overload import numpy as np from scipy import special @@ -54,7 +54,7 @@ def binary_search( """ lh = lower_bound rh = upper_bound - mh = np.mean(np.array([lh, rh])) + mh: float = np.mean(np.array([lh, rh])) while abs(rh - lh) > tolerance: mh = np.mean(np.array([lh, rh])) lx, mx, rx = (function(h) for h in (lh, mh, rh)) @@ -88,10 +88,19 @@ def choose(n: int, k: int) -> int: - https://en.wikipedia.org/wiki/Combination - https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.comb.html """ - return special.comb(n, k, exact=True) + value: int = special.comb(n, k, exact=True) + return value -def clip(a, min_a, max_a): +@overload +def clip(a: float, min_a: float, max_a: float) -> float: ... + + +@overload +def clip(a: str, min_a: str, max_a: str) -> str: ... + + +def clip(a, min_a, max_a): # type: ignore[no-untyped-def] """Clips ``a`` to the interval [``min_a``, ``max_a``]. Accepts any comparable objects (i.e. those that support <, >). @@ -125,4 +134,5 @@ def sigmoid(x: float) -> float: - https://en.wikipedia.org/wiki/Sigmoid_function - https://en.wikipedia.org/wiki/Logistic_function """ - return 1.0 / (1 + np.exp(-x)) + value: float = 1.0 / (1 + np.exp(-x)) + return value diff --git a/mypy.ini b/mypy.ini index 4550abeaca..90e06a65d9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -94,6 +94,9 @@ ignore_errors = False [mypy-manim.utils.debug.*] ignore_errors = False +[mypy-manim.utils.simple_functions.*] +ignore_errors = False + [mypy-manim.utils.unit.*] ignore_errors = False From 9e0edf065424b8d54c56cc0562531f133ab4fe79 Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Tue, 5 Nov 2024 11:05:28 +0100 Subject: [PATCH 08/16] Handle my errors in utils.testing.* --- manim/renderer/cairo_renderer.py | 3 +- manim/scene/scene_file_writer.py | 19 ++++++--- manim/utils/testing/_frames_testers.py | 14 ++++--- manim/utils/testing/_show_diff.py | 2 +- manim/utils/testing/_test_class_makers.py | 49 +++++++++++++++-------- manim/utils/testing/frames_comparison.py | 33 +++++++++------ mypy.ini | 3 ++ 7 files changed, 81 insertions(+), 42 deletions(-) diff --git a/manim/renderer/cairo_renderer.py b/manim/renderer/cairo_renderer.py index b97fa50299..45076f7d13 100644 --- a/manim/renderer/cairo_renderer.py +++ b/manim/renderer/cairo_renderer.py @@ -3,6 +3,7 @@ import typing import numpy as np +import numpy.typing as npt from manim.utils.hashing import get_hash_from_play_call @@ -160,7 +161,7 @@ def render(self, scene, time, moving_mobjects): self.update_frame(scene, moving_mobjects) self.add_frame(self.get_frame()) - def get_frame(self): + def get_frame(self) -> npt.NDArray: """ Gets the current frame as NumPy array. diff --git a/manim/scene/scene_file_writer.py b/manim/scene/scene_file_writer.py index 4293de7105..69651d0a8a 100644 --- a/manim/scene/scene_file_writer.py +++ b/manim/scene/scene_file_writer.py @@ -20,7 +20,9 @@ from pydub import AudioSegment from manim import __version__ -from manim.typing import PixelArray +from manim.renderer.cairo_renderer import CairoRenderer +from manim.renderer.opengl_renderer import OpenGLRenderer +from manim.typing import PixelArray, StrPath from .. import config, logger from .._config.logger_utils import set_file_logger @@ -104,7 +106,12 @@ class SceneFileWriter: force_output_as_scene_name = False - def __init__(self, renderer, scene_name, **kwargs): + def __init__( + self, + renderer: CairoRenderer | OpenGLRenderer, + scene_name: StrPath, + **kwargs: Any, + ) -> None: self.renderer = renderer self.init_output_directories(scene_name) self.init_audio() @@ -118,7 +125,7 @@ def __init__(self, renderer, scene_name, **kwargs): name="autocreated", type_=DefaultSectionType.NORMAL, skip_animations=False ) - def init_output_directories(self, scene_name): + def init_output_directories(self, scene_name: StrPath): """Initialise output directories. Notes @@ -378,7 +385,9 @@ def add_sound( self.add_audio_segment(new_segment, time, **kwargs) # Writers - def begin_animation(self, allow_write: bool = False, file_path=None): + def begin_animation( + self, allow_write: bool = False, file_path: StrPath | None = None + ) -> None: """ Used internally by manim to stream the animation to FFMPEG for displaying or writing to a file. @@ -391,7 +400,7 @@ def begin_animation(self, allow_write: bool = False, file_path=None): if write_to_movie() and allow_write: self.open_partial_movie_stream(file_path=file_path) - def end_animation(self, allow_write: bool = False): + def end_animation(self, allow_write: bool = False) -> None: """ Internally used by Manim to stop streaming to FFMPEG gracefully. diff --git a/manim/utils/testing/_frames_testers.py b/manim/utils/testing/_frames_testers.py index bef6184937..811bcff762 100644 --- a/manim/utils/testing/_frames_testers.py +++ b/manim/utils/testing/_frames_testers.py @@ -3,7 +3,9 @@ import contextlib import logging import warnings +from collections.abc import Generator from pathlib import Path +from typing import Any import numpy as np @@ -16,7 +18,7 @@ class _FramesTester: - def __init__(self, file_path: Path, show_diff=False) -> None: + def __init__(self, file_path: Path, show_diff: bool = False) -> None: self._file_path = file_path self._show_diff = show_diff self._frames: np.ndarray @@ -24,7 +26,7 @@ def __init__(self, file_path: Path, show_diff=False) -> None: self._frames_compared = 0 @contextlib.contextmanager - def testing(self): + def testing(self) -> Generator[Any, Any, Any]: with np.load(self._file_path) as data: self._frames = data["frame_data"] # For backward compatibility, when the control data contains only one frame (<= v0.8.0) @@ -38,7 +40,7 @@ def testing(self): f"when there are {self._number_frames} control frames for this test." ) - def check_frame(self, frame_number: int, frame: np.ndarray): + def check_frame(self, frame_number: int, frame: np.ndarray) -> None: assert frame_number < self._number_frames, ( f"The tested scene is at frame number {frame_number} " f"when there are {self._number_frames} control frames." @@ -86,17 +88,17 @@ def __init__(self, file_path: Path, size_frame: tuple) -> None: self._number_frames_written: int = 0 # Actually write a frame. - def check_frame(self, index: int, frame: np.ndarray): + def check_frame(self, index: int, frame: np.ndarray) -> None: frame = frame[np.newaxis, ...] self.frames = np.concatenate((self.frames, frame)) self._number_frames_written += 1 @contextlib.contextmanager - def testing(self): + def testing(self) -> Generator[Any, Any, Any]: yield self.save_contol_data() - def save_contol_data(self): + def save_contol_data(self) -> None: self.frames = self.frames.astype("uint8") np.savez_compressed(self.file_path, frame_data=self.frames) logger.info( diff --git a/manim/utils/testing/_show_diff.py b/manim/utils/testing/_show_diff.py index 0cb2aab0f5..7c8abf8698 100644 --- a/manim/utils/testing/_show_diff.py +++ b/manim/utils/testing/_show_diff.py @@ -11,7 +11,7 @@ def show_diff_helper( frame_data: np.ndarray, expected_frame_data: np.ndarray, control_data_filename: str, -): +) -> None: """Will visually display with matplotlib differences between frame generated and the one expected.""" import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt diff --git a/manim/utils/testing/_test_class_makers.py b/manim/utils/testing/_test_class_makers.py index fe127be1c7..17bb4bd379 100644 --- a/manim/utils/testing/_test_class_makers.py +++ b/manim/utils/testing/_test_class_makers.py @@ -1,9 +1,14 @@ from __future__ import annotations -from typing import Callable +from typing import Any, Callable +import numpy.typing as npt + +from manim.renderer.cairo_renderer import CairoRenderer +from manim.renderer.opengl_renderer import OpenGLRenderer from manim.scene.scene import Scene from manim.scene.scene_file_writer import SceneFileWriter +from manim.typing import StrPath from ._frames_testers import _FramesTester @@ -11,13 +16,14 @@ def _make_test_scene_class( base_scene: type[Scene], construct_test: Callable[[Scene], None], - test_renderer, + test_renderer: CairoRenderer | OpenGLRenderer | None, ) -> type[Scene]: - class _TestedScene(base_scene): - def __init__(self, *args, **kwargs): + # TODO: Get the type annotation right for the base_scene argument. + class _TestedScene(base_scene): # type: ignore[valid-type, misc] + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, renderer=test_renderer, **kwargs) - def construct(self): + def construct(self) -> None: construct_test(self) # Manim hack to render the very last frame (normally the last frame is not the very end of the animation) @@ -28,7 +34,7 @@ def construct(self): return _TestedScene -def _make_test_renderer_class(from_renderer): +def _make_test_renderer_class(from_renderer: type) -> Any: # Just for inheritance. class _TestRenderer(from_renderer): pass @@ -39,38 +45,49 @@ class _TestRenderer(from_renderer): class DummySceneFileWriter(SceneFileWriter): """Delegate of SceneFileWriter used to test the frames.""" - def __init__(self, renderer, scene_name, **kwargs): + def __init__( + self, + renderer: CairoRenderer | OpenGLRenderer, + scene_name: StrPath, + **kwargs: Any, + ) -> None: super().__init__(renderer, scene_name, **kwargs) self.i = 0 - def init_output_directories(self, scene_name): + def init_output_directories(self, scene_name: StrPath) -> None: pass - def add_partial_movie_file(self, hash_animation): + def add_partial_movie_file(self, hash_animation: str) -> None: pass - def begin_animation(self, allow_write=True): + def begin_animation( + self, allow_write: bool = True, file_path: StrPath | None = None + ) -> Any: pass - def end_animation(self, allow_write): + def end_animation(self, allow_write: bool = False) -> None: pass - def combine_to_movie(self): + def combine_to_movie(self) -> None: pass - def combine_to_section_videos(self): + def combine_to_section_videos(self) -> None: pass - def clean_cache(self): + def clean_cache(self) -> None: pass - def write_frame(self, frame_or_renderer, num_frames=1): + def write_frame( + self, frame_or_renderer: npt.NDArray | OpenGLRenderer, num_frames: int = 1 + ) -> None: self.i += 1 def _make_scene_file_writer_class(tester: _FramesTester) -> type[SceneFileWriter]: class TestSceneFileWriter(DummySceneFileWriter): - def write_frame(self, frame_or_renderer, num_frames=1): + def write_frame( + self, frame_or_renderer: npt.NDArray | OpenGLRenderer, num_frames: int = 1 + ) -> None: tester.check_frame(self.i, frame_or_renderer) super().write_frame(frame_or_renderer, num_frames=num_frames) diff --git a/manim/utils/testing/frames_comparison.py b/manim/utils/testing/frames_comparison.py index a298585b29..1197481281 100644 --- a/manim/utils/testing/frames_comparison.py +++ b/manim/utils/testing/frames_comparison.py @@ -3,7 +3,7 @@ import functools import inspect from pathlib import Path -from typing import Callable +from typing import Any, Callable import cairo import pytest @@ -14,7 +14,9 @@ from manim._config.utils import ManimConfig from manim.camera.three_d_camera import ThreeDCamera from manim.renderer.cairo_renderer import CairoRenderer +from manim.renderer.opengl_renderer import OpenGLRenderer from manim.scene.three_d_scene import ThreeDScene +from manim.typing import StrPath from ._frames_testers import _ControlDataWriter, _FramesTester from ._test_class_makers import ( @@ -31,13 +33,13 @@ def frames_comparison( - func=None, + func: Callable | None = None, *, last_frame: bool = True, - renderer_class=CairoRenderer, - base_scene=Scene, - **custom_config, -): + renderer_class: type[CairoRenderer | OpenGLRenderer] = CairoRenderer, + base_scene: type[Scene] = Scene, + **custom_config: Any, +) -> Any: """Compares the frames generated by the test with control frames previously registered. If there is no control frames for this test, the test will fail. To generate @@ -60,7 +62,7 @@ def frames_comparison( If the scene has a moving animation, then the test must set last_frame to False. """ - def decorator_maker(tested_scene_construct): + def decorator_maker(tested_scene_construct: Callable) -> Callable: if ( SCENE_PARAMETER_NAME not in inspect.getfullargspec(tested_scene_construct).args @@ -79,11 +81,14 @@ def decorator_maker(tested_scene_construct): "There is no module test name indicated for the graphical unit test. You have to declare __module_test__ in the test file.", ) module_name = tested_scene_construct.__globals__.get("__module_test__") + assert isinstance(module_name, str) test_name = tested_scene_construct.__name__[len("test_") :] @functools.wraps(tested_scene_construct) # The "request" parameter is meant to be used as a fixture by pytest. See below. - def wrapper(*args, request: FixtureRequest, tmp_path, **kwargs): + def wrapper( + *args: Any, request: FixtureRequest, tmp_path: StrPath, **kwargs: Any + ) -> None: # check for cairo version if ( renderer_class is CairoRenderer @@ -146,13 +151,13 @@ def wrapper(*args, request: FixtureRequest, tmp_path, **kwargs): inspect.Parameter("tmp_path", inspect.Parameter.KEYWORD_ONLY), ] new_sig = old_sig.replace(parameters=parameters) - wrapper.__signature__ = new_sig + wrapper.__signature__ = new_sig # type: ignore[attr-defined] # Reach a bit into pytest internals to hoist the marks from our wrapped # function. - wrapper.pytestmark = [] + wrapper.pytestmark = [] # type: ignore[attr-defined] new_marks = getattr(tested_scene_construct, "pytestmark", []) - wrapper.pytestmark = new_marks + wrapper.pytestmark = new_marks # type: ignore[attr-defined] return wrapper # Case where the decorator is called with and without parentheses. @@ -193,7 +198,9 @@ def _make_test_comparing_frames( The pytest test. """ if is_set_test_data_test: - frames_tester = _ControlDataWriter(file_path, size_frame=size_frame) + frames_tester: _FramesTester = _ControlDataWriter( + file_path, size_frame=size_frame + ) else: frames_tester = _FramesTester(file_path, show_diff=show_diff) @@ -204,7 +211,7 @@ def _make_test_comparing_frames( ) testRenderer = _make_test_renderer_class(renderer_class) - def real_test(): + def real_test() -> None: with frames_tester.testing(): sceneTested = _make_test_scene_class( base_scene=base_scene, diff --git a/mypy.ini b/mypy.ini index 90e06a65d9..9facfb8c8d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -97,6 +97,9 @@ ignore_errors = False [mypy-manim.utils.simple_functions.*] ignore_errors = False +[mypy-manim.utils.testing.*] +ignore_errors = False + [mypy-manim.utils.unit.*] ignore_errors = False From f5a5411e25e142046ae08747ae33a55095540e60 Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Tue, 5 Nov 2024 11:13:41 +0100 Subject: [PATCH 09/16] Avoid circular import. --- manim/scene/scene_file_writer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/manim/scene/scene_file_writer.py b/manim/scene/scene_file_writer.py index 69651d0a8a..903b80cfd6 100644 --- a/manim/scene/scene_file_writer.py +++ b/manim/scene/scene_file_writer.py @@ -20,8 +20,6 @@ from pydub import AudioSegment from manim import __version__ -from manim.renderer.cairo_renderer import CairoRenderer -from manim.renderer.opengl_renderer import OpenGLRenderer from manim.typing import PixelArray, StrPath from .. import config, logger @@ -108,7 +106,7 @@ class SceneFileWriter: def __init__( self, - renderer: CairoRenderer | OpenGLRenderer, + renderer: Any, scene_name: StrPath, **kwargs: Any, ) -> None: From 8e4cf8092c81300d6c4ce0f9691d0569cd2655ff Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Wed, 6 Nov 2024 12:44:48 +0100 Subject: [PATCH 10/16] Handle mypy errors in utils/family_ops.py --- manim/utils/family_ops.py | 16 ++++++++++++---- mypy.ini | 3 +++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/manim/utils/family_ops.py b/manim/utils/family_ops.py index 8d4af9d5a5..e7cf011161 100644 --- a/manim/utils/family_ops.py +++ b/manim/utils/family_ops.py @@ -2,20 +2,26 @@ import itertools as it +from manim.mobject.mobject import Mobject + __all__ = [ "extract_mobject_family_members", "restructure_list_to_exclude_certain_family_members", ] -def extract_mobject_family_members(mobject_list, only_those_with_points=False): +def extract_mobject_family_members( + mobject_list: list[Mobject], only_those_with_points: bool = False +) -> list[Mobject]: result = list(it.chain(*(mob.get_family() for mob in mobject_list))) if only_those_with_points: result = [mob for mob in result if mob.has_points()] return result -def restructure_list_to_exclude_certain_family_members(mobject_list, to_remove): +def restructure_list_to_exclude_certain_family_members( + mobject_list: list[Mobject], to_remove: list[Mobject] +) -> list[Mobject]: """ Removes anything in to_remove from mobject_list, but in the event that one of the items to be removed is a member of the family of an item in mobject_list, @@ -25,10 +31,12 @@ def restructure_list_to_exclude_certain_family_members(mobject_list, to_remove): but one of its submobjects is removed, e.g. scene.remove(m1), it's useful for the list of mobject_list to be edited to contain other submobjects, but not m1. """ - new_list = [] + new_list: list[Mobject] = [] to_remove = extract_mobject_family_members(to_remove) - def add_safe_mobjects_from_list(list_to_examine, set_to_remove): + def add_safe_mobjects_from_list( + list_to_examine: list[Mobject], set_to_remove: set[Mobject] + ) -> None: for mob in list_to_examine: if mob in set_to_remove: continue diff --git a/mypy.ini b/mypy.ini index 9facfb8c8d..ddeaac4cb8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -94,6 +94,9 @@ ignore_errors = False [mypy-manim.utils.debug.*] ignore_errors = False +[mypy-manim.utils.family_ops.*] +ignore_errors = False + [mypy-manim.utils.simple_functions.*] ignore_errors = False From 5a6a8bf7db72ed0e94f35a7873cc556d5624cf3a Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Wed, 6 Nov 2024 12:46:38 +0100 Subject: [PATCH 11/16] Handle mypy errors in utils/parameter_parsing.py --- manim/utils/parameter_parsing.py | 2 +- mypy.ini | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/manim/utils/parameter_parsing.py b/manim/utils/parameter_parsing.py index 7966906bb1..d3676c7301 100644 --- a/manim/utils/parameter_parsing.py +++ b/manim/utils/parameter_parsing.py @@ -23,7 +23,7 @@ def flatten_iterable_parameters( :class:`list` The flattened list of parameters. """ - flattened_parameters = [] + flattened_parameters: list[T] = [] for arg in args: if isinstance(arg, (Iterable, GeneratorType)): flattened_parameters.extend(arg) diff --git a/mypy.ini b/mypy.ini index ddeaac4cb8..0e9e56b534 100644 --- a/mypy.ini +++ b/mypy.ini @@ -97,6 +97,9 @@ ignore_errors = False [mypy-manim.utils.family_ops.*] ignore_errors = False +[mypy-manim.utils.parameter_parsing.*] +ignore_errors = False + [mypy-manim.utils.simple_functions.*] ignore_errors = False From c2ef7be353d3ccb5dc7eaabe01cb705991120a66 Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Wed, 6 Nov 2024 13:20:01 +0100 Subject: [PATCH 12/16] Handle some of the mypy errors in utils.docbuild.* --- manim/_config/utils.py | 2 +- manim/utils/docbuild/autocolor_directive.py | 2 ++ manim/utils/docbuild/manim_directive.py | 12 +++++++----- manim/utils/docbuild/module_parsing.py | 14 ++++++++------ 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/manim/_config/utils.py b/manim/_config/utils.py index b453b290e2..06bf7616a7 100644 --- a/manim/_config/utils.py +++ b/manim/_config/utils.py @@ -1395,7 +1395,7 @@ def renderer(self, value: str | RendererType) -> None: self._set_from_enum("renderer", renderer, RendererType) @property - def media_dir(self) -> str: + def media_dir(self) -> str | Path: """Main output directory. See :meth:`ManimConfig.get_dir`.""" return self._d["media_dir"] diff --git a/manim/utils/docbuild/autocolor_directive.py b/manim/utils/docbuild/autocolor_directive.py index 3a699f54d4..e88af71457 100644 --- a/manim/utils/docbuild/autocolor_directive.py +++ b/manim/utils/docbuild/autocolor_directive.py @@ -26,6 +26,8 @@ class ManimColorModuleDocumenter(Directive): has_content = True def add_directive_header(self, sig: str) -> None: + # TODO: The Directive class has no method named + # add_directive_header. super().add_directive_header(sig) def run(self) -> list[nodes.Element]: diff --git a/manim/utils/docbuild/manim_directive.py b/manim/utils/docbuild/manim_directive.py index 4576be81dd..c679993d9f 100644 --- a/manim/utils/docbuild/manim_directive.py +++ b/manim/utils/docbuild/manim_directive.py @@ -118,12 +118,14 @@ class SkipManimNode(nodes.Admonition, nodes.Element): def visit(self: SkipManimNode, node: nodes.Element, name: str = "") -> None: + # TODO: The SkipManimNode class have no method named visit_admonition. self.visit_admonition(node, name) if not isinstance(node[0], nodes.title): node.insert(0, nodes.title("skip-manim", "Example Placeholder")) def depart(self: SkipManimNode, node: nodes.Element) -> None: + # TODO: The SkipManimNode class have no method named depart_admonition. self.depart_admonition(node) @@ -245,7 +247,7 @@ def run(self) -> list[nodes.Element]: if not dest_dir.exists(): dest_dir.mkdir(parents=True, exist_ok=True) - source_block = [ + source_block_in = [ ".. code-block:: python", "", " from manim import *\n", @@ -258,7 +260,7 @@ def run(self) -> list[nodes.Element]: "", " ", ] - source_block = "\n".join(source_block) + source_block = "\n".join(source_block_in) config.media_dir = (Path(setup.confdir) / "media").absolute() config.images_dir = "{media_dir}/images" @@ -345,7 +347,7 @@ def run(self) -> list[nodes.Element]: rendering_times_file_path = Path("../rendering_times.csv") -def _write_rendering_stats(scene_name: str, run_time: str, file_name: str) -> None: +def _write_rendering_stats(scene_name: str, run_time: float, file_name: str) -> None: with rendering_times_file_path.open("a") as file: csv.writer(file).writerow( [ @@ -369,9 +371,9 @@ def _log_rendering_times(*args: tuple[Any]) -> None: data = [row for row in data if row] max_file_length = max(len(row[0]) for row in data) - for key, group in it.groupby(data, key=lambda row: row[0]): + for key, group_iter in it.groupby(data, key=lambda row: row[0]): key = key.ljust(max_file_length + 1, ".") - group = list(group) + group = list(group_iter) if len(group) == 1: row = group[0] print(f"{key}{row[2].rjust(7, '.')}s {row[1]}") diff --git a/manim/utils/docbuild/module_parsing.py b/manim/utils/docbuild/module_parsing.py index 57ac9a56aa..98acadbdcc 100644 --- a/manim/utils/docbuild/module_parsing.py +++ b/manim/utils/docbuild/module_parsing.py @@ -3,7 +3,9 @@ from __future__ import annotations import ast +from ast import Attribute, Name, Subscript from pathlib import Path +from typing import Any from typing_extensions import TypeAlias @@ -86,10 +88,10 @@ def parse_module_attributes() -> tuple[AliasDocsDict, DataDict, TypeVarDict]: return ALIAS_DOCS_DICT, DATA_DICT, TYPEVAR_DICT for module_path in MANIM_ROOT.rglob("*.py"): - module_name = module_path.resolve().relative_to(MANIM_ROOT) - module_name = list(module_name.parts) - module_name[-1] = module_name[-1].removesuffix(".py") - module_name = ".".join(module_name) + module_name_t1 = module_path.resolve().relative_to(MANIM_ROOT) + module_name_t2 = list(module_name_t1.parts) + module_name_t2[-1] = module_name_t2[-1].removesuffix(".py") + module_name = ".".join(module_name_t2) module_content = module_path.read_text(encoding="utf-8") @@ -153,7 +155,7 @@ def parse_module_attributes() -> tuple[AliasDocsDict, DataDict, TypeVarDict]: ) ) ): - inner_nodes = node.body + inner_nodes: list[Any] = node.body else: inner_nodes = [node] @@ -208,7 +210,7 @@ def parse_module_attributes() -> tuple[AliasDocsDict, DataDict, TypeVarDict]: # Does the assignment have a target of type Name? Then # it could be considered a definition of a module attribute. if type(node) is ast.AnnAssign: - target = node.target + target: Name | Attribute | Subscript | ast.expr | None = node.target elif type(node) is ast.Assign and len(node.targets) == 1: target = node.targets[0] else: From 30e5f36fb365e80af42b86f05416ca45277d09a2 Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Wed, 6 Nov 2024 13:26:33 +0100 Subject: [PATCH 13/16] Handle mypy errors for utils/config_ops.py --- manim/utils/config_ops.py | 21 +++++++++++---------- mypy.ini | 3 +++ 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/manim/utils/config_ops.py b/manim/utils/config_ops.py index 6e1f09990e..e698dcf979 100644 --- a/manim/utils/config_ops.py +++ b/manim/utils/config_ops.py @@ -10,11 +10,12 @@ import itertools as it +from typing import Any import numpy as np -def merge_dicts_recursively(*dicts): +def merge_dicts_recursively(*dicts: dict) -> dict: """ Creates a dict whose keyset is the union of all the input dictionaries. The value for each key is based @@ -24,7 +25,7 @@ def merge_dicts_recursively(*dicts): When values are dictionaries, it is applied recursively """ - result = {} + result: dict = {} all_items = it.chain(*(d.items() for d in dicts)) for key, value in all_items: if key in result and isinstance(result[key], dict) and isinstance(value, dict): @@ -34,7 +35,7 @@ def merge_dicts_recursively(*dicts): return result -def update_dict_recursively(current_dict, *others): +def update_dict_recursively(current_dict: dict, *others: dict) -> None: updated_dict = merge_dicts_recursively(current_dict, *others) current_dict.update(updated_dict) @@ -44,7 +45,7 @@ def update_dict_recursively(current_dict, *others): class DictAsObject: - def __init__(self, dictin): + def __init__(self, dictin: dict): self.__dict__ = dictin @@ -53,13 +54,13 @@ class _Data: self.data attributes must be arrays. """ - def __set_name__(self, obj, name): + def __set_name__(self, obj: Any, name: str) -> None: self.name = name - def __get__(self, obj, owner): + def __get__(self, obj: Any, owner: Any) -> Any: return obj.data[self.name] - def __set__(self, obj, array: np.ndarray): + def __set__(self, obj: Any, array: np.ndarray) -> None: obj.data[self.name] = array @@ -68,11 +69,11 @@ class _Uniforms: self.uniforms attributes must be floats. """ - def __set_name__(self, obj, name): + def __set_name__(self, obj: Any, name: str) -> None: self.name = name - def __get__(self, obj, owner): + def __get__(self, obj: Any, owner: Any) -> Any: return obj.__dict__["uniforms"][self.name] - def __set__(self, obj, num: float): + def __set__(self, obj: Any, num: float) -> None: obj.__dict__["uniforms"][self.name] = num diff --git a/mypy.ini b/mypy.ini index 0e9e56b534..c2db342f46 100644 --- a/mypy.ini +++ b/mypy.ini @@ -88,6 +88,9 @@ ignore_errors = True [mypy-manim.utils.bezier.*] ignore_errors = False +[mypy-manim.utils.config_ops.*] +ignore_errors = False + [mypy-manim.utils.color.*] ignore_errors = False From 133e10b2d416c25c5a974e929ac25bf40e104a07 Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Wed, 6 Nov 2024 13:30:03 +0100 Subject: [PATCH 14/16] Handle mypy errors from utils/commands.py --- manim/utils/commands.py | 9 +++++++-- mypy.ini | 6 ++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/manim/utils/commands.py b/manim/utils/commands.py index 8a15889510..b518106f81 100644 --- a/manim/utils/commands.py +++ b/manim/utils/commands.py @@ -4,9 +4,12 @@ from collections.abc import Generator from pathlib import Path from subprocess import run +from typing import Any import av +from manim.typing import StrOrBytesPath + __all__ = [ "capture", "get_video_metadata", @@ -14,7 +17,9 @@ ] -def capture(command, cwd=None, command_input=None): +def capture( + command: str, cwd: StrOrBytesPath | None = None, command_input: str | None = None +) -> tuple[str, str, int]: p = run( command, cwd=cwd, @@ -27,7 +32,7 @@ def capture(command, cwd=None, command_input=None): return out, err, p.returncode -def get_video_metadata(path_to_video: str | os.PathLike) -> dict[str]: +def get_video_metadata(path_to_video: str | os.PathLike) -> dict[str, Any]: with av.open(str(path_to_video)) as container: stream = container.streams.video[0] ctxt = stream.codec_context diff --git a/mypy.ini b/mypy.ini index c2db342f46..a88ab364f9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -88,6 +88,9 @@ ignore_errors = True [mypy-manim.utils.bezier.*] ignore_errors = False +[mypy-manim.utils.commands.*] +ignore_errors = False + [mypy-manim.utils.config_ops.*] ignore_errors = False @@ -97,6 +100,9 @@ ignore_errors = False [mypy-manim.utils.debug.*] ignore_errors = False +[mypy-manim.utils.docbuild.*] +ignore_errors = True + [mypy-manim.utils.family_ops.*] ignore_errors = False From 5ca896f5b6187ca6753a3fb407ad282cf5ff573d Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Wed, 6 Nov 2024 13:33:24 +0100 Subject: [PATCH 15/16] Handle mypy errors in utils/tex_templates.py --- manim/utils/tex.py | 3 +++ manim/utils/tex_templates.py | 2 +- mypy.ini | 3 +++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/manim/utils/tex.py b/manim/utils/tex.py index 41e19bcd57..39ef55d8f1 100644 --- a/manim/utils/tex.py +++ b/manim/utils/tex.py @@ -36,6 +36,9 @@ class TexTemplate: tex_compiler: str = "latex" """The TeX compiler to be used, e.g. ``latex``, ``pdflatex`` or ``lualatex``.""" + description: str = "" + """A description of the template""" + output_format: str = ".dvi" """The output format resulting from compilation, e.g. ``.dvi`` or ``.pdf``.""" diff --git a/manim/utils/tex_templates.py b/manim/utils/tex_templates.py index ed9db91c3f..7273b67d0f 100644 --- a/manim/utils/tex_templates.py +++ b/manim/utils/tex_templates.py @@ -12,7 +12,7 @@ # This file makes TexTemplateLibrary and TexFontTemplates available for use in manim Tex and MathTex objects. -def _new_ams_template(): +def _new_ams_template() -> TexTemplate: """Returns a simple Tex Template with only basic AMS packages""" preamble = r""" \usepackage[english]{babel} diff --git a/mypy.ini b/mypy.ini index a88ab364f9..cc298f009f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -115,6 +115,9 @@ ignore_errors = False [mypy-manim.utils.testing.*] ignore_errors = False +[mypy-manim.utils.tex_templates.*] +ignore_errors = False + [mypy-manim.utils.unit.*] ignore_errors = False From 47a0661c20ea7362783fa1c0a34ab9286b149194 Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Wed, 6 Nov 2024 13:45:15 +0100 Subject: [PATCH 16/16] Handle mypy errors in utils/space_ops.py --- manim/utils/space_ops.py | 37 +++++++++++++++++++++++-------------- mypy.ini | 3 +++ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/manim/utils/space_ops.py b/manim/utils/space_ops.py index 0017a53ef7..431f33fcd2 100644 --- a/manim/utils/space_ops.py +++ b/manim/utils/space_ops.py @@ -4,7 +4,7 @@ import itertools as it from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable import numpy as np from mapbox_earcut import triangulate_float32 as earcut @@ -59,7 +59,8 @@ def norm_squared(v: float) -> float: - return np.dot(v, v) + val: float = np.dot(v, v) + return val def cross(v1: Vector3D, v2: Vector3D) -> Vector3D: @@ -200,7 +201,7 @@ def rotate_vector( return rotation_matrix(angle, axis) @ vector -def thick_diagonal(dim: int, thickness=2) -> np.ndarray: +def thick_diagonal(dim: int, thickness: float = 2) -> np.ndarray: row_indices = np.arange(dim).repeat(dim).reshape((dim, dim)) col_indices = np.transpose(row_indices) return (np.abs(row_indices - col_indices) < thickness).astype("uint8") @@ -318,8 +319,10 @@ def angle_of_vector(vector: Sequence[float] | np.ndarray) -> float: c_vec = np.empty(vector.shape[1], dtype=np.complex128) c_vec.real = vector[0] c_vec.imag = vector[1] - return np.angle(c_vec) - return np.angle(complex(*vector[:2])) + val1: float = np.angle(c_vec) + return val1 + val: float = np.angle(complex(*vector[:2])) + return val def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float: @@ -338,13 +341,17 @@ def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float: float The angle between the vectors. """ - return 2 * np.arctan2( + val: float = 2 * np.arctan2( np.linalg.norm(normalize(v1) - normalize(v2)), np.linalg.norm(normalize(v1) + normalize(v2)), ) + return val -def normalize(vect: np.ndarray | tuple[float], fall_back=None) -> np.ndarray: + +def normalize( + vect: np.ndarray | tuple[float], fall_back: np.ndarray | None = None +) -> np.ndarray: norm = np.linalg.norm(vect) if norm > 0: return np.array(vect) / norm @@ -487,11 +494,11 @@ def R3_to_complex(point: Sequence[float]) -> np.ndarray: return complex(*point[:2]) -def complex_func_to_R3_func(complex_func): +def complex_func_to_R3_func(complex_func: Callable) -> Callable: return lambda p: complex_to_R3(complex_func(R3_to_complex(p))) -def center_of_mass(points: Sequence[float]) -> np.ndarray: +def center_of_mass(points: list[Sequence[float]]) -> np.ndarray: """Gets the center of mass of the points in space. Parameters @@ -619,12 +626,13 @@ def get_winding_number(points: Sequence[np.ndarray]) -> float: >>> get_winding_number(polygon.get_vertices()) 0.0 """ - total_angle = 0 + total_angle: float = 0 for p1, p2 in adjacent_pairs(points): d_angle = angle_of_vector(p2) - angle_of_vector(p1) d_angle = ((d_angle + PI) % TAU) - PI total_angle += d_angle - return total_angle / TAU + val: float = total_angle / TAU + return val def shoelace(x_y: np.ndarray) -> float: @@ -637,7 +645,8 @@ def shoelace(x_y: np.ndarray) -> float: """ x = x_y[:, 0] y = x_y[:, 1] - return np.trapz(y, x) + val: float = np.trapz(y, x) + return val def shoelace_direction(x_y: np.ndarray) -> str: @@ -758,7 +767,7 @@ def earclip_triangulation(verts: np.ndarray, ring_ends: list) -> list: raise Exception("Could not find a ring to attach") # Setup linked list - after = [] + after: list = [] end0 = 0 for end1 in ring_ends: after.extend(range(end0 + 1, end1)) @@ -829,7 +838,7 @@ def spherical_to_cartesian(spherical: Sequence[float]) -> np.ndarray: def perpendicular_bisector( line: Sequence[np.ndarray], - norm_vector=OUT, + norm_vector: Vector3D = OUT, ) -> Sequence[np.ndarray]: """Returns a list of two points that correspond to the ends of the perpendicular bisector of the diff --git a/mypy.ini b/mypy.ini index cc298f009f..66824dd1b6 100644 --- a/mypy.ini +++ b/mypy.ini @@ -109,6 +109,9 @@ ignore_errors = False [mypy-manim.utils.parameter_parsing.*] ignore_errors = False +[mypy-manim.utils.space_ops.*] +ignore_errors = False + [mypy-manim.utils.simple_functions.*] ignore_errors = False