From 51a3735a8620154046c11f04d63096b42978e34f Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Thu, 24 Oct 2024 03:22:52 -0400 Subject: [PATCH] JAX installation is now handled correctly for different configurations (CPU, CUDA, TPU) --- CHANGELOG.md | 1 + README.md | 21 ++++++++++++++++++++- setup.py | 13 ++++++++++--- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b25184..1d036de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - mkdocs is now configured correctly for the new project structure +- JAX installation is now handled correctly for different configurations (CPU, CUDA, TPU) ## [0.2.0] - 2024-10-22 diff --git a/README.md b/README.md index 749f7cd..a23c7dc 100644 --- a/README.md +++ b/README.md @@ -12,11 +12,30 @@ XLB can now be installed via pip: `pip install xlb`. XLB is a fully differentiable 2D/3D Lattice Boltzmann Method (LBM) library that leverages hardware acceleration. It supports [JAX](https://github.com/google/jax) and [NVIDIA Warp](https://github.com/NVIDIA/warp) backends, and is specifically designed to solve fluid dynamics problems in a computationally efficient and differentiable manner. Its unique combination of features positions it as an exceptionally suitable tool for applications in physics-based machine learning. With the new Warp backend, XLB now offers state-of-the-art performance for even faster simulations. ## Getting Started -To get started with XLB, you can install it using pip: +To get started with XLB, you can install it using pip. There are different installation options depending on your hardware and needs: + +### Basic Installation (CPU-only) ```bash pip install xlb ``` +### Installation with CUDA support (for NVIDIA GPUs) +This installation is for the JAX backend with CUDA support: +```bash +pip install "xlb[cuda]" +``` + +### Installation with TPU support +This installation is for the JAX backend with TPU support: +```bash +pip install "xlb[tpu]" +``` + +### Notes: +- For Mac users: Use the basic CPU installation command as JAX's GPU support is not available on MacOS +- The NVIDIA Warp backend is included in all installation options and supports CUDA automatically when available +- The installation options for CUDA and TPU only affect the JAX backend + To install the latest development version from source: ```bash diff --git a/setup.py b/setup.py index a7f06f9..d893d39 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='xlb', - version='0.2.0', + version='0.2.1', description='XLB: Accelerated Lattice Boltzmann (XLB) for Physics-based ML', long_description=open('README.md').read(), long_description_content_type='text/markdown', @@ -11,7 +11,6 @@ license='Apache License 2.0', packages=find_packages(), install_requires=[ - 'jax[cuda]>=0.4.34', 'matplotlib>=3.9.2', 'numpy>=2.1.2', 'pyvista>=0.44.1', @@ -19,7 +18,15 @@ 'warp-lang>=1.4.0', 'numpy-stl>=3.1.2', 'pydantic>=2.9.1', - 'ruff>=0.6.5' + 'ruff>=0.6.5', + 'jax>=0.4.34' # Base JAX CPU-only requirement ], + extras_require={ + 'cuda': ['jax[cuda12]>=0.4.34'], # For CUDA installations + 'tpu': ['jax[tpu]>=0.4.34'], # For TPU installations + }, python_requires='>=3.10', + dependency_links=[ + 'https://storage.googleapis.com/jax-releases/libtpu_releases.html' + ], )