diff --git a/src/hssm/utils.py b/src/hssm/utils.py index 79adf6cb..388d21c6 100644 --- a/src/hssm/utils.py +++ b/src/hssm/utils.py @@ -9,8 +9,10 @@ _parse_bambi(). """ +import contextlib import itertools import logging +import os from copy import deepcopy from typing import Any, Literal, cast @@ -548,3 +550,33 @@ def _rearrange_data(data: pd.DataFrame | np.ndarray) -> pd.DataFrame | np.ndarra def _split_array(data: np.ndarray | list[int], divisor: int) -> list[np.ndarray]: num_splits = len(data) // divisor + (1 if len(data) % divisor != 0 else 0) return [tmp.astype(int) for tmp in np.array_split(data, num_splits)] + + +class SuppressOutput: + """Context manager for output suppressing. + + Example + + ```python + with SuppressOutput(): + grad_func = pytensor.function( + [v, a, z, t], + grad, + mode=nan_guard_mode, + ) + ``` + """ + + def __enter__(self): # noqa: D105 + self._null_file = open(os.devnull, "w") + self._stdout_context = contextlib.redirect_stdout(self._null_file) + self._stderr_context = contextlib.redirect_stderr(self._null_file) + self._stdout_context.__enter__() + self._stderr_context.__enter__() + logging.disable(logging.CRITICAL) # Disable logging + + def __exit__(self, exc_type, exc_value, traceback): # noqa: D105 + self._stdout_context.__exit__(exc_type, exc_value, traceback) + self._stderr_context.__exit__(exc_type, exc_value, traceback) + self._null_file.close() + logging.disable(logging.NOTSET) # Re-enable logging diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 8e782430..61f0a383 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -4,11 +4,9 @@ old implementation of WFPT from (https://github.com/hddm-devs/hddm) """ -import contextlib -import logging -import os from pathlib import Path from itertools import product +from hssm.utils import SuppressOutput import numpy as np import pandas as pd @@ -28,25 +26,6 @@ hssm.set_floatX("float32") - -# Temporary measure to suppress output from pytensor.function -# See issues #594 in hssm and #1037 in pymc-devs/pytensor repos -class SuppressOutput: - def __enter__(self): - self._null_file = open(os.devnull, "w") - self._stdout_context = contextlib.redirect_stdout(self._null_file) - self._stderr_context = contextlib.redirect_stderr(self._null_file) - self._stdout_context.__enter__() - self._stderr_context.__enter__() - logging.disable(logging.CRITICAL) # Disable logging - - def __exit__(self, exc_type, exc_value, traceback): - self._stdout_context.__exit__(exc_type, exc_value, traceback) - self._stderr_context.__exit__(exc_type, exc_value, traceback) - self._null_file.close() - logging.disable(logging.NOTSET) # Re-enable logging - - # def test_logp(data_fixture): # """ # This function compares new and old implementation of logp calculation @@ -132,6 +111,8 @@ def test_analytical_gradient(): logp = logp_ddm(cav_data_numpy, v, a, z, t).sum() grad = pt.grad(logp, wrt=[v, a, z, t]) + # Temporary measure to suppress output from pytensor.function + # See issues #594 in hssm and #1037 in pymc-devs/pytensor repos with SuppressOutput(): grad_func = pytensor.function( [v, a, z, t], @@ -147,6 +128,7 @@ def test_analytical_gradient(): assert np.all(np.isfinite(grad), axis=None), "Gradient contains non-finite values." + # Also temporary with SuppressOutput(): grad_func_sdv = pytensor.function( [v, a, z, t, sv],