Skip to content

Commit

Permalink
fix another delay related bug
Browse files Browse the repository at this point in the history
  • Loading branch information
maedoc committed Jul 2, 2024
1 parent 22a4360 commit 4aa15bb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion vbjax/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def make_continuation(run_chunk, chunk_len, max_lag, n_from, n_svar, stochastic=
from vbjax import randn

# need to be compile time constants for dynamic_*
i0 = chunk_len - 1
i0 = chunk_len
i1 = max_lag + 1

@jax.jit
Expand Down
6 changes: 3 additions & 3 deletions vbjax/tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ def f(buf, x, t, p):
nh = 16
dt = 0.1
clen = 30
buf = jp.zeros((nh + clen, 2, 1)) + 1.0
buf = jp.zeros((nh + 1 + clen, 2, 1)) + 1.0
p = vb.mpr_default_theta._replace(eta=-1.0)._replace(tau=3.0)

# run it
_, loop = vb.make_sdde(dt, nh, f, 0.0)
cc = vb.make_continuation(loop, buf.shape[0] - nh, nh, 1, 1, stochastic=False)
cc = vb.make_continuation(loop, clen, nh, 1, 1, stochastic=False)
# b, xs = jax.lax.scan(lambda buf,key: cc(buf,p,key), buf, vb.keys[:3])
xs = []
for i in range(3):
Expand All @@ -102,7 +102,7 @@ def f(buf, x, t, p):
xs = np.array(xs)

numpy.testing.assert_allclose(
loop(jp.zeros((nh + clen*3, 2, 1)) + 1.0, p)[1][:-2, 1, 0],
loop(jp.zeros((nh + 1 + clen*3, 2, 1)) + 1.0, p)[1][:, 1, 0],
xs.reshape(-1, 2)[:, 1],
1e-6, 2e-5
)

0 comments on commit 4aa15bb

Please sign in to comment.