Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: Apply Updaters to Graph Edges During Edge Creation #3836

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 59 additions & 84 deletions manim/mobject/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import itertools as it
from collections.abc import Hashable, Iterable
from copy import copy
from copy import copy, deepcopy
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast

import networkx as nx
Expand Down Expand Up @@ -617,6 +617,7 @@ def __init__(

self.vertices = {v: vertex_type(**self._vertex_config[v]) for v in vertices}
self.vertices.update(vertex_mobjects)
self.add(*self.vertices.values())

self.change_layout(
layout=layout,
Expand All @@ -626,37 +627,16 @@ def __init__(
root_vertex=root_vertex,
)

# build edge_config
if edge_config is None:
edge_config = {}
default_tip_config = {}
default_edge_config = {}
if edge_config:
default_tip_config = edge_config.pop("tip_config", {})
default_edge_config = {
k: v
for k, v in edge_config.items()
if not isinstance(
k, tuple
) # everything that is not an edge is an option
}
self._edge_config = {}
self._tip_config = {}
for e in edges:
if e in edge_config:
self._tip_config[e] = edge_config[e].pop(
"tip_config", copy(default_tip_config)
)
self._edge_config[e] = edge_config[e]
else:
self._tip_config[e] = copy(default_tip_config)
self._edge_config[e] = copy(default_edge_config)

self.default_edge_config = default_edge_config
self._populate_edge_dict(edges, edge_type)
self.edges = {}
self._edge_config = {}
self.default_edge_config, _ = GenericGraph._split_out_child_configs(
edge_config, lambda e: e in edges
)

self.add(*self.vertices.values())
self.add(*self.edges.values())
self.add_edges(*edges, edge_type=edge_type, edge_config=edge_config)

self.add_updater(self.update_edges)

Expand All @@ -665,11 +645,11 @@ def _empty_networkx_graph() -> nx.classes.graph.Graph:
"""Return an empty networkx graph for the given graph type."""
raise NotImplementedError("To be implemented in concrete subclasses")

def _populate_edge_dict(
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject]
):
"""Helper method for populating the edges of the graph."""
raise NotImplementedError("To be implemented in concrete subclasses")
@staticmethod
def _split_out_child_configs(config: dict, is_child_key) -> tuple[dict, dict]:
parent_config = {k: v for k, v in config.items() if not is_child_key(k)}
child_configs = {k: v for k, v in config.items() if is_child_key(k)}
return parent_config, child_configs

def __getitem__(self: Graph, v: Hashable) -> Mobject:
return self.vertices[v]
Expand Down Expand Up @@ -1019,28 +999,20 @@ def _add_edge(

"""
if edge_config is None:
edge_config = self.default_edge_config.copy()
added_mobjects = []
for v in edge:
if v not in self.vertices:
added_mobjects.append(self._add_vertex(v))
u, v = edge
edge_config = {}

added_vertices = [self._add_vertex(v) for v in edge if v not in self.vertices]

u, v = edge
self._graph.add_edge(u, v)

base_edge_config = self.default_edge_config.copy()
base_edge_config.update(edge_config)
edge_config = base_edge_config
self._edge_config[(u, v)] = edge_config
self._edge_config[edge] = {**self.default_edge_config, **edge_config}
edge_mobject = self._create_edge_mobject(edge, edge_type)

edge_mobject = edge_type(
self[u].get_center(), self[v].get_center(), z_index=-1, **edge_config
)
self.edges[(u, v)] = edge_mobject

self.add(edge_mobject)
added_mobjects.append(edge_mobject)
return self.get_group_class()(*added_mobjects)

return self.get_group_class()(*added_vertices, edge_mobject)

def add_edges(
self,
Expand Down Expand Up @@ -1078,13 +1050,12 @@ def add_edges(
"""
if edge_config is None:
edge_config = {}
non_edge_settings = {k: v for (k, v) in edge_config.items() if k not in edges}
base_edge_config = self.default_edge_config.copy()
base_edge_config.update(non_edge_settings)
base_edge_config = {e: base_edge_config.copy() for e in edges}
for e in edges:
base_edge_config[e].update(edge_config.get(e, {}))
edge_config = base_edge_config
else:
edge_config = deepcopy(edge_config)

batch_default_config, custom_configs = GenericGraph._split_out_child_configs(
edge_config, lambda e: e in edges
)

edge_vertices = set(it.chain(*edges))
new_vertices = [v for v in edge_vertices if v not in self.vertices]
Expand All @@ -1095,7 +1066,10 @@ def add_edges(
self._add_edge(
edge,
edge_type=edge_type,
edge_config=edge_config[edge],
edge_config={
**batch_default_config,
**custom_configs.get(edge, {}),
},
).submobjects
for edge in edges
),
Expand Down Expand Up @@ -1136,7 +1110,7 @@ def _remove_edge(self, edge: tuple[Hashable]):
edge_mobject = self.edges.pop(edge)

self._graph.remove_edge(*edge)
self._edge_config.pop(edge, None)
self._edge_config.pop(edge)

self.remove(edge_mobject)
return edge_mobject
Expand Down Expand Up @@ -1535,18 +1509,14 @@ def construct(self):
def _empty_networkx_graph() -> nx.Graph:
return nx.Graph()

def _populate_edge_dict(
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject]
):
self.edges = {
(u, v): edge_type(
self[u].get_center(),
self[v].get_center(),
z_index=-1,
**self._edge_config[(u, v)],
)
for (u, v) in edges
}
def _create_edge_mobject(self, edge, edge_type):
u, v = edge
return edge_type(
self[u].get_center(),
self[v].get_center(),
z_index=-1,
**self._edge_config[(u, v)],
)

def update_edges(self, graph):
for (u, v), edge in graph.edges.items():
Expand Down Expand Up @@ -1742,21 +1712,26 @@ def construct(self):
def _empty_networkx_graph() -> nx.DiGraph:
return nx.DiGraph()

def _populate_edge_dict(
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject]
):
self.edges = {
(u, v): edge_type(
self[u],
self[v],
z_index=-1,
**self._edge_config[(u, v)],
)
for (u, v) in edges
}
@staticmethod
def _split_out_tip_configs(config: dict) -> tuple[dict, dict]:
edge_config, tip_config = GenericGraph._split_out_child_configs(
config, lambda k: k == "tip_config"
)
return edge_config, tip_config.get("tip_config", {})

for (u, v), edge in self.edges.items():
edge.add_tip(**self._tip_config[(u, v)])
def _create_edge_mobject(self, edge, edge_type):
edge_config, tip_config = DiGraph._split_out_tip_configs(
self._edge_config[edge]
)
u, v = edge
edge_mobject = edge_type(
self[u],
self[v],
z_index=-1,
**edge_config,
)
edge_mobject.add_tip(**tip_config)
return edge_mobject

def update_edges(self, graph):
"""Updates the edges to stick at their corresponding vertices.
Expand Down
Binary file not shown.
37 changes: 37 additions & 0 deletions tests/test_graphical_units/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

from manim import *
from manim.utils.testing.frames_comparison import frames_comparison

__module_test__ = "graph"


@frames_comparison
def test_digraph_add_edges(scene):
vertices = range(5)
edges = [
(0, 1),
(1, 2),
(3, 2),
(3, 4),
]

edge_config = {
"stroke_width": 2,
"tip_config": {
"tip_shape": ArrowSquareTip,
"tip_length": 0.15,
},
(3, 4): {"color": RED, "tip_config": {"tip_length": 0.25, "tip_width": 0.25}},
}

g = DiGraph(
vertices,
[],
labels=True,
layout="circular",
).scale(1.4)

g.add_edges(*edges, edge_config=edge_config)

scene.add(g)
Loading