Skip to content

Commit

Permalink
Improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Oct 15, 2024
1 parent f49430d commit 5c9c974
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions scico/test/linop/xray/test_xray.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np

import jax
import jax.numpy as jnp

import pytest
Expand Down Expand Up @@ -76,8 +77,9 @@ def test_apply_adjoint():
@pytest.mark.parametrize("det_count_factor", [1.02 / np.sqrt(2.0), 1.0])
def test_fbp(dx, det_count_factor):
N = 256
x_gt = np.zeros((256, 256), dtype=np.float32)
x_gt[64:-64, 64:-64] = 1.0
x_gt = np.zeros((N, N), dtype=np.float32)
N4 = N // 4
x_gt[N4:-N4, N4:-N4] = 1.0

det_count = int(det_count_factor * N)
n_proj = 360
Expand All @@ -88,6 +90,19 @@ def test_fbp(dx, det_count_factor):
assert psnr(x_gt, x_fbp) > 28


def test_fbp_jit():
N = 64
x_gt = np.ones((N, N), dtype=np.float32)

det_count = N
n_proj = 90
angles = np.linspace(0, np.pi, n_proj, endpoint=False)
A = XRayTransform2D(x_gt.shape, angles, det_count=det_count)
y = A(x_gt)
fbp = jax.jit(A.fbp)
x_fbp = fbp(y)


def test_3d_scaling():
x = jnp.zeros((4, 4, 1))
x = x.at[1:3, 1:3, 0].set(1.0)
Expand Down

0 comments on commit 5c9c974

Please sign in to comment.