diff --git a/django_celery_beat/admin.py b/django_celery_beat/admin.py index 2c462656..89db3757 100644 --- a/django_celery_beat/admin.py +++ b/django_celery_beat/admin.py @@ -249,11 +249,18 @@ def run_tasks(self, request, queryset): return task_ids = [ - task.apply_async(args=args, kwargs=kwargs, queue=queue, - periodic_task_name=periodic_task_name) + task.apply_async( + args=args, + kwargs=kwargs, + queue=queue, + headers={'periodic_task_name': periodic_task_name} + ) if queue and len(queue) - else task.apply_async(args=args, kwargs=kwargs, - periodic_task_name=periodic_task_name) + else task.apply_async( + args=args, + kwargs=kwargs, + headers={'periodic_task_name': periodic_task_name} + ) for task, args, kwargs, queue, periodic_task_name in tasks ] tasks_run = len(task_ids) diff --git a/django_celery_beat/schedulers.py b/django_celery_beat/schedulers.py index 846b97a9..30b1ae76 100644 --- a/django_celery_beat/schedulers.py +++ b/django_celery_beat/schedulers.py @@ -77,8 +77,9 @@ def __init__(self, model, app=None): if getattr(model, 'expires_', None): self.options['expires'] = getattr(model, 'expires_') - self.options['headers'] = loads(model.headers or '{}') - self.options['periodic_task_name'] = model.name + headers = loads(model.headers or '{}') + headers['periodic_task_name'] = model.name + self.options['headers'] = headers self.total_run_count = model.total_run_count self.model = model diff --git a/t/unit/test_schedulers.py b/t/unit/test_schedulers.py index d070bb45..109f2e31 100644 --- a/t/unit/test_schedulers.py +++ b/t/unit/test_schedulers.py @@ -130,8 +130,8 @@ def test_entry(self): assert e.options['exchange'] == 'foo' assert e.options['routing_key'] == 'cpu' assert e.options['priority'] == 1 - assert e.options['headers'] == {'_schema_name': 'foobar'} - assert e.options['periodic_task_name'] == m.name + assert e.options['headers']['_schema_name'] == 'foobar' + assert e.options['headers']['periodic_task_name'] == m.name right_now = self.app.now() m2 = self.create_model_interval( @@ -869,3 +869,16 @@ def test_run_tasks(self): assert len(self.request._messages._queued_messages) == 1 queued_message = self.request._messages._queued_messages[0].message assert queued_message == '2 tasks were successfully run' + + @pytest.mark.timeout(5) + def test_run_task_headers(self, monkeypatch): + def mock_apply_async(*args, **kwargs): + self.captured_headers = kwargs.get('headers', {}) + + monkeypatch.setattr('celery.app.task.Task.apply_async', + mock_apply_async) + ma = PeriodicTaskAdmin(PeriodicTask, self.site) + self.request = self.patch_request(self.request_factory.get('/')) + ma.run_tasks(self.request, PeriodicTask.objects.filter(id=self.m1.id)) + assert 'periodic_task_name' in self.captured_headers + assert self.captured_headers['periodic_task_name'] == self.m1.name