diff --git a/src/hssm/utils.py b/src/hssm/utils.py index 79adf6cb..9c83716b 100644 --- a/src/hssm/utils.py +++ b/src/hssm/utils.py @@ -9,8 +9,10 @@ _parse_bambi(). """ +import contextlib import itertools import logging +import os from copy import deepcopy from typing import Any, Literal, cast @@ -548,3 +550,43 @@ def _rearrange_data(data: pd.DataFrame | np.ndarray) -> pd.DataFrame | np.ndarra def _split_array(data: np.ndarray | list[int], divisor: int) -> list[np.ndarray]: num_splits = len(data) // divisor + (1 if len(data) % divisor != 0 else 0) return [tmp.astype(int) for tmp in np.array_split(data, num_splits)] + + +class SuppressOutput: + """Context manager for suppressing output. + + This context manager redirects both stdout and stderr to `os.devnull`, + effectively silencing all output during the execution of the block. + It also disables logging by setting the logging level to `CRITICAL`. + + Examples + -------- + >>> with SuppressOutput(): + ... grad_func = pytensor.function( + ... [v, a, z, t], + ... grad, + ... mode=nan_guard_mode, + ... ) + + Methods + ------- + __enter__() + Redirects stdout and stderr, and disables logging. + + __exit__(exc_type, exc_value, traceback) + Restores stdout, stderr, and logging upon exit. + """ + + def __enter__(self): # noqa: D105 + self._null_file = open(os.devnull, "w") + self._stdout_context = contextlib.redirect_stdout(self._null_file) + self._stderr_context = contextlib.redirect_stderr(self._null_file) + self._stdout_context.__enter__() + self._stderr_context.__enter__() + logging.disable(logging.CRITICAL) # Disable logging + + def __exit__(self, exc_type, exc_value, traceback): # noqa: D105 + self._stdout_context.__exit__(exc_type, exc_value, traceback) + self._stderr_context.__exit__(exc_type, exc_value, traceback) + self._null_file.close() + logging.disable(logging.NOTSET) # Re-enable logging diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 964fe65c..61f0a383 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -6,6 +6,7 @@ from pathlib import Path from itertools import product +from hssm.utils import SuppressOutput import numpy as np import pandas as pd @@ -109,11 +110,15 @@ def test_analytical_gradient(): size = cav_data_numpy.shape[0] logp = logp_ddm(cav_data_numpy, v, a, z, t).sum() grad = pt.grad(logp, wrt=[v, a, z, t]) - grad_func = pytensor.function( - [v, a, z, t], - grad, - mode=nan_guard_mode, - ) + + # Temporary measure to suppress output from pytensor.function + # See issues #594 in hssm and #1037 in pymc-devs/pytensor repos + with SuppressOutput(): + grad_func = pytensor.function( + [v, a, z, t], + grad, + mode=nan_guard_mode, + ) v_test = np.random.normal(size=size) a_test = np.random.uniform(0.0001, 2, size=size) z_test = np.random.uniform(0.1, 1.0, size=size) @@ -123,13 +128,15 @@ def test_analytical_gradient(): assert np.all(np.isfinite(grad), axis=None), "Gradient contains non-finite values." - grad_func_sdv = pytensor.function( - [v, a, z, t, sv], - pt.grad( - logp_ddm_sdv(cav_data_numpy, v, a, z, t, sv).sum(), wrt=[v, a, z, t, sv] - ), - mode=nan_guard_mode, - ) + # Also temporary + with SuppressOutput(): + grad_func_sdv = pytensor.function( + [v, a, z, t, sv], + pt.grad( + logp_ddm_sdv(cav_data_numpy, v, a, z, t, sv).sum(), wrt=[v, a, z, t, sv] + ), + mode=nan_guard_mode, + ) grad_sdv = np.array(grad_func_sdv(v_test, a_test, z_test, t_test, sv_test))