Skip to content

Commit

Permalink
refactor: replace custom SuppressOutput implementation with import fr…
Browse files Browse the repository at this point in the history
…om hssm.utils
  • Loading branch information
cpaniaguam committed Oct 21, 2024
1 parent 929d2b8 commit 7e4a5a2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
32 changes: 32 additions & 0 deletions src/hssm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
_parse_bambi().
"""

import contextlib
import itertools
import logging
import os
from copy import deepcopy
from typing import Any, Literal, cast

Expand Down Expand Up @@ -548,3 +550,33 @@ def _rearrange_data(data: pd.DataFrame | np.ndarray) -> pd.DataFrame | np.ndarra
def _split_array(data: np.ndarray | list[int], divisor: int) -> list[np.ndarray]:
num_splits = len(data) // divisor + (1 if len(data) % divisor != 0 else 0)
return [tmp.astype(int) for tmp in np.array_split(data, num_splits)]


class SuppressOutput:
"""Context manager for output suppressing.
Example
```python
with SuppressOutput():
grad_func = pytensor.function(
[v, a, z, t],
grad,
mode=nan_guard_mode,
)
```
"""

def __enter__(self): # noqa: D105
self._null_file = open(os.devnull, "w")
self._stdout_context = contextlib.redirect_stdout(self._null_file)
self._stderr_context = contextlib.redirect_stderr(self._null_file)
self._stdout_context.__enter__()
self._stderr_context.__enter__()
logging.disable(logging.CRITICAL) # Disable logging

def __exit__(self, exc_type, exc_value, traceback): # noqa: D105
self._stdout_context.__exit__(exc_type, exc_value, traceback)
self._stderr_context.__exit__(exc_type, exc_value, traceback)
self._null_file.close()
logging.disable(logging.NOTSET) # Re-enable logging
26 changes: 4 additions & 22 deletions tests/test_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
old implementation of WFPT from (https://github.com/hddm-devs/hddm)
"""

import contextlib
import logging
import os
from pathlib import Path
from itertools import product
from hssm.utils import SuppressOutput

import numpy as np
import pandas as pd
Expand All @@ -28,25 +26,6 @@

hssm.set_floatX("float32")


# Temporary measure to suppress output from pytensor.function
# See issues #594 in hssm and #1037 in pymc-devs/pytensor repos
class SuppressOutput:
def __enter__(self):
self._null_file = open(os.devnull, "w")
self._stdout_context = contextlib.redirect_stdout(self._null_file)
self._stderr_context = contextlib.redirect_stderr(self._null_file)
self._stdout_context.__enter__()
self._stderr_context.__enter__()
logging.disable(logging.CRITICAL) # Disable logging

def __exit__(self, exc_type, exc_value, traceback):
self._stdout_context.__exit__(exc_type, exc_value, traceback)
self._stderr_context.__exit__(exc_type, exc_value, traceback)
self._null_file.close()
logging.disable(logging.NOTSET) # Re-enable logging


# def test_logp(data_fixture):
# """
# This function compares new and old implementation of logp calculation
Expand Down Expand Up @@ -132,6 +111,8 @@ def test_analytical_gradient():
logp = logp_ddm(cav_data_numpy, v, a, z, t).sum()
grad = pt.grad(logp, wrt=[v, a, z, t])

# Temporary measure to suppress output from pytensor.function
# See issues #594 in hssm and #1037 in pymc-devs/pytensor repos
with SuppressOutput():
grad_func = pytensor.function(
[v, a, z, t],
Expand All @@ -147,6 +128,7 @@ def test_analytical_gradient():

assert np.all(np.isfinite(grad), axis=None), "Gradient contains non-finite values."

# Also temporary
with SuppressOutput():
grad_func_sdv = pytensor.function(
[v, a, z, t, sv],
Expand Down

0 comments on commit 7e4a5a2

Please sign in to comment.