diff --git a/pyomo/solvers/plugins/solvers/xpress_direct.py b/pyomo/solvers/plugins/solvers/xpress_direct.py index 1064a985a72..0e7f4dd7ce3 100644 --- a/pyomo/solvers/plugins/solvers/xpress_direct.py +++ b/pyomo/solvers/plugins/solvers/xpress_direct.py @@ -257,12 +257,24 @@ def __init__(self, **kwds): def available(self, exception_flag=True): """True if the solver is available.""" - if exception_flag and not xpress_available: - xpress.log_import_warning(logger=__name__) - raise ApplicationError( - "No Python bindings available for %s solver plugin" % (type(self),) - ) - return bool(xpress_available) + if not xpress_available: + if exception_flag: + xpress.log_import_warning(logger=__name__) + raise ApplicationError( + "No Python bindings available for %s solver plugin" % (type(self),) + ) + return False + + # Check that there is a valid license + try: + xpress.init() + return True + except: + if exception_flag: + raise + return False + finally: + xpress.free() def _apply_solver(self): StaleFlagManager.mark_all_as_stale() diff --git a/pyomo/solvers/tests/checks/test_xpress_persistent.py b/pyomo/solvers/tests/checks/test_xpress_persistent.py index 50966ab6184..d5107f9dda7 100644 --- a/pyomo/solvers/tests/checks/test_xpress_persistent.py +++ b/pyomo/solvers/tests/checks/test_xpress_persistent.py @@ -8,12 +8,18 @@ # rights in this software. # This software is distributed under the 3-clause BSD License. # ___________________________________________________________________________ +import logging import pyomo.common.unittest as unittest import pyomo.environ as pe +import pyomo.solvers.plugins.solvers.xpress_direct as xpd + +from pyomo.common.log import LoggingIntercept from pyomo.core.expr.taylor_series import taylor_series_expansion -from pyomo.solvers.plugins.solvers.xpress_direct import xpress_available from pyomo.opt.results.solver import TerminationCondition, SolverStatus +from pyomo.solvers.plugins.solvers.xpress_persistent import XpressPersistent + +xpress_available = pe.SolverFactory('xpress_persistent').available(False) class TestXpressPersistent(unittest.TestCase): @@ -329,3 +335,52 @@ def test_nonconvexqp_infeasible(self): self.assertEqual( results.solver.termination_condition, TerminationCondition.infeasible ) + + def test_available(self): + class mock_xpress(object): + def __init__(self, importable, initable): + self._initable = initable + xpd.xpress_available = importable + + def log_import_warning(self, logger): + logging.getLogger(logger).warning("import warning") + + def init(self): + if not self._initable: + raise RuntimeError("init failed") + + def free(self): + pass + + orig = xpd.xpress, xpd.xpress_available + try: + _xpress_persistent = XpressPersistent + xpd.xpress = mock_xpress(True, True) + with LoggingIntercept() as LOG: + self.assertTrue(XpressPersistent().available(True)) + self.assertTrue(XpressPersistent().available(False)) + self.assertEqual(LOG.getvalue(), "") + + xpd.xpress = mock_xpress(False, False) + with LoggingIntercept() as LOG: + self.assertFalse(XpressPersistent().available(False)) + self.assertEqual(LOG.getvalue(), "") + with LoggingIntercept() as LOG: + with self.assertRaisesRegex( + xpd.ApplicationError, + "No Python bindings available for .*XpressPersistent.* " + "solver plugin", + ): + XpressPersistent().available(True) + self.assertEqual(LOG.getvalue(), "import warning\n") + + xpd.xpress = mock_xpress(True, False) + with LoggingIntercept() as LOG: + self.assertFalse(XpressPersistent().available(False)) + self.assertEqual(LOG.getvalue(), "") + with LoggingIntercept() as LOG: + with self.assertRaisesRegex(RuntimeError, "init failed"): + XpressPersistent().available(True) + self.assertEqual(LOG.getvalue(), "") + finally: + xpd.xpress, xpd.xpress_available = orig