Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

75 manager sampler add minimum tokens per second accepted #94

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions apps/python/sampler/activities/lmeh/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from packages.python.lmeh.utils import sql as lmeh_sql
from packages.python.lmeh.utils import task_config as open_llm_config
from packages.python.lmeh.utils.common import get_task_manager
from packages.python.protocol.protocol import PocketNetworkTaskRequest
from packages.python.protocol.protocol import (
LLMTimeouts,
PocketNetworkTaskRequest,
TimeoutHandler,
)


@activity.defn
Expand All @@ -19,7 +23,10 @@ async def lmeh_sample(args: PocketNetworkTaskRequest) -> bool:
eval_logger = get_app_logger("sample")
config = get_app_config()["config"]
wf_id = activity.info().workflow_id

timeouts = LLMTimeouts(**config["timeouts"][args.requester_args.service])
AguirreNicolas marked this conversation as resolved.
Show resolved Hide resolved
timeout_handler = TimeoutHandler(
service=args.requester_args.service, timeouts=timeouts
)
eval_logger.info(
"Starting activity lmeh_sample",
task_name=args.tasks,
Expand Down Expand Up @@ -139,6 +146,7 @@ async def lmeh_sample(args: PocketNetworkTaskRequest) -> bool:
mongo_client=mongo_client,
args=args,
eval_logger=eval_logger,
timeout_handler=timeout_handler,
)
eval_logger.info("LM generated successfully.")

Expand Down
20 changes: 19 additions & 1 deletion docker-compose/dev/apps/config/sampler.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,23 @@
"port": 7233,
"namespace": "pocket-ml-testbench",
"task_queue": "sampler"
}
},
"timeouts": {
RawthiL marked this conversation as resolved.
Show resolved Hide resolved
"random": {
"ttft": {
"x": [
RawthiL marked this conversation as resolved.
Show resolved Hide resolved
0,
8192,
32768
],
"y": [
0,
2,
10
]
},
"tpot": 0.336,
"queue": 30
}
}
}
84 changes: 83 additions & 1 deletion docker-compose/morse-poc/apps_configs/sampler.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,87 @@
"max_concurrent_workflow_tasks": 4,
"max_concurrent_workflow_task_polls": 4,
"max_concurrent_activity_task_polls": 4
}
},
"timeouts": {
AguirreNicolas marked this conversation as resolved.
Show resolved Hide resolved
"A100": {
"ttft": {
"x": [
0,
8192,
32768
],
"y": [
0,
2,
10
]
},
"tpot": 0.336,
"queue": 30
},
"A101": {
"ttft": {
"x": [
0,
8192,
32768
],
"y": [
0,
2,
10
]
},
"tpot": 0.336,
"queue": 30
},
"A102": {
"ttft": {
"x": [
0,
8192,
32768
],
"y": [
0,
2,
10
]
},
"tpot": 0.336,
"queue": 30
},
"A103": {
"ttft": {
"x": [
0,
8192,
32768
],
"y": [
0,
2,
10
]
},
"tpot": 0.336,
"queue": 30
}
},
"A1FF": {
"ttft": {
"x": [
0,
8192,
32768
],
"y": [
0,
2,
10
]
},
"tpot": 0.336,
"queue": 30
}
}
19 changes: 19 additions & 0 deletions packages/python/lmeh/utils/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
PocketNetworkMongoDBResultNumerical,
PocketNetworkMongoDBTask,
PocketNetworkTaskRequest,
TimeoutHandler,
)


Expand Down Expand Up @@ -162,6 +163,7 @@ async def generate_requests(
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
eval_logger: Optional[logging.Logger] = None,
timeout_handler=TimeoutHandler,
):
"""Generate and save in mongoDB: Task->Instances->Prompts

Expand Down Expand Up @@ -309,13 +311,30 @@ async def generate_requests(
exclude_defaults=True,
exclude={"ctxlen", "context_enc", "continuation_enc"},
)
# Timeout
prefill = pocket_req.ctxlen
decode = (
lm.max_gen_toks if instance.request_type == "generate_until" else 2
)
timeout = int(
timeout_handler.get_timeout(prefill=prefill, decode=decode)
)
eval_logger.debug(
"Timeout:",
timeout=timeout,
prefill=pocket_req.ctxlen,
decode=decode,
request_type=instance.request_type,
)
# Prompt
prompt_mongo = PocketNetworkMongoDBPrompt(
data=data,
task_id=task_mongodb.id,
instance_id=instance_id,
ctxlen=pocket_req.ctxlen,
context_enc=pocket_req.context_enc,
continuation_enc=pocket_req.continuation_enc,
timeout=timeout,
)
insert_mongo_prompts.append(prompt_mongo.model_dump(by_alias=True))
try:
Expand Down
62 changes: 61 additions & 1 deletion packages/python/protocol/protocol.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import time
import uuid
from datetime import datetime
from typing import Dict, List, Literal, Optional, Union
from typing import Callable, Dict, List, Literal, Optional, Union

import numpy as np
from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

Expand Down Expand Up @@ -310,3 +311,62 @@ class PocketNetworkMongoDBConfig(BaseModel):

class Config:
arbitrary_types_allowed = True


######################
# TIMEOUT HANDLER
######################


class TTFT(BaseModel):
x: List[int]
y: List[int]


class LLMTimeouts(BaseModel):
ttft: TTFT
tpot: float
queue: float


class TimeoutHandler(BaseModel):
service: str
timeouts: LLMTimeouts

# Define the functions for each chain
def timeouts_a100(self, prefill: int, decode: int) -> float:
x = self.timeouts.ttft.x
y = self.timeouts.ttft.y
queue = self.timeouts.queue
z = np.polyfit(x, y, 2)
AguirreNicolas marked this conversation as resolved.
Show resolved Hide resolved
ttft = np.poly1d(z)
tpot = self.timeouts.tpot
timeout = ttft(prefill) + (tpot * decode) + queue
return float(timeout)

# In case of new chains redefine the functions below
def timeouts_a101(self, prefill: int, decode: int) -> float:
return self.timeouts_a100(prefill, decode)

def timeouts_a102(self, prefill: int, decode: int) -> float:
return self.timeouts_a100(prefill, decode)

def timeouts_a103(self, prefill: int, decode: int) -> float:
return self.timeouts_a100(prefill, decode)

def chain_default(self, prefill: int, decode: int) -> float:
return 60

@model_validator(mode="after")
def map_timeouts(self):
chain_timeouts: Dict[str, Callable[[], str]] = {
"A100": self.timeouts_a100,
"A101": self.timeouts_a101,
"A102": self.timeouts_a102,
"A103": self.timeouts_a103,
}
self._timeout_fn = chain_timeouts.get(self.service, self.chain_default)
return

def get_timeout(self, prefill: int, decode: int) -> float:
return self._timeout_fn(prefill, decode)
Loading