import os
class FileHandler:
def __init__(self, file_path, content):
self.file_path = file_path
# Create the file when the object is created
with open(self.file_path, 'w') as f:
f.write(content)
print(f"File created: {self.file_path}")
def __del__(self):
# Delete the file when the object is about to be garbage collected
if os.path.exists(self.file_path):
os.remove(self.file_path)
print(f"File deleted: {self.file_path}")
# Example usage
file_handler = FileHandler("example.txt", "This is some content.")
# The file will be deleted automatically when the object is garbage collected
import pickle
import tempfile
import os
class MyModel:
def __init__(self, data, model):
self.data = data
self.model = model # This model has .save_model() and .load_model()
def __getstate__(self):
state = self.__dict__.copy()
# Use a temporary file to save the model
with tempfile.NamedTemporaryFile(delete=False) as tmp:
self.model.save_model(tmp.name)
tmp.seek(0)
state['model'] = tmp.read() # Read the binary content of the file
os.unlink(tmp.name) # Clean up the temp file
return state
def __setstate__(self, state):
# Write the binary data to a temporary file and load the model from it
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(state['model'])
tmp.seek(0)
model = SomeModelClass() # Instantiate the model class
model.load_model(tmp.name)
state['model'] = model
os.unlink(tmp.name) # Clean up the temp file
self.__dict__.update(state)
# Example usage
data = 'some data'
model = SomeModelClass() # This class would have .save_model() and .load_model()
obj = MyModel(data, model)
# Pickle the object
with open('my_object.pkl', 'wb') as f:
pickle.dump(obj, f)
# Unpickle the object
with open('my_object.pkl', 'rb') as f:
loaded_obj = pickle.load(f)
import pickle
import io
class MyModel:
def __init__(self, data, model):
self.data = data
self.model = model # Assume this attribute has .save_model() and .load_model()
def __getstate__(self):
state = self.__dict__.copy()
# Save the model to a BytesIO stream
model_stream = io.BytesIO()
self.model.save_model(model_stream)
model_stream.seek(0) # Important: move to the start of the stream after writing
state['model'] = model_stream.getvalue() # Save the byte data
return state
def __setstate__(self, state):
# Load the model from the byte data
model_stream = io.BytesIO(state['model'])
model = SomeModelClass() # Assuming you have a way to instantiate it
model.load_model(model_stream)
state['model'] = model
self.__dict__.update(state)
# Example usage
data = 'some data'
model = SomeModelClass() # This class would have .save_model() and .load_model()
obj = MyModel(data, model)
# Pickle the object
with open('my_object.pkl', 'wb') as f:
pickle.dump(obj, f)
# Unpickle the object
with open('my_object.pkl', 'rb') as f:
loaded_obj = pickle.load(f)
import polars as pl
def sample_groups(df: pl.DataFrame, group_col: str, n: int) -> pl.DataFrame:
"""
Take `n` samples from each group in a specified categorical column of a Polars DataFrame.
Parameters:
df (pl.DataFrame): The input Polars DataFrame.
group_col (str): The column name to group by.
n (int): The number of samples to take from each group.
Returns:
pl.DataFrame: A new DataFrame with `n` samples from each group.
"""
sampled_dfs = []
# Get unique groups in the group_col
groups = df[group_col].unique()
for group in groups:
# Filter the DataFrame by the current group
group_df = df.filter(pl.col(group_col) == group)
# Sample n rows from the current group
sampled_group_df = group_df.sample(n, with_replacement=True)
# Collect the sampled DataFrame
sampled_dfs.append(sampled_group_df)
# Concatenate all the sampled DataFrames
sampled_df = pl.concat(sampled_dfs)
return sampled_df
# Example usage:
df = pl.DataFrame({
"category": ["A", "A", "A", "B", "B", "B", "C", "C", "C"],
"value": [1, 2, 3, 4, 5, 6, 7, 8, 9]
})
# Take 2 samples from each group in the "category" column
sampled_df = sample_groups(df, "category", 2)
print(sampled_df)
import pytest
import polars as pl
from google.cloud import bigquery
from unittest.mock import patch, MagicMock
def test_execute_sql_file():
# Mocking the BigQuery client
with patch('google.cloud.bigquery.Client', return_value=MagicMock()) as mock_client:
# Mocking the read_file_contents function
with patch('your_module.read_file_contents', return_value="SELECT * FROM table WHERE date >= '{start_date}'") as mock_read_file:
# Mocking the query_polars_df function
with patch('your_module.query_polars_df', return_value=pl.DataFrame({"column1": [1, 2, 3]})) as mock_query:
from your_module import execute_sql_file
# Call the function with the SQL file and kwargs
df = execute_sql_file("query.sql", start_date='2024-01-01')
# Check if the BigQuery client was instantiated
mock_client.assert_called_once_with(project=PROJECT_ID)
# Check if the SQL file was read
mock_read_file.assert_called_once_with("query.sql")
# Check if the query was executed
mock_query.assert_called_once_with("SELECT * FROM table WHERE date >= '2024-01-01'", client=mock_client.return_value)
# Check if the returned DataFrame is correct
assert df.shape == (3, 1)
assert df.columns == ["column1"]
assert df["column1"].to_list() == [1, 2, 3]
Process Overview
Calculating RFM (Recency, Frequency, Monetary value)
You calculate RFM values for your customer database over a specified period.
Using Models for Prediction
You use the Beta-Geometric/Negative Binomial Distribution (BG/NBD) model to predict customer churn.
You use the Gamma-Gamma model to predict monetary value for the next X amount of time (e.g., 3 months).
Testing Predictions
You test your model predictions by comparing them to actual values for the past 3 months.
You obtain data for two periods:
The actual data for the past 3 months.
Data for the year before that (to serve as a training or reference period).
Updating Models for Production
After validation, you update your models with the latest data and deploy them for production use.
Addressing the Concerns
Validation Process
Historical Validation: It's correct to validate your models using historical data. For example, if you're in January 2024, you could validate your models using data from January 2023 to December 2023 to predict the period from October 2023 to December 2023.
Real-time Validation: Once validated, you update your models with the latest data for production use. However, these models are not immediately validated with real future data since that data hasn't arrived yet.
Model Retraining and Updating
It is standard practice to retrain models with the most recent data to ensure they capture the latest trends and behaviors.
These updated models, although trained on the latest data, will only be validated once the future period has elapsed and you can compare predictions against actual data.
Best Practices for Training and Validation
Backtesting
Perform backtesting where you use historical data to simulate predictions and compare them against actual outcomes. This helps in validating the model's performance.
Example: Use data from January 2022 to December 2022 to predict the period from October 2022 to December 2022, then compare against actual data for those months.
Rolling Window Validation
Use a rolling window approach to continually validate your model. For example, validate the model for multiple 3-month periods in the past, not just the most recent one.
This approach ensures the model's robustness over different periods and reduces the risk of overfitting to a particular period.
Continuous Monitoring
Once the models are in production, continuously monitor their performance and periodically validate them as new data becomes available.
Set up a mechanism to automatically retrain and validate models at regular intervals (e.g., monthly or quarterly).
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="%(levelname)s (%(asctime)s): %(message)s (func: %(funcName)s line: %(lineno)d [%(filename)s])",
datefmt="%H:%M:%S",
)
# TODO move logging files to a log folder in root
# Print DEBUG level to console
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.debug(f"Print hi from crypto_labelling {__name__=}")
class ColorFormatter(logging.Formatter):
# Define the color codes
COLORS = {
"DEBUG": "\033[34m",
"INFO": "\033[0m", # White
"FINISHED": "\033[92m", # Green
"WARNING": "\033[93m", # Yellow
"ERROR": "\033[91m", # Red
"CRITICAL": "\033[1;31m", # Bold Red
"METRICS_LEVEL": "\033[36m", # Purple,
"PARAMS_LEVEL": "\033[95m",
"ARTIFACT_LEVEL": "\033[35m",
}
RESET = "\033[0m" # Reset code
def format(self, record):
levelname = record.levelname
message = super().format(record)
color = self.COLORS.get(levelname, self.RESET)
return f"{color}{message}{self.RESET}"
formatter = ColorFormatter(
fmt="%(levelname)s (%(asctime)s): %(message)s (func: %(funcName)s line: %(lineno)d [%(filename)s])",
datefmt="%H:%M:%S",
)
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "snakeviz",
# ]
# ///
import argparse
import cProfile
import os
import subprocess
import threading
import time
import psutil
def log_resource_usage(interval=1, output_file="resource_usage.log"):
"""Logs the CPU usage per core and memory usage in both percentage and MB every `interval` seconds."""
with open(output_file, "w") as f:
core_columns = "\t".join(
map(lambda i: f"Core {i}[%]", range(psutil.cpu_count()))
)
f.write(f"Total CPU[%]\tMemory[%]\tMemory[MB]\t{core_columns}\tTime\n")
while True:
# Overall CPU and memory usage
total_cpu_usage = psutil.cpu_percent()
memory = psutil.virtual_memory()
memory_usage_percent = memory.percent
memory_usage_mb = memory.used / (1024**2) # Convert to MB
# Per-core CPU usage
per_core_cpu_usage = psutil.cpu_percent(percpu=True)
per_core_cpu_str = "\t".join(map(str, per_core_cpu_usage))
# Current timestamp
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
# Log the data to file
f.write(
f"{total_cpu_usage}\t{memory_usage_percent}\t{memory_usage_mb:.2f}\t{per_core_cpu_str}\t{timestamp}\n"
)
f.flush()
time.sleep(interval)
def profile_program(script, profiler_file="output.pstats"):
"""Runs the specified program with cProfile."""
# cProfile.run(f"exec(open('{file}').read())", profiler_file)
subprocess.run(
["python", "-m", "cProfile", "-o", profiler_file, script], check=True
)
if __name__ == "__main__":
# Set up argument parser for CLI
parser = argparse.ArgumentParser(
description="Profile a Python script and log resource usage."
)
parser.add_argument("script", help="The Python script to execute and profile.")
parser.add_argument(
"--profiler_file",
default="logs/output.pstats",
help="File to save profiler results.",
)
parser.add_argument(
"--resource_file",
default="logs/resource_usage.log",
help="File to log CPU and memory usage.",
)
parser.add_argument(
"--interval",
type=int,
default=1,
help="Interval (in seconds) between resource usage logs.",
)
# Parse the arguments
args = parser.parse_args()
# Start resource usage logging in a separate thread
resource_thread = threading.Thread(
target=log_resource_usage,
kwargs=dict(interval=args.interval, output_file=args.resource_file),
)
resource_thread.daemon = True # Daemonize thread to exit with the program
resource_thread.start()
# Profile the main program
profile_program(args.script, args.profiler_file)
print(
f"Finalized execution to visualize results in interactive window run:\nuvx snakeviz {args.profiler_file}"
)
# Target for building the Docker image
.PHONY: build
build:
@echo "Fetching environment variables from Python script..."
IMAGE_NAME=$$(python3 -c 'from env_vars import IMAGE_NAME; print(IMAGE_NAME)') \
ENV_VALUE=$$(python3 -c 'from env_vars import ENV_VALUE; print(ENV_VALUE)') \
DOCKER_BUILDKIT=1 docker build --no-cache --progress=plain -t $$IMAGE_NAME . -f Dockerfile --build-arg VAR_!=$$ENV_VALUE
# Target for cleaning up
.PHONY: clean
clean:
@echo "Cleaning Docker build artifacts..."
docker system prune -f
import atexit
import json
import logging
import os
import pickle
from posixpath import splitext
from typing import Any, Callable, Optional
import mlflow
import mlflow.tracking.client
import pandas as pd
import plotly.graph_objects as go
import yaml
from mlflow.entities.experiment import Experiment
from mlflow.entities.model_registry.model_version import ModelVersion
from mlflow.entities.run import Run
from mlflow.models.model import ModelInfo
from mlflow.pyfunc import PythonModel
from mlflow.tracking.client import MlflowClient
from consts import EXPERIMENT_TAGS
from src.utils.search_file import search_file
logger = logging.getLogger("models")
LOAD_FUNCTIONS = {
"catboost": mlflow.catboost.load_model,
"fastai": mlflow.fastai.load_model,
"gluon": mlflow.gluon.load_model,
"h2o": mlflow.h2o.load_model,
"keras": mlflow.keras.load_model,
"lightgbm": mlflow.lightgbm.load_model,
"onnx": mlflow.onnx.load_model,
"pyfunc": mlflow.pyfunc.load_model,
"pytorch": mlflow.pytorch.load_model,
"sklearn": mlflow.sklearn.load_model,
"spacy": mlflow.spacy.load_model,
"spark": mlflow.spark.load_model,
"statsmodels": mlflow.statsmodels.load_model,
"tensorflow": mlflow.tensorflow.load_model,
"xgboost": mlflow.xgboost.load_model,
"paddle": mlflow.paddle.load_model,
"prophet": mlflow.prophet.load_model,
"pmdarima": mlflow.pmdarima.load_model,
}
LOG_FUNCTIONS = {
"catboost": mlflow.catboost.log_model,
"fastai": mlflow.fastai.log_model,
"gluon": mlflow.gluon.log_model,
"h2o": mlflow.h2o.log_model,
"keras": mlflow.keras.log_model,
"lightgbm": mlflow.lightgbm.log_model,
"onnx": mlflow.onnx.log_model,
"pyfunc": mlflow.pyfunc.log_model,
"pytorch": mlflow.pytorch.log_model,
"sklearn": mlflow.sklearn.log_model,
"spacy": mlflow.spacy.log_model,
"spark": mlflow.spark.log_model,
"statsmodels": mlflow.statsmodels.log_model,
"tensorflow": mlflow.tensorflow.log_model,
"xgboost": mlflow.xgboost.log_model,
"paddle": mlflow.paddle.log_model,
"prophet": mlflow.prophet.log_model,
"pmdarima": mlflow.pmdarima.log_model,
}
@search_file
def save_pickle(
model: Any,
local_model_path: str,
):
with open(local_model_path, "wb") as f:
pickle.dump(model, f)
class InvalidModelName(Exception):
"""Invalid Model Name Exception"""
class ModelFlavorNotSupported(Exception):
"""MlFlow Model Flavor Not Supported"""
class MlFlowLogger:
"""
A class used to interact with MLflow's tracking and model registry components.
"""
load_functions: dict[str, Callable] = LOAD_FUNCTIONS
log_functions: dict[str, Callable] = LOG_FUNCTIONS
def __init__(self):
mlflow.set_tracking_uri()
self.client = MlflowClient() # Initialize client
def start_run(self, experiment_name: str, run_name: str | None = None):
experiment = self.get_or_create_experiment(experiment_name)
experiment_id = experiment.experiment_id
self.run = mlflow.start_run(experiment_id=experiment_id, run_name=run_name)
self.run_id = self.run.info.run_uuid
logger.info(f"Starting mlflow run with run_id={self.run_id}")
atexit.register(self.end_run)
return self.run_id
def end_run(self):
mlflow.end_run()
def get_experiment_id(self, experiment_name: str) -> str:
"""
Retrieves the ID of an experiment given its name.
Parameters
----------
experiment_name : str
The name of the experiment.
Returns
-------
str
The ID of the experiment.
"""
retrieved_exp_id = self.get_experiment(experiment_name)
return retrieved_exp_id.experiment_id
def get_experiment(self, experiment_name: str) -> Experiment:
"""
Retrieves an experiment given its name.
Parameters
----------
experiment_name : str
The name of the experiment.
Returns
-------
Experiment
The retrieved experiment.
"""
client = self.client
return client.get_experiment_by_name(experiment_name)
def get_experiment_runs(self, experiment_name: str) -> list[Run]:
"""
Retrieves all runs of an experiment given its name.
Parameters
----------
experiment_name : str
The name of the experiment.
Returns
-------
list[Run]
A list of all runs of the experiment.
"""
client = self.client
retrieved_exp_id = self.get_experiment_id(experiment_name)
runs = client.search_runs(retrieved_exp_id)
return runs
def get_latest_run(self, experiment_name: str) -> Run:
"""
Retrieves the latest run of an experiment given its name.
Parameters
----------
experiment_name : str
The name of the experiment.
Returns
-------
Run
The latest run of the experiment.
"""
retrieved_exp_id = self.get_experiment_id(experiment_name)
latest_run = self.client.search_runs(
retrieved_exp_id, order_by=["attribute.start_time DESC"], max_results=1
)
return latest_run[0]
def register_latest_model(
self,
experiment_name: str,
model_registry_name: str,
mlflow_pyfunc_model_path: Optional[str] = None,
):
"""
Registers the latest model of an experiment.
Parameters
----------
experiment_name : str
The name of the experiment.
model_registry_name : str
The name to register the model under in the model registry.
mlflow_pyfunc_model_path : str
The path to the model in the MLflow format.
Returns
-------
ModelVersion
The registered model version.
"""
latest_run = self.get_latest_run(experiment_name)
run_id = latest_run.info.run_id
if mlflow_pyfunc_model_path is None:
run_data = latest_run.data.tags
[contents] = json.loads(run_data["mlflow.log-model.history"])
mlflow_pyfunc_model_path = contents["artifact_path"]
model_version = mlflow.register_model(
f"runs:/{run_id}/{mlflow_pyfunc_model_path}", model_registry_name
)
return model_version
def get_latest_model(
self, model_registry_name: str, stages: str | list[str] | None = None
) -> ModelVersion:
"""
Retrieves the latest model from the model registry.
Parameters
----------
model_registry_name : str
The name of the model in the model registry.
stages : str | list[str] | None, optional
The stages to consider when retrieving the latest model, by default None.
Returns
-------
ModelVersion
The latest model version.
"""
if stages is None:
stages = ["Staging", "Production"] # ['None']
elif isinstance(stages, str):
stages = [stages]
latest_models = self.client.get_latest_versions(
model_registry_name, stages=stages
)
latest_models = list(
filter(lambda model: model.name == model_registry_name, latest_models)
)
if len(latest_models) == 0:
raise Exception(
f"No model found matching the description. {model_registry_name=}"
)
return max(latest_models, key=lambda x: int(x.version))
def create_new_experiment(
self,
experiment_name: str,
experiment_description: str,
experiment_tags: dict[str, str],
) -> str:
"""Create an experiment.
Parameters
----------
experiment_name : str
The experiment name. Must be unique.
experiment_description : str
The experiment description to be displayed in the ui tab
experiment_tags : dict[str, str]
A dictionary of key-value pairs that are converted into
:py:class:`mlflow.entities.ExperimentTag` objects, set as
experiment tags upon experiment creation.
Returns
-------
str
String as an integer ID of the created experiment.
Examples
--------
.. code-block:: python
:caption: Example
# Create a new experiment, will fail if it already exists.
experiment_name = "cltv-lifetime-models-nested-runs"
experiment_description = (
"Project SAMS-CLTV Segmentation and Lifetime Model predictions."
"This project segments the memberships into customer segments and associate_types"
)
# Provide searchable tags that define characteristics of the Runs that will be in this Experiment
experiment_tags = {
"project_name": "cltv-rfm-lifetime-models",
"store": "SAMS",
"model_groups": "associate_types",
"team": "wmt-mx-dl-iaml",
"project_quarter": "Q3-2023",
"mlflow.note.content": experiment_description,
}
"""
retrieved_experiment = self.get_experiment(experiment_name)
if retrieved_experiment is not None:
logger.warning(
f"Experiment '{experiment_name}' already exists: {retrieved_experiment}"
)
return retrieved_experiment.experiment_id
experiment_tags["mlflow.note.content"] = experiment_description
cltv_experiment = self.client.create_experiment(
name=experiment_name, tags=experiment_tags
)
return cltv_experiment
def get_or_create_experiment(
self, experiment_name: str, tags: Optional[dict] = None
) -> Experiment:
"""
Retrieves an experiment if it exists, otherwise creates a new one.
Parameters
----------
experiment_name : str
The name of the experiment.
Returns
-------
Experiment
The retrieved or created experiment.
"""
retrieved_experiment = self.get_experiment(experiment_name)
if retrieved_experiment is not None:
return retrieved_experiment
logger.warning(
f"'{experiment_name}' not found. Creating new experiment. No metadata was provided. Using defaults."
)
if tags is None:
tags = EXPERIMENT_TAGS
self.client.create_experiment(name=experiment_name, tags=tags)
return self.get_experiment(experiment_name)
def log_model(
self, model: Any, artifact_path: str, flavor: str, **kwargs
) -> ModelInfo:
if flavor not in self.log_functions:
raise ModelFlavorNotSupported(f"Model flavour {flavor} is not valid.")
model_info = self.log_functions[flavor](model, artifact_path, **kwargs)
return model_info
def log_pyfunc_model(
self,
model: Any,
local_model_path: str,
mlflow_pyfunc_model_path: str,
python_model: PythonModel = PythonModel,
code_path: list[str] | tuple[str, ...] = ("./src", "./config", "consts.py"),
**kwargs,
):
if "signature" in kwargs:
signature = kwargs.pop("signature")
else:
signature = getattr(python_model, "signature")
code_path = list(code_path)
local_model_path = self._assert_local_model_store(
model, local_model_path, mlflow_pyfunc_model_path
)
artifacts = {"model_path": local_model_path}
model_info = mlflow.pyfunc.log_model(
artifact_path=mlflow_pyfunc_model_path,
python_model=python_model(),
code_path=code_path,
artifacts=artifacts,
signature=signature,
**kwargs,
)
return model_info
def _assert_local_model_store(
self, model: Any, local_model_path: str, mlflow_pyfunc_model_path: Optional[str]
):
"""
Asserts that the local model stored is valid.
Parameters
----------
model : Any
The model to check.
local_model_path : str
The local path to the model.
mlflow_pyfunc_model_path : Optional[str]
The path to the model in the MLflow format.
Returns
-------
str
The validated local model path.
"""
assert not isinstance(model, str)
# determine if local_model_path is a directory and that the filename can be constructed
if not local_model_path.endswith(".pkl"):
if isinstance(mlflow_pyfunc_model_path, str):
local_model_path = f"{local_model_path}/{mlflow_pyfunc_model_path}.pkl"
else:
logger.error(
"When passing local_model_path as a directory, "
"mlflow_pyfunc_model_path must be provided too."
)
raise InvalidModelName(
f"Could not reconstruct the model name from {local_model_path=} & {mlflow_pyfunc_model_path=}"
)
# Handle cases when model is not provided, it must be saved already in memory
if model is None:
if os.path.isfile(local_model_path) and local_model_path.endswith(".pkl"):
return local_model_path
logger.error(
"If model argument is not provided it must be already saved to memory."
)
raise FileNotFoundError(f"Model not found in '{local_model_path}'")
save_pickle(model=model, local_model_path=local_model_path)
logger.info(f"Model saved to {local_model_path}. ({model=})")
return local_model_path
def register_model(
self, run_id: str, mlflow_pyfunc_model_path: str, model_registry_name: str
):
"""
Registers a model to the model registry.
Parameters
----------
run_id : str
The ID of the run.
mlflow_pyfunc_model_path : str
The path to the model in the MLflow format.
model_registry_name : str
The name to register the model under in the model registry.
Returns
-------
ModelVersion
The registered model version.
"""
model_version = mlflow.register_model(
f"runs:/{run_id}/{mlflow_pyfunc_model_path}", model_registry_name
)
return model_version
def log_metrics(self, metrics: dict[str, float], step: int | None = None):
mlflow.log_metrics(metrics, step=step)
def log_params(self, params: dict[str, Any]):
mlflow.log_params(params)
def set_tags(self, tags: dict[str, Any]):
mlflow.set_tags(tags)
def register_model_to_stage(
self,
run_id: str,
model_registry_name: str,
mlflow_pyfunc_model_path: str,
stage: None | str = None,
):
"""
Registers a model to the MLflow Model Registry and optionally sets its stage.
Parameters
----------
run_id : str
The ID of the run that produced the model.
model_registry_name : str
The name to register the model under in the model registry.
mlflow_pyfunc_model_path : str
The path to the model in the MLflow format.
stage : str, optional
The stage to set for the model in the model registry. If this is `None`, the function will not set a stage for the model.
Returns
-------
mlflow.entities.model_registry.ModelVersion
The registered model version.
Raises
------
MlflowException
If an error occurs while registering the model or transitioning its stage.
"""
model_version = self.register_model(
run_id, mlflow_pyfunc_model_path, model_registry_name
)
if stage is not None:
model_version = self.promote_model_to_stage(
model_version.name,
model_version.version,
stage=stage,
)
return model_version
def promote_model_to_stage(self, model_name: str, model_version: str, stage: str):
archive_existing_versions = stage in ["Production", "Staging"]
model_version = self.client.transition_model_version_stage(
model_name,
model_version,
stage=stage,
archive_existing_versions=archive_existing_versions,
)
return model_version
def load_latest_model(
self, model_registry_name: str, model_flavour: str = "pyfunc"
) -> mlflow.pyfunc.PyFuncModel:
latest_model = self.get_latest_model(model_registry_name)
loaded_model = self.load_functions[model_flavour](
f"models:/{latest_model.name}/{latest_model.version}"
)
return loaded_model
def load_model_version(self, model_registry_name: str, model_version: str):
model_info = self.client.get_model_version(model_registry_name, model_version)
loaded_model = mlflow.pyfunc.load_model(
f"models:/{model_info.name}/{model_info.version}"
)
return loaded_model
def load_model_from_run(self, run_id: str) -> mlflow.pyfunc.PyFuncModel:
run_obj = self.client.get_run(run_id)
tags = run_obj.data.tags
[model_history] = json.loads(tags["mlflow.log-model.history"])
artifact_path = model_history["artifact_path"]
logged_model = f"runs:/{run_id}/{artifact_path}"
# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(logged_model)
return loaded_model
def log_table(self, df: pd.DataFrame, table_name: str):
fig = go.Figure(
data=[
go.Table(
header=dict(
values=list(df.columns),
fill_color="paleturquoise",
align="left",
),
cells=dict(
values=[df[col] for col in df],
fill_color="lavender",
align="left",
),
)
]
)
if not table_name.endswith(".html"):
root, ext = splitext(table_name)
table_name = f"{root}.html"
mlflow.log_figure(fig, table_name)
def log_parquet(self, df: pd.DataFrame, local_path: str, artifact_path: str):
df.to_parquet(local_path, index=False)
mlflow.log_artifact(local_path, artifact_path)
def log_yaml(self, data: Any, local_path: str, artifact_path: str):
with open(local_path, "w", encoding="utf-8") as file:
yaml.dump(data, file, default_flow_style=False)
mlflow.log_artifact(local_path, artifact_path)
if __name__ == "__main__":
pass