A cryptographically-secure pseudo-random number generator for JAX based on the 20 round ChaCha cipher.
The ChaCha cipher was introduced in Daniel J. Bernstein "ChaCha, a variant of Salsa20".
The implementation follows the specification in the IRTF RFC 7539: "ChaCha20 and Poly1305 for IETF Protocols".
Note that the implementation is not security-hardened. Our threat models assumes that the machine on which the code is executed is a trusted environment and we keep key values, cipher states and plaintexts in plain memory.
The package is split into two modules:
- The
cipher
module is a full implementation of the ChaCha20 cipher. - The
random
module provides aJAX
-style API for the CSPRNG based on the cipher.
The package currently exposes basic RNG functions using the same interface as JAX
:
chacha.random.PRNGKey
: Equivalent tojax.random.PRNGKey
: Given a seed of up to 256 bits, it returns aPRNGKey
object from which randomness can be generated.chacha.random.split
: Equivalent tojax.random.split
: Splits a givenPRNGKey
into the desired number of freshPRNGKey
instances.chacha.random.fold_in
: !Deprecated! Equivalent tojax.random.fold_in
: Deterministically derives a newPRNGKey
from a given one and additional data.chacha.random.random_bits
: Equivalent tojax._src.random._random_bits
: Raw access to random bits, returned as an array of unsinged integers.chacha.random.uniform
: Equivalent tojax.random.uniform
: Uniformly sampled floating point numbers in the range[0, 1)
.
Note: PRNGKey
instances of this ChaCha20-based RNG are not interoperable with those of jax.random
, i.e., you cannot mix them.
Security notice Versions prior to 3.0.0 may repeat random states via the split
and fold_in
functions.
Per conventions of pseudo-random number generation in the JAX
framework, the functions random_bits
and uniform
are
deterministic given the randomness state (the PRNGKey
object). The user needs to split the state using split
before each
call to random_bits
or uniform
to get proper pseudo-random numbers. For more details, see what the JAX documentation has to say about this.
The PRNGKey
object is the state of the ChaCha cipher with the randomness seed provided by the user as the cipher key.
Every invocation of split
generates random numbers from the given state to use as fresh nonces in the derived states, with the key kept constant.
The counter in the state is used internally within calls to random_bits
and uniform
which can thus output up to 256 GiB of random values
for each state.
The following methods for direct use of the ChaCha20 cipher for encryption are available:
Quick use functions:
chacha.cipher.encrypt_with_key
: Encrypt a of any length message by providing a 256 bit key, 96 bit nonce/IV and an optional 32 bit initial counter value.chacha.cipher.decrypt_with_key
: Decrypt a of any length message by providing a 256 bit key, 96 bit nonce/IV and an optional 32 bit initial counter value.
State construction and use:
chacha.cipher.setup_state
: Create a ChaCha state structure by providing a 256 bit key, 96 bit nonce/IV and a 32 bit initial counter value.chacha.cipher.encrypt
: Encrypt a message of any length using a ChaCha state structure.chacha.cipher.decrypt
: Decrypt a message of any length using a ChaCha state structure.
For the latest stable version install via pip
pip install jax-chacha-prng
Binaries for glibc based 64-bit linux systems (manylinux wheels) are compiled with CPU and CUDA 11 support (you will have to install JAX with CUDA support to benefit from this). Binaries for all other systems are compiled for CPU execution only. This is because JAX does not have CUDA libraries for these systems either.
However, you can instruct pip to instead compile the package from sources via
pip install --no-binary :all: jax-chacha-prng
or by installing it directly from the v2-stable
branch:
pip install git+https://github.com/DPBayes/jax-chacha-prng@v2-stable#egg=jax-chacha-prng
This will compile CUDA kernels if the CUDA library is present on the system,
otherwise only CPU kernels will be built. To check whether CUDA kernels were
built and installed, you can check the return value of chacha.native.cuda_supported()
.
JAX is still under ungoing development and its developers currently give no guarantee that the API remains stable between releases. However, recent releases were mostly stable in the interfaces required for JAX-ChaCha-PRNG. In order to allow usage with JAX-ChaCha-PRNG with the most current JAX release, we therefore do not currently constrain the JAX version from above in our dependency list.
However, if you should encounter issues with a new JAX release at some point,
you can use the compatible-jax
installation target to force usage of the latest
JAX version known to be compatible with JAX-ChaCha-PRNG:
pip install .[compatible-jax]
The latest version of JAX-ChaCha-PRNG is currently known to work reliably with JAX versions 0.4.1 - 0.4.14 . We regularly check the compatible version range, but do not expect new versions of JAX to be immediately tested.
Version numbers adhere to Semantic Versioning. Changes between releases are tracked in ChangeLog.txt
.
The software is licensed under the Apache License 2.0.
You can find the full license text in LICENSES/Apache-2.0.txt
.
Single files included from third parties may be under a different license, which is annotated in the file
itself and a full license text included in the LICENSES
directory. The repository is fully REUSE-compliant.
We thank the NVIDIA AI Technology Center Finland for their contribution of GPU performance benchmarking and helpful discussions on optimisation.
We welcome any fixes, improvements or other contributions via pull request to this repository.
Before submitting your changes, please make sure to run our Python unit tests via pytest tests/
and
ensure that they all succeed. If you add new functionality, please also add tests.
If you made changes to the native C++/CUDA code, please also compile and run the native tests:
mkdir build
cmake -DBUILD_TESTING=On ..
make -j
./cpu_kernel_tests
./gpu_kernel_tests # if you have CUDA installed and a GPU available