Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

75 manager sampler add minimum tokens per second accepted #94

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion apps/python/sampler/activities/lmeh/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from packages.python.lmeh.utils import sql as lmeh_sql
from packages.python.lmeh.utils import task_config as open_llm_config
from packages.python.lmeh.utils.common import get_task_manager
from packages.python.protocol.protocol import PocketNetworkTaskRequest
from packages.python.protocol.protocol import (
LLMTimeouts,
PocketNetworkTaskRequest,
TimeoutHandler,
)


@activity.defn
Expand All @@ -19,6 +23,26 @@ async def lmeh_sample(args: PocketNetworkTaskRequest) -> bool:
eval_logger = get_app_logger("sample")
config = get_app_config()["config"]
wf_id = activity.info().workflow_id
# check if config has timeouts
if "timeouts" in config:
try:
timeouts = LLMTimeouts(**config["timeouts"][args.requester_args.service])
timeout_handler = TimeoutHandler(timeouts=timeouts)
except Exception as e:
eval_logger.error(
"Error creating TimeoutHandler",
error=e,
timeouts=config["timeouts"],
service=args.requester_args.service,
)
raise ApplicationError(
"Error creating TimeoutHandler",
str(e),
type="TimeoutHandler",
non_retryable=True,
)
else:
timeout_handler = TimeoutHandler()

eval_logger.info(
"Starting activity lmeh_sample",
Expand Down Expand Up @@ -139,6 +163,7 @@ async def lmeh_sample(args: PocketNetworkTaskRequest) -> bool:
mongo_client=mongo_client,
args=args,
eval_logger=eval_logger,
timeout_handler=timeout_handler,
)
eval_logger.info("LM generated successfully.")

Expand Down
20 changes: 19 additions & 1 deletion docker-compose/dev/apps/config/sampler.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,23 @@
"port": 7233,
"namespace": "pocket-ml-testbench",
"task_queue": "sampler"
}
},
"timeouts": {
RawthiL marked this conversation as resolved.
Show resolved Hide resolved
"random": {
"ttft": {
"x": [
RawthiL marked this conversation as resolved.
Show resolved Hide resolved
0,
8192,
32768
],
"y": [
0,
2,
10
]
},
"tpot": 0.336,
"queue": 30
}
}
}
49 changes: 48 additions & 1 deletion docker-compose/morse-poc/apps_configs/sampler.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,52 @@
"max_concurrent_workflow_tasks": 4,
"max_concurrent_workflow_task_polls": 4,
"max_concurrent_activity_task_polls": 4
}
},
"timeouts": {
AguirreNicolas marked this conversation as resolved.
Show resolved Hide resolved
"A100": {
"ttft": {
"prompt_lenght": [0, 8192, 32768],
"sla_time": [0, 2, 10]
},
"tpot": 0.336,
"queue": 30,
"type": "llm"
},
"A101": {
"ttft": {
"prompt_lenght": [0, 8192, 32768],
"sla_time": [0, 2, 10]
},
"tpot": 0.336,
"queue": 30,
"type": "llm"
},
"A102": {
"ttft": {
"prompt_lenght": [0, 8192, 32768],
"sla_time": [0, 2, 10]
},
"tpot": 0.336,
"queue": 30,
"type": "llm"
},
"A103": {
"ttft": {
"prompt_lenght": [0, 8192, 32768],
"sla_time": [0, 2, 10]
},
"tpot": 0.336,
"queue": 30,
"type": "llm"
}
},
"A1FF": {
"ttft": {
"prompt_lenght": [0, 8192, 32768],
"sla_time": [0, 2, 10]
},
"tpot": 0.336,
"queue": 30,
"type": "llm"
}
}
19 changes: 19 additions & 0 deletions packages/python/lmeh/utils/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
PocketNetworkMongoDBResultNumerical,
PocketNetworkMongoDBTask,
PocketNetworkTaskRequest,
TimeoutHandler,
)


Expand Down Expand Up @@ -162,6 +163,7 @@ async def generate_requests(
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
eval_logger: Optional[logging.Logger] = None,
timeout_handler=TimeoutHandler,
):
"""Generate and save in mongoDB: Task->Instances->Prompts

Expand Down Expand Up @@ -309,13 +311,30 @@ async def generate_requests(
exclude_defaults=True,
exclude={"ctxlen", "context_enc", "continuation_enc"},
)
# Timeout
prefill = pocket_req.ctxlen
decode = (
lm.max_gen_toks if instance.request_type == "generate_until" else 2
)
timeout = int(
timeout_handler.get_timeout(prefill=prefill, decode=decode)
)
eval_logger.debug(
"Timeout:",
timeout=timeout,
prefill=pocket_req.ctxlen,
decode=decode,
request_type=instance.request_type,
)
# Prompt
prompt_mongo = PocketNetworkMongoDBPrompt(
data=data,
task_id=task_mongodb.id,
instance_id=instance_id,
ctxlen=pocket_req.ctxlen,
context_enc=pocket_req.context_enc,
continuation_enc=pocket_req.continuation_enc,
timeout=timeout,
)
insert_mongo_prompts.append(prompt_mongo.model_dump(by_alias=True))
try:
Expand Down
70 changes: 69 additions & 1 deletion packages/python/protocol/protocol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import uuid
from datetime import datetime
from typing import Dict, List, Literal, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Union

from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
Expand Down Expand Up @@ -310,3 +310,71 @@ class PocketNetworkMongoDBConfig(BaseModel):

class Config:
arbitrary_types_allowed = True


######################
# TIMEOUT HANDLER
######################


class TTFT(BaseModel):
prompt_lenght: List[int]
sla_time: List[int]


class LLMTimeouts(BaseModel):
ttft: TTFT
tpot: float
queue: float
type: str = "llm"


class TimeoutHandler(BaseModel):
model_config = ConfigDict(extra="allow")
timeouts: Optional[LLMTimeouts] = None

def llm_timeout(self, prefill: int, decode: int) -> float:
timeout = self.ttft(prefill) + (self.tpot * decode) + self.queue
return float(timeout)

# whenever a new timeout type is added, add a new function here

def chain_default(self, prefill: int, decode: int) -> float:
return 60

@model_validator(mode="after")
def map_timeouts(self):
if self.timeouts:
chain_timeouts: Dict[str, Callable[[], str]] = {
"llm": self.llm_timeout,
}
self._timeout_fn = chain_timeouts.get(
self.timeouts.type, self.chain_default
)
else:
# if timeouts are not defined, means default
self._timeout_fn = self.chain_default
return

def model_post_init(self, __context: Any) -> None:
if self.timeouts is None:
# if timeouts are not defined, means default
return
if self.timeouts.type == "llm":
try:
import numpy as np

x = self.timeouts.ttft.prompt_lenght
y = self.timeouts.ttft.sla_time
self.queue = self.timeouts.queue
z = np.polyfit(x, y, 2)
self.ttft = np.poly1d(z)
self.tpot = self.timeouts.tpot
except Exception as e:
raise ValueError(f"Error creating timeout function: {e}")
# whenever a new timeout type is added, add new post init here
# to define attributes.
return

def get_timeout(self, **kwargs) -> float:
return self._timeout_fn(**kwargs)
Loading