Skip to content

Commit

Permalink
Merge branch 'main' of github.com:Chainlit/literalai-python into will…
Browse files Browse the repository at this point in the history
…y/eng-1754-fix-mistralai-instrumentation-for-100
  • Loading branch information
willydouhard committed Aug 8, 2024
2 parents 4c4ce14 + 5011c6d commit 4a0b063
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 80 deletions.
102 changes: 59 additions & 43 deletions literalai/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,6 @@

from typing_extensions import deprecated

from literalai.context import active_steps_var, active_thread_var
from literalai.evaluation.dataset import Dataset, DatasetType
from literalai.evaluation.dataset_experiment import (
DatasetExperiment,
DatasetExperimentItem,
)
from literalai.observability.filter import (
generations_filters,
generations_order_by,
scores_filters,
scores_order_by,
steps_filters,
steps_order_by,
threads_filters,
threads_order_by,
users_filters,
)
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

from literalai.api.attachment_helpers import (
AttachmentUpload,
create_attachment_helper,
Expand All @@ -60,10 +41,12 @@
get_generations_helper,
)
from literalai.api.prompt_helpers import (
PromptRollout,
create_prompt_helper,
create_prompt_lineage_helper,
get_prompt_ab_testing_helper,
get_prompt_helper,
promote_prompt_helper,
update_prompt_ab_testing_helper,
)
from literalai.api.score_helpers import (
ScoreUpdate,
Expand Down Expand Up @@ -98,29 +81,44 @@
get_users_helper,
update_user_helper,
)
from literalai.context import active_steps_var, active_thread_var
from literalai.evaluation.dataset import Dataset, DatasetType
from literalai.evaluation.dataset_experiment import (
DatasetExperiment,
DatasetExperimentItem,
)
from literalai.observability.filter import (
generations_filters,
generations_order_by,
scores_filters,
scores_order_by,
steps_filters,
steps_order_by,
threads_filters,
threads_order_by,
users_filters,
)
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

if TYPE_CHECKING:
from typing import Tuple # noqa: F401

import httpx

from literalai.my_types import (
Environment,
PaginatedResponse,
)
from literalai.my_types import Environment, PaginatedResponse
from literalai.observability.generation import (
GenerationMessage,
CompletionGeneration,
ChatGeneration,
CompletionGeneration,
GenerationMessage,
)
from literalai.observability.step import (
Attachment,
Score,
ScoreDict,
ScoreType,
Step,
StepDict,
StepType,
ScoreType,
ScoreDict,
Score,
Attachment,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1365,21 +1363,33 @@ def get_prompt(
else:
raise ValueError("Either the `id` or the `name` must be provided.")

def promote_prompt(self, name: str, version: int) -> str:
def get_prompt_ab_testing(self, name: str) -> List[PromptRollout]:
"""
Promotes the prompt with name to target version.
Get the A/B testing configuration for a prompt lineage.
Args:
name (str): The name of the prompt lineage.
version (int): The version number to promote.
Returns:
str: The champion prompt ID.
List[PromptRollout]
"""
lineage = self.get_or_create_prompt_lineage(name)
lineage_id = lineage["id"]
return self.gql_helper(*get_prompt_ab_testing_helper(name=name))

return self.gql_helper(*promote_prompt_helper(lineage_id, version))
def update_prompt_ab_testing(
self, name: str, rollouts: List[PromptRollout]
) -> Dict:
"""
Update the A/B testing configuration for a prompt lineage.
Args:
name (str): The name of the prompt lineage.
rollouts (List[PromptRollout]): The percentage rollout for each prompt version.
Returns:
Dict
"""
return self.gql_helper(
*update_prompt_ab_testing_helper(name=name, rollouts=rollouts)
)

# Misc API

Expand Down Expand Up @@ -2552,13 +2562,19 @@ async def get_prompt(

get_prompt.__doc__ = LiteralAPI.get_prompt.__doc__

async def promote_prompt(self, name: str, version: int) -> str:
lineage = await self.get_or_create_prompt_lineage(name)
lineage_id = lineage["id"]
async def update_prompt_ab_testing(
self, name: str, rollouts: List[PromptRollout]
) -> Dict:
return await self.gql_helper(
*update_prompt_ab_testing_helper(name=name, rollouts=rollouts)
)

update_prompt_ab_testing.__doc__ = LiteralAPI.update_prompt_ab_testing.__doc__

return await self.gql_helper(*promote_prompt_helper(lineage_id, version))
async def get_prompt_ab_testing(self, name: str) -> List[PromptRollout]:
return await self.gql_helper(*get_prompt_ab_testing_helper(name=name))

promote_prompt.__doc__ = LiteralAPI.promote_prompt.__doc__
get_prompt_ab_testing.__doc__ = LiteralAPI.get_prompt_ab_testing.__doc__

# Misc API

Expand Down
35 changes: 27 additions & 8 deletions literalai/api/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,16 +1041,35 @@
}
"""

PROMOTE_PROMPT_VERSION = """mutation promotePromptVersion(
$lineageId: String!
$version: Int!
GET_PROMPT_AB_TESTING = """query getPromptLineageRollout($projectId: String, $lineageName: String!) {
promptLineageRollout(projectId: $projectId, lineageName: $lineageName) {
pageInfo {
startCursor
endCursor
}
edges {
node {
version
rollout
}
}
}
}
"""

UPDATE_PROMPT_AB_TESTING = """mutation updatePromptLineageRollout(
$projectId: String
$name: String!
$rollouts: [PromptVersionRolloutInput!]!
) {
promotePromptVersion(
lineageId: $lineageId
version: $version
updatePromptLineageRollout(
projectId: $projectId
name: $name
rollouts: $rollouts
) {
id
championId
ok
message
errorCode
}
}"""

Expand Down
35 changes: 25 additions & 10 deletions literalai/api/prompt_helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict

from literalai.observability.generation import GenerationMessage
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings
Expand Down Expand Up @@ -61,16 +61,31 @@ def process_response(response):
return gql.GET_PROMPT_VERSION, description, variables, process_response


def promote_prompt_helper(
lineage_id: str,
version: int,
class PromptRollout(TypedDict):
version: int
rollout: int


def get_prompt_ab_testing_helper(
name: Optional[str] = None,
):
variables = {"lineageId": lineage_id, "version": version}
variables = {"lineageName": name}

def process_response(response) -> List[PromptRollout]:
response_data = response["data"]["promptLineageRollout"]
return list(map(lambda x: x["node"], response_data["edges"]))

description = "get prompt A/B testing"

return gql.GET_PROMPT_AB_TESTING, description, variables, process_response


def update_prompt_ab_testing_helper(name: str, rollouts: List[PromptRollout]):
variables = {"name": name, "rollouts": rollouts}

def process_response(response) -> str:
prompt = response["data"]["promotePromptVersion"]
return prompt["championId"] if prompt else None
def process_response(response) -> Dict:
return response["data"]["updatePromptLineageRollout"]

description = "promote prompt version"
description = "update prompt A/B testing"

return gql.PROMOTE_PROMPT_VERSION, description, variables, process_response
return gql.UPDATE_PROMPT_AB_TESTING, description, variables, process_response
10 changes: 1 addition & 9 deletions literalai/prompt_engineering/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional

from typing_extensions import deprecated, TypedDict

import chevron
from typing_extensions import TypedDict, deprecated

if TYPE_CHECKING:
from literalai.api import LiteralAPI
Expand Down Expand Up @@ -117,13 +116,6 @@ def from_dict(cls, api: "LiteralAPI", prompt_dict: PromptDict) -> "Prompt":
variables_default_values=prompt_dict.get("variablesDefaultValues"),
)

def promote(self) -> "Prompt":
"""
Promotes this prompt to champion.
"""
self.api.promote_prompt(self.name, self.version)
return self

def format_messages(self, **kwargs: Any) -> List[Any]:
"""
Formats the prompt's template messages with the given variables.
Expand Down
47 changes: 37 additions & 10 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import secrets
import time
import uuid
from typing import List

import pytest

from literalai import AsyncLiteralClient, LiteralClient
from literalai.context import active_steps_var
from literalai.observability.generation import ChatGeneration
from literalai.observability.generation import ChatGeneration, GenerationMessage
from literalai.observability.thread import Thread

"""
Expand Down Expand Up @@ -384,7 +385,6 @@ def step_decorated():
async def test_nested_run_steps(
self, client: LiteralClient, async_client: AsyncLiteralClient
):

@async_client.run(name="foo")
def run_decorated():
s = async_client.get_current_step()
Expand Down Expand Up @@ -627,16 +627,43 @@ async def test_prompt(self, async_client: AsyncLiteralClient):
assert messages[0]["content"] == expected

@pytest.mark.timeout(5)
async def test_champion_prompt(self, client: LiteralClient):
new_prompt = client.api.get_or_create_prompt(
name="Python SDK E2E Tests",
template_messages=[{"role": "user", "content": "Hello"}],
async def test_prompt_ab_testing(self, client: LiteralClient):
prompt_name = "Python SDK E2E Tests"

v0: List[GenerationMessage] = [{"role": "user", "content": "Hello"}]
v1: List[GenerationMessage] = [{"role": "user", "content": "Hello 2"}]

prompt_v0 = client.api.get_or_create_prompt(
name=prompt_name,
template_messages=v0,
)
new_prompt.promote()

prompt = client.api.get_prompt(name="Python SDK E2E Tests")
assert prompt is not None
assert prompt.version == new_prompt.version
client.api.update_prompt_ab_testing(
prompt_v0.name, rollouts=[{"version": 0, "rollout": 100}]
)

ab_testing = client.api.get_prompt_ab_testing(name=prompt_v0.name)
assert len(ab_testing) == 1
assert ab_testing[0]["version"] == 0
assert ab_testing[0]["rollout"] == 100

prompt_v1 = client.api.get_or_create_prompt(
name=prompt_name,
template_messages=v1,
)

client.api.update_prompt_ab_testing(
name=prompt_v1.name,
rollouts=[{"version": 0, "rollout": 60}, {"version": 1, "rollout": 40}],
)

ab_testing = client.api.get_prompt_ab_testing(name=prompt_v1.name)

assert len(ab_testing) == 2
assert ab_testing[0]["version"] == 0
assert ab_testing[0]["rollout"] == 60
assert ab_testing[1]["version"] == 1
assert ab_testing[1]["rollout"] == 40

@pytest.mark.timeout(5)
async def test_gracefulness(self, broken_client: LiteralClient):
Expand Down

0 comments on commit 4a0b063

Please sign in to comment.