Skip to content

Commit

Permalink
add access to linear residual
Browse files Browse the repository at this point in the history
  • Loading branch information
gardner48 committed May 17, 2024
1 parent a5bf7ac commit 1b3805b
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 12 deletions.
12 changes: 9 additions & 3 deletions examples/arkode/C_serial/ark_analytic.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,21 @@ static int f(sunrealtype t, N_Vector y, N_Vector ydot, void* user_data);
static int Jac(sunrealtype t, N_Vector y, N_Vector fy, SUNMatrix J,
void* user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);

/* >>> Function to access delta <<< */
int Access(long int step, int stage, int iter, N_Vector delta, void* user_data)
/* >>> Function to access delta and linear residuals <<< */
int Access(long int step, int stage, int iter, N_Vector delta,
N_Vector ls_res, N_Vector ls_res_r, void* user_data)
{
printf("Step %li = \n", step);
printf("Stage %i = \n", stage);
printf("Iter %i = \n", iter);
printf("Delta:\n");
N_VPrintFile(delta, stdout);
printf("LS Res:\n");
N_VPrintFile(ls_res, stdout);
printf("LS Res Relaxed:\n");
N_VPrintFile(ls_res_r, stdout);
printf("\n");
return 0;
}

/* Private function to check function return values */
Expand Down Expand Up @@ -142,7 +148,7 @@ int main(void)
if (check_flag(&flag, "ARKodeSetLinear", 1)) { return 1; }

/* >>> Attach function to access delta <<< */
flag = ARKStepSetAccessDeltaFn(arkode_mem, Access);
flag = ARKStepSetAccessFn(arkode_mem, Access);
if (check_flag(&flag, "ARKStepSetAccessDeltaFn", 1)) { return 1; }

/* Open output stream for results, output comment line */
Expand Down
7 changes: 4 additions & 3 deletions include/arkode/arkode_arkstep.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ extern "C" {
#endif

/* Callback to access delta */
typedef int (*ARKStepAccessDeltaFn)(long int step, int stage, int iter,
N_Vector delta, void* user_data);
typedef int (*ARKStepAccessFn)(long int step, int stage, int iter,
N_Vector delta, N_Vector ls_res,
N_Vector ls_res2, void* user_data);

/* -----------------
* ARKStep Constants
Expand Down Expand Up @@ -106,7 +107,7 @@ SUNDIALS_EXPORT int ARKStepCreateMRIStepInnerStepper(void* arkode_mem,
MRIStepInnerStepper* stepper);

SUNDIALS_EXPORT
int ARKStepSetAccessDeltaFn(void* arkode_mem, ARKStepAccessDeltaFn access_fn);
int ARKStepSetAccessFn(void* arkode_mem, ARKStepAccessFn access_fn);

/* --------------------------------------------------------------------------
* Deprecated Functions -- all are superseded by shared ARKODE-level routines
Expand Down
2 changes: 1 addition & 1 deletion src/arkode/arkode_arkstep_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ typedef struct ARKodeARKStepMemRec
sunrealtype* stage_times; /* workspace for applying forcing */
sunrealtype* stage_coefs; /* workspace for applying forcing */

ARKStepAccessDeltaFn access_fn;
ARKStepAccessFn access_fn;

}* ARKodeARKStepMem;

Expand Down
2 changes: 1 addition & 1 deletion src/arkode/arkode_arkstep_io.c
Original file line number Diff line number Diff line change
Expand Up @@ -2408,7 +2408,7 @@ int ARKStepGetNumRelaxSolveIters(void* arkode_mem, long int* iters)
EOF
===============================================================*/

int ARKStepSetAccessDeltaFn(void* arkode_mem, ARKStepAccessDeltaFn access_fn)
int ARKStepSetAccessFn(void* arkode_mem, ARKStepAccessFn access_fn)
{
ARKodeMem ark_mem;
ARKodeARKStepMem step_mem;
Expand Down
8 changes: 7 additions & 1 deletion src/arkode/arkode_arkstep_nls.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,19 @@

int arkAccessDeltaFn(int iter, N_Vector delta, void* arkode_mem)
{
int retval = 0;
ARKodeMem ark_mem = (ARKodeMem)arkode_mem;

ARKodeARKStepMem step_mem;
int retval = arkStep_AccessStepMem(ark_mem, __func__, &step_mem);
retval = arkStep_AccessStepMem(ark_mem, __func__, &step_mem);
if (retval != ARK_SUCCESS) { return retval; }

ARKLsMem arkls_mem;
retval = arkLs_AccessLMem(ark_mem, __func__, &arkls_mem);
if (retval != ARK_SUCCESS) { return (retval); }

retval = step_mem->access_fn(ark_mem->nst, step_mem->istage, iter, delta,
arkls_mem->ytemp, arkls_mem->ytemp2,
ark_mem->user_data);
if (retval != ARK_SUCCESS) { return retval; }

Expand Down
32 changes: 29 additions & 3 deletions src/arkode/arkode_ls.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#include "arkode_impl.h"
#include "arkode_ls_impl.h"
#include "sundials/sundials_matrix.h"
#include "sundials/sundials_nvector.h"

/* constants */
#define MIN_INC_MULT SUN_RCONST(1000.0)
Expand Down Expand Up @@ -264,6 +266,15 @@ int ARKodeSetLinearSolver(void* arkode_mem, SUNLinearSolver LS, SUNMatrix A)
return (ARKLS_MEM_FAIL);
}

if (!arkAllocVec(ark_mem, ark_mem->tempv1, &(arkls_mem->ytemp2)))
{
arkProcessError(ark_mem, ARKLS_MEM_FAIL, __LINE__, __func__, __FILE__,
MSG_LS_MEM_FAIL);
free(arkls_mem);
arkls_mem = NULL;
return (ARKLS_MEM_FAIL);
}

if (!arkAllocVec(ark_mem, ark_mem->tempv1, &(arkls_mem->x)))
{
arkProcessError(ark_mem, ARKLS_MEM_FAIL, __LINE__, __func__, __FILE__,
Expand Down Expand Up @@ -3383,9 +3394,12 @@ int arkLsSolve(ARKodeMem ark_mem, N_Vector b, sunrealtype tnow, N_Vector ynow,
}
}

/* Call solver, and copy x to b */
/* Call solver */
retval = SUNLinSolSolve(arkls_mem->LS, arkls_mem->A, arkls_mem->x, b, delta);
N_VScale(ONE, arkls_mem->x, b);

/* compute the residual r = Ax - b */
SUNMatMatvec(arkls_mem->savedJ, arkls_mem->x, arkls_mem->ytemp);
N_VLinearSum(ONE, arkls_mem->ytemp, -ONE, b, arkls_mem->ytemp);

/* If using a direct or matrix-iterative solver, scale the correction to
account for change in gamma (this is only beneficial if M==I) */
Expand All @@ -3399,9 +3413,16 @@ int arkLsSolve(ARKodeMem ark_mem, N_Vector b, sunrealtype tnow, N_Vector ynow,
__FILE__, "An error occurred in ark_step_getgammas");
return (arkls_mem->last_flag);
}
if (gamrat != ONE) { N_VScale(TWO / (ONE + gamrat), b, b); }
if (gamrat != ONE) { N_VScale(TWO / (ONE + gamrat), arkls_mem->x, arkls_mem->x); }
}

/* compute the relaxed residual r = Ax - b */
SUNMatMatvec(arkls_mem->savedJ, arkls_mem->x, arkls_mem->ytemp2);
N_VLinearSum(ONE, arkls_mem->ytemp2, -ONE, b, arkls_mem->ytemp2);

/* copy x to b */
N_VScale(ONE, arkls_mem->x, b);

/* Retrieve statistics from iterative linear solvers */
resnorm = ZERO;
nli_inc = 0;
Expand Down Expand Up @@ -3498,6 +3519,11 @@ int arkLsFree(ARKodeMem ark_mem)
N_VDestroy(arkls_mem->ytemp);
arkls_mem->ytemp = NULL;
}
if (arkls_mem->ytemp2)
{
N_VDestroy(arkls_mem->ytemp2);
arkls_mem->ytemp2 = NULL;
}
if (arkls_mem->x)
{
N_VDestroy(arkls_mem->x);
Expand Down
1 change: 1 addition & 0 deletions src/arkode/arkode_ls_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ typedef struct ARKLsMemRec
SUNMatrix A; /* A = M - gamma * df/dy */
SUNMatrix savedJ; /* savedJ = old Jacobian */
N_Vector ytemp; /* temp vector passed to jtimes and psolve */
N_Vector ytemp2;
N_Vector x; /* solution vector used by SUNLinearSolver */
N_Vector ycur; /* ptr to current y vector in ARKLs solve */
N_Vector fcur; /* ptr to current fcur = fI(tcur, ycur) */
Expand Down

0 comments on commit 1b3805b

Please sign in to comment.