-
Notifications
You must be signed in to change notification settings - Fork 994
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #793 from randydl/feat-multi-gpu
Add multi_gpu process project
- Loading branch information
Showing
6 changed files
with
161 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
## 项目简介 | ||
本项目提供基于 LitServe 的多 GPU 并行处理方案。LitServe 是一个简便且灵活的 AI 模型服务引擎,基于 FastAPI 构建。它为 FastAPI 增强了批处理、流式传输和 GPU 自动扩展等功能,无需为每个模型单独重建 FastAPI 服务器。 | ||
|
||
## 环境配置 | ||
请使用以下命令配置所需的环境: | ||
```bash | ||
pip install -U litserve python-multipart filetype | ||
pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com | ||
pip install paddlepaddle-gpu==3.0.0b1 -i https://www.paddlepaddle.org.cn/packages/stable/cu118 | ||
``` | ||
|
||
## 快速使用 | ||
### 1. 启动服务端 | ||
以下示例展示了如何启动服务端,支持自定义设置: | ||
```python | ||
server = ls.LitServer( | ||
MinerUAPI(output_dir='/tmp'), # 可自定义输出文件夹 | ||
accelerator='cuda', # 启用 GPU 加速 | ||
devices='auto', # "auto" 使用所有 GPU | ||
workers_per_device=1, # 每个 GPU 启动一个服务实例 | ||
timeout=False # 设置为 False 以禁用超时 | ||
) | ||
server.run(port=8000) # 设定服务端口为 8000 | ||
``` | ||
|
||
启动服务端命令: | ||
```bash | ||
python server.py | ||
``` | ||
|
||
### 2. 启动客户端 | ||
以下代码展示了客户端的使用方式,可根据需求修改配置: | ||
```python | ||
files = ['demo/small_ocr.pdf'] # 替换为文件路径,支持 jpg/jpeg、png、pdf 文件 | ||
n_jobs = np.clip(len(files), 1, 8) # 设置并发线程数,此处最大为 8,可根据自身修改 | ||
results = Parallel(n_jobs, prefer='threads', verbose=10)( | ||
delayed(do_parse)(p) for p in files | ||
) | ||
print(results) | ||
``` | ||
|
||
启动客户端命令: | ||
```bash | ||
python client.py | ||
``` | ||
好了,你的文件会自动在多个 GPU 上并行处理!🍻🍻🍻 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import base64 | ||
import requests | ||
import numpy as np | ||
from loguru import logger | ||
from joblib import Parallel, delayed | ||
|
||
|
||
def to_b64(file_path): | ||
try: | ||
with open(file_path, 'rb') as f: | ||
return base64.b64encode(f.read()).decode('utf-8') | ||
except Exception as e: | ||
raise Exception(f'File: {file_path} - Info: {e}') | ||
|
||
|
||
def do_parse(file_path, url='http://127.0.0.1:8000/predict', **kwargs): | ||
try: | ||
response = requests.post(url, json={ | ||
'file': to_b64(file_path), | ||
'kwargs': kwargs | ||
}) | ||
|
||
if response.status_code == 200: | ||
output = response.json() | ||
output['file_path'] = file_path | ||
return output | ||
else: | ||
raise Exception(response.text) | ||
except Exception as e: | ||
logger.error(f'File: {file_path} - Info: {e}') | ||
|
||
|
||
if __name__ == '__main__': | ||
files = ['small_ocr.pdf'] | ||
n_jobs = np.clip(len(files), 1, 8) | ||
results = Parallel(n_jobs, prefer='threads', verbose=10)( | ||
delayed(do_parse)(p) for p in files | ||
) | ||
print(results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import os | ||
import fitz | ||
import torch | ||
import base64 | ||
import litserve as ls | ||
from uuid import uuid4 | ||
from fastapi import HTTPException | ||
from filetype import guess_extension | ||
from magic_pdf.tools.common import do_parse | ||
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton | ||
|
||
|
||
class MinerUAPI(ls.LitAPI): | ||
def __init__(self, output_dir='/tmp'): | ||
self.output_dir = output_dir | ||
|
||
def setup(self, device): | ||
if device.startswith('cuda'): | ||
os.environ['CUDA_VISIBLE_DEVICES'] = device.split(':')[-1] | ||
if torch.cuda.device_count() > 1: | ||
raise RuntimeError("Remove any CUDA actions before setting 'CUDA_VISIBLE_DEVICES'.") | ||
|
||
model_manager = ModelSingleton() | ||
model_manager.get_model(True, False) | ||
model_manager.get_model(False, False) | ||
print(f'Model initialization complete on {device}!') | ||
|
||
def decode_request(self, request): | ||
file = request['file'] | ||
file = self.to_pdf(file) | ||
opts = request.get('kwargs', {}) | ||
opts.setdefault('debug_able', False) | ||
opts.setdefault('parse_method', 'auto') | ||
return file, opts | ||
|
||
def predict(self, inputs): | ||
try: | ||
do_parse(self.output_dir, pdf_name := str(uuid4()), inputs[0], [], **inputs[1]) | ||
return pdf_name | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
finally: | ||
self.clean_memory() | ||
|
||
def encode_response(self, response): | ||
return {'output_dir': response} | ||
|
||
def clean_memory(self): | ||
import gc | ||
if torch.cuda.is_available(): | ||
torch.cuda.empty_cache() | ||
torch.cuda.ipc_collect() | ||
gc.collect() | ||
|
||
def to_pdf(self, file_base64): | ||
try: | ||
file_bytes = base64.b64decode(file_base64) | ||
file_ext = guess_extension(file_bytes) | ||
with fitz.open(stream=file_bytes, filetype=file_ext) as f: | ||
if f.is_pdf: return f.tobytes() | ||
return f.convert_to_pdf() | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
|
||
if __name__ == '__main__': | ||
server = ls.LitServer( | ||
MinerUAPI(output_dir='/tmp'), | ||
accelerator='cuda', | ||
devices='auto', | ||
workers_per_device=1, | ||
timeout=False | ||
) | ||
server.run(port=8000) |
Binary file not shown.