Skip to content

Commit

Permalink
Merge pull request spacetelescope#196 from schlafly/kdtree
Browse files Browse the repository at this point in the history
Use KDTree to reduce memory & compute in _xy_2dhist
  • Loading branch information
mcara authored Jan 8, 2024
2 parents f7ad5a1 + a4e6bb9 commit d71a5dd
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ Release Notes
using input catalogs that failed to align in the expanded reference
catalog. [#195]

- Reduce memory & compute needed by _xy_2dhist by pruning distant
pairs with a kdtree. This is a purely internal change that does not
affect the results of the algorithm. [#196]


0.8.5 (30-November-2023)
========================
Expand Down
18 changes: 14 additions & 4 deletions tweakwcs/matchutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from astropy.utils.exceptions import AstropyDeprecationWarning

from stsci.stimage import xyxymatch
from scipy import spatial

from . import __version__ # noqa: F401

Expand Down Expand Up @@ -277,10 +278,19 @@ def __call__(self, refcat, imcat, tp_pscale=1.0, tp_units=None, **kwargs):


def _xy_2dhist(imgxy, refxy, r):
# This code replaces the C version (arrxyzero) from carrutils.c
# It is about 5-8 times slower than the C version.
dx = np.subtract.outer(imgxy[:, 0], refxy[:, 0]).ravel()
dy = np.subtract.outer(imgxy[:, 1], refxy[:, 1]).ravel()
# trim to only pairs within (r+0.5) * np.sqrt(2) using a kdtree
# to avoid computing differences for many widely separated pairs.
kdtree = spatial.KDTree(refxy)
neighbors = kdtree.query_ball_point(imgxy, (r + 0.5) * np.sqrt(2))
lens = [len(n) for n in neighbors]
mi = np.repeat(np.arange(imgxy.shape[0]), lens)
if len(mi) > 0:
mr = np.concatenate([n for n in neighbors if len(n) > 0])
else:
mr = mi

dx = imgxy[mi, 0] - refxy[mr, 0]
dy = imgxy[mi, 1] - refxy[mr, 1]
idx = np.where((dx < r + 0.5) & (dx >= -r - 0.5) &
(dy < r + 0.5) & (dy >= -r - 0.5))
r = int(np.ceil(r))
Expand Down

0 comments on commit d71a5dd

Please sign in to comment.