diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 80cdf894..2ccbcbad 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,19 +1,43 @@ -## Contributing to exoplanet-jax +# Contributor Guide -### Reporting issues +Thank you for your interest in improving this project. This project is +open-source and welcomes contributions in the form of bug reports, feature +requests, and pull requests. -If you find a bug or other unexpected behavior while using `exoplanet-jax`, -open an issue on the [GitHub repository](https://github.com/exoplanet-dev/exoplanet-jax/issues) -and we will try to respond and (hopefully) solve the problem in a timely manner. -Similarly, if you have a feature request or question about the library, the best -place to post those is currently on GitHub as an issue, but that is likely -change if the user community grows. If you report an issue, please give the -details needed to reproduce the problem (version of exoplanet, its dependencies, -and your platform) and a small standalone piece of code that demonstrates the -problem clearly. +Here is a list of important resources for contributors: -### Contributing code +- [Source Code](https://github.com/exoplanet-dev/jaxoplanet) +- [Documentation](https://jax.exoplanet.codes) +- [Issue Tracker](https://github.com/exoplanet-dev/jaxoplanet/issues) -We welcome contributions to the codebase of all scales from typo fixes to new features, -but if you would like to add a substantial feature, it would be a good idea to first -open an issue that describes your plan so that we can discuss in advance. +## How to report a bug + +Report bugs on the [Issue Tracker](https://github.com/exoplanet-dev/jaxoplanet/issues). + +When filing an issue, make sure to answer these questions: + +- Which operating system and Python version are you using? +- Which version of this project are you using? +- What did you do? +- What did you expect to see? +- What did you see instead? + +The best way to get your bug fixed is to provide a test case, and/or steps to +reproduce the issue. In particular, please include a [Minimal, Reproducible +Example](https://stackoverflow.com/help/minimal-reproducible-example). + +## How to request a feature + +Feel free to request features on the [Issue +Tracker](https://github.com/exoplanet-dev/jaxoplanet/issues). + +## How to test the project + +```bash +python -m pip install nox +python -m nox -s test +``` + +## How to submit changes + +Open a [Pull Request](https://github.com/exoplanet-dev/jaxoplanet/pulls). diff --git a/README.md b/README.md index 0cd159f1..09a49720 100644 --- a/README.md +++ b/README.md @@ -4,18 +4,23 @@ _Astronomical time series analysis with JAX_ --- -_jaxoplanet_ is a [functional-programming](https://en.wikipedia.org/wiki/Functional_programming)-forward implementation of many features from the +_jaxoplanet_ is a +[functional-programming](https://en.wikipedia.org/wiki/Functional_programming)-forward +implementation of many features from the [exoplanet](https://docs.exoplanet.codes/en/latest/) and [starry](https://starry.readthedocs.io/en/latest/) packages built on top of [JAX](https://jax.readthedocs.io/en/latest/). *jaxoplanet* includes fast and robust implementations of many exoplanet-specific operations, including solving Kepler's equation, and computing limb-darkened -light curves. Since *jaxoplanet* is built on top of JAX it has have first-class -support for hardware acceleration using GPUs and TPUs, and it also integrates seamlessly -with modeling tools like [NumPyro](https://numpyro.readthedocs.io/en/latest/), -and [Flax](https://flax.readthedocs.io/en/latest/). +light curves. Since *jaxoplanet* is built on top of JAX it has first-class +support for hardware acceleration using GPUs and TPUs, and it also integrates +seamlessly with modeling tools like +[NumPyro](https://numpyro.readthedocs.io/en/latest/), and +[Flax](https://flax.readthedocs.io/en/latest/). +**For the most complete documentation, check out the documentation page at +[jax.exoplanet.codes](https://jax.exoplanet.codes).** ## Installation @@ -33,14 +38,17 @@ Then install _jaxoplanet_ with: python -m pip install jaxoplanet ``` -If you run into issues with installing *jaxoplanet*, take a look at the **Troubleshooting on *jaxoplanet* installation** page. +If you run into issues with installing *jaxoplanet*, take a look at [the +installation instructions](https://jax.exoplanet.codes/en/latest/install). ## Quick start ## Attribution -While we don't yet have a citation for *jaxoplanet*, please reference the GitHub repository if you find -this code useful in your research. The BibTeX entry for the repo is: + +While we don't yet have a citation for *jaxoplanet*, please reference the GitHub +repository if you find this code useful in your research. The BibTeX entry for +the repo is: ``` @software{jaxoplanet, @@ -56,4 +64,5 @@ this code useful in your research. The BibTeX entry for the repo is: ## License Copyright (c) 2021-2024 Simons Foundation, Inc. -*jaxoplanet* is free software made available under the MIT License. For details see the LICENSE file. +*jaxoplanet* is free software made available under the MIT License. For details +see the LICENSE file. diff --git a/docs/code-of-conduct.md b/docs/code-of-conduct.md new file mode 100644 index 00000000..58fd373b --- /dev/null +++ b/docs/code-of-conduct.md @@ -0,0 +1,3 @@ +```{include} ../CODE_OF_CONDUCT.md + +``` diff --git a/docs/contributing.md b/docs/contributing.md new file mode 100644 index 00000000..8307ac45 --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1,12 @@ +(contributing)= + +```{include} ../CONTRIBUTING.md + +``` + +```{toctree} +:maxdepth: 1 +:hidden: + +code-of-conduct +``` diff --git a/docs/guide.md b/docs/guide.md new file mode 100644 index 00000000..6e35cf20 --- /dev/null +++ b/docs/guide.md @@ -0,0 +1,15 @@ +(guide)= + +# User Guide + +The following pages give some background on the context within which `jaxoplanet` +exists, as well as detailed installation and API documentation. Click through +for all the details, or head over to the {ref}`tutorials` for a more hands-on +experience. + +```{toctree} +:maxdepth: 1 + +install +troubleshooting +``` diff --git a/docs/index.md b/docs/index.md index 98211b43..85d139e2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -8,6 +8,7 @@ maxdepth: 1 --- -installation +guide tutorials +contributing ``` diff --git a/docs/install.md b/docs/install.md new file mode 100644 index 00000000..47634b1d --- /dev/null +++ b/docs/install.md @@ -0,0 +1,39 @@ +(install)= + +# Installation Guide + +`jaxoplanet` is built on top of [`jax`](https://github.com/google/jax) so that's the +primary dependency that you'll need. All of the methods below will install any +required dependencies, but if you want to take advantage of your GPU, that might +take a little more setup. `jaxoplanet` doesn't have any GPU-specific code, so it +should be enough to just [follow the installation instructions for CUDA support +in the `jax` README](https://github.com/google/jax/#installation). + +## Using pip + +The easiest way to install the most recent stable version of `jaxoplanet` is +with [pip](https://pip.pypa.io): + +```bash +python -m pip install jaxoplanet +``` + +## From source + +Alternatively, you can get the source: + +```bash +git clone https://github.com/exoplanet-dev/jaxoplanet.git +cd jaxoplanet +python -m pip install -e . +``` + +## Tests + +If you installed from source, you can run the unit tests. From the root of the +source directory, run: + +```bash +python -m pip install nox +python -m nox -s test +``` diff --git a/docs/installation.md b/docs/installation.md deleted file mode 100644 index 6a71d4f5..00000000 --- a/docs/installation.md +++ /dev/null @@ -1,65 +0,0 @@ -# Troubleshooting on `jaxoplanet` installation - -Follow these steps to install `jax` and `jaxoplanet` on your system, with special instructions for Mac M1 chip or newer. - -## Step 1: Check Your Python Architecture - -First, verify your Python architecture to determine if you're on an ARM or Intel-based chip. - -```bash -import platform -platform.machine() -``` - -- If the output is "arm64", proceed with the following steps. -- If the output is "x86-64", you're running an Intel emulator on an Apple silicon chip and need to switch to arm64. - -For the best performance of `jax`, it will be better to install `python` under the arm64 archecture. - -Install Miniforge for ARM64. - -Download Miniforge from the official Conda Forge page (https://conda-forge.org/miniforge/). -Run the installer script: -```bash -bash Mambaforge-23.11.0-0-MacOSX-arm64.sh -``` - -Restart the terminal. - -## Step 2: Install `jax` and `jaxoplanet` - -Create a New Environment (Optional) in miniforge. - -It's a good practice to create a new environment for your projects: - -```bash -conda create --name jaxoplanet -conda activate jaxoplanet -conda install pip -python -m pip install "jax[cpu]" - ``` - -You have two options to install jaxoplanet: - -Option 1: Direct Installation via `pip` - -```bash -python -m pip install jaxoplanet -``` - -Option 2: Install from Source -If you prefer to install jaxoplanet from source: - -Clone the jaxoplanet repository: -```bash -git clone https://github.com/exoplanet-dev/jaxoplanet.git -cd jaxoplanet -pip install . -``` - -You could install `jaxoplanet` and its depedencies with -```bash -pip install ".[docs]" -``` - -Follow these instructions to successfully install jax and jaxoplanet on your system. diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md new file mode 100644 index 00000000..ae178981 --- /dev/null +++ b/docs/troubleshooting.md @@ -0,0 +1,80 @@ +(troubleshooting)= + +# Troubleshooting + +This page includes some tips for troubleshooting issues that you might run into +when using `jaxoplanet`. This is a work-in-progress, so if you don't see your +issue listed here, feel free to open an issue on the [GitHub repository issue +tracker](https://github.com/exoplanet-dev/jaxoplanet/issues). + +## NaNs and infinities + +It's not that uncommon to hit NaNs or infinities when using `jax` and +`jaxoplanet`. This is often caused by numerical precision issues, and this can +be exacerbated by the fact that, by default, `jax` disables double precision +calculations. You can enable double precision [a few different ways as described +in the `jax` +docs](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision), +and the way we do it in these docs is to add the following, when necessary: + +```python +import jax + +jax.config.update("jax_enable_x64", True) +``` + +If enabling double precision doesn't do the trick, this often means that there's +an issue with the parameter or modeling choices that you're making. `jax`'s +["debug NaNs" mode](https://jax.readthedocs.io/en/latest/debugging/flags.html) +can help diagnose issues here. + +## Installation issues on ARM Macs + +Follow these steps to install `jax` and `jaxoplanet` on your system, with +special instructions for Mac M1 chip or newer. + +### Step 1: Check Your Python Architecture + +First, verify your Python architecture to determine if you're on an ARM or +Intel-based chip. + +```bash +import platform +platform.machine() +``` + +- If the output is "arm64", proceed with the following steps. +- If the output is "x86-64", you're running an Intel emulator on an Apple + silicon chip and need to switch to arm64. + +For the best performance of `jax`, it will be better to install Python under +the arm64 architecture. + +**Install Miniforge for ARM64.** + +Download Miniforge from the official Conda Forge page (https://conda-forge.org/miniforge/). +Run the installer script: + +```bash +bash Mambaforge-23.11.0-0-MacOSX-arm64.sh +``` + +Restart the terminal. + +### Step 2: Install `jax` and `jaxoplanet` + +Create a New Environment (Optional) in miniforge. + +It's a good practice to create a new environment for your projects: + +```bash +conda create --name jaxoplanet +conda activate jaxoplanet +conda install pip + +conda install jax +# or +python -m pip install "jax[cpu]" +``` + +Then install `jaxoplanet` using the instructions in {ref}`install`. diff --git a/docs/tutorials/rv.ipynb b/docs/tutorials/rv.ipynb index 1bab3f02..56298454 100644 --- a/docs/tutorials/rv.ipynb +++ b/docs/tutorials/rv.ipynb @@ -19,10 +19,10 @@ }, "outputs": [], "source": [ - "# Double precision with JAX\n", "import jax\n", - "from tqdm.autonotebook import tqdm\n", + "import numpyro\n", "\n", + "numpyro.set_host_device_count(2)\n", "jax.config.update(\"jax_enable_x64\", True)" ] }, @@ -133,10 +133,6 @@ "metadata": {}, "outputs": [], "source": [ - "import numpyro\n", - "\n", - "numpyro.set_host_device_count(2)\n", - "\n", "from numpyro import distributions as dist, infer\n", "\n", "\n", @@ -219,7 +215,7 @@ "\n", "plot_data()\n", "plt.plot(over_time, posterior_rvs.mean(0), \"C0\")\n", - "plt.fill_between(\n", + "_ = plt.fill_between(\n", " over_time,\n", " *np.percentile(posterior_rvs, [16, 84], axis=0),\n", " alpha=0.3,\n", @@ -242,7 +238,7 @@ "source": [ "import corner\n", "\n", - "corner.corner(\n", + "_ = corner.corner(\n", " samples,\n", " var_names=[\"mass\", \"period\"],\n", " truths=[truth[\"mass\"], truth[\"period\"]],\n", @@ -251,6 +247,13 @@ " title_fmt=\".4f\",\n", ")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/jaxoplanet/light_curves/__init__.py b/src/jaxoplanet/light_curves/__init__.py index cac37729..bcbfb14f 100644 --- a/src/jaxoplanet/light_curves/__init__.py +++ b/src/jaxoplanet/light_curves/__init__.py @@ -1,6 +1,4 @@ -"""This module contains models for computing and transforming light curve models""" +__all__ = ["transforms", "limb_dark_light_curve"] from jaxoplanet.light_curves import transforms as transforms -from jaxoplanet.light_curves.limb_dark import ( - limb_dark_light_curve as limb_dark_light_curve, -) +from jaxoplanet.light_curves.limb_dark import light_curve as limb_dark_light_curve diff --git a/src/jaxoplanet/light_curves/limb_dark.py b/src/jaxoplanet/light_curves/limb_dark.py index 8c7587a3..f7994da7 100644 --- a/src/jaxoplanet/light_curves/limb_dark.py +++ b/src/jaxoplanet/light_curves/limb_dark.py @@ -1,3 +1,5 @@ +__all__ = ["light_curve"] + from functools import partial from typing import Callable @@ -12,9 +14,25 @@ from jaxoplanet.units import unit_registry as ureg -def limb_dark_light_curve( +def light_curve( orbit: LightCurveOrbit, *u: Array, order: int = 10 ) -> Callable[[Quantity], Array]: + """Compute the light curve for arbitrary polynomial limb darkening + + See `Agol et al. (2020) `_ and + :func:`jaxoplanet.core.limb_dark.light_curve` for more technical details. + + Args: + orbit (LightCurveOrbit): An orbit object that can be used to evaluate the + relative positions of the transiting body with respect to the light source. + u (Array): The coefficients of the polynomial limb darkening + order (int): The order of the numerical integration used by the backend; see + :func:`jaxoplanet.core.limb_dark.light_curve` + + Returns: + A function which takes the time in days as input and returns the light curve flux + """ + if u: ld_u = jnp.concatenate([jnp.atleast_1d(jnp.asarray(u_)) for u_ in u], axis=0) else: diff --git a/src/jaxoplanet/object_stack.py b/src/jaxoplanet/object_stack.py index 065577a1..d1108d61 100644 --- a/src/jaxoplanet/object_stack.py +++ b/src/jaxoplanet/object_stack.py @@ -1,3 +1,5 @@ +__all__ = ["ObjectStack"] + from collections.abc import Callable, Sequence from functools import wraps from typing import Any, Generic, Optional, TypeVar, Union diff --git a/src/jaxoplanet/proto.py b/src/jaxoplanet/proto.py index 354eacf2..d725ec74 100644 --- a/src/jaxoplanet/proto.py +++ b/src/jaxoplanet/proto.py @@ -1,57 +1,20 @@ -from typing import Optional, Protocol - -from jaxoplanet.types import Quantity +__all__ = ["LightCurveOrbit"] +from typing import Protocol -class LightCurveBody(Protocol): - @property - def shape(self) -> tuple[int, ...]: ... - - @property - def radius(self) -> Quantity: ... +from jaxoplanet.types import Quantity class LightCurveOrbit(Protocol): - @property - def shape(self) -> tuple[int, ...]: ... - - @property - def radius(self) -> Quantity: ... - - def relative_position(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: ... + """An interface for orbits that can be used to compute light curves""" - @property - def central_radius(self) -> Quantity: ... - - -class Orbit(Protocol): @property def shape(self) -> tuple[int, ...]: ... @property def radius(self) -> Quantity: ... - def relative_position( - self, t: Quantity, parallax: Optional[Quantity] = None - ) -> tuple[Quantity, Quantity, Quantity]: ... + def relative_position(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: ... @property def central_radius(self) -> Quantity: ... - - def position( - self, t: Quantity, parallax: Optional[Quantity] = None - ) -> tuple[Quantity, Quantity, Quantity]: ... - - def central_position( - self, t: Quantity, parallax: Optional[Quantity] = None - ) -> tuple[Quantity, Quantity, Quantity]: ... - - def velocity(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: ... - - def central_velocity(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: ... - - def relative_velocity(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: ... - - def radial_velocity( - self, t: Quantity, semiamplitude: Optional[Quantity] = None - ) -> tuple[Quantity, Quantity, Quantity]: ... diff --git a/src/jaxoplanet/test_utils.py b/src/jaxoplanet/test_utils.py index 5dbcacba..d9cefccb 100644 --- a/src/jaxoplanet/test_utils.py +++ b/src/jaxoplanet/test_utils.py @@ -1,9 +1,18 @@ +__all__ = [ + "assert_allclose", + "assert_quantity_allclose", + "assert_quantity_pytree_allclose", +] + from jax import tree_util from jax._src.public_test_util import check_close from jpu.core import is_quantity def assert_allclose(calculated, expected, *args, **kwargs): + """ + Check that two floating point arrays are equal within a dtype-dependent tolerance + """ kwargs["rtol"] = kwargs.get( "rtol", { @@ -15,6 +24,12 @@ def assert_allclose(calculated, expected, *args, **kwargs): def assert_quantity_allclose(calculated, expected, *args, convert=False, **kwargs): + """ + Check that two floating point quantities are equal within a dtype-dependent tolerance + + By default, the units are required to also be equal, but a conversion will be + attempted if the ``convert`` argument is set to ``True``. + """ if not is_quantity(calculated) and not is_quantity(expected): assert_allclose(calculated, expected, *args, **kwargs) elif convert: @@ -32,6 +47,10 @@ def assert_quantity_allclose(calculated, expected, *args, convert=False, **kwarg def assert_quantity_pytree_allclose( calculated, expected, *args, is_leaf=is_quantity, **kwargs ): + """ + Check that two Pytrees with floating point quantities or arrays as leaves are equal + within a dtype-dependent tolerance + """ leaves1, treedef1 = tree_util.tree_flatten(calculated, is_leaf=is_leaf) leaves2, treedef2 = tree_util.tree_flatten(expected, is_leaf=is_leaf) assert treedef1 == treedef2 diff --git a/src/jaxoplanet/units/decorator.py b/src/jaxoplanet/units/decorator.py index 3378cb1e..035caa4e 100644 --- a/src/jaxoplanet/units/decorator.py +++ b/src/jaxoplanet/units/decorator.py @@ -1,3 +1,5 @@ +__all__ = ["quantity_input"] + import inspect from functools import partial, wraps from typing import Any, Callable, Optional