Skip to content

Commit

Permalink
Issue #931 add MultiResult wrapper and Connection support
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Oct 30, 2024
1 parent d42e6f9 commit e1e5e09
Show file tree
Hide file tree
Showing 9 changed files with 443 additions and 25 deletions.
16 changes: 9 additions & 7 deletions openeo/internal/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import sys
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union

from openeo.api.process import Parameter
from openeo.internal.process_graph_visitor import (
Expand Down Expand Up @@ -244,7 +244,7 @@ def walk(x) -> Iterator[PGNode]:
yield from walk(self.arguments)


def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, Any]) -> Dict[str, dict]:
def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, List[FlatGraphableMixin], Any]) -> Dict[str, dict]:
"""
Convert given object to a internal flat dict graph representation.
"""
Expand All @@ -253,12 +253,15 @@ def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, Any]) -> Dict[str, di
# including `{"process_graph": {nodes}}` ("process graph")
# or just the raw process graph nodes?
if isinstance(x, dict):
# Assume given dict is already a flat graph representation
return x
elif isinstance(x, FlatGraphableMixin):
return x.flat_graph()
elif isinstance(x, (str, Path)):
# Assume a JSON resource (raw JSON, path to local file, JSON url, ...)
return load_json_resource(x)
elif isinstance(x, (list, tuple)) and all(isinstance(i, FlatGraphableMixin) for i in x):
return MultiLeafGraph(x).flat_graph()
raise ValueError(x)


Expand Down Expand Up @@ -450,14 +453,13 @@ def _process_from_parameter(self, name: str) -> Any:
return self._parameters[name]


class MultiResult(FlatGraphableMixin):
class MultiLeafGraph(FlatGraphableMixin):
"""
Handler of use cases where there are multiple result nodes
(or other leaf nodes) in a process graph.
Container for process graphs with multiple leaf/result nodes.
"""

def __init__(self, leaves: List[FlatGraphableMixin]):
self._leaves = leaves
def __init__(self, leaves: Iterable[FlatGraphableMixin]):
self._leaves = list(leaves)

def flat_graph(self) -> Dict[str, dict]:
flattener = GraphFlattener(multi_input_mode=True)
Expand Down
58 changes: 55 additions & 3 deletions openeo/rest/_testing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import collections
import json
import re
from typing import Callable, Iterator, Optional, Sequence, Union
from typing import Callable, Iterable, Iterator, Optional, Sequence, Union

from openeo import Connection, DataCube
from openeo.rest.vectorcube import VectorCube
Expand All @@ -19,8 +21,12 @@ class DummyBackend:
and allows inspection of posted process graphs
"""

# TODO: move to openeo.testing

__slots__ = (
"_requests_mock",
"connection",
"file_formats",
"sync_requests",
"batch_jobs",
"validation_requests",
Expand All @@ -33,8 +39,14 @@ class DummyBackend:
# Default result (can serve both as JSON or binary data)
DEFAULT_RESULT = b'{"what?": "Result data"}'

def __init__(self, requests_mock, connection: Connection):
def __init__(
self,
requests_mock,
connection: Connection,
):
self._requests_mock = requests_mock
self.connection = connection
self.file_formats = {"input": {}, "output": {}}
self.sync_requests = []
self.batch_jobs = {}
self.validation_requests = []
Expand Down Expand Up @@ -69,6 +81,35 @@ def __init__(self, requests_mock, connection: Connection):
)
requests_mock.post(connection.build_url("/validation"), json=self._handle_post_validation)

@classmethod
def at(cls, root_url: str, *, requests_mock, capabilities: Optional[dict] = None) -> DummyBackend:
"""
Factory to build dummy backend from given root URL
including creation of connection and mocking of capabilities doc
"""
root_url = root_url.rstrip("/") + "/"
requests_mock.get(root_url, json=build_capabilities(**(capabilities or None)))
connection = Connection(root_url)
return cls(requests_mock=requests_mock, connection=connection)

def setup_collection(self, collection_id: str):
# TODO: also mock `/collections` overview
self._requests_mock.get(
self.connection.build_url(f"/collections/{collection_id}"),
# TODO: add more metadata?
json={"id": collection_id},
)
return self

def setup_file_format(self, name: str, type: str = "output", gis_data_types: Iterable[str] = ("raster",)):
self.file_formats[type][name] = {
"title": name,
"gis_data_types": list(gis_data_types),
"parameters": {},
}
self._requests_mock.get(self.connection.build_url("/file_formats"), json=self.file_formats)
return self

def _handle_post_result(self, request, context):
"""handler of `POST /result` (synchronous execute)"""
pg = request.json()["process"]["process_graph"]
Expand Down Expand Up @@ -147,10 +188,21 @@ def get_sync_pg(self) -> dict:
return self.sync_requests[0]

def get_batch_pg(self) -> dict:
"""Get one and only batch process graph"""
"""
Get process graph of the one and only batch job.
Fails when there is none or more than one.
"""
assert len(self.batch_jobs) == 1
return self.batch_jobs[max(self.batch_jobs.keys())]["pg"]

def get_validation_pg(self) -> dict:
"""
Get process graph of the one and only validation request.
:return:
"""
assert len(self.validation_requests) == 1
return self.validation_requests[0]

def get_pg(self, process_id: Optional[str] = None) -> dict:
"""
Get one and only batch process graph (sync or batch)
Expand Down
44 changes: 38 additions & 6 deletions openeo/rest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
Expand Down Expand Up @@ -53,7 +54,7 @@
OpenEoClientException,
OpenEoRestError,
)
from openeo.rest._datacube import build_child_callback
from openeo.rest._datacube import _ProcessGraphAbstraction, build_child_callback
from openeo.rest.auth.auth import BasicBearerAuth, BearerAuth, NullAuth, OidcBearerAuth
from openeo.rest.auth.config import AuthConfig, RefreshTokenStore
from openeo.rest.auth.oidc import (
Expand Down Expand Up @@ -1128,7 +1129,9 @@ def user_defined_process(self, user_defined_process_id: str) -> RESTUserDefinedP
"""
return RESTUserDefinedProcess(user_defined_process_id=user_defined_process_id, connection=self)

def validate_process_graph(self, process_graph: Union[dict, FlatGraphableMixin, Any]) -> List[dict]:
def validate_process_graph(
self, process_graph: Union[dict, FlatGraphableMixin, str, Path, List[FlatGraphableMixin]]
) -> List[dict]:
"""
Validate a process graph without executing it.
Expand Down Expand Up @@ -1608,12 +1611,19 @@ def upload_file(
metadata = resp.json()
return UserFile.from_metadata(metadata=metadata, connection=self)

def _build_request_with_process_graph(self, process_graph: Union[dict, FlatGraphableMixin, Any], **kwargs) -> dict:
def _build_request_with_process_graph(
self,
process_graph: Union[dict, FlatGraphableMixin, str, Path, List[FlatGraphableMixin]],
**kwargs,
) -> dict:
"""
Prepare a json payload with a process graph to submit to /result, /services, /jobs, ...
:param process_graph: flat dict representing a "process graph with metadata" ({"process": {"process_graph": ...}, ...})
"""
# TODO: make this a more general helper (like `as_flat_graph`)
connections = extract_connections(process_graph)
if any(c != self for c in connections):
raise OpenEoClientException(f"Mixing different connections: {self} and {connections}.")
result = kwargs
process_graph = as_flat_graph(process_graph)
if "process_graph" not in process_graph:
Expand Down Expand Up @@ -1656,7 +1666,7 @@ def _preflight_validation(self, pg_with_metadata: dict, *, validate: Optional[bo
# TODO: unify `download` and `execute` better: e.g. `download` always writes to disk, `execute` returns result (raw or as JSON decoded dict)
def download(
self,
graph: Union[dict, FlatGraphableMixin, str, Path],
graph: Union[dict, FlatGraphableMixin, str, Path, List[FlatGraphableMixin]],
outputfile: Union[Path, str, None] = None,
*,
timeout: Optional[int] = None,
Expand Down Expand Up @@ -1695,7 +1705,7 @@ def download(

def execute(
self,
process_graph: Union[dict, str, Path],
process_graph: Union[dict, FlatGraphableMixin, str, Path, List[FlatGraphableMixin]],
*,
timeout: Optional[int] = None,
validate: Optional[bool] = None,
Expand Down Expand Up @@ -1732,7 +1742,7 @@ def execute(

def create_job(
self,
process_graph: Union[dict, str, Path, FlatGraphableMixin],
process_graph: Union[dict, FlatGraphableMixin, str, Path, List[FlatGraphableMixin]],
*,
title: Optional[str] = None,
description: Optional[str] = None,
Expand Down Expand Up @@ -1968,3 +1978,25 @@ def paginate(con: Connection, url: str, params: Optional[dict] = None, callback:
url = next_links[0]["href"]
page += 1
params = {}


def extract_connections(
data: Union[_ProcessGraphAbstraction, Sequence[_ProcessGraphAbstraction], Any]
) -> Set[Connection]:
"""
Extract the :py:class:`Connection` object(s) linked from a given data construct.
Typical use case is to get the connection from a :py:class:`DataCube`,
but can also extract multiple connections from a list of data cubes.
"""
connections = set()
# TODO: define some kind of "Connected" interface/mixin/protocol
# for objects that contain a connection instead of just checking for _ProcessGraphAbstraction
# TODO: also support extracting connections from other objects like BatchJob, ...
if isinstance(data, _ProcessGraphAbstraction) and data.connection:
connections.add(data.connection)
elif isinstance(data, (list, tuple, set)):
for item in data:
if isinstance(item, _ProcessGraphAbstraction) and item.connection:
connections.add(item.connection)

return connections
74 changes: 74 additions & 0 deletions openeo/rest/multiresult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

from typing import Dict, List, Optional

from openeo import BatchJob
from openeo.internal.graph_building import FlatGraphableMixin, MultiLeafGraph
from openeo.rest import OpenEoClientException
from openeo.rest.connection import Connection, extract_connections


class MultiResult(FlatGraphableMixin):
"""
Adapter to create/run batch jobs from process graphs with multiple end/result/leaf nodes.
Usage example:
.. code-block:: python
cube1 = ...
cube2 = ...
multi_result = MultiResult([cube1, cube2])
job = multi_result.create_job()
"""

def __init__(self, leaves: List[FlatGraphableMixin], connection: Optional[Connection] = None):
self._multi_leaf_graph = MultiLeafGraph(leaves=leaves)
self._connection = self._common_connection(leaves=leaves, connection=connection)

@staticmethod
def _common_connection(leaves: List[FlatGraphableMixin], connection: Optional[Connection] = None) -> Connection:
"""Find common connection. Fails if there are multiple or none."""
connections = set()
if connection:
connections.add(connection)
connections.update(extract_connections(leaves))

if len(connections) == 1:
return connections.pop()
elif len(connections) == 0:
raise OpenEoClientException("No connection in any of the MultiResult leaves")
else:
raise OpenEoClientException("MultiResult with multiple different connections")

def flat_graph(self) -> Dict[str, dict]:
return self._multi_leaf_graph.flat_graph()

def create_job(
self,
*,
title: Optional[str] = None,
description: Optional[str] = None,
job_options: Optional[dict] = None,
validate: Optional[bool] = None,
) -> BatchJob:
return self._connection.create_job(
process_graph=self._multi_leaf_graph,
title=title,
description=description,
additional=job_options,
validate=validate,
)

def execute_batch(
self,
*,
title: Optional[str] = None,
description: Optional[str] = None,
job_options: Optional[dict] = None,
validate: Optional[bool] = None,
) -> BatchJob:
job = self.create_job(title=title, description=description, job_options=job_options, validate=validate)
return job.run_synchronous()
3 changes: 2 additions & 1 deletion openeo/rest/udp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing
from pathlib import Path
from typing import List, Optional, Union

from openeo.api.process import Parameter
Expand All @@ -16,7 +17,7 @@


def build_process_dict(
process_graph: Union[dict, FlatGraphableMixin],
process_graph: Union[dict, FlatGraphableMixin, Path, List[FlatGraphableMixin]],
process_id: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
Expand Down
Loading

0 comments on commit e1e5e09

Please sign in to comment.