Skip to content

Commit

Permalink
feat: replace promote with prompt ab testing
Browse files Browse the repository at this point in the history
  • Loading branch information
willydouhard committed Aug 6, 2024
1 parent c35302f commit 8d159b0
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 77 deletions.
115 changes: 74 additions & 41 deletions literalai/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,12 @@
Literal,
Optional,
TypeVar,
Union, cast,
Union,
cast,
)

from typing_extensions import deprecated

from literalai.context import active_steps_var, active_thread_var
from literalai.evaluation.dataset import Dataset, DatasetType
from literalai.evaluation.dataset_experiment import DatasetExperiment, DatasetExperimentItem
from literalai.observability.filter import (
generations_filters,
generations_order_by,
scores_filters,
scores_order_by,
steps_filters,
steps_order_by,
threads_filters,
threads_order_by,
users_filters,
)
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

from literalai.api.attachment_helpers import (
AttachmentUpload,
create_attachment_helper,
Expand All @@ -51,12 +36,17 @@
get_dataset_item_helper,
update_dataset_helper,
)
from literalai.api.generation_helpers import create_generation_helper, get_generations_helper
from literalai.api.generation_helpers import (
create_generation_helper,
get_generations_helper,
)
from literalai.api.prompt_helpers import (
PromptRollout,
create_prompt_helper,
create_prompt_lineage_helper,
get_prompt_ab_testing_helper,
get_prompt_helper,
promote_prompt_helper,
update_prompt_ab_testing_helper,
)
from literalai.api.score_helpers import (
ScoreUpdate,
Expand Down Expand Up @@ -91,18 +81,45 @@
get_users_helper,
update_user_helper,
)
from literalai.context import active_steps_var, active_thread_var
from literalai.evaluation.dataset import Dataset, DatasetType
from literalai.evaluation.dataset_experiment import (
DatasetExperiment,
DatasetExperimentItem,
)
from literalai.observability.filter import (
generations_filters,
generations_order_by,
scores_filters,
scores_order_by,
steps_filters,
steps_order_by,
threads_filters,
threads_order_by,
users_filters,
)
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

if TYPE_CHECKING:
from typing import Tuple # noqa: F401

import httpx

from literalai.my_types import (
Environment,
PaginatedResponse,
from literalai.my_types import Environment, PaginatedResponse
from literalai.observability.generation import (
ChatGeneration,
CompletionGeneration,
GenerationMessage,
)
from literalai.observability.step import (
Attachment,
Score,
ScoreDict,
ScoreType,
Step,
StepDict,
StepType,
)
from literalai.observability.generation import GenerationMessage, CompletionGeneration, ChatGeneration
from literalai.observability.step import Step, StepDict, StepType, ScoreType, ScoreDict, Score, Attachment

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -678,8 +695,7 @@ def upload_file(
fields: Dict = request_dict.get("fields", {})
object_key: Optional[str] = fields.get("key")
upload_type: Literal["raw", "multipart"] = cast(
Literal["raw", "multipart"],
request_dict.get("uploadType", "multipart")
Literal["raw", "multipart"], request_dict.get("uploadType", "multipart")
)
signed_url: Optional[str] = json_res.get("signedUrl")

Expand Down Expand Up @@ -1344,21 +1360,33 @@ def get_prompt(
else:
raise ValueError("Either the `id` or the `name` must be provided.")

def promote_prompt(self, name: str, version: int) -> str:
def get_prompt_ab_testing(self, name: str) -> List[PromptRollout]:
"""
Promotes the prompt with name to target version.
Get the A/B testing configuration for a prompt lineage.
Args:
name (str): The name of the prompt lineage.
version (int): The version number to promote.
Returns:
str: The champion prompt ID.
List[PromptRollout]
"""
lineage = self.get_or_create_prompt_lineage(name)
lineage_id = lineage["id"]
return self.gql_helper(*get_prompt_ab_testing_helper(name=name))

return self.gql_helper(*promote_prompt_helper(lineage_id, version))
def update_prompt_ab_testing(
self, name: str, rollouts: List[PromptRollout]
) -> Dict:
"""
Update the A/B testing configuration for a prompt lineage.
Args:
name (str): The name of the prompt lineage.
rollouts (List[PromptRollout]): The percentage rollout for each prompt version.
Returns:
Dict
"""
return self.gql_helper(
*update_prompt_ab_testing_helper(name=name, rollouts=rollouts)
)

# Misc API

Expand Down Expand Up @@ -1912,8 +1940,7 @@ async def upload_file(
fields: Dict = request_dict.get("fields", {})
object_key: Optional[str] = fields.get("key")
upload_type: Literal["raw", "multipart"] = cast(
Literal["raw", "multipart"],
request_dict.get("uploadType", "multipart")
Literal["raw", "multipart"], request_dict.get("uploadType", "multipart")
)
signed_url: Optional[str] = json_res.get("signedUrl")

Expand Down Expand Up @@ -2529,13 +2556,19 @@ async def get_prompt(

get_prompt.__doc__ = LiteralAPI.get_prompt.__doc__

async def promote_prompt(self, name: str, version: int) -> str:
lineage = await self.get_or_create_prompt_lineage(name)
lineage_id = lineage["id"]
async def update_prompt_ab_testing(
self, name: str, rollouts: List[PromptRollout]
) -> Dict:
return await self.gql_helper(
*update_prompt_ab_testing_helper(name=name, rollouts=rollouts)
)

update_prompt_ab_testing.__doc__ = LiteralAPI.update_prompt_ab_testing.__doc__

return await self.gql_helper(*promote_prompt_helper(lineage_id, version))
async def get_prompt_ab_testing(self, name: str) -> List[PromptRollout]:
return await self.gql_helper(*get_prompt_ab_testing_helper(name=name))

promote_prompt.__doc__ = LiteralAPI.promote_prompt.__doc__
get_prompt_ab_testing.__doc__ = LiteralAPI.get_prompt_ab_testing.__doc__

# Misc API

Expand Down
35 changes: 27 additions & 8 deletions literalai/api/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,16 +1031,35 @@
}
"""

PROMOTE_PROMPT_VERSION = """mutation promotePromptVersion(
$lineageId: String!
$version: Int!
GET_PROMPT_AB_TESTING = """query getPromptLineageRollout($projectId: String, $lineageName: String!) {
promptLineageRollout(projectId: $projectId, lineageName: $lineageName) {
pageInfo {
startCursor
endCursor
}
edges {
node {
version
rollout
}
}
}
}
"""

UPDATE_PROMPT_AB_TESTING = """mutation updatePromptLineageRollout(
$projectId: String
$name: String!
$rollouts: [PromptVersionRolloutInput!]!
) {
promotePromptVersion(
lineageId: $lineageId
version: $version
updatePromptLineageRollout(
projectId: $projectId
name: $name
rollouts: $rollouts
) {
id
championId
ok
message
errorCode
}
}"""

Expand Down
35 changes: 25 additions & 10 deletions literalai/api/prompt_helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict

from literalai.observability.generation import GenerationMessage
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings
Expand Down Expand Up @@ -61,16 +61,31 @@ def process_response(response):
return gql.GET_PROMPT_VERSION, description, variables, process_response


def promote_prompt_helper(
lineage_id: str,
version: int,
class PromptRollout(TypedDict):
version: int
rollout: int


def get_prompt_ab_testing_helper(
name: Optional[str] = None,
):
variables = {"lineageId": lineage_id, "version": version}
variables = {"lineageName": name}

def process_response(response) -> List[PromptRollout]:
response_data = response["data"]["promptLineageRollout"]
return list(map(lambda x: x["node"], response_data["edges"]))

description = "get prompt A/B testing"

return gql.GET_PROMPT_AB_TESTING, description, variables, process_response


def update_prompt_ab_testing_helper(name: str, rollouts: List[PromptRollout]):
variables = {"name": name, "rollouts": rollouts}

def process_response(response) -> str:
prompt = response["data"]["promotePromptVersion"]
return prompt["championId"] if prompt else None
def process_response(response) -> Dict:
return response["data"]["updatePromptLineageRollout"]

description = "promote prompt version"
description = "update prompt A/B testing"

return gql.PROMOTE_PROMPT_VERSION, description, variables, process_response
return gql.UPDATE_PROMPT_AB_TESTING, description, variables, process_response
10 changes: 1 addition & 9 deletions literalai/prompt_engineering/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional

from typing_extensions import deprecated, TypedDict

import chevron
from typing_extensions import TypedDict, deprecated

if TYPE_CHECKING:
from literalai.api import LiteralAPI
Expand Down Expand Up @@ -117,13 +116,6 @@ def from_dict(cls, api: "LiteralAPI", prompt_dict: PromptDict) -> "Prompt":
variables_default_values=prompt_dict.get("variablesDefaultValues"),
)

def promote(self) -> "Prompt":
"""
Promotes this prompt to champion.
"""
self.api.promote_prompt(self.name, self.version)
return self

def format_messages(self, **kwargs: Any) -> List[Any]:
"""
Formats the prompt's template messages with the given variables.
Expand Down
46 changes: 37 additions & 9 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import secrets
import time
import uuid
from typing import List

import pytest

from literalai import AsyncLiteralClient, LiteralClient
from literalai.context import active_steps_var
from literalai.observability.generation import ChatGeneration
from literalai.observability.generation import ChatGeneration, GenerationMessage
from literalai.observability.thread import Thread

"""
Expand Down Expand Up @@ -599,16 +600,43 @@ async def test_prompt(self, async_client: AsyncLiteralClient):
assert messages[0]["content"] == expected

@pytest.mark.timeout(5)
async def test_champion_prompt(self, client: LiteralClient):
new_prompt = client.api.get_or_create_prompt(
name="Python SDK E2E Tests",
template_messages=[{"role": "user", "content": "Hello"}],
async def test_prompt_ab_testing(self, client: LiteralClient):
prompt_name = "Python SDK E2E Tests"

v0: List[GenerationMessage] = [{"role": "user", "content": "Hello"}]
v1: List[GenerationMessage] = [{"role": "user", "content": "Hello 2"}]

prompt_v0 = client.api.get_or_create_prompt(
name=prompt_name,
template_messages=v0,
)
new_prompt.promote()

prompt = client.api.get_prompt(name="Python SDK E2E Tests")
assert prompt is not None
assert prompt.version == new_prompt.version
client.api.update_prompt_ab_testing(
prompt_v0.name, rollouts=[{"version": 0, "rollout": 100}]
)

ab_testing = client.api.get_prompt_ab_testing(name=prompt_v0.name)
assert len(ab_testing) == 1
assert ab_testing[0]["version"] == 0
assert ab_testing[0]["rollout"] == 100

prompt_v1 = client.api.get_or_create_prompt(
name=prompt_name,
template_messages=v1,
)

client.api.update_prompt_ab_testing(
name=prompt_v1.name,
rollouts=[{"version": 0, "rollout": 60}, {"version": 1, "rollout": 40}],
)

ab_testing = client.api.get_prompt_ab_testing(name=prompt_v1.name)

assert len(ab_testing) == 2
assert ab_testing[0]["version"] == 0
assert ab_testing[0]["rollout"] == 60
assert ab_testing[1]["version"] == 1
assert ab_testing[1]["rollout"] == 40

@pytest.mark.timeout(5)
async def test_gracefulness(self, broken_client: LiteralClient):
Expand Down

0 comments on commit 8d159b0

Please sign in to comment.