diff --git a/datarobot_provider/__init__.py b/datarobot_provider/__init__.py index 74fa37e..cbd9a3d 100644 --- a/datarobot_provider/__init__.py +++ b/datarobot_provider/__init__.py @@ -7,10 +7,10 @@ # Released under the terms of DataRobot Tool and Utility Agreement. def get_provider_info(): return { - "package-name": "airflow-provider-datarobot", + "package-name": "airflow-provider-datarobot-early-access", "name": "DataRobot Airflow Provider", "description": "DataRobot Airflow provider.", - "versions": ["0.0.8"], + "versions": ["0.0.8.1"], "connection-types": [ { "hook-class-name": "datarobot_provider.hooks.datarobot.DataRobotHook", diff --git a/datarobot_provider/example_dags/custom_job/__init__.py b/datarobot_provider/example_dags/custom_job/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/datarobot_provider/example_dags/custom_job/job.py b/datarobot_provider/example_dags/custom_job/job.py new file mode 100644 index 0000000..ddd8be5 --- /dev/null +++ b/datarobot_provider/example_dags/custom_job/job.py @@ -0,0 +1,31 @@ +import os +from datarobot import Deployment + + +def main(): + print(f"Running python code: {__file__}") + + # Using this job runtime parameters + print() + print("Runtime parameters:") + print("-------------------") + string_param = os.environ.get("STRING_PARAMETER", None) + print(f"string param: {string_param}") + + deployment_param = os.environ.get("DEPLOYMENT", None) + print(f"deployment_param: {deployment_param}") + + model_package_param = os.environ.get("MODEL_PACKAGE", None) + print(f"model_package_param: {model_package_param}") + + # An example of using the python client to list deployments + deployments = Deployment.list() + print() + print("List of all deployments") + print("-----------------------") + for deployment in deployments: + print(deployment) + + +if __name__ == "__main__": + main() diff --git a/datarobot_provider/example_dags/custom_job/metadata.yaml b/datarobot_provider/example_dags/custom_job/metadata.yaml new file mode 100644 index 0000000..dee0789 --- /dev/null +++ b/datarobot_provider/example_dags/custom_job/metadata.yaml @@ -0,0 +1,12 @@ +name: runtime-params + +runtimeParameterDefinitions: + - fieldName: MODEL_PACKAGE + type: modelPackage + description: Model package that will be used to store key values + - fieldName: DEPLOYMENT + type: deployment + description: Deployment that will be used to make predictions + - fieldName: STRING_PARAMETER + type: string + description: An example of a string parameter diff --git a/datarobot_provider/example_dags/custom_job/run.sh b/datarobot_provider/example_dags/custom_job/run.sh new file mode 100644 index 0000000..3032fdc --- /dev/null +++ b/datarobot_provider/example_dags/custom_job/run.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +echo "Job Starting: ($0)" + +echo "===== Runtime Parameters ======" +echo "Model Package: $MODEL_PACKAGE" +echo "Deployment: $DEPLOYMENT" +echo "STRING_PARAMETER: $STRING_PARAMETER" +echo +echo +echo "===== Generic Variables ===========================" +echo "CURRENT_CUSTOM_JOB_RUN_ID: $CURRENT_CUSTOM_JOB_RUN_ID" +echo "CURRENT_CUSTOM_JOB_ID: $CURRENT_CUSTOM_JOB_ID" +echo "DATAROBOT_ENDPOINT: $DATAROBOT_ENDPOINT" +echo "DATAROBOT_API_TOKEN: Use the environment variable $DATAROBOT_API_TOKEN" +echo "===================================================" + +echo +echo "How to check how much memory your job has" + memory_limit_bytes=$(cat /sys/fs/cgroup/memory/memory.limit_in_bytes) + memory_limit_megabytes=$((memory_limit_bytes / 1024 / 1024)) +echo "Memory Limit (in Megabytes): $memory_limit_megabytes" +echo + +# Uncomment the following if you want to check if the job has network access +## Define the IP address of an external server to ping (e.g., Google's DNS) +#external_server="8.8.8.8" +#echo "Checking internet connection" +## Try to ping the external server +#ping -c 1 $external_server > /dev/null 2>&1 +# +## Check the exit status of the ping command +#if [ $? -eq 0 ]; then +# echo "Internet connection is available." +#else +# echo "No internet connection." +#fi +#echo +#echo + +# Run the code in job.py +dir_path=$(dirname $0) +echo "Entrypoint is at $dir_path - cd into it" +cd $dir_path + +if command -v python3 &>/dev/null; then + echo "python3 is installed and available." +else + echo "Error: python3 is not installed or not available." + exit 1 +fi + +python_file="job.py" +if [ -f "$python_file" ]; then + echo "Found $python_file .. running it" + python3 ./job.py +else + echo "File $python_file does not exist" + exit 1 +fi diff --git a/datarobot_provider/example_dags/datarobot_custom_job_dag.py b/datarobot_provider/example_dags/datarobot_custom_job_dag.py new file mode 100644 index 0000000..9d6fc75 --- /dev/null +++ b/datarobot_provider/example_dags/datarobot_custom_job_dag.py @@ -0,0 +1,92 @@ +# Copyright 2024 DataRobot, Inc. and its affiliates. +# +# All rights reserved. +# +# This is proprietary source code of DataRobot, Inc. and its affiliates. +# +# Released under the terms of DataRobot Tool and Utility Agreement. + +from datetime import datetime + +from airflow.decorators import dag + +from datarobot_provider.operators.custom_job import CreateCustomJobOperator +from datarobot_provider.operators.custom_job import AddFilesToCustomJobOperator +from datarobot_provider.operators.custom_job import SetCustomJobExecutionEnvironmentOperator +from datarobot_provider.operators.custom_job import SetCustomJobRuntimeParametersOperator +from datarobot_provider.operators.custom_job import RunCustomJobOperator +from datarobot_provider.sensors.client import BaseAsyncResolutionSensor + + +@dag( + schedule=None, + start_date=datetime(2023, 1, 1), + tags=['example', 'custom job'], + params={}, +) +def create_custom_custom_job(): + create_custom_job_op = CreateCustomJobOperator( + task_id='create_custom_job', + name="airflow-test-create-custom-job-v556", + description="demo-test-demonstration", + ) + + add_files_to_custom_job_op = AddFilesToCustomJobOperator( + task_id='add_files_to_custom_job', + custom_job_id=create_custom_job_op.output, + files_path="custom_job/", + ) + + # list_execution_env_op = ListExecutionEnvironmentOperator( + # task_id='list_execution_env', + # search_for="Python 3.9 PyTorch Drop-In" + # ) + + set_env_to_custom_job_op = SetCustomJobExecutionEnvironmentOperator( + task_id='set_env_to_custom_job', + custom_job_id=create_custom_job_op.output, + environment_id='5e8c888007389fe0f466c72b', + environment_version_id='65c1db901800cd9782d7ac07', + ) + + set_runtime_parameters_op = SetCustomJobRuntimeParametersOperator( + task_id='set_runtime_parameters', + custom_job_id=create_custom_job_op.output, + runtime_parameter_values=[ + {"fieldName": "DEPLOYMENT", "type": "deployment", "value": "650ef15944f21ea1a3c91a25"}, + { + "fieldName": "MODEL_PACKAGE", + "type": "modelPackage", + "value": "654b9b228404a39b5c8da5b2", + }, + {"fieldName": "STRING_PARAMETER", "type": "string", "value": 'my test string'}, + ], + ) + + run_custom_job_op = RunCustomJobOperator( + task_id='run_custom_job', + custom_job_id=create_custom_job_op.output, + ) + + custom_job_complete_sensor = BaseAsyncResolutionSensor( + task_id="check_custom_job_complete", + job_id=run_custom_job_op.output, + poke_interval=5, + mode="reschedule", + timeout=3600, + ) + + ( + create_custom_job_op + >> add_files_to_custom_job_op + >> set_env_to_custom_job_op + >> set_runtime_parameters_op + >> run_custom_job_op + >> custom_job_complete_sensor + ) + + +create_custom_job_dag = create_custom_custom_job() + +if __name__ == "__main__": + create_custom_job_dag.test() diff --git a/datarobot_provider/operators/custom_job.py b/datarobot_provider/operators/custom_job.py new file mode 100644 index 0000000..7cf7ac0 --- /dev/null +++ b/datarobot_provider/operators/custom_job.py @@ -0,0 +1,462 @@ +# Copyright 2023 DataRobot, Inc. and its affiliates. +# +# All rights reserved. +# +# This is proprietary source code of DataRobot, Inc. and its affiliates. +# +# Released under the terms of DataRobot Tool and Utility Agreement. +import json +import os +from pathlib import Path +from typing import Any +from typing import Dict +from typing import Iterable + +import datarobot as dr +from airflow.exceptions import AirflowException, AirflowFailException +from airflow.models import BaseOperator + +from datarobot_provider.hooks.datarobot import DataRobotHook + + +class CreateCustomJobOperator(BaseOperator): + """ + Create a DataRobot custom job. + :param name: custom job name + :type name: str + :param description: custom job description + :type description: str, optional + :return: custom job ID + :rtype: str + """ + + # Specify the arguments that are allowed to parse with jinja templating + template_fields: Iterable[str] = [ + "name", + "description", + ] + template_fields_renderers: Dict[str, str] = {} + template_ext: Iterable[str] = () + ui_color = '#f4a460' + + def __init__( + self, + *, + name: str, + description: str = None, + datarobot_conn_id: str = "datarobot_default", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.name = name + self.description = description + self.datarobot_conn_id = datarobot_conn_id + if kwargs.get('xcom_push') is not None: + raise AirflowException( + "'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead" + ) + + def execute(self, context: Dict[str, Any]) -> str: + # Initialize DataRobot client + DataRobotHook(datarobot_conn_id=self.datarobot_conn_id).run() + + if self.name is None: + raise ValueError("Invalid or missing custom job name") + + response = dr.client.get_client().post( + "customJobs/", + data={ + "name": self.name, + "description": self.description, + }, + ) + + if response.status_code == 201: + custom_job_id = response.json()["id"] + self.log.info(f"Custom job created, custom_job_id={custom_job_id}") + return custom_job_id + else: + e_msg = "Server unexpectedly returned status code {}" + raise AirflowFailException(e_msg.format(response.status_code)) + + +class AddFilesToCustomJobOperator(BaseOperator): + """ + Adding files to custom job from specified location. + :param custom_job_id: custom job ID + :type custom_job_id: str + :param files_path: files location to add + :type files_path: str + :return: list of files added to the custom job + :rtype: List[str] + """ + + # Specify the arguments that are allowed to parse with jinja templating + template_fields: Iterable[str] = [ + "custom_job_id", + "files_path", + ] + template_fields_renderers: Dict[str, str] = {} + template_ext: Iterable[str] = () + ui_color = '#f4a460' + + def __init__( + self, + *, + custom_job_id: str = None, + files_path: str = None, + datarobot_conn_id: str = "datarobot_default", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.custom_job_id = custom_job_id + self.files_path = files_path + self.datarobot_conn_id = datarobot_conn_id + if kwargs.get('xcom_push') is not None: + raise AirflowException( + "'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead" + ) + + def upload_custom_job_file(self, file_path, filename): + with open(file_path, "rb") as file_payload: + files = { + 'file': file_payload, + 'filePath': filename, + } + + response = dr.client.get_client().build_request_with_file( + form_data=files, + fname=filename, + filelike=file_payload, + url=f"customJobs/{self.custom_job_id}/", + method="patch", + ) + + return response.status_code + + def execute(self, context: Dict[str, Any]) -> str: + # Initialize DataRobot client + DataRobotHook(datarobot_conn_id=self.datarobot_conn_id).run() + + airflow_home = os.environ.get('AIRFLOW_HOME') + self.files_path = Path(f'{airflow_home}/dags/{self.files_path}') + if self.custom_job_id is None: + raise ValueError("Invalid or missing custom_job_id") + + if self.files_path is None: + raise ValueError("Invalid or missing files_path") + uploaded_files = [] + if os.path.isdir(self.files_path): + for file_name in os.listdir(self.files_path): + self.upload_custom_job_file(self.files_path / file_name, file_name) + uploaded_files.append(file_name) + return uploaded_files + + +class SetCustomJobExecutionEnvironmentOperator(BaseOperator): + """ + Set execution environment to the custom job. + :param custom_job_id: DataRobot custom job ID + :type custom_job_id: str + :param environment_id: execution environment ID + :type environment_id: str + :param environment_version_id: execution environment version ID + :type environment_version_id: str + :return: operation response + :rtype: dict + """ + + # Specify the arguments that are allowed to parse with jinja templating + template_fields: Iterable[str] = [ + "custom_job_id", + "environment_id", + "environment_version_id", + ] + template_fields_renderers: Dict[str, str] = {} + template_ext: Iterable[str] = () + ui_color = '#f4a460' + + def __init__( + self, + *, + custom_job_id: str, + environment_id: str, + environment_version_id: str, + datarobot_conn_id: str = "datarobot_default", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.custom_job_id = custom_job_id + self.environment_id = environment_id + self.environment_version_id = environment_version_id + self.datarobot_conn_id = datarobot_conn_id + if kwargs.get('xcom_push') is not None: + raise AirflowException( + "'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead" + ) + + def execute(self, context: Dict[str, Any]) -> str: + # Initialize DataRobot client + DataRobotHook(datarobot_conn_id=self.datarobot_conn_id).run() + + if self.custom_job_id is None: + raise ValueError("Invalid or missing custom_job_id") + + if self.environment_id is None: + raise ValueError("Invalid or missing environment_id") + + if self.environment_version_id is None: + raise ValueError("Invalid or missing environment_version_id") + + form_data = { + "environmentId": self.environment_id, + "environmentVersionId": self.environment_version_id, + } + + response = dr.client.get_client().patch( + url=f"customJobs/{self.custom_job_id}/", + data=form_data, + ) + + if response.status_code == 201: + response_json = response.json() + custom_job_id = response_json["id"] + environment_id = response_json["environmentId"] + environment_version_id = response_json["environmentVersionId"] + self.log.info( + f"Custom job_id={custom_job_id} environment updated " + f"with environment_id={environment_id} " + f"and environment_version_id={environment_version_id}" + ) + return response_json + else: + e_msg = "Server unexpectedly returned status code {}" + raise AirflowFailException(e_msg.format(response.status_code)) + + +class SetCustomJobRuntimeParametersOperator(BaseOperator): + """ + Create an execution environment. + :param name: execution environment name + :type name: str + :param description: execution environment description + :type description: str, optional + :param programming_language: programming language of the environment to be created. + Can be "python", "r", "java" or "other". Default value - "other" + :type programming_language: str, optional + :return: execution environment ID + :rtype: str + """ + + # Specify the arguments that are allowed to parse with jinja templating + template_fields: Iterable[str] = ["custom_job_id", "runtime_parameter_values"] + template_fields_renderers: Dict[str, str] = {} + template_ext: Iterable[str] = () + ui_color = '#f4a460' + + def __init__( + self, + *, + custom_job_id: str, + runtime_parameter_values: str = None, + datarobot_conn_id: str = "datarobot_default", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.custom_job_id = custom_job_id + self.runtime_parameter_values = runtime_parameter_values + self.datarobot_conn_id = datarobot_conn_id + if kwargs.get('xcom_push') is not None: + raise AirflowException( + "'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead" + ) + + def execute(self, context: Dict[str, Any]) -> str: + # Initialize DataRobot client + DataRobotHook(datarobot_conn_id=self.datarobot_conn_id).run() + + if self.runtime_parameter_values is None: + raise ValueError("Invalid or missing runtime_parameter_values") + + if isinstance(self.runtime_parameter_values, list): + self.runtime_parameter_values = json.dumps(self.runtime_parameter_values) + + form_data = { + "runtimeParameterValues": self.runtime_parameter_values, + } + + response = dr.client.get_client().patch( + url=f"customJobs/{self.custom_job_id}/", + data=form_data, + ) + + if response.status_code == 201: + response_json = response.json() + custom_job_id = response_json["id"] + self.log.info(f"Custom job_id={custom_job_id} environment updated") + return response_json + else: + e_msg = "Server unexpectedly returned status code {}" + raise AirflowFailException(e_msg.format(response.status_code)) + + +class RunCustomJobOperator(BaseOperator): + """ + Run custom job and return ID for job status check. + :param custom_job_id: custom job ID + :type custom_job_id: str + :param setup_dependencies: setup dependencies flag + :type bool, optional + :return: status check ID + :rtype: str + """ + + # Specify the arguments that are allowed to parse with jinja templating + template_fields: Iterable[str] = [ + "custom_job_id", + ] + template_fields_renderers: Dict[str, str] = {} + template_ext: Iterable[str] = () + ui_color = '#f4a460' + + def __init__( + self, + *, + custom_job_id: str = None, + setup_dependencies: bool = False, + datarobot_conn_id: str = "datarobot_default", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.custom_job_id = custom_job_id + self.setup_dependencies = setup_dependencies + self.datarobot_conn_id = datarobot_conn_id + if kwargs.get('xcom_push') is not None: + raise AirflowException( + "'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead" + ) + + def execute(self, context: Dict[str, Any]) -> str: + # Initialize DataRobot client + DataRobotHook(datarobot_conn_id=self.datarobot_conn_id).run() + + if self.custom_job_id is None: + raise ValueError("Invalid or missing custom job id") + + response = dr.client.get_client().post( + f"customJobs/{self.custom_job_id}/runs/", + data={"setupDependencies": self.setup_dependencies}, + ) + + if response.status_code == 201: + response_json = response.json() + job_status_id = response_json["jobStatusId"] + custom_job_id = response.json()["id"] + self.log.info(f"Custom job created, custom_job_id={custom_job_id}") + return job_status_id + else: + e_msg = "Server unexpectedly returned status code {}" + raise AirflowFailException(e_msg.format(response.status_code)) + + +class ListExecutionEnvironmentOperator(BaseOperator): + """ + List all exising execution environments that matches search condition. + :param search_for: the string for filtering execution environment - only execution + environments that contain the string in name or description will + be returned. + :type search_for: str, optional + :return: execution environment ID + :rtype: str + """ + + # Specify the arguments that are allowed to parse with jinja templating + template_fields: Iterable[str] = [ + "search_for", + ] + template_fields_renderers: Dict[str, str] = {} + template_ext: Iterable[str] = () + ui_color = '#f4a460' + + def __init__( + self, + *, + search_for: str = None, + datarobot_conn_id: str = "datarobot_default", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.search_for = search_for + self.datarobot_conn_id = datarobot_conn_id + if kwargs.get('xcom_push') is not None: + raise AirflowException( + "'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead" + ) + + def execute(self, context: Dict[str, Any]) -> str: + # Initialize DataRobot client + DataRobotHook(datarobot_conn_id=self.datarobot_conn_id).run() + + execution_environments = dr.ExecutionEnvironment.list(search_for=self.search_for) + execution_environment_ids = [ + execution_environment.id for execution_environment in execution_environments + ] + self.log.info(f"List of execution environments ids = {execution_environment_ids}") + + return execution_environment_ids + + +class ListExecutionEnvironmentVersionsOperator(BaseOperator): + """ + List all exising execution environments versions that matches search condition. + :param search_for: the string for filtering execution environment - only execution + environments that contain the string in name or description will + be returned. + :type search_for: str, optional + :return: execution environment ID + :rtype: str + """ + + # Specify the arguments that are allowed to parse with jinja templating + template_fields: Iterable[str] = [ + "environment_id", + "search_for", + ] + template_fields_renderers: Dict[str, str] = {} + template_ext: Iterable[str] = () + ui_color = '#f4a460' + + def __init__( + self, + *, + environment_id: str, + search_for: str = None, + datarobot_conn_id: str = "datarobot_default", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.environment_id = environment_id + self.search_for = search_for + self.datarobot_conn_id = datarobot_conn_id + if kwargs.get('xcom_push') is not None: + raise AirflowException( + "'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead" + ) + + def execute(self, context: Dict[str, Any]) -> str: + # Initialize DataRobot client + DataRobotHook(datarobot_conn_id=self.datarobot_conn_id).run() + + if self.environment_id is None: + raise ValueError("Invalid or missing environment_id") + + execution_environments = dr.ExecutionEnvironmentVersion.list( + execution_environment_id=self.environment_id, + # search_for=self.search_for + ) + execution_environment_ids = [ + execution_environment.id for execution_environment in execution_environments + ] + self.log.info(f"List of execution environments ids = {execution_environment_ids}") + + return execution_environment_ids diff --git a/setup.py b/setup.py index e64a25d..6c5703d 100644 --- a/setup.py +++ b/setup.py @@ -12,8 +12,8 @@ """Perform the package airflow-provider-datarobot setup.""" setup( - name='airflow-provider-datarobot', - version="0.0.8", + name='airflow-provider-datarobot-early-access', + version="0.0.8.1", description='DataRobot Airflow provider.', long_description=long_description, long_description_content_type='text/markdown',