diff --git a/src/core/access.py b/src/core/access.py index 90f1676..a45238c 100644 --- a/src/core/access.py +++ b/src/core/access.py @@ -1,15 +1,14 @@ -from typing import Any - from database.users import User, UserGroup from schemas.datasets.openml import Visibility +from sqlalchemy.engine import Row def _user_has_access( - dataset: dict[str, Any], + dataset: Row, user: User | None = None, ) -> bool: """Determine if `user` has the right to view `dataset`.""" - is_public = dataset["visibility"] == Visibility.PUBLIC + is_public = dataset.visibility == Visibility.PUBLIC return is_public or ( - user is not None and (user.user_id == dataset["uploader"] or UserGroup.ADMIN in user.groups) + user is not None and (user.user_id == dataset.uploader or UserGroup.ADMIN in user.groups) ) diff --git a/src/core/formatting.py b/src/core/formatting.py index e298b3e..7eb2447 100644 --- a/src/core/formatting.py +++ b/src/core/formatting.py @@ -1,7 +1,7 @@ import html -from typing import Any from schemas.datasets.openml import DatasetFileFormat +from sqlalchemy.engine import Row from core.errors import DatasetError @@ -20,18 +20,18 @@ def _format_error(*, code: DatasetError, message: str) -> dict[str, str]: return {"code": str(code), "message": message} -def _format_parquet_url(dataset: dict[str, Any]) -> str | None: - if dataset["format"].lower() != DatasetFileFormat.ARFF: +def _format_parquet_url(dataset: Row) -> str | None: + if dataset.format.lower() != DatasetFileFormat.ARFF: return None minio_base_url = "https://openml1.win.tue.nl" - return f"{minio_base_url}/dataset{dataset['did']}/dataset_{dataset['did']}.pq" + return f"{minio_base_url}/dataset{dataset.did}/dataset_{dataset.did}.pq" -def _format_dataset_url(dataset: dict[str, Any]) -> str: +def _format_dataset_url(dataset: Row) -> str: base_url = "https://test.openml.org" - filename = f"{html.escape(dataset['name'])}.{dataset['format'].lower()}" - return f"{base_url}/data/v1/download/{dataset['file_id']}/{filename}" + filename = f"{html.escape(dataset.name)}.{dataset.format.lower()}" + return f"{base_url}/data/v1/download/{dataset.file_id}/{filename}" def _safe_unquote(text: str | None) -> str | None: diff --git a/src/database/datasets.py b/src/database/datasets.py index fc40dc3..bc16ff2 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -1,12 +1,11 @@ """ Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707""" import datetime from collections import defaultdict -from typing import Any, Iterable +from typing import Iterable from schemas.datasets.openml import Feature, Quality from sqlalchemy import Connection, text - -from database.meta import get_column_names +from sqlalchemy.engine import Row def get_qualities_for_dataset(dataset_id: int, connection: Connection) -> list[Quality]: @@ -23,11 +22,12 @@ def get_qualities_for_dataset(dataset_id: int, connection: Connection) -> list[Q return [Quality(name=row.quality, value=row.value) for row in rows] -def get_qualities_for_datasets( +def _get_qualities_for_datasets( dataset_ids: Iterable[int], qualities: Iterable[str], connection: Connection, ) -> dict[int, list[Quality]]: + """Don't call with user-provided input, as query is not parameterized.""" qualities_filter = ",".join(f"'{q}'" for q in qualities) dids = ",".join(str(did) for did in dataset_ids) qualities_query = text( @@ -35,7 +35,7 @@ def get_qualities_for_datasets( SELECT `data`, `quality`, `value` FROM data_quality WHERE `data` in ({dids}) AND `quality` IN ({qualities_filter}) - """, # nosec - similar to above, no user input + """, # nosec - dids and qualities are not user-provided ) rows = connection.execute(qualities_query) qualities_by_id = defaultdict(list) @@ -59,8 +59,7 @@ def list_all_qualities(connection: Connection) -> list[str]: return [quality.quality for quality in qualities] -def get_dataset(dataset_id: int, connection: Connection) -> dict[str, Any] | None: - columns = get_column_names(connection, "dataset") +def get_dataset(dataset_id: int, connection: Connection) -> Row | None: row = connection.execute( text( """ @@ -71,11 +70,10 @@ def get_dataset(dataset_id: int, connection: Connection) -> dict[str, Any] | Non ), parameters={"dataset_id": dataset_id}, ) - return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None + return row.one_or_none() -def get_file(file_id: int, connection: Connection) -> dict[str, Any] | None: - columns = get_column_names(connection, "file") +def get_file(file_id: int, connection: Connection) -> Row | None: row = connection.execute( text( """ @@ -86,11 +84,10 @@ def get_file(file_id: int, connection: Connection) -> dict[str, Any] | None: ), parameters={"file_id": file_id}, ) - return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None + return row.one_or_none() def get_tags(dataset_id: int, connection: Connection) -> list[str]: - columns = get_column_names(connection, "dataset_tag") rows = connection.execute( text( """ @@ -101,7 +98,7 @@ def get_tags(dataset_id: int, connection: Connection) -> list[str]: ), parameters={"dataset_id": dataset_id}, ) - return [dict(zip(columns, row, strict=True))["tag"] for row in rows] + return [row.tag for row in rows] def tag_dataset(user_id: int, dataset_id: int, tag: str, connection: Connection) -> None: @@ -123,8 +120,7 @@ def tag_dataset(user_id: int, dataset_id: int, tag: str, connection: Connection) def get_latest_dataset_description( dataset_id: int, connection: Connection, -) -> dict[str, Any] | None: - columns = get_column_names(connection, "dataset_description") +) -> Row | None: row = connection.execute( text( """ @@ -136,10 +132,10 @@ def get_latest_dataset_description( ), parameters={"dataset_id": dataset_id}, ) - return dict(zip(columns, result[0], strict=True)) if (result := list(row)) else None + return row.one_or_none() -def get_latest_status_update(dataset_id: int, connection: Connection) -> dict[str, Any] | None: +def get_latest_status_update(dataset_id: int, connection: Connection) -> Row | None: row = connection.execute( text( """ @@ -151,11 +147,10 @@ def get_latest_status_update(dataset_id: int, connection: Connection) -> dict[st ), parameters={"dataset_id": dataset_id}, ) - return next(row.mappings(), None) + return row.first() -def get_latest_processing_update(dataset_id: int, connection: Connection) -> dict[str, Any] | None: - columns = get_column_names(connection, "data_processed") +def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row | None: row = connection.execute( text( """ @@ -167,9 +162,7 @@ def get_latest_processing_update(dataset_id: int, connection: Connection) -> dic ), parameters={"dataset_id": dataset_id}, ) - return ( - dict(zip(columns, result[0], strict=True), strict=True) if (result := list(row)) else None - ) + return row.one_or_none() def get_features_for_dataset(dataset_id: int, connection: Connection) -> list[Feature]: diff --git a/src/database/evaluations.py b/src/database/evaluations.py index 5dfbc4c..63ea740 100644 --- a/src/database/evaluations.py +++ b/src/database/evaluations.py @@ -1,24 +1,27 @@ -from typing import Any, Iterable +from typing import Sequence, cast from core.formatting import _str_to_bool from schemas.datasets.openml import EstimationProcedure -from sqlalchemy import Connection, CursorResult, text +from sqlalchemy import Connection, Row, text -def get_math_functions(function_type: str, connection: Connection) -> CursorResult[Any]: - return connection.execute( - text( - """ +def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]: + return cast( + Sequence[Row], + connection.execute( + text( + """ SELECT * FROM math_function WHERE `functionType` = :function_type """, - ), - parameters={"function_type": function_type}, + ), + parameters={"function_type": function_type}, + ).all(), ) -def get_estimation_procedures(connection: Connection) -> Iterable[EstimationProcedure]: +def get_estimation_procedures(connection: Connection) -> list[EstimationProcedure]: rows = connection.execute( text( """ diff --git a/src/database/flows.py b/src/database/flows.py index c369210..4889f13 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -1,22 +1,25 @@ -from typing import Any +from typing import Sequence, cast -from sqlalchemy import Connection, CursorResult, text +from sqlalchemy import Connection, Row, text -def get_flow_subflows(flow_id: int, expdb: Connection) -> CursorResult[Any]: - return expdb.execute( - text( - """ +def get_flow_subflows(flow_id: int, expdb: Connection) -> Sequence[Row]: + return cast( + Sequence[Row], + expdb.execute( + text( + """ SELECT child as child_id, identifier FROM implementation_component WHERE parent = :flow_id """, + ), + parameters={"flow_id": flow_id}, ), - parameters={"flow_id": flow_id}, ) -def get_flow_tags(flow_id: int, expdb: Connection) -> CursorResult[Any]: +def get_flow_tags(flow_id: int, expdb: Connection) -> list[str]: tag_rows = expdb.execute( text( """ @@ -30,20 +33,23 @@ def get_flow_tags(flow_id: int, expdb: Connection) -> CursorResult[Any]: return [tag.tag for tag in tag_rows] -def get_flow_parameters(flow_id: int, expdb: Connection) -> CursorResult[Any]: - return expdb.execute( - text( - """ +def get_flow_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]: + return cast( + Sequence[Row], + expdb.execute( + text( + """ SELECT *, defaultValue as default_value, dataType as data_type FROM input WHERE implementation_id = :flow_id """, + ), + parameters={"flow_id": flow_id}, ), - parameters={"flow_id": flow_id}, ) -def get_flow(flow_id: int, expdb: Connection) -> CursorResult[Any]: +def get_flow(flow_id: int, expdb: Connection) -> Row | None: return expdb.execute( text( """ @@ -53,4 +59,4 @@ def get_flow(flow_id: int, expdb: Connection) -> CursorResult[Any]: """, ), parameters={"flow_id": flow_id}, - ) + ).one_or_none() diff --git a/src/database/meta.py b/src/database/meta.py deleted file mode 100644 index 65e688b..0000000 --- a/src/database/meta.py +++ /dev/null @@ -1,17 +0,0 @@ -from sqlalchemy import Connection, text - - -def get_column_names(connection: Connection, table: str) -> list[str]: - *_, database_name = str(connection.engine.url).split("/") - result = connection.execute( - text( - """ - SELECT column_name - FROM INFORMATION_SCHEMA.COLUMNS - WHERE TABLE_NAME = :table_name AND TABLE_SCHEMA = :database - ORDER BY ORDINAL_POSITION - """, - ), - parameters={"table_name": table, "database": database_name}, - ) - return [colname for colname, in result.all()] diff --git a/src/database/studies.py b/src/database/studies.py index 6eaebd7..3c7c166 100644 --- a/src/database/studies.py +++ b/src/database/studies.py @@ -1,6 +1,6 @@ import re from datetime import datetime -from typing import cast +from typing import Sequence, cast from schemas.study import CreateStudy, StudyType from sqlalchemy import Connection, Row, text @@ -8,7 +8,7 @@ from database.users import User -def get_study_by_id(study_id: int, connection: Connection) -> Row: +def get_study_by_id(study_id: int, connection: Connection) -> Row | None: return connection.execute( text( """ @@ -18,10 +18,10 @@ def get_study_by_id(study_id: int, connection: Connection) -> Row: """, ), parameters={"study_id": study_id}, - ).fetchone() + ).one_or_none() -def get_study_by_alias(alias: str, connection: Connection) -> Row: +def get_study_by_alias(alias: str, connection: Connection) -> Row | None: return connection.execute( text( """ @@ -31,13 +31,13 @@ def get_study_by_alias(alias: str, connection: Connection) -> Row: """, ), parameters={"study_id": alias}, - ).fetchone() + ).one_or_none() -def get_study_data(study: Row, expdb: Connection) -> list[Row]: +def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]: if study.type_ == StudyType.TASK: return cast( - list[Row], + Sequence[Row], expdb.execute( text( """ @@ -47,10 +47,10 @@ def get_study_data(study: Row, expdb: Connection) -> list[Row]: """, ), parameters={"study_id": study.id}, - ).fetchall(), + ).all(), ) return cast( - list[Row], + Sequence[Row], expdb.execute( text( """ @@ -68,7 +68,7 @@ def get_study_data(study: Row, expdb: Connection) -> list[Row]: """, ), parameters={"study_id": study.id}, - ).fetchall(), + ).all(), ) @@ -96,7 +96,7 @@ def create_study(study: CreateStudy, user: User, expdb: Connection) -> int: "benchmark_suite": study.benchmark_suite, }, ) - (study_id,) = expdb.execute(text("""SELECT LAST_INSERT_ID();""")).fetchone() + (study_id,) = expdb.execute(text("""SELECT LAST_INSERT_ID();""")).one() return cast(int, study_id) diff --git a/src/database/tasks.py b/src/database/tasks.py index 07615f2..69ce220 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -1,10 +1,10 @@ -from typing import Any +from typing import Sequence, cast -from sqlalchemy import Connection, CursorResult, MappingResult, RowMapping, text +from sqlalchemy import Connection, Row, text -def get_task(task_id: int, expdb: Connection) -> RowMapping | None: - task_row = expdb.execute( +def get_task(task_id: int, expdb: Connection) -> Row | None: + return expdb.execute( text( """ SELECT * @@ -13,24 +13,25 @@ def get_task(task_id: int, expdb: Connection) -> RowMapping | None: """, ), parameters={"task_id": task_id}, - ) - return next(task_row.mappings(), None) + ).one_or_none() -def get_task_types(expdb: Connection) -> list[dict[str, str | int]]: - rows = expdb.execute( - text( - """ +def get_task_types(expdb: Connection) -> Sequence[Row]: + return cast( + Sequence[Row], + expdb.execute( + text( + """ SELECT `ttid`, `name`, `description`, `creator` FROM task_type """, - ), + ), + ).all(), ) - return list(rows.mappings()) -def get_task_type(task_type_id: int, expdb: Connection) -> RowMapping | None: - row = expdb.execute( +def get_task_type(task_type_id: int, expdb: Connection) -> Row | None: + return expdb.execute( text( """ SELECT * @@ -39,47 +40,54 @@ def get_task_type(task_type_id: int, expdb: Connection) -> RowMapping | None: """, ), parameters={"ttid": task_type_id}, - ) - return next(row, None) + ).one_or_none() -def get_input_for_task_type(task_type_id: int, expdb: Connection) -> CursorResult[Any]: - return expdb.execute( - text( - """ +def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Row]: + return cast( + Sequence[Row], + expdb.execute( + text( + """ SELECT * FROM task_type_inout WHERE `ttid`=:ttid AND `io`='input' """, - ), - parameters={"ttid": task_type_id}, + ), + parameters={"ttid": task_type_id}, + ).all(), ) -def get_input_for_task(task_id: int, expdb: Connection) -> MappingResult: - rows = expdb.execute( - text( - """ +def get_input_for_task(task_id: int, expdb: Connection) -> Sequence[Row]: + return cast( + Sequence[Row], + expdb.execute( + text( + """ SELECT `input`, `value` FROM task_inputs WHERE task_id = :task_id """, - ), - parameters={"task_id": task_id}, + ), + parameters={"task_id": task_id}, + ).all(), ) - return rows.mappings() -def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> CursorResult[Any]: - return expdb.execute( - text( - """ +def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> Sequence[Row]: + return cast( + Sequence[Row], + expdb.execute( + text( + """ SELECT * FROM task_type_inout WHERE `ttid`=:ttid AND `template_api` IS NOT NULL """, - ), - parameters={"ttid": task_type}, + ), + parameters={"ttid": task_type}, + ).all(), ) diff --git a/src/database/users.py b/src/database/users.py index 5152703..38fbd21 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -5,8 +5,6 @@ from pydantic import StringConstraints from sqlalchemy import Connection, text -from database.meta import get_column_names - # Enforces str is 32 hexadecimal characters, does not check validity. APIKey = Annotated[str, StringConstraints(pattern=r"^[0-9a-fA-F]{32}$")] @@ -18,8 +16,7 @@ class UserGroup(IntEnum): def get_user_id_for(*, api_key: APIKey, connection: Connection) -> int | None: - columns = get_column_names(connection, "users") - row = connection.execute( + user = connection.execute( text( """ SELECT * @@ -28,13 +25,11 @@ def get_user_id_for(*, api_key: APIKey, connection: Connection) -> int | None: """, ), parameters={"api_key": api_key}, - ) - if not (user := next(row, None)): - return None - return int(dict(zip(columns, user, strict=True))["id"]) + ).one_or_none() + return user.id if user else None -def get_user_groups_for(*, user_id: int, connection: Connection) -> list[int]: +def get_user_groups_for(*, user_id: int, connection: Connection) -> list[UserGroup]: row = connection.execute( text( """ @@ -45,7 +40,7 @@ def get_user_groups_for(*, user_id: int, connection: Connection) -> list[int]: ), parameters={"user_id": user_id}, ) - return [group for group, in row] + return [UserGroup(group) for group, in row] @dataclasses.dataclass diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 5903a28..650a198 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -17,24 +17,25 @@ _format_parquet_url, _safe_unquote, ) -from database.datasets import get_dataset as db_get_dataset from database.datasets import ( + _get_qualities_for_datasets, get_feature_values, get_features_for_dataset, get_file, get_latest_dataset_description, get_latest_processing_update, get_latest_status_update, - get_qualities_for_datasets, get_tags, insert_status_for_dataset, remove_deactivated_status, ) +from database.datasets import get_dataset as db_get_dataset from database.datasets import tag_dataset as db_tag_dataset from database.users import User, UserGroup from fastapi import APIRouter, Body, Depends, HTTPException from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType from sqlalchemy import Connection, text +from sqlalchemy.engine import Row from routers.dependencies import Pagination, expdb_connection, fetch_user, userdb_connection from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex @@ -240,7 +241,7 @@ def quality_clause(quality: str, range_: str | None) -> str: "NumberOfNumericFeatures", "NumberOfSymbolicFeatures", ] - qualities_by_dataset = get_qualities_for_datasets( + qualities_by_dataset = _get_qualities_for_datasets( dataset_ids=datasets.keys(), qualities=qualities_to_show, connection=expdb_db, @@ -261,9 +262,9 @@ def _get_processing_information(dataset_id: int, connection: Connection) -> Proc if not (data_processed := get_latest_processing_update(dataset_id, connection)): return ProcessingInformation(date=None, warning=None, error=None) - date_processed = data_processed["processing_date"] - warning = data_processed["warning"].strip() if data_processed["warning"] else None - error = data_processed["error"].strip() if data_processed["error"] else None + date_processed = data_processed.processing_date + warning = data_processed.warning.strip() if data_processed.warning else None + error = data_processed.error.strip() if data_processed.error else None return ProcessingInformation(date=date_processed, warning=warning, error=error) @@ -271,7 +272,7 @@ def _get_dataset_raise_otherwise( dataset_id: int, user: User | None, expdb: Connection, -) -> dict[str, Any]: +) -> Row: """Fetches the dataset from the database if it exists and the user has permissions. Raises HTTPException if the dataset does not exist or the user can not access it. @@ -305,7 +306,7 @@ def get_dataset_features( 273, "Dataset not processed yet. The dataset was not processed yet, features are not yet available. Please wait for a few minutes.", # noqa: E501 ) - elif processing_state.get("error"): + elif processing_state.error: code, msg = 274, "No features found. Additionally, dataset processed with error" else: code, msg = ( @@ -336,7 +337,7 @@ def update_dataset_status( dataset = _get_dataset_raise_otherwise(dataset_id, user, expdb) - can_deactivate = dataset["uploader"] == user.user_id or UserGroup.ADMIN in user.groups + can_deactivate = dataset.uploader == user.user_id or UserGroup.ADMIN in user.groups if status == DatasetStatus.DEACTIVATED and not can_deactivate: raise HTTPException( status_code=http.client.FORBIDDEN, @@ -349,7 +350,7 @@ def update_dataset_status( ) current_status = get_latest_status_update(dataset_id, expdb) - if current_status and current_status["status"] == status: + if current_status and current_status.status == status: raise HTTPException( status_code=http.client.PRECONDITION_FAILED, detail={"code": 694, "message": "Illegal status transition."}, @@ -363,7 +364,7 @@ def update_dataset_status( # - deactivated => active (delete a row) if current_status is None or status == DatasetStatus.DEACTIVATED: insert_status_for_dataset(dataset_id, user.user_id, status, expdb) - elif current_status["status"] == DatasetStatus.DEACTIVATED: + elif current_status.status == DatasetStatus.DEACTIVATED: remove_deactivated_status(dataset_id, expdb) else: raise HTTPException( @@ -385,7 +386,7 @@ def get_dataset( expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, ) -> DatasetMetadata: dataset = _get_dataset_raise_otherwise(dataset_id, user, expdb_db) - if not (dataset_file := get_file(dataset["file_id"], user_db)): + if not (dataset_file := get_file(dataset.file_id, user_db)): error = _format_error( code=DatasetError.NO_DATA_FILE, message="No data file found", @@ -397,20 +398,20 @@ def get_dataset( processing_result = _get_processing_information(dataset_id, expdb_db) status = get_latest_status_update(dataset_id, expdb_db) - status_ = DatasetStatus(status["status"]) if status else DatasetStatus.IN_PREPARATION + status_ = DatasetStatus(status.status) if status else DatasetStatus.IN_PREPARATION description_ = "" if description: - description_ = description["description"].replace("\r", "").strip() + description_ = description.description.replace("\r", "").strip() dataset_url = _format_dataset_url(dataset) parquet_url = _format_parquet_url(dataset) - contributors = _csv_as_list(dataset["contributor"], unquote_items=True) - creators = _csv_as_list(dataset["creator"], unquote_items=True) - ignore_attribute = _csv_as_list(dataset["ignore_attribute"], unquote_items=True) - row_id_attribute = _csv_as_list(dataset["row_id_attribute"], unquote_items=True) - original_data_url = _csv_as_list(dataset["original_data_url"], unquote_items=True) + contributors = _csv_as_list(dataset.contributor, unquote_items=True) + creators = _csv_as_list(dataset.creator, unquote_items=True) + ignore_attribute = _csv_as_list(dataset.ignore_attribute, unquote_items=True) + row_id_attribute = _csv_as_list(dataset.row_id_attribute, unquote_items=True) + original_data_url = _csv_as_list(dataset.original_data_url, unquote_items=True) # Not sure which properties are set by this bit: # foreach( $this->xml_fields_dataset['csv'] as $field ) { @@ -418,34 +419,34 @@ def get_dataset( # } return DatasetMetadata( - id=dataset["did"], - visibility=dataset["visibility"], + id=dataset.did, + visibility=dataset.visibility, status=status_, - name=dataset["name"], - licence=dataset["licence"], - version=dataset["version"], - version_label=dataset["version_label"] or "", - language=dataset["language"] or "", + name=dataset.name, + licence=dataset.licence, + version=dataset.version, + version_label=dataset.version_label or "", + language=dataset.language or "", creator=creators, contributor=contributors, - citation=dataset["citation"] or "", - upload_date=dataset["upload_date"], + citation=dataset.citation or "", + upload_date=dataset.upload_date, processing_date=processing_result.date, warning=processing_result.warning, error=processing_result.error, description=description_, - description_version=description["version"] if description else 0, + description_version=description.version if description else 0, tag=tags, - default_target_attribute=_safe_unquote(dataset["default_target_attribute"]), + default_target_attribute=_safe_unquote(dataset.default_target_attribute), ignore_attribute=ignore_attribute, row_id_attribute=row_id_attribute, url=dataset_url, parquet_url=parquet_url, minio_url=parquet_url, - file_id=dataset["file_id"], - format=dataset["format"].lower(), - paper_url=dataset["paper_url"] or None, + file_id=dataset.file_id, + format=dataset.format.lower(), + paper_url=dataset.paper_url or None, original_data_url=original_data_url, - collection_date=dataset["collection_date"], - md5_checksum=dataset_file["md5_hash"], + collection_date=dataset.collection_date, + md5_checksum=dataset_file.md5_hash, ) diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index 89b6e8b..ca95102 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -15,8 +15,8 @@ @router.get("/{flow_id}") def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection)] = None) -> Flow: - flow_rows = db_get_flow(flow_id, expdb) - if not (flow := next(flow_rows, None)): + flow = db_get_flow(flow_id, expdb) + if not flow: raise HTTPException(status_code=http.client.NOT_FOUND, detail="Flow not found") parameter_rows = get_flow_parameters(flow_id, expdb) diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 6afe4f4..e01b661 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -156,7 +156,7 @@ def get_task( ) -> Task: if not (task := db_get_task(task_id, expdb)): raise HTTPException(status_code=http.client.NOT_FOUND, detail="Task not found") - if not (task_type := get_task_type(task["ttid"], expdb)): + if not (task_type := get_task_type(task.ttid, expdb)): raise HTTPException( status_code=http.client.INTERNAL_SERVER_ERROR, detail="Task type not found", @@ -182,7 +182,7 @@ def get_task( name = f"Task {task_id} ({task_type.name})" dataset_id = task_inputs.get("source_data") if dataset_id and (dataset := get_dataset(dataset_id, expdb)): - name = f"Task {task_id}: {dataset['name']} ({task_type.name})" + name = f"Task {task_id}: {dataset.name} ({task_type.name})" return Task( id_=task.task_id, diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py index d1306c3..2cd7792 100644 --- a/src/routers/openml/tasktype.py +++ b/src/routers/openml/tasktype.py @@ -5,17 +5,18 @@ from database.tasks import get_input_for_task_type, get_task_types from database.tasks import get_task_type as db_get_task_type from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import Connection +from sqlalchemy import Connection, Row from routers.dependencies import expdb_connection router = APIRouter(prefix="/tasktype", tags=["tasks"]) -def _normalize_task_type(task_type: dict[str, str | int]) -> dict[str, str | list[Any]]: - ttype: dict[str, str | list[Any]] = { +def _normalize_task_type(task_type: Row) -> dict[str, str | None | list[Any]]: + # Task types may contain multi-line fields which have either \r\n or \n line endings + ttype: dict[str, str | None | list[Any]] = { k: str(v).replace("\r\n", "\n").strip() if v is not None else v - for k, v in task_type.items() + for k, v in task_type._mapping.items() if k != "id" } ttype["id"] = ttype.pop("ttid") @@ -27,8 +28,11 @@ def _normalize_task_type(task_type: dict[str, str | int]) -> dict[str, str | lis @router.get(path="/list") def list_task_types( expdb: Annotated[Connection, Depends(expdb_connection)] = None, -) -> dict[Literal["task_types"], dict[Literal["task_type"], list[dict[str, str | list[Any]]]]]: - task_types: list[dict[str, str | list[Any]]] = [ +) -> dict[ + Literal["task_types"], + dict[Literal["task_type"], list[dict[str, str | None | list[Any]]]], +]: + task_types: list[dict[str, str | None | list[Any]]] = [ _normalize_task_type(ttype) for ttype in get_task_types(expdb) ] return {"task_types": {"task_type": task_types}} @@ -38,7 +42,7 @@ def list_task_types( def get_task_type( task_type_id: int, expdb: Annotated[Connection, Depends(expdb_connection)], -) -> dict[Literal["task_type"], dict[str, str | list[str] | list[dict[str, str]]]]: +) -> dict[Literal["task_type"], dict[str, str | None | list[str] | list[dict[str, str]]]]: task_type_record = db_get_task_type(task_type_id, expdb) if task_type_record is None: raise HTTPException( @@ -46,10 +50,7 @@ def get_task_type( detail={"code": "241", "message": "Unknown task type."}, ) from None - # TODO: This below now is a RowMapping instead of a dictionary. - # In general: rerun all integration tests to make sure everything works after some - # recent task changes of dict to RowMapping/CursorResult[Any] types. - task_type = _normalize_task_type(task_type_record._asdict()) + task_type = _normalize_task_type(task_type_record) # Some names are quoted, or have typos in their comma-separation (e.g. 'A ,B') task_type["creator"] = [ creator.strip(' "') for creator in cast(str, task_type["creator"]).split(",") diff --git a/tests/routers/openml/datasets_test.py b/tests/routers/openml/datasets_test.py index 37e8242..0650171 100644 --- a/tests/routers/openml/datasets_test.py +++ b/tests/routers/openml/datasets_test.py @@ -28,6 +28,45 @@ def test_error_unknown_dataset( assert {"code": "111", "message": "Unknown dataset"} == response.json()["detail"] +def test_get_dataset(py_api: TestClient) -> None: + response = py_api.get("/datasets/1") + assert response.status_code == http.client.OK + description = response.json() + assert description.pop("description").startswith("**Author**:") + + assert description == { + "id": 1, + "name": "anneal", + "version": 1, + "format": "arff", + "description_version": 1, + "upload_date": "2014-04-06T23:19:24", + "licence": "Public", + "url": "https://test.openml.org/data/v1/download/1/anneal.arff", + "parquet_url": "https://openml1.win.tue.nl/dataset1/dataset_1.pq", + "file_id": 1, + "default_target_attribute": "class", + "version_label": "1", + "tag": ["study_14"], + "visibility": "public", + "minio_url": "https://openml1.win.tue.nl/dataset1/dataset_1.pq", + "status": "in_preparation", + "processing_date": "2023-10-12T09:08:38", + "md5_checksum": "4eaed8b6ec9d8211024b6c089b064761", + "row_id_attribute": [], + "ignore_attribute": [], + "language": "", + "error": None, + "warning": None, + "citation": "", + "collection_date": None, + "contributor": [], + "creator": [], + "paper_url": None, + "original_data_url": [], + } + + @pytest.mark.parametrize( ("api_key", "response_code"), [