A JAX port of FLUX.1 models using flax.nnx
.
Important
The current codebase is designed to maintain consistency with the original implementation, with minimal modifications. While it works as expected, it may not be the most efficient implementation. I plan to release an updated version soon that better adheres to JAX conventions and best practices.
Only tested with GPU now.
Currently no quantization support & no torch-like CPU offloading support.
PRs are welcome.
git clone https://github.com/lkwq007/flux-flax.git
cd flux-flax
mamba create -p ./env python=3.10
mamba activate ./env
pip install -r requirements.txt
For interactive sampling run
python main.py --name <name>
Or to generate a single sample run (not recommended, as jit compilation takes time)
python main.py --name <name> \
--height <height> --width <width> --nonloop \
--prompt "<prompt>"