Skip to content

Commit

Permalink
[Serve] Support headers in Readiness Probe (#3552)
Browse files Browse the repository at this point in the history
* inti

* probe_str remobve heades and delete env vars

* remove header values in replica manager logging
  • Loading branch information
cblmemo authored May 21, 2024
1 parent 7be7f6a commit cf840dc
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 6 deletions.
42 changes: 42 additions & 0 deletions llm/vllm/service-with-auth.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# service.yaml
# The newly-added `service` section to the `serve-openai-api.yaml` file.
service:
# Specifying the path to the endpoint to check the readiness of the service.
readiness_probe:
path: /v1/models
# Set authorization headers here if needed.
headers:
Authorization: Bearer $AUTH_TOKEN
# How many replicas to manage.
replicas: 1

# Fields below are the same with `serve-openai-api.yaml`.
envs:
MODEL_NAME: meta-llama/Llama-2-7b-chat-hf
HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass.
AUTH_TOKEN: # TODO: Fill with your own auth token (a random string), or use --env to pass.

resources:
accelerators: {L4:1, A10G:1, A10:1, A100:1, A100-80GB:1}
ports: 8000

setup: |
conda activate vllm
if [ $? -ne 0 ]; then
conda create -n vllm python=3.10 -y
conda activate vllm
fi
pip install transformers==4.38.0
pip install vllm==0.3.2
python -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
run: |
conda activate vllm
echo 'Starting vllm openai api server...'
python -m vllm.entrypoints.openai.api_server \
--model $MODEL_NAME --tokenizer hf-internal-testing/llama-tokenizer \
--host 0.0.0.0 --port 8000 --api-key $AUTH_TOKEN
19 changes: 16 additions & 3 deletions sky/serve/replica_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def probe(
self,
readiness_path: str,
post_data: Optional[Dict[str, Any]],
headers: Optional[Dict[str, str]],
) -> Tuple['ReplicaInfo', bool, float]:
"""Probe the readiness of the replica.
Expand All @@ -513,12 +514,14 @@ def probe(
msg += 'POST'
response = requests.post(
readiness_path,
headers=headers,
json=post_data,
timeout=serve_constants.READINESS_PROBE_TIMEOUT_SECONDS)
else:
msg += 'GET'
response = requests.get(
readiness_path,
headers=headers,
timeout=serve_constants.READINESS_PROBE_TIMEOUT_SECONDS)
msg += (f' request to {replica_identity} returned status '
f'code {response.status_code}')
Expand Down Expand Up @@ -565,9 +568,13 @@ def __init__(self, service_name: str,
self._service_name: str = service_name
self._uptime: Optional[float] = None
self._update_mode = serve_utils.DEFAULT_UPDATE_MODE
header_keys = None
if spec.readiness_headers is not None:
header_keys = list(spec.readiness_headers.keys())
logger.info(f'Readiness probe path: {spec.readiness_path}\n'
f'Initial delay seconds: {spec.initial_delay_seconds}\n'
f'Post data: {spec.post_data}')
f'Post data: {spec.post_data}\n'
f'Readiness header keys: {header_keys}')

# Newest version among the currently provisioned and launched replicas
self.latest_version: int = serve_constants.INITIAL_VERSION
Expand Down Expand Up @@ -1033,8 +1040,11 @@ def _probe_all_replicas(self) -> None:
probe_futures.append(
pool.apply_async(
info.probe,
(self._get_readiness_path(
info.version), self._get_post_data(info.version)),
(
self._get_readiness_path(info.version),
self._get_post_data(info.version),
self._get_readiness_headers(info.version),
),
),)
logger.info(f'Replicas to probe: {", ".join(replica_to_probe)}')

Expand Down Expand Up @@ -1215,5 +1225,8 @@ def _get_readiness_path(self, version: int) -> str:
def _get_post_data(self, version: int) -> Optional[Dict[str, Any]]:
return self._get_version_spec(version).post_data

def _get_readiness_headers(self, version: int) -> Optional[Dict[str, str]]:
return self._get_version_spec(version).readiness_headers

def _get_initial_delay_seconds(self, version: int) -> int:
return self._get_version_spec(version).initial_delay_seconds
18 changes: 16 additions & 2 deletions sky/serve/service_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
max_replicas: Optional[int] = None,
target_qps_per_replica: Optional[float] = None,
post_data: Optional[Dict[str, Any]] = None,
readiness_headers: Optional[Dict[str, str]] = None,
dynamic_ondemand_fallback: Optional[bool] = None,
base_ondemand_fallback_replicas: Optional[int] = None,
upscale_delay_seconds: Optional[int] = None,
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
self._max_replicas: Optional[int] = max_replicas
self._target_qps_per_replica: Optional[float] = target_qps_per_replica
self._post_data: Optional[Dict[str, Any]] = post_data
self._readiness_headers: Optional[Dict[str, str]] = readiness_headers
self._dynamic_ondemand_fallback: Optional[
bool] = dynamic_ondemand_fallback
self._base_ondemand_fallback_replicas: Optional[
Expand Down Expand Up @@ -111,11 +113,13 @@ def from_yaml_config(config: Dict[str, Any]) -> 'SkyServiceSpec':
service_config['readiness_path'] = readiness_section
initial_delay_seconds = None
post_data = None
readiness_headers = None
else:
service_config['readiness_path'] = readiness_section['path']
initial_delay_seconds = readiness_section.get(
'initial_delay_seconds', None)
post_data = readiness_section.get('post_data', None)
readiness_headers = readiness_section.get('headers', None)
if initial_delay_seconds is None:
initial_delay_seconds = constants.DEFAULT_INITIAL_DELAY_SECONDS
service_config['initial_delay_seconds'] = initial_delay_seconds
Expand All @@ -129,6 +133,7 @@ def from_yaml_config(config: Dict[str, Any]) -> 'SkyServiceSpec':
'`readiness_probe` section of your service YAML.'
) from e
service_config['post_data'] = post_data
service_config['readiness_headers'] = readiness_headers

policy_section = config.get('replica_policy', None)
simplified_policy_section = config.get('replicas', None)
Expand Down Expand Up @@ -204,6 +209,7 @@ def add_if_not_none(section, key, value, no_empty: bool = False):
add_if_not_none('readiness_probe', 'initial_delay_seconds',
self.initial_delay_seconds)
add_if_not_none('readiness_probe', 'post_data', self.post_data)
add_if_not_none('readiness_probe', 'headers', self._readiness_headers)
add_if_not_none('replica_policy', 'min_replicas', self.min_replicas)
add_if_not_none('replica_policy', 'max_replicas', self.max_replicas)
add_if_not_none('replica_policy', 'target_qps_per_replica',
Expand All @@ -220,8 +226,12 @@ def add_if_not_none(section, key, value, no_empty: bool = False):

def probe_str(self):
if self.post_data is None:
return f'GET {self.readiness_path}'
return f'POST {self.readiness_path} {json.dumps(self.post_data)}'
method = f'GET {self.readiness_path}'
else:
method = f'POST {self.readiness_path} {json.dumps(self.post_data)}'
headers = ('' if self.readiness_headers is None else
' with custom headers')
return f'{method}{headers}'

def spot_policy_str(self):
policy_strs = []
Expand Down Expand Up @@ -287,6 +297,10 @@ def target_qps_per_replica(self) -> Optional[float]:
def post_data(self) -> Optional[Dict[str, Any]]:
return self._post_data

@property
def readiness_headers(self) -> Optional[Dict[str, str]]:
return self._readiness_headers

@property
def base_ondemand_fallback_replicas(self) -> Optional[int]:
return self._base_ondemand_fallback_replicas
Expand Down
8 changes: 7 additions & 1 deletion sky/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,13 @@ def get_service_schema():
}, {
'type': 'object',
}]
}
},
'headers': {
'type': 'object',
'additionalProperties': {
'type': 'string'
}
},
}
}]
},
Expand Down

0 comments on commit cf840dc

Please sign in to comment.