Skip to content

Commit

Permalink
Merge pull request #595 from lnccbrown/fix-suppress-output-pytensor-f…
Browse files Browse the repository at this point in the history
…unction

fix: suppress buggy output from pytensor.function
  • Loading branch information
cpaniaguam authored Oct 23, 2024
2 parents fa12018 + a74ebe5 commit 8e0908b
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 12 deletions.
42 changes: 42 additions & 0 deletions src/hssm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
_parse_bambi().
"""

import contextlib
import itertools
import logging
import os
from copy import deepcopy
from typing import Any, Literal, cast

Expand Down Expand Up @@ -548,3 +550,43 @@ def _rearrange_data(data: pd.DataFrame | np.ndarray) -> pd.DataFrame | np.ndarra
def _split_array(data: np.ndarray | list[int], divisor: int) -> list[np.ndarray]:
num_splits = len(data) // divisor + (1 if len(data) % divisor != 0 else 0)
return [tmp.astype(int) for tmp in np.array_split(data, num_splits)]


class SuppressOutput:
"""Context manager for suppressing output.
This context manager redirects both stdout and stderr to `os.devnull`,
effectively silencing all output during the execution of the block.
It also disables logging by setting the logging level to `CRITICAL`.
Examples
--------
>>> with SuppressOutput():
... grad_func = pytensor.function(
... [v, a, z, t],
... grad,
... mode=nan_guard_mode,
... )
Methods
-------
__enter__()
Redirects stdout and stderr, and disables logging.
__exit__(exc_type, exc_value, traceback)
Restores stdout, stderr, and logging upon exit.
"""

def __enter__(self): # noqa: D105
self._null_file = open(os.devnull, "w")
self._stdout_context = contextlib.redirect_stdout(self._null_file)
self._stderr_context = contextlib.redirect_stderr(self._null_file)
self._stdout_context.__enter__()
self._stderr_context.__enter__()
logging.disable(logging.CRITICAL) # Disable logging

def __exit__(self, exc_type, exc_value, traceback): # noqa: D105
self._stdout_context.__exit__(exc_type, exc_value, traceback)
self._stderr_context.__exit__(exc_type, exc_value, traceback)
self._null_file.close()
logging.disable(logging.NOTSET) # Re-enable logging
31 changes: 19 additions & 12 deletions tests/test_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from pathlib import Path
from itertools import product
from hssm.utils import SuppressOutput

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -109,11 +110,15 @@ def test_analytical_gradient():
size = cav_data_numpy.shape[0]
logp = logp_ddm(cav_data_numpy, v, a, z, t).sum()
grad = pt.grad(logp, wrt=[v, a, z, t])
grad_func = pytensor.function(
[v, a, z, t],
grad,
mode=nan_guard_mode,
)

# Temporary measure to suppress output from pytensor.function
# See issues #594 in hssm and #1037 in pymc-devs/pytensor repos
with SuppressOutput():
grad_func = pytensor.function(
[v, a, z, t],
grad,
mode=nan_guard_mode,
)
v_test = np.random.normal(size=size)
a_test = np.random.uniform(0.0001, 2, size=size)
z_test = np.random.uniform(0.1, 1.0, size=size)
Expand All @@ -123,13 +128,15 @@ def test_analytical_gradient():

assert np.all(np.isfinite(grad), axis=None), "Gradient contains non-finite values."

grad_func_sdv = pytensor.function(
[v, a, z, t, sv],
pt.grad(
logp_ddm_sdv(cav_data_numpy, v, a, z, t, sv).sum(), wrt=[v, a, z, t, sv]
),
mode=nan_guard_mode,
)
# Also temporary
with SuppressOutput():
grad_func_sdv = pytensor.function(
[v, a, z, t, sv],
pt.grad(
logp_ddm_sdv(cav_data_numpy, v, a, z, t, sv).sum(), wrt=[v, a, z, t, sv]
),
mode=nan_guard_mode,
)

grad_sdv = np.array(grad_func_sdv(v_test, a_test, z_test, t_test, sv_test))

Expand Down

0 comments on commit 8e0908b

Please sign in to comment.