diff --git a/examples/fastapi.py b/examples/fastapi.py new file mode 100644 index 0000000..71d3778 --- /dev/null +++ b/examples/fastapi.py @@ -0,0 +1,44 @@ +import random +import numpy as np +from typing import Dict, List + +from fastapi import FastAPI # type: ignore +from joblib import load + +from sklearn.datasets import fetch_20newsgroups +from sklearn.model_selection import train_test_split + +from microbatch import AsyncUBatch + + +ngd = fetch_20newsgroups(subset="all") + +X = ngd.data +y = ngd.target +_, X_test, _, _ = train_test_split(X, y, test_size=0.33) + + +model = load("xgbregressor.joblib") + + +def predict(data: List[np.array]) -> List[np.float32]: + return model.predict(np.array(data)) # type: ignore + + +preidct_ubatch = AsyncUBatch[List[np.array], np.float32](max_size=100, timeout=0.01) +preidct_ubatch.set_handler(predict) +preidct_ubatch.start() + +app = FastAPI() + + +@app.post("/predict_ubatch") # type: ignore +async def predict_ubatch_post() -> Dict[str, float]: + output = await preidct_ubatch.ubatch(random.choice(X_test)) + return {"prediction": float(output)} + + +@app.post("/predict") # type: ignore +def predict_post() -> Dict[str, float]: + output = predict([random.choice(X_test)])[0] + return {"prediction": float(output)} diff --git a/examples/flask_app.py b/examples/flask_app.py index f70fe4c..84defc8 100644 --- a/examples/flask_app.py +++ b/examples/flask_app.py @@ -6,11 +6,8 @@ from flask import Flask from flask_restx import Resource, Api -# from numpy import genfromtxt - from ubatch import ubatch_decorator -# from keras.models import load_model from sklearn.datasets import fetch_20newsgroups from sklearn.model_selection import train_test_split @@ -25,7 +22,6 @@ model = load("xgbregressor.joblib") -# X_test = genfromtxt("xgbregressor_inputs.csv", delimiter=",") app = Flask(__name__) api = Api(app) @@ -37,14 +33,14 @@ def predict(data: List[np.array]) -> List[np.float32]: @api.route("/predict_ubatch") -class BatchPredict(Resource): +class BatchPredict(Resource): # type: ignore def post(self) -> Dict[str, float]: - output = predict.ubatch(random.choice(X_test)) + output: np.array = predict.ubatch(random.choice(X_test)) return {"prediction": float(output)} @api.route("/predict") -class Predict(Resource): +class Predict(Resource): # type: ignore def post(self) -> Dict[str, float]: - output = predict([random.choice(X_test)])[0] + output: np.array = predict([random.choice(X_test)])[0] return {"prediction": float(output)} diff --git a/poetry.lock b/poetry.lock index 5bb873e..5eb4e50 100644 --- a/poetry.lock +++ b/poetry.lock @@ -159,7 +159,7 @@ toml = ["toml"] name = "dataclasses" version = "0.8" description = "A backport of the dataclasses module for Python 3.6" -category = "dev" +category = "main" optional = false python-versions = ">=3.6, <3.7" @@ -171,6 +171,24 @@ category = "dev" optional = false python-versions = "*" +[[package]] +name = "fastapi" +version = "0.61.2" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +category = "main" +optional = true +python-versions = ">=3.6" + +[package.dependencies] +pydantic = ">=1.0.0,<2.0.0" +starlette = "0.13.6" + +[package.extras] +all = ["requests (>=2.24.0,<3.0.0)", "aiofiles (>=0.5.0,<0.6.0)", "jinja2 (>=2.11.2,<3.0.0)", "python-multipart (>=0.0.5,<0.0.6)", "itsdangerous (>=1.1.0,<2.0.0)", "pyyaml (>=5.3.1,<6.0.0)", "graphene (>=2.1.8,<3.0.0)", "ujson (>=3.0.0,<4.0.0)", "orjson (>=3.2.1,<4.0.0)", "email_validator (>=1.1.1,<2.0.0)", "uvicorn (>=0.11.5,<0.12.0)", "async_exit_stack (>=1.0.1,<2.0.0)", "async_generator (>=1.10,<2.0.0)"] +dev = ["python-jose[cryptography] (>=3.1.0,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "uvicorn (>=0.11.5,<0.12.0)", "graphene (>=2.1.8,<3.0.0)"] +doc = ["mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=5.5.0,<6.0.0)", "markdown-include (>=0.5.1,<0.6.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.2.0)", "typer (>=0.3.0,<0.4.0)", "typer-cli (>=0.0.9,<0.0.10)", "pyyaml (>=5.3.1,<6.0.0)"] +test = ["pytest (==5.4.3)", "pytest-cov (==2.10.0)", "pytest-asyncio (>=0.14.0,<0.15.0)", "mypy (==0.782)", "flake8 (>=3.8.3,<4.0.0)", "black (==19.10b0)", "isort (>=5.0.6,<6.0.0)", "requests (>=2.24.0,<3.0.0)", "httpx (>=0.14.0,<0.15.0)", "email_validator (>=1.1.1,<2.0.0)", "sqlalchemy (>=1.3.18,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "databases[sqlite] (>=0.3.2,<0.4.0)", "orjson (>=3.2.1,<4.0.0)", "async_exit_stack (>=1.0.1,<2.0.0)", "async_generator (>=1.10,<2.0.0)", "python-multipart (>=0.0.5,<0.0.6)", "aiofiles (>=0.5.0,<0.6.0)", "flask (>=1.1.2,<2.0.0)"] + [[package]] name = "filelock" version = "3.0.12" @@ -345,6 +363,14 @@ gevent = ["gevent (>=0.13)"] setproctitle = ["setproctitle"] tornado = ["tornado (>=0.2)"] +[[package]] +name = "h11" +version = "0.11.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +category = "main" +optional = true +python-versions = "*" + [[package]] name = "h5py" version = "2.10.0" @@ -700,6 +726,22 @@ category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +[[package]] +name = "pydantic" +version = "1.7.2" +description = "Data validation and settings management using python 3.6 type hinting" +category = "main" +optional = true +python-versions = ">=3.6" + +[package.dependencies] +dataclasses = {version = ">=0.6", markers = "python_version < \"3.7\""} + +[package.extras] +dotenv = ["python-dotenv (>=0.10.4)"] +email = ["email-validator (>=1.0.3)"] +typing_extensions = ["typing-extensions (>=3.7.2)"] + [[package]] name = "pyflakes" version = "2.2.0" @@ -930,6 +972,17 @@ category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +[[package]] +name = "starlette" +version = "0.13.6" +description = "The little ASGI library that shines." +category = "main" +optional = true +python-versions = ">=3.6" + +[package.extras] +full = ["aiofiles", "graphene", "itsdangerous", "jinja2", "python-multipart", "pyyaml", "requests", "ujson"] + [[package]] name = "stevedore" version = "3.2.2" @@ -1040,7 +1093,7 @@ python-versions = "*" name = "typing-extensions" version = "3.7.4.3" description = "Backported and Experimental Type Hints for Python 3.5+" -category = "dev" +category = "main" optional = false python-versions = "*" @@ -1057,6 +1110,22 @@ brotli = ["brotlipy (>=0.6.0)"] secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "uvicorn" +version = "0.12.2" +description = "The lightning-fast ASGI server." +category = "main" +optional = true +python-versions = "*" + +[package.dependencies] +click = ">=7.0.0,<8.0.0" +h11 = ">=0.8" +typing-extensions = {version = "*", markers = "python_version < \"3.8\""} + +[package.extras] +standard = ["websockets (>=8.0.0,<9.0.0)", "watchgod (>=0.6,<0.7)", "python-dotenv (>=0.13)", "PyYAML (>=5.1)", "httptools (>=0.1.0,<0.2.0)", "uvloop (>=0.14.0)", "colorama (>=0.4)"] + [[package]] name = "virtualenv" version = "20.1.0" @@ -1129,12 +1198,12 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["pytest (>=3.5,!=3.7.3)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pytest-cov", "jaraco.test (>=3.2.0)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"] [extras] -benchmark = ["flask-restx", "scikit-learn", "gunicorn", "xgboost", "keras", "tensorflow"] +benchmark = ["flask-restx", "scikit-learn", "gunicorn", "xgboost", "keras", "tensorflow", "fastapi", "uvicorn"] [metadata] lock-version = "1.1" python-versions = ">= 3.6, < 3.9" -content-hash = "3d924689c0590ade71c1838bbc3beffe480523f4e96a7f8b9dc63fd80a072aae" +content-hash = "d683d622a9efe796a02f627f982221468aa8abd6949bf529a5ce2f815688d40d" [metadata.files] absl-py = [ @@ -1236,6 +1305,10 @@ distlib = [ {file = "distlib-0.3.1-py2.py3-none-any.whl", hash = "sha256:8c09de2c67b3e7deef7184574fc060ab8a793e7adbb183d942c389c8b13c52fb"}, {file = "distlib-0.3.1.zip", hash = "sha256:edf6116872c863e1aa9d5bb7cb5e05a022c519a4594dc703843343a9ddd9bff1"}, ] +fastapi = [ + {file = "fastapi-0.61.2-py3-none-any.whl", hash = "sha256:8c8517680a221e69eb34073adf46c503092db2f24845b7bdc7f85b54f24ff0df"}, + {file = "fastapi-0.61.2.tar.gz", hash = "sha256:9e0494fcbba98f85b8cc9b2606bb6b625246e1b12f79ca61f508b0b00843eca6"}, +] filelock = [ {file = "filelock-3.0.12-py3-none-any.whl", hash = "sha256:929b7d63ec5b7d6b71b0fa5ac14e030b3f70b75747cef1b10da9b879fef15836"}, {file = "filelock-3.0.12.tar.gz", hash = "sha256:18d82244ee114f543149c66a6e0c14e9c4f8a1044b5cdaadd0f82159d6a6ff59"}, @@ -1333,6 +1406,10 @@ gunicorn = [ {file = "gunicorn-20.0.4-py2.py3-none-any.whl", hash = "sha256:cd4a810dd51bf497552cf3f863b575dabd73d6ad6a91075b65936b151cbf4f9c"}, {file = "gunicorn-20.0.4.tar.gz", hash = "sha256:1904bb2b8a43658807108d59c3f3d56c2b6121a701161de0ddf9ad140073c626"}, ] +h11 = [ + {file = "h11-0.11.0-py2.py3-none-any.whl", hash = "sha256:ab6c335e1b6ef34b205d5ca3e228c9299cc7218b049819ec84a388c2525e5d87"}, + {file = "h11-0.11.0.tar.gz", hash = "sha256:3c6c61d69c6f13d41f1b80ab0322f1872702a3ba26e12aa864c928f6a43fbaab"}, +] h5py = [ {file = "h5py-2.10.0-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:ecf4d0b56ee394a0984de15bceeb97cbe1fe485f1ac205121293fc44dcf3f31f"}, {file = "h5py-2.10.0-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:86868dc07b9cc8cb7627372a2e6636cdc7a53b7e2854ad020c9e9d8a4d3fd0f5"}, @@ -1584,6 +1661,30 @@ pycodestyle = [ {file = "pycodestyle-2.6.0-py2.py3-none-any.whl", hash = "sha256:2295e7b2f6b5bd100585ebcb1f616591b652db8a741695b3d8f5d28bdc934367"}, {file = "pycodestyle-2.6.0.tar.gz", hash = "sha256:c58a7d2815e0e8d7972bf1803331fb0152f867bd89adf8a01dfd55085434192e"}, ] +pydantic = [ + {file = "pydantic-1.7.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:dfaa6ed1d509b5aef4142084206584280bb6e9014f01df931ec6febdad5b200a"}, + {file = "pydantic-1.7.2-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:2182ba2a9290964b278bcc07a8d24207de709125d520efec9ad6fa6f92ee058d"}, + {file = "pydantic-1.7.2-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:0fe8b45d31ae53d74a6aa0bf801587bd49970070eac6a6326f9fa2a302703b8a"}, + {file = "pydantic-1.7.2-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:01f0291f4951580f320f7ae3f2ecaf0044cdebcc9b45c5f882a7e84453362420"}, + {file = "pydantic-1.7.2-cp36-cp36m-win_amd64.whl", hash = "sha256:4ba6b903e1b7bd3eb5df0e78d7364b7e831ed8b4cd781ebc3c4f1077fbcb72a4"}, + {file = "pydantic-1.7.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b11fc9530bf0698c8014b2bdb3bbc50243e82a7fa2577c8cfba660bcc819e768"}, + {file = "pydantic-1.7.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:a3c274c49930dc047a75ecc865e435f3df89715c775db75ddb0186804d9b04d0"}, + {file = "pydantic-1.7.2-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:c68b5edf4da53c98bb1ccb556ae8f655575cb2e676aef066c12b08c724a3f1a1"}, + {file = "pydantic-1.7.2-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:95d4410c4e429480c736bba0db6cce5aaa311304aea685ebcf9ee47571bfd7c8"}, + {file = "pydantic-1.7.2-cp37-cp37m-win_amd64.whl", hash = "sha256:a2fc7bf77ed4a7a961d7684afe177ff59971828141e608f142e4af858e07dddc"}, + {file = "pydantic-1.7.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9572c0db13c8658b4a4cb705dcaae6983aeb9842248b36761b3fbc9010b740f"}, + {file = "pydantic-1.7.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:f83f679e727742b0c465e7ef992d6da4a7e5268b8edd8fdaf5303276374bef52"}, + {file = "pydantic-1.7.2-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:e5fece30e80087d9b7986104e2ac150647ec1658c4789c89893b03b100ca3164"}, + {file = "pydantic-1.7.2-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:ce2d452961352ba229fe1e0b925b41c0c37128f08dddb788d0fd73fd87ea0f66"}, + {file = "pydantic-1.7.2-cp38-cp38-win_amd64.whl", hash = "sha256:fc21a37ff3f545de80b166e1735c4172b41b017948a3fb2d5e2f03c219eac50a"}, + {file = "pydantic-1.7.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c9760d1556ec59ff745f88269a8f357e2b7afc75c556b3a87b8dda5bc62da8ba"}, + {file = "pydantic-1.7.2-cp39-cp39-manylinux1_i686.whl", hash = "sha256:2c1673633ad1eea78b1c5c420a47cd48717d2ef214c8230d96ca2591e9e00958"}, + {file = "pydantic-1.7.2-cp39-cp39-manylinux2014_i686.whl", hash = "sha256:388c0c26c574ff49bad7d0fd6ed82fbccd86a0473fa3900397d3354c533d6ebb"}, + {file = "pydantic-1.7.2-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:ab1d5e4d8de00575957e1c982b951bffaedd3204ddd24694e3baca3332e53a23"}, + {file = "pydantic-1.7.2-cp39-cp39-win_amd64.whl", hash = "sha256:f045cf7afb3352a03bc6cb993578a34560ac24c5d004fa33c76efec6ada1361a"}, + {file = "pydantic-1.7.2-py3-none-any.whl", hash = "sha256:6665f7ab7fbbf4d3c1040925ff4d42d7549a8c15fe041164adfe4fc2134d4cce"}, + {file = "pydantic-1.7.2.tar.gz", hash = "sha256:c8200aecbd1fb914e1bd061d71a4d1d79ecb553165296af0c14989b89e90d09b"}, +] pyflakes = [ {file = "pyflakes-2.2.0-py2.py3-none-any.whl", hash = "sha256:0d94e0e05a19e57a99444b6ddcf9a6eb2e5c68d3ca1e98e90707af8152c90a92"}, {file = "pyflakes-2.2.0.tar.gz", hash = "sha256:35b2d75ee967ea93b55750aa9edbbf72813e06a66ba54438df2cfac9e3c27fc8"}, @@ -1749,6 +1850,10 @@ smmap = [ {file = "smmap-3.0.4-py2.py3-none-any.whl", hash = "sha256:54c44c197c819d5ef1991799a7e30b662d1e520f2ac75c9efbeb54a742214cf4"}, {file = "smmap-3.0.4.tar.gz", hash = "sha256:9c98bbd1f9786d22f14b3d4126894d56befb835ec90cef151af566c7e19b5d24"}, ] +starlette = [ + {file = "starlette-0.13.6-py3-none-any.whl", hash = "sha256:bd2ffe5e37fb75d014728511f8e68ebf2c80b0fa3d04ca1479f4dc752ae31ac9"}, + {file = "starlette-0.13.6.tar.gz", hash = "sha256:ebe8ee08d9be96a3c9f31b2cb2a24dbdf845247b745664bd8a3f9bd0c977fdbc"}, +] stevedore = [ {file = "stevedore-3.2.2-py3-none-any.whl", hash = "sha256:5e1ab03eaae06ef6ce23859402de785f08d97780ed774948ef16c4652c41bc62"}, {file = "stevedore-3.2.2.tar.gz", hash = "sha256:f845868b3a3a77a2489d226568abe7328b5c2d4f6a011cc759dfa99144a521f0"}, @@ -1819,6 +1924,10 @@ urllib3 = [ {file = "urllib3-1.26.2-py2.py3-none-any.whl", hash = "sha256:d8ff90d979214d7b4f8ce956e80f4028fc6860e4431f731ea4a8c08f23f99473"}, {file = "urllib3-1.26.2.tar.gz", hash = "sha256:19188f96923873c92ccb987120ec4acaa12f0461fa9ce5d3d0772bc965a39e08"}, ] +uvicorn = [ + {file = "uvicorn-0.12.2-py3-none-any.whl", hash = "sha256:e5dbed4a8a44c7b04376021021d63798d6a7bcfae9c654a0b153577b93854fba"}, + {file = "uvicorn-0.12.2.tar.gz", hash = "sha256:8ff7495c74b8286a341526ff9efa3988ebab9a4b2f561c7438c3cb420992d7dd"}, +] virtualenv = [ {file = "virtualenv-20.1.0-py2.py3-none-any.whl", hash = "sha256:b0011228208944ce71052987437d3843e05690b2f23d1c7da4263fde104c97a2"}, {file = "virtualenv-20.1.0.tar.gz", hash = "sha256:b8d6110f493af256a40d65e29846c69340a947669eec8ce784fcf3dd3af28380"}, diff --git a/pyproject.toml b/pyproject.toml index 351020b..6ecd040 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,8 @@ readme = "README.md" [tool.poetry.dependencies] python = ">= 3.6, < 3.9" flask-restx = { version = "^0.2", optional = true } +fastapi = { version = "^0.61", optional = true } +uvicorn = { version = "^0.12", optional = true } scikit-learn = { version = "^0.23", optional = true } gunicorn = { version = "^20.0", optional = true } xgboost = { version = "^1.2", optional = true } @@ -50,7 +52,7 @@ pytest-timeout = "^1.4" pytest-reraise = "^1.0" [tool.poetry.extras] -benchmark = ["flask-restx", "scikit-learn", "gunicorn", "xgboost", "keras", "tensorflow"] +benchmark = ["flask-restx", "scikit-learn", "gunicorn", "xgboost", "keras", "tensorflow", "fastapi", "uvicorn"] [tool.black] line-length = 88 diff --git a/ubatch/__init__.py b/ubatch/__init__.py index 575193e..56886e2 100644 --- a/ubatch/__init__.py +++ b/ubatch/__init__.py @@ -1,7 +1,10 @@ +from ubatch.async_ubatch import AsyncUBatch from ubatch.decorators import ubatch_decorator -from ubatch.ubatch import BadBatchOutputSize, HandlerAlreadySet, HandlerNotSet, UBatch +from ubatch.exceptions import BadBatchOutputSize, HandlerAlreadySet, HandlerNotSet +from ubatch.ubatch import UBatch __all__ = [ + "AsyncUBatch", "UBatch", "HandlerNotSet", "HandlerAlreadySet", diff --git a/ubatch/async_ubatch.py b/ubatch/async_ubatch.py new file mode 100644 index 0000000..500b77d --- /dev/null +++ b/ubatch/async_ubatch.py @@ -0,0 +1,143 @@ +import asyncio +import concurrent.futures +import logging +import time +from collections import deque +from functools import partial +from typing import Callable, Deque, Generic, List, Optional + +from ubatch.data_request import DataRequest, DataRequestBuffer, S, T +from ubatch.exceptions import BadBatchOutputSize, HandlerAlreadySet, HandlerNotSet + +logger = logging.getLogger(__name__) + + +CHECK_INTERVAL = 0.001 # Time to wait (in seconds) if queue is empty. +MONITOR_INTERVAL = 5 # Time to wait (in seconds) for logging statistics. + + +class AsyncUBatch(Generic[T, S]): + def __init__(self, max_size: int, timeout: float): + """Join multiple individual inputs into one batch of inputs. + + Args: + handler: User function that handle batches. + max_size: Maximum size of inputs to pass to the handler. + timeout: Maximum time (in seconds) to wait for inputs before + starting to process them. + """ + + self.max_size = max_size # Maximum size of handler inputs. + self.timeout = timeout # Maximum time (in seconds) of inputs to wait. + + self._handler: Optional[Callable[[List[T]], List[S]]] = None + # TODO: let select users if run in thread or process, + # some c libs release GIL, what about corrutine? + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self._queue: Deque[DataRequest[T, S]] = deque() + self.pending: int = 0 # Pending batch to being processed + + async def _monitor(self) -> None: + while True: + logging.info( + "queue size: %s, pending batch: %s", len(self._queue), self.pending + ) + await asyncio.sleep(MONITOR_INTERVAL) + + def set_handler(self, handler: Callable[[List[T]], List[S]]) -> None: + """Set user function to handle inputs data + + Args: + handler: Any callable to handle input data and return output data + """ + if self._handler: + raise HandlerAlreadySet() + + self._handler = handler + + async def _run_in_executor(self, buffer: DataRequestBuffer[T, S]) -> None: + if not self._handler: + raise HandlerNotSet + + loop = asyncio.get_event_loop() + + logging.debug("process send to pool with buffer: %s", buffer) + data = [x.data for x in buffer] + + try: + self.pending += 1 + # TODO: python 3.9 asyncio.to_thread + outputs = await loop.run_in_executor( + self._executor, partial(self._handler, data) + ) + self.pending -= 1 + + if len(outputs) != len(data): + # This exception is going to be set in every DataRequest + raise BadBatchOutputSize(len(data), len(outputs)) + except Exception as ex: + for dr in buffer: + dr.exception = ex + else: + for dr, o in zip(buffer, outputs): + dr.output = o + + logging.debug("end pool with buffer: %s", buffer) + + async def _process_queue(self) -> None: + loop = asyncio.get_event_loop() + + while True: + buffer = DataRequestBuffer[T, S](size=self.max_size) + + # Wait for at least 1 item is in buffer + while len(buffer) < 1: + try: + buffer.append(self._queue.pop()) + except IndexError: + await asyncio.sleep(CHECK_INTERVAL) + + _timeout = time.time() + self.timeout + _timeouted = False + + while not (buffer.full() or _timeouted): + _timeouted = time.time() > _timeout + try: + buffer.append(self._queue.pop()) + except IndexError: + await asyncio.sleep(CHECK_INTERVAL) + + # If thread/process is busy keep adding elements to buffer + while not (buffer.full() or self.pending != 0): + try: + buffer.append(self._queue.pop()) + except IndexError: + await asyncio.sleep(CHECK_INTERVAL) + + if buffer: + logging.debug("processing (len): %s", len(buffer)) + loop.create_task(self._run_in_executor(buffer)) + + async def ubatch(self, data: T) -> S: + # Async UBatch do not use DataRequest timeout + data_request = DataRequest[T, S](data=data, timeout=0) + + self._queue.append(data_request) + + while not data_request.ready: + await asyncio.sleep(CHECK_INTERVAL) + + logger.debug("Request ready: total time: %s", data_request.latency) + + return data_request.output + + def start(self) -> "AsyncUBatch[T, S]": # pragma: no cover + if not self._handler: + raise HandlerNotSet() + + loop = asyncio.get_event_loop() + + loop.create_task(self._monitor()) + loop.create_task(self._process_queue()) + + return self diff --git a/ubatch/exceptions.py b/ubatch/exceptions.py new file mode 100644 index 0000000..801d713 --- /dev/null +++ b/ubatch/exceptions.py @@ -0,0 +1,22 @@ +class BadBatchOutputSize(Exception): + def __init__(self, input_size: int, output_size: int): + """Raised when output size of handler differs from input size + + Args: + input_size: Size of input + output_size: Size of output + """ + self.input_size = input_size + self.output_size = output_size + self.message = ( + f"Output size: {output_size} differs from the input size: {input_size}" + ) + super().__init__(self.message) + + +class HandlerNotSet(Exception): + """Raised when not handler is set in MicroBatch""" + + +class HandlerAlreadySet(Exception): + """Raised when trying to change handler""" diff --git a/ubatch/ubatch.py b/ubatch/ubatch.py index 933582e..adbce94 100644 --- a/ubatch/ubatch.py +++ b/ubatch/ubatch.py @@ -5,34 +5,11 @@ from typing import Callable, Generic, List, Optional from ubatch.data_request import DataRequest, DataRequestBuffer, S, T +from ubatch.exceptions import BadBatchOutputSize, HandlerAlreadySet, HandlerNotSet logger = logging.getLogger(__name__) -class BadBatchOutputSize(Exception): - def __init__(self, input_size: int, output_size: int): - """Raised when output size of handler differs from input size - - Args: - input_size: Size of input - output_size: Size of output - """ - self.input_size = input_size - self.output_size = output_size - self.message = ( - f"Output size: {output_size} differs from the input size: {input_size}" - ) - super().__init__(self.message) - - -class HandlerNotSet(Exception): - """Raised when not handler is set in UBatch""" - - -class HandlerAlreadySet(Exception): - """Raised when trying to change handler""" - - class UBatch(Generic[T, S]): CHECK_INTERVAL = 0.002 # Time to wait (in seconds) if queue is empty.