Skip to content

Commit

Permalink
Fixing merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Feb 28, 2024
2 parents b23ada6 + 0f8b68a commit 7d171c9
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 85 deletions.
2 changes: 1 addition & 1 deletion src/jaxoplanet/light_curves/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""This module contains models for computing and transforming light curve models"""

from jaxoplanet.light_curves import exposure_time as exposure_time
from jaxoplanet.light_curves import transforms as transforms
from jaxoplanet.light_curves.limb_dark import (
limb_dark_light_curve as limb_dark_light_curve,
)
84 changes: 0 additions & 84 deletions src/jaxoplanet/light_curves/exposure_time.py

This file was deleted.

196 changes: 196 additions & 0 deletions src/jaxoplanet/light_curves/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
__all__ = ["integrate", "interpolate"]

from functools import wraps
from typing import Any, Optional, Union

import jax
import jax.numpy as jnp
import jpu.numpy as jnpu
from jpu.core import is_quantity

from jaxoplanet import units
from jaxoplanet.light_curves.types import LightCurveFunc
from jaxoplanet.light_curves.utils import vectorize
from jaxoplanet.types import Array, Quantity
from jaxoplanet.units import unit_registry as ureg

try:
from jax.extend import linear_util as lu
except ImportError:
from jax import linear_util as lu # type: ignore


@units.quantity_input(exposure_time=ureg.d)
def integrate(
func: LightCurveFunc,
exposure_time: Optional[Quantity] = None,
order: int = 0,
num_samples: int = 7,
) -> LightCurveFunc:
"""Transform a light curve function to apply exposure time integration
This transformation applies a fixed stencil numerical integration scheme to the input
function ``func`` to convolve the light curve with a top hat exposure time centered
on the input time, with a full width of ``exposure_time``.
The order of the integration scheme is set using the ``order`` parameter which must
be ``0``, ``1``, or ``2``. The default (``0``) uses the "resampling" scheme discussed
by `Kipping (2010) <https://arxiv.org/abs/1004.3741>`_. The higher order schemes
``1`` and ``2`` apply the trapezoid and Simpson's rules respectively, but won't
necessarily provide higher accuracy results because of discontinuities at the
contact points.
In practice, the parameter ``num_samples`` which sets the number of function
evaluations per integral has the most significant effect on the accuracy of this
integral, trading off against higher computational cost.
Args:
func: A light curve function which takes a time ``Quantity`` as the first
argument
exposure_time (Quantity): The exposure time (in days, by default)
order (int): The order of the integration scheme as discussed above
num_samples (int): The number of function evaluations made per integral,
controlling the accuracy of the numerics
Returns:
A new light curve function with the same signature as ``func``, computing the
exposure time integrated flux
"""
if exposure_time is None:
return func

if jnpu.ndim(exposure_time) != 0:
raise ValueError(
"The exposure time passed to 'integrate_exposure_time' has shape "
f"{jnpu.shape(exposure_time)}, but a scalar was expected; "
"To use exposure time integration with different exposures at different "
"times, manually 'vmap' or 'vectorize' the function"
)

# Ensure 'num_samples' is an odd number
num_samples = int(num_samples)
num_samples += 1 - num_samples % 2
stencil = jnp.ones(num_samples)

# Construct exposure time integration stencil
if order == 0:
dt = jnp.linspace(-0.5, 0.5, 2 * num_samples + 1)[1:-1:2]
elif order == 1:
dt = jnp.linspace(-0.5, 0.5, num_samples)
stencil = 2 * stencil
stencil = stencil.at[0].set(1)
stencil = stencil.at[-1].set(1)
elif order == 2:
dt = jnp.linspace(-0.5, 0.5, num_samples)
stencil = stencil.at[1:-1:2].set(4)
stencil = stencil.at[2:-1:2].set(2)
else:
raise ValueError(
"The parameter 'order' in 'integrate_exposure_time' must be 0, 1, or 2"
)
dt = dt * exposure_time
stencil /= jnp.sum(stencil)

@wraps(func)
@units.quantity_input(time=ureg.d)
@vectorize
def wrapped(time: Quantity, *args: Any, **kwargs: Any) -> Union[Array, Quantity]:
if jnpu.ndim(time) != 0:
raise ValueError(
"The time passed to 'integrate_exposure_time' has shape "
f"{jnpu.shape(time)}, but a scalar was expected; "
"this shouldn't typically happen so please open an issue "
"on GitHub demonstrating the problem"
)

f = lu.wrap_init(jax.vmap(func, in_axes=(0,) + (None,) * len(args)))
f = apply_exposure_time_integration(f, stencil, dt) # type: ignore
return f.call_wrapped(time, args, kwargs) # type: ignore

return wrapped


@lu.transformation # type: ignore
def apply_exposure_time_integration(stencil, dt, time, args, kwargs):
result = yield (time + dt,) + args, kwargs
yield jnpu.dot(stencil, result)


@units.quantity_input(period=ureg.day, time_transit=ureg.day, duration=ureg.day)
def interpolate(
func: LightCurveFunc,
*,
period: Quantity,
time_transit: Quantity,
num_samples: int,
duration: Optional[Quantity] = None,
args: tuple[Any, ...] = (),
kwargs: Optional[dict[str, Any]] = None,
) -> LightCurveFunc:
"""Transform a light curve function to pre-compute the model on a grid
Sometimes it can be useful to precompute the light curve on a grid near a transit,
and then interpolate those computations to the required phases when computing the
full model. This can speed things up a lot when you have many transits, or a lot of
out of transit data. This transform uses linear interpolation.
.. note:: Unlike some other transforms, this function requires that any upstream
``*args`` and ``**kwargs`` be passed directly to the transform, rather than when
calling the transformed function. This is necessary because the model is
pre-computed when it is tranformed.
Args:
func: A light curve function which takes a time ``Quantity`` as the first
argument
period (Quantity): The period of the orbit. Used to wrap the input times into the
domain of the pre-computed model
time_transit (Quantity): The transit time of the orbit. Used to wrap the input
times into the domain of the pre-computed model
duration (Quantity): The duration centered on the transit to pre-compute. By
default, the full period will be evaluated
num_samples (int): The number of points in the time grid used for pre-computation
args (tuple): Any extra positional arguments that should be passed to ``func``
kwargs (dict): Any extra keyword arguments that should be passed to ``func``
Returns:
A new light curve function with the same signature as ``func``, computing the
flux by interpolating a pre-computed model
"""

kwargs = kwargs or {}
if duration is None:
duration = period
time_grid = time_transit + duration * jnp.linspace(-0.5, 0.5, num_samples)
flux_grid = func(time_grid, *args, **kwargs)

if is_quantity(flux_grid):
flux_magnitude = flux_grid.magnitude
flux_units = flux_grid.units
else:
flux_magnitude = flux_grid
flux_units = None

@wraps(func)
@units.quantity_input(time=ureg.d)
@vectorize
def wrapped(time: Quantity, *args: Any, **kwargs: Any) -> Union[Array, Quantity]:
del args, kwargs
time_wrapped = (
jnpu.mod(time - time_transit + 0.5 * period, period)
+ 0.5 * period
+ time_transit
)
flux = jnp.interp(
time_wrapped.magnitude,
time_grid.magnitude,
flux_magnitude,
left=flux_magnitude[0],
right=flux_magnitude[-1],
period=period.magnitude,
)
if flux_units is None:
return flux
else:
return flux * flux_units

return wrapped
24 changes: 24 additions & 0 deletions tests/light_curves/transforms_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import jax.numpy as jnp
import pytest

from jaxoplanet.light_curves import transforms
from jaxoplanet.test_utils import assert_allclose
from jaxoplanet.units import quantity_input, unit_registry as ureg


@quantity_input(time=ureg.day)
def lc_func1(time):
return jnp.ones_like(time.magnitude)


@quantity_input(time=ureg.day)
def lc_func2(time):
return jnp.stack([0.5 * time.magnitude + 0.1, -1.5 * time.magnitude + 3.6], axis=-1)


@pytest.mark.parametrize("order", [0, 1, 2])
@pytest.mark.parametrize("lc_func", [lc_func1, lc_func2])
def test_integrate_invariant(order, lc_func):
time = jnp.linspace(0, 10, 50)
int_func = transforms.integrate(lc_func, exposure_time=0.1, order=order)
assert_allclose(int_func(time), lc_func(time))

0 comments on commit 7d171c9

Please sign in to comment.