From bf4174a6e715eb833450f8fe4ba2ec955e5a929c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 29 Mar 2023 13:51:09 +0100 Subject: [PATCH] Lock-based thread safety --- doc/source/changelog.rst | 6 +- doc/source/index.rst | 6 +- zict/__init__.py | 1 - zict/buffer.py | 91 +++++++++++++++++++++------ zict/cache.py | 21 ++----- zict/common.py | 69 +++++++++++++++++++-- zict/file.py | 51 ++++++++------- zict/func.py | 10 +-- zict/lmdb.py | 1 + zict/lru.py | 126 +++++++++++++++++++++++++------------- zict/sieve.py | 56 +++++++++-------- zict/tests/conftest.py | 22 +++++++ zict/tests/test_buffer.py | 27 +++++++- zict/tests/test_cache.py | 4 ++ zict/tests/test_common.py | 65 ++++++++++++++++++-- zict/tests/test_file.py | 10 +++ zict/tests/test_lru.py | 36 +++++++++-- zict/tests/test_sieve.py | 4 ++ zict/tests/test_utils.py | 66 +------------------- zict/tests/utils_test.py | 9 +++ zict/utils.py | 92 ---------------------------- zict/zip.py | 1 + 22 files changed, 449 insertions(+), 325 deletions(-) create mode 100644 zict/tests/conftest.py diff --git a/doc/source/changelog.rst b/doc/source/changelog.rst index b6a7ddc..40a6be5 100644 --- a/doc/source/changelog.rst +++ b/doc/source/changelog.rst @@ -14,10 +14,10 @@ Changelog (:pr:`78`) `Guido Imperiale`_ - ``LMDB`` now uses memory-mapped I/O on MacOSX and is usable on Windows. (:pr:`78`) `Guido Imperiale`_ -- The library is now partially thread-safe. - (:pr:`82`, :pr:`90`, :pr:`93`) `Guido Imperiale`_ +- The library is now almost completely thread-safe. + (:pr:`82`, :pr:`90`, :pr:`92`, :pr:`93`) `Guido Imperiale`_ - :class:`LRU` and :class:`Buffer` now support delayed eviction. - New objects :class:`Accumulator` and :class:`InsertionSortedSet`. + New object :class:`InsertionSortedSet`. (:pr:`87`) `Guido Imperiale`_ - All mappings now return proper KeysView, ItemsView, and ValuesView objects from their keys(), items(), and values() methods (:pr:`93`) `Guido Imperiale`_ diff --git a/doc/source/index.rst b/doc/source/index.rst index eb38de0..428bcb0 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -36,8 +36,8 @@ zlib-compressed, directory of files. Thread-safety ------------- -This library is only partially thread-safe. Refer to the documentation of the individual -mappings for details. +Most classes in this library are thread-safe. +Refer to the documentation of the individual mappings for exceptions. API --- @@ -64,8 +64,6 @@ API Additionally, **zict** makes available the following general-purpose objects: -.. autoclass:: Accumulator - :members: .. autoclass:: InsertionSortedSet :members: .. autoclass:: WeakValueMapping diff --git a/zict/__init__.py b/zict/__init__.py index cbcdbd4..a13f0d1 100644 --- a/zict/__init__.py +++ b/zict/__init__.py @@ -6,7 +6,6 @@ from zict.lmdb import LMDB as LMDB from zict.lru import LRU as LRU from zict.sieve import Sieve as Sieve -from zict.utils import Accumulator as Accumulator from zict.utils import InsertionSortedSet as InsertionSortedSet from zict.zip import Zip as Zip diff --git a/zict/buffer.py b/zict/buffer.py index 33c9afb..23d3e60 100644 --- a/zict/buffer.py +++ b/zict/buffer.py @@ -7,7 +7,7 @@ ValuesView, ) -from zict.common import KT, VT, ZictBase, close, flush +from zict.common import KT, VT, ZictBase, close, discard, flush, locked from zict.lru import LRU @@ -37,8 +37,10 @@ class Buffer(ZictBase[KT, VT]): Notes ----- - ``__contains__`` and ``__len__`` are thread-safe if the same methods on both - ``fast`` and ``slow`` are thread-safe. All other methods are not thread-safe. + If you call methods of this class from multiple threads, access will be fast as long + as all methods of ``fast``, plus ``slow.__contains__`` and ``slow.__delitem__``, are + fast. ``slow.__getitem__``, ``slow.__setitem__`` and callbacks are not protected + by locks. Examples -------- @@ -58,6 +60,7 @@ class Buffer(ZictBase[KT, VT]): weight: Callable[[KT, VT], float] fast_to_slow_callbacks: list[Callable[[KT, VT], None]] slow_to_fast_callbacks: list[Callable[[KT, VT], None]] + _cancel_restore: dict[KT, bool] def __init__( self, @@ -72,7 +75,14 @@ def __init__( | list[Callable[[KT, VT], None]] | None = None, ): - self.fast = LRU(n, fast, weight=weight, on_evict=[self.fast_to_slow]) + super().__init__() + self.fast = LRU( + n, + fast, + weight=weight, + on_evict=[self.fast_to_slow], + on_cancel_evict=[self._cancel_evict], + ) self.slow = slow self.weight = weight if callable(fast_to_slow_callbacks): @@ -81,6 +91,7 @@ def __init__( slow_to_fast_callbacks = [slow_to_fast_callbacks] self.fast_to_slow_callbacks = fast_to_slow_callbacks or [] self.slow_to_fast_callbacks = slow_to_fast_callbacks or [] + self._cancel_restore = {} @property def n(self) -> float: @@ -98,16 +109,38 @@ def fast_to_slow(self, key: KT, value: VT) -> None: raise def slow_to_fast(self, key: KT) -> VT: - value = self.slow[key] + self._cancel_restore[key] = False + try: + with self.unlock(): + value = self.slow[key] + if self._cancel_restore[key]: + raise KeyError(key) + finally: + self._cancel_restore.pop(key) + # Avoid useless movement for heavy values w = self.weight(key, value) if w <= self.n: + # Multithreaded edge case: + # - Thread 1 starts slow_to_fast(x) and puts it at the top of fast + # - This causes the eviction of older key(s) + # - While thread 1 is evicting older keys, thread 2 is loading fast with + # set_noevict() + # - By the time the eviction of the older key(s) is done, there is + # enough weight in fast that thread 1 will spill x + # - If the below code was just `self.fast[key] = value; del + # self.slow[key]` now the key would be in neither slow nor fast! + self.fast.set_noevict(key, value) del self.slow[key] - self.fast[key] = value - for cb in self.slow_to_fast_callbacks: - cb(key, value) + + with self.unlock(): + self.fast.evict_until_below_target() + for cb in self.slow_to_fast_callbacks: + cb(key, value) + return value + @locked def __getitem__(self, key: KT) -> VT: try: return self.fast[key] @@ -115,31 +148,41 @@ def __getitem__(self, key: KT) -> VT: return self.slow_to_fast(key) def __setitem__(self, key: KT, value: VT) -> None: - try: - del self.slow[key] - except KeyError: - pass - # This may trigger an eviction from fast to slow of older keys. - # If the weight is individually greater than n, then key/value will be stored - # into self.slow instead (see LRU.__setitem__). + with self.lock: + discard(self.slow, key) + if key in self._cancel_restore: + self._cancel_restore[key] = True self.fast[key] = value + @locked def set_noevict(self, key: KT, value: VT) -> None: """Variant of ``__setitem__`` that does not move keys from fast to slow if the total weight exceeds n """ - try: - del self.slow[key] - except KeyError: - pass + discard(self.slow, key) + if key in self._cancel_restore: + self._cancel_restore[key] = True self.fast.set_noevict(key, value) + def evict_until_below_target(self, n: float | None = None) -> None: + """Wrapper around :meth:`zict.LRU.evict_until_below_target`. + Presented here to allow easier overriding. + """ + self.fast.evict_until_below_target(n) + + @locked def __delitem__(self, key: KT) -> None: + if key in self._cancel_restore: + self._cancel_restore[key] = True try: del self.fast[key] except KeyError: del self.slow[key] + @locked + def _cancel_evict(self, key: KT, value: VT) -> None: + discard(self.slow, key) + def values(self) -> ValuesView[VT]: return BufferValuesView(self) @@ -147,7 +190,15 @@ def items(self) -> ItemsView[KT, VT]: return BufferItemsView(self) def __len__(self) -> int: - return len(self.fast) + len(self.slow) + with self.lock, self.fast.lock: + return ( + len(self.fast) + + len(self.slow) + - sum( + k in self.fast and k in self.slow + for k in chain(self._cancel_restore, self.fast._cancel_evict) + ) + ) def __iter__(self) -> Iterator[KT]: """Make sure that the iteration is not disrupted if you evict/restore a key in diff --git a/zict/cache.py b/zict/cache.py index 5489569..afe2591 100644 --- a/zict/cache.py +++ b/zict/cache.py @@ -4,7 +4,7 @@ from collections.abc import Iterator, MutableMapping from typing import TYPE_CHECKING -from zict.common import KT, VT, ZictBase, close, flush +from zict.common import KT, VT, ZictBase, close, discard, flush class Cache(ZictBase[KT, VT]): @@ -22,14 +22,6 @@ class Cache(ZictBase[KT, VT]): If True (default), the cache will be updated both when writing and reading. If False, update the cache when reading, but just invalidate it when writing. - Notes - ----- - All methods are thread-safe if all methods on both ``data`` and ``cache`` are - thread-safe; however, only one thread can call ``__setitem__`` and ``__delitem__`` - at any given time. - ``__contains__`` and ``__len__`` are thread-safe if the same methods on ``data`` are - thread-safe. - Examples -------- Keep the latest 100 accessed values in memory @@ -51,6 +43,7 @@ def __init__( cache: MutableMapping[KT, VT], update_on_set: bool = True, ): + super().__init__() self.data = data self.cache = cache self.update_on_set = update_on_set @@ -67,20 +60,14 @@ def __getitem__(self, key: KT) -> VT: def __setitem__(self, key: KT, value: VT) -> None: # If the item was already in cache and data.__setitem__ fails, e.g. because it's # a File and the disk is full, make sure that the cache is invalidated. - try: - del self.cache[key] - except KeyError: - pass + discard(self.cache, key) self.data[key] = value if self.update_on_set: self.cache[key] = value def __delitem__(self, key: KT) -> None: - try: - del self.cache[key] - except KeyError: - pass + discard(self.cache, key) del self.data[key] def __len__(self) -> int: diff --git a/zict/common.py b/zict/common.py index 1907e0e..07a8a44 100644 --- a/zict/common.py +++ b/zict/common.py @@ -1,18 +1,24 @@ from __future__ import annotations -from collections.abc import Iterable, Mapping +import threading +from collections.abc import Callable, Iterable, Iterator, Mapping +from contextlib import contextmanager from enum import Enum +from functools import wraps from itertools import chain from typing import MutableMapping # TODO move to collections.abc (needs Python >=3.9) -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, cast T = TypeVar("T") KT = TypeVar("KT") VT = TypeVar("VT") if TYPE_CHECKING: - # TODO import from typing (needs Python >=3.11) - from typing_extensions import Self + # TODO import ParamSpec from typing (needs Python >=3.10) + # TODO import Self from typing (needs Python >=3.11) + from typing_extensions import ParamSpec, Self + + P = ParamSpec("P") class NoDefault(Enum): @@ -25,6 +31,20 @@ class NoDefault(Enum): class ZictBase(MutableMapping[KT, VT]): """Base class for zict mappings""" + lock: threading.RLock + + def __init__(self) -> None: + self.lock = threading.RLock() + + def __getstate__(self) -> dict[str, Any]: + state = self.__dict__.copy() + del state["lock"] + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + self.__dict__ = state + self.lock = threading.RLock() + def update( # type: ignore[override] self, other: Mapping[KT, VT] | Iterable[tuple[KT, VT]] = (), @@ -41,6 +61,12 @@ def _do_update(self, items: Iterable[tuple[KT, VT]]) -> None: for k, v in items: self[k] = v + def discard(self, key: KT) -> None: + """Flush *key* if possible. + Not the same as ``m.pop(key, None)``, as it doesn't trigger ``__getitem__``. + """ + discard(self, key) + def close(self) -> None: """Release any system resources held by this object""" @@ -53,6 +79,17 @@ def __exit__(self, *args: Any) -> None: def __del__(self) -> None: self.close() + @contextmanager + def unlock(self) -> Iterator[None]: + """To be used in a method decorated by ``@locked``. + Temporarily releases the mapping's RLock. + """ + self.lock.release() + try: + yield + finally: + self.lock.acquire() + def close(*z: Any) -> None: """Close *z* if possible.""" @@ -66,3 +103,27 @@ def flush(*z: Any) -> None: for zi in z: if hasattr(zi, "flush"): zi.flush() + + +def discard(m: MutableMapping[KT, VT], key: KT) -> None: + """Flush *key* if possible. + Not the same as ``m.pop(key, None)``, as it doesn't trigger ``__getitem__``. + """ + try: + del m[key] + except KeyError: + pass + + +def locked(func: Callable[P, VT]) -> Callable[P, VT]: + """Decorator for a method of ZictBase, which wraps the whole method in a + mapping-global rlock. + """ + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> VT: + self = cast(ZictBase, args[0]) + with self.lock: + return func(*args, **kwargs) + + return wrapper diff --git a/zict/file.py b/zict/file.py index 9ee1652..e59786e 100644 --- a/zict/file.py +++ b/zict/file.py @@ -6,7 +6,7 @@ from collections.abc import Iterator from urllib.parse import quote, unquote -from zict.common import ZictBase +from zict.common import ZictBase, locked class File(ZictBase[str, bytes]): @@ -27,8 +27,13 @@ class File(ZictBase[str, bytes]): Notes ----- - This class is fully thread-safe, with only one caveat: you can't have two calls to - ``__setitem__``, on the same key, at the same time from two different threads. + If you call methods of this class from multiple threads, access will be fast as long + as atomic disk access such as ``open``, ``os.fstat``, and ``os.remove`` is fast. + This is not always the case, e.g. in case of slow network mounts or spun-down + magnetic drives. + Bytes read/write in the files is not protected by locks; this could cause failures + on Windows, NFS, and in general whenever it's not OK to delete a file while there + are file descriptors open on it. Examples -------- @@ -54,6 +59,7 @@ class File(ZictBase[str, bytes]): _inc: int def __init__(self, directory: str | pathlib.Path, memmap: bool = False): + super().__init__() self.directory = str(directory) self.memmap = memmap self.filenames = {} @@ -89,6 +95,7 @@ def __str__(self) -> str: __repr__ = __str__ + @locked def __getitem__(self, key: str) -> bytearray | memoryview: fn = os.path.join(self.directory, self.filenames[key]) @@ -99,21 +106,19 @@ def __getitem__(self, key: str) -> bytearray | memoryview: # Note that this is a dask-specific feature; vanilla pickle.loads will instead # return an array with flags.writeable=False. - try: - if self.memmap: - with open(fn, "r+b") as fh: - return memoryview(mmap.mmap(fh.fileno(), 0)) - else: - with open(fn, "rb") as fh: - size = os.fstat(fh.fileno()).st_size - buf = bytearray(size) + if self.memmap: + with open(fn, "r+b") as fh: + return memoryview(mmap.mmap(fh.fileno(), 0)) + else: + with open(fn, "rb") as fh: + size = os.fstat(fh.fileno()).st_size + buf = bytearray(size) + with self.unlock(): nread = fh.readinto(buf) - assert nread == size - return buf - - except FileNotFoundError: # pragma: nocover - raise KeyError(key) # Race condition with __setitem__ or __delitem__ + assert nread == size + return buf + @locked def __setitem__( self, key: str, @@ -123,13 +128,9 @@ def __setitem__( | list[bytes | bytearray | memoryview] | tuple[bytes | bytearray | memoryview, ...], ) -> None: - try: - del self[key] - except KeyError: - pass - + self.discard(key) fn = self._safe_key(key) - with open(os.path.join(self.directory, fn), "wb") as fh: + with open(os.path.join(self.directory, fn), "wb") as fh, self.unlock(): if isinstance(value, (tuple, list)): fh.writelines(value) else: @@ -142,12 +143,10 @@ def __contains__(self, key: object) -> bool: def __iter__(self) -> Iterator[str]: return iter(self.filenames) + @locked def __delitem__(self, key: str) -> None: fn = self.filenames.pop(key) - try: - os.remove(os.path.join(self.directory, fn)) - except FileNotFoundError: # pragma: nocover - raise KeyError(key) # Race condition with __setitem__ or __delitem__ + os.remove(os.path.join(self.directory, fn)) def __len__(self) -> int: return len(self.filenames) diff --git a/zict/func.py b/zict/func.py index 33d24a0..b45cfa0 100644 --- a/zict/func.py +++ b/zict/func.py @@ -19,15 +19,6 @@ class Func(ZictBase[KT, VT], Generic[KT, VT, WT]): Function to call on value as we pull it from the mapping d: MutableMapping - Notes - ----- - ``__contains__, ``__delitem__``, and ``__len__`` are thread-safe if the same methods - on ``d`` are thread-safe. - ``__setitem__`` and ``update`` are thread-safe if the same methods on ``d`` as well - ``dump`` are thread-safe. - ``__getitem__`` is thread-safe if both ``d.__getitem__`` and ``load`` are - thread-safe. - Examples -------- >>> def double(x): @@ -55,6 +46,7 @@ def __init__( load: Callable[[WT], VT], d: MutableMapping[KT, WT], ): + super().__init__() self.dump = dump self.load = load self.d = d diff --git a/zict/lmdb.py b/zict/lmdb.py index 01f80fe..15324c4 100644 --- a/zict/lmdb.py +++ b/zict/lmdb.py @@ -50,6 +50,7 @@ class LMDB(ZictBase[str, bytes]): def __init__(self, directory: str | pathlib.Path, map_size: int | None = None): import lmdb + super().__init__() if map_size is None: if sys.platform != "win32": map_size = min(2**40, sys.maxsize // 4) diff --git a/zict/lru.py b/zict/lru.py index 4115b53..a996897 100644 --- a/zict/lru.py +++ b/zict/lru.py @@ -9,8 +9,8 @@ ValuesView, ) -from zict.common import KT, VT, NoDefault, ZictBase, close, flush, nodefault -from zict.utils import Accumulator, InsertionSortedSet +from zict.common import KT, VT, NoDefault, ZictBase, close, flush, locked, nodefault +from zict.utils import InsertionSortedSet class LRU(ZictBase[KT, VT]): @@ -23,23 +23,24 @@ class LRU(ZictBase[KT, VT]): d: MutableMapping Dict-like in which to hold elements. There are no expectations on its internal ordering. Iteration on the LRU follows the order of the underlying mapping. - on_evict: list of callables - Function:: k, v -> action to call on key value pairs prior to eviction + on_evict: callable or list of callables + Function:: k, v -> action to call on key/value pairs prior to eviction If an exception occurs during an on_evict callback (e.g a callback tried storing to disk and raised a disk full error) the key will remain in the LRU. + on_cancel_evict: callable or list of callables + Function:: k, v -> action to call on key/value pairs if they're deleted or + updated from a thread while the on_evict callables are being executed in + another. + If you're not accessing the LRU from multiple threads, ignore this parameter. weight: callable Function:: k, v -> number to determine the size of keeping the item in the mapping. Defaults to ``(k, v) -> 1`` Notes ----- - Most methods are thread-safe if the same methods on ``d`` are thread-safe. - ``__setitem__``, ``__delitem__``, :meth:`evict`, and - :meth:`evict_until_below_capacity` also require all callables in ``on_evict`` to be - thread-safe and should not be called from different threads for the same - key. It's OK to set/delete different keys from different threads, it's OK to set a - key in a thread and read it from many other threads, but it's not OK to set/delete - the same key from different threads at the same time. + If you call methods of this class from multiple threads, access will be fast as long + as all methods of ``d`` are fast. Callbacks are not protected by locks and can be + arbitrarily slow. Examples -------- @@ -54,45 +55,57 @@ class LRU(ZictBase[KT, VT]): order: InsertionSortedSet[KT] heavy: InsertionSortedSet[KT] on_evict: list[Callable[[KT, VT], None]] + on_cancel_evict: list[Callable[[KT, VT], None]] weight: Callable[[KT, VT], float] n: float weights: dict[KT, float] closed: bool - total_weight: Accumulator + total_weight: float + _cancel_evict: dict[KT, bool] def __init__( self, n: float, d: MutableMapping[KT, VT], + *, on_evict: Callable[[KT, VT], None] | list[Callable[[KT, VT], None]] | None = None, + on_cancel_evict: Callable[[KT, VT], None] + | list[Callable[[KT, VT], None]] + | None = None, weight: Callable[[KT, VT], float] = lambda k, v: 1, ): + super().__init__() self.d = d self.n = n + if callable(on_evict): on_evict = [on_evict] self.on_evict = on_evict or [] + if callable(on_cancel_evict): + on_cancel_evict = [on_cancel_evict] + self.on_cancel_evict = on_cancel_evict or [] + self.weight = weight self.weights = {k: weight(k, v) for k, v in d.items()} - self.total_weight = Accumulator(sum(self.weights.values())) + self.total_weight = sum(self.weights.values()) self.order = InsertionSortedSet(d) self.heavy = InsertionSortedSet(k for k, v in self.weights.items() if v >= n) self.closed = False + self._cancel_evict = {} + @locked def __getitem__(self, key: KT) -> VT: result = self.d[key] - # Don't use .remove() to prevent race condition which can happen during - # multithreaded access - self.order.discard(key) + self.order.remove(key) self.order.add(key) return result def __setitem__(self, key: KT, value: VT) -> None: self.set_noevict(key, value) try: - self.evict_until_below_capacity() + self.evict_until_below_target() except Exception: if self.weights[key] > self.n and key not in self.heavy: # weight(value) > n and evicting the key we just inserted failed. @@ -104,18 +117,17 @@ def __setitem__(self, key: KT, value: VT) -> None: pass raise + @locked def set_noevict(self, key: KT, value: VT) -> None: """Variant of ``__setitem__`` that does not evict if the total weight exceeds n. Unlike ``__setitem__``, this method does not depend on the ``on_evict`` functions to be thread-safe for its own thread-safety. It also is not prone to re-raising exceptions from the ``on_evict`` callbacks. """ - try: - del self[key] - except KeyError: - pass - + self.discard(key) weight = self.weight(key, value) + if key in self._cancel_evict: + self._cancel_evict[key] = True self.d[key] = value self.order.add(key) if weight > self.n: @@ -123,12 +135,23 @@ def set_noevict(self, key: KT, value: VT) -> None: self.weights[key] = weight self.total_weight += weight - def evict_until_below_capacity(self) -> None: - """Evict key/value pairs until the total weight falls below n""" - while self.total_weight > self.n and not self.closed: + def evict_until_below_target(self, n: float | None = None) -> None: + """Evict key/value pairs until the total weight falls below n + + Parameters + ---------- + n: float, optional + Total weight threshold to achieve. Defaults to self.n. + """ + if n is None: + n = self.n + while self.total_weight > n and not self.closed: self.evict() - def evict(self, key: KT | NoDefault = nodefault) -> tuple[KT, VT, float]: + @locked + def evict( + self, key: KT | NoDefault = nodefault + ) -> tuple[KT, VT, float] | tuple[None, None, float]: """Evict least recently used key, or least recently inserted key with individual weight > n, if any. You may also evict a specific key. @@ -138,43 +161,58 @@ def evict(self, key: KT | NoDefault = nodefault) -> tuple[KT, VT, float]: Returns ------- Tuple of (key, value, weight) + + Or (None, None, 0) if the key that was being evicted was updated or deleted from + another thread while the on_evict callbacks were being executed. This outcome is + only possible in multithreaded access. """ + if key is nodefault: + try: + key = next(iter(self.heavy or self.order)) + except StopIteration: + raise KeyError("evict(): dictionary is empty") + + if key in self._cancel_evict: + return None, None, 0 + # For the purpose of multithreaded access, it's important that the value remains # in self.d until all callbacks are successful. # When this is used inside a Buffer, there must never be a moment when the key # is neither in fast nor in slow. - if key is nodefault: - while True: - try: - key = next(iter(self.heavy or self.order)) - value = self.d[key] - break - except StopIteration: - raise KeyError("evict(): dictionary is empty") - except (KeyError, RuntimeError): # pragma: nocover - pass # Race condition caused by multithreading - else: - value = self.d[key] + value = self.d[key] # If we are evicting a heavy key we just inserted and one of the callbacks # fails, put it at the bottom of the LRU instead of the top. This way lighter # keys will have a chance to be evicted first and make space. self.heavy.discard(key) - # This may raise; e.g. if a callback tries storing to a full disk - for cb in self.on_evict: - cb(key, value) + self._cancel_evict[key] = False + try: + with self.unlock(): + # This may raise; e.g. if a callback tries storing to a full disk + for cb in self.on_evict: + cb(key, value) + + if self._cancel_evict[key]: + for cb in self.on_cancel_evict: + cb(key, value) + return None, None, 0 + finally: + self._cancel_evict.pop(key) - self.d.pop(key, None) # type: ignore[arg-type] - self.order.discard(key) + del self.d[key] + self.order.remove(key) weight = self.weights.pop(key) self.total_weight -= weight return key, value, weight + @locked def __delitem__(self, key: KT) -> None: + if key in self._cancel_evict: + self._cancel_evict[key] = True del self.d[key] - self.order.discard(key) + self.order.remove(key) self.heavy.discard(key) self.total_weight -= self.weights.pop(key) diff --git a/zict/sieve.py b/zict/sieve.py index 44ff845..e32262d 100644 --- a/zict/sieve.py +++ b/zict/sieve.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping from typing import Generic, TypeVar -from zict.common import KT, VT, ZictBase, close, flush +from zict.common import KT, VT, ZictBase, close, flush, locked MKT = TypeVar("MKT") @@ -24,9 +24,10 @@ class Sieve(ZictBase[KT, VT], Generic[KT, VT, MKT]): Notes ----- - ``__contains__`` is thread-safe. - ``__len__`` is thread-safe if the same method on all mappings is thread-safe. - All other methods are not thread-safe. + If you call methods of this class from multiple threads, access will be fast as long + as the ``__contains__`` and ``__delitem__`` methods of all underlying mappins, are + fast. ``__getitem__`` and ``__setitem__`` methods of the underlying mappings are not + protected by locks. Examples -------- @@ -36,10 +37,6 @@ class Sieve(ZictBase[KT, VT], Generic[KT, VT, MKT]): >>> def is_small(key, value): # doctest: +SKIP ... return sys.getsizeof(value) < 10000 # doctest: +SKIP >>> d = Sieve(mappings, is_small) # doctest: +SKIP - - See Also - -------- - Buffer """ mappings: Mapping[MKT, MutableMapping[KT, VT]] @@ -51,6 +48,7 @@ def __init__( mappings: Mapping[MKT, MutableMapping[KT, VT]], selector: Callable[[KT, VT], MKT], ): + super().__init__() self.mappings = mappings self.selector = selector self.key_to_mapping = {} @@ -59,38 +57,42 @@ def __getitem__(self, key: KT) -> VT: return self.key_to_mapping[key][key] def __setitem__(self, key: KT, value: VT) -> None: - old_mapping = self.key_to_mapping.get(key) - mkey = self.selector(key, value) - mapping = self.mappings[mkey] - if old_mapping is not None and old_mapping is not mapping: - del old_mapping[key] + with self.lock: + old_mapping = self.key_to_mapping.get(key) + mkey = self.selector(key, value) + mapping = self.mappings[mkey] + if old_mapping is not None and old_mapping is not mapping: + del old_mapping[key] + self.key_to_mapping[key] = mapping + mapping[key] = value - self.key_to_mapping[key] = mapping + @locked def __delitem__(self, key: KT) -> None: del self.key_to_mapping.pop(key)[key] def _do_update(self, items: Iterable[tuple[KT, VT]]) -> None: # Optimized update() implementation issuing a single update() # call per underlying mapping. - updates = defaultdict(list) - mapping_ids = {id(m): m for m in self.mappings.values()} - - for key, value in items: - old_mapping = self.key_to_mapping.get(key) - mkey = self.selector(key, value) - mapping = self.mappings[mkey] - if old_mapping is not None and old_mapping is not mapping: - del old_mapping[key] - # Can't hash a mutable mapping, so use its id() instead - updates[id(mapping)].append((key, value)) + with self.lock: + updates = defaultdict(list) + mapping_ids = {id(m): m for m in self.mappings.values()} + + for key, value in items: + old_mapping = self.key_to_mapping.get(key) + mkey = self.selector(key, value) + mapping = self.mappings[mkey] + if old_mapping is not None and old_mapping is not mapping: + del old_mapping[key] + # Can't hash a mutable mapping, so use its id() instead + updates[id(mapping)].append((key, value)) + self.key_to_mapping[key] = mapping for mid, mitems in updates.items(): mapping = mapping_ids[mid] mapping.update(mitems) - for key, _ in mitems: - self.key_to_mapping[key] = mapping + @locked def __len__(self) -> int: return sum(map(len, self.mappings.values())) diff --git a/zict/tests/conftest.py b/zict/tests/conftest.py new file mode 100644 index 0000000..e3597bd --- /dev/null +++ b/zict/tests/conftest.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor + +import pytest + + +@pytest.fixture +def is_locked(): + """Callable that returns True if the parameter zict mapping has its RLock engaged""" + with ThreadPoolExecutor(1) as ex: + + def __is_locked(d): + out = d.lock.acquire(blocking=False) + if out: + d.lock.release() + return not out + + def _is_locked(d): + return ex.submit(__is_locked, d).result() + + yield _is_locked diff --git a/zict/tests/test_buffer.py b/zict/tests/test_buffer.py index 253658b..15b151b 100644 --- a/zict/tests/test_buffer.py +++ b/zict/tests/test_buffer.py @@ -185,7 +185,7 @@ def s2f_cb(k, v): assert b == {"x": 1} # Add key > n, again total weight > n this will move everything to slow except w - # that stays in fast due after callback raise + # that stays in fast due to callback raising with pytest.raises(MyError): buff["w"] = 11 @@ -216,7 +216,7 @@ def test_set_noevict(): assert b == {} assert f2s == s2f == [] - buff.fast.evict_until_below_capacity() + buff.evict_until_below_target() assert a == {"y": 3} assert b == {"z": 6, "x": 3} assert f2s == ["z", "x"] @@ -229,6 +229,13 @@ def test_set_noevict(): assert b == {"z": 6} assert f2s == s2f == [] + # Custom target; 0 != None + buff.evict_until_below_target(0) + assert a == {} + assert b == {"z": 6, "x": 1, "y": 3} + assert f2s == ["y", "x"] + assert s2f == [] + def test_evict_restore_during_iter(): """Test that __iter__ won't be disrupted if another thread evicts or restores a key""" @@ -242,3 +249,19 @@ def test_evict_restore_during_iter(): assert next(it) == "z" with pytest.raises(StopIteration): next(it) + + +def test_cancel_evict(): + ... # TODO + + +def test_cancel_restore(): + ... # TODO + + +def test_callbacks_are_unlocked(): + ... # TODO + + +def test_stress_threadsafe(): + ... # TODO diff --git a/zict/tests/test_cache.py b/zict/tests/test_cache.py index 57569df..68c3bbf 100644 --- a/zict/tests/test_cache.py +++ b/zict/tests/test_cache.py @@ -139,3 +139,7 @@ def test_mapping(): buff = Cache({}, {}) utils_test.check_mapping(buff) utils_test.check_closing(buff) + + +def test_stress_threadsafe(): + ... # TODO diff --git a/zict/tests/test_common.py b/zict/tests/test_common.py index e10f7b5..4d7a2b8 100644 --- a/zict/tests/test_common.py +++ b/zict/tests/test_common.py @@ -1,12 +1,15 @@ -from collections import UserDict +import pickle -from zict.common import ZictBase +import pytest + +from zict.common import locked +from zict.tests.utils_test import SimpleDict def test_close_on_del(): closed = False - class D(ZictBase, UserDict): + class D(SimpleDict): def close(self): nonlocal closed closed = True @@ -19,7 +22,7 @@ def close(self): def test_context(): closed = False - class D(ZictBase, UserDict): + class D(SimpleDict): def close(self): nonlocal closed closed = True @@ -33,7 +36,7 @@ def close(self): def test_update(): items = [] - class D(ZictBase, UserDict): + class D(SimpleDict): def _do_update(self, items_): nonlocal items items = items_ @@ -51,3 +54,55 @@ def _do_update(self, items_): # Special kwargs can't overwrite positional-only parameters d.update(self=1, other=2) assert list(items) == [("self", 1), ("other", 2)] + + +def test_discard(): + class D(SimpleDict): + def __getitem__(self, key): + raise AssertionError() + + d = D() + d["x"] = 1 + d["z"] = 2 + d.discard("x") + d.discard("y") + assert d.data == {"z": 2} + + +def test_pickle(): + d = SimpleDict() + d["x"] = 1 + d2 = pickle.loads(pickle.dumps(d)) + assert d2.data == {"x": 1} + + +def test_lock(is_locked): + class CustomError(Exception): + pass + + class D(SimpleDict): + @locked + def f(self, crash): + assert is_locked(self) + with self.unlock(): + assert not is_locked(self) + assert is_locked(self) + + # context manager re-acquires the lock on failure + with pytest.raises(CustomError): + with self.unlock(): + raise CustomError() + assert is_locked(self) + + if crash: + raise CustomError() + + d = D() + assert not is_locked(d) + d.f(crash=False) + assert not is_locked(d) + + # decorator releases the lock on failure + with pytest.raises(CustomError): + d.f(crash=True) + assert not is_locked(d) diff --git a/zict/tests/test_file.py b/zict/tests/test_file.py index bddefab..5b4e1c8 100644 --- a/zict/tests/test_file.py +++ b/zict/tests/test_file.py @@ -1,5 +1,6 @@ import os import pathlib +import sys import time from concurrent.futures import ThreadPoolExecutor from threading import Barrier @@ -158,3 +159,12 @@ def worker(key, start): assert f2.result() > 100 assert not z + + +def test_reads_writes_are_unlocked(): + ... # TODO + + +@pytest.mark.skipif(sys.platform == "win32", reason="Can't delete file with open fd") +def test_stress_threadsafe(tmpdir): + ... # TODO diff --git a/zict/tests/test_lru.py b/zict/tests/test_lru.py index 57344c4..aa9d670 100644 --- a/zict/tests/test_lru.py +++ b/zict/tests/test_lru.py @@ -1,11 +1,9 @@ -from collections import UserDict from concurrent.futures import ThreadPoolExecutor from threading import Barrier import pytest from zict import LRU -from zict.common import ZictBase from zict.tests import utils_test @@ -215,7 +213,7 @@ def test_weight(): assert d == {"y": 4} -def test_noevict(): +def test_manual_eviction(): a = [] lru = LRU(100, {}, weight=lambda k, v: v, on_evict=lambda k, v: a.append(k)) lru.set_noevict("x", 70) @@ -226,13 +224,23 @@ def test_noevict(): assert list(lru.order) == ["x", "y", "z"] assert a == [] - lru.evict_until_below_capacity() + lru.evict_until_below_target() assert dict(lru) == {"y": 50} assert a == ["z", "x"] assert lru.weights == {"y": 50} assert lru.order == {"y"} assert not lru.heavy + lru.evict_until_below_target() # No-op + assert dict(lru) == {"y": 50} + lru.evict_until_below_target(50) # Custom target + assert dict(lru) == {"y": 50} + lru.evict_until_below_target(0) # 0 != None + assert not lru + assert not lru.order + assert not lru.weights + assert a == ["z", "x", "y"] + def test_explicit_evict(): d = {} @@ -283,7 +291,7 @@ def test_getitem_is_threasafe(): def f(_): barrier.wait() - for _ in range(5_000_000): + for _ in range(500_000): assert lru["x"] == 1 barrier = Barrier(2) @@ -316,7 +324,7 @@ def test_flush_close(): flushed = 0 closed = False - class D(ZictBase, UserDict): + class D(utils_test.SimpleDict): def flush(self): nonlocal flushed flushed += 1 @@ -330,3 +338,19 @@ def close(self): assert flushed == 1 assert closed + + +def test_cancel_evict(): + ... # TODO + + +def test_slow_writes_are_unlocked(): + ... # TODO + + +def test_callbacks_are_unlocked(): + ... # TODO + + +def test_stress_threadsafe(): + ... # TODO diff --git a/zict/tests/test_sieve.py b/zict/tests/test_sieve.py index 9b6be69..a529622 100644 --- a/zict/tests/test_sieve.py +++ b/zict/tests/test_sieve.py @@ -72,3 +72,7 @@ def selector(key, value): z = Sieve(mappings, selector) utils_test.check_mapping(z) utils_test.check_closing(z) + + +def test_stress_threadsafe(): + ... # TODO diff --git a/zict/tests/test_utils.py b/zict/tests/test_utils.py index 434afb2..a34757d 100644 --- a/zict/tests/test_utils.py +++ b/zict/tests/test_utils.py @@ -3,8 +3,7 @@ import pytest -from zict import Accumulator, InsertionSortedSet -from zict.utils import ATOMIC_INT_IADD +from zict import InsertionSortedSet def test_insertion_sorted_set(): @@ -108,66 +107,3 @@ def t(): # On Windows, we've seen as little as 2300. assert f1.result() > 100 assert f2.result() > 100 - - -def test_accumulator(): - acc = Accumulator() - assert acc == 0 - acc = Accumulator(123) - assert acc == 123 - assert repr(acc) == "123" - acc += 1 - assert acc == 124 - acc -= 1 - assert acc == 123 - acc += 0.5 - assert acc == 123.5 - - # Test operators - assert int(acc) == 123 - assert float(acc) == 123.5 - assert not acc != 123.5 - assert acc >= 123.5 - assert not acc >= 124 - assert acc > 123 - assert not acc > 123.5 - assert acc <= 123.5 - assert not acc <= 123 - assert acc < 124 - assert not acc < 123 - assert acc + 1 == 124.5 - assert acc - 1 == 122.5 - assert acc * 2 == 247 - assert acc / 2 == 61.75 - assert hash(acc) == hash(123.5) - - -@pytest.mark.parametrize("dtype", [int, float]) -def test_accumulator_threadsafe(dtype): - acc = Accumulator(dtype(2)) - if ATOMIC_INT_IADD: - # CPython >= 3.10 - assert isinstance(acc, dtype) - N = 10_000_000 - expect = 99999970000002 - else: - assert isinstance(acc, Accumulator) - N = 1_000_000 - expect = 999997000002 - - barrier = Barrier(2) - - def t(): - nonlocal acc - barrier.wait() - for i in range(N): - acc += i - acc -= 1 - assert acc >= 0 - - with ThreadPoolExecutor(2) as ex: - f1 = ex.submit(t) - f2 = ex.submit(t) - f1.result() - f2.result() - assert acc == expect diff --git a/zict/tests/utils_test.py b/zict/tests/utils_test.py index bd5c777..8eab9ca 100644 --- a/zict/tests/utils_test.py +++ b/zict/tests/utils_test.py @@ -1,9 +1,12 @@ import random import string +from collections import UserDict from collections.abc import ItemsView, KeysView, MutableMapping, ValuesView import pytest +from zict.common import ZictBase + def generate_random_strings(n, min_len, max_len): r = random.Random(42) @@ -131,3 +134,9 @@ def check_mapping(z): def check_closing(z): z.close() + + +class SimpleDict(ZictBase, UserDict): + def __init__(self): + ZictBase.__init__(self) + UserDict.__init__(self) diff --git a/zict/utils.py b/zict/utils.py index 4f2033d..438310b 100644 --- a/zict/utils.py +++ b/zict/utils.py @@ -1,11 +1,6 @@ from __future__ import annotations -import platform -import sys -import threading -from collections import defaultdict from collections.abc import Iterable, Iterator -from numbers import Number from typing import MutableSet # TODO import from collections.abc (needs Python >=3.9) from zict.common import T @@ -70,90 +65,3 @@ def popright(self) -> T: def clear(self) -> None: self._d.clear() - - -ATOMIC_INT_IADD = ( - platform.python_implementation() == "CPython" and sys.version_info >= (3, 10) -) - - -class Accumulator(Number): - """A lockless thread-safe accumulator""" - - _values: defaultdict[int, float] - __slots__ = ("_values",) - - def __new__(cls, value: float = 0) -> Accumulator: - if ATOMIC_INT_IADD: - # int.__iadd__ and float.__iadd__ are GIL-atomic starting from CPython 3.10. - # We can get rid of the whole class and just use them instead. - # This is an implementation detail. - return value # type: ignore[return-value] - - self = object.__new__(cls) - # Don't return float unless you actually added floats. - # This behaviour is consistent with sum(). - self._values = defaultdict(int) - self._values[threading.get_ident()] = value - return self - - def _value(self) -> float: - """Return accumulator total across all threads. - The return type is float if any float elements were added, otherwise it's int. - """ - while True: - try: - return sum(self._values.values()) - except RuntimeError: # dictionary changed size during iteration - pass # pragma: nocover - - def __iadd__(self, other: float) -> Accumulator: - self._values[threading.get_ident()] += other - return self - - def __isub__(self, other: float) -> Accumulator: - self._values[threading.get_ident()] -= other - return self - - # Trivial wrappers around self._value(). - # Since they are magic methods, they can't be implemented with __getattr__ - # or with accessor classes. - - def __repr__(self) -> str: - return repr(self._value()) - - def __int__(self) -> int: - return int(self._value()) - - def __float__(self) -> float: - return float(self._value()) - - def __eq__(self, other: object) -> bool: - return self._value() == other - - def __gt__(self, other: float) -> bool: - return self._value() > other - - def __ge__(self, other: float) -> bool: - return self._value() >= other - - def __lt__(self, other: float) -> bool: - return self._value() < other - - def __le__(self, other: float) -> bool: - return self._value() <= other - - def __add__(self, other: float) -> float: - return self._value() + other - - def __sub__(self, other: float) -> float: - return self._value() - other - - def __mul__(self, other: float) -> float: - return self._value() * other - - def __truediv__(self, other: float) -> float: - return self._value() / other - - def __hash__(self) -> int: - return hash(self._value()) diff --git a/zict/zip.py b/zict/zip.py index d2902d4..199221b 100644 --- a/zict/zip.py +++ b/zict/zip.py @@ -41,6 +41,7 @@ class Zip(MutableMapping[str, bytes]): _file: zipfile.ZipFile | None def __init__(self, filename: str, mode: FileMode = "a"): + super().__init__() self.filename = filename self.mode = mode self._file = None