Skip to content

Implementation of continuous-time recurrent neural networks (CT-RNNs) in Julia's SciML ecosystem

License

Notifications You must be signed in to change notification settings

keith-murray/ctrnn-julia

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CT-RNN Implementation in Julia's SciML Ecosystem

This repository contains an implementation of continuous-time recurrent neural networks (CT-RNNs) in the Julia programming language and SciML ecosystem. Specifically, the architecture and training of CT-RNNs is implemented via the Lux.jl, SciMLSensitivity.jl, and OrdinaryDiffEq.jl packages.

logo

Installation

Clone the repo and execute the following commands:

using Pkg
Pkg.activate("RecurrentNetworks")
Pkg.status()
Pkg.build()
using RecurrentNetworks

Usage

To train the model on a sample task, execute the following code:

RecurrentNetworks.schedule("./data/setup_data.jld2", "./data/models/", 1)

Examples

The scripts directory contains various example scripts train and examine trained models. All scripts can be executed with Pluto notebooks.

  • create_scheduler.jl - initializes parameter file for training
  • examine_saved_models.jl - exploratory analysis of a trained model
  • training_testing_data.jl - data initialization for sample task

A technical note

I wouldn't recommend using Julia for training CT-RNNs. My initial impression was that Julia and the SciML ecosystem could train CT-RNNs faster than PyTorch or TensorFlow in Python; however, I've since found the JAX ecosystem in Python to be strictly better than Julia and SciML. Checkout keith-murray/ctrnn-jax for a JAX implementation of CT-RNNs that is faster and more usable than this repo.

About

Implementation of continuous-time recurrent neural networks (CT-RNNs) in Julia's SciML ecosystem

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published