diff --git a/include/gridpp.h b/include/gridpp.h index 59f8489e..b082bece 100644 --- a/include/gridpp.h +++ b/include/gridpp.h @@ -1236,6 +1236,10 @@ namespace gridpp { private: point p; }; + // Checks that a lat-coordinate is valid (based on the coordinate type) + bool check_lat(float lat) const; + // Checks that a lon-coordinate is valid (based on the coordinate type) + bool check_lon(float lon) const; }; /** Represents a vector of locations and their metadata */ diff --git a/src/api/kdtree.cpp b/src/api/kdtree.cpp index 4691e8b6..916d60b3 100644 --- a/src/api/kdtree.cpp +++ b/src/api/kdtree.cpp @@ -1,4 +1,5 @@ #include "gridpp.h" +#include using namespace gridpp; @@ -117,6 +118,11 @@ bool gridpp::KDTree::convert_coordinates(const vec& lats, const vec& lons, vec& } bool gridpp::KDTree::convert_coordinates(float lat, float lon, float& x_coord, float& y_coord, float& z_coord) const { + if(!check_lat(lat) || !check_lon(lon)) { + std::stringstream ss; + ss << "Invalid coords: " << lat << "," << lon << std::endl; + throw std::invalid_argument(ss.str()); + } if(mType == gridpp::Cartesian) { x_coord = lon; y_coord = lat; @@ -246,3 +252,12 @@ bool gridpp::KDTree::is_not_equal::operator()(value const& v) const { float z0 = v.first.get<2>(); return p.get<0>() != x0 || p.get<1>() != y0 || p.get<2>() != z0; } + +bool gridpp::KDTree::check_lat(float lat) const { + if(get_coordinate_type() == gridpp::Cartesian) + return gridpp::is_valid(lat); + return gridpp::is_valid(lat) && (lat >= -90.001) && (lat <= 90.001); +}; +bool gridpp::KDTree::check_lon(float lon) const { + return gridpp::is_valid(lon); +} diff --git a/tests/kdtree_test.py b/tests/kdtree_test.py index e7d74bfd..7e83512e 100644 --- a/tests/kdtree_test.py +++ b/tests/kdtree_test.py @@ -89,6 +89,44 @@ def test_empty_constructor(self): tree = gridpp.KDTree() self.assertEqual(tree.get_coordinate_type(), gridpp.Geodetic) + def test_invalid_coords(self): + lats = [91, -91, np.nan, 0] + lons = [0, 0, 0, np.nan] + for i in range(len(lats)): + curr_lats = [lats[i]] + curr_lons = [lons[i]] + with self.subTest(lat=lats[i], lon=lons[i]): + with self.assertRaises(ValueError) as e: + tree = gridpp.KDTree(curr_lats, curr_lons, gridpp.Geodetic) + + def test_valid_coords(self): + lats = [90.000001, -90.0000001] + lons = [0, 0] + for i in range(len(lats)): + curr_lats = [0, lats[i]] + curr_lons = [0, lons[i]] + with self.subTest(lat=lats[i], lon=lons[i]): + tree = gridpp.KDTree(curr_lats, curr_lons, gridpp.Geodetic) + I = tree.get_nearest_neighbour(0, 0) + self.assertEqual(I, 0) + + def test_wrap_lon(self): + """Check that longitudes outside [-180, 180] are correctly handled""" + lons = [-360, 0, 360] + for i in range(len(lons)): + curr_lats = [0] + curr_lons = [lons[i]] + with self.subTest(lon=lons[i]): + tree = gridpp.KDTree(curr_lats, curr_lons, gridpp.Geodetic) + I, dist = tree.get_neighbours_with_distance(0, 0, 1e9) + self.assertEqual(I[0], 0) + self.assertAlmostEqual(dist[0], 0) + + I, dist = tree.get_neighbours_with_distance(0, 180, 1e9) + self.assertEqual(I[0], 0) + diameter_of_earth = 12756274.0 + self.assertAlmostEqual(dist[0], diameter_of_earth) + if __name__ == '__main__': unittest.main()