Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Proxy w/ retry #3395

Merged
merged 34 commits into from
May 14, 2024
Merged

[Serve] Proxy w/ retry #3395

merged 34 commits into from
May 14, 2024

Conversation

cblmemo
Copy link
Collaborator

@cblmemo cblmemo commented Mar 31, 2024

Using proxy on our load balancer w/ retry. It is useful for spot-based serving.

TODO:

  • Test and make sure it works
  • Benchmark, compare w/ previous implementation
  • Support streaming
  • Make it more robust
  • Update documentation

Tested (run the relevant ones):

  • Code formatting: bash format.sh
  • Any manual or new tests for this PR (please specify below)
import fastapi, uvicorn, asyncio
import multiprocessing
from sky.serve import load_balancer

REPLICA_URLS = []
PROCESSES = []
CONTROLLER_PORT = 20018
WORD_TO_STREAM = 'Hello world! Nice to meet you!'
TIME_TO_SLEEP = 0.2

def _start_streaming_replica(port):
    app = fastapi.FastAPI()

    @app.get('/')
    async def stream():
        async def generate_words():
            for word in WORD_TO_STREAM.split():
                yield word + "\n"
                await asyncio.sleep(TIME_TO_SLEEP)
        
        return fastapi.responses.StreamingResponse(generate_words(), media_type="text/plain")

    @app.get('/non-stream')
    async def non_stream():
        return {'message': WORD_TO_STREAM}

    @app.get('/error')
    async def error():
        raise fastapi.HTTPException(status_code=500, detail='Internal Server Error')

    uvicorn.run(app, host='0.0.0.0', port=port)

def _start_streaming_replica_in_process(port):
    global PROCESSES, REPLICA_URLS
    STREAMING_REPLICA_PROCESS = multiprocessing.Process(target=_start_streaming_replica, args=(port,))
    STREAMING_REPLICA_PROCESS.start()
    PROCESSES.append(STREAMING_REPLICA_PROCESS)
    REPLICA_URLS.append(f'0.0.0.0:{port}')

def _start_controller():
    app = fastapi.FastAPI()

    @app.post('/controller/load_balancer_sync')
    async def lb_sync(request: fastapi.Request):
        return {'ready_replica_urls': REPLICA_URLS}

    uvicorn.run(app, host='0.0.0.0', port=CONTROLLER_PORT)

def _start_controller_in_process():
    global PROCESSES
    CONTROLLER_PROCESS = multiprocessing.Process(target=_start_controller)
    CONTROLLER_PROCESS.start()
    PROCESSES.append(CONTROLLER_PROCESS)

if __name__ == '__main__':
    try:
        _start_streaming_replica_in_process(7001)
        _start_streaming_replica_in_process(7002)
        _start_streaming_replica_in_process(7003)
        _start_streaming_replica_in_process(7004)
        _start_controller_in_process()
        lb = load_balancer.SkyServeLoadBalancer(
            controller_url=f'http://0.0.0.0:{CONTROLLER_PORT}',
            load_balancer_port=7000)
        lb.run()
    finally:
        for p in PROCESSES:
            p.terminate()
  • All skyserve smoke tests: pytest tests/test_smoke.py --serve
  • Relevant individual smoke tests: pytest tests/test_smoke.py::test_fill_in_the_name
  • Backward compatibility tests: bash tests/backward_comaptibility_tests.sh

@cblmemo cblmemo changed the title [Serve] Proxy prototype w/ LeastConnectionPolicy [Serve][Do not Merge] Proxy prototype w/ LeastConnectionPolicy Mar 31, 2024
@cblmemo cblmemo changed the title [Serve][Do not Merge] Proxy prototype w/ LeastConnectionPolicy [Serve][Do not Merge] Proxy prototype Mar 31, 2024
@cblmemo cblmemo changed the title [Serve][Do not Merge] Proxy prototype [Serve][Do not Merge] Proxy prototype w/ retry Mar 31, 2024
@cblmemo cblmemo changed the title [Serve][Do not Merge] Proxy prototype w/ retry [Serve] Proxy w/ retry May 1, 2024
@cblmemo cblmemo requested a review from Michaelvll May 1, 2024 15:32
@cblmemo
Copy link
Collaborator Author

cblmemo commented May 1, 2024

@Michaelvll This is ready for a look now 🫡 I'm still running smoke tests and adding a new streaming test for now, will report back later

@cblmemo cblmemo marked this pull request as ready for review May 1, 2024 15:32
Copy link
Collaborator

@Michaelvll Michaelvll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome @cblmemo! Went through the PR and the code looks mostly good to me! Left several comments. I will have a try soon.

docs/source/serving/sky-serve.rst Outdated Show resolved Hide resolved
docs/source/serving/sky-serve.rst Outdated Show resolved Hide resolved
docs/source/serving/sky-serve.rst Outdated Show resolved Hide resolved
@@ -33,7 +33,7 @@ Client disconnected, stopping computation.
You can also run

```bash
curl -L http://<endpoint>/
curl http://<endpoint>/
```

and manually Ctrl + C to cancel the request and see logs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use an OpenAI API client to stream the output and check if abortion/cancellation work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The smoke test test_skyserve_cancel passed. Do you think that is enough?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we should test with the integration with OpenAI API for both client and server (vllm.entrypoints.openai.api_server) side to make sure our solution works for the most commonly used library.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tested with the example of llama2 (both service YAML using vLLM and openai client) w/ and w/o streaming and it works well 🫡

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the abortion work for it as well?

sky/serve/README.md Outdated Show resolved Hide resolved
sky/serve/load_balancer.py Outdated Show resolved Hide resolved
sky/serve/load_balancer.py Outdated Show resolved Hide resolved
Comment on lines 143 to 145
response = await self._proxy_request_to(ready_replica_url, request)
if response is not None:
return response
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen if the replica dies when the streaming is not finished? Will we retry here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we dont. One challenge here is to recover the state of streaming, i.e. after a recovery the new replica needs to start streaming from the middle to make the client does not aware that an error has happened. (e.g. the total steaming array is [hello, nice, to, meet, you] and an error happened after [hello, nice]. the new replica needs to stream [to, meet, you] but not the whole array [hello, nice, to, meet, you]) But how could the new replica know where to start? It might be possible to do some special case handling for the LLM workloads (record some information, ...) but it will be really hard to support general streaming workloads.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. In that case, if an underlying replica is scaled down or preempted, will the client just hang and take a long time to timeout, it might be good to include a TODO here if that is the case to optimize for replica failure?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will immediately closed due to connection lost

Hello
world!
Nice
to
curl: (18) transfer closed with outstanding read data remaining

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you decide to store state:

... special case handling for the LLM workloads (record some information, ...) ...

It would be ideal to turn this off in order to ensure one can boot a zero retention for security.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah definitely! We will provide an option to enable this feature or not if we implemented this feature. Thanks for the advise!

tests/skyserve/streaming/send_streaming_request.py Outdated Show resolved Hide resolved
tests/test_smoke.py Outdated Show resolved Hide resolved
@cblmemo cblmemo requested a review from Michaelvll May 3, 2024 08:44
@cblmemo
Copy link
Collaborator Author

cblmemo commented May 3, 2024

All skyserve smoke test passed 🫡

Copy link
Collaborator

@Michaelvll Michaelvll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update @cblmemo! Mostly looks good to me. Left several comments

sky/serve/load_balancer.py Outdated Show resolved Hide resolved
sky/serve/load_balancer.py Show resolved Hide resolved
sky/serve/load_balancer.py Outdated Show resolved Hide resolved
sky/serve/load_balancer.py Show resolved Hide resolved
Comment on lines 143 to 145
response = await self._proxy_request_to(ready_replica_url, request)
if response is not None:
return response
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. In that case, if an underlying replica is scaled down or preempted, will the client just hang and take a long time to timeout, it might be good to include a TODO here if that is the case to optimize for replica failure?

sky/serve/load_balancing_policies.py Outdated Show resolved Hide resolved
sky/serve/constants.py Outdated Show resolved Hide resolved
@Michaelvll Michaelvll self-requested a review May 6, 2024 21:20
@cblmemo
Copy link
Collaborator Author

cblmemo commented May 9, 2024

TODO:

@Michaelvll Michaelvll self-requested a review May 9, 2024 00:08
Comment on lines 99 to 100
for client in client_to_close:
asyncio.run(client.aclose())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

profile the performance?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I benchmarked time to close a client and it is in the order of milliseconds, mostly ranging from 2 to 6 ms. I also change it to async execution to run it in the background for a safeguard.

@cblmemo
Copy link
Collaborator Author

cblmemo commented May 9, 2024

Tested for abortion and it works as well. Use the following script to launch LB & worker, and http://0.0.0.0:7000/, then Ctrl+C the curl command. The logging for ===========WORKER will stop.

UPD: I'll test e2e LLM workloads later.

import fastapi, uvicorn, asyncio
import multiprocessing
from sky.serve import load_balancer

REPLICA_URLS = []
PROCESSES = []
CONTROLLER_PORT = 20018
WORD_TO_STREAM = 'Hello world! Nice to meet you!'
TIME_TO_SLEEP = 0.2

def _start_streaming_replica(port):
    app = fastapi.FastAPI()

    @app.get('/')
    async def stream():
        async def generate_words():
            for word in WORD_TO_STREAM.split()*1000:
                yield word + "\n"
                print('===========WORKER', word)
                await asyncio.sleep(TIME_TO_SLEEP)
        
        return fastapi.responses.StreamingResponse(generate_words(), media_type="text/plain")

    @app.get('/non-stream')
    async def non_stream():
        return {'message': WORD_TO_STREAM}

    @app.get('/error')
    async def error():
        raise fastapi.HTTPException(status_code=500, detail='Internal Server Error')

    uvicorn.run(app, host='0.0.0.0', port=port)

def _start_streaming_replica_in_process(port):
    global PROCESSES, REPLICA_URLS
    STREAMING_REPLICA_PROCESS = multiprocessing.Process(target=_start_streaming_replica, args=(port,))
    STREAMING_REPLICA_PROCESS.start()
    PROCESSES.append(STREAMING_REPLICA_PROCESS)
    REPLICA_URLS.append(f'http://0.0.0.0:{port}')

def _start_controller():
    app = fastapi.FastAPI()
    flip_flop = False

    @app.post('/controller/load_balancer_sync')
    async def lb_sync(request: fastapi.Request):
        return {'ready_replica_urls': REPLICA_URLS}

    uvicorn.run(app, host='0.0.0.0', port=CONTROLLER_PORT)

def _start_controller_in_process():
    global PROCESSES
    CONTROLLER_PROCESS = multiprocessing.Process(target=_start_controller)
    CONTROLLER_PROCESS.start()
    PROCESSES.append(CONTROLLER_PROCESS)

if __name__ == '__main__':
    try:
        _start_streaming_replica_in_process(7001)
        _start_controller_in_process()
        lb = load_balancer.SkyServeLoadBalancer(
            controller_url=f'http://0.0.0.0:{CONTROLLER_PORT}',
            load_balancer_port=7000)
        lb.run()
    finally:
        for p in PROCESSES:
            p.terminate()

@cblmemo
Copy link
Collaborator Author

cblmemo commented May 9, 2024

Just tested with a modified version of fastchat and the abortion works well. I uses the OpenAI Client here and manually Ctrl+C to abort the request.

YAML i used:

service:
  readiness_probe: /v1/models
  replicas: 1

resources:
  ports: 8087
  memory: 32+
  accelerators: L4:1
  disk_size: 1024
  disk_tier: best

envs:
  MODEL_SIZE: 7
  HF_TOKEN: <huggingface-token>

setup: |
  conda activate chatbot
  if [ $? -ne 0 ]; then
    conda create -n chatbot python=3.9 -y
    conda activate chatbot
  fi

  # Install dependencies
  git clone https://github.com/cblmemo/fschat-print-streaming.git fschat
  cd fschat
  git switch print-stream
  pip install -e ".[model_worker,webui]"
  python -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"

run: |
  conda activate chatbot
  
  echo 'Starting controller...'
  python -u -m fastchat.serve.controller --host 0.0.0.0 > ~/controller.log 2>&1 &
  sleep 10
  echo 'Starting model worker...'
  python -u -m fastchat.serve.model_worker --host 0.0.0.0 \
            --model-path meta-llama/Llama-2-${MODEL_SIZE}b-chat-hf \
            --num-gpus $SKYPILOT_NUM_GPUS_PER_NODE 2>&1 \
            | tee model_worker.log &

  echo 'Waiting for model worker to start...'
  while ! `cat model_worker.log | grep -q 'Uvicorn running on'`; do sleep 1; done

  echo 'Starting openai api server...'
  python -u -m fastchat.serve.openai_api_server --host 0.0.0.0 --port 8087 | tee ~/openai_api_server.log

Copy link
Collaborator

@Michaelvll Michaelvll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome and an important update @cblmemo! The code looks good to me. One tests are passed, I think it should be good to go.

sky/serve/load_balancer.py Outdated Show resolved Hide resolved
@cblmemo
Copy link
Collaborator Author

cblmemo commented May 14, 2024

This is awesome and an important update @cblmemo! The code looks good to me. One tests are passed, I think it should be good to go.

Thanks! Rerunning all smoke test now, will merge after all of them are passed 🫡

@cblmemo
Copy link
Collaborator Author

cblmemo commented May 14, 2024

Fixed a bug introduced in #3484. Merging now

@cblmemo cblmemo merged commit 5a2f1b8 into master May 14, 2024
20 checks passed
@cblmemo cblmemo deleted the serve-proxy-prototype branch May 14, 2024 13:44
@Michaelvll
Copy link
Collaborator

Fixed a bug introduced in #3484. Merging now

@cblmemo could you elaborate the bug a bit for future reference?

@cblmemo
Copy link
Collaborator Author

cblmemo commented May 14, 2024

Fixed a bug introduced in #3484. Merging now

@cblmemo could you elaborate the bug a bit for future reference?

Sure. The smoke test test_skyserve_auto_restart uses GCP command to manually kill an instance, so it is requires that the replica is running on GCP. However in #3484 we accidentally removed the cloud: gcp in auto_restart.yaml. I added it back ; )

@anishchopra
Copy link

anishchopra commented Jul 4, 2024

@cblmemo @Michaelvll This seems to have broken my service, which exposes the Gradio UI rather than the API itself. Here's what my service config looks like:

envs:
  MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct
  HF_TOKEN: ...

service:
  replica_policy:
    min_replicas: 1
    max_replicas: 3
    target_qps_per_replica: 5

  # An actual request for readiness probe.
  readiness_probe:
    initial_delay_seconds: 1800
    path: /
 
resources:
  cloud: aws
  accelerators: A10G:1
  disk_size: 512  # Ensure model checkpoints can fit.
  ports: 8001  # Expose to internet traffic.

setup: |
  conda activate vllm
  if [ $? -ne 0 ]; then
    conda create -n vllm python=3.10 -y
    conda activate vllm
  fi

  pip install vllm==0.4.2
  # Install Gradio for web UI.
  pip install gradio openai
  pip install flash-attn==2.5.9.post1
  pip install numpy==1.26.4


run: |
  conda activate vllm
  echo 'Starting vllm api server...'

  # https://github.com/vllm-project/vllm/issues/3098
  export PATH=$PATH:/sbin

  # NOTE: --gpu-memory-utilization 0.95 needed for 4-GPU nodes.
  python -u -m vllm.entrypoints.openai.api_server \
    --port 8000 \
    --model $MODEL_NAME \
    --trust-remote-code --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
    --gpu-memory-utilization 0.95 \
    --max-num-seqs 64 \
    2>&1 | tee api_server.log &

  while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do
    echo 'Waiting for vllm api server to start...'
    sleep 5
  done

  echo 'Starting gradio server...'
  git clone https://github.com/vllm-project/vllm.git || true
  python vllm/examples/gradio_openai_chatbot_webserver.py \
    -m $MODEL_NAME \
    --port 8001 \
    --host 0.0.0.0 \
    --model-url http://localhost:8000/v1 \
    --stop-token-ids 128009,128001

The UI looks all messed up and I just get a bunch of error messages when trying to use it. If I open the URL of the instance directly, it works fine.

@cblmemo
Copy link
Collaborator Author

cblmemo commented Jul 7, 2024

@cblmemo @Michaelvll This seems to have broken my service, which exposes the Gradio UI rather than the API itself. Here's what my service config looks like:

envs:
  MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct
  HF_TOKEN: ...

service:
  replica_policy:
    min_replicas: 1
    max_replicas: 3
    target_qps_per_replica: 5

  # An actual request for readiness probe.
  readiness_probe:
    initial_delay_seconds: 1800
    path: /
 
resources:
  cloud: aws
  accelerators: A10G:1
  disk_size: 512  # Ensure model checkpoints can fit.
  ports: 8001  # Expose to internet traffic.

setup: |
  conda activate vllm
  if [ $? -ne 0 ]; then
    conda create -n vllm python=3.10 -y
    conda activate vllm
  fi

  pip install vllm==0.4.2
  # Install Gradio for web UI.
  pip install gradio openai
  pip install flash-attn==2.5.9.post1
  pip install numpy==1.26.4


run: |
  conda activate vllm
  echo 'Starting vllm api server...'

  # https://github.com/vllm-project/vllm/issues/3098
  export PATH=$PATH:/sbin

  # NOTE: --gpu-memory-utilization 0.95 needed for 4-GPU nodes.
  python -u -m vllm.entrypoints.openai.api_server \
    --port 8000 \
    --model $MODEL_NAME \
    --trust-remote-code --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
    --gpu-memory-utilization 0.95 \
    --max-num-seqs 64 \
    2>&1 | tee api_server.log &

  while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do
    echo 'Waiting for vllm api server to start...'
    sleep 5
  done

  echo 'Starting gradio server...'
  git clone https://github.com/vllm-project/vllm.git || true
  python vllm/examples/gradio_openai_chatbot_webserver.py \
    -m $MODEL_NAME \
    --port 8001 \
    --host 0.0.0.0 \
    --model-url http://localhost:8000/v1 \
    --stop-token-ids 128009,128001

The UI looks all messed up and I just get a bunch of error messages when trying to use it. If I open the URL of the instance directly, it works fine.

Hi @anishchopra ! Thanks for the feedback. This is probably due to the proxy trying to send queries on both replicas, but it might contain some session related information which makes it error out; thus single replica deployment works well. Previously, a user will be redirected to one replica and all interactions are made with the redirected replica. Currently, we would suggest using SkyServe to host your API endpoint and launching the Gradio server manually to point to the service endpoint. If you have further suggestions or requirements, could you help filing an issue for this? Thanks!

@anishchopra
Copy link

@cblmemo I actually only had one replica running in this case. Your suggestion of launching a gradio server separately does work, however I bring up this issue because it points to something not being proxied correctly.

@cblmemo
Copy link
Collaborator Author

cblmemo commented Jul 13, 2024

@cblmemo I actually only had one replica running in this case. Your suggestion of launching a gradio server separately does work, however I bring up this issue because it points to something not being proxied correctly.

Humm, could you share the output of sky -v and sky -c? I actually tried on the latest master and it works well, so maybe it is a version issue. And thanks for pointing it out! I just filed an issue #3749 to keep track of this :))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants