diff --git a/bindings/pydrake/BUILD.bazel b/bindings/pydrake/BUILD.bazel index a96e80aa751c..ad38f114383e 100644 --- a/bindings/pydrake/BUILD.bazel +++ b/bindings/pydrake/BUILD.bazel @@ -392,20 +392,14 @@ drake_py_unittest( tags = ["lint"], ) -# TODO(jwnimmer-tri) Once this a real test, switch it to drake_py_unittest. -drake_py_binary( +drake_py_unittest( name = "memory_leak_test", - srcs = ["test/memory_leak_test.py"], - add_test_rule = True, data = [ "@drake_models//:iiwa_description", "@drake_models//:manipulation_station", "@drake_models//:veggies", "@drake_models//:wsg_50_description", ], - test_rule_args = [ - "--count=2", - ], deps = [ ":all_py", ], diff --git a/bindings/pydrake/test/memory_leak_test.py b/bindings/pydrake/test/memory_leak_test.py index b6dadb46c58c..c05990e847fb 100644 --- a/bindings/pydrake/test/memory_leak_test.py +++ b/bindings/pydrake/test/memory_leak_test.py @@ -1,17 +1,18 @@ -"""Eventually this program might grow up to be an actual regression test for -memory leaks, but for now it merely serves to demonstrate such leaks. +"""Regression test for memory leaks. -Currently, it neither asserts the absence of leaks (i.e., a real test) nor the -presence of leaks (i.e., an expect-fail test) -- instead, it's a demonstration -that we can instrument and observe by hand, to gain traction on the problem. +The test contains examples of pydrake code that may leak (DUTs), +instrumentation to detect leaks, and optional additional debug printing under +an internal verbose option. """ -import argparse import dataclasses import functools import gc +import platform import sys import textwrap +import unittest +import weakref from pydrake.planning import RobotDiagramBuilder from pydrake.systems.analysis import Simulator @@ -40,32 +41,106 @@ from pydrake.visualization import ApplyVisualizationConfig, VisualizationConfig -@dataclasses.dataclass -class RepetitionDetail: - """Captures some details of an instrumented run: an iteration counter, and - the count of allocated memory blocks.""" - i: int - blocks: int | None = None +# Developer-only configuration. +VERBOSE = False + + +@functools.cache +def _get_meshcat_singleton(): + return Meshcat() + + +@dataclasses.dataclass(frozen=True) +class _Sentinel: + """_Sentinel uses `weakref.finalize` to spy on the end of an object's + lifetime. The test will use this information to determine whether objects + of interest were properly garbage collected or not, and to provide logging + of exactly when objects are finalized. + + See also: https://docs.python.org/3/library/weakref.html#weakref.finalize + """ + finalizer: weakref.finalize + name: str + + +def _make_sentinel(obj, name): + """Makes a _Sentinel for `obj` using `name` for debugging messages. If + this instance was created with `VERBOSE=False` (the default), no + messages will be printed, but the _Sentinel will still track the object. + """ + if VERBOSE: + print(f"sentinel: tracked {name} {hex(id(obj))}") + + def done(oid): + if VERBOSE: + print(f"sentinel: unmade {name} {hex(oid)}") + return _Sentinel(finalizer=weakref.finalize(obj, done, id(obj)), name=name) + + +def _make_sentinels_from_locals(dut_name, locals_dict): + """Makes _Sentinels for all local variables of interest.""" + # Skip specific types not supported by weakref, as needed. + return {_make_sentinel(value, f"{dut_name}::{key}") + for key, value in locals_dict.items() + if not any(isinstance(value, typ) for typ in [list, str])} + + +def _report_sentinels(sentinels, message: str): + """Prints extensive debug information for a sequence of _Sentinels. + The message string can provide additional context that may be available at + the call site. + """ + print(message) + for sentinel in sentinels: + print(f"sentinel for {sentinel.name}") + finalizer = sentinel.finalizer + print(f"sentinel alive? {finalizer.alive}") + if finalizer.alive: + o = finalizer.peek()[0] + is_tracked = gc.is_tracked(o) + print(f"is_tracked: {is_tracked}") + if is_tracked: + print(f"generation: {_object_generation(o)}") + print(f"referrers: {gc.get_referrers(o)}") + print(f"referents: {gc.get_referents(o)}") + + +def _object_generation(o) -> int | None: + """Returns the garbage collection generation of the passed object, or None + if the object is not tracked by garbage collection. + + See also: https://github.com/python/cpython/blob/main/InternalDocs/garbage_collector.md#optimization-generations # noqa + """ + for gen in range(3): + gen_list = gc.get_objects(generation=gen) + if any([x is o for x in gen_list]): + return gen + return None def _dut_simple_source(): """A device under test that creates and destroys a leaf system.""" source = ConstantVectorSource([1.0]) + return {_make_sentinel(source, "simple source")} def _dut_trivial_simulator(): - """A device under test that creates and destroys a simulator that contains - only a single, simple subsystem.""" + """A device under test that creates and destroys a simulator that + contains only a single, simple subsystem. + """ builder = DiagramBuilder() - builder.AddSystem(ConstantVectorSource([1.0])) + source = builder.AddSystem(ConstantVectorSource([1.0])) + source2 = builder.AddSystem(ConstantVectorSource([1.0])) diagram = builder.Build() simulator = Simulator(system=diagram) simulator.AdvanceTo(1.0) + return _make_sentinels_from_locals("trivial_simulator", locals()) def _dut_mixed_language_simulator(): - """A device under test that creates and destroys a simulator that contains - subsystems written in both C++ and Python.""" + """A device under test that creates and destroys a simulator that + contains subsystems written in both C++ and Python. + """ builder = RobotDiagramBuilder() builder.builder().AddSystem(ConstantVectorSource([1.0])) diagram = builder.Build() @@ -75,16 +150,13 @@ def _dut_mixed_language_simulator(): plant = diagram.plant() plant_context = plant.GetMyContextFromRoot(context) plant.EvalSceneGraphInspector(plant_context) - - -@functools.cache -def _get_meshcat_singleton(): - return Meshcat() + return _make_sentinels_from_locals("mixed_language_simulator", locals()) def _dut_full_example(): - """A device under test that creates and destroys a simulator that contains - everything a full-stack simulation would ever use.""" + """A device under test that creates and destroys a simulator that + contains everything a full-stack simulation would ever use. + """ builder = DiagramBuilder() plant, scene_graph = AddMultibodyPlant( plant_config=MultibodyPlantConfig( @@ -192,56 +264,69 @@ def _dut_full_example(): random = RandomGenerator(22) diagram.SetRandomContext(simulator.get_mutable_context(), random) simulator.AdvanceTo(0.5) + return _make_sentinels_from_locals("full_example", locals()) + + +def _repeat(*, dut: callable, count: int): + """Calls dut() for count times in a row, performing a full garbage + collection before and after each call. Tracks memory leaks of interest; the + count of leaks is returned. If `VERBOSE=True`, additional debug + information will be printed. + Args: + dut() -> Sequence[_Sentinel]: a callable function containing code to + test for leaks, and returning _Sentinels + for data of interest. + count: the number of times to invoke `dut`. -def _repeat(*, dut: callable, count: int) -> list[RepetitionDetail]: - """Returns the details of calling dut() for count times in a row.""" - # Pre-allocate all of our return values. - details = [RepetitionDetail(i=i) for i in range(count)] + Returns: + int: the total number of leaked objectss detected by examining returned + _Sentinels. + """ gc.collect() - tare_blocks = sys.getallocatedblocks() - # Call the dut repeatedly, keeping stats as we go. + # Call the dut repeatedly, observing tracked blocks. + leaks = 0 for i in range(count): - dut() + sentinels = dut() + if VERBOSE: + _report_sentinels(sentinels, "before collect") gc.collect() - details[i].blocks = sys.getallocatedblocks() - tare_blocks - return details - - -def _main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--count", - metavar="N", - type=int, - default=5, - help="Number of iterations to run", - ) - parser.add_argument( - "--dut", - metavar="NAME", - help="Chooses a device under test; when not given, all DUTs are run.", - ) - args = parser.parse_args() - all_duts = dict([ - (dut.__name__[5:], dut) - for dut in [ - _dut_simple_source, - _dut_trivial_simulator, - _dut_mixed_language_simulator, - _dut_full_example, - ] - ]) - if args.dut: - run_duts = {args.dut: all_duts[args.dut]} - else: - run_duts = all_duts - for name, dut in run_duts.items(): - details = _repeat(dut=dut, count=args.count) - print(f"RUNNING: {name}") - for x in details: - print(x) - - -assert __name__ == "__main__", __name__ -sys.exit(_main()) + if VERBOSE: + _report_sentinels(sentinels, "after collect") + leaks += any( + [sentinel.finalizer.alive for sentinel in sentinels]) + return leaks + + +class TestMemoryLeaks(unittest.TestCase): + def do_test(self, *, dut, count, leaks_allowed=0, leaks_required=0): + """Runs the requested `dut` (see _repeat() above) for `count` + iterations. Check that leaks detected <= leaks allowed. In addition, + check if the leaks required <= the actual leaks measured. Using a non-0 + leaks_required will cause the test to fail if fixes get implemented. In + that case, the test can likely be updated to be more strict. + """ + leaks = _repeat(dut=dut, count=count) + self.assertLessEqual(leaks, leaks_allowed) + self.assertGreaterEqual(leaks, leaks_required) + + def test_simple_source(self): + self.do_test(dut=_dut_simple_source, count=10) + + def test_trivial_simulator(self): + self.do_test( + dut=_dut_trivial_simulator, + count=5, + # TODO(rpoyner-tri): Allow 0 leaks. + leaks_allowed=5, leaks_required=1) + + def test_mixed_language_simulator(self): + self.do_test( + dut=_dut_mixed_language_simulator, + count=5, + # TODO(rpoyner-tri): Allow 0 leaks. + leaks_allowed=4, leaks_required=1) + + def test_full_example(self): + # Note: this test doesn't invoke the #14355 deliberate cycle. + self.do_test(dut=_dut_full_example, count=2)