Skip to content

Commit

Permalink
Changed tests and transaction atomic using according to pull request …
Browse files Browse the repository at this point in the history
…message
  • Loading branch information
dvdria committed Apr 30, 2024
1 parent 76d2d1b commit 223b8de
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 39 deletions.
4 changes: 2 additions & 2 deletions django_celery_results/backends/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from celery.result import GroupResult, allow_join_result, result_from_tuple
from celery.utils.log import get_logger
from celery.utils.serialization import b64decode, b64encode
from django.db import connection, transaction
from django.db import connection, router, transaction
from django.db.utils import InterfaceError
from kombu.exceptions import DecodeError

Expand Down Expand Up @@ -246,7 +246,7 @@ def on_chord_part_return(self, request, state, result, **kwargs):
if not gid or not tid:
return
call_callback = False
with transaction.atomic(using=ChordCounter.objects.db):
with transaction.atomic(using=router.db_for_write(ChordCounter)):
# We need to know if `count` hits 0.
# wrap the update in a transaction
# with a `select_for_update` lock to prevent race conditions.
Expand Down
24 changes: 24 additions & 0 deletions t/proj/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
'NAME': 'postgres',
'USER': 'postgres',
'PASSWORD': 'postgres',
"PORT": 5434,
'OPTIONS': {
'connect_timeout': 1000,
}
Expand All @@ -42,13 +43,29 @@
'NAME': 'postgres',
'USER': 'postgres',
'PASSWORD': 'postgres',
"PORT": 5434,
'OPTIONS': {
'connect_timeout': 1000,
},
'TEST': {
'MIRROR': 'default',
},
},
'read-only': {
'ENGINE': 'django.db.backends.postgresql',
'HOST': 'localhost',
'NAME': 'read-only-database',
'USER': 'postgres',
'PASSWORD': 'postgres',
"PORT": 5434,
'OPTIONS': {
'connect_timeout': 1000,
'options': '-c default_transaction_read_only=on',
},
'TEST': {
'MIRROR': 'default',
},
}
}
except ImportError:
DATABASES = {
Expand All @@ -66,6 +83,13 @@
'timeout': 1000,
}
},
'read-only': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
'OPTIONS': {
'timeout': 1000,
}
}
}

# Quick-start development settings - unsuitable for production
Expand Down
99 changes: 62 additions & 37 deletions t/unit/backends/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,8 +922,24 @@ def test_backend_result_extended_is_false(self):
assert tr.task_kwargs is None


class DjangoCeleryResultRouter:
route_app_labels = {"django_celery_results"}

def db_for_read(self, model, **hints):
"""Route read access to the read-only database"""
if model._meta.app_label in self.route_app_labels:
return "read-only"
return None

def db_for_write(self, model, **hints):
"""Route write access to the default database"""
if model._meta.app_label in self.route_app_labels:
return "default"
return None


class ChordPartReturnTestCase(TransactionTestCase):
databases = "__all__"
databases = {"default", "read-only"}

def setUp(self):
super().setUp()
Expand All @@ -938,39 +954,48 @@ def test_on_chord_part_return_multiple_databases(self):
Test if the ChordCounter is properly decremented and the callback is
triggered after all chord parts have returned with multiple databases
"""
gid = uuid()
tid1 = uuid()
tid2 = uuid()
subtasks = [AsyncResult(tid1), AsyncResult(tid2)]
group = GroupResult(id=gid, results=subtasks)
self.b.apply_chord(group, self.add.s())

chord_counter = ChordCounter.objects.using(
"secondary"
).get(group_id=gid)
assert chord_counter.count == 2

request = mock.MagicMock()
request.id = subtasks[0].id
request.group = gid
request.task = "my_task"
request.args = ["a", 1, "password"]
request.kwargs = {"c": 3, "d": "e", "password": "password"}
request.argsrepr = "argsrepr"
request.kwargsrepr = "kwargsrepr"
request.hostname = "celery@ip-0-0-0-0"
request.properties = {"periodic_task_name": "my_periodic_task"}
request.ignore_result = False
result = {"foo": "baz"}

self.b.mark_as_done(tid1, result, request=request)

chord_counter.refresh_from_db()
assert chord_counter.count == 1

self.b.mark_as_done(tid2, result, request=request)

with pytest.raises(ChordCounter.DoesNotExist):
ChordCounter.objects.using("secondary").get(group_id=gid)

request.chord.delay.assert_called_once()
with self.settings(DATABASE_ROUTERS=[DjangoCeleryResultRouter()]):
gid = uuid()
tid1 = uuid()
tid2 = uuid()
subtasks = [AsyncResult(tid1), AsyncResult(tid2)]
group = GroupResult(id=gid, results=subtasks)

assert ChordCounter.objects.count() == 0
assert ChordCounter.objects.using("read-only").count() == 0
assert ChordCounter.objects.using("default").count() == 0

self.b.apply_chord(group, self.add.s())

# Check if the ChordCounter was created in the correct database
assert ChordCounter.objects.count() == 1
assert ChordCounter.objects.using("read-only").count() == 1
assert ChordCounter.objects.using("default").count() == 1

chord_counter = ChordCounter.objects.get(group_id=gid)
assert chord_counter.count == 2

request = mock.MagicMock()
request.id = subtasks[0].id
request.group = gid
request.task = "my_task"
request.args = ["a", 1, "password"]
request.kwargs = {"c": 3, "d": "e", "password": "password"}
request.argsrepr = "argsrepr"
request.kwargsrepr = "kwargsrepr"
request.hostname = "celery@ip-0-0-0-0"
request.properties = {"periodic_task_name": "my_periodic_task"}
request.ignore_result = False
result = {"foo": "baz"}

self.b.mark_as_done(tid1, result, request=request)

chord_counter.refresh_from_db()
assert chord_counter.count == 1

self.b.mark_as_done(tid2, result, request=request)

with pytest.raises(ChordCounter.DoesNotExist):
ChordCounter.objects.get(group_id=gid)

request.chord.delay.assert_called_once()

0 comments on commit 223b8de

Please sign in to comment.