Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert dict entry to dataclass & typing #239

Merged
merged 13 commits into from
Oct 17, 2024
4 changes: 2 additions & 2 deletions .github/workflows/ci-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:

- name: Unit tests (local)
if: matrix.backend == 'local'
run: pytest -m "not mongo"
run: pytest -m "not mongo" --cov=cachier --cov-report=term --cov-report=xml:cov.xml

- name: Setup docker (missing on MacOS)
if: runner.os == 'macOS' && matrix.backend == 'db'
Expand All @@ -77,7 +77,7 @@ jobs:
docker ps -a
- name: Unit tests (DB)
if: matrix.backend == 'db'
run: pytest -m "mongo"
run: pytest -m "mongo" --cov=cachier --cov-report=term --cov-report=xml:cov.xml
- name: Speed eval
run: python tests/speed_eval.py

Expand Down
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,6 @@ norecursedirs = [
]
addopts = [
"--color=yes",
"--cov=cachier",
"--cov-report=term",
"--cov-report=xml:cov.xml",
"-r a",
"-v",
"-s",
Expand Down
14 changes: 13 additions & 1 deletion src/cachier/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import hashlib
import os
import pickle
import threading
from collections.abc import Mapping
from dataclasses import dataclass, replace
from typing import Optional, Union
from typing import Any, Optional, Union

from ._types import Backend, HashFunc, Mongetter

Expand Down Expand Up @@ -38,6 +39,17 @@ class Params:
_global_params = Params()


@dataclass
class CacheEntry:
"""Data class for cache entries."""

value: Any
time: datetime
stale: bool
being_calculated: bool
condition: Optional[threading.Condition] = None


def _update_with_defaults(
param, name: str, func_kwargs: Optional[dict] = None
):
Expand Down
14 changes: 7 additions & 7 deletions src/cachier/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,17 +258,17 @@ def func_wrapper(*args, **kwds):
_print("No entry found. No current calc. Calling like a boss.")
return _calc_entry(core, key, func, args, kwds)
_print("Entry found.")
if _allow_none or entry.get("value", None) is not None:
if _allow_none or entry.value is not None:
_print("Cached result found.")
now = datetime.datetime.now()
if now - entry["time"] <= _stale_after:
if now - entry.time <= _stale_after:
_print("And it is fresh!")
return entry["value"]
return entry.value
_print("But it is stale... :(")
if entry["being_calculated"]:
if entry.being_calculated:
if _next_time:
_print("Returning stale.")
return entry["value"] # return stale val
return entry.value # return stale val
_print("Already calc. Waiting on change.")
try:
return core.wait_on_entry_calc(key)
Expand All @@ -283,10 +283,10 @@ def func_wrapper(*args, **kwds):
)
finally:
core.mark_entry_not_calculated(key)
return entry["value"]
return entry.value
_print("Calling decorated function and waiting")
return _calc_entry(core, key, func, args, kwds)
if entry["being_calculated"]:
if entry.being_calculated:
_print("No value but being calculated. Waiting.")
try:
return core.wait_on_entry_calc(key)
Expand Down
20 changes: 10 additions & 10 deletions src/cachier/cores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import abc # for the _BaseCore abstract base class
import inspect
import threading
from typing import Callable
from typing import Callable, Optional, Tuple

from .._types import HashFunc
from ..config import _update_with_defaults
from ..config import CacheEntry, _update_with_defaults


class RecalculationNeeded(Exception):
Expand Down Expand Up @@ -51,7 +51,7 @@ def get_key(self, args, kwds):
"""Return a unique key based on the arguments provided."""
return self.hash_func(args, kwds)

def get_entry(self, args, kwds):
def get_entry(self, args, kwds) -> Tuple[str, Optional[CacheEntry]]:
"""Get entry based on given arguments.

Return the result mapped to the given arguments in this core's cache,
Expand All @@ -76,7 +76,7 @@ def check_calc_timeout(self, time_spent):
raise RecalculationNeeded()

@abc.abstractmethod
def get_entry_by_key(self, key):
def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]:
"""Get entry based on given key.

Return the result mapped to the given key in this core's cache, if such
Expand All @@ -85,25 +85,25 @@ def get_entry_by_key(self, key):
"""

@abc.abstractmethod
def set_entry(self, key, func_res):
def set_entry(self, key: str, func_res):
"""Map the given result to the given key in this core's cache."""

@abc.abstractmethod
def mark_entry_being_calculated(self, key):
def mark_entry_being_calculated(self, key: str) -> None:
"""Mark the entry mapped by the given key as being calculated."""

@abc.abstractmethod
def mark_entry_not_calculated(self, key):
def mark_entry_not_calculated(self, key: str) -> None:
"""Mark the entry mapped by the given key as not being calculated."""

@abc.abstractmethod
def wait_on_entry_calc(self, key):
def wait_on_entry_calc(self, key: str) -> None:
"""Wait on the entry with keys being calculated and returns result."""

@abc.abstractmethod
def clear_cache(self):
def clear_cache(self) -> None:
"""Clear the cache of this core."""

@abc.abstractmethod
def clear_being_calculated(self):
def clear_being_calculated(self) -> None:
"""Mark all entries in this cache as not being calculated."""
76 changes: 40 additions & 36 deletions src/cachier/cores/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import threading
from datetime import datetime
from typing import Any, Optional, Tuple

from .._types import HashFunc
from ..config import CacheEntry
from .base import _BaseCore, _get_func_str


Expand All @@ -14,76 +16,78 @@ def __init__(self, hash_func: HashFunc, wait_for_calc_timeout: int):
super().__init__(hash_func, wait_for_calc_timeout)
self.cache = {}

def _hash_func_key(self, key):
def _hash_func_key(self, key: str) -> str:
return f"{_get_func_str(self.func)}:{key}"

def get_entry_by_key(self, key, reload=False):
def get_entry_by_key(
self, key: str, reload=False
) -> Tuple[str, Optional[CacheEntry]]:
with self.lock:
return key, self.cache.get(self._hash_func_key(key), None)

def set_entry(self, key, func_res):
def set_entry(self, key: str, func_res: Any) -> None:
with self.lock:
try:
# we need to retain the existing condition so that
# mark_entry_not_calculated can notify all possibly-waiting
# threads about it
cond = self.cache[self._hash_func_key(key)]["condition"]
cond = self.cache[self._hash_func_key(key)].condition
except KeyError: # pragma: no cover
cond = None
self.cache[self._hash_func_key(key)] = {
"value": func_res,
"time": datetime.now(),
"stale": False,
"being_calculated": False,
"condition": cond,
}
self.cache[self._hash_func_key(key)] = CacheEntry(
value=func_res,
time=datetime.now(),
stale=False,
being_calculated=False,
condition=cond,
)

def mark_entry_being_calculated(self, key):
def mark_entry_being_calculated(self, key: str) -> None:
with self.lock:
condition = threading.Condition()
# condition.acquire()
try:
self.cache[self._hash_func_key(key)]["being_calculated"] = True
self.cache[self._hash_func_key(key)]["condition"] = condition
self.cache[self._hash_func_key(key)].being_calculated = True
self.cache[self._hash_func_key(key)].condition = condition
except KeyError:
self.cache[self._hash_func_key(key)] = {
"value": None,
"time": datetime.now(),
"stale": False,
"being_calculated": True,
"condition": condition,
}
self.cache[self._hash_func_key(key)] = CacheEntry(
value=None,
time=datetime.now(),
stale=False,
being_calculated=True,
condition=condition,
)

def mark_entry_not_calculated(self, key):
def mark_entry_not_calculated(self, key: str) -> None:
with self.lock:
try:
entry = self.cache[self._hash_func_key(key)]
except KeyError: # pragma: no cover
return # that's ok, we don't need an entry in that case
entry["being_calculated"] = False
cond = entry["condition"]
entry.being_calculated = False
cond = entry.condition
if cond:
cond.acquire()
cond.notify_all()
cond.release()
entry["condition"] = None
entry.condition = None

def wait_on_entry_calc(self, key):
def wait_on_entry_calc(self, key: str) -> Any:
with self.lock: # pragma: no cover
entry = self.cache[self._hash_func_key(key)]
if not entry["being_calculated"]:
return entry["value"]
entry["condition"].acquire()
entry["condition"].wait()
entry["condition"].release()
return self.cache[self._hash_func_key(key)]["value"]
if not entry.being_calculated:
return entry.value
entry.condition.acquire()
entry.condition.wait()
entry.condition.release()
return self.cache[self._hash_func_key(key)].value

def clear_cache(self):
def clear_cache(self) -> None:
with self.lock:
self.cache.clear()

def clear_being_calculated(self):
def clear_being_calculated(self) -> None:
with self.lock:
for entry in self.cache.values():
entry["being_calculated"] = False
entry["condition"] = None
entry.being_calculated = False
entry.condition = None
44 changes: 23 additions & 21 deletions src/cachier/cores/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import warnings # to warn if pymongo is missing
from contextlib import suppress
from datetime import datetime
from typing import Any, Optional, Tuple

from .._types import HashFunc, Mongetter
from ..config import CacheEntry

with suppress(ImportError):
from bson.binary import Binary # to save binary data to mongodb
Expand Down Expand Up @@ -65,29 +67,29 @@ def __init__(
def _func_str(self) -> str:
return _get_func_str(self.func)

def get_entry_by_key(self, key):
def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]:
res = self.mongo_collection.find_one(
{"func": self._func_str, "key": key}
)
if not res:
return key, None
try:
entry = {
"value": pickle.loads(res["value"]), # noqa: S301
"time": res.get("time", None),
"stale": res.get("stale", False),
"being_calculated": res.get("being_calculated", False),
}
entry = CacheEntry(
value=pickle.loads(res["value"]), # noqa: S301
time=res.get("time", None),
stale=res.get("stale", False),
being_calculated=res.get("being_calculated", False),
)
except KeyError:
entry = {
"value": None,
"time": res.get("time", None),
"stale": res.get("stale", False),
"being_calculated": res.get("being_calculated", False),
}
entry = CacheEntry(
value=None,
time=res.get("time", None),
stale=res.get("stale", False),
being_calculated=res.get("being_calculated", False),
)
return key, entry

def set_entry(self, key, func_res):
def set_entry(self, key: str, func_res: Any) -> None:
thebytes = pickle.dumps(func_res)
self.mongo_collection.update_one(
filter={"func": self._func_str, "key": key},
Expand All @@ -104,14 +106,14 @@ def set_entry(self, key, func_res):
upsert=True,
)

def mark_entry_being_calculated(self, key):
def mark_entry_being_calculated(self, key: str) -> None:
self.mongo_collection.update_one(
filter={"func": self._func_str, "key": key},
update={"$set": {"being_calculated": True}},
upsert=True,
)

def mark_entry_not_calculated(self, key):
def mark_entry_not_calculated(self, key: str) -> None:
with suppress(OperationFailure): # don't care in this case
self.mongo_collection.update_one(
filter={
Expand All @@ -122,22 +124,22 @@ def mark_entry_not_calculated(self, key):
upsert=False, # should not insert in this case
)

def wait_on_entry_calc(self, key):
def wait_on_entry_calc(self, key: str) -> Any:
time_spent = 0
while True:
time.sleep(MONGO_SLEEP_DURATION_IN_SEC)
time_spent += MONGO_SLEEP_DURATION_IN_SEC
key, entry = self.get_entry_by_key(key)
if entry is None:
raise RecalculationNeeded()
if not entry["being_calculated"]:
return entry["value"]
if not entry.being_calculated:
return entry.value
self.check_calc_timeout(time_spent)

def clear_cache(self):
def clear_cache(self) -> None:
self.mongo_collection.delete_many(filter={"func": self._func_str})

def clear_being_calculated(self):
def clear_being_calculated(self) -> None:
self.mongo_collection.update_many(
filter={
"func": self._func_str,
Expand Down
Loading
Loading