Skip to content

Commit

Permalink
Simplify BlockArray implementation (#259)
Browse files Browse the repository at this point in the history
Co-authored-by: Thilo Balke <[email protected]>
  • Loading branch information
Michael-T-McCann and tbalke authored Apr 22, 2022
1 parent a1838ee commit 4381be5
Show file tree
Hide file tree
Showing 55 changed files with 1,348 additions and 3,017 deletions.
4 changes: 2 additions & 2 deletions docs/source/notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ When evaluating the gradient of ``f`` at 0, :func:`scico.grad` returns ``nan``:
>>> import scico
>>> import scico.numpy as snp
>>> f = lambda x: snp.linalg.norm(x)**2
>>> scico.grad(f)(snp.zeros(2)) #
>>> scico.grad(f)(snp.zeros(2, dtype=snp.float32)) #
DeviceArray([nan, nan], dtype=float32)

This can be fixed by defining the squared :math:`\ell_2` norm directly as
Expand All @@ -194,7 +194,7 @@ This can be fixed by defining the squared :math:`\ell_2` norm directly as
::

>>> g = lambda x: snp.sum(x**2)
>>> scico.grad(g)(snp.zeros(2))
>>> scico.grad(g)(snp.zeros(2, dtype=snp.float32))
DeviceArray([0., 0.], dtype=float32)

An alternative is to define a `custom derivative rule <https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html#enforcing-a-differentiation-convention>`_ to enforce a particular derivative convention at a point.
2 changes: 1 addition & 1 deletion docs/source/operator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Each :class:`.Operator` object has an ``input_shape`` and ``output_shape``; thes
The ``matrix_shape`` attribute describes the shape of the :class:`.LinearOperator` if it were to act on vectorized, or flattened, inputs.


For example, consider a two dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`.
For example, consider a two-dimensional array :math:`\mb{x} \in \mathbb{R}^{n \times m}`.
We compute the discrete differences of :math:`\mb{x}` in the horizontal and vertical directions,
generating two new arrays: :math:`\mb{x}_h \in \mathbb{R}^{n \times (m-1)}` and :math:`\mb{x}_v \in
\mathbb{R}^{(n-1) \times m}`. We represent this linear operator by
Expand Down
12 changes: 7 additions & 5 deletions examples/scripts/denoise_tv_iso_pgm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python
#!/Usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the SCICO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
Expand Down Expand Up @@ -39,8 +39,8 @@
import scico.numpy as snp
import scico.random
from scico import functional, linop, loss, operator, plot
from scico.array import ensure_on_device
from scico.blockarray import BlockArray
from scico.numpy import BlockArray
from scico.numpy.util import ensure_on_device
from scico.optimize.pgm import AcceleratedPGM, RobustLineSearchStepSize
from scico.typing import JaxArray
from scico.util import device_info
Expand Down Expand Up @@ -96,6 +96,7 @@ def __init__(
super().__init__(y=y, A=A, scale=1.0)
self.lmbda = lmbda

@jax.jit
def __call__(self, x: Union[JaxArray, BlockArray]) -> float:

xint = self.y - self.lmbda * self.A(x)
Expand All @@ -117,14 +118,15 @@ class IsoProjector(functional.Functional):
def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
return 0.0

@jax.jit
def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray:
norm_v_ptp = jnp.sqrt(jnp.sum(jnp.abs(v) ** 2, axis=0))

x_out = v / jnp.maximum(jnp.ones(v.shape), norm_v_ptp)
out1 = v[0, :, -1] / jnp.maximum(jnp.ones(v[0, :, -1].shape), jnp.abs(v[0, :, -1]))
x_out_1 = jax.ops.index_update(x_out, jax.ops.index[0, :, -1], out1)
x_out = x_out.at[0, :, -1].set(out1)
out2 = v[1, -1, :] / jnp.maximum(jnp.ones(v[1, -1, :].shape), jnp.abs(v[1, -1, :]))
x_out = jax.ops.index_update(x_out_1, jax.ops.index[1, -1, :], out2)
x_out = x_out.at[1, -1, :].set(out2)

return x_out

Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/sparsecode_poisson_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
$I(\mathbf{x}^{(0)} \geq 0)$ is the non-negative indicator.
This example also demonstrates the application of
[blockarray.BlockArray](../_autosummary/scico.blockarray.rst#scico.blockarray.BlockArray),
[blockarray.BlockArray](../_autosummary/scico.numpy.rst#scico.numpy.BlockArray),
[functional.SeparableFunctional](../_autosummary/scico.functional.rst#scico.functional.SeparableFunctional),
and
[functional.ZeroFunctional](../_autosummary/scico.functional.rst#scico.functional.ZeroFunctional)
Expand All @@ -40,7 +40,7 @@
import scico.numpy as snp
import scico.random
from scico import functional, loss, plot
from scico.blockarray import BlockArray
from scico.numpy import BlockArray
from scico.operator import Operator
from scico.optimize.pgm import (
AcceleratedPGM,
Expand Down
2 changes: 1 addition & 1 deletion scico/_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flax.core import Scope # noqa
from flax.linen.module import _Sentinel # noqa

from scico.blockarray import BlockArray
from scico.numpy import BlockArray
from scico.typing import JaxArray

# The imports of Scope and _Sentinel (above) and the definition of Module
Expand Down
14 changes: 7 additions & 7 deletions scico/_generic_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

import scico.numpy as snp
from scico._autograd import linear_adjoint
from scico.array import is_complex_dtype, is_nested
from scico.blockarray import BlockArray, block_sizes
from scico.numpy import BlockArray
from scico.numpy.util import is_complex_dtype, is_nested, shape_to_size
from scico.typing import BlockShape, DType, JaxArray, Shape


Expand Down Expand Up @@ -152,8 +152,8 @@ def __init__(
# Determine the shape of the "vectorized" operator (as an element of ℝ^{n × m}
# If the function returns a BlockArray we need to compute the size of each block,
# then sum.
self.input_size = int(np.sum(block_sizes(self.input_shape)))
self.output_size = int(np.sum(block_sizes(self.output_shape)))
self.input_size = shape_to_size(self.input_shape)
self.output_size = shape_to_size(self.output_shape)

self.shape = (self.output_shape, self.input_shape)
self.matrix_shape = (self.output_size, self.input_size)
Expand Down Expand Up @@ -320,8 +320,8 @@ def freeze(self, argnum: int, val: Union[JaxArray, BlockArray]) -> Operator:
def concat_args(args):
# Creates a blockarray with args and the frozen value in the correct place
# Eg if this operator takes a blockarray with two blocks, then
# concat_args(args) = BlockArray.array([val, args]) if argnum = 0
# concat_args(args) = BlockArray.array([args, val]) if argnum = 1
# concat_args(args) = snp.blockarray([val, args]) if argnum = 0
# concat_args(args) = snp.blockarray([args, val]) if argnum = 1

if isinstance(args, (DeviceArray, np.ndarray)):
# In the case that the original operator takes a blcokarray with two
Expand All @@ -336,7 +336,7 @@ def concat_args(args):
arg_list.append(args[i - 1])
else:
arg_list.append(val)
return BlockArray.array(arg_list)
return snp.blockarray(arg_list)

return Operator(
input_shape=input_shape,
Expand Down
Loading

0 comments on commit 4381be5

Please sign in to comment.