Skip to content

Commit

Permalink
Merge branch 'main' into benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
ZedongPeng authored Feb 1, 2024
2 parents 8d051b0 + d5aca8b commit 66f5025
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 34 deletions.
50 changes: 16 additions & 34 deletions pyomo/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import textwrap
import types

from operator import attrgetter

from pyomo.common.collections import Sequence, Mapping
from pyomo.common.deprecation import (
deprecated,
Expand Down Expand Up @@ -1688,11 +1690,9 @@ def __call__(
ans.reset()
else:
# Copy over any Dict definitions
for k in self._decl_order:
for k, v in self._data.items():
if preserve_implicit or k in self._declared:
v = self._data[k]
ans._data[k] = _tmp = v(preserve_implicit=preserve_implicit)
ans._decl_order.append(k)
if k in self._declared:
ans._declared.add(k)
_tmp._parent = ans
Expand Down Expand Up @@ -2383,12 +2383,7 @@ class ConfigDict(ConfigBase, Mapping):

content_filters = {None, 'all', 'userdata'}

__slots__ = (
'_decl_order',
'_declared',
'_implicit_declaration',
'_implicit_domain',
)
__slots__ = ('_declared', '_implicit_declaration', '_implicit_domain')
_all_slots = set(__slots__ + ConfigBase.__slots__)

def __init__(
Expand All @@ -2399,7 +2394,6 @@ def __init__(
implicit_domain=None,
visibility=0,
):
self._decl_order = []
self._declared = set()
self._implicit_declaration = implicit
if (
Expand Down Expand Up @@ -2478,18 +2472,17 @@ def __delitem__(self, key):
_key = str(key).replace(' ', '_')
del self._data[_key]
# Clean up the other data structures
self._decl_order.remove(_key)
self._declared.discard(_key)

def __contains__(self, key):
_key = str(key).replace(' ', '_')
return _key in self._data

def __len__(self):
return self._decl_order.__len__()
return len(self._data)

def __iter__(self):
return (self._data[key]._name for key in self._decl_order)
return map(attrgetter('_name'), self._data.values())

def __getattr__(self, name):
# Note: __getattr__ is only called after all "usual" attribute
Expand Down Expand Up @@ -2526,13 +2519,12 @@ def keys(self):

def values(self):
self._userAccessed = True
for key in self._decl_order:
yield self[key]
return map(self.__getitem__, self._data)

def items(self):
self._userAccessed = True
for key in self._decl_order:
yield (self._data[key]._name, self[key])
for key, val in self._data.items():
yield (val._name, self[key])

@deprecated('The iterkeys method is deprecated. Use dict.keys().', version='6.0')
def iterkeys(self):
Expand Down Expand Up @@ -2561,7 +2553,6 @@ def _add(self, name, config):
% (name, self.name(True))
)
self._data[_name] = config
self._decl_order.append(_name)
config._parent = self
config._name = name
return config
Expand Down Expand Up @@ -2613,10 +2604,7 @@ def add(self, name, config):
def value(self, accessValue=True):
if accessValue:
self._userAccessed = True
return {
cfg._name: cfg.value(accessValue)
for cfg in map(self._data.__getitem__, self._decl_order)
}
return {cfg._name: cfg.value(accessValue) for cfg in self._data.values()}

def set_value(self, value, skip_implicit=False):
if value is None:
Expand All @@ -2636,7 +2624,7 @@ def set_value(self, value, skip_implicit=False):
_key = str(key).replace(' ', '_')
if _key in self._data:
# str(key) may not be key... store the mapping so that
# when we later iterate over the _decl_order, we can map
# when we later iterate over the _data, we can map
# the local keys back to the incoming value keys.
_decl_map[_key] = key
else:
Expand All @@ -2659,7 +2647,7 @@ def set_value(self, value, skip_implicit=False):
# We want to set the values in declaration order (so that
# things are deterministic and in case a validation depends
# on the order)
for key in self._decl_order:
for key in self._data:
if key in _decl_map:
self[key] = value[_decl_map[key]]
# implicit data is declared at the end (in sorted order)
Expand All @@ -2675,16 +2663,11 @@ def set_value(self, value, skip_implicit=False):
def reset(self):
# Reset the values in the order they were declared. This
# allows reset functions to have a deterministic ordering.
def _keep(self, key):
keep = key in self._declared
if keep:
self._data[key].reset()
for key, val in list(self._data.items()):
if key in self._declared:
val.reset()
else:
del self._data[key]
return keep

# this is an in-place slice of a list...
self._decl_order[:] = [x for x in self._decl_order if _keep(self, x)]
self._userAccessed = False
self._userSet = False

Expand All @@ -2695,8 +2678,7 @@ def _data_collector(self, level, prefix, visibility=None, docMode=False):
yield (level, prefix, None, self)
if level is not None:
level += 1
for key in self._decl_order:
cfg = self._data[key]
for cfg in self._data.values():
yield from cfg._data_collector(level, cfg._name + ': ', visibility, docMode)


Expand Down
41 changes: 41 additions & 0 deletions pyomo/common/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3027,6 +3027,47 @@ def fcn(self):
self.assertEqual(add_docstring_list("", ExampleClass.CONFIG), ref)
self.assertIn('add_docstring_list is deprecated', LOG.getvalue())

def test_declaration_in_init(self):
class CustomConfig(ConfigDict):
def __init__(
self,
description=None,
doc=None,
implicit=False,
implicit_domain=None,
visibility=0,
):
super().__init__(
description=description,
doc=doc,
implicit=implicit,
implicit_domain=implicit_domain,
visibility=visibility,
)

self.declare('time_limit', ConfigValue(domain=NonNegativeFloat))
self.declare('stream_solver', ConfigValue(domain=bool))

cfg = CustomConfig()
OUT = StringIO()
cfg.display(ostream=OUT)
# Note: pypy outputs "None" as "null"
self.assertEqual(
"time_limit: None\nstream_solver: None\n",
OUT.getvalue().replace('null', 'None'),
)

# Test that creating a copy of a ConfigDict with declared fields
# in the __init__ does not result in duplicate outputs in the
# display (reported in PR #3113)
cfg2 = cfg({'time_limit': 10, 'stream_solver': 0})
OUT = StringIO()
cfg2.display(ostream=OUT)
self.assertEqual(
"time_limit: 10.0\nstream_solver: false\n",
OUT.getvalue().replace('null', 'None'),
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 66f5025

Please sign in to comment.