This repository contains an implementation of continuous-time recurrent neural networks (CT-RNNs) in the Julia programming language and SciML ecosystem. Specifically, the architecture and training of CT-RNNs is implemented via the Lux.jl, SciMLSensitivity.jl, and OrdinaryDiffEq.jl packages.
Clone the repo and execute the following commands:
using Pkg
Pkg.activate("RecurrentNetworks")
Pkg.status()
Pkg.build()
using RecurrentNetworks
To train the model on a sample task, execute the following code:
RecurrentNetworks.schedule("./data/setup_data.jld2", "./data/models/", 1)
The scripts
directory contains various example scripts train and examine trained models. All scripts can be executed with Pluto notebooks.
create_scheduler.jl
- initializes parameter file for trainingexamine_saved_models.jl
- exploratory analysis of a trained modeltraining_testing_data.jl
- data initialization for sample task
I wouldn't recommend using Julia for training CT-RNNs. My initial impression was that Julia and the SciML ecosystem could train CT-RNNs faster than PyTorch or TensorFlow in Python; however, I've since found the JAX ecosystem in Python to be strictly better than Julia and SciML. Checkout keith-murray/ctrnn-jax
for a JAX implementation of CT-RNNs that is faster and more usable than this repo.