diff --git a/events/api.py b/events/api.py index f26ab0665..d03d59f93 100644 --- a/events/api.py +++ b/events/api.py @@ -1553,32 +1553,6 @@ def perform_destroy(self, instance): register_view(ImageViewSet, "image", base_name="image") -def parse_duration_string(duration): - """ - Parse duration string expressed in format - 86400 or 86400s (24 hours) - 180m or 3h (3 hours) - 3d (3 days) - """ - m = re.match(r"(\d+)\s*(d|h|m|s)?$", duration.strip().lower()) - if not m: - raise ParseError("Invalid duration supplied. Try '1d', '2h' or '180m'.") - val, unit = m.groups() - if not unit: - unit = "s" - - if unit == "s": - mul = 1 - elif unit == "m": - mul = 60 - elif unit == "h": - mul = 3600 - elif unit == "d": - mul = 24 * 3600 - - return int(val) * mul - - def _terms_to_regex(terms, operator: Literal["AND", "OR"]): """ Create a compiled regex from of the provided terms of the form @@ -2111,18 +2085,6 @@ def _filter_event_queryset(queryset, params, srs=None): # noqa: C901 # to change this to actually filter only subevents of recurring events? queryset = queryset.exclude(super_event_type=Event.SuperEventType.RECURRING) - val = params.get("max_duration", None) - if val: - dur = parse_duration_string(val) - cond = "end_time - start_time <= %s :: interval" - queryset = queryset.extra(where=[cond], params=[str(dur)]) - - val = params.get("min_duration", None) - if val: - dur = parse_duration_string(val) - cond = "end_time - start_time >= %s :: interval" - queryset = queryset.extra(where=[cond], params=[str(dur)]) - # Filter by publisher, multiple sources separated by comma val = params.get("publisher", None) if val: diff --git a/events/filters.py b/events/filters.py index 7259b15cd..0b7a5e615 100644 --- a/events/filters.py +++ b/events/filters.py @@ -1,4 +1,5 @@ -from datetime import datetime +import re +from datetime import datetime, timedelta from datetime import timezone as datetime_timezone from functools import partial from typing import Iterable, Optional @@ -9,7 +10,16 @@ from django.contrib.gis.geos import Point from django.contrib.gis.measure import D from django.contrib.postgres.search import SearchQuery, SearchRank -from django.db.models import Case, Exists, F, OuterRef, Q, When +from django.db.models import ( + Case, + DurationField, + Exists, + ExpressionWrapper, + F, + OuterRef, + Q, + When, +) from django.db.models import DateTimeField as ModelDateTimeField from django.utils import timezone from django.utils.translation import gettext_lazy as _ @@ -157,6 +167,32 @@ def filter_division(queryset, name: str, value: Iterable[str]): return queryset.filter(**{name + "__name__in": names}) +def parse_duration_string(duration) -> timedelta: + """ + Parse duration string expressed in format + 86400 or 86400s (24 hours) + 180m or 3h (3 hours) + 3d (3 days) + """ + m = re.match(r"(\d+)\s*(d|h|m|s)?$", duration.strip().lower()) + if not m: + raise ParseError("Invalid duration supplied. Try '1d', '2h' or '180m'.") + val, unit = m.groups() + if not unit: + unit = "s" + + if unit == "s": + mul = 1 + elif unit == "m": + mul = 60 + elif unit == "h": + mul = 3600 + elif unit == "d": + mul = 24 * 3600 + + return timedelta(seconds=int(val) * mul) + + class EventOrderingFilter(LinkedEventsOrderingFilter): def filter_queryset(self, request, queryset, view): ordering = self.get_ordering(request, queryset, view) @@ -264,6 +300,16 @@ class EventFilter(django_filters.rest_framework.FilterSet): "Search for events where enrolment is open on a given date or datetime." ), ) + min_duration = django_filters.CharFilter( + method="filter_min_duration", + help_text=_("Search for events that are at least as long as given duration."), + ) + max_duration = django_filters.CharFilter( + method="filter_max_duration", + help_text=_( + "Search for events that are at most as long as the given duration." + ), + ) class Meta: model = Event @@ -272,6 +318,28 @@ class Meta: "registration__remaining_waiting_list_capacity": ["gte", "isnull"], } + def filter_max_duration(self, queryset, name, value): + if not value: + return queryset + + max_duration = parse_duration_string(value) + return queryset.annotate( + duration=ExpressionWrapper( + F("end_time") - F("start_time"), output_field=DurationField() + ) + ).filter(duration__lte=max_duration) + + def filter_min_duration(self, queryset, name, value): + if not value: + return queryset + + min_duration = parse_duration_string(value) + return queryset.annotate( + duration=ExpressionWrapper( + F("end_time") - F("start_time"), output_field=DurationField() + ) + ).filter(duration__gte=min_duration) + def filter_enrolment_open_on(self, queryset, name, value: datetime): value = value.astimezone(ZoneInfo(settings.TIME_ZONE)) diff --git a/events/tests/test_event_filters.py b/events/tests/test_event_filters.py index d74aa285a..6c6b493ba 100644 --- a/events/tests/test_event_filters.py +++ b/events/tests/test_event_filters.py @@ -147,3 +147,33 @@ def test_get_event_list_ongoing(ongoing): assert event_end_after in qs else: assert event_end_before in qs + + +@pytest.mark.django_db +def test_get_event_list_min_duration(): + EventFactory(start_time=timezone.now(), end_time=timezone.now()) + event_long = EventFactory( + start_time=timezone.now(), end_time=timezone.now() + timedelta(hours=1) + ) + + filter_set = EventFilter() + + qs = filter_set.filter_min_duration(Event.objects.all(), "min_duration", "1h") + + assert qs.count() == 1 + assert event_long in qs + + +@pytest.mark.django_db +def test_get_event_list_max_duration(): + event_short = EventFactory(start_time=timezone.now(), end_time=timezone.now()) + EventFactory( + start_time=timezone.now(), end_time=timezone.now() + timedelta(hours=1) + ) + + filter_set = EventFilter() + + qs = filter_set.filter_max_duration(Event.objects.all(), "max_duration", "1h") + + assert qs.count() == 1 + assert event_short in qs diff --git a/events/tests/test_event_get.py b/events/tests/test_event_get.py index 33f874924..98f92590d 100644 --- a/events/tests/test_event_get.py +++ b/events/tests/test_event_get.py @@ -10,6 +10,7 @@ from django.db import DEFAULT_DB_ALIAS, connections from django.test import TestCase from django.test.utils import CaptureQueriesContext +from django.utils import timezone from django.utils.timezone import localtime from freezegun import freeze_time from rest_framework import status @@ -18,7 +19,12 @@ from events.models import Event, Language, PublicationStatus from events.tests.conftest import APIClient from events.tests.factories import EventFactory, OfferFactory -from events.tests.utils import assert_fields_exist, datetime_zone_aware, get +from events.tests.utils import ( + assert_fields_exist, + create_super_event, + datetime_zone_aware, + get, +) from events.tests.utils import versioned_reverse as reverse from registrations.tests.factories import OfferPriceGroupFactory, RegistrationFactory @@ -698,6 +704,32 @@ def test_publication_status_filter( ) +@pytest.mark.django_db +def test_get_event_list_max_duration_with_hide_recurring_children( + api_client, event, event2 +): + """ + When there are hide_recurring_children filter it makes joins which with previous + solution caused malformed sql query. This test makes sure that the query is valid + and returns correct results with max_duration with hide_recurring_children filters. + """ + event.start_time = timezone.now() + timedelta(days=10) + event.end_time = event.start_time + timedelta(days=10, hours=1) + event.super_event_type = Event.SuperEventType.RECURRING + event2.data_source = event.data_source + event2.start_time = timezone.now() + timedelta(days=20) + event2.end_time = timezone.now() + timedelta(days=20, hours=1) + event2.save() + super_event = create_super_event([event, event2], event.data_source) + super_event.start_time = timezone.now() + timedelta(hours=1) + super_event.end_time = timezone.now() + timedelta(hours=2) + super_event.save() + + get_list_and_assert_events( + "max_duration=2h&hide_recurring_children=true", [super_event] + ) + + @pytest.mark.django_db def test_event_status_filter( api_client, event, event2, event3, event4, user, organization, data_source