Skip to content

Commit

Permalink
matlab: modernize
Browse files Browse the repository at this point in the history
  • Loading branch information
scivision committed Mar 26, 2024
1 parent a218ae3 commit 7dd757b
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 44 deletions.
6 changes: 4 additions & 2 deletions +airtools/kaczmarz.m
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@
end

% The sizes of A, b and x must match.
if size(b,1) ~= m || size(b,2) ~= 1
error('The sizes of A and b do not match')
if size(b,1) ~= m
error("rows of A " + int2str(m) + " and columns of b " + int2str(size(b,1)) + " do not match")
elseif size(b,2) ~= 1
error('b must be a column vector')
elseif size(x0,1) ~= n || size(x0,2) ~= 1
error('The size of x0 does not match the problem')
end
Expand Down
6 changes: 3 additions & 3 deletions +airtools/logmart.m
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
y(y<=1e-8) = 1e-8;

if isempty(x0)
x=(A'*y)./sum(A(:));
xA=A*x;
x=x.*max(y(:))/max(xA(:));
x = (A'*y) ./ sum(A(:));
xA = A*x;
x = x.*max(y(:)) / max(xA(:));
else
x=x0;
end
Expand Down
4 changes: 2 additions & 2 deletions +airtools/maxent.m
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
tau0 = 1e-3; % Initial threshold used in secant root finder.

% Initialization.
[m,n] = size(A); x_lambda = zeros(n,length(lambda)); F = zeros(maxit,1);
n = size(A,2); x_lambda = zeros(n,length(lambda)); F = zeros(maxit,1);
if (min(lambda) <= 0)
error('Regularization parameter lambda must be positive')
end
Expand All @@ -51,7 +51,7 @@

% Start the nonlinear CG iteration here.
delta_x = x; dF = 1; it = 0; phi0 = p'*g;
while (norm(delta_x) > minstep*norm(x) && dF > flat && it < maxit && phi0 < 0)
while (all(norm(delta_x) > minstep*norm(x), 'all') && dF > flat && it < maxit && all(phi0 < 0, 'all'))
it = it + 1;

% Compute some CG quantities.
Expand Down
29 changes: 18 additions & 11 deletions +airtools/unitTest.m
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ function test_condition(tc, name)
%% well-posed, well-conditioned problem?
A = tc.TestData.(name).A;
b = tc.TestData.(name).b;
[U,s] = csvd(A);
picard(U, s, b);
[U,s] = airtools.csvd(A);
airtools.picard(U, s, b);
disp("Condition #: " + num2str(cond(A)))
end

Expand All @@ -70,7 +70,7 @@ function test_pseudoinv(tc, name)
x_true = tc.TestData.(name).x_true;

x_pinv = pinv(A)*b;
tc.verifyEqual(x_pinv, x_true, 'RelTol', 0.005)
tc.verifyEqual(x_pinv, x_true, RelTol=0.005)
end

function test_logmart(tc, name)
Expand All @@ -80,19 +80,26 @@ function test_logmart(tc, name)

tc.assumeGreaterThanOrEqual(b, 0)

x_logmart = logmart(b,A);
tc.verifyEqual(x_logmart, x_true, 'RelTol', 0.1)
x_logmart = airtools.logmart(b,A);
tc.verifyEqual(x_logmart, x_true, RelTol=0.1)
end

function test_maxent(tc, name)
A = tc.TestData.(name).A;
b = tc.TestData.(name).b;
x_true = tc.TestData.(name).x_true;
% x_python = py.airtools.maxent.maxent(A,b,0.00002)
% py.numpy.testing.assert_array_almost_equal(x_python,x_true)

x_maxent = maxent(A,b,0.001);
tc.verifyEqual(x_maxent, x_true, 'RelTol', 0.05, 'maxent')
x_maxent = airtools.maxent(A,b,0.001);
tc.verifyEqual(x_maxent, x_true, RelTol=0.05)
end

function test_maxent_python(tc, name)
A = tc.TestData.(name).A;
b = tc.TestData.(name).b;
x_true = tc.TestData.(name).x_true;

x_python = py.airtools.maxent.maxent(A,b,0.00002)
tc.verifyEqual(x_python, x_true)
end

function test_kart(tc,name)
Expand All @@ -102,8 +109,8 @@ function test_kart(tc,name)
% x_python = py.airtools.kaczmarz.kaczmarz(A,b,200)[0]
% py.numpy.testing.assert_array_almost_equal(x_python,x_true)
%
x_kaczmarz = kaczmarz(A,b,250);
tc.verifyEqual(x_kaczmarz, x_true, 'RelTol', 0.05, 'kaczmarz ART')
x_kaczmarz = airtools.kaczmarz(A,b,250);
tc.verifyEqual(x_kaczmarz, x_true, RelTol=0.05)
end

end
Expand Down
1 change: 0 additions & 1 deletion src/airtools/maxent.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def maxent(A, b, lamb, w=None, x0=None) -> tuple:

# Treat each lambda separately.
for j in range(Nlambda):

# Prepare for nonlinear CG iteration.
l2 = lamb[j] ** 2.0
x = x0
Expand Down
1 change: 0 additions & 1 deletion src/airtools/picard.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@


def picard(U, s, b, d=0) -> tuple:

n, ps = np.atleast_2d(s).T.shape

beta = np.abs(np.asfortranarray(U[:, :n]).T.dot(b))
Expand Down
4 changes: 3 additions & 1 deletion src/airtools/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def matrices(name: str) -> np.ndarray:
[0.126491, 0.357771, 1.41421, 4.0],
]
),
"fiedler": np.array([[0, 1, 2, 3], [1, 0, 1, 2], [2, 1, 0, 1], [3, 2, 1, 0]]),
"fiedler": np.array(
[[0, 1, 2, 3], [1, 0, 1, 2], [2, 1, 0, 1], [3, 2, 1, 0]], dtype=np.float64
),
"hilbert": np.array(
[
[1.0, 1 / 2, 1 / 3, 1 / 4],
Expand Down
13 changes: 0 additions & 13 deletions src/airtools/tests/matlab_engine.py

This file was deleted.

32 changes: 22 additions & 10 deletions src/airtools/tests/test_matlab.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
import functools

import numpy as np

Expand All @@ -8,7 +9,7 @@
import airtools

try:
from .matlab_engine import matlab_engine
import matlab.engine

matlab_skip = False
except (ImportError, RuntimeError):
Expand All @@ -26,18 +27,27 @@
used = ("identity", "fiedler")


@functools.cache
def matlab_engine():
"""
cached to use in multiple tests without restarting
"""
eng = matlab.engine.start_matlab("-nojvm")
eng.addpath(eng.genpath(str(Rmatlab)), nargout=0)
return eng


@pytest.mark.skipif(matlab_skip, reason="Matlab Engine not available")
@pytest.mark.parametrize("name", used)
def test_maxent(matrices, name):

eng = matlab_engine()
eng.addpath(eng.genpath(str(Rmatlab)), nargout=0)

A = matrices
b = A @ x
lamb = 2.5e-5

x_matlab = eng.airtools.maxent(A, b, lamb).squeeze()
x_matlab = eng.airtools.maxent(A, b, lamb)
x_matlab = np.asarray(x_matlab).squeeze()

assert x_matlab == approx(x, rel=0.01)

Expand All @@ -48,17 +58,20 @@ def test_maxent(matrices, name):
@pytest.mark.skipif(matlab_skip, reason="Matlab Engine not available")
@pytest.mark.parametrize("name", used)
def test_kaczmarz(matrices, name):

eng = matlab_engine()
eng.addpath(eng.genpath(str(Rmatlab)), nargout=0)

A = matrices
b = A @ x
max_iter = 200
lamb = 1.0
x0 = np.zeros_like(x)

x_matlab = eng.airtools.kaczmarz(A, b, max_iter, x0, {"lambda": lamb}).squeeze()
print("kaczmarz: A.shape", A.shape)
print("kaczmarz: b.shape", b.shape)
print("kaczmarz: x0.shape", x0.shape)

x_matlab = eng.airtools.kaczmarz(A, b[:, None], max_iter, x0[:, None], {"lambda": lamb})
x_matlab = np.asarray(x_matlab).squeeze()
assert x_matlab == approx(x, rel=0.01)

x_est = airtools.kaczmarz(A, b, x0=x0, max_iter=max_iter, lamb=lamb)[0]
Expand All @@ -68,9 +81,7 @@ def test_kaczmarz(matrices, name):
@pytest.mark.skipif(matlab_skip, reason="Matlab Engine not available")
@pytest.mark.parametrize("name", used)
def test_logmart(matrices, name):

eng = matlab_engine()
eng.addpath(eng.genpath(str(Rmatlab)), nargout=0)

A = matrices

Expand All @@ -79,7 +90,8 @@ def test_logmart(matrices, name):
max_iter = 2000
sigma = 1.0

x_matlab = eng.airtools.logmart(b, A, relax, [], sigma, max_iter).squeeze()
x_matlab = eng.airtools.logmart(b, A, relax, eng.double.empty(), sigma, max_iter)
x_matlab = np.asarray(x_matlab).squeeze()
assert x_matlab == approx(x, rel=0.01)

x_est = airtools.logmart(A, b, relax=relax, sigma=sigma, max_iter=max_iter)[0]
Expand Down

0 comments on commit 7dd757b

Please sign in to comment.