This repository has been archived by the owner on Nov 6, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat(mlflow): Add model/model version mapping Co-authored-by: knikitiuk <[email protected]>
- Loading branch information
1 parent
d48ef66
commit e8a8816
Showing
24 changed files
with
793 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
platform_host_url: http://localhost:8080 | ||
default_pulling_interval: 10 | ||
token: "" | ||
plugins: | ||
- type: mlflow | ||
name: mlflow_adapter | ||
dev_mode: False | ||
tracking_uri: str | ||
registry_uri: str | ||
filter_experiments: None # List of pipeline names to filter, if omit fetches all pipelines |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# ml-flow | ||
|
||
## Preview: | ||
- [Keyword Definition](#keyword-definition) | ||
- [Available parameters](#available-parameters) | ||
- [Current client flow](#current-client-flow) | ||
- [Job info specifics](#job-info-specifics) | ||
|
||
## Keyword Definition | ||
Experiment (pipeline) - is a sequence of jobs (runs). | ||
|
||
Job (run) - is a transformation task/ method that is applied to input data. | ||
|
||
Experiment can be composed of several jobs or a single one. Job can include one or a set of transformations. | ||
|
||
## Available Parameters | ||
Each Experiment in mlflow has: | ||
|
||
| Parameter | | ||
|----------------------------------------| | ||
| artifact_location | | ||
| creation_time | | ||
| experiment_id | | ||
| last_update_time | | ||
| lifecycle_stage | | ||
| name | | ||
| tags | | | | | ||
|
||
Each Run(job) in mlflow has: | ||
|
||
| Parameter | | ||
|----------------------------| | ||
| metrics | | ||
| params | | ||
| tags | | ||
| last_update_time | | ||
| lifecycle_stage | | ||
| artifact_uri | | ||
| end_time | | ||
| experiment_id | | ||
| lifecycle_stage | | ||
| run_id | | ||
| run_name | | ||
| run_uuid | | ||
| start_time | | ||
| status | | ||
| user_id | | ||
|
||
|
||
## Current client flow | ||
MlFlow client returns two type of entities, they are Experiment Entity and Job Entity. | ||
|
||
1. Client connects to mlflow tracking uri by link that is set in [config](https://github.com/opendatadiscovery/odd-collector/blob/40d218a0ab8d0644884f06b5b55577094577ba48/odd_collector/domain/plugin.py) as host. | ||
|
||
2. Client requests necessary experiments by name in case the list was specified in [config](https://github.com/opendatadiscovery/odd-collector/blob/40d218a0ab8d0644884f06b5b55577094577ba48/odd_collector/domain/plugin.py) as pipelines, | ||
otherwise we request the full list of experiments. | ||
|
||
3. For each experiment fetch list of jobs and general information with the list of parameters that could be found [here](odd_collector/adapters/mlflow/domain/job.py) | ||
|
||
## Job info specifics | ||
|
||
1. Request list of jobs for each experiment_id. | ||
(Experiment's jobs are requested as a dataframe. List of jobs (runs) that belongs to specified Experiment is generated while rows iteration) | ||
|
||
2. For each job_id get job_info (run_info) and request list of artifacts. | ||
(As we can't fetch separately input/ output artifacts, request list of artifacts for each job. | ||
In case it is a folder iterate though files and append each one to a list.) | ||
|
||
If it is necessary to specify exactly which artifacts are input/output log them in your experiment as a params with key 'input_artifacts' and 'output_artifacts', | ||
|
||
For instance: mlflow.log_param('input_artifacts',['https://']) | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from typing import Dict, Type | ||
from urllib.parse import urlparse | ||
|
||
from funcy import lconcat, lpluck_attr | ||
from odd_collector_sdk.domain.adapter import AbstractAdapter | ||
from odd_models.models import DataEntity, DataEntityList | ||
|
||
from odd_collector.domain.plugin import MlflowPlugin | ||
|
||
from .client import Client, MlflowClientBase | ||
from .generator import MlFlowGenerator | ||
from .mappers import map_experiment, map_model, map_model_version, map_run | ||
|
||
|
||
class Adapter(AbstractAdapter): | ||
def __init__(self, config: MlflowPlugin, client: Type[MlflowClientBase] = Client): | ||
self.config = config | ||
self.client = client(config) | ||
self.generator = MlFlowGenerator( | ||
host_settings=urlparse(config.tracking_uri).netloc | ||
) | ||
|
||
def get_data_source_oddrn(self) -> str: | ||
return self.generator.get_data_source_oddrn() | ||
|
||
def get_data_entity_list(self) -> DataEntityList: | ||
experiment_entities: Dict[str, DataEntity] = {} | ||
runs_entities: Dict[str, DataEntity] = {} | ||
models_entities: Dict[str, DataEntity] = {} | ||
model_versions_entities: Dict[str, DataEntity] = {} | ||
|
||
for experiment in self.client.get_experiments(): | ||
self.generator.set_oddrn_paths(experiments=experiment.name) | ||
|
||
runs: Dict[str, DataEntity] = dict( | ||
map_run(self.generator, run) for run in experiment.runs | ||
) | ||
|
||
experiment_id, experiment_entity = map_experiment( | ||
self.generator, | ||
lpluck_attr("oddrn", runs.values()), | ||
experiment, | ||
) | ||
|
||
experiment_entities[experiment_id] = experiment_entity | ||
runs_entities.update(runs) | ||
|
||
for model in self.client.get_models(): | ||
self.generator.set_oddrn_paths(models=model.name) | ||
|
||
_model_data_entity: DataEntity = map_model(self.generator, [], model) | ||
_model_versions_entities: Dict[str, DataEntity] = {} | ||
|
||
for model_version in model.model_versions: | ||
model_version_entity = map_model_version(self.generator, model_version) | ||
if model_version.run_id in runs_entities: | ||
runs_entities[model_version.run_id].data_transformer.outputs.append( | ||
model_version_entity.oddrn | ||
) | ||
_model_data_entity.data_entity_group.entities_list.append( | ||
model_version_entity.oddrn | ||
) | ||
|
||
_model_versions_entities[model_version.full_name] = model_version_entity | ||
|
||
model_versions_entities[model.name] = _model_data_entity | ||
model_versions_entities.update(_model_versions_entities) | ||
|
||
items = DataEntityList( | ||
data_source_oddrn=self.get_data_source_oddrn(), | ||
items=lconcat( | ||
runs_entities.values(), | ||
experiment_entities.values(), | ||
model_versions_entities.values(), | ||
models_entities.values(), | ||
), | ||
) | ||
|
||
return items |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import logging | ||
import tempfile | ||
from abc import ABC | ||
from functools import partial | ||
from typing import Callable, Iterable, Optional, TypeVar | ||
|
||
import mlflow | ||
from mlflow.entities import Experiment as MlFlowExperiment | ||
from mlflow.entities import FileInfo | ||
from mlflow.entities import Run as MlFLowRun | ||
from mlflow.entities.model_registry import ModelVersion as MlFlowModelVersion | ||
from mlflow.entities.model_registry import RegisteredModel | ||
from mlflow.exceptions import MlflowException | ||
from mlflow.store.entities import PagedList | ||
from odd_collector_sdk.errors import DataSourceError | ||
|
||
from ...domain.plugin import MlflowPlugin | ||
from .domain.experiment import Experiment | ||
from .domain.model import Model | ||
from .domain.model_version import ModelVersion | ||
from .domain.odd_metadata import OddMetadata | ||
from .domain.run import Run | ||
|
||
T = TypeVar("T", MlFlowExperiment, MlFLowRun, RegisteredModel, MlFlowModelVersion) | ||
|
||
|
||
class MlflowClientBase(ABC): | ||
def __init__(self, config: MlflowPlugin): | ||
self.config = config | ||
|
||
def get_experiments(self) -> Iterable[Experiment]: | ||
raise NotImplementedError | ||
|
||
def get_models(self) -> Iterable[Model]: | ||
raise NotImplementedError | ||
|
||
|
||
class Client(MlflowClientBase): | ||
METADATA_FILE_NAME = "odd_metadata.json" | ||
|
||
def __init__(self, config: MlflowPlugin): | ||
super().__init__(config) | ||
try: | ||
self._client = mlflow.MlflowClient(config.tracking_uri, config.registry_uri) | ||
mlflow.set_tracking_uri(config.tracking_uri) | ||
mlflow.set_registry_uri(config.registry_uri) | ||
except MlflowException as e: | ||
raise DataSourceError from e | ||
|
||
def get_experiments(self) -> Iterable[Experiment]: | ||
"""Load and map mlflow Experiments with runs | ||
Returns: | ||
Iterable[Experiment]: mapped to domain Experiments | ||
""" | ||
search_experiments = partial( | ||
self._client.search_experiments, filter_string=self._filter_string() | ||
) | ||
|
||
for experiment in self._fetch(fn=search_experiments): | ||
runs = list(self._get_runs(experiment)) | ||
yield Experiment.from_mlflow(experiment, runs) | ||
|
||
def _filter_string(self) -> Optional[str]: | ||
"""Generate filter string by experiments name | ||
Returns: | ||
Optional[str] | ||
""" | ||
if self.config.filter_experiments is None: | ||
return None | ||
|
||
return " AND ".join( | ||
f"name == '{name}'" for name in self.config.filter_experiments | ||
) | ||
|
||
def get_models(self) -> Iterable[Model]: | ||
"""Search registered model and all model versions for them (not only latest) | ||
Returns: | ||
Iterable[Model]: mapped to domain Models | ||
""" | ||
search_models = partial(self._client.search_registered_models) | ||
|
||
for model in self._fetch(fn=search_models): | ||
model = Model.from_mlflow(model) | ||
model.model_versions = list(self._get_model_versions_by(model.name)) | ||
|
||
yield model | ||
|
||
def _get_model_versions_by(self, model_name: str) -> Iterable[ModelVersion]: | ||
"""Search all model versions by model name | ||
Args: | ||
model_name (str): MLFlow Model's name | ||
Returns: | ||
Iterable[ModelVersion]: _description_ | ||
""" | ||
search_model_versions = partial( | ||
self._client.search_model_versions, filter_string=f"name='{model_name}'" | ||
) | ||
|
||
for mv in self._fetch(fn=search_model_versions): | ||
yield ModelVersion.from_mlflow(mv) | ||
|
||
def _get_runs(self, experiment: MlFlowExperiment) -> Iterable[Run]: | ||
search_runs = partial( | ||
self._client.search_runs, experiment_ids=[experiment.experiment_id] | ||
) | ||
|
||
for run in self._fetch(fn=search_runs): | ||
artifacts = self._get_artifacts(run.info.run_id) | ||
odd_artifacts = self._load_odd_artifact(run.info.run_id) | ||
|
||
yield Run.from_mlflow(run, list(artifacts), odd_artifacts) | ||
|
||
def _get_artifacts(self, run_id: str) -> Iterable[FileInfo]: | ||
""" | ||
Collect list of all artifacts. As we can't fetch separately input/ output artifacts. | ||
For each artifact folder -> go inside and insert to a general list | ||
Args: | ||
run_id: str | ||
Returns: | ||
list of artifacts for specified run_id | ||
""" | ||
|
||
def _recursive(file_info: FileInfo) -> Iterable[FileInfo]: | ||
if file_info.is_dir: | ||
for file_info in self._client.list_artifacts( | ||
run_id, path=file_info.path | ||
): | ||
yield from _recursive(file_info) | ||
else: | ||
yield file_info | ||
|
||
for file_info in self._client.list_artifacts(run_id): | ||
yield from _recursive(file_info) | ||
|
||
def _fetch(self, fn: Callable[[Optional[str]], PagedList[T]]) -> Iterable[T]: | ||
"""Helper function fetching paginated resources recursively using token page | ||
Args: | ||
fn (function): MLFlow partial applied function | ||
Returns: | ||
Iterable[Any]: _description_ | ||
""" | ||
first_page: PagedList = fn() | ||
yield from first_page | ||
|
||
token = first_page.token | ||
while token: | ||
next_page = fn(page_token=token) | ||
|
||
yield from next_page | ||
|
||
token = next_page.token | ||
|
||
def _load_odd_artifact(self, run_id: str) -> OddMetadata: | ||
""" | ||
When MlFlow user logged additional information for Run as: | ||
mlflow.log_dict( | ||
{"inputs": ["s3://training/wine/winequality-red.csv"], "outputs": []}, | ||
"odd_metadata.json", | ||
) | ||
Adapter tries to find artifact with name odd_metadata.json, | ||
load it to temporary directory and parse to OddMetadata | ||
""" | ||
try: | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
file_path = mlflow.artifacts.download_artifacts( | ||
run_id=run_id, | ||
artifact_path=self.METADATA_FILE_NAME, | ||
dst_path=tmp_dir, | ||
) | ||
|
||
return OddMetadata.parse_file(file_path) | ||
except Exception as e: | ||
logging.debug( | ||
"Could not read metadata file odd_metadata.json", e, exc_info=True | ||
) | ||
return OddMetadata() |
Empty file.
Oops, something went wrong.