Skip to content

Commit

Permalink
Refactor db layer to access results in a consistent way (#137)
Browse files Browse the repository at this point in the history
* Return row instead of dict when fetching a dataset from database

This allows for attribute access instead of mapping-style access.
The latter more closely resembles how we will later interact with
the ORM objects.

* Add simple get_dataset test case

* Change more dataset functions to return Row instead of Dict

To make attribute access possible, which will be more similar
when we do switch over to ORM

* Use one_or_none for get flow to return Row instead

Since we are selecting based on the primary key
we should always get at most one result.

* Use Row results directly

* Prefer SQLAlchemy 2.0 method names

* Use newer style SQLAlchemy method instead of dictionaries and old

* Change CursorMapping to Sequence[Row] for consistency
  • Loading branch information
PGijsbers authored Jan 3, 2024
1 parent 78073cd commit 99b5376
Show file tree
Hide file tree
Showing 14 changed files with 210 additions and 182 deletions.
9 changes: 4 additions & 5 deletions src/core/access.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from typing import Any

from database.users import User, UserGroup
from schemas.datasets.openml import Visibility
from sqlalchemy.engine import Row


def _user_has_access(
dataset: dict[str, Any],
dataset: Row,
user: User | None = None,
) -> bool:
"""Determine if `user` has the right to view `dataset`."""
is_public = dataset["visibility"] == Visibility.PUBLIC
is_public = dataset.visibility == Visibility.PUBLIC
return is_public or (
user is not None and (user.user_id == dataset["uploader"] or UserGroup.ADMIN in user.groups)
user is not None and (user.user_id == dataset.uploader or UserGroup.ADMIN in user.groups)
)
14 changes: 7 additions & 7 deletions src/core/formatting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import html
from typing import Any

from schemas.datasets.openml import DatasetFileFormat
from sqlalchemy.engine import Row

from core.errors import DatasetError

Expand All @@ -20,18 +20,18 @@ def _format_error(*, code: DatasetError, message: str) -> dict[str, str]:
return {"code": str(code), "message": message}


def _format_parquet_url(dataset: dict[str, Any]) -> str | None:
if dataset["format"].lower() != DatasetFileFormat.ARFF:
def _format_parquet_url(dataset: Row) -> str | None:
if dataset.format.lower() != DatasetFileFormat.ARFF:
return None

minio_base_url = "https://openml1.win.tue.nl"
return f"{minio_base_url}/dataset{dataset['did']}/dataset_{dataset['did']}.pq"
return f"{minio_base_url}/dataset{dataset.did}/dataset_{dataset.did}.pq"


def _format_dataset_url(dataset: dict[str, Any]) -> str:
def _format_dataset_url(dataset: Row) -> str:
base_url = "https://test.openml.org"
filename = f"{html.escape(dataset['name'])}.{dataset['format'].lower()}"
return f"{base_url}/data/v1/download/{dataset['file_id']}/{filename}"
filename = f"{html.escape(dataset.name)}.{dataset.format.lower()}"
return f"{base_url}/data/v1/download/{dataset.file_id}/{filename}"


def _safe_unquote(text: str | None) -> str | None:
Expand Down
39 changes: 16 additions & 23 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
""" Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707"""
import datetime
from collections import defaultdict
from typing import Any, Iterable
from typing import Iterable

from schemas.datasets.openml import Feature, Quality
from sqlalchemy import Connection, text

from database.meta import get_column_names
from sqlalchemy.engine import Row


def get_qualities_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]:
Expand All @@ -23,19 +22,20 @@ def get_qualities_for_dataset(dataset_id: int, connection: Connection) -> list[Q
return [Quality(name=row.quality, value=row.value) for row in rows]


def get_qualities_for_datasets(
def _get_qualities_for_datasets(
dataset_ids: Iterable[int],
qualities: Iterable[str],
connection: Connection,
) -> dict[int, list[Quality]]:
"""Don't call with user-provided input, as query is not parameterized."""
qualities_filter = ",".join(f"'{q}'" for q in qualities)
dids = ",".join(str(did) for did in dataset_ids)
qualities_query = text(
f"""
SELECT `data`, `quality`, `value`
FROM data_quality
WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter})
""", # nosec - similar to above, no user input
""", # nosec - dids and qualities are not user-provided
)
rows = connection.execute(qualities_query)
qualities_by_id = defaultdict(list)
Expand All @@ -59,8 +59,7 @@ def list_all_qualities(connection: Connection) -> list[str]:
return [quality.quality for quality in qualities]


def get_dataset(dataset_id: int, connection: Connection) -> dict[str, Any] | None:
columns = get_column_names(connection, "dataset")
def get_dataset(dataset_id: int, connection: Connection) -> Row | None:
row = connection.execute(
text(
"""
Expand All @@ -71,11 +70,10 @@ def get_dataset(dataset_id: int, connection: Connection) -> dict[str, Any] | Non
),
parameters={"dataset_id": dataset_id},
)
return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None
return row.one_or_none()


def get_file(file_id: int, connection: Connection) -> dict[str, Any] | None:
columns = get_column_names(connection, "file")
def get_file(file_id: int, connection: Connection) -> Row | None:
row = connection.execute(
text(
"""
Expand All @@ -86,11 +84,10 @@ def get_file(file_id: int, connection: Connection) -> dict[str, Any] | None:
),
parameters={"file_id": file_id},
)
return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None
return row.one_or_none()


def get_tags(dataset_id: int, connection: Connection) -> list[str]:
columns = get_column_names(connection, "dataset_tag")
rows = connection.execute(
text(
"""
Expand All @@ -101,7 +98,7 @@ def get_tags(dataset_id: int, connection: Connection) -> list[str]:
),
parameters={"dataset_id": dataset_id},
)
return [dict(zip(columns, row, strict=True))["tag"] for row in rows]
return [row.tag for row in rows]


def tag_dataset(user_id: int, dataset_id: int, tag: str, connection: Connection) -> None:
Expand All @@ -123,8 +120,7 @@ def tag_dataset(user_id: int, dataset_id: int, tag: str, connection: Connection)
def get_latest_dataset_description(
dataset_id: int,
connection: Connection,
) -> dict[str, Any] | None:
columns = get_column_names(connection, "dataset_description")
) -> Row | None:
row = connection.execute(
text(
"""
Expand All @@ -136,10 +132,10 @@ def get_latest_dataset_description(
),
parameters={"dataset_id": dataset_id},
)
return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None
return row.one_or_none()


def get_latest_status_update(dataset_id: int, connection: Connection) -> dict[str, Any] | None:
def get_latest_status_update(dataset_id: int, connection: Connection) -> Row | None:
row = connection.execute(
text(
"""
Expand All @@ -151,11 +147,10 @@ def get_latest_status_update(dataset_id: int, connection: Connection) -> dict[st
),
parameters={"dataset_id": dataset_id},
)
return next(row.mappings(), None)
return row.first()


def get_latest_processing_update(dataset_id: int, connection: Connection) -> dict[str, Any] | None:
columns = get_column_names(connection, "data_processed")
def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row | None:
row = connection.execute(
text(
"""
Expand All @@ -167,9 +162,7 @@ def get_latest_processing_update(dataset_id: int, connection: Connection) -> dic
),
parameters={"dataset_id": dataset_id},
)
return (
dict(zip(columns, result[0], strict=True), strict=True) if (result := list(row)) else None
)
return row.one_or_none()


def get_features_for_dataset(dataset_id: int, connection: Connection) -> list[Feature]:
Expand Down
21 changes: 12 additions & 9 deletions src/database/evaluations.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from typing import Any, Iterable
from typing import Sequence, cast

from core.formatting import _str_to_bool
from schemas.datasets.openml import EstimationProcedure
from sqlalchemy import Connection, CursorResult, text
from sqlalchemy import Connection, Row, text


def get_math_functions(function_type: str, connection: Connection) -> CursorResult[Any]:
return connection.execute(
text(
"""
def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
connection.execute(
text(
"""
SELECT *
FROM math_function
WHERE `functionType` = :function_type
""",
),
parameters={"function_type": function_type},
),
parameters={"function_type": function_type},
).all(),
)


def get_estimation_procedures(connection: Connection) -> Iterable[EstimationProcedure]:
def get_estimation_procedures(connection: Connection) -> list[EstimationProcedure]:
rows = connection.execute(
text(
"""
Expand Down
36 changes: 21 additions & 15 deletions src/database/flows.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from typing import Any
from typing import Sequence, cast

from sqlalchemy import Connection, CursorResult, text
from sqlalchemy import Connection, Row, text


def get_flow_subflows(flow_id: int, expdb: Connection) -> CursorResult[Any]:
return expdb.execute(
text(
"""
def get_flow_subflows(flow_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
expdb.execute(
text(
"""
SELECT child as child_id, identifier
FROM implementation_component
WHERE parent = :flow_id
""",
),
parameters={"flow_id": flow_id},
),
parameters={"flow_id": flow_id},
)


def get_flow_tags(flow_id: int, expdb: Connection) -> CursorResult[Any]:
def get_flow_tags(flow_id: int, expdb: Connection) -> list[str]:
tag_rows = expdb.execute(
text(
"""
Expand All @@ -30,20 +33,23 @@ def get_flow_tags(flow_id: int, expdb: Connection) -> CursorResult[Any]:
return [tag.tag for tag in tag_rows]


def get_flow_parameters(flow_id: int, expdb: Connection) -> CursorResult[Any]:
return expdb.execute(
text(
"""
def get_flow_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
expdb.execute(
text(
"""
SELECT *, defaultValue as default_value, dataType as data_type
FROM input
WHERE implementation_id = :flow_id
""",
),
parameters={"flow_id": flow_id},
),
parameters={"flow_id": flow_id},
)


def get_flow(flow_id: int, expdb: Connection) -> CursorResult[Any]:
def get_flow(flow_id: int, expdb: Connection) -> Row | None:
return expdb.execute(
text(
"""
Expand All @@ -53,4 +59,4 @@ def get_flow(flow_id: int, expdb: Connection) -> CursorResult[Any]:
""",
),
parameters={"flow_id": flow_id},
)
).one_or_none()
17 changes: 0 additions & 17 deletions src/database/meta.py

This file was deleted.

22 changes: 11 additions & 11 deletions src/database/studies.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import re
from datetime import datetime
from typing import cast
from typing import Sequence, cast

from schemas.study import CreateStudy, StudyType
from sqlalchemy import Connection, Row, text

from database.users import User


def get_study_by_id(study_id: int, connection: Connection) -> Row:
def get_study_by_id(study_id: int, connection: Connection) -> Row | None:
return connection.execute(
text(
"""
Expand All @@ -18,10 +18,10 @@ def get_study_by_id(study_id: int, connection: Connection) -> Row:
""",
),
parameters={"study_id": study_id},
).fetchone()
).one_or_none()


def get_study_by_alias(alias: str, connection: Connection) -> Row:
def get_study_by_alias(alias: str, connection: Connection) -> Row | None:
return connection.execute(
text(
"""
Expand All @@ -31,13 +31,13 @@ def get_study_by_alias(alias: str, connection: Connection) -> Row:
""",
),
parameters={"study_id": alias},
).fetchone()
).one_or_none()


def get_study_data(study: Row, expdb: Connection) -> list[Row]:
def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]:
if study.type_ == StudyType.TASK:
return cast(
list[Row],
Sequence[Row],
expdb.execute(
text(
"""
Expand All @@ -47,10 +47,10 @@ def get_study_data(study: Row, expdb: Connection) -> list[Row]:
""",
),
parameters={"study_id": study.id},
).fetchall(),
).all(),
)
return cast(
list[Row],
Sequence[Row],
expdb.execute(
text(
"""
Expand All @@ -68,7 +68,7 @@ def get_study_data(study: Row, expdb: Connection) -> list[Row]:
""",
),
parameters={"study_id": study.id},
).fetchall(),
).all(),
)


Expand Down Expand Up @@ -96,7 +96,7 @@ def create_study(study: CreateStudy, user: User, expdb: Connection) -> int:
"benchmark_suite": study.benchmark_suite,
},
)
(study_id,) = expdb.execute(text("""SELECT LAST_INSERT_ID();""")).fetchone()
(study_id,) = expdb.execute(text("""SELECT LAST_INSERT_ID();""")).one()
return cast(int, study_id)


Expand Down
Loading

0 comments on commit 99b5376

Please sign in to comment.