Skip to content

lkwq007/flux-flax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 

Repository files navigation

FLUX-Flax

A JAX port of FLUX.1 models using flax.nnx.

Important

The current codebase is designed to maintain consistency with the original implementation, with minimal modifications. While it works as expected, it may not be the most efficient implementation. I plan to release an updated version soon that better adheres to JAX conventions and best practices.

img

Status

Only tested with GPU now.

Currently no quantization support & no torch-like CPU offloading support.

PRs are welcome.

Local installation

git clone https://github.com/lkwq007/flux-flax.git
cd flux-flax
mamba create -p ./env python=3.10
mamba activate ./env
pip install -r requirements.txt

Usage

For interactive sampling run

python main.py --name <name>

Or to generate a single sample run (not recommended, as jit compilation takes time)

python main.py --name <name> \
  --height <height> --width <width> --nonloop \
  --prompt "<prompt>"

About

JAX port of FLUX.1 models using flax.nnx

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages