-
Notifications
You must be signed in to change notification settings - Fork 0
/
place_cells.py
executable file
·135 lines (102 loc) · 5.09 KB
/
place_cells.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
import scipy
class PlaceCells(object):
def __init__(self, options):
self.Np = options.Np
self.sigma = options.place_cell_rf
self.surround_scale = options.surround_scale
self.box_width = options.box_width
self.box_height = options.box_height
self.DoG = options.DoG
self.topology = options.topology
# Randomly tile place cell centers across environment
tf.random.set_seed(0)
usx = tf.random.uniform((self.Np,), -self.box_width/2, self.box_width/2, dtype=tf.float64)
usy = tf.random.uniform((self.Np,), -self.box_height/2, self.box_height/2, dtype=tf.float64)
self.us = tf.stack([usx, usy], axis=-1)
def get_activation(self, pos):
'''
Get place cell activations for a given position.
Args:
pos: 2d position of shape [batch_size, sequence_length, 2].
Returns:
outputs: Place cell activations with shape [batch_size, sequence_length, Np].
'''
d = tf.abs(pos[:, :, tf.newaxis, :] - self.us[tf.newaxis, tf.newaxis, ...])
if self.topology=='torus':
dx = tf.gather(d, 0, axis=-1)
dy = tf.gather(d, 1, axis=-1)
dx = tf.minimum(dx, self.box_width - dx)
dy = tf.minimum(dy, self.box_height - dy)
d = tf.stack([dx,dy], axis=-1)
if self.topology=='klein':
dxp = tf.gather(d, 0, axis=-1)
dyp = tf.gather(d, 1, axis=-1)
dxp = tf.minimum(dxp, self.box_width - dxp)
dp = tf.stack([dxp,dyp], axis=-1)
dyk = self.box_height - dyp
usm = -self.us[:,0]
dxk = tf.abs(pos[:, :, tf.newaxis, 0] - usm[tf.newaxis, tf.newaxis, ...])
dxk = tf.minimum(dxk, self.box_width - dxk)
dk = tf.stack([dxk,dyk], axis=-1)
norm2 = tf.minimum(tf.reduce_sum(dp**2, axis=-1),tf.reduce_sum(dk**2, axis=-1))
else:
norm2 = tf.reduce_sum(d**2, axis=-1)
# Normalize place cell outputs with prefactor alpha=1/2/np.pi/self.sigma**2,
# or, simply normalize with softmax, which yields same normalization on
# average and seems to speed up training.
outputs = tf.nn.softmax(-norm2/(2*self.sigma**2))
if self.DoG:
# Again, normalize with prefactor
# beta=1/2/np.pi/self.sigma**2/self.surround_scale, or use softmax.
outputs -= tf.nn.softmax(-norm2/(2*self.surround_scale*self.sigma**2))
# Shift and scale outputs so that they lie in [0,1].
outputs += tf.abs(tf.reduce_min(outputs, axis=-1, keepdims=True))
outputs /= tf.reduce_sum(outputs, axis=-1, keepdims=True)
return outputs
def get_nearest_cell_pos(self, activation, k=3):
'''
Decode position using centers of k maximally active place cells.
Args:
activation: Place cell activations of shape [batch_size, sequence_length, Np].
k: Number of maximally active place cells with which to decode position.
Returns:
pred_pos: Predicted 2d position with shape [batch_size, sequence_length, 2].
'''
_, idxs = tf.math.top_k(activation, k=k)
pred_pos = tf.reduce_mean(tf.gather(self.us, idxs), axis=-2)
return pred_pos
def grid_pc(self, pc_outputs, res=32):
''' Interpolate place cell outputs onto a grid'''
coordsx = np.linspace(-self.box_width/2, self.box_width/2, res)
coordsy = np.linspace(-self.box_height/2, self.box_height/2, res)
grid_x, grid_y = np.meshgrid(coordsx, coordsy)
grid = np.stack([grid_x.ravel(), grid_y.ravel()]).T
# Convert to numpy
us_np = self.us.numpy()
pc_outputs = pc_outputs.numpy().reshape(-1, self.Np)
T = pc_outputs.shape[0] #T vs transpose? What is T? (dim's?)
pc = np.zeros([T, res, res])
for i in range(len(pc_outputs)):
gridval = scipy.interpolate.griddata(us_np, pc_outputs[i], grid)
pc[i] = gridval.reshape([res, res])
pc[i] = np.flip(pc[i],axis=0)
return pc
def compute_covariance(self, res=30):
'''Compute spatial covariance matrix of place cell outputs'''
pos = np.array(np.meshgrid(np.linspace(-self.box_width/2, self.box_width/2, res),
np.linspace(-self.box_height/2, self.box_height/2, res))).T
pos = pos.astype(np.float32)
#Maybe specify dimensions here again?
pc_outputs = self.get_activation(pos)
pc_outputs = tf.reshape(pc_outputs, (-1, self.Np))
C = [email protected](pc_outputs)
Csquare = tf.reshape(C, (res,res,res,res))
Cmean = np.zeros([res,res])
for i in range(res):
for j in range(res):
Cmean += np.roll(np.roll(Csquare[i,j], -i, axis=0), -j, axis=1)
Cmean = np.roll(np.roll(Cmean, res//2, axis=0), res//2, axis=1)
return Cmean