Skip to content

Commit

Permalink
Merge pull request #27 from bioAI-Oslo/main
Browse files Browse the repository at this point in the history
Update paper branch
  • Loading branch information
JakobSonstebo authored Sep 18, 2023
2 parents 1a75b51 + 6fc997c commit 22d239e
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 39 deletions.
2 changes: 1 addition & 1 deletion docs/benchmarks/benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ One exception is the :class:`BernoulliGLM`, which also includes a refractory per
an additional :math:`N_{neurons}` synapses.

We also need to store the number of spikes of each neuron per time step, which by default consumes 32 bytes.
In most cases, however, we don't expect the number of spikes per time step for any neuron to exceed 127, which means we can safely reduce the memory conumption to
In most cases, however, we don't expect the number of spikes per time step for any neuron to exceed 127, which means we can safely reduce the memory consumption to
8 bytes by passing :code:`torch.int8` as the :code:`store_as_dtype` argument of the :meth:`simulate` method if we need additional memory.

Concretely, the total memory usage (in bytes) of GLM models can be estimated as
Expand Down
5 changes: 3 additions & 2 deletions docs/introduction/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@ in the :class:`BernoulliGLM` class, using the same parameters as in the original
.. code-block:: python
model = BernoulliGLM(
alpha= 0.2, # Decay rate of the coupling strength between neurons (1/ms)
beta= 0.5, # Decay rate of the self-inhibition during the relative refractory period (1/ms)
alpha=0.2, # Decay rate of the coupling strength between neurons (1/ms)
beta=0.5, # Decay rate of the self-inhibition during the relative refractory period (1/ms)
abs_ref_scale=3, # Absolute refractory period in time steps
rel_ref_scale=7, # Relative refractory period in time steps
abs_ref_strength=-100, # Strength of the self-inhibition during the absolute refractory period
rel_ref_strength=-30, # Initial strength of the self-inhibition during the relative refractory period
coupling_window=5, # Length of coupling window in time steps
theta=5, # Threshold for firing
r=1, # Parameter controlling the recurrence strength
dt=1, # Length of time step (ms)
)
Expand Down
3 changes: 3 additions & 0 deletions docs/tutorials/stimuli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ After we've defined the model and the stimulus, we can simulate the network and
rel_ref_strength=-30,
alpha=0.2,
beta=0.5,
r=1
)
# Define stimulus and add it to the model
Expand Down Expand Up @@ -92,6 +93,7 @@ Before we add the stimulus to the model, we'll run a simulation without it to se
rel_ref_strength=-30,
alpha=0.2,
beta=0.5,
r=1
)
spikes = model.simulate(network, n_steps=n_steps)
Expand Down Expand Up @@ -156,6 +158,7 @@ that is close to the frequency of the stimulus.
rel_ref_strength=-30,
alpha=0.2,
beta=0.5,
r=1
)
stimulus = SinStimulus(
Expand Down
3 changes: 2 additions & 1 deletion examples/large_scale_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@
" rel_ref_scale=5,\n",
" rel_ref_strength=-30,\n",
" beta=0.1,\n",
" r=1\n",
")"
]
},
Expand Down Expand Up @@ -842,7 +843,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.11"
},
"orig_nbformat": 4
},
Expand Down
25 changes: 14 additions & 11 deletions examples/simulate_with_stimulus.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions examples/working_with_stimulus.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
" rel_ref_strength=-30, \n",
" alpha=0.2,\n",
" beta=0.5,\n",
" r=1,\n",
")\n",
"model.add_stimulus(stim)\n",
"\n",
Expand Down Expand Up @@ -233,6 +234,7 @@
" rel_ref_strength=-30, \n",
" alpha=0.2,\n",
" beta=0.5,\n",
" r=1,\n",
")"
]
},
Expand Down Expand Up @@ -428,6 +430,7 @@
" rel_ref_strength=-30, \n",
" alpha=0.2,\n",
" beta=0.5,\n",
" r=1,\n",
")\n",
"\n",
"\n",
Expand Down
28 changes: 16 additions & 12 deletions spikeometric/models/bernoulli_glm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ class BernoulliGLM(BaseModel):
More formally, the model can be broken into three steps, each of which is implemented as a separate method in this class:
#. .. math:: g_i(t+1) = \sum_{\tau=0}^{T-1} \left(X_i(t-\tau)r(\tau) + \sum_{j \in \mathcal{N}(i)} (W_0)_{j, i} X_j(t-\tau) c(\tau)\right) + \mathcal{E}_i(t+1)
#. .. math:: g_i(t+1) = \sum_{\tau=0}^{T-1} \left(X_i(t-\tau)ref(\tau) + r\sum_{j \in \mathcal{N}(i)} (W_0)_{j, i} X_j(t-\tau) c(\tau)\right) + \mathcal{E}_i(t+1)
#. .. math:: p_i(t+1) = \sigma(g_i(t+1) - \theta) \Delta t
#. .. math:: X_i(t+1) \sim \text{Bernoulli}(p_i(t+1))
The first equation is implemented in the :meth:`input` method and gives us the input to the neuron :math:`i` at time :math:`t+1` as a sum of the refractory, synaptic and external inputs.
The refractory input is calculated by convolving the spike history of the neuron itself with a refractory filter :math:`r`, the synaptic input is obtained by convolving the spike history
The refractory input is calculated by convolving the spike history of the neuron itself with a refractory filter :math:`ref`, the synaptic input is obtained by convolving the spike history
of the neuron's neighbors with the coupling filter :math:`c`, weighted by the synaptic weights :math:`W_0`, and the exteral input is given by evaluating an external input function :math:`\mathcal{E}` at time :math:`t+1`.
The second equation is implemented in :meth:`non_linearity` which computes the probability that the neuron :math:`i` spikes at time :math:`t+1` by passing
Expand All @@ -40,13 +40,15 @@ class BernoulliGLM(BaseModel):
abs_ref_scale : int
The absolute refractory period of the neurons :math:`A_{ref}` in time steps
abs_ref_strength : float
The large negative activation :math:`a` added to the neurons during the absolute refractory period
The large negative activation :math:`abs` added to the neurons during the absolute refractory period
rel_ref_scale : int
The relative refractory period of the neurons :math:`R_{ref}` in time steps
rel_ref_strength : float
The negative activation :math:`r` added to the neurons during the relative refractory period (tunable)
The negative activation :math:`rel` added to the neurons during the relative refractory period (tunable)
beta : float
The decay rate :math:`\beta` of the weights. (tunable)
r : float
The scaling of the recurrent connections. (tunable)
rng : torch.Generator
The random number generator for sampling from the Bernoulli distribution.
"""
Expand All @@ -61,6 +63,7 @@ def __init__(self,
rel_ref_scale: int,
rel_ref_strength: int,
beta: float,
r: float,
rng=None
):
super().__init__()
Expand All @@ -76,6 +79,7 @@ def __init__(self,

# Parameters are used to store tensors that will be tunable
self.register_parameter("theta", nn.Parameter(torch.tensor(theta, dtype=torch.float)))
self.register_parameter("r", torch.nn.Parameter(torch.tensor(r, dtype=torch.float)))
self.register_parameter("beta", nn.Parameter(torch.tensor(beta, dtype=torch.float)))
self.register_parameter("alpha", nn.Parameter(torch.tensor(alpha, dtype=torch.float)))
self.register_parameter("rel_ref_strength", nn.Parameter(torch.tensor(rel_ref_strength, dtype=torch.float)))
Expand All @@ -88,7 +92,7 @@ def input(self, edge_index: torch.Tensor, W: torch.Tensor, state: torch.Tensor,
Computes the input at time step :obj:`t+1` by adding together the synaptic input from neighboring neurons and the stimulus input.
.. math::
g_i(t+1) = \sum_{\tau=0}^{T-1} \left(X_i(t-\tau)r(\tau) + \sum_{j \in \mathcal{N}(i)} (W_0)_{j, i} X_j(t-\tau) c(\tau)\right) + \mathcal{E}_i(t+1)
g_i(t+1) = \sum_{\tau=0}^{T-1} \left(X_i(t-\tau)ref(\tau) + \sum_{j \in \mathcal{N}(i)} (W_0)_{j, i} X_j(t-\tau) c(\tau)\right) + \mathcal{E}_i(t+1)
Parameters
----------
Expand All @@ -106,7 +110,7 @@ def input(self, edge_index: torch.Tensor, W: torch.Tensor, state: torch.Tensor,
synaptic_input : torch.Tensor [n_neurons, 1]
"""
return self.synaptic_input(edge_index, W, state=state) + self.stimulus_input(t)
return self.r * self.synaptic_input(edge_index, W, state=state) + self.stimulus_input(t)

def non_linearity(self, input: torch.Tensor) -> torch.Tensor:
r"""
Expand Down Expand Up @@ -150,7 +154,7 @@ def connectivity_filter(self, W0: torch.Tensor, edge_index: torch.Tensor) -> tor
r"""
The connectivity filter constructs a tensor holding the weights of the edges in the network.
This is done by filtering the initial coupling weights :math:`W_0` with the coupling filter :math:`c`
and using a refractory filter :math:`r` as self-edge weights to emulate the refractory period.
and using a refractory filter :math:`ref` as self-edge weights to emulate the refractory period.
For the coupling edges, we are given an initial weight :math:`(W_0)_{i,j}` for each edge. This
tells us how strong the connection between neurons :math:`i` and :math:`j` is immediately after a spike event.
Expand All @@ -172,16 +176,16 @@ def connectivity_filter(self, W0: torch.Tensor, edge_index: torch.Tensor) -> tor
This is modeled by weighting spike events by to :math:`r e^{-\alpha t \Delta t}` for
the next :math:`R_{ref}` time steps.
That is, the refractory filter :math:`r` is given by
That is, the refractory filter :math:`ref` is given by
.. math::
r(t) = \begin{cases}
a & \text{if } t < A_{ref} \\
r e^{-\alpha t \Delta t} & \text{if } A_{ref} \leq t < A_{ref} + R_{ref} \\
ref(t) = \begin{cases}
abs & \text{if } t < A_{ref} \\
rel e^{-\alpha t \Delta t} & \text{if } A_{ref} \leq t < A_{ref} + R_{ref} \\
0 & \text{if } A_{ref} + R_{ref} \leq t
\end{cases}
And we set `W_{i, i}(t) = r(t)` for all neurons :math:`i`.
And we set `W_{i, i}(t) = ref(t)` for all neurons :math:`i`.
All of this information can be represented by a tensor :math:`W` of shape :math:`N\times N\times T`, where
:code:`W[i, j, t]` is the weight of the edge from neuron :math:`i` to neuron :math:`j` at time step :math:`t` after a spike event.
Expand Down
21 changes: 11 additions & 10 deletions spikeometric/models/sa_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def simulate(self, data: Data, n_steps: int, verbose: bool =True, equilibration_
device = edge_index.device

# If verbose is True, a progress bar is shown
pbar = tqdm(range(n_steps + equilibration_steps), colour="#3E5641") if verbose else range(n_steps + equilibration_steps)
pbar = tqdm(range(n_steps), colour="#3E5641") if verbose else range(n_steps)

# Initialize the state of the network
x = torch.zeros(n_neurons, n_steps + equilibration_steps, device=device, dtype=store_as_dtype)
initial_activation = torch.rand((n_neurons,1), device=device)
x = torch.zeros(n_neurons, n_steps, device=device, dtype=store_as_dtype)
initial_activation = torch.rand((n_neurons, T), device=device)
activation = self.equilibrate(edge_index, W, initial_activation, equilibration_steps, store_as_dtype=store_as_dtype)

# Simulate the network
Expand Down Expand Up @@ -176,7 +176,7 @@ def tune(

self.requires_grad_(False) # Freeze the parameters

def equilibrate(self, edge_index: torch.Tensor, W: torch.Tensor, inital_state: torch.Tensor, n_steps=100, store_as_dtype: torch.dtype = torch.int) -> torch.Tensor:
def equilibrate(self, edge_index: torch.Tensor, W: torch.Tensor, initial_state: torch.Tensor, n_steps=100, store_as_dtype: torch.dtype = torch.int) -> torch.Tensor:
"""
Equilibrate the network to a given connectivity matrix.
Expand All @@ -186,7 +186,7 @@ def equilibrate(self, edge_index: torch.Tensor, W: torch.Tensor, inital_state: t
The connectivity of the network
W: torch.Tensor
The connectivity filter
inital_state: torch.Tensor
initial_state: torch.Tensor
The initial state of the network
n_steps: int
The number of time steps to equilibrate for
Expand All @@ -198,13 +198,14 @@ def equilibrate(self, edge_index: torch.Tensor, W: torch.Tensor, inital_state: t
x: torch.Tensor
The state of the network at each time step
"""
n_neurons = inital_state.shape[0]
device = inital_state.device
n_neurons = initial_state.shape[0]
device = initial_state.device
x_equi = torch.zeros((n_neurons, self.T + n_steps), device=device, dtype=store_as_dtype)
x_equi[:, self.T-1] = inital_state.squeeze()
activation_equi = initial_state

# Equilibrate the network
for t in range(self.T, self.T + n_steps):
x_equi[:, t] = self(edge_index=edge_index, W=W, state=x_equi[:, t-self.T:t])
x_equi[:, t] = self(edge_index=edge_index, W=W, state=activation_equi)
activation_equi = self.update_activation(spikes=x_equi[:, t:t+self.T], activation=activation_equi)

return x_equi[:, -self.T:]
return activation_equi
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def bernoulli_glm():
rel_ref_scale=7,
rel_ref_strength=-30.,
alpha=0.2,
r=1,
rng=rng,
)
return model
Expand Down
Binary file modified tests/test_data/stim_plan.pt
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_save_load(bernoulli_glm):
from spikeometric.models import BernoulliGLM
with NamedTemporaryFile() as f:
bernoulli_glm.save(f.name)
loaded_model = BernoulliGLM(1, 1, 1, 1, 1, 1, 1, 1, 1)
loaded_model = BernoulliGLM(1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
loaded_model.load(f.name)
for param, loaded_param in zip(bernoulli_glm.parameters(), loaded_model.parameters()):
assert_close(param, loaded_param)
Expand Down
16 changes: 15 additions & 1 deletion tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,18 @@ def test_storing_different_dtypes(bernoulli_glm, example_data):
import torch
for dtype in [torch.uint8, torch.int, torch.float, torch.double, torch.bool]:
X = bernoulli_glm.simulate(example_data, n_steps=1, verbose=False, store_as_dtype=dtype)
assert X.dtype == dtype
assert X.dtype == dtype

@pytest.mark.parametrize(
"model,example_data",
[
(pytest.lazy_fixture('bernoulli_glm'), pytest.lazy_fixture('bernoulli_glm_network')),
(pytest.lazy_fixture('poisson_glm'), pytest.lazy_fixture('poisson_glm_network')),
(pytest.lazy_fixture('rectified_lnp'), pytest.lazy_fixture('rectified_lnp_network')),
(pytest.lazy_fixture('threshold_sam'), pytest.lazy_fixture('threshold_sam_network')),
(pytest.lazy_fixture('rectified_sam'), pytest.lazy_fixture('rectified_sam_network')),
],
)
def test_does_not_store_equilibration_steps(model, example_data):
X = model.simulate(example_data, n_steps=100, verbose=False, equilibration_steps=10)
assert X.shape[1] == 100

0 comments on commit 22d239e

Please sign in to comment.