Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding type annotations to manim.utils.* #3999

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion manim/mobject/text/numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
__all__ = ["DecimalNumber", "Integer", "Variable"]

from collections.abc import Sequence
from typing import Any

import numpy as np

Expand Down Expand Up @@ -327,7 +328,9 @@ def construct(self):
self.add(Integer(number=6.28).set_x(-1.5).set_y(-2).set_color(YELLOW).scale(1.4))
"""

def __init__(self, number=0, num_decimal_places=0, **kwargs):
def __init__(
self, number: float = 0, num_decimal_places: int = 0, **kwargs: Any
) -> None:
super().__init__(number=number, num_decimal_places=num_decimal_places, **kwargs)

def get_value(self):
Expand Down
3 changes: 2 additions & 1 deletion manim/renderer/cairo_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import typing

import numpy as np
import numpy.typing as npt

from manim.utils.hashing import get_hash_from_play_call

from .. import config, logger
from ..camera.camera import Camera
from ..mobject.mobject import Mobject
from ..scene.scene_file_writer import SceneFileWriter

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'SceneFileWriter' may not be defined if module
manim.scene.scene_file_writer
is imported before module
manim.renderer.cairo_renderer
, as the
definition
of SceneFileWriter occurs after the cyclic
import
of manim.renderer.cairo_renderer.
'SceneFileWriter' may not be defined if module
manim.scene.scene_file_writer
is imported before module
manim.renderer.cairo_renderer
, as the
definition
of SceneFileWriter occurs after the cyclic
import
of manim.renderer.cairo_renderer.
'SceneFileWriter' may not be defined if module
manim.scene.scene_file_writer
is imported before module
manim.renderer.cairo_renderer
, as the
definition
of SceneFileWriter occurs after the cyclic
import
of manim.renderer.cairo_renderer.
from ..utils.exceptions import EndSceneEarlyException
from ..utils.iterables import list_update

Expand Down Expand Up @@ -160,7 +161,7 @@
self.update_frame(scene, moving_mobjects)
self.add_frame(self.get_frame())

def get_frame(self):
def get_frame(self) -> npt.NDArray:
"""
Gets the current frame as NumPy array.

Expand Down
19 changes: 14 additions & 5 deletions manim/scene/scene_file_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from pydub import AudioSegment

from manim import __version__
from manim.typing import PixelArray
from manim.renderer.cairo_renderer import CairoRenderer
Fixed Show fixed Hide fixed
from manim.renderer.opengl_renderer import OpenGLRenderer
Fixed Show fixed Hide fixed
from manim.typing import PixelArray, StrPath

from .. import config, logger
from .._config.logger_utils import set_file_logger
Expand Down Expand Up @@ -104,7 +106,12 @@

force_output_as_scene_name = False

def __init__(self, renderer, scene_name, **kwargs):
def __init__(
self,
renderer: CairoRenderer | OpenGLRenderer,
scene_name: StrPath,
**kwargs: Any,
) -> None:
self.renderer = renderer
self.init_output_directories(scene_name)
self.init_audio()
Expand All @@ -118,7 +125,7 @@
name="autocreated", type_=DefaultSectionType.NORMAL, skip_animations=False
)

def init_output_directories(self, scene_name):
def init_output_directories(self, scene_name: StrPath):
"""Initialise output directories.

Notes
Expand Down Expand Up @@ -378,7 +385,9 @@
self.add_audio_segment(new_segment, time, **kwargs)

# Writers
def begin_animation(self, allow_write: bool = False, file_path=None):
def begin_animation(
self, allow_write: bool = False, file_path: StrPath | None = None
) -> None:
"""
Used internally by manim to stream the animation to FFMPEG for
displaying or writing to a file.
Expand All @@ -391,7 +400,7 @@
if write_to_movie() and allow_write:
self.open_partial_movie_stream(file_path=file_path)

def end_animation(self, allow_write: bool = False):
def end_animation(self, allow_write: bool = False) -> None:
"""
Internally used by Manim to stop streaming to
FFMPEG gracefully.
Expand Down
107 changes: 69 additions & 38 deletions manim/utils/bezier.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
BezierPoints,
BezierPoints_Array,
ColVector,
InternalPoint3D,
InternalPoint3D_Array,
MatrixMN,
Point3D,
Point3D_Array,
Expand All @@ -54,10 +56,12 @@
@overload
def bezier(
points: Sequence[Point3D_Array],
) -> Callable[[float | ColVector], Point3D_Array]: ...
) -> Callable[[float | ColVector], Point3D | Point3D_Array]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


def bezier(points):
def bezier(
points: Point3D_Array | Sequence[Point3D_Array],
) -> Callable[[float | ColVector], Point3D | Point3D_Array]:
"""Classic implementation of a Bézier curve.

Parameters
Expand Down Expand Up @@ -111,21 +115,21 @@

if degree == 0:

def zero_bezier(t):
def zero_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
return np.ones_like(t) * P[0]

return zero_bezier

if degree == 1:

def linear_bezier(t):
def linear_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
return P[0] + t * (P[1] - P[0])

return linear_bezier

if degree == 2:

def quadratic_bezier(t):
def quadratic_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
t2 = t * t
mt = 1 - t
mt2 = mt * mt
Expand All @@ -135,7 +139,7 @@

if degree == 3:

def cubic_bezier(t):
def cubic_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
t2 = t * t
t3 = t2 * t
mt = 1 - t
Expand All @@ -145,11 +149,12 @@

return cubic_bezier

def nth_grade_bezier(t):
def nth_grade_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
is_scalar = not isinstance(t, np.ndarray)
if is_scalar:
B = np.empty((1, *P.shape))
else:
assert isinstance(t, np.ndarray)
t = t.reshape(-1, *[1 for dim in P.shape])
B = np.empty((t.shape[0], *P.shape))
B[:] = P
Expand All @@ -162,7 +167,8 @@
# In the end, there shall be the evaluation at t of a single Bezier curve of
# grade d, stored in the first slot of B
if is_scalar:
return B[0, 0]
val: Point3D = B[0, 0]
return val
return B[:, 0]

return nth_grade_bezier
Expand Down Expand Up @@ -700,7 +706,7 @@


# Memos explained in subdivide_bezier docstring
SUBDIVISION_MATRICES = [{} for i in range(4)]
SUBDIVISION_MATRICES: list[dict[int, npt.NDArray]] = [{} for i in range(4)]


def _get_subdivision_matrix(n_points: int, n_divisions: int) -> MatrixMN:
Expand Down Expand Up @@ -812,7 +818,9 @@
return subdivision_matrix


def subdivide_bezier(points: BezierPoints, n_divisions: int) -> Point3D_Array:
def subdivide_bezier(
points: InternalPoint3D_Array, n_divisions: int
) -> InternalPoint3D_Array:
r"""Subdivide a Bézier curve into :math:`n` subcurves which have the same shape.

The points at which the curve is split are located at the
Expand Down Expand Up @@ -1012,14 +1020,22 @@


@overload
def interpolate(start: Point3D, end: Point3D, alpha: float) -> Point3D: ...
def interpolate(
start: InternalPoint3D, end: InternalPoint3D, alpha: float
) -> Point3D: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


@overload
def interpolate(start: Point3D, end: Point3D, alpha: ColVector) -> Point3D_Array: ...
def interpolate(
start: InternalPoint3D, end: InternalPoint3D, alpha: ColVector
) -> Point3D_Array: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


def interpolate(start, end, alpha):
def interpolate(
start: float | InternalPoint3D,
end: float | InternalPoint3D,
alpha: float | ColVector,
) -> float | ColVector | Point3D | Point3D_Array:
"""Linearly interpolates between two values ``start`` and ``end``.

Parameters
Expand Down Expand Up @@ -1099,10 +1115,12 @@


@overload
def mid(start: Point3D, end: Point3D) -> Point3D: ...
def mid(start: InternalPoint3D, end: InternalPoint3D) -> Point3D: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


def mid(start: float | Point3D, end: float | Point3D) -> float | Point3D:
def mid(
start: float | InternalPoint3D, end: float | InternalPoint3D
) -> float | Point3D:
"""Returns the midpoint between two values.

Parameters
Expand All @@ -1124,15 +1142,21 @@


@overload
def inverse_interpolate(start: float, end: float, value: Point3D) -> Point3D: ...
def inverse_interpolate(
start: float, end: float, value: InternalPoint3D
) -> InternalPoint3D: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


@overload
def inverse_interpolate(start: Point3D, end: Point3D, value: Point3D) -> Point3D: ...
def inverse_interpolate(
start: InternalPoint3D, end: InternalPoint3D, value: InternalPoint3D
) -> Point3D: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


def inverse_interpolate(
start: float | Point3D, end: float | Point3D, value: float | Point3D
start: float | InternalPoint3D,
end: float | InternalPoint3D,
value: float | InternalPoint3D,
) -> float | Point3D:
"""Perform inverse interpolation to determine the alpha
values that would produce the specified ``value``
Expand Down Expand Up @@ -1186,7 +1210,7 @@
new_end: float,
old_start: float,
old_end: float,
old_value: Point3D,
old_value: InternalPoint3D,
) -> Point3D: ...


Expand All @@ -1195,7 +1219,7 @@
new_end: float,
old_start: float,
old_end: float,
old_value: float | Point3D,
old_value: float | InternalPoint3D,
) -> float | Point3D:
"""Interpolate a value from an old range to a new range.

Expand Down Expand Up @@ -1227,7 +1251,7 @@
return interpolate(
new_start,
new_end,
old_alpha, # type: ignore[arg-type]
old_alpha,
)


Expand Down Expand Up @@ -1263,7 +1287,8 @@
# they can only be an interpolation of these two anchors with alphas
# 1/3 and 2/3, which will draw a straight line between the anchors.
if n_anchors == 2:
return interpolate(anchors[0], anchors[1], np.array([[1 / 3], [2 / 3]]))
val = interpolate(anchors[0], anchors[1], np.array([[1 / 3], [2 / 3]]))
return (val[0], val[1])

# Handle different cases depending on whether the points form a closed
# curve or not
Expand Down Expand Up @@ -1738,7 +1763,12 @@
) -> QuadraticBezierPoints_Array: ...


def get_quadratic_approximation_of_cubic(a0, h0, h1, a1):
def get_quadratic_approximation_of_cubic(
a0: Point3D | Point3D_Array,
h0: Point3D | Point3D_Array,
h1: Point3D | Point3D_Array,
a1: Point3D | Point3D_Array,
) -> QuadraticBezierPoints_Array:
r"""If ``a0``, ``h0``, ``h1`` and ``a1`` are the control points of a cubic
Bézier curve, approximate the curve with two quadratic Bézier curves and
return an array of 6 points, where the first 3 points represent the first
Expand Down Expand Up @@ -1842,33 +1872,33 @@
If ``a0``, ``h0``, ``h1`` and ``a1`` have different dimensions, or
if their number of dimensions is not 1 or 2.
"""
a0 = np.asarray(a0)
h0 = np.asarray(h0)
h1 = np.asarray(h1)
a1 = np.asarray(a1)

if all(arr.ndim == 1 for arr in (a0, h0, h1, a1)):
num_curves, dim = 1, a0.shape[0]
elif all(arr.ndim == 2 for arr in (a0, h0, h1, a1)):
num_curves, dim = a0.shape
a0c = np.asarray(a0)
h0c = np.asarray(h0)
h1c = np.asarray(h1)
a1c = np.asarray(a1)

if all(arr.ndim == 1 for arr in (a0c, h0c, h1c, a1c)):
num_curves, dim = 1, a0c.shape[0]
elif all(arr.ndim == 2 for arr in (a0c, h0c, h1c, a1c)):
num_curves, dim = a0c.shape
else:
raise ValueError("All arguments must be Point3D or Point3D_Array.")

m0 = 0.25 * (3 * h0 + a0)
m1 = 0.25 * (3 * h1 + a1)
m0 = 0.25 * (3 * h0c + a0c)
m1 = 0.25 * (3 * h1c + a1c)
k = 0.5 * (m0 + m1)

result = np.empty((6 * num_curves, dim))
result[0::6] = a0
result[0::6] = a0c
result[1::6] = m0
result[2::6] = k
result[3::6] = k
result[4::6] = m1
result[5::6] = a1
result[5::6] = a1c
return result


def is_closed(points: Point3D_Array) -> bool:
def is_closed(points: InternalPoint3D_Array) -> bool:
"""Returns ``True`` if the spline given by ``points`` is closed, by
checking if its first and last points are close to each other, or``False``
otherwise.
Expand Down Expand Up @@ -1938,7 +1968,8 @@
return False
if abs(end[1] - start[1]) > tolerance[1]:
return False
return abs(end[2] - start[2]) <= tolerance[2]
val: bool = abs(end[2] - start[2]) <= tolerance[2]
return val


def proportions_along_bezier_curve_for_point(
Expand Down
Loading
Loading