Skip to content

Commit

Permalink
Fix atomic transaction not routing to the the correct DB in DatabaseB…
Browse files Browse the repository at this point in the history
…ackend.on_chord_part_return transaction.atomic (#427)

* using ChordCounter.objects.db in DatabaseBackend.on_chord_part_return transaction.atomic

* WIP testing on chord part return with multiple databases

* WIP testing on chord part return with multiple databases pre-committed

* Completed testing on chord part return with multiple databases

* Changed tests and transaction atomic using according to pull request message

* Removed ports from settings

---------

Co-authored-by: Davide Ria <[email protected]>
  • Loading branch information
gianbot and dvdria authored May 4, 2024
1 parent 23265e6 commit d72cad3
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 2 deletions.
4 changes: 2 additions & 2 deletions django_celery_results/backends/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from celery.result import GroupResult, allow_join_result, result_from_tuple
from celery.utils.log import get_logger
from celery.utils.serialization import b64decode, b64encode
from django.db import connection, transaction
from django.db import connection, router, transaction
from django.db.utils import InterfaceError
from kombu.exceptions import DecodeError

Expand Down Expand Up @@ -246,7 +246,7 @@ def on_chord_part_return(self, request, state, result, **kwargs):
if not gid or not tid:
return
call_callback = False
with transaction.atomic():
with transaction.atomic(using=router.db_for_write(ChordCounter)):
# We need to know if `count` hits 0.
# wrap the update in a transaction
# with a `select_for_update` lock to prevent race conditions.
Expand Down
21 changes: 21 additions & 0 deletions t/proj/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,20 @@
'MIRROR': 'default',
},
},
'read-only': {
'ENGINE': 'django.db.backends.postgresql',
'HOST': 'localhost',
'NAME': 'read-only-database',
'USER': 'postgres',
'PASSWORD': 'postgres',
'OPTIONS': {
'connect_timeout': 1000,
'options': '-c default_transaction_read_only=on',
},
'TEST': {
'MIRROR': 'default',
},
}
}
except ImportError:
DATABASES = {
Expand All @@ -66,6 +80,13 @@
'timeout': 1000,
}
},
'read-only': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
'OPTIONS': {
'timeout': 1000,
}
}
}

# Quick-start development settings - unsuitable for production
Expand Down
80 changes: 80 additions & 0 deletions t/unit/backends/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from celery.utils.serialization import b64decode
from celery.worker.request import Request
from celery.worker.strategy import hybrid_to_proto2
from django.test import TransactionTestCase

from django_celery_results.backends.database import DatabaseBackend
from django_celery_results.models import ChordCounter, TaskResult
Expand Down Expand Up @@ -919,3 +920,82 @@ def test_backend_result_extended_is_false(self):
tr = TaskResult.objects.get(task_id=tid2)
assert tr.task_args is None
assert tr.task_kwargs is None


class DjangoCeleryResultRouter:
route_app_labels = {"django_celery_results"}

def db_for_read(self, model, **hints):
"""Route read access to the read-only database"""
if model._meta.app_label in self.route_app_labels:
return "read-only"
return None

def db_for_write(self, model, **hints):
"""Route write access to the default database"""
if model._meta.app_label in self.route_app_labels:
return "default"
return None


class ChordPartReturnTestCase(TransactionTestCase):
databases = {"default", "read-only"}

def setUp(self):
super().setUp()
self.app.conf.result_serializer = 'json'
self.app.conf.result_backend = (
'django_celery_results.backends:DatabaseBackend')
self.app.conf.result_extended = True
self.b = DatabaseBackend(app=self.app)

def test_on_chord_part_return_multiple_databases(self):
"""
Test if the ChordCounter is properly decremented and the callback is
triggered after all chord parts have returned with multiple databases
"""
with self.settings(DATABASE_ROUTERS=[DjangoCeleryResultRouter()]):
gid = uuid()
tid1 = uuid()
tid2 = uuid()
subtasks = [AsyncResult(tid1), AsyncResult(tid2)]
group = GroupResult(id=gid, results=subtasks)

assert ChordCounter.objects.count() == 0
assert ChordCounter.objects.using("read-only").count() == 0
assert ChordCounter.objects.using("default").count() == 0

self.b.apply_chord(group, self.add.s())

# Check if the ChordCounter was created in the correct database
assert ChordCounter.objects.count() == 1
assert ChordCounter.objects.using("read-only").count() == 1
assert ChordCounter.objects.using("default").count() == 1

chord_counter = ChordCounter.objects.get(group_id=gid)
assert chord_counter.count == 2

request = mock.MagicMock()
request.id = subtasks[0].id
request.group = gid
request.task = "my_task"
request.args = ["a", 1, "password"]
request.kwargs = {"c": 3, "d": "e", "password": "password"}
request.argsrepr = "argsrepr"
request.kwargsrepr = "kwargsrepr"
request.hostname = "celery@ip-0-0-0-0"
request.properties = {"periodic_task_name": "my_periodic_task"}
request.ignore_result = False
result = {"foo": "baz"}

self.b.mark_as_done(tid1, result, request=request)

chord_counter.refresh_from_db()
assert chord_counter.count == 1

self.b.mark_as_done(tid2, result, request=request)

with pytest.raises(ChordCounter.DoesNotExist):
ChordCounter.objects.get(group_id=gid)

request.chord.delay.assert_called_once()

0 comments on commit d72cad3

Please sign in to comment.