Skip to content

Commit

Permalink
Modified equilibration for SA models to return only the final activation
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobSonstebo committed Sep 17, 2023
1 parent c6b4961 commit a327de8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
21 changes: 11 additions & 10 deletions spikeometric/models/sa_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def simulate(self, data: Data, n_steps: int, verbose: bool =True, equilibration_
device = edge_index.device

# If verbose is True, a progress bar is shown
pbar = tqdm(range(n_steps + equilibration_steps), colour="#3E5641") if verbose else range(n_steps + equilibration_steps)
pbar = tqdm(range(n_steps), colour="#3E5641") if verbose else range(n_steps)

# Initialize the state of the network
x = torch.zeros(n_neurons, n_steps + equilibration_steps, device=device, dtype=store_as_dtype)
initial_activation = torch.rand((n_neurons,1), device=device)
x = torch.zeros(n_neurons, n_steps, device=device, dtype=store_as_dtype)
initial_activation = torch.rand((n_neurons, T), device=device)
activation = self.equilibrate(edge_index, W, initial_activation, equilibration_steps, store_as_dtype=store_as_dtype)

# Simulate the network
Expand Down Expand Up @@ -176,7 +176,7 @@ def tune(

self.requires_grad_(False) # Freeze the parameters

def equilibrate(self, edge_index: torch.Tensor, W: torch.Tensor, inital_state: torch.Tensor, n_steps=100, store_as_dtype: torch.dtype = torch.int) -> torch.Tensor:
def equilibrate(self, edge_index: torch.Tensor, W: torch.Tensor, initial_state: torch.Tensor, n_steps=100, store_as_dtype: torch.dtype = torch.int) -> torch.Tensor:
"""
Equilibrate the network to a given connectivity matrix.
Expand All @@ -186,7 +186,7 @@ def equilibrate(self, edge_index: torch.Tensor, W: torch.Tensor, inital_state: t
The connectivity of the network
W: torch.Tensor
The connectivity filter
inital_state: torch.Tensor
initial_state: torch.Tensor
The initial state of the network
n_steps: int
The number of time steps to equilibrate for
Expand All @@ -198,13 +198,14 @@ def equilibrate(self, edge_index: torch.Tensor, W: torch.Tensor, inital_state: t
x: torch.Tensor
The state of the network at each time step
"""
n_neurons = inital_state.shape[0]
device = inital_state.device
n_neurons = initial_state.shape[0]
device = initial_state.device
x_equi = torch.zeros((n_neurons, self.T + n_steps), device=device, dtype=store_as_dtype)
x_equi[:, self.T-1] = inital_state.squeeze()
activation_equi = initial_state

# Equilibrate the network
for t in range(self.T, self.T + n_steps):
x_equi[:, t] = self(edge_index=edge_index, W=W, state=x_equi[:, t-self.T:t])
x_equi[:, t] = self(edge_index=edge_index, W=W, state=activation_equi)
activation_equi = self.update_activation(spikes=x_equi[:, t:t+self.T], activation=activation_equi)

return x_equi[:, -self.T:]
return activation_equi
16 changes: 15 additions & 1 deletion tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,18 @@ def test_storing_different_dtypes(bernoulli_glm, example_data):
import torch
for dtype in [torch.uint8, torch.int, torch.float, torch.double, torch.bool]:
X = bernoulli_glm.simulate(example_data, n_steps=1, verbose=False, store_as_dtype=dtype)
assert X.dtype == dtype
assert X.dtype == dtype

@pytest.mark.parametrize(
"model,example_data",
[
(pytest.lazy_fixture('bernoulli_glm'), pytest.lazy_fixture('bernoulli_glm_network')),
(pytest.lazy_fixture('poisson_glm'), pytest.lazy_fixture('poisson_glm_network')),
(pytest.lazy_fixture('rectified_lnp'), pytest.lazy_fixture('rectified_lnp_network')),
(pytest.lazy_fixture('threshold_sam'), pytest.lazy_fixture('threshold_sam_network')),
(pytest.lazy_fixture('rectified_sam'), pytest.lazy_fixture('rectified_sam_network')),
],
)
def test_does_not_store_equilibration_steps(model, example_data):
X = model.simulate(example_data, n_steps=100, verbose=False, equilibration_steps=10)
assert X.shape[1] == 100

0 comments on commit a327de8

Please sign in to comment.