From 12249729bb1f4ce02cd78663c35c88c58766765f Mon Sep 17 00:00:00 2001 From: Shane Elipot Date: Sat, 18 May 2024 12:29:52 -0400 Subject: [PATCH] chore: Convert xarray variables to numpy arrays for consistent operations in kinematics.velocity_from_position and kinematics.position_from_velocity (#443) --- clouddrift/kinematics.py | 9 ++++++++ tests/kinematics_tests.py | 46 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/clouddrift/kinematics.py b/clouddrift/kinematics.py index 3292d110..18671abd 100644 --- a/clouddrift/kinematics.py +++ b/clouddrift/kinematics.py @@ -659,6 +659,15 @@ def velocity_from_position( y_ = np.swapaxes(y, time_axis, -1) time_ = np.swapaxes(time, time_axis, -1) + # Convert to numpy arrays to insure consistent operations + if isinstance(x_, xr.DataArray): + x_ = x_.to_numpy() + if isinstance(y_, xr.DataArray): + y_ = y_.to_numpy() + if isinstance(time_, xr.DataArray): + time_ = time_.to_numpy() + + # Initialize arrays for dx, dy, and dt dx = np.empty(x_.shape) dy = np.empty(y_.shape) dt = np.empty(time_.shape) diff --git a/tests/kinematics_tests.py b/tests/kinematics_tests.py index 193abc43..6a2d9257 100644 --- a/tests/kinematics_tests.py +++ b/tests/kinematics_tests.py @@ -275,7 +275,7 @@ def test_velocity_position_roundtrip_centered(self): self.assertTrue(np.allclose(lon, self.lon, atol=1e-2)) self.assertTrue(np.allclose(lat, self.lat, atol=1e-2)) - def test_works_with_xarray(self): + def test_works_with_xarray_forward(self): lon, lat = position_from_velocity( xr.DataArray(data=self.uf), xr.DataArray(data=self.vf), @@ -287,6 +287,30 @@ def test_works_with_xarray(self): self.assertTrue(np.allclose(lon, self.lon)) self.assertTrue(np.allclose(lat, self.lat)) + def test_works_with_xarray_backward(self): + lon, lat = position_from_velocity( + xr.DataArray(data=self.ub), + xr.DataArray(data=self.vb), + xr.DataArray(data=self.time), + self.lon[0], + self.lat[0], + integration_scheme="backward", + ) + self.assertTrue(np.allclose(lon, self.lon)) + self.assertTrue(np.allclose(lat, self.lat)) + + def test_works_with_xarray_centered(self): + lon, lat = position_from_velocity( + xr.DataArray(data=self.uc), + xr.DataArray(data=self.vc), + xr.DataArray(data=self.time), + self.lon[0], + self.lat[0], + integration_scheme="centered", + ) + self.assertTrue(np.allclose(lon, self.lon, atol=1e-2)) + self.assertTrue(np.allclose(lat, self.lat, atol=1e-2)) + def test_works_with_2d_array(self): uf = np.reshape(np.tile(self.uf, 4), (4, self.uf.size)) vf = np.reshape(np.tile(self.vf, 4), (4, self.vf.size)) @@ -388,14 +412,30 @@ def test_result_value(self): self.assertTrue(np.all(np.isclose(self.ub, u_expected))) self.assertTrue(np.all(np.isclose(self.uc, u_expected))) - def test_works_with_xarray(self): + def test_works_with_xarray_forward(self): lon = xr.DataArray(data=self.lon, coords={"time": self.time}) lat = xr.DataArray(data=self.lat, coords={"time": self.time}) time = xr.DataArray(data=self.time, coords={"time": self.time}) - uf, vf = velocity_from_position(lon, lat, time) + uf, vf = velocity_from_position(lon, lat, time, difference_scheme="forward") self.assertTrue(np.all(uf == self.uf)) self.assertTrue(np.all(vf == self.vf)) + def test_works_with_xarray_backward(self): + lon = xr.DataArray(data=self.lon, coords={"time": self.time}) + lat = xr.DataArray(data=self.lat, coords={"time": self.time}) + time = xr.DataArray(data=self.time, coords={"time": self.time}) + ub, vb = velocity_from_position(lon, lat, time, difference_scheme="backward") + self.assertTrue(np.all(ub == self.ub)) + self.assertTrue(np.all(vb == self.vb)) + + def test_works_with_xarray_centered(self): + lon = xr.DataArray(data=self.lon, coords={"time": self.time}) + lat = xr.DataArray(data=self.lat, coords={"time": self.time}) + time = xr.DataArray(data=self.time, coords={"time": self.time}) + uc, vc = velocity_from_position(lon, lat, time, difference_scheme="centered") + self.assertTrue(np.all(uc == self.uc)) + self.assertTrue(np.all(vc == self.vc)) + def test_works_with_2d_array(self): lon = np.reshape(np.tile(self.lon, 4), (4, self.lon.size)) lat = np.reshape(np.tile(self.lat, 4), (4, self.lat.size))