diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a7a7c85..796ac48 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,10 @@ Release Notes using input catalogs that failed to align in the expanded reference catalog. [#195] +- Reduce memory & compute needed by _xy_2dhist by pruning distant + pairs with a kdtree. This is a purely internal change that does not + affect the results of the algorithm. [#196] + 0.8.5 (30-November-2023) ======================== diff --git a/tweakwcs/matchutils.py b/tweakwcs/matchutils.py index 19bc7d0..d846494 100644 --- a/tweakwcs/matchutils.py +++ b/tweakwcs/matchutils.py @@ -16,6 +16,7 @@ from astropy.utils.exceptions import AstropyDeprecationWarning from stsci.stimage import xyxymatch +from scipy import spatial from . import __version__ # noqa: F401 @@ -277,10 +278,19 @@ def __call__(self, refcat, imcat, tp_pscale=1.0, tp_units=None, **kwargs): def _xy_2dhist(imgxy, refxy, r): - # This code replaces the C version (arrxyzero) from carrutils.c - # It is about 5-8 times slower than the C version. - dx = np.subtract.outer(imgxy[:, 0], refxy[:, 0]).ravel() - dy = np.subtract.outer(imgxy[:, 1], refxy[:, 1]).ravel() + # trim to only pairs within (r+0.5) * np.sqrt(2) using a kdtree + # to avoid computing differences for many widely separated pairs. + kdtree = spatial.KDTree(refxy) + neighbors = kdtree.query_ball_point(imgxy, (r + 0.5) * np.sqrt(2)) + lens = [len(n) for n in neighbors] + mi = np.repeat(np.arange(imgxy.shape[0]), lens) + if len(mi) > 0: + mr = np.concatenate([n for n in neighbors if len(n) > 0]) + else: + mr = mi + + dx = imgxy[mi, 0] - refxy[mr, 0] + dy = imgxy[mi, 1] - refxy[mr, 1] idx = np.where((dx < r + 0.5) & (dx >= -r - 0.5) & (dy < r + 0.5) & (dy >= -r - 0.5)) r = int(np.ceil(r))