Skip to content

Commit

Permalink
75 manager sampler add minimum tokens per second accepted (#94)
Browse files Browse the repository at this point in the history
* fix names

* fix names

* tested code

* * Added try and exceptions
* Enhanced TimeoutHander construction.
* Added type into the timeout.
  • Loading branch information
AguirreNicolas authored Aug 1, 2024
1 parent 0bf077d commit 0efbae8
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 4 deletions.
27 changes: 26 additions & 1 deletion apps/python/sampler/activities/lmeh/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from packages.python.lmeh.utils import sql as lmeh_sql
from packages.python.lmeh.utils import task_config as open_llm_config
from packages.python.lmeh.utils.common import get_task_manager
from packages.python.protocol.protocol import PocketNetworkTaskRequest
from packages.python.protocol.protocol import (
LLMTimeouts,
PocketNetworkTaskRequest,
TimeoutHandler,
)


@activity.defn
Expand All @@ -19,6 +23,26 @@ async def lmeh_sample(args: PocketNetworkTaskRequest) -> bool:
eval_logger = get_app_logger("sample")
config = get_app_config()["config"]
wf_id = activity.info().workflow_id
# check if config has timeouts
if "timeouts" in config:
try:
timeouts = LLMTimeouts(**config["timeouts"][args.requester_args.service])
timeout_handler = TimeoutHandler(timeouts=timeouts)
except Exception as e:
eval_logger.error(
"Error creating TimeoutHandler",
error=e,
timeouts=config["timeouts"],
service=args.requester_args.service,
)
raise ApplicationError(
"Error creating TimeoutHandler",
str(e),
type="TimeoutHandler",
non_retryable=True,
)
else:
timeout_handler = TimeoutHandler()

eval_logger.info(
"Starting activity lmeh_sample",
Expand Down Expand Up @@ -139,6 +163,7 @@ async def lmeh_sample(args: PocketNetworkTaskRequest) -> bool:
mongo_client=mongo_client,
args=args,
eval_logger=eval_logger,
timeout_handler=timeout_handler,
)
eval_logger.info("LM generated successfully.")

Expand Down
20 changes: 19 additions & 1 deletion docker-compose/dev/apps/config/sampler.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,23 @@
"port": 7233,
"namespace": "pocket-ml-testbench",
"task_queue": "sampler"
}
},
"timeouts": {
"random": {
"ttft": {
"x": [
0,
8192,
32768
],
"y": [
0,
2,
10
]
},
"tpot": 0.336,
"queue": 30
}
}
}
49 changes: 48 additions & 1 deletion docker-compose/morse-poc/apps_configs/sampler.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,52 @@
"max_concurrent_workflow_tasks": 4,
"max_concurrent_workflow_task_polls": 4,
"max_concurrent_activity_task_polls": 4
}
},
"timeouts": {
"A100": {
"ttft": {
"prompt_lenght": [0, 8192, 32768],
"sla_time": [0, 2, 10]
},
"tpot": 0.336,
"queue": 30,
"type": "llm"
},
"A101": {
"ttft": {
"prompt_lenght": [0, 8192, 32768],
"sla_time": [0, 2, 10]
},
"tpot": 0.336,
"queue": 30,
"type": "llm"
},
"A102": {
"ttft": {
"prompt_lenght": [0, 8192, 32768],
"sla_time": [0, 2, 10]
},
"tpot": 0.336,
"queue": 30,
"type": "llm"
},
"A103": {
"ttft": {
"prompt_lenght": [0, 8192, 32768],
"sla_time": [0, 2, 10]
},
"tpot": 0.336,
"queue": 30,
"type": "llm"
}
},
"A1FF": {
"ttft": {
"prompt_lenght": [0, 8192, 32768],
"sla_time": [0, 2, 10]
},
"tpot": 0.336,
"queue": 30,
"type": "llm"
}
}
19 changes: 19 additions & 0 deletions packages/python/lmeh/utils/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
PocketNetworkMongoDBResultNumerical,
PocketNetworkMongoDBTask,
PocketNetworkTaskRequest,
TimeoutHandler,
)


Expand Down Expand Up @@ -162,6 +163,7 @@ async def generate_requests(
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
eval_logger: Optional[logging.Logger] = None,
timeout_handler=TimeoutHandler,
):
"""Generate and save in mongoDB: Task->Instances->Prompts
Expand Down Expand Up @@ -309,13 +311,30 @@ async def generate_requests(
exclude_defaults=True,
exclude={"ctxlen", "context_enc", "continuation_enc"},
)
# Timeout
prefill = pocket_req.ctxlen
decode = (
lm.max_gen_toks if instance.request_type == "generate_until" else 2
)
timeout = int(
timeout_handler.get_timeout(prefill=prefill, decode=decode)
)
eval_logger.debug(
"Timeout:",
timeout=timeout,
prefill=pocket_req.ctxlen,
decode=decode,
request_type=instance.request_type,
)
# Prompt
prompt_mongo = PocketNetworkMongoDBPrompt(
data=data,
task_id=task_mongodb.id,
instance_id=instance_id,
ctxlen=pocket_req.ctxlen,
context_enc=pocket_req.context_enc,
continuation_enc=pocket_req.continuation_enc,
timeout=timeout,
)
insert_mongo_prompts.append(prompt_mongo.model_dump(by_alias=True))
try:
Expand Down
70 changes: 69 additions & 1 deletion packages/python/protocol/protocol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import uuid
from datetime import datetime
from typing import Dict, List, Literal, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Union

from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
Expand Down Expand Up @@ -310,3 +310,71 @@ class PocketNetworkMongoDBConfig(BaseModel):

class Config:
arbitrary_types_allowed = True


######################
# TIMEOUT HANDLER
######################


class TTFT(BaseModel):
prompt_lenght: List[int]
sla_time: List[int]


class LLMTimeouts(BaseModel):
ttft: TTFT
tpot: float
queue: float
type: str = "llm"


class TimeoutHandler(BaseModel):
model_config = ConfigDict(extra="allow")
timeouts: Optional[LLMTimeouts] = None

def llm_timeout(self, prefill: int, decode: int) -> float:
timeout = self.ttft(prefill) + (self.tpot * decode) + self.queue
return float(timeout)

# whenever a new timeout type is added, add a new function here

def chain_default(self, prefill: int, decode: int) -> float:
return 60

@model_validator(mode="after")
def map_timeouts(self):
if self.timeouts:
chain_timeouts: Dict[str, Callable[[], str]] = {
"llm": self.llm_timeout,
}
self._timeout_fn = chain_timeouts.get(
self.timeouts.type, self.chain_default
)
else:
# if timeouts are not defined, means default
self._timeout_fn = self.chain_default
return

def model_post_init(self, __context: Any) -> None:
if self.timeouts is None:
# if timeouts are not defined, means default
return
if self.timeouts.type == "llm":
try:
import numpy as np

x = self.timeouts.ttft.prompt_lenght
y = self.timeouts.ttft.sla_time
self.queue = self.timeouts.queue
z = np.polyfit(x, y, 2)
self.ttft = np.poly1d(z)
self.tpot = self.timeouts.tpot
except Exception as e:
raise ValueError(f"Error creating timeout function: {e}")
# whenever a new timeout type is added, add new post init here
# to define attributes.
return

def get_timeout(self, **kwargs) -> float:
return self._timeout_fn(**kwargs)

0 comments on commit 0efbae8

Please sign in to comment.