Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs introduced by BlockArray changes #282

Merged
merged 14 commits into from
May 6, 2022
Merged

Conversation

Michael-T-McCann
Copy link
Contributor

@Michael-T-McCann Michael-T-McCann commented Apr 28, 2022

Running the examples revealed several bugs introduced by the BlockArray changes in #259 . Specifically,

  • x.sum() cannot be used to fully reduce a BlockArray x because as per the docs, all BlockArray methods map over the blocks, except a few reductions in snp. snp.sum(x) is required instead.
  • snp.iscomplexobj cannot be used to used to check if a BlockArray is complex, because it maps over blocks. In the future, we may want mixed types in BlockArray so such a check cannot work. In the short term, snp.util.is_complex_dtype(x.dtype) will work.

This PR adds tests that would have caught these problems and fixes them.

Closes #281 .

@codecov
Copy link

codecov bot commented Apr 28, 2022

Codecov Report

Merging #282 (be51f8f) into main (4dc86b8) will not change coverage.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##             main     #282   +/-   ##
=======================================
  Coverage   94.15%   94.15%           
=======================================
  Files          49       49           
  Lines        3269     3269           
=======================================
  Hits         3078     3078           
  Misses        191      191           
Flag Coverage Δ
unittests 94.15% <100.00%> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
scico/denoiser.py 89.02% <100.00%> (ø)
scico/functional/_indicator.py 100.00% <100.00%> (ø)
scico/functional/_norm.py 100.00% <100.00%> (ø)
scico/linop/_linop.py 98.09% <100.00%> (ø)
scico/loss.py 94.97% <100.00%> (ø)
scico/numpy/_wrapped_function_lists.py 100.00% <100.00%> (ø)
scico/optimize/pgm.py 96.46% <100.00%> (ø)
scico/solver.py 98.52% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4dc86b8...be51f8f. Read the comment docs.

@bwohlberg bwohlberg added the bug Something isn't working label Apr 28, 2022
@Michael-T-McCann Michael-T-McCann force-pushed the mike/BA-example-bug branch 3 times, most recently from 3402386 to 8c2cc64 Compare May 2, 2022 16:33
@tbalke
Copy link
Contributor

tbalke commented May 3, 2022

Running the examples now, I still get these issues (below).


denoise_tv_iso_pgm.py
Solving on CPU

Traceback (most recent call last):
  File "/Users/thilobalke/pythonModules/scico/examples/scripts/denoise_tv_iso_pgm.py", line 154, in <module>
    x = solver_iso.solve()
  File "/Users/thilobalke/pythonModules/scico/scico/optimize/pgm.py", line 533, in solve
    self.step()
  File "/Users/thilobalke/pythonModules/scico/scico/optimize/pgm.py", line 611, in step
    self.L = self.step_size.update(self.v)
  File "/Users/thilobalke/pythonModules/scico/scico/optimize/pgm.py", line 362, in update
    z = self.pgm.x_step(y, L)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/api.py", line 429, in cache_miss
    out_flat = xla.xla_call(
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/core.py", line 1683, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/core.py", line 1695, in call_bind
    outs = top_trace.process_call(primitive, fun, tracers, params)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/core.py", line 594, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/dispatch.py", line 142, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/linear_util.py", line 272, in memoized_fun
    ans = call(fun, *args)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/dispatch.py", line 169, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/dispatch.py", line 197, in lower_xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1680, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1657, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/thilobalke/pythonModules/scico/scico/optimize/pgm.py", line 441, in x_step
    return self.g.prox(v - 1.0 / L * self.f.grad(v), 1.0 / L)
  File "/Users/thilobalke/pythonModules/scico/scico/functional/_functional.py", line 126, in grad
    return self._grad(x)
  File "/Users/thilobalke/pythonModules/scico/scico/_autograd.py", line 53, in conjugated_grad
    jg = jax_grad(*args, **kwargs)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/api.py", line 993, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/api.py", line 1069, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/api.py", line 2522, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/interpreters/ad.py", line 116, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/interpreters/ad.py", line 103, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 522, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/api.py", line 427, in cache_miss
    _check_arg(arg)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/jax/_src/api.py", line 2924, in _check_arg
    raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Argument '<class '__main__.DualTVLoss'>
has_eval = True
has_prox = False
        ' of type <class '__main__.DualTVLoss'> is not a valid JAX type.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/thilobalke/pythonModules/scico/examples/scripts/denoise_tv_iso_pgm.py", line 154, in <module>
    x = solver_iso.solve()
  File "/Users/thilobalke/pythonModules/scico/scico/optimize/pgm.py", line 533, in solve
    self.step()
  File "/Users/thilobalke/pythonModules/scico/scico/optimize/pgm.py", line 611, in step
    self.L = self.step_size.update(self.v)
  File "/Users/thilobalke/pythonModules/scico/scico/optimize/pgm.py", line 362, in update
    z = self.pgm.x_step(y, L)
  File "/Users/thilobalke/pythonModules/scico/scico/optimize/pgm.py", line 441, in x_step
    return self.g.prox(v - 1.0 / L * self.f.grad(v), 1.0 / L)
  File "/Users/thilobalke/pythonModules/scico/scico/functional/_functional.py", line 126, in grad
    return self._grad(x)
  File "/Users/thilobalke/pythonModules/scico/scico/_autograd.py", line 53, in conjugated_grad
    jg = jax_grad(*args, **kwargs)
TypeError: Argument '<class '__main__.DualTVLoss'>
has_eval = True
has_prox = False
        ' of type <class '__main__.DualTVLoss'> is not a valid JAX type.

sparsecode_poisson_pgm.py
Solving on CPU

============================================================
Running solver with step size of class:  PGMStepSize
L0 (Specifically chosen so that convergence occurs):  1000.0 

Iter  Time      Objective  L          Residual 
-----------------------------------------------
   0  2.48e-01  1.358e+03  1.000e+03  1.253e-01
  10  6.17e-01  1.339e+03  1.000e+03  1.304e-02
  20  7.74e-01  1.336e+03  1.000e+03  3.918e-03
  30  9.31e-01  1.336e+03  1.000e+03  1.620e-03
  40  1.09e+00  1.336e+03  1.000e+03  1.098e-03
  49  1.23e+00  1.336e+03  1.000e+03  5.685e-04
===================================================
Running solver with step size of class:  BBStepSize
L0 (Arbitrary Initialization):  90.0 

Iter  Time      Objective  L          Residual 
-----------------------------------------------
   0  3.71e-01  1.479e+03  9.000e+01  9.151e-01
  10  6.10e-01        nan  9.000e+01        nan
  20  8.14e-01        nan  9.000e+01        nan
  30  1.02e+00        nan  9.000e+01        nan
  40  1.22e+00        nan  9.000e+01        nan
  49  1.41e+00        nan  9.000e+01        nan
/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/matplotlib/axes/_base.py:2475: UserWarning: Warning: converting a masked element to nan.
  xys = np.asarray(xys)
===========================================================
Running solver with step size of class:  AdaptiveBBStepSize
L0 (Arbitrary Initialization):  90.0 

Iter  Time      Objective  L          Residual 
-----------------------------------------------
   0  1.08e-01  1.479e+03  9.000e+01  9.151e-01
  10  3.02e-01        nan  9.000e+01        nan
  20  5.07e-01        nan  9.000e+01        nan
  30  7.09e-01        nan  9.000e+01        nan
  40  9.13e-01        nan  9.000e+01        nan
  49  1.11e+00        nan  9.000e+01        nan
/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/site-packages/matplotlib/axes/_base.py:2475: UserWarning: Warning: converting a masked element to nan.
  xys = np.asarray(xys)
===========================================================
Running solver with step size of class:  LineSearchStepSize
L0 (Arbitrary Initialization):  90.0 

Iter  Time      Objective  L          Residual 
-----------------------------------------------
   0  2.51e-01  1.351e+03  4.644e+02  2.699e-01
  10  5.05e-01  1.337e+03  4.644e+02  1.291e-02
  20  7.34e-01  1.336e+03  4.644e+02  3.084e-03
  30  9.65e-01  1.336e+03  4.644e+02  9.443e-04
  40  1.20e+00  1.336e+03  4.644e+02  8.498e-04
  49  1.42e+00  1.336e+03  5.573e+02  3.144e-04
=================================================================
Running solver with step size of class:  RobustLineSearchStepSize
L0 (Arbitrary Initialization):  90.0 

Iter  Time      Objective  L          Residual 
-----------------------------------------------
   0  1.57e-01  1.353e+03  6.480e+02  1.934e-01
  10  3.66e-01  1.342e+03  2.259e+02  3.863e-02
  20  5.97e-01  1.338e+03  3.151e+02  1.287e-02
  30  8.18e-01  1.337e+03  2.198e+02  1.610e-02
  40  1.04e+00  1.337e+03  1.532e+02  1.619e-02
  49  1.24e+00  1.336e+03  2.375e+02  6.250e-03

Waiting for input to close figures and exit

video_rpca_admm.py
Traceback (most recent call last):
  File "/Users/thilobalke/pythonModules/scico/examples/scripts/video_rpca_admm.py", line 56, in <module>
    A = linop.Sum(sum_axis=0, input_shape=(2,) + y.shape)
  File "/Users/thilobalke/pythonModules/scico/scico/linop/_linop.py", line 349, in __init__
    super().__init__(input_shape, input_dtype=input_dtype, jit=jit)  # type: ignore
  File "/Users/thilobalke/pythonModules/scico/scico/_generic_operators.py", line 453, in __init__
    super().__init__(
  File "/Users/thilobalke/pythonModules/scico/scico/_generic_operators.py", line 144, in __init__
    tmp = self(snp.zeros(self.input_shape, dtype=input_dtype))
  File "/Users/thilobalke/pythonModules/scico/scico/_generic_operators.py", line 580, in __call__
    return super().__call__(x)
  File "/Users/thilobalke/pythonModules/scico/scico/_generic_operators.py", line 198, in __call__
    return self._eval(x)
  File "/Users/thilobalke/pythonModules/scico/scico/linop/_linop.py", line 348, in <lambda>
    self._eval = lambda x: f(x, *args, **kwargs)
  File "/Users/thilobalke/pythonModules/scico/scico/numpy/_wrappers.py", line 137, in wrapped
    bound_args = sig.bind(*args, **kwargs)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/inspect.py", line 3043, in bind
    return self._bind(args, kwargs)
  File "/Users/thilobalke/opt/anaconda3/envs/scico/lib/python3.9/inspect.py", line 3032, in _bind
    raise TypeError(
TypeError: got an unexpected keyword argument 'sum_axis'

@Michael-T-McCann Michael-T-McCann merged commit e2a9a4e into main May 6, 2022
@Michael-T-McCann Michael-T-McCann deleted the mike/BA-example-bug branch May 9, 2022 20:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

iscomplexobj(BlockArray) after #259
3 participants