Skip to content

Commit

Permalink
Merge pull request #54 from peterchenadded/performance/53-improve-hea…
Browse files Browse the repository at this point in the history
…pq-usage-performance

Performance/53 improve heapq usage performance
  • Loading branch information
brean authored Jan 13, 2024
2 parents 2990b64 + bb83c7d commit a091843
Show file tree
Hide file tree
Showing 15 changed files with 305 additions and 14 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ target/

# ipython notebook
.ipynb_checkpoints

# python virtual env
venv/
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,23 @@ flow:

You can run the tests locally using pytest. Take a look at the `test`-folder

You can follow below steps to setup your virtual environment and run the tests.

```bash
# Go to repo
cd python-pathfinding

# Setup virtual env and activate it - Mac/Linux for windows use source venv/Scripts/activate
python3 -m venv venv
source venv/bin/activate

# Install test requirements
pip install -r test/requirements.txt

# Run all the tests
pytest
```

## Contributing

Please use the [issue tracker](https://github.com/brean/python-pathfinding/issues) to submit bug reports and feature requests. Please use merge requests as described [here](/CONTRIBUTING.md) to add/adapt functionality.
Expand Down
File renamed without changes.
112 changes: 112 additions & 0 deletions notebooks/performance.ipynb

Large diffs are not rendered by default.

87 changes: 87 additions & 0 deletions pathfinding/core/heap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Simple heap with ordering and removal."""
import heapq
from .graph import Graph
from .grid import Grid
from .world import World

class SimpleHeap:
"""Simple wrapper around open_list that keeps track of order and removed nodes automatically."""

def __init__(self, node, grid):
self.grid = grid
self.open_list = [self._get_node_tuple(node, 0)]
self.removed_node_tuples = set()
self.heap_order = {}
self.number_pushed = 0

def _get_node_tuple(self, node, heap_order):
if isinstance(self.grid, Graph):
return (node.f, heap_order, node.node_id)
elif isinstance(self.grid, Grid):
return (node.f, heap_order, node.x, node.y)
elif isinstance(self.grid, World):
return (node.f, heap_order, node.x, node.y, node.grid_id)
else:
assert False, "unsupported heap node node=%s" % node

def _get_node_id(self, node):
if isinstance(self.grid, Graph):
return node.node_id
elif isinstance(self.grid, Grid):
return (node.x, node.y)
elif isinstance(self.grid, World):
return (node.x, node.y, node.grid_id)


def pop_node(self):
"""
Pops node off the heap. i.e. returns the one with the lowest f.
Notes:
1. Checks if that values is in removed_node_tuples first, if not tries again.
2. We use this approach to avoid invalidating the heap structure.
"""
node_tuple = heapq.heappop(self.open_list)
while node_tuple in self.removed_node_tuples:
node_tuple = heapq.heappop(self.open_list)

if isinstance(self.grid, Graph):
node = self.grid.node(node_tuple[2])
elif isinstance(self.grid, Grid):
node = self.grid.node(node_tuple[2], node_tuple[3])
elif isinstance(self.grid, World):
node = self.grid.grids[node_tuple[4]].node(node_tuple[2], node_tuple[3])

return node

def push_node(self, node):
"""
Push node into heap.
:param node: The node to push.
"""
self.number_pushed = self.number_pushed + 1
node_tuple = self._get_node_tuple(node, self.number_pushed)
node_id = self._get_node_id(node)

self.heap_order[node_id] = self.number_pushed

heapq.heappush(self.open_list, node_tuple)

def remove_node(self, node, f):
"""
Remove the node from the heap.
This just stores it in a set and we just ignore the node if it does get popped from the heap.
:param node: The node to remove.
:param f: The old f value of the node.
"""
node_id = self._get_node_id(node)
heap_order = self.heap_order[node_id]
node_tuple = self._get_node_tuple(node, heap_order)
self.removed_node_tuples.add(node_tuple)

def __len__(self):
"""Returns the length of the open_list."""
return len(self.open_list)
4 changes: 1 addition & 3 deletions pathfinding/finder/a_star.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import heapq # used for the so colled "open list" that stores known nodes
from .finder import BY_END, Finder, MAX_RUNS, TIME_LIMIT
from ..core.diagonal_movement import DiagonalMovement
from ..core.heuristic import manhattan, octile
Expand Down Expand Up @@ -50,8 +49,7 @@ def check_neighbors(self, start, end, graph, open_list,
:param open_list: stores nodes that will be processed next
"""
# pop node with minimum 'f' value
node = heapq.nsmallest(1, open_list)[0]
open_list.remove(node)
node = open_list.pop_node()
node.closed = True

# if reached the end position, construct the path and return it
Expand Down
5 changes: 3 additions & 2 deletions pathfinding/finder/bi_a_star.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .a_star import AStarFinder
from .finder import BY_END, BY_START, MAX_RUNS, TIME_LIMIT
from ..core.diagonal_movement import DiagonalMovement
from ..core.heap import SimpleHeap


class BiAStarFinder(AStarFinder):
Expand Down Expand Up @@ -45,12 +46,12 @@ def find_path(self, start, end, grid):
self.start_time = time.time() # execution time limitation
self.runs = 0 # count number of iterations

start_open_list = [start]
start_open_list = SimpleHeap(start, grid)
start.g = 0
start.f = 0
start.opened = BY_START

end_open_list = [end]
end_open_list = SimpleHeap(end, grid)
end.g = 0
end.f = 0
end.opened = BY_END
Expand Down
4 changes: 2 additions & 2 deletions pathfinding/finder/breadth_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, heuristic=None, weight=1,
self.diagonalMovement = DiagonalMovement.never

def check_neighbors(self, start, end, grid, open_list):
node = open_list.pop(0)
node = open_list.pop_node()
node.closed = True

if node == end:
Expand All @@ -30,6 +30,6 @@ def check_neighbors(self, start, end, grid, open_list):
if neighbor.closed or neighbor.opened:
continue

open_list.append(neighbor)
open_list.push_node(neighbor)
neighbor.opened = True
neighbor.parent = node
10 changes: 6 additions & 4 deletions pathfinding/finder/finder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import heapq # used for the so colled "open list" that stores known nodes
import time # for time limitation
from ..core.diagonal_movement import DiagonalMovement
from ..core.heap import SimpleHeap


# max. amount of tries we iterate until we abort the search
Expand Down Expand Up @@ -107,20 +108,21 @@ def process_node(
ng = parent.g + graph.calc_cost(parent, node, self.weighted)

if not node.opened or ng < node.g:
old_f = node.f
node.g = ng
node.h = node.h or self.apply_heuristic(node, end)
# f is the estimated total cost from start to goal
node.f = node.g + node.h
node.parent = parent
if not node.opened:
heapq.heappush(open_list, node)
open_list.push_node(node)
node.opened = open_value
else:
# the node can be reached with smaller cost.
# Since its f value has been updated, we have to
# update its position in the open list
open_list.remove(node)
heapq.heappush(open_list, node)
open_list.remove_node(node, old_f)
open_list.push_node(node)

def check_neighbors(self, start, end, graph, open_list,
open_value=True, backtrace_by=None):
Expand Down Expand Up @@ -150,7 +152,7 @@ def find_path(self, start, end, grid):
self.runs = 0 # count number of iterations
start.opened = True

open_list = [start]
open_list = SimpleHeap(start, grid)

while len(open_list) > 0:
self.runs += 1
Expand Down
6 changes: 3 additions & 3 deletions pathfinding/finder/msp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import deque, namedtuple
from ..core import heuristic
from ..finder.finder import Finder
from ..core.heap import SimpleHeap


class MinimumSpanningTree(Finder):
Expand Down Expand Up @@ -31,14 +32,13 @@ def itertree(self, grid, start):

start.opened = True

open_list = [start]
open_list = SimpleHeap(start, grid)

while len(open_list) > 0:
self.runs += 1
self.keep_running()

node = heapq.nsmallest(1, open_list)[0]
open_list.remove(node)
node = open_list.pop_node()
node.closed = True
yield node

Expand Down
6 changes: 6 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[pytest]
pythonpath = .
log_cli = 1
log_cli_level = INFO
log_cli_format = %(asctime)s.%(msecs)03d [%(levelname)8s] (%(filename)s:%(lineno)s) %(message)s
log_cli_date_format = %Y-%m-%d %H:%M:%S
5 changes: 5 additions & 0 deletions test/path_test_scenarios.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[
{
"name": "s1",
"startX": 0,
"startY": 0,
"endX": 1,
Expand All @@ -10,6 +11,7 @@
"expectedDiagonalLength": 2
},
{
"name": "s2",
"startX": 1,
"startY": 1,
"endX": 4,
Expand All @@ -24,6 +26,7 @@
"expectedDiagonalLength": 5
},
{
"name": "s3",
"startX": 0,
"startY": 3,
"endX": 3,
Expand All @@ -38,6 +41,7 @@
"expectedDiagonalLength": 6
},
{
"name": "s4",
"startX": 4,
"startY": 4,
"endX": 19,
Expand Down Expand Up @@ -66,6 +70,7 @@
"expectedDiagonalLength": 16
},
{
"name": "s5",
"startX": 0,
"startY": 0,
"endX": 4,
Expand Down
6 changes: 6 additions & 0 deletions test/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pytest==7.4.4
snakeviz==2.2.0
pytest-profiling==1.7.0
numpy==1.26.3
pandas==2.1.4
matplotlib==3.8.2
25 changes: 25 additions & 0 deletions test/test_heap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pathfinding.core.heap import SimpleHeap
from pathfinding.core.grid import Grid

def test_heap():
grid = Grid(width=10, height=10)
start = grid.node(0, 0)
open_list = SimpleHeap(start, grid)

# Test pop
assert open_list.pop_node() == start
assert len(open_list) == 0

# Test push
open_list.push_node(grid.node(1, 1))
open_list.push_node(grid.node(1, 2))
open_list.push_node(grid.node(1, 3))

# Test removal and pop
assert len(open_list) == 3
open_list.remove_node(grid.node(1, 2), 0)
assert len(open_list) == 3

assert open_list.pop_node() == grid.node(1, 1)
assert open_list.pop_node() == grid.node(1, 3)
assert len(open_list) == 0
29 changes: 29 additions & 0 deletions test/test_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy
from pathfinding.core.diagonal_movement import DiagonalMovement
from pathfinding.core.grid import Grid
from pathfinding.finder.a_star import AStarFinder


def _add_block(g: numpy.ndarray, x: int, y: int, padding: int):
for i in range(x - padding, x + padding):
for j in range(y - padding, y + padding):
g[j][i] = 0

def test_a_star():
"""Test performance."""
# Get a 500 x 500 grid
grid = numpy.ones((500, 500), numpy.int32)

# Add a block at the center
_add_block(grid, 250, 250, 50)

finder_grid = Grid(matrix=grid)
start = finder_grid.node(0, 0)
end = finder_grid.node(400, 400)

finder = AStarFinder(diagonal_movement=DiagonalMovement.never)
path, runs = finder.find_path(start, end, finder_grid)

assert path[0] == start
assert path[-1] == end
assert len(path) == 801

0 comments on commit a091843

Please sign in to comment.