From 8d159b058894ffed4bfe6948611517663c55756d Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Tue, 6 Aug 2024 14:20:44 +0200 Subject: [PATCH] feat: replace promote with prompt ab testing --- literalai/api/__init__.py | 115 ++++++++++++++++--------- literalai/api/gql.py | 35 ++++++-- literalai/api/prompt_helpers.py | 35 +++++--- literalai/prompt_engineering/prompt.py | 10 +-- tests/e2e/test_e2e.py | 46 ++++++++-- 5 files changed, 164 insertions(+), 77 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 73f71da..502f056 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -10,27 +10,12 @@ Literal, Optional, TypeVar, - Union, cast, + Union, + cast, ) from typing_extensions import deprecated -from literalai.context import active_steps_var, active_thread_var -from literalai.evaluation.dataset import Dataset, DatasetType -from literalai.evaluation.dataset_experiment import DatasetExperiment, DatasetExperimentItem -from literalai.observability.filter import ( - generations_filters, - generations_order_by, - scores_filters, - scores_order_by, - steps_filters, - steps_order_by, - threads_filters, - threads_order_by, - users_filters, -) -from literalai.prompt_engineering.prompt import Prompt, ProviderSettings - from literalai.api.attachment_helpers import ( AttachmentUpload, create_attachment_helper, @@ -51,12 +36,17 @@ get_dataset_item_helper, update_dataset_helper, ) -from literalai.api.generation_helpers import create_generation_helper, get_generations_helper +from literalai.api.generation_helpers import ( + create_generation_helper, + get_generations_helper, +) from literalai.api.prompt_helpers import ( + PromptRollout, create_prompt_helper, create_prompt_lineage_helper, + get_prompt_ab_testing_helper, get_prompt_helper, - promote_prompt_helper, + update_prompt_ab_testing_helper, ) from literalai.api.score_helpers import ( ScoreUpdate, @@ -91,18 +81,45 @@ get_users_helper, update_user_helper, ) +from literalai.context import active_steps_var, active_thread_var +from literalai.evaluation.dataset import Dataset, DatasetType +from literalai.evaluation.dataset_experiment import ( + DatasetExperiment, + DatasetExperimentItem, +) +from literalai.observability.filter import ( + generations_filters, + generations_order_by, + scores_filters, + scores_order_by, + steps_filters, + steps_order_by, + threads_filters, + threads_order_by, + users_filters, +) +from literalai.prompt_engineering.prompt import Prompt, ProviderSettings if TYPE_CHECKING: from typing import Tuple # noqa: F401 import httpx -from literalai.my_types import ( - Environment, - PaginatedResponse, +from literalai.my_types import Environment, PaginatedResponse +from literalai.observability.generation import ( + ChatGeneration, + CompletionGeneration, + GenerationMessage, +) +from literalai.observability.step import ( + Attachment, + Score, + ScoreDict, + ScoreType, + Step, + StepDict, + StepType, ) -from literalai.observability.generation import GenerationMessage, CompletionGeneration, ChatGeneration -from literalai.observability.step import Step, StepDict, StepType, ScoreType, ScoreDict, Score, Attachment logger = logging.getLogger(__name__) @@ -678,8 +695,7 @@ def upload_file( fields: Dict = request_dict.get("fields", {}) object_key: Optional[str] = fields.get("key") upload_type: Literal["raw", "multipart"] = cast( - Literal["raw", "multipart"], - request_dict.get("uploadType", "multipart") + Literal["raw", "multipart"], request_dict.get("uploadType", "multipart") ) signed_url: Optional[str] = json_res.get("signedUrl") @@ -1344,21 +1360,33 @@ def get_prompt( else: raise ValueError("Either the `id` or the `name` must be provided.") - def promote_prompt(self, name: str, version: int) -> str: + def get_prompt_ab_testing(self, name: str) -> List[PromptRollout]: """ - Promotes the prompt with name to target version. + Get the A/B testing configuration for a prompt lineage. Args: name (str): The name of the prompt lineage. - version (int): The version number to promote. - Returns: - str: The champion prompt ID. + List[PromptRollout] """ - lineage = self.get_or_create_prompt_lineage(name) - lineage_id = lineage["id"] + return self.gql_helper(*get_prompt_ab_testing_helper(name=name)) - return self.gql_helper(*promote_prompt_helper(lineage_id, version)) + def update_prompt_ab_testing( + self, name: str, rollouts: List[PromptRollout] + ) -> Dict: + """ + Update the A/B testing configuration for a prompt lineage. + + Args: + name (str): The name of the prompt lineage. + rollouts (List[PromptRollout]): The percentage rollout for each prompt version. + + Returns: + Dict + """ + return self.gql_helper( + *update_prompt_ab_testing_helper(name=name, rollouts=rollouts) + ) # Misc API @@ -1912,8 +1940,7 @@ async def upload_file( fields: Dict = request_dict.get("fields", {}) object_key: Optional[str] = fields.get("key") upload_type: Literal["raw", "multipart"] = cast( - Literal["raw", "multipart"], - request_dict.get("uploadType", "multipart") + Literal["raw", "multipart"], request_dict.get("uploadType", "multipart") ) signed_url: Optional[str] = json_res.get("signedUrl") @@ -2529,13 +2556,19 @@ async def get_prompt( get_prompt.__doc__ = LiteralAPI.get_prompt.__doc__ - async def promote_prompt(self, name: str, version: int) -> str: - lineage = await self.get_or_create_prompt_lineage(name) - lineage_id = lineage["id"] + async def update_prompt_ab_testing( + self, name: str, rollouts: List[PromptRollout] + ) -> Dict: + return await self.gql_helper( + *update_prompt_ab_testing_helper(name=name, rollouts=rollouts) + ) + + update_prompt_ab_testing.__doc__ = LiteralAPI.update_prompt_ab_testing.__doc__ - return await self.gql_helper(*promote_prompt_helper(lineage_id, version)) + async def get_prompt_ab_testing(self, name: str) -> List[PromptRollout]: + return await self.gql_helper(*get_prompt_ab_testing_helper(name=name)) - promote_prompt.__doc__ = LiteralAPI.promote_prompt.__doc__ + get_prompt_ab_testing.__doc__ = LiteralAPI.get_prompt_ab_testing.__doc__ # Misc API diff --git a/literalai/api/gql.py b/literalai/api/gql.py index d98b2e9..8369835 100644 --- a/literalai/api/gql.py +++ b/literalai/api/gql.py @@ -1031,16 +1031,35 @@ } """ -PROMOTE_PROMPT_VERSION = """mutation promotePromptVersion( - $lineageId: String! - $version: Int! +GET_PROMPT_AB_TESTING = """query getPromptLineageRollout($projectId: String, $lineageName: String!) { + promptLineageRollout(projectId: $projectId, lineageName: $lineageName) { + pageInfo { + startCursor + endCursor + } + edges { + node { + version + rollout + } + } + } + } +""" + +UPDATE_PROMPT_AB_TESTING = """mutation updatePromptLineageRollout( + $projectId: String + $name: String! + $rollouts: [PromptVersionRolloutInput!]! ) { - promotePromptVersion( - lineageId: $lineageId - version: $version + updatePromptLineageRollout( + projectId: $projectId + name: $name + rollouts: $rollouts ) { - id - championId + ok + message + errorCode } }""" diff --git a/literalai/api/prompt_helpers.py b/literalai/api/prompt_helpers.py index 830176e..e754406 100644 --- a/literalai/api/prompt_helpers.py +++ b/literalai/api/prompt_helpers.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict from literalai.observability.generation import GenerationMessage from literalai.prompt_engineering.prompt import Prompt, ProviderSettings @@ -61,16 +61,31 @@ def process_response(response): return gql.GET_PROMPT_VERSION, description, variables, process_response -def promote_prompt_helper( - lineage_id: str, - version: int, +class PromptRollout(TypedDict): + version: int + rollout: int + + +def get_prompt_ab_testing_helper( + name: Optional[str] = None, ): - variables = {"lineageId": lineage_id, "version": version} + variables = {"lineageName": name} + + def process_response(response) -> List[PromptRollout]: + response_data = response["data"]["promptLineageRollout"] + return list(map(lambda x: x["node"], response_data["edges"])) + + description = "get prompt A/B testing" + + return gql.GET_PROMPT_AB_TESTING, description, variables, process_response + + +def update_prompt_ab_testing_helper(name: str, rollouts: List[PromptRollout]): + variables = {"name": name, "rollouts": rollouts} - def process_response(response) -> str: - prompt = response["data"]["promotePromptVersion"] - return prompt["championId"] if prompt else None + def process_response(response) -> Dict: + return response["data"]["updatePromptLineageRollout"] - description = "promote prompt version" + description = "update prompt A/B testing" - return gql.PROMOTE_PROMPT_VERSION, description, variables, process_response + return gql.UPDATE_PROMPT_AB_TESTING, description, variables, process_response diff --git a/literalai/prompt_engineering/prompt.py b/literalai/prompt_engineering/prompt.py index 93b15b9..fdaf609 100644 --- a/literalai/prompt_engineering/prompt.py +++ b/literalai/prompt_engineering/prompt.py @@ -3,9 +3,8 @@ from importlib.metadata import version from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional -from typing_extensions import deprecated, TypedDict - import chevron +from typing_extensions import TypedDict, deprecated if TYPE_CHECKING: from literalai.api import LiteralAPI @@ -117,13 +116,6 @@ def from_dict(cls, api: "LiteralAPI", prompt_dict: PromptDict) -> "Prompt": variables_default_values=prompt_dict.get("variablesDefaultValues"), ) - def promote(self) -> "Prompt": - """ - Promotes this prompt to champion. - """ - self.api.promote_prompt(self.name, self.version) - return self - def format_messages(self, **kwargs: Any) -> List[Any]: """ Formats the prompt's template messages with the given variables. diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index f7dae21..73dd7e5 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -3,12 +3,13 @@ import secrets import time import uuid +from typing import List import pytest from literalai import AsyncLiteralClient, LiteralClient from literalai.context import active_steps_var -from literalai.observability.generation import ChatGeneration +from literalai.observability.generation import ChatGeneration, GenerationMessage from literalai.observability.thread import Thread """ @@ -599,16 +600,43 @@ async def test_prompt(self, async_client: AsyncLiteralClient): assert messages[0]["content"] == expected @pytest.mark.timeout(5) - async def test_champion_prompt(self, client: LiteralClient): - new_prompt = client.api.get_or_create_prompt( - name="Python SDK E2E Tests", - template_messages=[{"role": "user", "content": "Hello"}], + async def test_prompt_ab_testing(self, client: LiteralClient): + prompt_name = "Python SDK E2E Tests" + + v0: List[GenerationMessage] = [{"role": "user", "content": "Hello"}] + v1: List[GenerationMessage] = [{"role": "user", "content": "Hello 2"}] + + prompt_v0 = client.api.get_or_create_prompt( + name=prompt_name, + template_messages=v0, ) - new_prompt.promote() - prompt = client.api.get_prompt(name="Python SDK E2E Tests") - assert prompt is not None - assert prompt.version == new_prompt.version + client.api.update_prompt_ab_testing( + prompt_v0.name, rollouts=[{"version": 0, "rollout": 100}] + ) + + ab_testing = client.api.get_prompt_ab_testing(name=prompt_v0.name) + assert len(ab_testing) == 1 + assert ab_testing[0]["version"] == 0 + assert ab_testing[0]["rollout"] == 100 + + prompt_v1 = client.api.get_or_create_prompt( + name=prompt_name, + template_messages=v1, + ) + + client.api.update_prompt_ab_testing( + name=prompt_v1.name, + rollouts=[{"version": 0, "rollout": 60}, {"version": 1, "rollout": 40}], + ) + + ab_testing = client.api.get_prompt_ab_testing(name=prompt_v1.name) + + assert len(ab_testing) == 2 + assert ab_testing[0]["version"] == 0 + assert ab_testing[0]["rollout"] == 60 + assert ab_testing[1]["version"] == 1 + assert ab_testing[1]["rollout"] == 40 @pytest.mark.timeout(5) async def test_gracefulness(self, broken_client: LiteralClient):