Skip to content

Commit

Permalink
fix gtfs_ride_stops to allow querying on single gtfs_ride_id without …
Browse files Browse the repository at this point in the history
…times
  • Loading branch information
OriHoch committed Jun 19, 2024
1 parent e557d65 commit 858bcab
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
7 changes: 5 additions & 2 deletions open_bus_stride_api/routers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,16 @@ def get_route_params_with_prefix(name_prefix, doc_param_prefix, route_params):
]


def add_api_router_list(router, tag, pydantic_model, what_plural, route_params):
def add_api_router_list(router, tag, pydantic_model, what_plural, route_params, description=None):

def _decorator(func):
func.__signature__ = inspect.signature(func).replace(parameters=[
route_param.get_signature_parameter() for route_param in route_params
])
router.add_api_route("/list", func, tags=[tag], response_model=typing.List[pydantic_model], description=f'List of {what_plural}.')
router.add_api_route(
"/list", func, tags=[tag], response_model=typing.List[pydantic_model],
description=description if description else f'List of {what_plural}.'
)
return func

return _decorator
Expand Down
29 changes: 22 additions & 7 deletions open_bus_stride_api/routers/gtfs_ride_stops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import datetime
from textwrap import dedent

import pydantic
from fastapi import APIRouter, HTTPException

from open_bus_stride_db import model
from pydantic.fields import Undefined

from . import common, gtfs_rides, gtfs_stops, gtfs_routes

Expand Down Expand Up @@ -67,12 +67,12 @@ def _post_session_query_hook(session_query):
gtfs_ride_stop_filter_params = [
common.RouteParam(
'arrival_time_from', datetime.datetime,
common.DocParam('arrival time from', filter_type='datetime_from', default=Undefined),
common.DocParam('arrival time from', filter_type='datetime_from'),
{'type': 'datetime_from', 'field': model.GtfsRideStop.arrival_time},
),
common.RouteParam(
'arrival_time_to', datetime.datetime,
common.DocParam('arrival time to', filter_type='datetime_to', default=Undefined),
common.DocParam('arrival time to', filter_type='datetime_to'),
{'type': 'datetime_to', 'field': model.GtfsRideStop.arrival_time},
),
common.RouteParam(
Expand Down Expand Up @@ -112,11 +112,26 @@ def _post_session_query_hook(session_query):
]


@common.add_api_router_list(router, TAG, GtfsRideStopWithRelatedPydanticModel, WHAT_PLURAL, gtfs_ride_stop_list_params)
@common.add_api_router_list(
router, TAG, GtfsRideStopWithRelatedPydanticModel, WHAT_PLURAL, gtfs_ride_stop_list_params,
description=dedent("""
List of gtfs ride stops.
Due to large number of items in the table, you must filter the results by at least one of the following:
1. gtfs_ride_ids - containing a single gtfs ride id.
2. arrival_time_from and arrival_time_to - containing a time range.
Additional filters can be applied in addition to one of the above options to narrow down the results.
""").strip()
)
def list_(**kwargs):
# Validate arrival_time range is no longer than 30 days (to avoid heavy queries)
if (kwargs['arrival_time_to'] - kwargs['arrival_time_from']).days > 30:
raise HTTPException(status_code=400, detail="Time range is longer than 30 days")
if not kwargs.get('gtfs_ride_ids') or ',' in kwargs['gtfs_ride_ids']:
# Validate arrival_time range is no longer than 30 days (to avoid heavy queries)
if not kwargs.get('arrival_time_from') or not kwargs.get('arrival_time_to'):
raise HTTPException(status_code=400, detail="arrival_time_from and arrival_time_to are required")
if (kwargs['arrival_time_to'] - kwargs['arrival_time_from']).days > 30:
raise HTTPException(status_code=400, detail="Time range is longer than 30 days")

return common.get_list(
SQL_MODEL, kwargs['limit'], kwargs['offset'],
Expand Down

0 comments on commit 858bcab

Please sign in to comment.